risingwave_expr_impl/scalar/
regexp.rs1use std::str::FromStr;
18
19use fancy_regex::{Regex, RegexBuilder};
20use risingwave_common::types::ScalarRefImpl;
21use risingwave_expr::{ExprError, Result, bail, function};
22use thiserror_ext::AsReport;
23
24#[derive(Debug)]
25pub struct RegexpContext {
26 pub regex: Regex,
27 pub global: bool,
28 pub replacement: String,
29}
30
31impl RegexpContext {
32 fn new(pattern: &str, flags: &str, replacement: &str) -> Result<Self> {
33 let options = RegexpOptions::from_str(flags)?;
34
35 let origin = if options.case_insensitive {
36 format!("(?i:{})", pattern)
37 } else {
38 pattern.to_owned()
39 };
40
41 Ok(Self {
42 regex: RegexBuilder::new(&origin)
43 .build()
44 .map_err(|e| ExprError::Parse(e.to_report_string().into()))?,
45 global: options.global,
46 replacement: make_replacement(replacement),
47 })
48 }
49
50 pub fn from_pattern(pattern: &str) -> Result<Self> {
51 Self::new(pattern, "", "")
52 }
53
54 pub fn from_pattern_flags(pattern: &str, flags: &str) -> Result<Self> {
55 Self::new(pattern, flags, "")
56 }
57
58 pub fn from_pattern_flags_for_count(pattern: &str, flags: &str) -> Result<Self> {
59 if flags.contains('g') {
60 bail!("regexp_count() does not support the global option");
61 }
62 Self::new(pattern, flags, "")
63 }
64
65 pub fn from_pattern_replacement(pattern: &str, replacement: &str) -> Result<Self> {
66 Self::new(pattern, "", replacement)
67 }
68
69 pub fn from_pattern_replacement_flags(
70 pattern: &str,
71 replacement: &str,
72 flags: &str,
73 ) -> Result<Self> {
74 Self::new(pattern, flags, replacement)
75 }
76}
77
78fn make_replacement(s: &str) -> String {
88 use std::fmt::Write;
89 let mut ret = String::with_capacity(s.len());
90 let mut chars = s.chars();
91 while let Some(c) = chars.next() {
92 if c != '\\' {
93 ret.push(c);
94 continue;
95 }
96 match chars.next() {
97 Some('&') => ret.push_str("${0}"),
98 Some(c @ '1'..='9') => write!(&mut ret, "${{{c}}}").unwrap(),
99 Some(c) => write!(ret, "\\{c}").unwrap(),
100 None => ret.push('\\'),
101 }
102 }
103 ret
104}
105
106#[derive(Default, Debug)]
108struct RegexpOptions {
109 case_insensitive: bool,
111 global: bool,
113}
114
115impl FromStr for RegexpOptions {
116 type Err = ExprError;
117
118 fn from_str(s: &str) -> Result<Self> {
119 let mut opts = Self::default();
120 for c in s.chars() {
121 match c {
122 'c' => opts.case_insensitive = false,
124 'i' => opts.case_insensitive = true,
126 'g' => opts.global = true,
128 _ => {
129 bail!("invalid regular expression option: \"{c}\"");
130 }
131 }
132 }
133 Ok(opts)
134 }
135}
136
137#[function(
138 "regexp_eq(varchar, varchar) -> boolean",
140 prebuild = "RegexpContext::from_pattern($1)?"
141)]
142fn regexp_eq(text: &str, regex: &RegexpContext) -> bool {
143 regex.regex.is_match(text).unwrap()
144}
145
146#[function(
147 "regexp_match(varchar, varchar) -> varchar[]",
149 prebuild = "RegexpContext::from_pattern($1)?"
150)]
151#[function(
152 "regexp_match(varchar, varchar, varchar) -> varchar[]",
154 prebuild = "RegexpContext::from_pattern_flags($1, $2)?"
155)]
156fn regexp_match(
157 text: &str,
158 regex: &RegexpContext,
159 writer: &mut impl risingwave_common::array::ListWrite,
160) -> Option<()> {
161 let skip_first = regex.regex.captures_len() > 1;
164 let capture = regex.regex.captures(text).unwrap()?;
165 let iter = capture
166 .iter()
167 .skip(if skip_first { 1 } else { 0 })
168 .map(|mat| mat.map(|m| ScalarRefImpl::Utf8(m.as_str())));
169 writer.write_iter(iter);
170 Some(())
171}
172
173#[function(
174 "regexp_count(varchar, varchar) -> int4",
176 prebuild = "RegexpContext::from_pattern($1)?"
177)]
178fn regexp_count_start0(text: &str, regex: &RegexpContext) -> Result<i32> {
179 regexp_count(text, 1, regex)
180}
181
182#[function(
183 "regexp_count(varchar, varchar, int4) -> int4",
185 prebuild = "RegexpContext::from_pattern($1)?"
186)]
187#[function(
188 "regexp_count(varchar, varchar, int4, varchar) -> int4",
190 prebuild = "RegexpContext::from_pattern_flags_for_count($1, $3)?"
191)]
192fn regexp_count(text: &str, start: i32, regex: &RegexpContext) -> Result<i32> {
193 let start = match start {
195 ..=0 => {
196 return Err(ExprError::InvalidParam {
197 name: "start",
198 reason: start.to_string().into(),
199 });
200 }
201 _ => start as usize - 1,
202 };
203
204 let mut start = match text.char_indices().nth(start) {
206 Some((idx, _)) => idx,
207 None => return Ok(0),
209 };
210
211 let mut count = 0;
212 while let Ok(Some(captures)) = regex.regex.captures(&text[start..]) {
213 count += 1;
214 start += captures.get(0).unwrap().end();
215 }
216 Ok(count)
217}
218
219#[function(
220 "regexp_replace(varchar, varchar, varchar) -> varchar",
222 prebuild = "RegexpContext::from_pattern_replacement($1, $2)?"
223)]
224#[function(
225 "regexp_replace(varchar, varchar, varchar, varchar) -> varchar",
227 prebuild = "RegexpContext::from_pattern_replacement_flags($1, $2, $3)?"
228)]
229fn regexp_replace0(text: &str, ctx: &RegexpContext) -> Result<Box<str>> {
230 regexp_replace(text, 1, None, ctx)
231}
232
233#[function(
234 "regexp_replace(varchar, varchar, varchar, int4) -> varchar",
236 prebuild = "RegexpContext::from_pattern_replacement($1, $2)?"
237)]
238fn regexp_replace_with_start(text: &str, start: i32, ctx: &RegexpContext) -> Result<Box<str>> {
239 regexp_replace(text, start, None, ctx)
240}
241
242#[function(
243 "regexp_replace(varchar, varchar, varchar, int4, int4) -> varchar",
245 prebuild = "RegexpContext::from_pattern_replacement($1, $2)?"
246)]
247fn regexp_replace_with_start_n(
248 text: &str,
249 start: i32,
250 n: i32,
251 ctx: &RegexpContext,
252) -> Result<Box<str>> {
253 regexp_replace(text, start, Some(n), ctx)
254}
255
256#[function(
257 "regexp_replace(varchar, varchar, varchar, int4, int4, varchar) -> varchar",
259 prebuild = "RegexpContext::from_pattern_replacement_flags($1, $2, $5)?"
260)]
261fn regexp_replace_with_start_n_flags(
262 text: &str,
263 start: i32,
264 n: i32,
265 ctx: &RegexpContext,
266) -> Result<Box<str>> {
267 regexp_replace(text, start, Some(n), ctx)
268}
269
270fn regexp_replace(
272 text: &str,
273 start: i32,
274 n: Option<i32>, ctx: &RegexpContext,
276) -> Result<Box<str>> {
277 let start = match start {
279 ..=0 => {
280 return Err(ExprError::InvalidParam {
281 name: "start",
282 reason: start.to_string().into(),
283 });
284 }
285 _ => start as usize - 1,
286 };
287
288 let start = match text.char_indices().nth(start) {
290 Some((idx, _)) => idx,
291 None => return Ok(text.into()),
293 };
294
295 if n.is_none() && ctx.global || n == Some(0) {
296 if ctx.regex.captures_len() <= 1 {
303 Ok(format!(
306 "{}{}",
307 &text[..start],
308 ctx.regex.replace_all(&text[start..], &ctx.replacement)
309 )
310 .into())
311 } else {
312 let mut search_start = start;
314
315 let mut ret = text[..search_start].to_string();
317
318 while let Ok(Some(capture)) = ctx.regex.captures(&text[search_start..]) {
320 let match_start = capture.get(0).unwrap().start();
321 let match_end = capture.get(0).unwrap().end();
322
323 if match_start == match_end {
324 search_start += 1;
326 continue;
327 }
328
329 ret.push_str(&text[search_start..search_start + match_start]);
331
332 capture.expand(&ctx.replacement, &mut ret);
335
336 search_start += match_end;
338 }
339
340 ret.push_str(&text[search_start..]);
342
343 Ok(ret.into())
344 }
345 } else {
346 let mut ret = if start > 1 {
353 text[..start].to_string()
354 } else {
355 "".to_owned()
356 };
357
358 if ctx.regex.captures_len() <= 1 {
360 if let Some(n) = n {
362 let mut count = 1;
364 let mut search_start = start;
366 while let Ok(Some(capture)) = ctx.regex.captures(&text[search_start..]) {
367 let match_start = capture.get(0).unwrap().start();
369 let match_end = capture.get(0).unwrap().end();
370
371 if count == n {
372 ret = format!(
375 "{}{}{}",
376 &text[..search_start + match_start],
377 &ctx.replacement,
378 &text[search_start + match_end..]
379 );
380 break;
381 }
382
383 count += 1;
385
386 search_start += match_end;
388 }
389 } else {
390 ret.push_str(&ctx.regex.replacen(&text[start..], 1, &ctx.replacement));
392 }
393 } else {
394 ret = "".to_owned();
397 if let Some(n) = n {
398 let mut count = 1;
400 while let Ok(Some(capture)) = ctx.regex.captures(&text[start..]) {
401 if count == n {
402 let match_start = capture.get(0).unwrap().start();
404 let match_end = capture.get(0).unwrap().end();
405
406 capture.expand(&ctx.replacement, &mut ret);
408
409 ret = format!(
411 "{}{}{}",
412 &text[..start + match_start],
413 ret,
414 &text[start + match_end..]
415 );
416 }
417
418 count += 1;
420 }
421
422 if ret.is_empty() {
424 ret = text.into();
425 }
426 } else {
427 if let Ok(None) = ctx.regex.captures(&text[start..]) {
429 return Ok(text.into());
431 }
432
433 if let Ok(Some(capture)) = ctx.regex.captures(&text[start..]) {
435 let match_start = capture.get(0).unwrap().start();
436 let match_end = capture.get(0).unwrap().end();
437
438 capture.expand(&ctx.replacement, &mut ret);
440
441 ret = format!(
443 "{}{}{}",
444 &text[..start + match_start],
445 ret,
446 &text[start + match_end..]
447 );
448 }
449 }
450 }
451
452 Ok(ret.into())
453 }
454}
455
456#[function(
457 "regexp_split_to_array(varchar, varchar) -> varchar[]",
459 prebuild = "RegexpContext::from_pattern($1)?"
460)]
461#[function(
462 "regexp_split_to_array(varchar, varchar, varchar) -> varchar[]",
464 prebuild = "RegexpContext::from_pattern_flags($1, $2)?"
465)]
466fn regexp_split_to_array(
467 text: &str,
468 regex: &RegexpContext,
469 writer: &mut impl risingwave_common::array::ListWrite,
470) -> Option<()> {
471 let n = text.len();
472 let mut start = 0;
473 let mut empty_flag = false;
474
475 loop {
476 if start >= n {
477 break;
479 }
480
481 let capture = regex.regex.captures(&text[start..]).unwrap();
482
483 if capture.is_none() {
484 break;
485 }
486
487 let whole_match = capture.unwrap().get(0);
488 debug_assert!(whole_match.is_some(), "Expected `whole_match` to be valid");
489
490 let begin = whole_match.unwrap().start() + start;
491 let end = whole_match.unwrap().end() + start;
492
493 if begin == end {
494 empty_flag = true;
496
497 if begin == text.len() {
498 start = begin;
500 break;
501 }
502 writer.write(Some(ScalarRefImpl::Utf8(&text[start..begin + 1])));
503 start = end + 1;
504 continue;
505 }
506
507 if start == begin {
508 if !empty_flag {
510 writer.write(Some(ScalarRefImpl::Utf8("")));
513 }
514 start = end;
515 continue;
516 }
517
518 if begin != 0 {
519 writer.write(Some(ScalarRefImpl::Utf8(&text[start..begin])));
521 }
522
523 start = end;
525 }
526
527 if start < n {
528 writer.write(Some(ScalarRefImpl::Utf8(&text[start..])));
532 }
533
534 if start == n && !empty_flag {
535 writer.write(Some(ScalarRefImpl::Utf8("")));
536 }
537
538 Some(())
539}