risingwave_connector/sink/encoder/
template.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16use std::collections::HashMap;
17
18use regex::{Captures, Regex};
19use risingwave_common::catalog::{Field, Schema};
20use risingwave_common::row::Row;
21use risingwave_common::types::{DataType, ScalarRefImpl, ToText};
22use thiserror_ext::AsReport;
23
24use super::{Result, RowEncoder};
25use crate::sink::SinkError;
26use crate::sink::encoder::SerTo;
27
28pub enum TemplateEncoder {
29    String(TemplateStringEncoder),
30    RedisGeoKey(TemplateRedisGeoKeyEncoder),
31    RedisGeoValue(TemplateRedisGeoValueEncoder),
32    RedisPubSubKey(TemplateRedisPubSubKeyEncoder),
33}
34impl TemplateEncoder {
35    pub fn new_string(schema: Schema, col_indices: Option<Vec<usize>>, template: String) -> Self {
36        TemplateEncoder::String(TemplateStringEncoder::new(schema, col_indices, template))
37    }
38
39    pub fn new_geo_value(
40        schema: Schema,
41        col_indices: Option<Vec<usize>>,
42        lat_name: &str,
43        lon_name: &str,
44    ) -> Result<Self> {
45        Ok(TemplateEncoder::RedisGeoValue(
46            TemplateRedisGeoValueEncoder::new(schema, col_indices, lat_name, lon_name)?,
47        ))
48    }
49
50    pub fn new_geo_key(
51        schema: Schema,
52        col_indices: Option<Vec<usize>>,
53        member_name: &str,
54        template: String,
55    ) -> Result<Self> {
56        Ok(TemplateEncoder::RedisGeoKey(
57            TemplateRedisGeoKeyEncoder::new(schema, col_indices, member_name, template)?,
58        ))
59    }
60
61    pub fn new_pubsub_key(
62        schema: Schema,
63        col_indices: Option<Vec<usize>>,
64        channel: Option<String>,
65        channel_column: Option<String>,
66    ) -> Result<Self> {
67        Ok(TemplateEncoder::RedisPubSubKey(
68            TemplateRedisPubSubKeyEncoder::new(schema, col_indices, channel, channel_column)?,
69        ))
70    }
71}
72impl RowEncoder for TemplateEncoder {
73    type Output = TemplateEncoderOutput;
74
75    fn schema(&self) -> &Schema {
76        match self {
77            TemplateEncoder::String(encoder) => &encoder.schema,
78            TemplateEncoder::RedisGeoValue(encoder) => &encoder.schema,
79            TemplateEncoder::RedisGeoKey(encoder) => &encoder.key_encoder.schema,
80            TemplateEncoder::RedisPubSubKey(encoder) => &encoder.schema,
81        }
82    }
83
84    fn col_indices(&self) -> Option<&[usize]> {
85        match self {
86            TemplateEncoder::String(encoder) => encoder.col_indices.as_deref(),
87            TemplateEncoder::RedisGeoValue(encoder) => encoder.col_indices.as_deref(),
88            TemplateEncoder::RedisGeoKey(encoder) => encoder.key_encoder.col_indices.as_deref(),
89            TemplateEncoder::RedisPubSubKey(encoder) => encoder.col_indices.as_deref(),
90        }
91    }
92
93    fn encode_cols(
94        &self,
95        row: impl Row,
96        col_indices: impl Iterator<Item = usize>,
97    ) -> Result<Self::Output> {
98        match self {
99            TemplateEncoder::String(encoder) => Ok(TemplateEncoderOutput::String(
100                encoder.encode_cols(row, col_indices)?,
101            )),
102            TemplateEncoder::RedisGeoValue(encoder) => encoder.encode_cols(row, col_indices),
103            TemplateEncoder::RedisGeoKey(encoder) => encoder.encode_cols(row, col_indices),
104            TemplateEncoder::RedisPubSubKey(encoder) => encoder.encode_cols(row, col_indices),
105        }
106    }
107}
108/// Encode a row according to a specified string template `user_id:{user_id}`.
109/// Data is encoded to string with [`ToText`].
110pub struct TemplateStringEncoder {
111    field_name_to_index: HashMap<String, (usize, Field)>,
112    col_indices: Option<Vec<usize>>,
113    template: String,
114    schema: Schema,
115}
116
117/// todo! improve the performance.
118impl TemplateStringEncoder {
119    pub fn new(schema: Schema, col_indices: Option<Vec<usize>>, template: String) -> Self {
120        let field_name_to_index = schema
121            .fields()
122            .iter()
123            .enumerate()
124            .map(|(index, field)| (field.name.clone(), (index, field.clone())))
125            .collect();
126        Self {
127            field_name_to_index,
128            col_indices,
129            template,
130            schema,
131        }
132    }
133
134    pub fn check_string_format(format: &str, map: &HashMap<String, DataType>) -> Result<()> {
135        // We will check if the string inside {} corresponds to a column name in rw.
136        let re = Regex::new(r"(\\\})|(\\\{)|\{([^}]*)\}").unwrap();
137        if !re.is_match(format) {
138            return Err(SinkError::Redis(
139                "Can't find {} in key_format or value_format".to_owned(),
140            ));
141        }
142        for capture in re.captures_iter(format) {
143            if let Some(inner_content) = capture.get(3)
144                && !map.contains_key(inner_content.as_str())
145            {
146                return Err(SinkError::Redis(format!(
147                    "Can't find field({:?}) in key_format or value_format",
148                    inner_content.as_str()
149                )));
150            }
151        }
152        Ok(())
153    }
154
155    pub fn encode_cols(
156        &self,
157        row: impl Row,
158        col_indices: impl Iterator<Item = usize>,
159    ) -> Result<String> {
160        let s = self.template.clone();
161        let re = Regex::new(r"(\\\})|(\\\{)|\{([^}]*)\}").unwrap();
162        let col_indices: Vec<_> = col_indices.collect();
163        let replaced = re.replace_all(s.as_ref(), |caps: &Captures<'_>| {
164            if caps.get(1).is_some() {
165                Cow::Borrowed("}")
166            } else if caps.get(2).is_some() {
167                Cow::Borrowed("{")
168            } else if let Some(content) = caps.get(3) {
169                let (idx, field) = self.field_name_to_index.get(content.as_str()).unwrap();
170                if col_indices.contains(idx) {
171                    let data = row.datum_at(*idx).to_text_with_type(&field.data_type);
172                    Cow::Owned(data)
173                } else {
174                    Cow::Borrowed("")
175                }
176            } else {
177                Cow::Borrowed("")
178            }
179        });
180        Ok(replaced.to_string())
181    }
182}
183
184pub struct TemplateRedisGeoValueEncoder {
185    schema: Schema,
186    col_indices: Option<Vec<usize>>,
187    lat_col: usize,
188    lon_col: usize,
189}
190
191impl TemplateRedisGeoValueEncoder {
192    pub fn new(
193        schema: Schema,
194        col_indices: Option<Vec<usize>>,
195        lat_name: &str,
196        lon_name: &str,
197    ) -> Result<Self> {
198        let lat_col = schema
199            .names_str()
200            .iter()
201            .position(|name| name == &lat_name)
202            .ok_or_else(|| {
203                SinkError::Redis(format!("Can't find lat column({}) in schema", lat_name))
204            })?;
205        let lon_col = schema
206            .names_str()
207            .iter()
208            .position(|name| name == &lon_name)
209            .ok_or_else(|| {
210                SinkError::Redis(format!("Can't find lon column({}) in schema", lon_name))
211            })?;
212        Ok(Self {
213            schema,
214            col_indices,
215            lat_col,
216            lon_col,
217        })
218    }
219
220    pub fn encode_cols(
221        &self,
222        row: impl Row,
223        _col_indices: impl Iterator<Item = usize>,
224    ) -> Result<TemplateEncoderOutput> {
225        let lat = into_string_from_scalar(
226            row.datum_at(self.lat_col)
227                .ok_or_else(|| SinkError::Redis("lat is null".to_owned()))?,
228        )?;
229        let lon = into_string_from_scalar(
230            row.datum_at(self.lon_col)
231                .ok_or_else(|| SinkError::Redis("lon is null".to_owned()))?,
232        )?;
233        Ok(TemplateEncoderOutput::RedisGeoValue((lat, lon)))
234    }
235}
236
237fn into_string_from_scalar(scalar: ScalarRefImpl<'_>) -> Result<String> {
238    match scalar {
239        ScalarRefImpl::Float32(ordered_float) => Ok(Into::<f32>::into(ordered_float).to_string()),
240        ScalarRefImpl::Float64(ordered_float) => Ok(Into::<f64>::into(ordered_float).to_string()),
241        ScalarRefImpl::Utf8(s) => Ok(s.to_owned()),
242        _ => Err(SinkError::Encode(
243            "Only f32 and f64 can convert to redis geo".to_owned(),
244        )),
245    }
246}
247
248pub struct TemplateRedisGeoKeyEncoder {
249    key_encoder: TemplateStringEncoder,
250    member_col: usize,
251}
252
253impl TemplateRedisGeoKeyEncoder {
254    pub fn new(
255        schema: Schema,
256        col_indices: Option<Vec<usize>>,
257        member_name: &str,
258        template: String,
259    ) -> Result<Self> {
260        let member_col = schema
261            .names_str()
262            .iter()
263            .position(|name| name == &member_name)
264            .ok_or_else(|| {
265                SinkError::Redis(format!(
266                    "Can't find member column({}) in schema",
267                    member_name
268                ))
269            })?;
270        let key_encoder = TemplateStringEncoder::new(schema, col_indices, template);
271        Ok(Self {
272            key_encoder,
273            member_col,
274        })
275    }
276
277    pub fn encode_cols(
278        &self,
279        row: impl Row,
280        col_indices: impl Iterator<Item = usize>,
281    ) -> Result<TemplateEncoderOutput> {
282        let member = row
283            .datum_at(self.member_col)
284            .ok_or_else(|| SinkError::Redis("member is null".to_owned()))?
285            .to_text()
286            .clone();
287        let key = self.key_encoder.encode_cols(row, col_indices)?;
288        Ok(TemplateEncoderOutput::RedisGeoKey((key, member)))
289    }
290}
291
292pub enum TemplateRedisPubSubKeyEncoderInner {
293    PubSubName(String),
294    PubSubColumnIndex(usize),
295}
296pub struct TemplateRedisPubSubKeyEncoder {
297    inner: TemplateRedisPubSubKeyEncoderInner,
298    schema: Schema,
299    col_indices: Option<Vec<usize>>,
300}
301
302impl TemplateRedisPubSubKeyEncoder {
303    pub fn new(
304        schema: Schema,
305        col_indices: Option<Vec<usize>>,
306        channel: Option<String>,
307        channel_column: Option<String>,
308    ) -> Result<Self> {
309        if let Some(channel) = channel {
310            return Ok(Self {
311                inner: TemplateRedisPubSubKeyEncoderInner::PubSubName(channel),
312                schema,
313                col_indices,
314            });
315        }
316        if let Some(channel_column) = channel_column {
317            let channel_column_index = schema
318                .names_str()
319                .iter()
320                .position(|name| name == &channel_column)
321                .ok_or_else(|| {
322                    SinkError::Redis(format!(
323                        "Can't find pubsub column({}) in schema",
324                        channel_column
325                    ))
326                })?;
327            return Ok(Self {
328                inner: TemplateRedisPubSubKeyEncoderInner::PubSubColumnIndex(channel_column_index),
329                schema,
330                col_indices,
331            });
332        }
333        Err(SinkError::Redis(
334            "`channel` or `channel_column` must be set".to_owned(),
335        ))
336    }
337
338    pub fn encode_cols(
339        &self,
340        row: impl Row,
341        _col_indices: impl Iterator<Item = usize>,
342    ) -> Result<TemplateEncoderOutput> {
343        match &self.inner {
344            TemplateRedisPubSubKeyEncoderInner::PubSubName(channel) => {
345                Ok(TemplateEncoderOutput::RedisPubSubKey(channel.clone()))
346            }
347            TemplateRedisPubSubKeyEncoderInner::PubSubColumnIndex(pubsub_col) => {
348                let pubsub_key = row
349                    .datum_at(*pubsub_col)
350                    .ok_or_else(|| SinkError::Redis("pubsub_key is null".to_owned()))?
351                    .to_text()
352                    .clone();
353                Ok(TemplateEncoderOutput::RedisPubSubKey(pubsub_key))
354            }
355        }
356    }
357}
358
359pub enum TemplateEncoderOutput {
360    // String formatted according to the template
361    String(String),
362    // The value of redis's geospatial, including longitude and latitude
363    RedisGeoValue((String, String)),
364    // The key of redis's geospatial, including redis's key and member
365    RedisGeoKey((String, String)),
366
367    RedisPubSubKey(String),
368}
369
370impl TemplateEncoderOutput {
371    pub fn into_string(self) -> Result<String> {
372        match self {
373            TemplateEncoderOutput::String(s) => Ok(s),
374            TemplateEncoderOutput::RedisGeoKey(_) => Err(SinkError::Encode(
375                "RedisGeoKey can't convert to string".to_owned(),
376            )),
377            TemplateEncoderOutput::RedisGeoValue(_) => Err(SinkError::Encode(
378                "RedisGeoVelue can't convert to string".to_owned(),
379            )),
380            TemplateEncoderOutput::RedisPubSubKey(s) => Ok(s),
381        }
382    }
383}
384
385impl SerTo<String> for TemplateEncoderOutput {
386    fn ser_to(self) -> Result<String> {
387        match self {
388            TemplateEncoderOutput::String(s) => Ok(s),
389            TemplateEncoderOutput::RedisGeoKey(_) => Err(SinkError::Encode(
390                "RedisGeoKey can't convert to string".to_owned(),
391            )),
392            TemplateEncoderOutput::RedisGeoValue(_) => Err(SinkError::Encode(
393                "RedisGeoVelue can't convert to string".to_owned(),
394            )),
395            TemplateEncoderOutput::RedisPubSubKey(s) => Ok(s),
396        }
397    }
398}
399
400/// The enum of inputs to `RedisSinkPayloadWriter`
401#[derive(Debug)]
402pub enum RedisSinkPayloadWriterInput {
403    // Json and String will be convert to string
404    String(String),
405    // The value of redis's geospatial, including longitude and latitude
406    RedisGeoValue((String, String)),
407    // The key of redis's geospatial, including redis's key and member
408    RedisGeoKey((String, String)),
409    RedisPubSubKey(String),
410}
411
412impl SerTo<RedisSinkPayloadWriterInput> for TemplateEncoderOutput {
413    fn ser_to(self) -> Result<RedisSinkPayloadWriterInput> {
414        match self {
415            TemplateEncoderOutput::String(s) => Ok(RedisSinkPayloadWriterInput::String(s)),
416            TemplateEncoderOutput::RedisGeoKey((lat, lon)) => {
417                Ok(RedisSinkPayloadWriterInput::RedisGeoKey((lat, lon)))
418            }
419            TemplateEncoderOutput::RedisGeoValue((key, member)) => {
420                Ok(RedisSinkPayloadWriterInput::RedisGeoValue((key, member)))
421            }
422            TemplateEncoderOutput::RedisPubSubKey(s) => {
423                Ok(RedisSinkPayloadWriterInput::RedisPubSubKey(s))
424            }
425        }
426    }
427}
428
429impl<T: SerTo<Vec<u8>>> SerTo<RedisSinkPayloadWriterInput> for T {
430    default fn ser_to(self) -> Result<RedisSinkPayloadWriterInput> {
431        let bytes = self.ser_to()?;
432        Ok(RedisSinkPayloadWriterInput::String(
433            String::from_utf8(bytes).map_err(|e| SinkError::Redis(e.to_report_string()))?,
434        ))
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use risingwave_common::catalog::{Field, Schema};
441    use risingwave_common::row::OwnedRow;
442    use risingwave_common::types::{DataType, ScalarImpl};
443
444    use super::*;
445
446    #[test]
447    fn test_template_format_validation() {
448        // Create a schema with test columns
449        let schema = Schema::new(vec![
450            Field {
451                data_type: DataType::Int32,
452                name: "id".to_owned(),
453            },
454            Field {
455                data_type: DataType::Varchar,
456                name: "name".to_owned(),
457            },
458            Field {
459                data_type: DataType::Varchar,
460                name: "email".to_owned(),
461            },
462        ]);
463
464        // Create a map of column names to their data types
465        let mut map = HashMap::new();
466        for field in schema.fields() {
467            map.insert(field.name.clone(), field.data_type.clone());
468        }
469
470        // Test various template formats
471        let valid_templates = vec![
472            "user:{id}",
473            "user:\\{{id}",
474            "user:\\{{id}\\}",
475            "user:\\{{id},{name}\\}",
476            "user:\\{prefix{id},suffix{name}\\}",
477            "user:\\{prefix{id},suffix{name},email:{email}\\}",
478            "user:\\{nested\\{deeply{id}\\}\\}",
479            "user:\\{outer\\{inner{id}\\},another{name}\\}",
480            "user:\\{complex\\{structure\\{with{id}\\},and{name}\\},email:{email}\\}",
481            "user:{id}{name}",
482            "user:\\\\{id}",
483            "user:\\\\\\{id}",
484            "user:\\a{id}",
485            "user:\\b{name}",
486            "user:{id}{name}{email}",
487        ];
488
489        for template in valid_templates {
490            // Validate the template format
491            assert!(
492                TemplateStringEncoder::check_string_format(template, &map).is_ok(),
493                "Template '{}' should be valid",
494                template
495            );
496        }
497
498        // Test invalid templates
499        let invalid_templates = vec![
500            "user:no_braces",        // No braces
501            "user:{invalid_column}", // Non-existent column
502            "user:{id",              // Unclosed brace
503            "user:id}",              // Unopened brace
504            "sadsadsad{}qw4e2ewq21", // Empty braces
505            "user:{}",
506            "user:{\\id}",
507        ];
508
509        for template in invalid_templates {
510            // Validate the template format
511            assert!(
512                TemplateStringEncoder::check_string_format(template, &map).is_err(),
513                "Template '{}' should be invalid",
514                template
515            );
516        }
517    }
518
519    #[test]
520    fn test_template_encoding() {
521        // Create a schema with test columns
522        let schema = Schema::new(vec![
523            Field {
524                data_type: DataType::Int32,
525                name: "id".to_owned(),
526            },
527            Field {
528                data_type: DataType::Varchar,
529                name: "name".to_owned(),
530            },
531            Field {
532                data_type: DataType::Varchar,
533                name: "email".to_owned(),
534            },
535        ]);
536
537        // Test cases with different template formats
538        let test_cases = vec![
539            ("user:{id}", "user:123", vec![0]),
540            ("user:\\{id\\}", "user:{id}", vec![0]),
541            ("user:\\{id,name\\}", "user:{id,name}", vec![0, 1]),
542            (
543                "user:\\{prefix{id},suffix{name}\\}",
544                "user:{prefix123,suffixJohn Doe}",
545                vec![0, 1],
546            ),
547            (
548                "user:\\{nested\\{deeply{id}\\}\\}",
549                "user:{nested{deeply123}}",
550                vec![0],
551            ),
552            (
553                "user:\\{outer\\{inner{id}\\},another{name}\\}",
554                "user:{outer{inner123},anotherJohn Doe}",
555                vec![0, 1],
556            ),
557            ("user:{id}{name}", "user:123John Doe", vec![0, 1]),
558            ("user:\\{id\\}{name}", "user:{id}John Doe", vec![0, 1]),
559            ("user:\\\\{id}", "user:\\{id}", vec![0]),
560            ("user:\\\\\\{id}", "user:\\\\{id}", vec![0]),
561            ("user:\\a{id}", "user:\\a123", vec![0]),
562            ("user:\\b{name}", "user:\\bJohn Doe", vec![1]),
563            (
564                "user:{id}{name}{email}",
565                "user:123John Doejohn@example.com",
566                vec![0, 1, 2],
567            ),
568        ];
569
570        for (template, expected, col_indices) in test_cases {
571            // Create an encoder with the template
572            let encoder = TemplateStringEncoder::new(
573                schema.clone(),
574                Some(col_indices.clone()),
575                template.to_owned(),
576            );
577
578            // Create a test row
579            let row = OwnedRow::new(vec![
580                Some(ScalarImpl::Int32(123)),
581                Some(ScalarImpl::Utf8("John Doe".into())),
582                Some(ScalarImpl::Utf8("john@example.com".into())),
583            ]);
584
585            // Encode the row
586            let result = encoder.encode_cols(row, col_indices.into_iter()).unwrap();
587
588            // Check the result
589            assert_eq!(result, expected, "Template '{}' encoding failed", template);
590        }
591    }
592
593    #[test]
594    fn test_complex_nested_template() {
595        // Create a schema with test columns
596        let schema = Schema::new(vec![
597            Field {
598                data_type: DataType::Int32,
599                name: "id".to_owned(),
600            },
601            Field {
602                data_type: DataType::Varchar,
603                name: "name".to_owned(),
604            },
605            Field {
606                data_type: DataType::Varchar,
607                name: "email".to_owned(),
608            },
609        ]);
610
611        // Create a map of column names to their data types
612        let mut map = HashMap::new();
613        for field in schema.fields() {
614            map.insert(field.name.clone(), field.data_type.clone());
615        }
616
617        // Test a very complex nested template
618        let complex_template = "user:\\{prefix{id},suffix{name},email:{email},nested\\{deeply{id}\\},outer\\{inner{name}\\}\\}";
619
620        // Validate the template format
621        assert!(TemplateStringEncoder::check_string_format(complex_template, &map).is_ok());
622
623        // Create an encoder with the template
624        let encoder = TemplateStringEncoder::new(
625            schema.clone(),
626            Some(vec![0, 1, 2]), // Include all columns
627            complex_template.to_owned(),
628        );
629
630        // Create a test row
631        let row = OwnedRow::new(vec![
632            Some(ScalarImpl::Int32(123)),
633            Some(ScalarImpl::Utf8("John Doe".into())),
634            Some(ScalarImpl::Utf8("john@example.com".into())),
635        ]);
636
637        // Encode the row
638        let result = encoder.encode_cols(row, vec![0, 1, 2].into_iter()).unwrap();
639
640        // Check that all column values are in the result
641        assert_eq!(
642            result,
643            "user:{prefix123,suffixJohn Doe,email:john@example.com,nested{deeply123},outer{innerJohn Doe}}"
644        );
645    }
646}