risingwave_expr_impl/scalar/
regexp.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Regular expression functions.
16
17use std::str::FromStr;
18
19use fancy_regex::{Regex, RegexBuilder};
20use risingwave_common::types::ScalarRefImpl;
21use risingwave_expr::{ExprError, Result, bail, function};
22use thiserror_ext::AsReport;
23
24#[derive(Debug)]
25pub struct RegexpContext {
26    pub regex: Regex,
27    pub global: bool,
28    pub replacement: String,
29}
30
31impl RegexpContext {
32    fn new(pattern: &str, flags: &str, replacement: &str) -> Result<Self> {
33        let options = RegexpOptions::from_str(flags)?;
34
35        let origin = if options.case_insensitive {
36            format!("(?i:{})", pattern)
37        } else {
38            pattern.to_owned()
39        };
40
41        Ok(Self {
42            regex: RegexBuilder::new(&origin)
43                .build()
44                .map_err(|e| ExprError::Parse(e.to_report_string().into()))?,
45            global: options.global,
46            replacement: make_replacement(replacement),
47        })
48    }
49
50    pub fn from_pattern(pattern: &str) -> Result<Self> {
51        Self::new(pattern, "", "")
52    }
53
54    pub fn from_pattern_flags(pattern: &str, flags: &str) -> Result<Self> {
55        Self::new(pattern, flags, "")
56    }
57
58    pub fn from_pattern_flags_for_count(pattern: &str, flags: &str) -> Result<Self> {
59        if flags.contains('g') {
60            bail!("regexp_count() does not support the global option");
61        }
62        Self::new(pattern, flags, "")
63    }
64
65    pub fn from_pattern_replacement(pattern: &str, replacement: &str) -> Result<Self> {
66        Self::new(pattern, "", replacement)
67    }
68
69    pub fn from_pattern_replacement_flags(
70        pattern: &str,
71        replacement: &str,
72        flags: &str,
73    ) -> Result<Self> {
74        Self::new(pattern, flags, replacement)
75    }
76}
77
78/// Construct the regex used to match and replace `\n` expression.
79/// <https://docs.rs/regex/latest/regex/struct.Captures.html#method.expand>
80///
81/// ```text
82/// \& -> ${0}
83/// \1 -> ${1}
84/// ...
85/// \9 -> ${9}
86/// ```
87fn make_replacement(s: &str) -> String {
88    use std::fmt::Write;
89    let mut ret = String::with_capacity(s.len());
90    let mut chars = s.chars();
91    while let Some(c) = chars.next() {
92        if c != '\\' {
93            ret.push(c);
94            continue;
95        }
96        match chars.next() {
97            Some('&') => ret.push_str("${0}"),
98            Some(c @ '1'..='9') => write!(&mut ret, "${{{c}}}").unwrap(),
99            Some(c) => write!(ret, "\\{c}").unwrap(),
100            None => ret.push('\\'),
101        }
102    }
103    ret
104}
105
106/// <https://www.postgresql.org/docs/current/functions-matching.html#POSIX-EMBEDDED-OPTIONS-TABLE>
107#[derive(Default, Debug)]
108struct RegexpOptions {
109    /// `c` and `i`
110    case_insensitive: bool,
111    /// `g`
112    global: bool,
113}
114
115impl FromStr for RegexpOptions {
116    type Err = ExprError;
117
118    fn from_str(s: &str) -> Result<Self> {
119        let mut opts = Self::default();
120        for c in s.chars() {
121            match c {
122                // Case sensitive matching here
123                'c' => opts.case_insensitive = false,
124                // Case insensitive matching here
125                'i' => opts.case_insensitive = true,
126                // Global matching here
127                'g' => opts.global = true,
128                _ => {
129                    bail!("invalid regular expression option: \"{c}\"");
130                }
131            }
132        }
133        Ok(opts)
134    }
135}
136
137#[function(
138    // source ~ pattern
139    "regexp_eq(varchar, varchar) -> boolean",
140    prebuild = "RegexpContext::from_pattern($1)?"
141)]
142fn regexp_eq(text: &str, regex: &RegexpContext) -> bool {
143    regex.regex.is_match(text).unwrap()
144}
145
146#[function(
147    // regexp_match(source, pattern)
148    "regexp_match(varchar, varchar) -> varchar[]",
149    prebuild = "RegexpContext::from_pattern($1)?"
150)]
151#[function(
152    // regexp_match(source, pattern, flags)
153    "regexp_match(varchar, varchar, varchar) -> varchar[]",
154    prebuild = "RegexpContext::from_pattern_flags($1, $2)?"
155)]
156fn regexp_match(
157    text: &str,
158    regex: &RegexpContext,
159    writer: &mut impl risingwave_common::array::ListWrite,
160) -> Option<()> {
161    // If there are multiple captures, then the first one is the whole match, and should be
162    // ignored in PostgreSQL's behavior.
163    let skip_first = regex.regex.captures_len() > 1;
164    let capture = regex.regex.captures(text).unwrap()?;
165    let iter = capture
166        .iter()
167        .skip(if skip_first { 1 } else { 0 })
168        .map(|mat| mat.map(|m| ScalarRefImpl::Utf8(m.as_str())));
169    writer.write_iter(iter);
170    Some(())
171}
172
173#[function(
174    // regexp_count(source, pattern)
175    "regexp_count(varchar, varchar) -> int4",
176    prebuild = "RegexpContext::from_pattern($1)?"
177)]
178fn regexp_count_start0(text: &str, regex: &RegexpContext) -> Result<i32> {
179    regexp_count(text, 1, regex)
180}
181
182#[function(
183    // regexp_count(source, pattern, start)
184    "regexp_count(varchar, varchar, int4) -> int4",
185    prebuild = "RegexpContext::from_pattern($1)?"
186)]
187#[function(
188    // regexp_count(source, pattern, start, flags)
189    "regexp_count(varchar, varchar, int4, varchar) -> int4",
190    prebuild = "RegexpContext::from_pattern_flags_for_count($1, $3)?"
191)]
192fn regexp_count(text: &str, start: i32, regex: &RegexpContext) -> Result<i32> {
193    // First get the start position to count for
194    let start = match start {
195        ..=0 => {
196            return Err(ExprError::InvalidParam {
197                name: "start",
198                reason: start.to_string().into(),
199            });
200        }
201        _ => start as usize - 1,
202    };
203
204    // Find the start byte index considering the unicode
205    let mut start = match text.char_indices().nth(start) {
206        Some((idx, _)) => idx,
207        // The `start` is out of bound
208        None => return Ok(0),
209    };
210
211    let mut count = 0;
212    while let Ok(Some(captures)) = regex.regex.captures(&text[start..]) {
213        count += 1;
214        start += captures.get(0).unwrap().end();
215    }
216    Ok(count)
217}
218
219#[function(
220    // regexp_replace(source, pattern, replacement)
221    "regexp_replace(varchar, varchar, varchar) -> varchar",
222    prebuild = "RegexpContext::from_pattern_replacement($1, $2)?"
223)]
224#[function(
225    // regexp_replace(source, pattern, replacement, flags)
226    "regexp_replace(varchar, varchar, varchar, varchar) -> varchar",
227    prebuild = "RegexpContext::from_pattern_replacement_flags($1, $2, $3)?"
228)]
229fn regexp_replace0(text: &str, ctx: &RegexpContext) -> Result<Box<str>> {
230    regexp_replace(text, 1, None, ctx)
231}
232
233#[function(
234    // regexp_replace(source, pattern, replacement, start)
235    "regexp_replace(varchar, varchar, varchar, int4) -> varchar",
236    prebuild = "RegexpContext::from_pattern_replacement($1, $2)?"
237)]
238fn regexp_replace_with_start(text: &str, start: i32, ctx: &RegexpContext) -> Result<Box<str>> {
239    regexp_replace(text, start, None, ctx)
240}
241
242#[function(
243    // regexp_replace(source, pattern, replacement, start, N)
244    "regexp_replace(varchar, varchar, varchar, int4, int4) -> varchar",
245    prebuild = "RegexpContext::from_pattern_replacement($1, $2)?"
246)]
247fn regexp_replace_with_start_n(
248    text: &str,
249    start: i32,
250    n: i32,
251    ctx: &RegexpContext,
252) -> Result<Box<str>> {
253    regexp_replace(text, start, Some(n), ctx)
254}
255
256#[function(
257    // regexp_replace(source, pattern, replacement, start, N, flags)
258    "regexp_replace(varchar, varchar, varchar, int4, int4, varchar) -> varchar",
259    prebuild = "RegexpContext::from_pattern_replacement_flags($1, $2, $5)?"
260)]
261fn regexp_replace_with_start_n_flags(
262    text: &str,
263    start: i32,
264    n: i32,
265    ctx: &RegexpContext,
266) -> Result<Box<str>> {
267    regexp_replace(text, start, Some(n), ctx)
268}
269
270// regexp_replace(source, pattern, replacement [, start [, N ]] [, flags ])
271fn regexp_replace(
272    text: &str,
273    start: i32,
274    n: Option<i32>, // `None` if not specified
275    ctx: &RegexpContext,
276) -> Result<Box<str>> {
277    // The start position to begin the search
278    let start = match start {
279        ..=0 => {
280            return Err(ExprError::InvalidParam {
281                name: "start",
282                reason: start.to_string().into(),
283            });
284        }
285        _ => start as usize - 1,
286    };
287
288    // This is because the source text may contain unicode
289    let start = match text.char_indices().nth(start) {
290        Some((idx, _)) => idx,
291        // With no match
292        None => return Ok(text.into()),
293    };
294
295    if n.is_none() && ctx.global || n == Some(0) {
296        // --------------------------------------------------------------
297        // `-g` enabled (& `N` is not specified) or `N` is `0`          |
298        // We need to replace all the occurrence of the matched pattern |
299        // --------------------------------------------------------------
300
301        // See if there is capture group or not
302        if ctx.regex.captures_len() <= 1 {
303            // There is no capture groups in the regex
304            // Just replace all matched patterns after `start`
305            Ok(format!(
306                "{}{}",
307                &text[..start],
308                ctx.regex.replace_all(&text[start..], &ctx.replacement)
309            )
310            .into())
311        } else {
312            // The position to start searching for replacement
313            let mut search_start = start;
314
315            // Construct the return string
316            let mut ret = text[..search_start].to_string();
317
318            // Begin the actual replace logic
319            while let Ok(Some(capture)) = ctx.regex.captures(&text[search_start..]) {
320                let match_start = capture.get(0).unwrap().start();
321                let match_end = capture.get(0).unwrap().end();
322
323                if match_start == match_end {
324                    // If this is an empty match
325                    search_start += 1;
326                    continue;
327                }
328
329                // Append the portion of the text from `search_start` to `match_start`
330                ret.push_str(&text[search_start..search_start + match_start]);
331
332                // Start to replacing
333                // Note that the result will be written directly to `ret` buffer
334                capture.expand(&ctx.replacement, &mut ret);
335
336                // Update the `search_start`
337                search_start += match_end;
338            }
339
340            // Push the rest of the text to return string
341            ret.push_str(&text[search_start..]);
342
343            Ok(ret.into())
344        }
345    } else {
346        // -------------------------------------------------
347        // Only replace the first matched pattern          |
348        // Or the N-th matched pattern if `N` is specified |
349        // -------------------------------------------------
350
351        // Construct the return string
352        let mut ret = if start > 1 {
353            text[..start].to_string()
354        } else {
355            "".to_owned()
356        };
357
358        // See if there is capture group or not
359        if ctx.regex.captures_len() <= 1 {
360            // There is no capture groups in the regex
361            if let Some(n) = n {
362                // Replace only the N-th match
363                let mut count = 1;
364                // The absolute index for the start of searching
365                let mut search_start = start;
366                while let Ok(Some(capture)) = ctx.regex.captures(&text[search_start..]) {
367                    // Get the current start & end index
368                    let match_start = capture.get(0).unwrap().start();
369                    let match_end = capture.get(0).unwrap().end();
370
371                    if count == n {
372                        // We've reached the pattern to replace
373                        // Let's construct the return string
374                        ret = format!(
375                            "{}{}{}",
376                            &text[..search_start + match_start],
377                            &ctx.replacement,
378                            &text[search_start + match_end..]
379                        );
380                        break;
381                    }
382
383                    // Update the counter
384                    count += 1;
385
386                    // Update `start`
387                    search_start += match_end;
388                }
389            } else {
390                // `N` is not specified
391                ret.push_str(&ctx.regex.replacen(&text[start..], 1, &ctx.replacement));
392            }
393        } else {
394            // There are capture groups in the regex
395            // Reset return string at the beginning
396            ret = "".to_owned();
397            if let Some(n) = n {
398                // Replace only the N-th match
399                let mut count = 1;
400                while let Ok(Some(capture)) = ctx.regex.captures(&text[start..]) {
401                    if count == n {
402                        // We've reached the pattern to replace
403                        let match_start = capture.get(0).unwrap().start();
404                        let match_end = capture.get(0).unwrap().end();
405
406                        // Get the replaced string and expand it
407                        capture.expand(&ctx.replacement, &mut ret);
408
409                        // Construct the return string
410                        ret = format!(
411                            "{}{}{}",
412                            &text[..start + match_start],
413                            ret,
414                            &text[start + match_end..]
415                        );
416                    }
417
418                    // Update the counter
419                    count += 1;
420                }
421
422                // If there is no match, just return the original string
423                if ret.is_empty() {
424                    ret = text.into();
425                }
426            } else {
427                // `N` is not specified
428                if let Ok(None) = ctx.regex.captures(&text[start..]) {
429                    // No match
430                    return Ok(text.into());
431                }
432
433                // Otherwise replace the source text
434                if let Ok(Some(capture)) = ctx.regex.captures(&text[start..]) {
435                    let match_start = capture.get(0).unwrap().start();
436                    let match_end = capture.get(0).unwrap().end();
437
438                    // Get the replaced string and expand it
439                    capture.expand(&ctx.replacement, &mut ret);
440
441                    // Construct the return string
442                    ret = format!(
443                        "{}{}{}",
444                        &text[..start + match_start],
445                        ret,
446                        &text[start + match_end..]
447                    );
448                }
449            }
450        }
451
452        Ok(ret.into())
453    }
454}
455
456#[function(
457    // regexp_split_to_array(source, pattern)
458    "regexp_split_to_array(varchar, varchar) -> varchar[]",
459    prebuild = "RegexpContext::from_pattern($1)?"
460)]
461#[function(
462    // regexp_split_to_array(source, pattern, flags)
463    "regexp_split_to_array(varchar, varchar, varchar) -> varchar[]",
464    prebuild = "RegexpContext::from_pattern_flags($1, $2)?"
465)]
466fn regexp_split_to_array(
467    text: &str,
468    regex: &RegexpContext,
469    writer: &mut impl risingwave_common::array::ListWrite,
470) -> Option<()> {
471    let n = text.len();
472    let mut start = 0;
473    let mut empty_flag = false;
474
475    loop {
476        if start >= n {
477            // Prevent overflow
478            break;
479        }
480
481        let capture = regex.regex.captures(&text[start..]).unwrap();
482
483        if capture.is_none() {
484            break;
485        }
486
487        let whole_match = capture.unwrap().get(0);
488        debug_assert!(whole_match.is_some(), "Expected `whole_match` to be valid");
489
490        let begin = whole_match.unwrap().start() + start;
491        let end = whole_match.unwrap().end() + start;
492
493        if begin == end {
494            // Empty match (i.e., `\s*`)
495            empty_flag = true;
496
497            if begin == text.len() {
498                // We do not need to push extra stuff to the result list
499                start = begin;
500                break;
501            }
502            writer.write(Some(ScalarRefImpl::Utf8(&text[start..begin + 1])));
503            start = end + 1;
504            continue;
505        }
506
507        if start == begin {
508            // The before match is possibly empty
509            if !empty_flag {
510                // We'll push an empty string to conform with postgres
511                // If there does not exists a empty match before
512                writer.write(Some(ScalarRefImpl::Utf8("")));
513            }
514            start = end;
515            continue;
516        }
517
518        if begin != 0 {
519            // Normal case
520            writer.write(Some(ScalarRefImpl::Utf8(&text[start..begin])));
521        }
522
523        // We should update the `start` no matter `begin` is zero or not
524        start = end;
525    }
526
527    if start < n {
528        // Push the extra text to the list
529        // Note that this will implicitly push the entire text to the list
530        // If there is no match, which is the expected behavior
531        writer.write(Some(ScalarRefImpl::Utf8(&text[start..])));
532    }
533
534    if start == n && !empty_flag {
535        writer.write(Some(ScalarRefImpl::Utf8("")));
536    }
537
538    Some(())
539}