risingwave_expr_impl/scalar/
trim.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use risingwave_expr::function;
16
17#[function("trim(varchar) -> varchar")]
18pub fn trim(s: &str, writer: &mut impl std::fmt::Write) {
19    writer.write_str(s.trim()).unwrap();
20}
21
22/// Note: the behavior of `ltrim` in `PostgreSQL` and `trim_start` (or `trim_left`) in Rust
23/// are actually different when the string is in right-to-left languages like Arabic or Hebrew.
24/// Since we would like to simplify the implementation, currently we omit this case.
25#[function("ltrim(varchar) -> varchar")]
26pub fn ltrim(s: &str, writer: &mut impl std::fmt::Write) {
27    writer.write_str(s.trim_start()).unwrap();
28}
29
30/// Note: the behavior of `rtrim` in `PostgreSQL` and `trim_end` (or `trim_right`) in Rust
31/// are actually different when the string is in right-to-left languages like Arabic or Hebrew.
32/// Since we would like to simplify the implementation, currently we omit this case.
33#[function("rtrim(varchar) -> varchar")]
34pub fn rtrim(s: &str, writer: &mut impl std::fmt::Write) {
35    writer.write_str(s.trim_end()).unwrap();
36}
37
38#[function("trim(varchar, varchar) -> varchar")]
39pub fn trim_characters(s: &str, characters: &str, writer: &mut impl std::fmt::Write) {
40    let pattern = |c| characters.chars().any(|ch| ch == c);
41    // We remark that feeding a &str and a slice of chars into trim_left/right_matches
42    // means different, one is matching with the entire string and the other one is matching
43    // with any char in the slice.
44    writer.write_str(s.trim_matches(pattern)).unwrap();
45}
46
47#[function("ltrim(varchar, varchar) -> varchar")]
48pub fn ltrim_characters(s: &str, characters: &str, writer: &mut impl std::fmt::Write) {
49    let pattern = |c| characters.chars().any(|ch| ch == c);
50    writer.write_str(s.trim_start_matches(pattern)).unwrap();
51}
52
53#[function("rtrim(varchar, varchar) -> varchar")]
54pub fn rtrim_characters(s: &str, characters: &str, writer: &mut impl std::fmt::Write) {
55    let pattern = |c| characters.chars().any(|ch| ch == c);
56    writer.write_str(s.trim_end_matches(pattern)).unwrap();
57}
58
59fn trim_bound(bytes: &[u8], bytesremoved: &[u8]) -> (usize, usize) {
60    let existed = |b: &u8| bytesremoved.contains(b);
61
62    let start = bytes
63        .iter()
64        .position(|b| !existed(b))
65        .unwrap_or(bytes.len());
66
67    let end = bytes
68        .iter()
69        .rposition(|b| !existed(b))
70        .map(|i| i + 1)
71        .unwrap_or(0);
72
73    (start, end)
74}
75
76///  Removes the longest string containing only bytes appearing in bytesremoved from the start,
77///  end, or both ends (BOTH is the default) of bytes.
78///
79/// # Example
80///
81/// ```slt
82/// query T
83/// SELECT trim('\x9012'::bytea from '\x1234567890'::bytea);
84/// ----
85/// \x345678
86/// ```
87#[function("trim(bytea, bytea) -> bytea")]
88pub fn trim_bytea(bytes: &[u8], bytesremoved: &[u8], writer: &mut impl std::io::Write) {
89    let (start, end) = trim_bound(bytes, bytesremoved);
90
91    if let Some(slice) = bytes.get(start..end) {
92        writer.write_all(slice).unwrap();
93    }
94}
95
96/// Removes the longest string containing only bytes appearing in bytesremoved
97/// from the start of bytes.
98///
99/// # Example
100///
101/// ```slt
102/// query T
103/// SELECT ltrim('\x1234567890'::bytea, '\x9012'::bytea);
104/// ----
105/// \x34567890
106/// ```
107#[function("ltrim(bytea, bytea) -> bytea")]
108pub fn ltrim_bytea(bytes: &[u8], bytesremoved: &[u8], writer: &mut impl std::io::Write) {
109    let (start, _) = trim_bound(bytes, bytesremoved);
110    writer.write_all(&bytes[start..]).unwrap();
111}
112
113/// Removes the longest string containing only bytes appearing in bytesremoved
114/// from the end of bytes.
115///
116/// # Example
117///
118/// ```slt
119/// query T
120/// SELECT rtrim('\x1234567890'::bytea, '\x9012'::bytea);
121/// ----
122/// \x12345678
123/// ```
124#[function("rtrim(bytea, bytea) -> bytea")]
125pub fn rtrim_bytea(bytes: &[u8], bytesremoved: &[u8], writer: &mut impl std::io::Write) {
126    let (_, end) = trim_bound(bytes, bytesremoved);
127    writer.write_all(&bytes[..end]).unwrap();
128}
129
130#[cfg(test)]
131mod tests {
132    use hex_literal::hex;
133
134    use super::*;
135
136    #[test]
137    fn test_trim() {
138        let cases = [
139            (" Hello\tworld\t", "Hello\tworld"),
140            (" 空I ❤️ databases空 ", "空I ❤️ databases空"),
141        ];
142
143        for (s, expected) in cases {
144            let mut writer = String::new();
145            trim(s, &mut writer);
146            assert_eq!(writer, expected);
147        }
148    }
149
150    #[test]
151    fn test_ltrim() {
152        let cases = [
153            (" \tHello\tworld\t", "Hello\tworld\t"),
154            (" \t空I ❤️ databases空 ", "空I ❤️ databases空 "),
155        ];
156
157        for (s, expected) in cases {
158            let mut writer = String::new();
159            ltrim(s, &mut writer);
160            assert_eq!(writer, expected);
161        }
162    }
163
164    #[test]
165    fn test_rtrim() {
166        let cases = [
167            (" \tHello\tworld\t ", " \tHello\tworld"),
168            (" \t空I ❤️ databases空\t ", " \t空I ❤️ databases空"),
169        ];
170
171        for (s, expected) in cases {
172            let mut writer = String::new();
173            rtrim(s, &mut writer);
174            assert_eq!(writer, expected);
175        }
176    }
177
178    #[test]
179    fn test_trim_characters() {
180        let cases = [
181            ("Hello world", "Hdl", "ello wor"),
182            ("abcde", "aae", "bcd"),
183            ("zxy", "yxz", ""),
184        ];
185
186        for (s, characters, expected) in cases {
187            let mut writer = String::new();
188            trim_characters(s, characters, &mut writer);
189            assert_eq!(writer, expected);
190        }
191    }
192
193    #[test]
194    fn test_ltrim_characters() {
195        let cases = [
196            ("Hello world", "Hdl", "ello world"),
197            ("abcde", "aae", "bcde"),
198            ("zxy", "yxz", ""),
199        ];
200
201        for (s, characters, expected) in cases {
202            let mut writer = String::new();
203            ltrim_characters(s, characters, &mut writer);
204            assert_eq!(writer, expected);
205        }
206    }
207
208    #[test]
209    fn test_rtrim_characters() {
210        let cases = [
211            ("Hello world", "Hdl", "Hello wor"),
212            ("abcde", "aae", "abcd"),
213            ("zxy", "yxz", ""),
214        ];
215
216        for (s, characters, expected) in cases {
217            let mut writer = String::new();
218            rtrim_characters(s, characters, &mut writer);
219            assert_eq!(writer, expected);
220        }
221    }
222
223    #[test]
224    fn test_trim_bytea() {
225        let cases = [
226            (
227                &hex!("1234567890") as &[u8],
228                &hex!("9012") as &[u8],
229                &hex!("345678") as &[u8],
230            ),
231            (
232                &hex!("abcdef") as &[u8],
233                &hex!("00abcf") as &[u8],
234                &hex!("cdef") as &[u8],
235            ),
236            (
237                &hex!("11112222") as &[u8],
238                &hex!("1122") as &[u8],
239                b"" as &[u8],
240            ),
241        ];
242
243        for (bytes, bytesremoved, expected) in cases {
244            let mut result = Vec::new();
245            trim_bytea(bytes, bytesremoved, &mut result);
246            assert_eq!(&result, expected);
247        }
248    }
249
250    #[test]
251    fn test_ltrim_bytea() {
252        let cases = [
253            (
254                &hex!("1234567890") as &[u8],
255                &hex!("9012") as &[u8],
256                &hex!("34567890") as &[u8],
257            ),
258            (
259                &hex!("abcdef") as &[u8],
260                &hex!("00abcf") as &[u8],
261                &hex!("cdef") as &[u8],
262            ),
263            (
264                &hex!("11112222") as &[u8],
265                &hex!("1122") as &[u8],
266                b"" as &[u8],
267            ),
268        ];
269        for (bytes, bytesremoved, expected) in cases {
270            let mut result = Vec::new();
271            ltrim_bytea(bytes, bytesremoved, &mut result);
272            assert_eq!(&result, expected);
273        }
274    }
275
276    #[test]
277    fn test_rtrim_bytea() {
278        let cases = [
279            (
280                &hex!("1234567890") as &[u8],
281                &hex!("9012") as &[u8],
282                &hex!("12345678") as &[u8],
283            ),
284            (
285                &hex!("abcdef") as &[u8],
286                &hex!("00abcf") as &[u8],
287                &hex!("abcdef") as &[u8],
288            ),
289            (
290                &hex!("11112222") as &[u8],
291                &hex!("1122") as &[u8],
292                b"" as &[u8],
293            ),
294        ];
295        for (bytes, bytesremoved, expected) in cases {
296            let mut result = Vec::new();
297            rtrim_bytea(bytes, bytesremoved, &mut result);
298            assert_eq!(&result, expected);
299        }
300    }
301}