risingwave_expr_impl/scalar/
substr.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Write;
16
17use risingwave_expr::{ExprError, Result, function};
18
19#[function("substr(varchar, int4) -> varchar")]
20pub fn substr_start(s: &str, start: i32, writer: &mut impl Write) -> Result<()> {
21    let skip = start.saturating_sub(1).max(0) as usize;
22
23    let substr = s.chars().skip(skip);
24    for char in substr {
25        writer.write_char(char).unwrap();
26    }
27
28    Ok(())
29}
30
31#[function("substr(bytea, int4) -> bytea")]
32pub fn substr_start_bytea(s: &[u8], start: i32) -> Box<[u8]> {
33    let skip = start.saturating_sub(1).max(0) as usize;
34
35    s.iter().copied().skip(skip).collect()
36}
37
38fn convert_args(start: i32, count: i32) -> Result<(usize, usize)> {
39    if count < 0 {
40        return Err(ExprError::InvalidParam {
41            name: "length",
42            reason: "negative substring length not allowed".into(),
43        });
44    }
45
46    let skip = start.saturating_sub(1).max(0) as usize;
47    let take = if start >= 1 {
48        count as usize
49    } else {
50        count.saturating_add(start.saturating_sub(1)).max(0) as usize
51    };
52
53    // The returned args may still go out of bounds.
54    // So `skip` and `take` on iterator is safer than `[skip..(skip+take)]`
55    Ok((skip, take))
56}
57
58#[function("substr(varchar, int4, int4) -> varchar")]
59pub fn substr_start_for(s: &str, start: i32, count: i32, writer: &mut impl Write) -> Result<()> {
60    let (skip, take) = convert_args(start, count)?;
61
62    let substr = s.chars().skip(skip).take(take);
63    for char in substr {
64        writer.write_char(char).unwrap();
65    }
66
67    Ok(())
68}
69
70#[function("substr(bytea, int4, int4) -> bytea")]
71pub fn substr_start_for_bytea(s: &[u8], start: i32, count: i32) -> Result<Box<[u8]>> {
72    let (skip, take) = convert_args(start, count)?;
73
74    Ok(s.iter().copied().skip(skip).take(take).collect())
75}
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80
81    #[test]
82    fn test_substr() -> Result<()> {
83        let s = "cxscgccdd";
84        let us = "上海自来水来自海上";
85
86        let cases = [
87            (s, 4, None, "cgccdd"),
88            (s, 4, Some(-2), "[unused result]"),
89            (s, 4, Some(2), "cg"),
90            (s, -1, Some(-5), "[unused result]"),
91            (s, -1, Some(0), ""),
92            (s, -1, Some(1), ""),
93            (s, -1, Some(2), ""),
94            (s, -1, Some(3), "c"),
95            (s, -1, Some(5), "cxs"),
96            // Unicode test
97            (us, 1, Some(3), "上海自"),
98            (us, 3, Some(3), "自来水"),
99            (us, 6, Some(2), "来自"),
100            (us, 6, Some(100), "来自海上"),
101            (us, 6, None, "来自海上"),
102            ("Mér", 1, Some(2), "Mé"),
103        ];
104
105        for (s, off, len, expected) in cases {
106            let mut writer = String::new();
107            match len {
108                Some(len) => {
109                    let result = substr_start_for(s, off, len, &mut writer);
110                    if len < 0 {
111                        assert!(result.is_err());
112                        continue;
113                    } else {
114                        result?
115                    }
116                }
117                None => substr_start(s, off, &mut writer)?,
118            }
119            assert_eq!(writer, expected);
120        }
121        Ok(())
122    }
123}