risingwave_expr_impl/scalar/
split_part.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use risingwave_expr::{ExprError, Result, function};
16
17#[function("split_part(varchar, varchar, int4) -> varchar")]
18pub fn split_part(
19    string_expr: &str,
20    delimiter_expr: &str,
21    nth_expr: i32,
22    writer: &mut impl std::fmt::Write,
23) -> Result<()> {
24    if nth_expr == 0 {
25        return Err(ExprError::InvalidParam {
26            name: "data",
27            reason: "can't be zero".into(),
28        });
29    };
30
31    let mut split = string_expr.split(delimiter_expr);
32    let nth_val = if string_expr.is_empty() {
33        // postgres: return empty string for empty input string
34        Default::default()
35    } else if delimiter_expr.is_empty() {
36        // postgres: handle empty field separator
37        //           if first or last field, return input string, else empty string
38        if nth_expr == 1 || nth_expr == -1 {
39            string_expr
40        } else {
41            Default::default()
42        }
43    } else {
44        match nth_expr.cmp(&0) {
45            std::cmp::Ordering::Equal => unreachable!(),
46
47            // Since `nth_expr` can not be 0, so the `abs()` of it can not be smaller than 1
48            // (that's `abs(1)` or `abs(-1)`).  Hence the result of sub 1 can not be less than 0.
49            // postgres: if nonexistent field, return empty string
50            std::cmp::Ordering::Greater => split.nth(nth_expr as usize - 1).unwrap_or_default(),
51            std::cmp::Ordering::Less => {
52                let split = split.collect::<Vec<_>>();
53                split
54                    .iter()
55                    .rev()
56                    .nth(nth_expr.unsigned_abs() as usize - 1)
57                    .cloned()
58                    .unwrap_or_default()
59            }
60        }
61    };
62    writer.write_str(nth_val).unwrap();
63    Ok(())
64}
65
66#[cfg(test)]
67mod tests {
68    use super::split_part;
69
70    #[test]
71    fn test_split_part() {
72        let cases: Vec<(&str, &str, i32, Option<&str>)> = vec![
73            // postgres cases
74            ("", "@", 1, Some("")),
75            ("", "@", -1, Some("")),
76            ("joeuser@mydatabase", "", 1, Some("joeuser@mydatabase")),
77            ("joeuser@mydatabase", "", 2, Some("")),
78            ("joeuser@mydatabase", "", -1, Some("joeuser@mydatabase")),
79            ("joeuser@mydatabase", "", -2, Some("")),
80            ("joeuser@mydatabase", "@", 0, None),
81            ("joeuser@mydatabase", "@@", 1, Some("joeuser@mydatabase")),
82            ("joeuser@mydatabase", "@@", 2, Some("")),
83            ("joeuser@mydatabase", "@", 1, Some("joeuser")),
84            ("joeuser@mydatabase", "@", 2, Some("mydatabase")),
85            ("joeuser@mydatabase", "@", 3, Some("")),
86            ("@joeuser@mydatabase@", "@", 2, Some("joeuser")),
87            ("joeuser@mydatabase", "@", -1, Some("mydatabase")),
88            ("joeuser@mydatabase", "@", -2, Some("joeuser")),
89            ("joeuser@mydatabase", "@", -3, Some("")),
90            ("@joeuser@mydatabase@", "@", -2, Some("mydatabase")),
91            // other cases
92
93            // makes sure that `rsplit` is not used internally when `nth` is negative
94            ("@@@", "@@", -1, Some("@")),
95        ];
96
97        for (i, case @ (string_expr, delimiter_expr, nth_expr, expected)) in
98            cases.iter().enumerate()
99        {
100            let mut writer = String::new();
101            let actual = match split_part(string_expr, delimiter_expr, *nth_expr, &mut writer) {
102                Ok(_) => Some(writer.as_str()),
103                Err(_) => None,
104            };
105            assert_eq!(&actual, expected, "\nat case {i}: {:?}\n", case);
106        }
107    }
108}