risingwave_expr_impl/scalar/
format.rs1use std::fmt::Write;
16use std::str::FromStr;
17
18use risingwave_common::row::Row;
19use risingwave_common::types::{ScalarRefImpl, ToText};
20use risingwave_expr::{ExprError, Result, function};
21
22use super::string::quote_ident;
23
24#[function(
40 "format(varchar, variadic anyarray) -> varchar",
41 prebuild = "Formatter::from_str($0).map_err(|e| ExprError::Parse(e.to_report_string().into()))?"
42)]
43fn format(row: impl Row, formatter: &Formatter, writer: &mut impl Write) -> Result<()> {
44 let mut args = row.iter();
45 for node in &formatter.nodes {
46 match node {
47 FormatterNode::Literal(literal) => writer.write_str(literal).unwrap(),
48 FormatterNode::Specifier(sp) => {
49 let arg = args.next().ok_or(ExprError::TooFewArguments)?;
50 match sp.ty {
51 SpecifierType::SimpleString => {
52 if let Some(scalar) = arg {
53 scalar.write(writer).unwrap();
54 }
55 }
56 SpecifierType::SqlIdentifier => match arg {
57 Some(ScalarRefImpl::Utf8(arg)) => quote_ident(arg, writer),
58 _ => {
59 return Err(ExprError::UnsupportedFunction(
60 "unsupported data for specifier type 'I'".to_owned(),
61 ));
62 }
63 },
64 SpecifierType::SqlLiteral => {
65 return Err(ExprError::UnsupportedFunction(
66 "unsupported specifier type 'L'".to_owned(),
67 ));
68 }
69 }
70 }
71 }
72 }
73 Ok(())
74}
75
76#[derive(Copy, Clone, Debug, PartialEq, Eq)]
78pub enum SpecifierType {
79 SimpleString,
82 SqlIdentifier,
85 SqlLiteral,
88}
89
90impl TryFrom<char> for SpecifierType {
91 type Error = ();
92
93 fn try_from(c: char) -> std::result::Result<Self, Self::Error> {
94 match c {
95 's' => Ok(SpecifierType::SimpleString),
96 'I' => Ok(SpecifierType::SqlIdentifier),
97 'L' => Ok(SpecifierType::SqlLiteral),
98 _ => Err(()),
99 }
100 }
101}
102
103#[derive(Debug)]
104struct Specifier {
105 ty: SpecifierType,
107}
108
109#[derive(Debug)]
110enum FormatterNode {
111 Specifier(Specifier),
112 Literal(String),
113}
114
115#[derive(Debug)]
116struct Formatter {
117 nodes: Vec<FormatterNode>,
118}
119
120#[derive(Debug, thiserror::Error)]
121enum ParseFormatError {
122 #[error("unrecognized format() type specifier \"{0}\"")]
123 UnrecognizedSpecifierType(char),
124 #[error("unterminated format() type specifier")]
125 UnterminatedSpecifier,
126}
127
128impl FromStr for Formatter {
129 type Err = ParseFormatError;
130
131 fn from_str(format: &str) -> std::result::Result<Self, ParseFormatError> {
134 let mut nodes = Vec::with_capacity(8);
136 let mut after_percent = false;
137 let mut literal = String::with_capacity(8);
138 for c in format.chars() {
139 if after_percent && c == '%' {
140 literal.push('%');
141 after_percent = false;
142 } else if after_percent {
143 if let Ok(ty) = SpecifierType::try_from(c) {
145 if !literal.is_empty() {
146 nodes.push(FormatterNode::Literal(std::mem::take(&mut literal)));
147 }
148 nodes.push(FormatterNode::Specifier(Specifier { ty }));
149 } else {
150 return Err(ParseFormatError::UnrecognizedSpecifierType(c));
151 }
152 after_percent = false;
153 } else if c == '%' {
154 after_percent = true;
155 } else {
156 literal.push(c);
157 }
158 }
159
160 if after_percent {
161 return Err(ParseFormatError::UnterminatedSpecifier);
162 }
163
164 if !literal.is_empty() {
165 nodes.push(FormatterNode::Literal(literal));
166 }
167
168 Ok(Formatter { nodes })
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use risingwave_common::array::DataChunk;
175 use risingwave_common::row::Row;
176 use risingwave_common::test_prelude::DataChunkTestExt;
177 use risingwave_common::types::ToOwnedDatum;
178 use risingwave_common::util::iter_util::ZipEqDebug;
179 use risingwave_expr::expr::build_from_pretty;
180
181 #[tokio::test]
182 async fn test_format() {
183 let format = build_from_pretty("(format:varchar $0:varchar $1:varchar $2:varchar)");
184 let (input, expected) = DataChunk::from_pretty(
185 "T T T T
186 Hello%s World . HelloWorld
187 %s%s Hello World HelloWorld
188 %I && . \"&&\"
189 . a b .",
190 )
191 .split_column_at(3);
192
193 let output = format.eval(&input).await.unwrap();
195 assert_eq!(&output, expected.column_at(0));
196
197 for (row, expected) in input.rows().zip_eq_debug(expected.rows()) {
199 let result = format.eval_row(&row.to_owned_row()).await.unwrap();
200 assert_eq!(result, expected.datum_at(0).to_owned_datum());
201 }
202 }
203}