1use std::borrow::Cow;
16use std::collections::HashMap;
17
18use regex::{Captures, Regex};
19use risingwave_common::catalog::{Field, Schema};
20use risingwave_common::row::Row;
21use risingwave_common::types::{DataType, ScalarRefImpl, ToText};
22use thiserror_ext::AsReport;
23
24use super::{Result, RowEncoder};
25use crate::sink::SinkError;
26use crate::sink::encoder::SerTo;
27
28pub enum TemplateEncoder {
29    String(TemplateStringEncoder),
30    RedisGeoKey(TemplateRedisGeoKeyEncoder),
31    RedisGeoValue(TemplateRedisGeoValueEncoder),
32    RedisPubSubKey(TemplateRedisPubSubKeyEncoder),
33}
34impl TemplateEncoder {
35    pub fn new_string(schema: Schema, col_indices: Option<Vec<usize>>, template: String) -> Self {
36        TemplateEncoder::String(TemplateStringEncoder::new(schema, col_indices, template))
37    }
38
39    pub fn new_geo_value(
40        schema: Schema,
41        col_indices: Option<Vec<usize>>,
42        lat_name: &str,
43        lon_name: &str,
44    ) -> Result<Self> {
45        Ok(TemplateEncoder::RedisGeoValue(
46            TemplateRedisGeoValueEncoder::new(schema, col_indices, lat_name, lon_name)?,
47        ))
48    }
49
50    pub fn new_geo_key(
51        schema: Schema,
52        col_indices: Option<Vec<usize>>,
53        member_name: &str,
54        template: String,
55    ) -> Result<Self> {
56        Ok(TemplateEncoder::RedisGeoKey(
57            TemplateRedisGeoKeyEncoder::new(schema, col_indices, member_name, template)?,
58        ))
59    }
60
61    pub fn new_pubsub_key(
62        schema: Schema,
63        col_indices: Option<Vec<usize>>,
64        channel: Option<String>,
65        channel_column: Option<String>,
66    ) -> Result<Self> {
67        Ok(TemplateEncoder::RedisPubSubKey(
68            TemplateRedisPubSubKeyEncoder::new(schema, col_indices, channel, channel_column)?,
69        ))
70    }
71}
72impl RowEncoder for TemplateEncoder {
73    type Output = TemplateEncoderOutput;
74
75    fn schema(&self) -> &Schema {
76        match self {
77            TemplateEncoder::String(encoder) => &encoder.schema,
78            TemplateEncoder::RedisGeoValue(encoder) => &encoder.schema,
79            TemplateEncoder::RedisGeoKey(encoder) => &encoder.key_encoder.schema,
80            TemplateEncoder::RedisPubSubKey(encoder) => &encoder.schema,
81        }
82    }
83
84    fn col_indices(&self) -> Option<&[usize]> {
85        match self {
86            TemplateEncoder::String(encoder) => encoder.col_indices.as_deref(),
87            TemplateEncoder::RedisGeoValue(encoder) => encoder.col_indices.as_deref(),
88            TemplateEncoder::RedisGeoKey(encoder) => encoder.key_encoder.col_indices.as_deref(),
89            TemplateEncoder::RedisPubSubKey(encoder) => encoder.col_indices.as_deref(),
90        }
91    }
92
93    fn encode_cols(
94        &self,
95        row: impl Row,
96        col_indices: impl Iterator<Item = usize>,
97    ) -> Result<Self::Output> {
98        match self {
99            TemplateEncoder::String(encoder) => Ok(TemplateEncoderOutput::String(
100                encoder.encode_cols(row, col_indices)?,
101            )),
102            TemplateEncoder::RedisGeoValue(encoder) => encoder.encode_cols(row, col_indices),
103            TemplateEncoder::RedisGeoKey(encoder) => encoder.encode_cols(row, col_indices),
104            TemplateEncoder::RedisPubSubKey(encoder) => encoder.encode_cols(row, col_indices),
105        }
106    }
107}
108pub struct TemplateStringEncoder {
111    field_name_to_index: HashMap<String, (usize, Field)>,
112    col_indices: Option<Vec<usize>>,
113    template: String,
114    schema: Schema,
115}
116
117impl TemplateStringEncoder {
119    pub fn new(schema: Schema, col_indices: Option<Vec<usize>>, template: String) -> Self {
120        let field_name_to_index = schema
121            .fields()
122            .iter()
123            .enumerate()
124            .map(|(index, field)| (field.name.clone(), (index, field.clone())))
125            .collect();
126        Self {
127            field_name_to_index,
128            col_indices,
129            template,
130            schema,
131        }
132    }
133
134    pub fn check_string_format(format: &str, map: &HashMap<String, DataType>) -> Result<()> {
135        let re = Regex::new(r"(\\\})|(\\\{)|\{([^}]*)\}").unwrap();
137        if !re.is_match(format) {
138            return Err(SinkError::Redis(
139                "Can't find {} in key_format or value_format".to_owned(),
140            ));
141        }
142        for capture in re.captures_iter(format) {
143            if let Some(inner_content) = capture.get(3)
144                && !map.contains_key(inner_content.as_str())
145            {
146                return Err(SinkError::Redis(format!(
147                    "Can't find field({:?}) in key_format or value_format",
148                    inner_content.as_str()
149                )));
150            }
151        }
152        Ok(())
153    }
154
155    pub fn encode_cols(
156        &self,
157        row: impl Row,
158        col_indices: impl Iterator<Item = usize>,
159    ) -> Result<String> {
160        let s = self.template.clone();
161        let re = Regex::new(r"(\\\})|(\\\{)|\{([^}]*)\}").unwrap();
162        let col_indices: Vec<_> = col_indices.collect();
163        let replaced = re.replace_all(s.as_ref(), |caps: &Captures<'_>| {
164            if caps.get(1).is_some() {
165                Cow::Borrowed("}")
166            } else if caps.get(2).is_some() {
167                Cow::Borrowed("{")
168            } else if let Some(content) = caps.get(3) {
169                let (idx, field) = self.field_name_to_index.get(content.as_str()).unwrap();
170                if col_indices.contains(idx) {
171                    let data = row.datum_at(*idx).to_text_with_type(&field.data_type);
172                    Cow::Owned(data)
173                } else {
174                    Cow::Borrowed("")
175                }
176            } else {
177                Cow::Borrowed("")
178            }
179        });
180        Ok(replaced.to_string())
181    }
182}
183
184pub struct TemplateRedisGeoValueEncoder {
185    schema: Schema,
186    col_indices: Option<Vec<usize>>,
187    lat_col: usize,
188    lon_col: usize,
189}
190
191impl TemplateRedisGeoValueEncoder {
192    pub fn new(
193        schema: Schema,
194        col_indices: Option<Vec<usize>>,
195        lat_name: &str,
196        lon_name: &str,
197    ) -> Result<Self> {
198        let lat_col = schema
199            .names_str()
200            .iter()
201            .position(|name| name == &lat_name)
202            .ok_or_else(|| {
203                SinkError::Redis(format!("Can't find lat column({}) in schema", lat_name))
204            })?;
205        let lon_col = schema
206            .names_str()
207            .iter()
208            .position(|name| name == &lon_name)
209            .ok_or_else(|| {
210                SinkError::Redis(format!("Can't find lon column({}) in schema", lon_name))
211            })?;
212        Ok(Self {
213            schema,
214            col_indices,
215            lat_col,
216            lon_col,
217        })
218    }
219
220    pub fn encode_cols(
221        &self,
222        row: impl Row,
223        _col_indices: impl Iterator<Item = usize>,
224    ) -> Result<TemplateEncoderOutput> {
225        let lat = into_string_from_scalar(
226            row.datum_at(self.lat_col)
227                .ok_or_else(|| SinkError::Redis("lat is null".to_owned()))?,
228        )?;
229        let lon = into_string_from_scalar(
230            row.datum_at(self.lon_col)
231                .ok_or_else(|| SinkError::Redis("lon is null".to_owned()))?,
232        )?;
233        Ok(TemplateEncoderOutput::RedisGeoValue((lat, lon)))
234    }
235}
236
237fn into_string_from_scalar(scalar: ScalarRefImpl<'_>) -> Result<String> {
238    match scalar {
239        ScalarRefImpl::Float32(ordered_float) => Ok(Into::<f32>::into(ordered_float).to_string()),
240        ScalarRefImpl::Float64(ordered_float) => Ok(Into::<f64>::into(ordered_float).to_string()),
241        ScalarRefImpl::Utf8(s) => Ok(s.to_owned()),
242        _ => Err(SinkError::Encode(
243            "Only f32 and f64 can convert to redis geo".to_owned(),
244        )),
245    }
246}
247
248pub struct TemplateRedisGeoKeyEncoder {
249    key_encoder: TemplateStringEncoder,
250    member_col: usize,
251}
252
253impl TemplateRedisGeoKeyEncoder {
254    pub fn new(
255        schema: Schema,
256        col_indices: Option<Vec<usize>>,
257        member_name: &str,
258        template: String,
259    ) -> Result<Self> {
260        let member_col = schema
261            .names_str()
262            .iter()
263            .position(|name| name == &member_name)
264            .ok_or_else(|| {
265                SinkError::Redis(format!(
266                    "Can't find member column({}) in schema",
267                    member_name
268                ))
269            })?;
270        let key_encoder = TemplateStringEncoder::new(schema, col_indices, template);
271        Ok(Self {
272            key_encoder,
273            member_col,
274        })
275    }
276
277    pub fn encode_cols(
278        &self,
279        row: impl Row,
280        col_indices: impl Iterator<Item = usize>,
281    ) -> Result<TemplateEncoderOutput> {
282        let member = row
283            .datum_at(self.member_col)
284            .ok_or_else(|| SinkError::Redis("member is null".to_owned()))?
285            .to_text();
286        let key = self.key_encoder.encode_cols(row, col_indices)?;
287        Ok(TemplateEncoderOutput::RedisGeoKey((key, member)))
288    }
289}
290
291pub enum TemplateRedisPubSubKeyEncoderInner {
292    PubSubName(String),
293    PubSubColumnIndex(usize),
294}
295pub struct TemplateRedisPubSubKeyEncoder {
296    inner: TemplateRedisPubSubKeyEncoderInner,
297    schema: Schema,
298    col_indices: Option<Vec<usize>>,
299}
300
301impl TemplateRedisPubSubKeyEncoder {
302    pub fn new(
303        schema: Schema,
304        col_indices: Option<Vec<usize>>,
305        channel: Option<String>,
306        channel_column: Option<String>,
307    ) -> Result<Self> {
308        if let Some(channel) = channel {
309            return Ok(Self {
310                inner: TemplateRedisPubSubKeyEncoderInner::PubSubName(channel),
311                schema,
312                col_indices,
313            });
314        }
315        if let Some(channel_column) = channel_column {
316            let channel_column_index = schema
317                .names_str()
318                .iter()
319                .position(|name| name == &channel_column)
320                .ok_or_else(|| {
321                    SinkError::Redis(format!(
322                        "Can't find pubsub column({}) in schema",
323                        channel_column
324                    ))
325                })?;
326            return Ok(Self {
327                inner: TemplateRedisPubSubKeyEncoderInner::PubSubColumnIndex(channel_column_index),
328                schema,
329                col_indices,
330            });
331        }
332        Err(SinkError::Redis(
333            "`channel` or `channel_column` must be set".to_owned(),
334        ))
335    }
336
337    pub fn encode_cols(
338        &self,
339        row: impl Row,
340        _col_indices: impl Iterator<Item = usize>,
341    ) -> Result<TemplateEncoderOutput> {
342        match &self.inner {
343            TemplateRedisPubSubKeyEncoderInner::PubSubName(channel) => {
344                Ok(TemplateEncoderOutput::RedisPubSubKey(channel.clone()))
345            }
346            TemplateRedisPubSubKeyEncoderInner::PubSubColumnIndex(pubsub_col) => {
347                let pubsub_key = row
348                    .datum_at(*pubsub_col)
349                    .ok_or_else(|| SinkError::Redis("pubsub_key is null".to_owned()))?
350                    .to_text();
351                Ok(TemplateEncoderOutput::RedisPubSubKey(pubsub_key))
352            }
353        }
354    }
355}
356
357pub enum TemplateEncoderOutput {
358    String(String),
360    RedisGeoValue((String, String)),
362    RedisGeoKey((String, String)),
364
365    RedisPubSubKey(String),
366}
367
368impl TemplateEncoderOutput {
369    pub fn into_string(self) -> Result<String> {
370        match self {
371            TemplateEncoderOutput::String(s) => Ok(s),
372            TemplateEncoderOutput::RedisGeoKey(_) => Err(SinkError::Encode(
373                "RedisGeoKey can't convert to string".to_owned(),
374            )),
375            TemplateEncoderOutput::RedisGeoValue(_) => Err(SinkError::Encode(
376                "RedisGeoVelue can't convert to string".to_owned(),
377            )),
378            TemplateEncoderOutput::RedisPubSubKey(s) => Ok(s),
379        }
380    }
381}
382
383impl SerTo<String> for TemplateEncoderOutput {
384    fn ser_to(self) -> Result<String> {
385        match self {
386            TemplateEncoderOutput::String(s) => Ok(s),
387            TemplateEncoderOutput::RedisGeoKey(_) => Err(SinkError::Encode(
388                "RedisGeoKey can't convert to string".to_owned(),
389            )),
390            TemplateEncoderOutput::RedisGeoValue(_) => Err(SinkError::Encode(
391                "RedisGeoVelue can't convert to string".to_owned(),
392            )),
393            TemplateEncoderOutput::RedisPubSubKey(s) => Ok(s),
394        }
395    }
396}
397
398#[derive(Debug)]
400pub enum RedisSinkPayloadWriterInput {
401    String(String),
403    RedisGeoValue((String, String)),
405    RedisGeoKey((String, String)),
407    RedisPubSubKey(String),
408}
409
410impl SerTo<RedisSinkPayloadWriterInput> for TemplateEncoderOutput {
411    fn ser_to(self) -> Result<RedisSinkPayloadWriterInput> {
412        match self {
413            TemplateEncoderOutput::String(s) => Ok(RedisSinkPayloadWriterInput::String(s)),
414            TemplateEncoderOutput::RedisGeoKey((lat, lon)) => {
415                Ok(RedisSinkPayloadWriterInput::RedisGeoKey((lat, lon)))
416            }
417            TemplateEncoderOutput::RedisGeoValue((key, member)) => {
418                Ok(RedisSinkPayloadWriterInput::RedisGeoValue((key, member)))
419            }
420            TemplateEncoderOutput::RedisPubSubKey(s) => {
421                Ok(RedisSinkPayloadWriterInput::RedisPubSubKey(s))
422            }
423        }
424    }
425}
426
427impl<T: SerTo<Vec<u8>>> SerTo<RedisSinkPayloadWriterInput> for T {
428    default fn ser_to(self) -> Result<RedisSinkPayloadWriterInput> {
429        let bytes = self.ser_to()?;
430        Ok(RedisSinkPayloadWriterInput::String(
431            String::from_utf8(bytes).map_err(|e| SinkError::Redis(e.to_report_string()))?,
432        ))
433    }
434}
435
436#[cfg(test)]
437mod tests {
438    use risingwave_common::catalog::{Field, Schema};
439    use risingwave_common::row::OwnedRow;
440    use risingwave_common::types::{DataType, ScalarImpl};
441
442    use super::*;
443
444    #[test]
445    fn test_template_format_validation() {
446        let schema = Schema::new(vec![
448            Field {
449                data_type: DataType::Int32,
450                name: "id".to_owned(),
451            },
452            Field {
453                data_type: DataType::Varchar,
454                name: "name".to_owned(),
455            },
456            Field {
457                data_type: DataType::Varchar,
458                name: "email".to_owned(),
459            },
460        ]);
461
462        let mut map = HashMap::new();
464        for field in schema.fields() {
465            map.insert(field.name.clone(), field.data_type.clone());
466        }
467
468        let valid_templates = vec![
470            "user:{id}",
471            "user:\\{{id}",
472            "user:\\{{id}\\}",
473            "user:\\{{id},{name}\\}",
474            "user:\\{prefix{id},suffix{name}\\}",
475            "user:\\{prefix{id},suffix{name},email:{email}\\}",
476            "user:\\{nested\\{deeply{id}\\}\\}",
477            "user:\\{outer\\{inner{id}\\},another{name}\\}",
478            "user:\\{complex\\{structure\\{with{id}\\},and{name}\\},email:{email}\\}",
479            "user:{id}{name}",
480            "user:\\\\{id}",
481            "user:\\\\\\{id}",
482            "user:\\a{id}",
483            "user:\\b{name}",
484            "user:{id}{name}{email}",
485        ];
486
487        for template in valid_templates {
488            assert!(
490                TemplateStringEncoder::check_string_format(template, &map).is_ok(),
491                "Template '{}' should be valid",
492                template
493            );
494        }
495
496        let invalid_templates = vec![
498            "user:no_braces",        "user:{invalid_column}", "user:{id",              "user:id}",              "sadsadsad{}qw4e2ewq21", "user:{}",
504            "user:{\\id}",
505        ];
506
507        for template in invalid_templates {
508            assert!(
510                TemplateStringEncoder::check_string_format(template, &map).is_err(),
511                "Template '{}' should be invalid",
512                template
513            );
514        }
515    }
516
517    #[test]
518    fn test_template_encoding() {
519        let schema = Schema::new(vec![
521            Field {
522                data_type: DataType::Int32,
523                name: "id".to_owned(),
524            },
525            Field {
526                data_type: DataType::Varchar,
527                name: "name".to_owned(),
528            },
529            Field {
530                data_type: DataType::Varchar,
531                name: "email".to_owned(),
532            },
533        ]);
534
535        let test_cases = vec![
537            ("user:{id}", "user:123", vec![0]),
538            ("user:\\{id\\}", "user:{id}", vec![0]),
539            ("user:\\{id,name\\}", "user:{id,name}", vec![0, 1]),
540            (
541                "user:\\{prefix{id},suffix{name}\\}",
542                "user:{prefix123,suffixJohn Doe}",
543                vec![0, 1],
544            ),
545            (
546                "user:\\{nested\\{deeply{id}\\}\\}",
547                "user:{nested{deeply123}}",
548                vec![0],
549            ),
550            (
551                "user:\\{outer\\{inner{id}\\},another{name}\\}",
552                "user:{outer{inner123},anotherJohn Doe}",
553                vec![0, 1],
554            ),
555            ("user:{id}{name}", "user:123John Doe", vec![0, 1]),
556            ("user:\\{id\\}{name}", "user:{id}John Doe", vec![0, 1]),
557            ("user:\\\\{id}", "user:\\{id}", vec![0]),
558            ("user:\\\\\\{id}", "user:\\\\{id}", vec![0]),
559            ("user:\\a{id}", "user:\\a123", vec![0]),
560            ("user:\\b{name}", "user:\\bJohn Doe", vec![1]),
561            (
562                "user:{id}{name}{email}",
563                "user:123John Doejohn@example.com",
564                vec![0, 1, 2],
565            ),
566        ];
567
568        for (template, expected, col_indices) in test_cases {
569            let encoder = TemplateStringEncoder::new(
571                schema.clone(),
572                Some(col_indices.clone()),
573                template.to_owned(),
574            );
575
576            let row = OwnedRow::new(vec![
578                Some(ScalarImpl::Int32(123)),
579                Some(ScalarImpl::Utf8("John Doe".into())),
580                Some(ScalarImpl::Utf8("john@example.com".into())),
581            ]);
582
583            let result = encoder.encode_cols(row, col_indices.into_iter()).unwrap();
585
586            assert_eq!(result, expected, "Template '{}' encoding failed", template);
588        }
589    }
590
591    #[test]
592    fn test_complex_nested_template() {
593        let schema = Schema::new(vec![
595            Field {
596                data_type: DataType::Int32,
597                name: "id".to_owned(),
598            },
599            Field {
600                data_type: DataType::Varchar,
601                name: "name".to_owned(),
602            },
603            Field {
604                data_type: DataType::Varchar,
605                name: "email".to_owned(),
606            },
607        ]);
608
609        let mut map = HashMap::new();
611        for field in schema.fields() {
612            map.insert(field.name.clone(), field.data_type.clone());
613        }
614
615        let complex_template = "user:\\{prefix{id},suffix{name},email:{email},nested\\{deeply{id}\\},outer\\{inner{name}\\}\\}";
617
618        assert!(TemplateStringEncoder::check_string_format(complex_template, &map).is_ok());
620
621        let encoder = TemplateStringEncoder::new(
623            schema,
624            Some(vec![0, 1, 2]), complex_template.to_owned(),
626        );
627
628        let row = OwnedRow::new(vec![
630            Some(ScalarImpl::Int32(123)),
631            Some(ScalarImpl::Utf8("John Doe".into())),
632            Some(ScalarImpl::Utf8("john@example.com".into())),
633        ]);
634
635        let result = encoder.encode_cols(row, vec![0, 1, 2].into_iter()).unwrap();
637
638        assert_eq!(
640            result,
641            "user:{prefix123,suffixJohn Doe,email:john@example.com,nested{deeply123},outer{innerJohn Doe}}"
642        );
643    }
644}