risingwave_expr_impl/scalar/
trim.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Write;
16
17use risingwave_expr::function;
18
19#[function("trim(varchar) -> varchar")]
20pub fn trim(s: &str, writer: &mut impl Write) {
21    writer.write_str(s.trim()).unwrap();
22}
23
24/// Note: the behavior of `ltrim` in `PostgreSQL` and `trim_start` (or `trim_left`) in Rust
25/// are actually different when the string is in right-to-left languages like Arabic or Hebrew.
26/// Since we would like to simplify the implementation, currently we omit this case.
27#[function("ltrim(varchar) -> varchar")]
28pub fn ltrim(s: &str, writer: &mut impl Write) {
29    writer.write_str(s.trim_start()).unwrap();
30}
31
32/// Note: the behavior of `rtrim` in `PostgreSQL` and `trim_end` (or `trim_right`) in Rust
33/// are actually different when the string is in right-to-left languages like Arabic or Hebrew.
34/// Since we would like to simplify the implementation, currently we omit this case.
35#[function("rtrim(varchar) -> varchar")]
36pub fn rtrim(s: &str, writer: &mut impl Write) {
37    writer.write_str(s.trim_end()).unwrap();
38}
39
40#[function("trim(varchar, varchar) -> varchar")]
41pub fn trim_characters(s: &str, characters: &str, writer: &mut impl Write) {
42    let pattern = |c| characters.chars().any(|ch| ch == c);
43    // We remark that feeding a &str and a slice of chars into trim_left/right_matches
44    // means different, one is matching with the entire string and the other one is matching
45    // with any char in the slice.
46    writer.write_str(s.trim_matches(pattern)).unwrap();
47}
48
49#[function("ltrim(varchar, varchar) -> varchar")]
50pub fn ltrim_characters(s: &str, characters: &str, writer: &mut impl Write) {
51    let pattern = |c| characters.chars().any(|ch| ch == c);
52    writer.write_str(s.trim_start_matches(pattern)).unwrap();
53}
54
55#[function("rtrim(varchar, varchar) -> varchar")]
56pub fn rtrim_characters(s: &str, characters: &str, writer: &mut impl Write) {
57    let pattern = |c| characters.chars().any(|ch| ch == c);
58    writer.write_str(s.trim_end_matches(pattern)).unwrap();
59}
60
61fn trim_bound(bytes: &[u8], bytesremoved: &[u8]) -> (usize, usize) {
62    let existed = |b: &u8| bytesremoved.contains(b);
63
64    let start = bytes
65        .iter()
66        .position(|b| !existed(b))
67        .unwrap_or(bytes.len());
68
69    let end = bytes
70        .iter()
71        .rposition(|b| !existed(b))
72        .map(|i| i + 1)
73        .unwrap_or(0);
74
75    (start, end)
76}
77
78///  Removes the longest string containing only bytes appearing in bytesremoved from the start,
79///  end, or both ends (BOTH is the default) of bytes.
80///
81/// # Example
82///
83/// ```slt
84/// query T
85/// SELECT trim('\x9012'::bytea from '\x1234567890'::bytea);
86/// ----
87/// \x345678
88/// ```
89#[function("trim(bytea, bytea) -> bytea")]
90pub fn trim_bytea(bytes: &[u8], bytesremoved: &[u8]) -> Box<[u8]> {
91    let (start, end) = trim_bound(bytes, bytesremoved);
92
93    bytes
94        .get(start..end)
95        .map(|s| s.iter().copied().collect())
96        .unwrap_or_else(|| Box::<[u8]>::from([]))
97}
98
99/// Removes the longest string containing only bytes appearing in bytesremoved
100/// from the start of bytes.
101///
102/// # Example
103///
104/// ```slt
105/// query T
106/// SELECT ltrim('\x1234567890'::bytea, '\x9012'::bytea);
107/// ----
108/// \x34567890
109/// ```
110#[function("ltrim(bytea, bytea) -> bytea")]
111pub fn ltrim_bytea(bytes: &[u8], bytesremoved: &[u8]) -> Box<[u8]> {
112    let (start, _) = trim_bound(bytes, bytesremoved);
113    bytes[start..].iter().copied().collect()
114}
115
116/// Removes the longest string containing only bytes appearing in bytesremoved
117/// from the end of bytes.
118///
119/// # Example
120///
121/// ```slt
122/// query T
123/// SELECT rtrim('\x1234567890'::bytea, '\x9012'::bytea);
124/// ----
125/// \x12345678
126/// ```
127#[function("rtrim(bytea, bytea) -> bytea")]
128pub fn rtrim_bytea(bytes: &[u8], bytesremoved: &[u8]) -> Box<[u8]> {
129    let (_, end) = trim_bound(bytes, bytesremoved);
130    bytes[..end].iter().copied().collect()
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136
137    #[test]
138    fn test_trim() {
139        let cases = [
140            (" Hello\tworld\t", "Hello\tworld"),
141            (" 空I ❤️ databases空 ", "空I ❤️ databases空"),
142        ];
143
144        for (s, expected) in cases {
145            let mut writer = String::new();
146            trim(s, &mut writer);
147            assert_eq!(writer, expected);
148        }
149    }
150
151    #[test]
152    fn test_ltrim() {
153        let cases = [
154            (" \tHello\tworld\t", "Hello\tworld\t"),
155            (" \t空I ❤️ databases空 ", "空I ❤️ databases空 "),
156        ];
157
158        for (s, expected) in cases {
159            let mut writer = String::new();
160            ltrim(s, &mut writer);
161            assert_eq!(writer, expected);
162        }
163    }
164
165    #[test]
166    fn test_rtrim() {
167        let cases = [
168            (" \tHello\tworld\t ", " \tHello\tworld"),
169            (" \t空I ❤️ databases空\t ", " \t空I ❤️ databases空"),
170        ];
171
172        for (s, expected) in cases {
173            let mut writer = String::new();
174            rtrim(s, &mut writer);
175            assert_eq!(writer, expected);
176        }
177    }
178
179    #[test]
180    fn test_trim_characters() {
181        let cases = [
182            ("Hello world", "Hdl", "ello wor"),
183            ("abcde", "aae", "bcd"),
184            ("zxy", "yxz", ""),
185        ];
186
187        for (s, characters, expected) in cases {
188            let mut writer = String::new();
189            trim_characters(s, characters, &mut writer);
190            assert_eq!(writer, expected);
191        }
192    }
193
194    #[test]
195    fn test_ltrim_characters() {
196        let cases = [
197            ("Hello world", "Hdl", "ello world"),
198            ("abcde", "aae", "bcde"),
199            ("zxy", "yxz", ""),
200        ];
201
202        for (s, characters, expected) in cases {
203            let mut writer = String::new();
204            ltrim_characters(s, characters, &mut writer);
205            assert_eq!(writer, expected);
206        }
207    }
208
209    #[test]
210    fn test_rtrim_characters() {
211        let cases = [
212            ("Hello world", "Hdl", "Hello wor"),
213            ("abcde", "aae", "abcd"),
214            ("zxy", "yxz", ""),
215        ];
216
217        for (s, characters, expected) in cases {
218            let mut writer = String::new();
219            rtrim_characters(s, characters, &mut writer);
220            assert_eq!(writer, expected);
221        }
222    }
223}