risingwave_expr_impl/scalar/
regexp.rs1use std::str::FromStr;
18
19use fancy_regex::{Regex, RegexBuilder};
20use risingwave_common::array::{ArrayBuilder, ListValue, Utf8Array, Utf8ArrayBuilder};
21use risingwave_expr::{ExprError, Result, bail, function};
22use thiserror_ext::AsReport;
23
24#[derive(Debug)]
25pub struct RegexpContext {
26 pub regex: Regex,
27 pub global: bool,
28 pub replacement: String,
29}
30
31impl RegexpContext {
32 fn new(pattern: &str, flags: &str, replacement: &str) -> Result<Self> {
33 let options = RegexpOptions::from_str(flags)?;
34
35 let origin = if options.case_insensitive {
36 format!("(?i:{})", pattern)
37 } else {
38 pattern.to_owned()
39 };
40
41 Ok(Self {
42 regex: RegexBuilder::new(&origin)
43 .build()
44 .map_err(|e| ExprError::Parse(e.to_report_string().into()))?,
45 global: options.global,
46 replacement: make_replacement(replacement),
47 })
48 }
49
50 pub fn from_pattern(pattern: &str) -> Result<Self> {
51 Self::new(pattern, "", "")
52 }
53
54 pub fn from_pattern_flags(pattern: &str, flags: &str) -> Result<Self> {
55 Self::new(pattern, flags, "")
56 }
57
58 pub fn from_pattern_flags_for_count(pattern: &str, flags: &str) -> Result<Self> {
59 if flags.contains('g') {
60 bail!("regexp_count() does not support the global option");
61 }
62 Self::new(pattern, flags, "")
63 }
64
65 pub fn from_pattern_replacement(pattern: &str, replacement: &str) -> Result<Self> {
66 Self::new(pattern, "", replacement)
67 }
68
69 pub fn from_pattern_replacement_flags(
70 pattern: &str,
71 replacement: &str,
72 flags: &str,
73 ) -> Result<Self> {
74 Self::new(pattern, flags, replacement)
75 }
76}
77
78fn make_replacement(s: &str) -> String {
88 use std::fmt::Write;
89 let mut ret = String::with_capacity(s.len());
90 let mut chars = s.chars();
91 while let Some(c) = chars.next() {
92 if c != '\\' {
93 ret.push(c);
94 continue;
95 }
96 match chars.next() {
97 Some('&') => ret.push_str("${0}"),
98 Some(c @ '1'..='9') => write!(&mut ret, "${{{c}}}").unwrap(),
99 Some(c) => write!(ret, "\\{c}").unwrap(),
100 None => ret.push('\\'),
101 }
102 }
103 ret
104}
105
106#[derive(Default, Debug)]
108struct RegexpOptions {
109 case_insensitive: bool,
111 global: bool,
113}
114
115impl FromStr for RegexpOptions {
116 type Err = ExprError;
117
118 fn from_str(s: &str) -> Result<Self> {
119 let mut opts = Self::default();
120 for c in s.chars() {
121 match c {
122 'c' => opts.case_insensitive = false,
124 'i' => opts.case_insensitive = true,
126 'g' => opts.global = true,
128 _ => {
129 bail!("invalid regular expression option: \"{c}\"");
130 }
131 }
132 }
133 Ok(opts)
134 }
135}
136
137#[function(
138 "regexp_eq(varchar, varchar) -> boolean",
140 prebuild = "RegexpContext::from_pattern($1)?"
141)]
142fn regexp_eq(text: &str, regex: &RegexpContext) -> bool {
143 regex.regex.is_match(text).unwrap()
144}
145
146#[function(
147 "regexp_match(varchar, varchar) -> varchar[]",
149 prebuild = "RegexpContext::from_pattern($1)?"
150)]
151#[function(
152 "regexp_match(varchar, varchar, varchar) -> varchar[]",
154 prebuild = "RegexpContext::from_pattern_flags($1, $2)?"
155)]
156fn regexp_match(text: &str, regex: &RegexpContext) -> Option<ListValue> {
157 let skip_first = regex.regex.captures_len() > 1;
160 let capture = regex.regex.captures(text).unwrap()?;
161 let list = capture
162 .iter()
163 .skip(if skip_first { 1 } else { 0 })
164 .map(|mat| mat.map(|m| m.as_str()))
165 .collect::<Utf8Array>();
166 Some(ListValue::new(list.into()))
167}
168
169#[function(
170 "regexp_count(varchar, varchar) -> int4",
172 prebuild = "RegexpContext::from_pattern($1)?"
173)]
174fn regexp_count_start0(text: &str, regex: &RegexpContext) -> Result<i32> {
175 regexp_count(text, 1, regex)
176}
177
178#[function(
179 "regexp_count(varchar, varchar, int4) -> int4",
181 prebuild = "RegexpContext::from_pattern($1)?"
182)]
183#[function(
184 "regexp_count(varchar, varchar, int4, varchar) -> int4",
186 prebuild = "RegexpContext::from_pattern_flags_for_count($1, $3)?"
187)]
188fn regexp_count(text: &str, start: i32, regex: &RegexpContext) -> Result<i32> {
189 let start = match start {
191 ..=0 => {
192 return Err(ExprError::InvalidParam {
193 name: "start",
194 reason: start.to_string().into(),
195 });
196 }
197 _ => start as usize - 1,
198 };
199
200 let mut start = match text.char_indices().nth(start) {
202 Some((idx, _)) => idx,
203 None => return Ok(0),
205 };
206
207 let mut count = 0;
208 while let Ok(Some(captures)) = regex.regex.captures(&text[start..]) {
209 count += 1;
210 start += captures.get(0).unwrap().end();
211 }
212 Ok(count)
213}
214
215#[function(
216 "regexp_replace(varchar, varchar, varchar) -> varchar",
218 prebuild = "RegexpContext::from_pattern_replacement($1, $2)?"
219)]
220#[function(
221 "regexp_replace(varchar, varchar, varchar, varchar) -> varchar",
223 prebuild = "RegexpContext::from_pattern_replacement_flags($1, $2, $3)?"
224)]
225fn regexp_replace0(text: &str, ctx: &RegexpContext) -> Result<Box<str>> {
226 regexp_replace(text, 1, None, ctx)
227}
228
229#[function(
230 "regexp_replace(varchar, varchar, varchar, int4) -> varchar",
232 prebuild = "RegexpContext::from_pattern_replacement($1, $2)?"
233)]
234fn regexp_replace_with_start(text: &str, start: i32, ctx: &RegexpContext) -> Result<Box<str>> {
235 regexp_replace(text, start, None, ctx)
236}
237
238#[function(
239 "regexp_replace(varchar, varchar, varchar, int4, int4) -> varchar",
241 prebuild = "RegexpContext::from_pattern_replacement($1, $2)?"
242)]
243fn regexp_replace_with_start_n(
244 text: &str,
245 start: i32,
246 n: i32,
247 ctx: &RegexpContext,
248) -> Result<Box<str>> {
249 regexp_replace(text, start, Some(n), ctx)
250}
251
252#[function(
253 "regexp_replace(varchar, varchar, varchar, int4, int4, varchar) -> varchar",
255 prebuild = "RegexpContext::from_pattern_replacement_flags($1, $2, $5)?"
256)]
257fn regexp_replace_with_start_n_flags(
258 text: &str,
259 start: i32,
260 n: i32,
261 ctx: &RegexpContext,
262) -> Result<Box<str>> {
263 regexp_replace(text, start, Some(n), ctx)
264}
265
266fn regexp_replace(
268 text: &str,
269 start: i32,
270 n: Option<i32>, ctx: &RegexpContext,
272) -> Result<Box<str>> {
273 let start = match start {
275 ..=0 => {
276 return Err(ExprError::InvalidParam {
277 name: "start",
278 reason: start.to_string().into(),
279 });
280 }
281 _ => start as usize - 1,
282 };
283
284 let start = match text.char_indices().nth(start) {
286 Some((idx, _)) => idx,
287 None => return Ok(text.into()),
289 };
290
291 if n.is_none() && ctx.global || n == Some(0) {
292 if ctx.regex.captures_len() <= 1 {
299 Ok(format!(
302 "{}{}",
303 &text[..start],
304 ctx.regex.replace_all(&text[start..], &ctx.replacement)
305 )
306 .into())
307 } else {
308 let mut search_start = start;
310
311 let mut ret = text[..search_start].to_string();
313
314 while let Ok(Some(capture)) = ctx.regex.captures(&text[search_start..]) {
316 let match_start = capture.get(0).unwrap().start();
317 let match_end = capture.get(0).unwrap().end();
318
319 if match_start == match_end {
320 search_start += 1;
322 continue;
323 }
324
325 ret.push_str(&text[search_start..search_start + match_start]);
327
328 capture.expand(&ctx.replacement, &mut ret);
331
332 search_start += match_end;
334 }
335
336 ret.push_str(&text[search_start..]);
338
339 Ok(ret.into())
340 }
341 } else {
342 let mut ret = if start > 1 {
349 text[..start].to_string()
350 } else {
351 "".to_owned()
352 };
353
354 if ctx.regex.captures_len() <= 1 {
356 if let Some(n) = n {
358 let mut count = 1;
360 let mut search_start = start;
362 while let Ok(Some(capture)) = ctx.regex.captures(&text[search_start..]) {
363 let match_start = capture.get(0).unwrap().start();
365 let match_end = capture.get(0).unwrap().end();
366
367 if count == n {
368 ret = format!(
371 "{}{}{}",
372 &text[..search_start + match_start],
373 &ctx.replacement,
374 &text[search_start + match_end..]
375 );
376 break;
377 }
378
379 count += 1;
381
382 search_start += match_end;
384 }
385 } else {
386 ret.push_str(&ctx.regex.replacen(&text[start..], 1, &ctx.replacement));
388 }
389 } else {
390 ret = "".to_owned();
393 if let Some(n) = n {
394 let mut count = 1;
396 while let Ok(Some(capture)) = ctx.regex.captures(&text[start..]) {
397 if count == n {
398 let match_start = capture.get(0).unwrap().start();
400 let match_end = capture.get(0).unwrap().end();
401
402 capture.expand(&ctx.replacement, &mut ret);
404
405 ret = format!(
407 "{}{}{}",
408 &text[..start + match_start],
409 ret,
410 &text[start + match_end..]
411 );
412 }
413
414 count += 1;
416 }
417
418 if ret.is_empty() {
420 ret = text.into();
421 }
422 } else {
423 if let Ok(None) = ctx.regex.captures(&text[start..]) {
425 return Ok(text.into());
427 }
428
429 if let Ok(Some(capture)) = ctx.regex.captures(&text[start..]) {
431 let match_start = capture.get(0).unwrap().start();
432 let match_end = capture.get(0).unwrap().end();
433
434 capture.expand(&ctx.replacement, &mut ret);
436
437 ret = format!(
439 "{}{}{}",
440 &text[..start + match_start],
441 ret,
442 &text[start + match_end..]
443 );
444 }
445 }
446 }
447
448 Ok(ret.into())
449 }
450}
451
452#[function(
453 "regexp_split_to_array(varchar, varchar) -> varchar[]",
455 prebuild = "RegexpContext::from_pattern($1)?"
456)]
457#[function(
458 "regexp_split_to_array(varchar, varchar, varchar) -> varchar[]",
460 prebuild = "RegexpContext::from_pattern_flags($1, $2)?"
461)]
462fn regexp_split_to_array(text: &str, regex: &RegexpContext) -> Option<ListValue> {
463 let n = text.len();
464 let mut start = 0;
465 let mut builder = Utf8ArrayBuilder::new(0);
466 let mut empty_flag = false;
467
468 loop {
469 if start >= n {
470 break;
472 }
473
474 let capture = regex.regex.captures(&text[start..]).unwrap();
475
476 if capture.is_none() {
477 break;
478 }
479
480 let whole_match = capture.unwrap().get(0);
481 debug_assert!(whole_match.is_some(), "Expected `whole_match` to be valid");
482
483 let begin = whole_match.unwrap().start() + start;
484 let end = whole_match.unwrap().end() + start;
485
486 if begin == end {
487 empty_flag = true;
489
490 if begin == text.len() {
491 start = begin;
493 break;
494 }
495 builder.append(Some(&text[start..begin + 1]));
496 start = end + 1;
497 continue;
498 }
499
500 if start == begin {
501 if !empty_flag {
503 builder.append(Some(""));
506 }
507 start = end;
508 continue;
509 }
510
511 if begin != 0 {
512 builder.append(Some(&text[start..begin]));
514 }
515
516 start = end;
518 }
519
520 if start < n {
521 builder.append(Some(&text[start..]));
525 }
526
527 if start == n && !empty_flag {
528 builder.append(Some(""));
529 }
530
531 Some(ListValue::new(builder.finish().into()))
532}