risingwave_expr_impl/scalar/
substr.rs1use risingwave_expr::{ExprError, Result, function};
16
17#[function("substr(varchar, int4) -> varchar")]
18pub fn substr_start(s: &str, start: i32, writer: &mut impl std::fmt::Write) -> Result<()> {
19 let skip = start.saturating_sub(1).max(0) as usize;
20
21 let substr = s.chars().skip(skip);
22 for char in substr {
23 writer.write_char(char).unwrap();
24 }
25
26 Ok(())
27}
28
29#[function("substr(bytea, int4) -> bytea")]
30pub fn substr_start_bytea(s: &[u8], start: i32, writer: &mut impl std::io::Write) {
31 let skip = start.saturating_sub(1).max(0) as usize;
32 if skip >= s.len() {
33 return;
34 }
35 writer.write_all(&s[skip..]).unwrap();
36}
37
38fn convert_args(start: i32, count: i32) -> Result<(usize, usize)> {
39 if count < 0 {
40 return Err(ExprError::InvalidParam {
41 name: "length",
42 reason: "negative substring length not allowed".into(),
43 });
44 }
45
46 let skip = start.saturating_sub(1).max(0) as usize;
47 let take = if start >= 1 {
48 count as usize
49 } else {
50 count.saturating_add(start.saturating_sub(1)).max(0) as usize
51 };
52
53 Ok((skip, take))
56}
57
58#[function("substr(varchar, int4, int4) -> varchar")]
59pub fn substr_start_for(
60 s: &str,
61 start: i32,
62 count: i32,
63 writer: &mut impl std::fmt::Write,
64) -> Result<()> {
65 let (skip, take) = convert_args(start, count)?;
66
67 let substr = s.chars().skip(skip).take(take);
68 for char in substr {
69 writer.write_char(char).unwrap();
70 }
71
72 Ok(())
73}
74
75#[function("substr(bytea, int4, int4) -> bytea")]
76pub fn substr_start_for_bytea(
77 s: &[u8],
78 start: i32,
79 count: i32,
80 writer: &mut impl std::io::Write,
81) -> Result<()> {
82 let (skip, take) = convert_args(start, count)?;
83
84 if skip >= s.len() {
85 return Ok(());
86 }
87 let end = (skip + take).min(s.len());
88 writer.write_all(&s[skip..end]).unwrap();
89 Ok(())
90}
91
92#[cfg(test)]
93mod tests {
94 use super::*;
95
96 #[test]
97 fn test_substr() -> Result<()> {
98 let s = "cxscgccdd";
99 let us = "上海自来水来自海上";
100
101 let cases = [
102 (s, 4, None, "cgccdd"),
103 (s, 4, Some(-2), "[unused result]"),
104 (s, 4, Some(2), "cg"),
105 (s, -1, Some(-5), "[unused result]"),
106 (s, -1, Some(0), ""),
107 (s, -1, Some(1), ""),
108 (s, -1, Some(2), ""),
109 (s, -1, Some(3), "c"),
110 (s, -1, Some(5), "cxs"),
111 (us, 1, Some(3), "上海自"),
113 (us, 3, Some(3), "自来水"),
114 (us, 6, Some(2), "来自"),
115 (us, 6, Some(100), "来自海上"),
116 (us, 6, None, "来自海上"),
117 ("Mér", 1, Some(2), "Mé"),
118 ];
119
120 for (s, off, len, expected) in cases {
121 let mut writer = String::new();
122 match len {
123 Some(len) => {
124 let result = substr_start_for(s, off, len, &mut writer);
125 if len < 0 {
126 assert!(result.is_err());
127 continue;
128 } else {
129 result?
130 }
131 }
132 None => substr_start(s, off, &mut writer)?,
133 }
134 assert_eq!(writer, expected);
135 }
136 Ok(())
137 }
138}