risingwave_expr_impl/scalar/
format.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Write;
16use std::str::FromStr;
17
18use risingwave_common::row::Row;
19use risingwave_common::types::{ScalarRefImpl, ToText};
20use risingwave_expr::{ExprError, Result, function};
21
22use super::string::quote_ident;
23
24/// Formats arguments according to a format string.
25///
26/// # Example
27///
28/// ```slt
29/// query T
30/// select format('%s %s', 'Hello', 'World');
31/// ----
32/// Hello World
33///
34/// query T
35/// select format('%s %s', variadic array['Hello', 'World']);
36/// ----
37/// Hello World
38/// ```
39#[function(
40    "format(varchar, variadic anyarray) -> varchar",
41    prebuild = "Formatter::from_str($0).map_err(|e| ExprError::Parse(e.to_report_string().into()))?"
42)]
43fn format(row: impl Row, formatter: &Formatter, writer: &mut impl Write) -> Result<()> {
44    let mut args = row.iter();
45    for node in &formatter.nodes {
46        match node {
47            FormatterNode::Literal(literal) => writer.write_str(literal).unwrap(),
48            FormatterNode::Specifier(sp) => {
49                let arg = args.next().ok_or(ExprError::TooFewArguments)?;
50                match sp.ty {
51                    SpecifierType::SimpleString => {
52                        if let Some(scalar) = arg {
53                            scalar.write(writer).unwrap();
54                        }
55                    }
56                    SpecifierType::SqlIdentifier => match arg {
57                        Some(ScalarRefImpl::Utf8(arg)) => quote_ident(arg, writer),
58                        _ => {
59                            return Err(ExprError::UnsupportedFunction(
60                                "unsupported data for specifier type 'I'".to_owned(),
61                            ));
62                        }
63                    },
64                    SpecifierType::SqlLiteral => {
65                        return Err(ExprError::UnsupportedFunction(
66                            "unsupported specifier type 'L'".to_owned(),
67                        ));
68                    }
69                }
70            }
71        }
72    }
73    Ok(())
74}
75
76/// The type of format conversion to use to produce the format specifier's output.
77#[derive(Copy, Clone, Debug, PartialEq, Eq)]
78pub enum SpecifierType {
79    /// `s` formats the argument value as a simple string. A null value is treated as an empty
80    /// string.
81    SimpleString,
82    /// `I` treats the argument value as an SQL identifier, double-quoting it if necessary. It is
83    /// an error for the value to be null (equivalent to `quote_ident`).
84    SqlIdentifier,
85    /// `L` quotes the argument value as an SQL literal. A null value is displayed as the string
86    /// NULL, without quotes (equivalent to `quote_nullable`).
87    SqlLiteral,
88}
89
90impl TryFrom<char> for SpecifierType {
91    type Error = ();
92
93    fn try_from(c: char) -> std::result::Result<Self, Self::Error> {
94        match c {
95            's' => Ok(SpecifierType::SimpleString),
96            'I' => Ok(SpecifierType::SqlIdentifier),
97            'L' => Ok(SpecifierType::SqlLiteral),
98            _ => Err(()),
99        }
100    }
101}
102
103#[derive(Debug)]
104struct Specifier {
105    // TODO: support position, flags and width.
106    ty: SpecifierType,
107}
108
109#[derive(Debug)]
110enum FormatterNode {
111    Specifier(Specifier),
112    Literal(String),
113}
114
115#[derive(Debug)]
116struct Formatter {
117    nodes: Vec<FormatterNode>,
118}
119
120#[derive(Debug, thiserror::Error)]
121enum ParseFormatError {
122    #[error("unrecognized format() type specifier \"{0}\"")]
123    UnrecognizedSpecifierType(char),
124    #[error("unterminated format() type specifier")]
125    UnterminatedSpecifier,
126}
127
128impl FromStr for Formatter {
129    type Err = ParseFormatError;
130
131    /// Parse the format string into a high-efficient representation.
132    /// <https://www.postgresql.org/docs/current/functions-string.html#FUNCTIONS-STRING-FORMAT>
133    fn from_str(format: &str) -> std::result::Result<Self, ParseFormatError> {
134        // 8 is a good magic number here, it can cover an input like 'Testing %s, %s, %s, %%'.
135        let mut nodes = Vec::with_capacity(8);
136        let mut after_percent = false;
137        let mut literal = String::with_capacity(8);
138        for c in format.chars() {
139            if after_percent && c == '%' {
140                literal.push('%');
141                after_percent = false;
142            } else if after_percent {
143                // TODO: support position, flags and width.
144                if let Ok(ty) = SpecifierType::try_from(c) {
145                    if !literal.is_empty() {
146                        nodes.push(FormatterNode::Literal(std::mem::take(&mut literal)));
147                    }
148                    nodes.push(FormatterNode::Specifier(Specifier { ty }));
149                } else {
150                    return Err(ParseFormatError::UnrecognizedSpecifierType(c));
151                }
152                after_percent = false;
153            } else if c == '%' {
154                after_percent = true;
155            } else {
156                literal.push(c);
157            }
158        }
159
160        if after_percent {
161            return Err(ParseFormatError::UnterminatedSpecifier);
162        }
163
164        if !literal.is_empty() {
165            nodes.push(FormatterNode::Literal(literal));
166        }
167
168        Ok(Formatter { nodes })
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use risingwave_common::array::DataChunk;
175    use risingwave_common::row::Row;
176    use risingwave_common::test_prelude::DataChunkTestExt;
177    use risingwave_common::types::ToOwnedDatum;
178    use risingwave_common::util::iter_util::ZipEqDebug;
179    use risingwave_expr::expr::build_from_pretty;
180
181    #[tokio::test]
182    async fn test_format() {
183        let format = build_from_pretty("(format:varchar $0:varchar $1:varchar $2:varchar)");
184        let (input, expected) = DataChunk::from_pretty(
185            "T          T       T       T
186             Hello%s    World   .       HelloWorld
187             %s%s       Hello   World   HelloWorld
188             %I         &&      .       \"&&\"
189             .          a       b       .",
190        )
191        .split_column_at(3);
192
193        // test eval
194        let output = format.eval(&input).await.unwrap();
195        assert_eq!(&output, expected.column_at(0));
196
197        // test eval_row
198        for (row, expected) in input.rows().zip_eq_debug(expected.rows()) {
199            let result = format.eval_row(&row.to_owned_row()).await.unwrap();
200            assert_eq!(result, expected.datum_at(0).to_owned_datum());
201        }
202    }
203}