risingwave_expr_impl/scalar/
regexp.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Regular expression functions.
16
17use std::str::FromStr;
18
19use fancy_regex::{Regex, RegexBuilder};
20use risingwave_common::array::{ArrayBuilder, ListValue, Utf8Array, Utf8ArrayBuilder};
21use risingwave_expr::{ExprError, Result, bail, function};
22use thiserror_ext::AsReport;
23
24#[derive(Debug)]
25pub struct RegexpContext {
26    pub regex: Regex,
27    pub global: bool,
28    pub replacement: String,
29}
30
31impl RegexpContext {
32    fn new(pattern: &str, flags: &str, replacement: &str) -> Result<Self> {
33        let options = RegexpOptions::from_str(flags)?;
34
35        let origin = if options.case_insensitive {
36            format!("(?i:{})", pattern)
37        } else {
38            pattern.to_owned()
39        };
40
41        Ok(Self {
42            regex: RegexBuilder::new(&origin)
43                .build()
44                .map_err(|e| ExprError::Parse(e.to_report_string().into()))?,
45            global: options.global,
46            replacement: make_replacement(replacement),
47        })
48    }
49
50    pub fn from_pattern(pattern: &str) -> Result<Self> {
51        Self::new(pattern, "", "")
52    }
53
54    pub fn from_pattern_flags(pattern: &str, flags: &str) -> Result<Self> {
55        Self::new(pattern, flags, "")
56    }
57
58    pub fn from_pattern_flags_for_count(pattern: &str, flags: &str) -> Result<Self> {
59        if flags.contains('g') {
60            bail!("regexp_count() does not support the global option");
61        }
62        Self::new(pattern, flags, "")
63    }
64
65    pub fn from_pattern_replacement(pattern: &str, replacement: &str) -> Result<Self> {
66        Self::new(pattern, "", replacement)
67    }
68
69    pub fn from_pattern_replacement_flags(
70        pattern: &str,
71        replacement: &str,
72        flags: &str,
73    ) -> Result<Self> {
74        Self::new(pattern, flags, replacement)
75    }
76}
77
78/// Construct the regex used to match and replace `\n` expression.
79/// <https://docs.rs/regex/latest/regex/struct.Captures.html#method.expand>
80///
81/// ```text
82/// \& -> ${0}
83/// \1 -> ${1}
84/// ...
85/// \9 -> ${9}
86/// ```
87fn make_replacement(s: &str) -> String {
88    use std::fmt::Write;
89    let mut ret = String::with_capacity(s.len());
90    let mut chars = s.chars();
91    while let Some(c) = chars.next() {
92        if c != '\\' {
93            ret.push(c);
94            continue;
95        }
96        match chars.next() {
97            Some('&') => ret.push_str("${0}"),
98            Some(c @ '1'..='9') => write!(&mut ret, "${{{c}}}").unwrap(),
99            Some(c) => write!(ret, "\\{c}").unwrap(),
100            None => ret.push('\\'),
101        }
102    }
103    ret
104}
105
106/// <https://www.postgresql.org/docs/current/functions-matching.html#POSIX-EMBEDDED-OPTIONS-TABLE>
107#[derive(Default, Debug)]
108struct RegexpOptions {
109    /// `c` and `i`
110    case_insensitive: bool,
111    /// `g`
112    global: bool,
113}
114
115impl FromStr for RegexpOptions {
116    type Err = ExprError;
117
118    fn from_str(s: &str) -> Result<Self> {
119        let mut opts = Self::default();
120        for c in s.chars() {
121            match c {
122                // Case sensitive matching here
123                'c' => opts.case_insensitive = false,
124                // Case insensitive matching here
125                'i' => opts.case_insensitive = true,
126                // Global matching here
127                'g' => opts.global = true,
128                _ => {
129                    bail!("invalid regular expression option: \"{c}\"");
130                }
131            }
132        }
133        Ok(opts)
134    }
135}
136
137#[function(
138    // source ~ pattern
139    "regexp_eq(varchar, varchar) -> boolean",
140    prebuild = "RegexpContext::from_pattern($1)?"
141)]
142fn regexp_eq(text: &str, regex: &RegexpContext) -> bool {
143    regex.regex.is_match(text).unwrap()
144}
145
146#[function(
147    // regexp_match(source, pattern)
148    "regexp_match(varchar, varchar) -> varchar[]",
149    prebuild = "RegexpContext::from_pattern($1)?"
150)]
151#[function(
152    // regexp_match(source, pattern, flags)
153    "regexp_match(varchar, varchar, varchar) -> varchar[]",
154    prebuild = "RegexpContext::from_pattern_flags($1, $2)?"
155)]
156fn regexp_match(text: &str, regex: &RegexpContext) -> Option<ListValue> {
157    // If there are multiple captures, then the first one is the whole match, and should be
158    // ignored in PostgreSQL's behavior.
159    let skip_first = regex.regex.captures_len() > 1;
160    let capture = regex.regex.captures(text).unwrap()?;
161    let list = capture
162        .iter()
163        .skip(if skip_first { 1 } else { 0 })
164        .map(|mat| mat.map(|m| m.as_str()))
165        .collect::<Utf8Array>();
166    Some(ListValue::new(list.into()))
167}
168
169#[function(
170    // regexp_count(source, pattern)
171    "regexp_count(varchar, varchar) -> int4",
172    prebuild = "RegexpContext::from_pattern($1)?"
173)]
174fn regexp_count_start0(text: &str, regex: &RegexpContext) -> Result<i32> {
175    regexp_count(text, 1, regex)
176}
177
178#[function(
179    // regexp_count(source, pattern, start)
180    "regexp_count(varchar, varchar, int4) -> int4",
181    prebuild = "RegexpContext::from_pattern($1)?"
182)]
183#[function(
184    // regexp_count(source, pattern, start, flags)
185    "regexp_count(varchar, varchar, int4, varchar) -> int4",
186    prebuild = "RegexpContext::from_pattern_flags_for_count($1, $3)?"
187)]
188fn regexp_count(text: &str, start: i32, regex: &RegexpContext) -> Result<i32> {
189    // First get the start position to count for
190    let start = match start {
191        ..=0 => {
192            return Err(ExprError::InvalidParam {
193                name: "start",
194                reason: start.to_string().into(),
195            });
196        }
197        _ => start as usize - 1,
198    };
199
200    // Find the start byte index considering the unicode
201    let mut start = match text.char_indices().nth(start) {
202        Some((idx, _)) => idx,
203        // The `start` is out of bound
204        None => return Ok(0),
205    };
206
207    let mut count = 0;
208    while let Ok(Some(captures)) = regex.regex.captures(&text[start..]) {
209        count += 1;
210        start += captures.get(0).unwrap().end();
211    }
212    Ok(count)
213}
214
215#[function(
216    // regexp_replace(source, pattern, replacement)
217    "regexp_replace(varchar, varchar, varchar) -> varchar",
218    prebuild = "RegexpContext::from_pattern_replacement($1, $2)?"
219)]
220#[function(
221    // regexp_replace(source, pattern, replacement, flags)
222    "regexp_replace(varchar, varchar, varchar, varchar) -> varchar",
223    prebuild = "RegexpContext::from_pattern_replacement_flags($1, $2, $3)?"
224)]
225fn regexp_replace0(text: &str, ctx: &RegexpContext) -> Result<Box<str>> {
226    regexp_replace(text, 1, None, ctx)
227}
228
229#[function(
230    // regexp_replace(source, pattern, replacement, start)
231    "regexp_replace(varchar, varchar, varchar, int4) -> varchar",
232    prebuild = "RegexpContext::from_pattern_replacement($1, $2)?"
233)]
234fn regexp_replace_with_start(text: &str, start: i32, ctx: &RegexpContext) -> Result<Box<str>> {
235    regexp_replace(text, start, None, ctx)
236}
237
238#[function(
239    // regexp_replace(source, pattern, replacement, start, N)
240    "regexp_replace(varchar, varchar, varchar, int4, int4) -> varchar",
241    prebuild = "RegexpContext::from_pattern_replacement($1, $2)?"
242)]
243fn regexp_replace_with_start_n(
244    text: &str,
245    start: i32,
246    n: i32,
247    ctx: &RegexpContext,
248) -> Result<Box<str>> {
249    regexp_replace(text, start, Some(n), ctx)
250}
251
252#[function(
253    // regexp_replace(source, pattern, replacement, start, N, flags)
254    "regexp_replace(varchar, varchar, varchar, int4, int4, varchar) -> varchar",
255    prebuild = "RegexpContext::from_pattern_replacement_flags($1, $2, $5)?"
256)]
257fn regexp_replace_with_start_n_flags(
258    text: &str,
259    start: i32,
260    n: i32,
261    ctx: &RegexpContext,
262) -> Result<Box<str>> {
263    regexp_replace(text, start, Some(n), ctx)
264}
265
266// regexp_replace(source, pattern, replacement [, start [, N ]] [, flags ])
267fn regexp_replace(
268    text: &str,
269    start: i32,
270    n: Option<i32>, // `None` if not specified
271    ctx: &RegexpContext,
272) -> Result<Box<str>> {
273    // The start position to begin the search
274    let start = match start {
275        ..=0 => {
276            return Err(ExprError::InvalidParam {
277                name: "start",
278                reason: start.to_string().into(),
279            });
280        }
281        _ => start as usize - 1,
282    };
283
284    // This is because the source text may contain unicode
285    let start = match text.char_indices().nth(start) {
286        Some((idx, _)) => idx,
287        // With no match
288        None => return Ok(text.into()),
289    };
290
291    if n.is_none() && ctx.global || n == Some(0) {
292        // --------------------------------------------------------------
293        // `-g` enabled (& `N` is not specified) or `N` is `0`          |
294        // We need to replace all the occurrence of the matched pattern |
295        // --------------------------------------------------------------
296
297        // See if there is capture group or not
298        if ctx.regex.captures_len() <= 1 {
299            // There is no capture groups in the regex
300            // Just replace all matched patterns after `start`
301            Ok(format!(
302                "{}{}",
303                &text[..start],
304                ctx.regex.replace_all(&text[start..], &ctx.replacement)
305            )
306            .into())
307        } else {
308            // The position to start searching for replacement
309            let mut search_start = start;
310
311            // Construct the return string
312            let mut ret = text[..search_start].to_string();
313
314            // Begin the actual replace logic
315            while let Ok(Some(capture)) = ctx.regex.captures(&text[search_start..]) {
316                let match_start = capture.get(0).unwrap().start();
317                let match_end = capture.get(0).unwrap().end();
318
319                if match_start == match_end {
320                    // If this is an empty match
321                    search_start += 1;
322                    continue;
323                }
324
325                // Append the portion of the text from `search_start` to `match_start`
326                ret.push_str(&text[search_start..search_start + match_start]);
327
328                // Start to replacing
329                // Note that the result will be written directly to `ret` buffer
330                capture.expand(&ctx.replacement, &mut ret);
331
332                // Update the `search_start`
333                search_start += match_end;
334            }
335
336            // Push the rest of the text to return string
337            ret.push_str(&text[search_start..]);
338
339            Ok(ret.into())
340        }
341    } else {
342        // -------------------------------------------------
343        // Only replace the first matched pattern          |
344        // Or the N-th matched pattern if `N` is specified |
345        // -------------------------------------------------
346
347        // Construct the return string
348        let mut ret = if start > 1 {
349            text[..start].to_string()
350        } else {
351            "".to_owned()
352        };
353
354        // See if there is capture group or not
355        if ctx.regex.captures_len() <= 1 {
356            // There is no capture groups in the regex
357            if let Some(n) = n {
358                // Replace only the N-th match
359                let mut count = 1;
360                // The absolute index for the start of searching
361                let mut search_start = start;
362                while let Ok(Some(capture)) = ctx.regex.captures(&text[search_start..]) {
363                    // Get the current start & end index
364                    let match_start = capture.get(0).unwrap().start();
365                    let match_end = capture.get(0).unwrap().end();
366
367                    if count == n {
368                        // We've reached the pattern to replace
369                        // Let's construct the return string
370                        ret = format!(
371                            "{}{}{}",
372                            &text[..search_start + match_start],
373                            &ctx.replacement,
374                            &text[search_start + match_end..]
375                        );
376                        break;
377                    }
378
379                    // Update the counter
380                    count += 1;
381
382                    // Update `start`
383                    search_start += match_end;
384                }
385            } else {
386                // `N` is not specified
387                ret.push_str(&ctx.regex.replacen(&text[start..], 1, &ctx.replacement));
388            }
389        } else {
390            // There are capture groups in the regex
391            // Reset return string at the beginning
392            ret = "".to_owned();
393            if let Some(n) = n {
394                // Replace only the N-th match
395                let mut count = 1;
396                while let Ok(Some(capture)) = ctx.regex.captures(&text[start..]) {
397                    if count == n {
398                        // We've reached the pattern to replace
399                        let match_start = capture.get(0).unwrap().start();
400                        let match_end = capture.get(0).unwrap().end();
401
402                        // Get the replaced string and expand it
403                        capture.expand(&ctx.replacement, &mut ret);
404
405                        // Construct the return string
406                        ret = format!(
407                            "{}{}{}",
408                            &text[..start + match_start],
409                            ret,
410                            &text[start + match_end..]
411                        );
412                    }
413
414                    // Update the counter
415                    count += 1;
416                }
417
418                // If there is no match, just return the original string
419                if ret.is_empty() {
420                    ret = text.into();
421                }
422            } else {
423                // `N` is not specified
424                if let Ok(None) = ctx.regex.captures(&text[start..]) {
425                    // No match
426                    return Ok(text.into());
427                }
428
429                // Otherwise replace the source text
430                if let Ok(Some(capture)) = ctx.regex.captures(&text[start..]) {
431                    let match_start = capture.get(0).unwrap().start();
432                    let match_end = capture.get(0).unwrap().end();
433
434                    // Get the replaced string and expand it
435                    capture.expand(&ctx.replacement, &mut ret);
436
437                    // Construct the return string
438                    ret = format!(
439                        "{}{}{}",
440                        &text[..start + match_start],
441                        ret,
442                        &text[start + match_end..]
443                    );
444                }
445            }
446        }
447
448        Ok(ret.into())
449    }
450}
451
452#[function(
453    // regexp_split_to_array(source, pattern)
454    "regexp_split_to_array(varchar, varchar) -> varchar[]",
455    prebuild = "RegexpContext::from_pattern($1)?"
456)]
457#[function(
458    // regexp_split_to_array(source, pattern, flags)
459    "regexp_split_to_array(varchar, varchar, varchar) -> varchar[]",
460    prebuild = "RegexpContext::from_pattern_flags($1, $2)?"
461)]
462fn regexp_split_to_array(text: &str, regex: &RegexpContext) -> Option<ListValue> {
463    let n = text.len();
464    let mut start = 0;
465    let mut builder = Utf8ArrayBuilder::new(0);
466    let mut empty_flag = false;
467
468    loop {
469        if start >= n {
470            // Prevent overflow
471            break;
472        }
473
474        let capture = regex.regex.captures(&text[start..]).unwrap();
475
476        if capture.is_none() {
477            break;
478        }
479
480        let whole_match = capture.unwrap().get(0);
481        debug_assert!(whole_match.is_some(), "Expected `whole_match` to be valid");
482
483        let begin = whole_match.unwrap().start() + start;
484        let end = whole_match.unwrap().end() + start;
485
486        if begin == end {
487            // Empty match (i.e., `\s*`)
488            empty_flag = true;
489
490            if begin == text.len() {
491                // We do not need to push extra stuff to the result list
492                start = begin;
493                break;
494            }
495            builder.append(Some(&text[start..begin + 1]));
496            start = end + 1;
497            continue;
498        }
499
500        if start == begin {
501            // The before match is possibly empty
502            if !empty_flag {
503                // We'll push an empty string to conform with postgres
504                // If there does not exists a empty match before
505                builder.append(Some(""));
506            }
507            start = end;
508            continue;
509        }
510
511        if begin != 0 {
512            // Normal case
513            builder.append(Some(&text[start..begin]));
514        }
515
516        // We should update the `start` no matter `begin` is zero or not
517        start = end;
518    }
519
520    if start < n {
521        // Push the extra text to the list
522        // Note that this will implicitly push the entire text to the list
523        // If there is no match, which is the expected behavior
524        builder.append(Some(&text[start..]));
525    }
526
527    if start == n && !empty_flag {
528        builder.append(Some(""));
529    }
530
531    Some(ListValue::new(builder.finish().into()))
532}