risingwave_expr_impl/scalar/
format.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::str::FromStr;
16
17use risingwave_common::row::Row;
18use risingwave_common::types::{ScalarRefImpl, ToText};
19use risingwave_expr::{ExprError, Result, function};
20
21use super::string::quote_ident;
22
23/// Formats arguments according to a format string.
24///
25/// # Example
26///
27/// ```slt
28/// query T
29/// select format('%s %s', 'Hello', 'World');
30/// ----
31/// Hello World
32///
33/// query T
34/// select format('%s %s', variadic array['Hello', 'World']);
35/// ----
36/// Hello World
37/// ```
38#[function(
39    "format(varchar, variadic anyarray) -> varchar",
40    prebuild = "Formatter::from_str($0).map_err(|e| ExprError::Parse(e.to_report_string().into()))?"
41)]
42fn format(row: impl Row, formatter: &Formatter, writer: &mut impl std::fmt::Write) -> Result<()> {
43    let mut args = row.iter();
44    for node in &formatter.nodes {
45        match node {
46            FormatterNode::Literal(literal) => writer.write_str(literal).unwrap(),
47            FormatterNode::Specifier(sp) => {
48                let arg = args.next().ok_or(ExprError::TooFewArguments)?;
49                match sp.ty {
50                    SpecifierType::SimpleString => {
51                        if let Some(scalar) = arg {
52                            scalar.write(writer).unwrap();
53                        }
54                    }
55                    SpecifierType::SqlIdentifier => match arg {
56                        Some(ScalarRefImpl::Utf8(arg)) => quote_ident(arg, writer),
57                        _ => {
58                            return Err(ExprError::UnsupportedFunction(
59                                "unsupported data for specifier type 'I'".to_owned(),
60                            ));
61                        }
62                    },
63                    SpecifierType::SqlLiteral => {
64                        return Err(ExprError::UnsupportedFunction(
65                            "unsupported specifier type 'L'".to_owned(),
66                        ));
67                    }
68                }
69            }
70        }
71    }
72    Ok(())
73}
74
75/// The type of format conversion to use to produce the format specifier's output.
76#[derive(Copy, Clone, Debug, PartialEq, Eq)]
77pub enum SpecifierType {
78    /// `s` formats the argument value as a simple string. A null value is treated as an empty
79    /// string.
80    SimpleString,
81    /// `I` treats the argument value as an SQL identifier, double-quoting it if necessary. It is
82    /// an error for the value to be null (equivalent to `quote_ident`).
83    SqlIdentifier,
84    /// `L` quotes the argument value as an SQL literal. A null value is displayed as the string
85    /// NULL, without quotes (equivalent to `quote_nullable`).
86    SqlLiteral,
87}
88
89impl TryFrom<char> for SpecifierType {
90    type Error = ();
91
92    fn try_from(c: char) -> std::result::Result<Self, Self::Error> {
93        match c {
94            's' => Ok(SpecifierType::SimpleString),
95            'I' => Ok(SpecifierType::SqlIdentifier),
96            'L' => Ok(SpecifierType::SqlLiteral),
97            _ => Err(()),
98        }
99    }
100}
101
102#[derive(Debug)]
103struct Specifier {
104    // TODO: support position, flags and width.
105    ty: SpecifierType,
106}
107
108#[derive(Debug)]
109enum FormatterNode {
110    Specifier(Specifier),
111    Literal(String),
112}
113
114#[derive(Debug)]
115struct Formatter {
116    nodes: Vec<FormatterNode>,
117}
118
119#[derive(Debug, thiserror::Error)]
120enum ParseFormatError {
121    #[error("unrecognized format() type specifier \"{0}\"")]
122    UnrecognizedSpecifierType(char),
123    #[error("unterminated format() type specifier")]
124    UnterminatedSpecifier,
125}
126
127impl FromStr for Formatter {
128    type Err = ParseFormatError;
129
130    /// Parse the format string into a high-efficient representation.
131    /// <https://www.postgresql.org/docs/current/functions-string.html#FUNCTIONS-STRING-FORMAT>
132    fn from_str(format: &str) -> std::result::Result<Self, ParseFormatError> {
133        // 8 is a good magic number here, it can cover an input like 'Testing %s, %s, %s, %%'.
134        let mut nodes = Vec::with_capacity(8);
135        let mut after_percent = false;
136        let mut literal = String::with_capacity(8);
137        for c in format.chars() {
138            if after_percent && c == '%' {
139                literal.push('%');
140                after_percent = false;
141            } else if after_percent {
142                // TODO: support position, flags and width.
143                if let Ok(ty) = SpecifierType::try_from(c) {
144                    if !literal.is_empty() {
145                        nodes.push(FormatterNode::Literal(std::mem::take(&mut literal)));
146                    }
147                    nodes.push(FormatterNode::Specifier(Specifier { ty }));
148                } else {
149                    return Err(ParseFormatError::UnrecognizedSpecifierType(c));
150                }
151                after_percent = false;
152            } else if c == '%' {
153                after_percent = true;
154            } else {
155                literal.push(c);
156            }
157        }
158
159        if after_percent {
160            return Err(ParseFormatError::UnterminatedSpecifier);
161        }
162
163        if !literal.is_empty() {
164            nodes.push(FormatterNode::Literal(literal));
165        }
166
167        Ok(Formatter { nodes })
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use risingwave_common::array::DataChunk;
174    use risingwave_common::row::Row;
175    use risingwave_common::test_prelude::DataChunkTestExt;
176    use risingwave_common::types::ToOwnedDatum;
177    use risingwave_common::util::iter_util::ZipEqDebug;
178    use risingwave_expr::expr::build_from_pretty;
179
180    #[tokio::test]
181    async fn test_format() {
182        let format = build_from_pretty("(format:varchar $0:varchar $1:varchar $2:varchar)");
183        let (input, expected) = DataChunk::from_pretty(
184            "T          T       T       T
185             Hello%s    World   .       HelloWorld
186             %s%s       Hello   World   HelloWorld
187             %I         &&      .       \"&&\"
188             .          a       b       .",
189        )
190        .split_column_at(3);
191
192        // test eval
193        let output = format.eval(&input).await.unwrap();
194        assert_eq!(&output, expected.column_at(0));
195
196        // test eval_row
197        for (row, expected) in input.rows().zip_eq_debug(expected.rows()) {
198            let result = format.eval_row(&row.to_owned_row()).await.unwrap();
199            assert_eq!(result, expected.datum_at(0).to_owned_datum());
200        }
201    }
202}