1use itertools::Itertools;
16use risingwave_common::bail_not_implemented;
17use risingwave_common::types::{
18 DataType, DateTimeField, Decimal, Interval, MapType, ScalarImpl, StructType,
19};
20use risingwave_sqlparser::ast::{DateTimeField as AstDateTimeField, Expr, Value};
21use thiserror_ext::AsReport;
22
23use crate::binder::Binder;
24use crate::error::{ErrorCode, Result};
25use crate::expr::{Expr as _, ExprImpl, ExprType, FunctionCall, Literal, align_types};
26
27impl Binder {
28 pub fn bind_value(&mut self, value: &Value) -> Result<Literal> {
29 match value {
30 Value::Number(s) => self.bind_number(s.clone()),
31 Value::SingleQuotedString(s) => self.bind_string(s),
32 Value::CstyleEscapedString(s) => self.bind_string(&s.value),
33 Value::Boolean(b) => self.bind_bool(*b),
34 Value::Null => Ok(Literal::new_untyped(None)),
37 Value::Interval {
38 value,
39 leading_field,
40 leading_precision: None,
42 last_field: None,
43 fractional_seconds_precision: None,
44 } => self.bind_interval(value, *leading_field),
45 _ => bail_not_implemented!("value: {:?}", value),
46 }
47 }
48
49 pub(super) fn bind_string(&mut self, s: &str) -> Result<Literal> {
50 Ok(Literal::new_untyped(Some(s.to_owned())))
51 }
52
53 fn bind_bool(&mut self, b: bool) -> Result<Literal> {
54 Ok(Literal::new(Some(ScalarImpl::Bool(b)), DataType::Boolean))
55 }
56
57 fn bind_number(&mut self, mut s: String) -> Result<Literal> {
58 let prefix_start = match s.starts_with('-') {
59 true => 1,
60 false => 0,
61 };
62 let base = match prefix_start + 2 <= s.len() {
63 true => match &s[prefix_start..prefix_start + 2] {
64 "0x" => 16,
66 "0o" => 8,
67 "0b" => 2,
68 _ => 10,
69 },
70 false => 10,
71 };
72 if base != 10 {
73 s.replace_range(prefix_start..prefix_start + 2, "");
74 }
75
76 let (data, data_type) = if let Ok(int_32) = i32::from_str_radix(&s, base) {
77 (Some(ScalarImpl::Int32(int_32)), DataType::Int32)
78 } else if let Ok(int_64) = i64::from_str_radix(&s, base) {
79 (Some(ScalarImpl::Int64(int_64)), DataType::Int64)
80 } else if let Ok(decimal) = Decimal::from_str_radix(&s, base) {
81 (Some(ScalarImpl::Decimal(decimal)), DataType::Decimal)
83 } else if let Some(scientific) = Decimal::from_scientific(&s) {
84 (Some(ScalarImpl::Decimal(scientific)), DataType::Decimal)
85 } else {
86 return Err(ErrorCode::BindError(format!("Number {s} overflows")).into());
87 };
88 Ok(Literal::new(data, data_type))
89 }
90
91 fn bind_interval(
92 &mut self,
93 s: &str,
94 leading_field: Option<AstDateTimeField>,
95 ) -> Result<Literal> {
96 let interval =
97 Interval::parse_with_fields(s, leading_field.map(Self::bind_date_time_field))
98 .map_err(|e| ErrorCode::BindError(e.to_report_string()))?;
99 let datum = Some(ScalarImpl::Interval(interval));
100 let literal = Literal::new(datum, DataType::Interval);
101
102 Ok(literal)
103 }
104
105 pub(crate) fn bind_date_time_field(field: AstDateTimeField) -> DateTimeField {
106 match field {
109 AstDateTimeField::Year => DateTimeField::Year,
110 AstDateTimeField::Month => DateTimeField::Month,
111 AstDateTimeField::Day => DateTimeField::Day,
112 AstDateTimeField::Hour => DateTimeField::Hour,
113 AstDateTimeField::Minute => DateTimeField::Minute,
114 AstDateTimeField::Second => DateTimeField::Second,
115 }
116 }
117
118 pub(super) fn bind_array(&mut self, exprs: &[Expr]) -> Result<ExprImpl> {
120 if exprs.is_empty() {
121 return Err(ErrorCode::BindError("cannot determine type of empty array\nHINT: Explicitly cast to the desired type, for example ARRAY[]::integer[].".into()).into());
122 }
123 let mut exprs = exprs
124 .iter()
125 .map(|e| self.bind_expr_inner(e))
126 .collect::<Result<Vec<ExprImpl>>>()?;
127 let element_type = align_types(exprs.iter_mut())?;
128 let expr: ExprImpl =
129 FunctionCall::new_unchecked(ExprType::Array, exprs, DataType::list(element_type))
130 .into();
131 Ok(expr)
132 }
133
134 pub(super) fn bind_map(&mut self, entries: &[(Expr, Expr)]) -> Result<ExprImpl> {
135 if entries.is_empty() {
136 return Err(ErrorCode::BindError("cannot determine type of empty map\nHINT: Explicitly cast to the desired type, for example MAP{}::map(int,int).".into()).into());
137 }
138 let mut keys = Vec::with_capacity(entries.len());
139 let mut values = Vec::with_capacity(entries.len());
140 for (k, v) in entries {
141 keys.push(self.bind_expr_inner(k)?);
142 values.push(self.bind_expr_inner(v)?);
143 }
144 let key_type = align_types(keys.iter_mut())?;
145 let value_type = align_types(values.iter_mut())?;
146
147 let keys: ExprImpl =
148 FunctionCall::new_unchecked(ExprType::Array, keys, DataType::list(key_type.clone()))
149 .into();
150 let values: ExprImpl = FunctionCall::new_unchecked(
151 ExprType::Array,
152 values,
153 DataType::list(value_type.clone()),
154 )
155 .into();
156
157 let expr: ExprImpl = FunctionCall::new_unchecked(
158 ExprType::MapFromKeyValues,
159 vec![keys, values],
160 DataType::Map(MapType::from_kv(key_type, value_type)),
161 )
162 .into();
163 Ok(expr)
164 }
165
166 pub(super) fn bind_array_cast(
167 &mut self,
168 exprs: &[Expr],
169 element_type: &DataType,
170 ) -> Result<ExprImpl> {
171 let exprs = exprs
172 .iter()
173 .map(|e| self.bind_cast_inner(e, element_type))
174 .collect::<Result<Vec<ExprImpl>>>()?;
175
176 let expr: ExprImpl = FunctionCall::new_unchecked(
177 ExprType::Array,
178 exprs,
179 DataType::list(element_type.clone()),
180 )
181 .into();
182 Ok(expr)
183 }
184
185 pub(super) fn bind_map_cast(
186 &mut self,
187 entries: &[(Expr, Expr)],
188 map_type: &MapType,
189 ) -> Result<ExprImpl> {
190 let mut keys = Vec::with_capacity(entries.len());
191 let mut values = Vec::with_capacity(entries.len());
192 for (k, v) in entries {
193 keys.push(self.bind_cast_inner(k, map_type.key())?);
194 values.push(self.bind_cast_inner(v, map_type.value())?);
195 }
196
197 let keys: ExprImpl = FunctionCall::new_unchecked(
198 ExprType::Array,
199 keys,
200 DataType::list(map_type.key().clone()),
201 )
202 .into();
203 let values: ExprImpl = FunctionCall::new_unchecked(
204 ExprType::Array,
205 values,
206 DataType::list(map_type.value().clone()),
207 )
208 .into();
209
210 let expr: ExprImpl = FunctionCall::new_unchecked(
211 ExprType::MapFromKeyValues,
212 vec![keys, values],
213 DataType::Map(map_type.clone()),
214 )
215 .into();
216 Ok(expr)
217 }
218
219 pub(super) fn bind_index(&mut self, obj: &Expr, index: &Expr) -> Result<ExprImpl> {
220 let obj = self.bind_expr_inner(obj)?;
221 match obj.return_type() {
222 DataType::List(l) => Ok(FunctionCall::new_unchecked(
223 ExprType::ArrayAccess,
224 vec![obj, self.bind_expr_inner(index)?],
225 l.into_elem(),
226 )
227 .into()),
228 DataType::Map(m) => Ok(FunctionCall::new_unchecked(
229 ExprType::MapAccess,
230 vec![obj, self.bind_expr_inner(index)?],
231 m.value().clone(),
232 )
233 .into()),
234 data_type => Err(ErrorCode::BindError(format!(
235 "index operator applied to type {}, which is not a list or map",
236 data_type
237 ))
238 .into()),
239 }
240 }
241
242 pub(super) fn bind_array_range_index(
243 &mut self,
244 obj: &Expr,
245 start: Option<&Expr>,
246 end: Option<&Expr>,
247 ) -> Result<ExprImpl> {
248 let obj = self.bind_expr_inner(obj)?;
249 let start = match start {
250 None => ExprImpl::literal_int(1),
251 Some(expr) => self
252 .bind_expr_inner(expr)?
253 .cast_implicit(&DataType::Int32)?,
254 };
255 let end = match end {
258 None => ExprImpl::literal_int(i32::MAX),
259 Some(expr) => self
260 .bind_expr_inner(expr)?
261 .cast_implicit(&DataType::Int32)?,
262 };
263 match obj.return_type() {
264 t @ DataType::List(_) => Ok(FunctionCall::new_unchecked(
265 ExprType::ArrayRangeAccess,
266 vec![obj, start, end],
267 t,
268 )
269 .into()),
270 data_type => Err(ErrorCode::BindError(format!(
271 "array range index applied to type {}, which is not a list",
272 data_type
273 ))
274 .into()),
275 }
276 }
277
278 pub(super) fn bind_row(&mut self, exprs: &[Expr]) -> Result<ExprImpl> {
280 let exprs = exprs
281 .iter()
282 .map(|e| self.bind_expr_inner(e))
283 .collect::<Result<Vec<ExprImpl>>>()?;
284 let data_type =
285 StructType::unnamed(exprs.iter().map(|e| e.return_type()).collect_vec()).into();
286 let expr: ExprImpl = FunctionCall::new_unchecked(ExprType::Row, exprs, data_type).into();
287 Ok(expr)
288 }
289}
290
291#[cfg(test)]
292mod tests {
293 use risingwave_common::types::test_utils::IntervalTestExt;
294 use risingwave_expr::expr::build_from_prost;
295 use risingwave_sqlparser::ast::Value::Number;
296
297 use super::*;
298 use crate::binder::test_utils::mock_binder;
299 use crate::expr::Expr;
300
301 #[tokio::test]
302 async fn test_bind_value() {
303 use std::str::FromStr;
304
305 let mut binder = mock_binder();
306 let values = [
307 "1",
308 "111111111111111",
309 "111111111.111111",
310 "111111111111111111111111",
311 "0.111111",
312 "-0.01",
313 ];
314 let data = [
315 Some(ScalarImpl::Int32(1)),
316 Some(ScalarImpl::Int64(111111111111111)),
317 Some(ScalarImpl::Decimal(
318 Decimal::from_str("111111111.111111").unwrap(),
319 )),
320 Some(ScalarImpl::Decimal(
321 Decimal::from_str("111111111111111111111111").unwrap(),
322 )),
323 Some(ScalarImpl::Decimal(Decimal::from_str("0.111111").unwrap())),
324 Some(ScalarImpl::Decimal(Decimal::from_str("-0.01").unwrap())),
325 ];
326 let data_type = [
327 DataType::Int32,
328 DataType::Int64,
329 DataType::Decimal,
330 DataType::Decimal,
331 DataType::Decimal,
332 DataType::Decimal,
333 ];
334
335 for i in 0..values.len() {
336 let value = Value::Number(String::from(values[i]));
337 let res = binder.bind_value(&value).unwrap();
338 let ans = Literal::new(data[i].clone(), data_type[i].clone());
339 assert_eq!(res, ans);
340 }
341 }
342
343 #[tokio::test]
344 async fn test_bind_radix() {
345 let mut binder = mock_binder();
346
347 for (input, expected) in [
348 ("0x42e3", ScalarImpl::Int32(0x42e3)),
349 ("-0x40", ScalarImpl::Int32(-0x40)),
350 ("0b1101", ScalarImpl::Int32(0b1101)),
351 ("-0b101", ScalarImpl::Int32(-0b101)),
352 ("0o664", ScalarImpl::Int32(0o664)),
353 ("-0o755", ScalarImpl::Int32(-0o755)),
354 ("2147483647", ScalarImpl::Int32(2147483647)),
355 ("2147483648", ScalarImpl::Int64(2147483648)),
356 ("-2147483648", ScalarImpl::Int32(-2147483648)),
357 ("0x7fffffff", ScalarImpl::Int32(0x7fffffff)),
358 ("0x80000000", ScalarImpl::Int64(0x80000000)),
359 ("-0x80000000", ScalarImpl::Int32(-0x80000000)),
360 ] {
361 let lit = binder.bind_number(input.into()).unwrap();
362 assert_eq!(lit.get_data().as_ref().unwrap(), &expected);
363 }
364 }
365
366 #[tokio::test]
367 async fn test_bind_scientific_number() {
368 use std::str::FromStr;
369
370 let mut binder = mock_binder();
371 let values = [
372 ("1e6"),
373 ("1.25e6"),
374 ("1.25e1"),
375 ("1e-2"),
376 ("1.25e-2"),
377 ("1e15"),
378 ];
379 let data = [
380 Some(ScalarImpl::Decimal(Decimal::from_str("1000000").unwrap())),
381 Some(ScalarImpl::Decimal(Decimal::from_str("1250000").unwrap())),
382 Some(ScalarImpl::Decimal(Decimal::from_str("12.5").unwrap())),
383 Some(ScalarImpl::Decimal(Decimal::from_str("0.01").unwrap())),
384 Some(ScalarImpl::Decimal(Decimal::from_str("0.0125").unwrap())),
385 Some(ScalarImpl::Decimal(
386 Decimal::from_str("1000000000000000").unwrap(),
387 )),
388 ];
389 let data_type = [
390 DataType::Decimal,
391 DataType::Decimal,
392 DataType::Decimal,
393 DataType::Decimal,
394 DataType::Decimal,
395 DataType::Decimal,
396 ];
397
398 for i in 0..values.len() {
399 let res = binder.bind_value(&Number(values[i].to_owned())).unwrap();
400 let ans = Literal::new(data[i].clone(), data_type[i].clone());
401 assert_eq!(res, ans);
402 }
403 }
404
405 #[test]
406 fn test_array_expr() {
407 let expr: ExprImpl = FunctionCall::new_unchecked(
408 ExprType::Array,
409 vec![ExprImpl::literal_int(11)],
410 DataType::Int32.list(),
411 )
412 .into();
413 let expr_pb = expr.to_expr_proto();
414 let expr = build_from_prost(&expr_pb).unwrap();
415 match expr.return_type() {
416 DataType::List(list) => {
417 assert_eq!(list.into_elem(), DataType::Int32);
418 }
419 _ => panic!("unexpected type"),
420 };
421 }
422
423 #[test]
424 fn test_array_index_expr() {
425 let array_expr = FunctionCall::new_unchecked(
426 ExprType::Array,
427 vec![ExprImpl::literal_int(11), ExprImpl::literal_int(22)],
428 DataType::Int32.list(),
429 )
430 .into();
431
432 let expr: ExprImpl = FunctionCall::new_unchecked(
433 ExprType::ArrayAccess,
434 vec![array_expr, ExprImpl::literal_int(1)],
435 DataType::Int32,
436 )
437 .into();
438
439 let expr_pb = expr.to_expr_proto();
440 let expr = build_from_prost(&expr_pb).unwrap();
441 assert_eq!(expr.return_type(), DataType::Int32);
442 }
443
444 #[tokio::test]
445 async fn test_bind_interval() {
446 let mut binder = mock_binder();
447 let values = [
448 "1 hour",
449 "1 h",
450 "1 year",
451 "6 second",
452 "2 minutes",
453 "1 month",
454 ];
455 let data = vec![
456 Literal::new(
457 Some(ScalarImpl::Interval(Interval::from_minutes(60))),
458 DataType::Interval,
459 ),
460 Literal::new(
461 Some(ScalarImpl::Interval(Interval::from_minutes(60))),
462 DataType::Interval,
463 ),
464 Literal::new(
465 Some(ScalarImpl::Interval(Interval::from_ymd(1, 0, 0))),
466 DataType::Interval,
467 ),
468 Literal::new(
469 Some(ScalarImpl::Interval(Interval::from_millis(6 * 1000))),
470 DataType::Interval,
471 ),
472 Literal::new(
473 Some(ScalarImpl::Interval(Interval::from_minutes(2))),
474 DataType::Interval,
475 ),
476 Literal::new(
477 Some(ScalarImpl::Interval(Interval::from_month(1))),
478 DataType::Interval,
479 ),
480 ];
481
482 for i in 0..values.len() {
483 let value = Value::Interval {
484 value: values[i].to_owned(),
485 leading_field: None,
486 leading_precision: None,
487 last_field: None,
488 fractional_seconds_precision: None,
489 };
490 assert_eq!(binder.bind_value(&value).unwrap(), data[i]);
491 }
492 }
493}