risingwave_expr_impl/scalar/
format.rs1use std::str::FromStr;
16
17use risingwave_common::row::Row;
18use risingwave_common::types::{ScalarRefImpl, ToText};
19use risingwave_expr::{ExprError, Result, function};
20
21use super::string::quote_ident;
22
23#[function(
39 "format(varchar, variadic anyarray) -> varchar",
40 prebuild = "Formatter::from_str($0).map_err(|e| ExprError::Parse(e.to_report_string().into()))?"
41)]
42fn format(row: impl Row, formatter: &Formatter, writer: &mut impl std::fmt::Write) -> Result<()> {
43 let mut args = row.iter();
44 for node in &formatter.nodes {
45 match node {
46 FormatterNode::Literal(literal) => writer.write_str(literal).unwrap(),
47 FormatterNode::Specifier(sp) => {
48 let arg = args.next().ok_or(ExprError::TooFewArguments)?;
49 match sp.ty {
50 SpecifierType::SimpleString => {
51 if let Some(scalar) = arg {
52 scalar.write(writer).unwrap();
53 }
54 }
55 SpecifierType::SqlIdentifier => match arg {
56 Some(ScalarRefImpl::Utf8(arg)) => quote_ident(arg, writer),
57 _ => {
58 return Err(ExprError::UnsupportedFunction(
59 "unsupported data for specifier type 'I'".to_owned(),
60 ));
61 }
62 },
63 SpecifierType::SqlLiteral => {
64 return Err(ExprError::UnsupportedFunction(
65 "unsupported specifier type 'L'".to_owned(),
66 ));
67 }
68 }
69 }
70 }
71 }
72 Ok(())
73}
74
75#[derive(Copy, Clone, Debug, PartialEq, Eq)]
77pub enum SpecifierType {
78 SimpleString,
81 SqlIdentifier,
84 SqlLiteral,
87}
88
89impl TryFrom<char> for SpecifierType {
90 type Error = ();
91
92 fn try_from(c: char) -> std::result::Result<Self, Self::Error> {
93 match c {
94 's' => Ok(SpecifierType::SimpleString),
95 'I' => Ok(SpecifierType::SqlIdentifier),
96 'L' => Ok(SpecifierType::SqlLiteral),
97 _ => Err(()),
98 }
99 }
100}
101
102#[derive(Debug)]
103struct Specifier {
104 ty: SpecifierType,
106}
107
108#[derive(Debug)]
109enum FormatterNode {
110 Specifier(Specifier),
111 Literal(String),
112}
113
114#[derive(Debug)]
115struct Formatter {
116 nodes: Vec<FormatterNode>,
117}
118
119#[derive(Debug, thiserror::Error)]
120enum ParseFormatError {
121 #[error("unrecognized format() type specifier \"{0}\"")]
122 UnrecognizedSpecifierType(char),
123 #[error("unterminated format() type specifier")]
124 UnterminatedSpecifier,
125}
126
127impl FromStr for Formatter {
128 type Err = ParseFormatError;
129
130 fn from_str(format: &str) -> std::result::Result<Self, ParseFormatError> {
133 let mut nodes = Vec::with_capacity(8);
135 let mut after_percent = false;
136 let mut literal = String::with_capacity(8);
137 for c in format.chars() {
138 if after_percent && c == '%' {
139 literal.push('%');
140 after_percent = false;
141 } else if after_percent {
142 if let Ok(ty) = SpecifierType::try_from(c) {
144 if !literal.is_empty() {
145 nodes.push(FormatterNode::Literal(std::mem::take(&mut literal)));
146 }
147 nodes.push(FormatterNode::Specifier(Specifier { ty }));
148 } else {
149 return Err(ParseFormatError::UnrecognizedSpecifierType(c));
150 }
151 after_percent = false;
152 } else if c == '%' {
153 after_percent = true;
154 } else {
155 literal.push(c);
156 }
157 }
158
159 if after_percent {
160 return Err(ParseFormatError::UnterminatedSpecifier);
161 }
162
163 if !literal.is_empty() {
164 nodes.push(FormatterNode::Literal(literal));
165 }
166
167 Ok(Formatter { nodes })
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use risingwave_common::array::DataChunk;
174 use risingwave_common::row::Row;
175 use risingwave_common::test_prelude::DataChunkTestExt;
176 use risingwave_common::types::ToOwnedDatum;
177 use risingwave_common::util::iter_util::ZipEqDebug;
178 use risingwave_expr::expr::build_from_pretty;
179
180 #[tokio::test]
181 async fn test_format() {
182 let format = build_from_pretty("(format:varchar $0:varchar $1:varchar $2:varchar)");
183 let (input, expected) = DataChunk::from_pretty(
184 "T T T T
185 Hello%s World . HelloWorld
186 %s%s Hello World HelloWorld
187 %I && . \"&&\"
188 . a b .",
189 )
190 .split_column_at(3);
191
192 let output = format.eval(&input).await.unwrap();
194 assert_eq!(&output, expected.column_at(0));
195
196 for (row, expected) in input.rows().zip_eq_debug(expected.rows()) {
198 let result = format.eval_row(&row.to_owned_row()).await.unwrap();
199 assert_eq!(result, expected.datum_at(0).to_owned_datum());
200 }
201 }
202}