risingwave_expr_impl/scalar/
string.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! String functions
16//!
17//! <https://www.postgresql.org/docs/current/functions-string.html>
18
19use std::fmt::Write;
20
21use risingwave_expr::function;
22
23/// Returns the character with the specified Unicode code point.
24///
25/// # Example
26///
27/// ```slt
28/// query T
29/// select chr(65);
30/// ----
31/// A
32/// ```
33#[function("chr(int4) -> varchar")]
34pub fn chr(code: i32, writer: &mut impl Write) {
35    if let Some(c) = std::char::from_u32(code as u32) {
36        write!(writer, "{}", c).unwrap();
37    }
38}
39
40/// Returns true if the given string starts with the specified prefix.
41///
42/// # Example
43///
44/// ```slt
45/// query T
46/// select starts_with('abcdef', 'abc');
47/// ----
48/// t
49///
50/// query T
51/// select 'abcdef' ^@ 'abc';
52/// ----
53/// t
54///
55/// query T
56/// select 'abcdef' ^@ some(array['x', 'a', 't']);
57/// ----
58/// t
59/// ```
60#[function("starts_with(varchar, varchar) -> boolean")]
61pub fn starts_with(s: &str, prefix: &str) -> bool {
62    s.starts_with(prefix)
63}
64
65/// Capitalizes the first letter of each word in the given string.
66///
67/// # Example
68///
69/// ```slt
70/// query T
71/// select initcap('the quick brown fox');
72/// ----
73/// The Quick Brown Fox
74/// ```
75#[function("initcap(varchar) -> varchar")]
76pub fn initcap(s: &str, writer: &mut impl Write) {
77    let mut capitalize_next = true;
78    for c in s.chars() {
79        if capitalize_next {
80            write!(writer, "{}", c.to_uppercase()).unwrap();
81            capitalize_next = false;
82        } else {
83            write!(writer, "{}", c.to_lowercase()).unwrap();
84        }
85        if c.is_whitespace() {
86            capitalize_next = true;
87        }
88    }
89}
90
91/// Extends the given string on the left until it is at least the specified length,
92/// using the specified fill character (or a space by default).
93///
94/// # Example
95///
96/// ```slt
97/// query T
98/// select lpad('abc', 5);
99/// ----
100///   abc
101///
102/// query T
103/// select lpad('abcdef', 3);
104/// ----
105/// abc
106/// ```
107#[function("lpad(varchar, int4) -> varchar")]
108pub fn lpad(s: &str, length: i32, writer: &mut impl Write) {
109    lpad_fill(s, length, " ", writer);
110}
111
112/// Extends the string to the specified length by prepending the characters fill.
113/// If the string is already longer than the specified length, it is truncated on the right.
114///
115/// # Example
116///
117/// ```slt
118/// query T
119/// select lpad('hi', 5, 'xy');
120/// ----
121/// xyxhi
122///
123/// query T
124/// select lpad('hi', 5, '');
125/// ----
126/// hi
127/// ```
128#[function("lpad(varchar, int4, varchar) -> varchar")]
129pub fn lpad_fill(s: &str, length: i32, fill: &str, writer: &mut impl Write) {
130    let s_len = s.chars().count();
131    let fill_len = fill.chars().count();
132
133    if length <= 0 {
134        return;
135    }
136    if s_len >= length as usize {
137        for c in s.chars().take(length as usize) {
138            write!(writer, "{c}").unwrap();
139        }
140    } else if fill_len == 0 {
141        write!(writer, "{s}").unwrap();
142    } else {
143        let mut remaining_length = length as usize - s_len;
144        while remaining_length >= fill_len {
145            write!(writer, "{fill}").unwrap();
146            remaining_length -= fill_len;
147        }
148        for c in fill.chars().take(remaining_length) {
149            write!(writer, "{c}").unwrap();
150        }
151        write!(writer, "{s}").unwrap();
152    }
153}
154
155/// Extends the given string on the right until it is at least the specified length,
156/// using the specified fill character (or a space by default).
157///
158/// # Example
159///
160/// ```slt
161/// query T
162/// select rpad('abc', 5);
163/// ----
164/// abc
165///
166/// query T
167/// select rpad('abcdef', 3);
168/// ----
169/// abc
170/// ```
171#[function("rpad(varchar, int4) -> varchar")]
172pub fn rpad(s: &str, length: i32, writer: &mut impl Write) {
173    rpad_fill(s, length, " ", writer);
174}
175
176/// Extends the given string on the right until it is at least the specified length,
177/// using the specified fill string, truncating the string if it is already longer
178/// than the specified length.
179///
180/// # Example
181///
182/// ```slt
183/// query T
184/// select rpad('hi', 5, 'xy');
185/// ----
186/// hixyx
187///
188/// query T
189/// select rpad('abc', 5, '😀');
190/// ----
191/// abc😀😀
192///
193/// query T
194/// select rpad('abcdef', 3, '0');
195/// ----
196/// abc
197///
198/// query T
199/// select rpad('hi', 5, '');
200/// ----
201/// hi
202/// ```
203#[function("rpad(varchar, int4, varchar) -> varchar")]
204pub fn rpad_fill(s: &str, length: i32, fill: &str, writer: &mut impl Write) {
205    let s_len = s.chars().count();
206    let fill_len = fill.chars().count();
207
208    if length <= 0 {
209        return;
210    }
211
212    if s_len >= length as usize {
213        for c in s.chars().take(length as usize) {
214            write!(writer, "{c}").unwrap();
215        }
216    } else if fill_len == 0 {
217        write!(writer, "{s}").unwrap();
218    } else {
219        write!(writer, "{s}").unwrap();
220        let mut remaining_length = length as usize - s_len;
221        while remaining_length >= fill_len {
222            write!(writer, "{fill}").unwrap();
223            remaining_length -= fill_len;
224        }
225        for c in fill.chars().take(remaining_length) {
226            write!(writer, "{c}").unwrap();
227        }
228    }
229}
230
231/// Reverses the characters in the given string.
232///
233/// # Example
234///
235/// ```slt
236/// query T
237/// select reverse('abcdef');
238/// ----
239/// fedcba
240/// ```
241#[function("reverse(varchar) -> varchar")]
242pub fn reverse(s: &str, writer: &mut impl Write) {
243    for c in s.chars().rev() {
244        write!(writer, "{}", c).unwrap();
245    }
246}
247
248/// Converts the input string to ASCII by dropping accents, assuming that the input string
249/// is encoded in one of the supported encodings (Latin1, Latin2, Latin9, or WIN1250).
250///
251/// # Example
252///
253/// ```slt
254/// query T
255/// select to_ascii('Karél');
256/// ----
257/// Karel
258/// ```
259#[function("to_ascii(varchar) -> varchar")]
260pub fn to_ascii(s: &str, writer: &mut impl Write) {
261    for c in s.chars() {
262        let ascii = match c {
263            'Á' | 'À' | 'Â' | 'Ã' => 'A',
264            'á' | 'à' | 'â' | 'ã' => 'a',
265            'Č' | 'Ć' | 'Ç' => 'C',
266            'č' | 'ć' | 'ç' => 'c',
267            'Ď' => 'D',
268            'ď' => 'd',
269            'É' | 'È' | 'Ê' | 'Ẽ' => 'E',
270            'é' | 'è' | 'ê' | 'ẽ' => 'e',
271            'Í' | 'Ì' | 'Î' | 'Ĩ' => 'I',
272            'í' | 'ì' | 'î' | 'ĩ' => 'i',
273            'Ľ' => 'L',
274            'ľ' => 'l',
275            'Ň' => 'N',
276            'ň' => 'n',
277            'Ó' | 'Ò' | 'Ô' | 'Õ' => 'O',
278            'ó' | 'ò' | 'ô' | 'õ' => 'o',
279            'Ŕ' => 'R',
280            'ŕ' => 'r',
281            'Š' | 'Ś' => 'S',
282            'š' | 'ś' => 's',
283            'Ť' => 'T',
284            'ť' => 't',
285            'Ú' | 'Ù' | 'Û' | 'Ũ' => 'U',
286            'ú' | 'ù' | 'û' | 'ũ' => 'u',
287            'Ý' | 'Ỳ' => 'Y',
288            'ý' | 'ỳ' => 'y',
289            'Ž' | 'Ź' | 'Ż' => 'Z',
290            'ž' | 'ź' | 'ż' => 'z',
291            _ => c,
292        };
293        write!(writer, "{}", ascii).unwrap();
294    }
295}
296
297/// Converts the given integer to its equivalent hexadecimal representation.
298///
299/// # Example
300///
301/// ```slt
302/// query T
303/// select to_hex(2147483647);
304/// ----
305/// 7fffffff
306///
307/// query T
308/// select to_hex(-2147483648);
309/// ----
310/// 80000000
311///
312/// query T
313/// select to_hex(9223372036854775807);
314/// ----
315/// 7fffffffffffffff
316///
317/// query T
318/// select to_hex(-9223372036854775808);
319/// ----
320/// 8000000000000000
321/// ```
322#[function("to_hex(int4) -> varchar")]
323pub fn to_hex_i32(n: i32, writer: &mut impl Write) {
324    write!(writer, "{:x}", n).unwrap();
325}
326
327#[function("to_hex(int8) -> varchar")]
328pub fn to_hex_i64(n: i64, writer: &mut impl Write) {
329    write!(writer, "{:x}", n).unwrap();
330}
331
332/// Returns the given string suitably quoted to be used as an identifier in an SQL statement string.
333/// Quotes are added only if necessary (i.e., if the string contains non-identifier characters or
334/// would be case-folded). Embedded quotes are properly doubled.
335///
336/// Refer to <https://github.com/postgres/postgres/blob/90189eefc1e11822794e3386d9bafafd3ba3a6e8/src/backend/utils/adt/ruleutils.c#L11506>
337///
338/// # Example
339///
340/// ```slt
341/// query T
342/// select quote_ident('foo bar')
343/// ----
344/// "foo bar"
345///
346/// query T
347/// select quote_ident('FooBar')
348/// ----
349/// "FooBar"
350///
351/// query T
352/// select quote_ident('foo_bar')
353/// ----
354/// foo_bar
355///
356/// query T
357/// select quote_ident('foo"bar')
358/// ----
359/// "foo""bar"
360///
361/// # FIXME: quote SQL keywords is not supported yet
362/// query T
363/// select quote_ident('select')
364/// ----
365/// select
366/// ```
367#[function("quote_ident(varchar) -> varchar")]
368pub fn quote_ident(s: &str, writer: &mut impl Write) {
369    let needs_quotes = s.chars().any(|c| !matches!(c, 'a'..='z' | '0'..='9' | '_'));
370    if !needs_quotes {
371        write!(writer, "{}", s).unwrap();
372        return;
373    }
374    write!(writer, "\"").unwrap();
375    for c in s.chars() {
376        if c == '"' {
377            write!(writer, "\"\"").unwrap();
378        } else {
379            write!(writer, "{c}").unwrap();
380        }
381    }
382    write!(writer, "\"").unwrap();
383}
384
385/// Returns the first n characters in the string.
386/// If n is a negative value, the function will return all but last |n| characters.
387///
388/// # Example
389///
390/// ```slt
391/// query T
392/// select left('RisingWave', 6)
393/// ----
394/// Rising
395///
396/// query T
397/// select left('RisingWave', 42)
398/// ----
399/// RisingWave
400///
401/// query T
402/// select left('RisingWave', 0)
403/// ----
404/// (empty)
405///
406/// query T
407/// select left('RisingWave', -4)
408/// ----
409/// Rising
410///
411/// query T
412/// select left('RisingWave', -2147483648);
413/// ----
414/// (empty)
415/// ```
416#[function("left(varchar, int4) -> varchar")]
417pub fn left(s: &str, n: i32, writer: &mut impl Write) {
418    let n = if n >= 0 {
419        n as usize
420    } else {
421        s.chars().count().saturating_add_signed(n as isize)
422    };
423
424    s.chars()
425        .take(n)
426        .for_each(|c| writer.write_char(c).unwrap());
427}
428
429/// Returns the last n characters in the string.
430/// If n is a negative value, the function will return all but first |n| characters.
431///
432/// # Example
433///
434/// ```slt
435/// query T
436/// select right('RisingWave', 4)
437/// ----
438/// Wave
439///
440/// query T
441/// select left('RisingWave', 42)
442/// ----
443/// RisingWave
444///
445/// query T
446/// select right('RisingWave', 0)
447/// ----
448/// (empty)
449///
450/// query T
451/// select right('RisingWave', -6)
452/// ----
453/// Wave
454///
455/// # PostgreSQL returns the whole string due to an overflow bug, which we do not follow.
456/// query T
457/// select right('RisingWave', -2147483648);
458/// ----
459/// (empty)
460/// ```
461#[function("right(varchar, int4) -> varchar")]
462pub fn right(s: &str, n: i32, writer: &mut impl Write) {
463    let skip = if n >= 0 {
464        s.chars().count().saturating_sub(n as usize)
465    } else {
466        // `n as usize` is signed extended. This is `-n` without overflow.
467        usize::MAX - (n as usize) + 1
468    };
469
470    s.chars()
471        .skip(skip)
472        .for_each(|c| writer.write_char(c).unwrap());
473}
474
475/// `quote_literal(string text)`
476/// `quote_literal(value anyelement)`
477///
478/// Return the given string suitably quoted to be used as a string literal in an SQL statement
479/// string. Embedded single-quotes and backslashes are properly doubled.
480/// Note that `quote_literal` returns null on null input; if the argument might be null,
481/// `quote_nullable` is often more suitable.
482///
483/// # Example
484///
485/// Note that the quotes are part of the output string.
486///
487/// ```slt
488/// query T
489/// select quote_literal(E'O\'Reilly')
490/// ----
491/// 'O''Reilly'
492///
493/// query T
494/// select quote_literal(E'C:\\Windows\\')
495/// ----
496/// E'C:\\Windows\\'
497///
498/// query T
499/// select quote_literal(42.5)
500/// ----
501/// '42.5'
502///
503/// query T
504/// select quote_literal('hello'::bytea);
505/// ----
506/// E'\\x68656c6c6f'
507///
508/// query T
509/// select quote_literal('{"hello":"world","foo":233}'::jsonb);
510/// ----
511/// '{"foo": 233, "hello": "world"}'
512/// ```
513#[function("quote_literal(varchar) -> varchar")]
514pub fn quote_literal(s: &str, writer: &mut impl Write) {
515    if s.contains('\\') {
516        // use escape format: E'...'
517        write!(writer, "E").unwrap();
518    }
519    write!(writer, "'").unwrap();
520    for c in s.chars() {
521        match c {
522            '\'' => write!(writer, "''").unwrap(),
523            '\\' => write!(writer, "\\\\").unwrap(),
524            _ => write!(writer, "{}", c).unwrap(),
525        }
526    }
527    write!(writer, "'").unwrap();
528}
529
530/// `quote_nullable(string text)`
531///
532/// Return the given string suitably quoted to be used as a string literal in an SQL statement
533/// string; or, if the argument is null, return NULL.
534/// Embedded single-quotes and backslashes are properly doubled.
535#[function("quote_nullable(varchar) -> varchar")]
536pub fn quote_nullable(s: Option<&str>, writer: &mut impl Write) {
537    match s {
538        Some(s) => quote_literal(s, writer),
539        None => write!(writer, "NULL").unwrap(),
540    }
541}
542
543#[cfg(test)]
544mod tests {
545    use super::*;
546
547    #[test]
548    fn test_left_and_right() {
549        let s = "cxscgccdd";
550        let us = "上海自来水来自海上";
551
552        let cases = [
553            (s, 3, "cxs", "cdd"),
554            (s, -3, "cxscgc", "cgccdd"),
555            (s, 0, "", ""),
556            (s, 15, "cxscgccdd", "cxscgccdd"),
557            // Unicode test
558            (us, 5, "上海自来水", "水来自海上"),
559            (us, -6, "上海自", "自海上"),
560        ];
561
562        for (s, n, left_expected, right_expected) in cases {
563            let mut left_writer = String::new();
564            let mut right_writer = String::new();
565            left(s, n, &mut left_writer);
566            right(s, n, &mut right_writer);
567            assert_eq!(left_writer, left_expected);
568            assert_eq!(right_writer, right_expected);
569        }
570    }
571}