risingwave_expr_impl/scalar/
regexp.rs1use std::str::FromStr;
18
19use fancy_regex::{Regex, RegexBuilder};
20use risingwave_common::array::{ArrayBuilder, ListValue, Utf8Array, Utf8ArrayBuilder};
21use risingwave_expr::{ExprError, Result, bail, function};
22use thiserror_ext::AsReport;
23
24#[derive(Debug)]
25pub struct RegexpContext {
26    pub regex: Regex,
27    pub global: bool,
28    pub replacement: String,
29}
30
31impl RegexpContext {
32    fn new(pattern: &str, flags: &str, replacement: &str) -> Result<Self> {
33        let options = RegexpOptions::from_str(flags)?;
34
35        let origin = if options.case_insensitive {
36            format!("(?i:{})", pattern)
37        } else {
38            pattern.to_owned()
39        };
40
41        Ok(Self {
42            regex: RegexBuilder::new(&origin)
43                .build()
44                .map_err(|e| ExprError::Parse(e.to_report_string().into()))?,
45            global: options.global,
46            replacement: make_replacement(replacement),
47        })
48    }
49
50    pub fn from_pattern(pattern: &str) -> Result<Self> {
51        Self::new(pattern, "", "")
52    }
53
54    pub fn from_pattern_flags(pattern: &str, flags: &str) -> Result<Self> {
55        Self::new(pattern, flags, "")
56    }
57
58    pub fn from_pattern_flags_for_count(pattern: &str, flags: &str) -> Result<Self> {
59        if flags.contains('g') {
60            bail!("regexp_count() does not support the global option");
61        }
62        Self::new(pattern, flags, "")
63    }
64
65    pub fn from_pattern_replacement(pattern: &str, replacement: &str) -> Result<Self> {
66        Self::new(pattern, "", replacement)
67    }
68
69    pub fn from_pattern_replacement_flags(
70        pattern: &str,
71        replacement: &str,
72        flags: &str,
73    ) -> Result<Self> {
74        Self::new(pattern, flags, replacement)
75    }
76}
77
78fn make_replacement(s: &str) -> String {
88    use std::fmt::Write;
89    let mut ret = String::with_capacity(s.len());
90    let mut chars = s.chars();
91    while let Some(c) = chars.next() {
92        if c != '\\' {
93            ret.push(c);
94            continue;
95        }
96        match chars.next() {
97            Some('&') => ret.push_str("${0}"),
98            Some(c @ '1'..='9') => write!(&mut ret, "${{{c}}}").unwrap(),
99            Some(c) => write!(ret, "\\{c}").unwrap(),
100            None => ret.push('\\'),
101        }
102    }
103    ret
104}
105
106#[derive(Default, Debug)]
108struct RegexpOptions {
109    case_insensitive: bool,
111    global: bool,
113}
114
115impl FromStr for RegexpOptions {
116    type Err = ExprError;
117
118    fn from_str(s: &str) -> Result<Self> {
119        let mut opts = Self::default();
120        for c in s.chars() {
121            match c {
122                'c' => opts.case_insensitive = false,
124                'i' => opts.case_insensitive = true,
126                'g' => opts.global = true,
128                _ => {
129                    bail!("invalid regular expression option: \"{c}\"");
130                }
131            }
132        }
133        Ok(opts)
134    }
135}
136
137#[function(
138    "regexp_eq(varchar, varchar) -> boolean",
140    prebuild = "RegexpContext::from_pattern($1)?"
141)]
142fn regexp_eq(text: &str, regex: &RegexpContext) -> bool {
143    regex.regex.is_match(text).unwrap()
144}
145
146#[function(
147    "regexp_match(varchar, varchar) -> varchar[]",
149    prebuild = "RegexpContext::from_pattern($1)?"
150)]
151#[function(
152    "regexp_match(varchar, varchar, varchar) -> varchar[]",
154    prebuild = "RegexpContext::from_pattern_flags($1, $2)?"
155)]
156fn regexp_match(text: &str, regex: &RegexpContext) -> Option<ListValue> {
157    let skip_first = regex.regex.captures_len() > 1;
160    let capture = regex.regex.captures(text).unwrap()?;
161    let list = capture
162        .iter()
163        .skip(if skip_first { 1 } else { 0 })
164        .map(|mat| mat.map(|m| m.as_str()))
165        .collect::<Utf8Array>();
166    Some(ListValue::new(list.into()))
167}
168
169#[function(
170    "regexp_count(varchar, varchar) -> int4",
172    prebuild = "RegexpContext::from_pattern($1)?"
173)]
174fn regexp_count_start0(text: &str, regex: &RegexpContext) -> Result<i32> {
175    regexp_count(text, 1, regex)
176}
177
178#[function(
179    "regexp_count(varchar, varchar, int4) -> int4",
181    prebuild = "RegexpContext::from_pattern($1)?"
182)]
183#[function(
184    "regexp_count(varchar, varchar, int4, varchar) -> int4",
186    prebuild = "RegexpContext::from_pattern_flags_for_count($1, $3)?"
187)]
188fn regexp_count(text: &str, start: i32, regex: &RegexpContext) -> Result<i32> {
189    let start = match start {
191        ..=0 => {
192            return Err(ExprError::InvalidParam {
193                name: "start",
194                reason: start.to_string().into(),
195            });
196        }
197        _ => start as usize - 1,
198    };
199
200    let mut start = match text.char_indices().nth(start) {
202        Some((idx, _)) => idx,
203        None => return Ok(0),
205    };
206
207    let mut count = 0;
208    while let Ok(Some(captures)) = regex.regex.captures(&text[start..]) {
209        count += 1;
210        start += captures.get(0).unwrap().end();
211    }
212    Ok(count)
213}
214
215#[function(
216    "regexp_replace(varchar, varchar, varchar) -> varchar",
218    prebuild = "RegexpContext::from_pattern_replacement($1, $2)?"
219)]
220#[function(
221    "regexp_replace(varchar, varchar, varchar, varchar) -> varchar",
223    prebuild = "RegexpContext::from_pattern_replacement_flags($1, $2, $3)?"
224)]
225fn regexp_replace0(text: &str, ctx: &RegexpContext) -> Result<Box<str>> {
226    regexp_replace(text, 1, None, ctx)
227}
228
229#[function(
230    "regexp_replace(varchar, varchar, varchar, int4) -> varchar",
232    prebuild = "RegexpContext::from_pattern_replacement($1, $2)?"
233)]
234fn regexp_replace_with_start(text: &str, start: i32, ctx: &RegexpContext) -> Result<Box<str>> {
235    regexp_replace(text, start, None, ctx)
236}
237
238#[function(
239    "regexp_replace(varchar, varchar, varchar, int4, int4) -> varchar",
241    prebuild = "RegexpContext::from_pattern_replacement($1, $2)?"
242)]
243fn regexp_replace_with_start_n(
244    text: &str,
245    start: i32,
246    n: i32,
247    ctx: &RegexpContext,
248) -> Result<Box<str>> {
249    regexp_replace(text, start, Some(n), ctx)
250}
251
252#[function(
253    "regexp_replace(varchar, varchar, varchar, int4, int4, varchar) -> varchar",
255    prebuild = "RegexpContext::from_pattern_replacement_flags($1, $2, $5)?"
256)]
257fn regexp_replace_with_start_n_flags(
258    text: &str,
259    start: i32,
260    n: i32,
261    ctx: &RegexpContext,
262) -> Result<Box<str>> {
263    regexp_replace(text, start, Some(n), ctx)
264}
265
266fn regexp_replace(
268    text: &str,
269    start: i32,
270    n: Option<i32>, ctx: &RegexpContext,
272) -> Result<Box<str>> {
273    let start = match start {
275        ..=0 => {
276            return Err(ExprError::InvalidParam {
277                name: "start",
278                reason: start.to_string().into(),
279            });
280        }
281        _ => start as usize - 1,
282    };
283
284    let start = match text.char_indices().nth(start) {
286        Some((idx, _)) => idx,
287        None => return Ok(text.into()),
289    };
290
291    if n.is_none() && ctx.global || n == Some(0) {
292        if ctx.regex.captures_len() <= 1 {
299            Ok(format!(
302                "{}{}",
303                &text[..start],
304                ctx.regex.replace_all(&text[start..], &ctx.replacement)
305            )
306            .into())
307        } else {
308            let mut search_start = start;
310
311            let mut ret = text[..search_start].to_string();
313
314            while let Ok(Some(capture)) = ctx.regex.captures(&text[search_start..]) {
316                let match_start = capture.get(0).unwrap().start();
317                let match_end = capture.get(0).unwrap().end();
318
319                if match_start == match_end {
320                    search_start += 1;
322                    continue;
323                }
324
325                ret.push_str(&text[search_start..search_start + match_start]);
327
328                capture.expand(&ctx.replacement, &mut ret);
331
332                search_start += match_end;
334            }
335
336            ret.push_str(&text[search_start..]);
338
339            Ok(ret.into())
340        }
341    } else {
342        let mut ret = if start > 1 {
349            text[..start].to_string()
350        } else {
351            "".to_owned()
352        };
353
354        if ctx.regex.captures_len() <= 1 {
356            if let Some(n) = n {
358                let mut count = 1;
360                let mut search_start = start;
362                while let Ok(Some(capture)) = ctx.regex.captures(&text[search_start..]) {
363                    let match_start = capture.get(0).unwrap().start();
365                    let match_end = capture.get(0).unwrap().end();
366
367                    if count == n {
368                        ret = format!(
371                            "{}{}{}",
372                            &text[..search_start + match_start],
373                            &ctx.replacement,
374                            &text[search_start + match_end..]
375                        );
376                        break;
377                    }
378
379                    count += 1;
381
382                    search_start += match_end;
384                }
385            } else {
386                ret.push_str(&ctx.regex.replacen(&text[start..], 1, &ctx.replacement));
388            }
389        } else {
390            ret = "".to_owned();
393            if let Some(n) = n {
394                let mut count = 1;
396                while let Ok(Some(capture)) = ctx.regex.captures(&text[start..]) {
397                    if count == n {
398                        let match_start = capture.get(0).unwrap().start();
400                        let match_end = capture.get(0).unwrap().end();
401
402                        capture.expand(&ctx.replacement, &mut ret);
404
405                        ret = format!(
407                            "{}{}{}",
408                            &text[..start + match_start],
409                            ret,
410                            &text[start + match_end..]
411                        );
412                    }
413
414                    count += 1;
416                }
417
418                if ret.is_empty() {
420                    ret = text.into();
421                }
422            } else {
423                if let Ok(None) = ctx.regex.captures(&text[start..]) {
425                    return Ok(text.into());
427                }
428
429                if let Ok(Some(capture)) = ctx.regex.captures(&text[start..]) {
431                    let match_start = capture.get(0).unwrap().start();
432                    let match_end = capture.get(0).unwrap().end();
433
434                    capture.expand(&ctx.replacement, &mut ret);
436
437                    ret = format!(
439                        "{}{}{}",
440                        &text[..start + match_start],
441                        ret,
442                        &text[start + match_end..]
443                    );
444                }
445            }
446        }
447
448        Ok(ret.into())
449    }
450}
451
452#[function(
453    "regexp_split_to_array(varchar, varchar) -> varchar[]",
455    prebuild = "RegexpContext::from_pattern($1)?"
456)]
457#[function(
458    "regexp_split_to_array(varchar, varchar, varchar) -> varchar[]",
460    prebuild = "RegexpContext::from_pattern_flags($1, $2)?"
461)]
462fn regexp_split_to_array(text: &str, regex: &RegexpContext) -> Option<ListValue> {
463    let n = text.len();
464    let mut start = 0;
465    let mut builder = Utf8ArrayBuilder::new(0);
466    let mut empty_flag = false;
467
468    loop {
469        if start >= n {
470            break;
472        }
473
474        let capture = regex.regex.captures(&text[start..]).unwrap();
475
476        if capture.is_none() {
477            break;
478        }
479
480        let whole_match = capture.unwrap().get(0);
481        debug_assert!(whole_match.is_some(), "Expected `whole_match` to be valid");
482
483        let begin = whole_match.unwrap().start() + start;
484        let end = whole_match.unwrap().end() + start;
485
486        if begin == end {
487            empty_flag = true;
489
490            if begin == text.len() {
491                start = begin;
493                break;
494            }
495            builder.append(Some(&text[start..begin + 1]));
496            start = end + 1;
497            continue;
498        }
499
500        if start == begin {
501            if !empty_flag {
503                builder.append(Some(""));
506            }
507            start = end;
508            continue;
509        }
510
511        if begin != 0 {
512            builder.append(Some(&text[start..begin]));
514        }
515
516        start = end;
518    }
519
520    if start < n {
521        builder.append(Some(&text[start..]));
525    }
526
527    if start == n && !empty_flag {
528        builder.append(Some(""));
529    }
530
531    Some(ListValue::new(builder.finish().into()))
532}