risingwave_expr_impl/scalar/
string.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! String functions
16//!
17//! <https://www.postgresql.org/docs/current/functions-string.html>
18
19use risingwave_common::util::quote_ident::QuoteIdent;
20use risingwave_expr::function;
21
22/// Returns the character with the specified Unicode code point.
23///
24/// # Example
25///
26/// ```slt
27/// query T
28/// select chr(65);
29/// ----
30/// A
31/// ```
32#[function("chr(int4) -> varchar")]
33pub fn chr(code: i32, writer: &mut impl std::fmt::Write) {
34    if let Some(c) = std::char::from_u32(code as u32) {
35        write!(writer, "{}", c).unwrap();
36    }
37}
38
39/// Returns true if the given string starts with the specified prefix.
40///
41/// # Example
42///
43/// ```slt
44/// query T
45/// select starts_with('abcdef', 'abc');
46/// ----
47/// t
48///
49/// query T
50/// select 'abcdef' ^@ 'abc';
51/// ----
52/// t
53///
54/// query T
55/// select 'abcdef' ^@ some(array['x', 'a', 't']);
56/// ----
57/// t
58/// ```
59#[function("starts_with(varchar, varchar) -> boolean")]
60pub fn starts_with(s: &str, prefix: &str) -> bool {
61    s.starts_with(prefix)
62}
63
64/// Capitalizes the first letter of each word in the given string.
65///
66/// # Example
67///
68/// ```slt
69/// query T
70/// select initcap('the quick brown fox');
71/// ----
72/// The Quick Brown Fox
73/// ```
74#[function("initcap(varchar) -> varchar")]
75pub fn initcap(s: &str, writer: &mut impl std::fmt::Write) {
76    let mut capitalize_next = true;
77    for c in s.chars() {
78        if capitalize_next {
79            write!(writer, "{}", c.to_uppercase()).unwrap();
80            capitalize_next = false;
81        } else {
82            write!(writer, "{}", c.to_lowercase()).unwrap();
83        }
84        if c.is_whitespace() {
85            capitalize_next = true;
86        }
87    }
88}
89
90/// Extends the given string on the left until it is at least the specified length,
91/// using the specified fill character (or a space by default).
92///
93/// # Example
94///
95/// ```slt
96/// query T
97/// select lpad('abc', 5);
98/// ----
99///   abc
100///
101/// query T
102/// select lpad('abcdef', 3);
103/// ----
104/// abc
105/// ```
106#[function("lpad(varchar, int4) -> varchar")]
107pub fn lpad(s: &str, length: i32, writer: &mut impl std::fmt::Write) {
108    lpad_fill(s, length, " ", writer);
109}
110
111/// Extends the string to the specified length by prepending the characters fill.
112/// If the string is already longer than the specified length, it is truncated on the right.
113///
114/// # Example
115///
116/// ```slt
117/// query T
118/// select lpad('hi', 5, 'xy');
119/// ----
120/// xyxhi
121///
122/// query T
123/// select lpad('hi', 5, '');
124/// ----
125/// hi
126/// ```
127#[function("lpad(varchar, int4, varchar) -> varchar")]
128pub fn lpad_fill(s: &str, length: i32, fill: &str, writer: &mut impl std::fmt::Write) {
129    let s_len = s.chars().count();
130    let fill_len = fill.chars().count();
131
132    if length <= 0 {
133        return;
134    }
135    if s_len >= length as usize {
136        for c in s.chars().take(length as usize) {
137            write!(writer, "{c}").unwrap();
138        }
139    } else if fill_len == 0 {
140        write!(writer, "{s}").unwrap();
141    } else {
142        let mut remaining_length = length as usize - s_len;
143        while remaining_length >= fill_len {
144            write!(writer, "{fill}").unwrap();
145            remaining_length -= fill_len;
146        }
147        for c in fill.chars().take(remaining_length) {
148            write!(writer, "{c}").unwrap();
149        }
150        write!(writer, "{s}").unwrap();
151    }
152}
153
154/// Extends the given string on the right until it is at least the specified length,
155/// using the specified fill character (or a space by default).
156///
157/// # Example
158///
159/// ```slt
160/// query T
161/// select rpad('abc', 5);
162/// ----
163/// abc
164///
165/// query T
166/// select rpad('abcdef', 3);
167/// ----
168/// abc
169/// ```
170#[function("rpad(varchar, int4) -> varchar")]
171pub fn rpad(s: &str, length: i32, writer: &mut impl std::fmt::Write) {
172    rpad_fill(s, length, " ", writer);
173}
174
175/// Extends the given string on the right until it is at least the specified length,
176/// using the specified fill string, truncating the string if it is already longer
177/// than the specified length.
178///
179/// # Example
180///
181/// ```slt
182/// query T
183/// select rpad('hi', 5, 'xy');
184/// ----
185/// hixyx
186///
187/// query T
188/// select rpad('abc', 5, '😀');
189/// ----
190/// abc😀😀
191///
192/// query T
193/// select rpad('abcdef', 3, '0');
194/// ----
195/// abc
196///
197/// query T
198/// select rpad('hi', 5, '');
199/// ----
200/// hi
201/// ```
202#[function("rpad(varchar, int4, varchar) -> varchar")]
203pub fn rpad_fill(s: &str, length: i32, fill: &str, writer: &mut impl std::fmt::Write) {
204    let s_len = s.chars().count();
205    let fill_len = fill.chars().count();
206
207    if length <= 0 {
208        return;
209    }
210
211    if s_len >= length as usize {
212        for c in s.chars().take(length as usize) {
213            write!(writer, "{c}").unwrap();
214        }
215    } else if fill_len == 0 {
216        write!(writer, "{s}").unwrap();
217    } else {
218        write!(writer, "{s}").unwrap();
219        let mut remaining_length = length as usize - s_len;
220        while remaining_length >= fill_len {
221            write!(writer, "{fill}").unwrap();
222            remaining_length -= fill_len;
223        }
224        for c in fill.chars().take(remaining_length) {
225            write!(writer, "{c}").unwrap();
226        }
227    }
228}
229
230/// Reverses the characters in the given string.
231///
232/// # Example
233///
234/// ```slt
235/// query T
236/// select reverse('abcdef');
237/// ----
238/// fedcba
239/// ```
240#[function("reverse(varchar) -> varchar")]
241pub fn reverse(s: &str, writer: &mut impl std::fmt::Write) {
242    for c in s.chars().rev() {
243        write!(writer, "{}", c).unwrap();
244    }
245}
246
247/// Converts the input string to ASCII by dropping accents, assuming that the input string
248/// is encoded in one of the supported encodings (Latin1, Latin2, Latin9, or WIN1250).
249///
250/// # Example
251///
252/// ```slt
253/// query T
254/// select to_ascii('Karél');
255/// ----
256/// Karel
257/// ```
258#[function("to_ascii(varchar) -> varchar")]
259pub fn to_ascii(s: &str, writer: &mut impl std::fmt::Write) {
260    for c in s.chars() {
261        let ascii = match c {
262            'Á' | 'À' | 'Â' | 'Ã' => 'A',
263            'á' | 'à' | 'â' | 'ã' => 'a',
264            'Č' | 'Ć' | 'Ç' => 'C',
265            'č' | 'ć' | 'ç' => 'c',
266            'Ď' => 'D',
267            'ď' => 'd',
268            'É' | 'È' | 'Ê' | 'Ẽ' => 'E',
269            'é' | 'è' | 'ê' | 'ẽ' => 'e',
270            'Í' | 'Ì' | 'Î' | 'Ĩ' => 'I',
271            'í' | 'ì' | 'î' | 'ĩ' => 'i',
272            'Ľ' => 'L',
273            'ľ' => 'l',
274            'Ň' => 'N',
275            'ň' => 'n',
276            'Ó' | 'Ò' | 'Ô' | 'Õ' => 'O',
277            'ó' | 'ò' | 'ô' | 'õ' => 'o',
278            'Ŕ' => 'R',
279            'ŕ' => 'r',
280            'Š' | 'Ś' => 'S',
281            'š' | 'ś' => 's',
282            'Ť' => 'T',
283            'ť' => 't',
284            'Ú' | 'Ù' | 'Û' | 'Ũ' => 'U',
285            'ú' | 'ù' | 'û' | 'ũ' => 'u',
286            'Ý' | 'Ỳ' => 'Y',
287            'ý' | 'ỳ' => 'y',
288            'Ž' | 'Ź' | 'Ż' => 'Z',
289            'ž' | 'ź' | 'ż' => 'z',
290            _ => c,
291        };
292        write!(writer, "{}", ascii).unwrap();
293    }
294}
295
296/// Converts the given integer to its equivalent hexadecimal representation.
297///
298/// # Example
299///
300/// ```slt
301/// query T
302/// select to_hex(2147483647);
303/// ----
304/// 7fffffff
305///
306/// query T
307/// select to_hex(-2147483648);
308/// ----
309/// 80000000
310///
311/// query T
312/// select to_hex(9223372036854775807);
313/// ----
314/// 7fffffffffffffff
315///
316/// query T
317/// select to_hex(-9223372036854775808);
318/// ----
319/// 8000000000000000
320/// ```
321#[function("to_hex(int4) -> varchar")]
322pub fn to_hex_i32(n: i32, writer: &mut impl std::fmt::Write) {
323    write!(writer, "{:x}", n).unwrap();
324}
325
326#[function("to_hex(int8) -> varchar")]
327pub fn to_hex_i64(n: i64, writer: &mut impl std::fmt::Write) {
328    write!(writer, "{:x}", n).unwrap();
329}
330
331/// Returns the given string suitably quoted to be used as an identifier in an SQL statement string.
332/// Quotes are added only if necessary (i.e., if the string contains non-identifier characters or
333/// would be case-folded). Embedded quotes are properly doubled.
334///
335/// Refer to <https://github.com/postgres/postgres/blob/90189eefc1e11822794e3386d9bafafd3ba3a6e8/src/backend/utils/adt/ruleutils.c#L11506>
336///
337/// # Example
338///
339/// ```slt
340/// query T
341/// select quote_ident('foo bar')
342/// ----
343/// "foo bar"
344///
345/// query T
346/// select quote_ident('FooBar')
347/// ----
348/// "FooBar"
349///
350/// query T
351/// select quote_ident('foo_bar')
352/// ----
353/// foo_bar
354///
355/// query T
356/// select quote_ident('foo"bar')
357/// ----
358/// "foo""bar"
359///
360/// # FIXME: quote SQL keywords is not supported yet
361/// query T
362/// select quote_ident('select')
363/// ----
364/// select
365/// ```
366#[function("quote_ident(varchar) -> varchar")]
367pub fn quote_ident(s: &str, writer: &mut impl std::fmt::Write) {
368    write!(writer, "{}", QuoteIdent(s)).unwrap();
369}
370
371/// Returns the first n characters in the string.
372/// If n is a negative value, the function will return all but last |n| characters.
373///
374/// # Example
375///
376/// ```slt
377/// query T
378/// select left('RisingWave', 6)
379/// ----
380/// Rising
381///
382/// query T
383/// select left('RisingWave', 42)
384/// ----
385/// RisingWave
386///
387/// query T
388/// select left('RisingWave', 0)
389/// ----
390/// (empty)
391///
392/// query T
393/// select left('RisingWave', -4)
394/// ----
395/// Rising
396///
397/// query T
398/// select left('RisingWave', -2147483648);
399/// ----
400/// (empty)
401/// ```
402#[function("left(varchar, int4) -> varchar")]
403pub fn left(s: &str, n: i32, writer: &mut impl std::fmt::Write) {
404    let n = if n >= 0 {
405        n as usize
406    } else {
407        s.chars().count().saturating_add_signed(n as isize)
408    };
409
410    s.chars()
411        .take(n)
412        .for_each(|c| writer.write_char(c).unwrap());
413}
414
415/// Returns the last n characters in the string.
416/// If n is a negative value, the function will return all but first |n| characters.
417///
418/// # Example
419///
420/// ```slt
421/// query T
422/// select right('RisingWave', 4)
423/// ----
424/// Wave
425///
426/// query T
427/// select left('RisingWave', 42)
428/// ----
429/// RisingWave
430///
431/// query T
432/// select right('RisingWave', 0)
433/// ----
434/// (empty)
435///
436/// query T
437/// select right('RisingWave', -6)
438/// ----
439/// Wave
440///
441/// # PostgreSQL returns the whole string due to an overflow bug, which we do not follow.
442/// query T
443/// select right('RisingWave', -2147483648);
444/// ----
445/// (empty)
446/// ```
447#[function("right(varchar, int4) -> varchar")]
448pub fn right(s: &str, n: i32, writer: &mut impl std::fmt::Write) {
449    let skip = if n >= 0 {
450        s.chars().count().saturating_sub(n as usize)
451    } else {
452        // `n as usize` is signed extended. This is `-n` without overflow.
453        usize::MAX - (n as usize) + 1
454    };
455
456    s.chars()
457        .skip(skip)
458        .for_each(|c| writer.write_char(c).unwrap());
459}
460
461/// `quote_literal(string text)`
462/// `quote_literal(value anyelement)`
463///
464/// Return the given string suitably quoted to be used as a string literal in an SQL statement
465/// string. Embedded single-quotes and backslashes are properly doubled.
466/// Note that `quote_literal` returns null on null input; if the argument might be null,
467/// `quote_nullable` is often more suitable.
468///
469/// # Example
470///
471/// Note that the quotes are part of the output string.
472///
473/// ```slt
474/// query T
475/// select quote_literal(E'O\'Reilly')
476/// ----
477/// 'O''Reilly'
478///
479/// query T
480/// select quote_literal(E'C:\\Windows\\')
481/// ----
482/// E'C:\\Windows\\'
483///
484/// query T
485/// select quote_literal(42.5)
486/// ----
487/// '42.5'
488///
489/// query T
490/// select quote_literal('hello'::bytea);
491/// ----
492/// E'\\x68656c6c6f'
493///
494/// query T
495/// select quote_literal('{"hello":"world","foo":233}'::jsonb);
496/// ----
497/// '{"foo": 233, "hello": "world"}'
498/// ```
499#[function("quote_literal(varchar) -> varchar")]
500pub fn quote_literal(s: &str, writer: &mut impl std::fmt::Write) {
501    if s.contains('\\') {
502        // use escape format: E'...'
503        write!(writer, "E").unwrap();
504    }
505    write!(writer, "'").unwrap();
506    for c in s.chars() {
507        match c {
508            '\'' => write!(writer, "''").unwrap(),
509            '\\' => write!(writer, "\\\\").unwrap(),
510            _ => write!(writer, "{}", c).unwrap(),
511        }
512    }
513    write!(writer, "'").unwrap();
514}
515
516/// `quote_nullable(string text)`
517///
518/// Return the given string suitably quoted to be used as a string literal in an SQL statement
519/// string; or, if the argument is null, return NULL.
520/// Embedded single-quotes and backslashes are properly doubled.
521#[function("quote_nullable(varchar) -> varchar")]
522pub fn quote_nullable(s: Option<&str>, writer: &mut impl std::fmt::Write) {
523    match s {
524        Some(s) => quote_literal(s, writer),
525        None => write!(writer, "NULL").unwrap(),
526    }
527}
528
529#[cfg(test)]
530mod tests {
531    use super::*;
532
533    #[test]
534    fn test_left_and_right() {
535        let s = "cxscgccdd";
536        let us = "上海自来水来自海上";
537
538        let cases = [
539            (s, 3, "cxs", "cdd"),
540            (s, -3, "cxscgc", "cgccdd"),
541            (s, 0, "", ""),
542            (s, 15, "cxscgccdd", "cxscgccdd"),
543            // Unicode test
544            (us, 5, "上海自来水", "水来自海上"),
545            (us, -6, "上海自", "自海上"),
546        ];
547
548        for (s, n, left_expected, right_expected) in cases {
549            let mut left_writer = String::new();
550            let mut right_writer = String::new();
551            left(s, n, &mut left_writer);
552            right(s, n, &mut right_writer);
553            assert_eq!(left_writer, left_expected);
554            assert_eq!(right_writer, right_expected);
555        }
556    }
557}