1use risingwave_common::bail_not_implemented;
16use risingwave_common::types::{
17 DataType, DateTimeField, Decimal, Interval, MapType, ScalarImpl, StructType,
18};
19use risingwave_sqlparser::ast::{DateTimeField as AstDateTimeField, Expr, Value};
20use thiserror_ext::AsReport;
21
22use crate::binder::Binder;
23use crate::error::{ErrorCode, Result};
24use crate::expr::{Expr as _, ExprImpl, ExprType, FunctionCall, Literal, align_types};
25
26impl Binder {
27 pub fn bind_value(&mut self, value: &Value) -> Result<Literal> {
28 match value {
29 Value::Number(s) => self.bind_number(s.clone()),
30 Value::SingleQuotedString(s) => self.bind_string(s),
31 Value::CstyleEscapedString(s) => self.bind_string(&s.value),
32 Value::Boolean(b) => self.bind_bool(*b),
33 Value::Null => Ok(Literal::new_untyped(None)),
36 Value::Interval {
37 value,
38 leading_field,
39 leading_precision: None,
41 last_field: None,
42 fractional_seconds_precision: None,
43 } => self.bind_interval(value, *leading_field),
44 _ => bail_not_implemented!("value: {:?}", value),
45 }
46 }
47
48 pub(super) fn bind_string(&mut self, s: &str) -> Result<Literal> {
49 Ok(Literal::new_untyped(Some(s.to_owned())))
50 }
51
52 fn bind_bool(&mut self, b: bool) -> Result<Literal> {
53 Ok(Literal::new(Some(ScalarImpl::Bool(b)), DataType::Boolean))
54 }
55
56 fn bind_number(&mut self, mut s: String) -> Result<Literal> {
57 let prefix_start = match s.starts_with('-') {
58 true => 1,
59 false => 0,
60 };
61 let base = match prefix_start + 2 <= s.len() {
62 true => match &s[prefix_start..prefix_start + 2] {
63 "0x" => 16,
65 "0o" => 8,
66 "0b" => 2,
67 _ => 10,
68 },
69 false => 10,
70 };
71 if base != 10 {
72 s.replace_range(prefix_start..prefix_start + 2, "");
73 }
74
75 let (data, data_type) = if let Ok(int_32) = i32::from_str_radix(&s, base) {
76 (Some(ScalarImpl::Int32(int_32)), DataType::Int32)
77 } else if let Ok(int_64) = i64::from_str_radix(&s, base) {
78 (Some(ScalarImpl::Int64(int_64)), DataType::Int64)
79 } else if let Ok(decimal) = Decimal::from_str_radix(&s, base) {
80 (Some(ScalarImpl::Decimal(decimal)), DataType::Decimal)
82 } else if let Some(scientific) = Decimal::from_scientific(&s) {
83 (Some(ScalarImpl::Decimal(scientific)), DataType::Decimal)
84 } else {
85 return Err(ErrorCode::BindError(format!("Number {s} overflows")).into());
86 };
87 Ok(Literal::new(data, data_type))
88 }
89
90 fn bind_interval(
91 &mut self,
92 s: &str,
93 leading_field: Option<AstDateTimeField>,
94 ) -> Result<Literal> {
95 let interval =
96 Interval::parse_with_fields(s, leading_field.map(Self::bind_date_time_field))
97 .map_err(|e| ErrorCode::BindError(e.to_report_string()))?;
98 let datum = Some(ScalarImpl::Interval(interval));
99 let literal = Literal::new(datum, DataType::Interval);
100
101 Ok(literal)
102 }
103
104 pub(crate) fn bind_date_time_field(field: AstDateTimeField) -> DateTimeField {
105 match field {
108 AstDateTimeField::Year => DateTimeField::Year,
109 AstDateTimeField::Month => DateTimeField::Month,
110 AstDateTimeField::Day => DateTimeField::Day,
111 AstDateTimeField::Hour => DateTimeField::Hour,
112 AstDateTimeField::Minute => DateTimeField::Minute,
113 AstDateTimeField::Second => DateTimeField::Second,
114 }
115 }
116
117 pub(super) fn bind_array(&mut self, exprs: &[Expr]) -> Result<ExprImpl> {
119 if exprs.is_empty() {
120 return Err(ErrorCode::BindError("cannot determine type of empty array\nHINT: Explicitly cast to the desired type, for example ARRAY[]::integer[].".into()).into());
121 }
122 let mut exprs = exprs
123 .iter()
124 .map(|e| self.bind_expr_inner(e))
125 .collect::<Result<Vec<ExprImpl>>>()?;
126 let element_type = align_types(exprs.iter_mut())?;
127 let expr: ExprImpl =
128 FunctionCall::new_unchecked(ExprType::Array, exprs, DataType::list(element_type))
129 .into();
130 Ok(expr)
131 }
132
133 pub(super) fn bind_map(&mut self, entries: &[(Expr, Expr)]) -> Result<ExprImpl> {
134 if entries.is_empty() {
135 return Err(ErrorCode::BindError("cannot determine type of empty map\nHINT: Explicitly cast to the desired type, for example MAP{}::map(int,int).".into()).into());
136 }
137 let mut keys = Vec::with_capacity(entries.len());
138 let mut values = Vec::with_capacity(entries.len());
139 for (k, v) in entries {
140 keys.push(self.bind_expr_inner(k)?);
141 values.push(self.bind_expr_inner(v)?);
142 }
143 let key_type = align_types(keys.iter_mut())?;
144 let value_type = align_types(values.iter_mut())?;
145
146 let keys: ExprImpl =
147 FunctionCall::new_unchecked(ExprType::Array, keys, DataType::list(key_type.clone()))
148 .into();
149 let values: ExprImpl = FunctionCall::new_unchecked(
150 ExprType::Array,
151 values,
152 DataType::list(value_type.clone()),
153 )
154 .into();
155
156 let expr: ExprImpl = FunctionCall::new_unchecked(
157 ExprType::MapFromKeyValues,
158 vec![keys, values],
159 DataType::Map(MapType::from_kv(key_type, value_type)),
160 )
161 .into();
162 Ok(expr)
163 }
164
165 pub(super) fn bind_array_cast(
166 &mut self,
167 exprs: &[Expr],
168 element_type: &DataType,
169 ) -> Result<ExprImpl> {
170 let exprs = exprs
171 .iter()
172 .map(|e| self.bind_cast_inner(e, element_type))
173 .collect::<Result<Vec<ExprImpl>>>()?;
174
175 let expr: ExprImpl = FunctionCall::new_unchecked(
176 ExprType::Array,
177 exprs,
178 DataType::list(element_type.clone()),
179 )
180 .into();
181 Ok(expr)
182 }
183
184 pub(super) fn bind_map_cast(
185 &mut self,
186 entries: &[(Expr, Expr)],
187 map_type: &MapType,
188 ) -> Result<ExprImpl> {
189 let mut keys = Vec::with_capacity(entries.len());
190 let mut values = Vec::with_capacity(entries.len());
191 for (k, v) in entries {
192 keys.push(self.bind_cast_inner(k, map_type.key())?);
193 values.push(self.bind_cast_inner(v, map_type.value())?);
194 }
195
196 let keys: ExprImpl = FunctionCall::new_unchecked(
197 ExprType::Array,
198 keys,
199 DataType::list(map_type.key().clone()),
200 )
201 .into();
202 let values: ExprImpl = FunctionCall::new_unchecked(
203 ExprType::Array,
204 values,
205 DataType::list(map_type.value().clone()),
206 )
207 .into();
208
209 let expr: ExprImpl = FunctionCall::new_unchecked(
210 ExprType::MapFromKeyValues,
211 vec![keys, values],
212 DataType::Map(map_type.clone()),
213 )
214 .into();
215 Ok(expr)
216 }
217
218 pub(super) fn bind_index(&mut self, obj: &Expr, index: &Expr) -> Result<ExprImpl> {
219 let obj = self.bind_expr_inner(obj)?;
220 match obj.return_type() {
221 DataType::List(l) => Ok(FunctionCall::new_unchecked(
222 ExprType::ArrayAccess,
223 vec![obj, self.bind_expr_inner(index)?],
224 l.into_elem(),
225 )
226 .into()),
227 DataType::Map(m) => Ok(FunctionCall::new_unchecked(
228 ExprType::MapAccess,
229 vec![obj, self.bind_expr_inner(index)?],
230 m.value().clone(),
231 )
232 .into()),
233 data_type => Err(ErrorCode::BindError(format!(
234 "index operator applied to type {}, which is not a list or map",
235 data_type
236 ))
237 .into()),
238 }
239 }
240
241 pub(super) fn bind_array_range_index(
242 &mut self,
243 obj: &Expr,
244 start: Option<&Expr>,
245 end: Option<&Expr>,
246 ) -> Result<ExprImpl> {
247 let obj = self.bind_expr_inner(obj)?;
248 let start = match start {
249 None => ExprImpl::literal_int(1),
250 Some(expr) => self
251 .bind_expr_inner(expr)?
252 .cast_implicit(&DataType::Int32)?,
253 };
254 let end = match end {
257 None => ExprImpl::literal_int(i32::MAX),
258 Some(expr) => self
259 .bind_expr_inner(expr)?
260 .cast_implicit(&DataType::Int32)?,
261 };
262 match obj.return_type() {
263 t @ DataType::List(_) => Ok(FunctionCall::new_unchecked(
264 ExprType::ArrayRangeAccess,
265 vec![obj, start, end],
266 t,
267 )
268 .into()),
269 data_type => Err(ErrorCode::BindError(format!(
270 "array range index applied to type {}, which is not a list",
271 data_type
272 ))
273 .into()),
274 }
275 }
276
277 pub(super) fn bind_row(&mut self, exprs: &[Expr]) -> Result<ExprImpl> {
279 let exprs = exprs
280 .iter()
281 .map(|e| self.bind_expr_inner(e))
282 .collect::<Result<Vec<ExprImpl>>>()?;
283 let data_type = StructType::row_expr_type(exprs.iter().map(|e| e.return_type())).into();
284 let expr: ExprImpl = FunctionCall::new_unchecked(ExprType::Row, exprs, data_type).into();
285 Ok(expr)
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use risingwave_common::types::test_utils::IntervalTestExt;
292 use risingwave_expr::expr::build_from_prost;
293 use risingwave_sqlparser::ast::Value::Number;
294
295 use super::*;
296 use crate::binder::test_utils::mock_binder;
297 use crate::expr::Expr;
298
299 #[tokio::test]
300 async fn test_bind_value() {
301 use std::str::FromStr;
302
303 let mut binder = mock_binder();
304 let values = [
305 "1",
306 "111111111111111",
307 "111111111.111111",
308 "111111111111111111111111",
309 "0.111111",
310 "-0.01",
311 ];
312 let data = [
313 Some(ScalarImpl::Int32(1)),
314 Some(ScalarImpl::Int64(111111111111111)),
315 Some(ScalarImpl::Decimal(
316 Decimal::from_str("111111111.111111").unwrap(),
317 )),
318 Some(ScalarImpl::Decimal(
319 Decimal::from_str("111111111111111111111111").unwrap(),
320 )),
321 Some(ScalarImpl::Decimal(Decimal::from_str("0.111111").unwrap())),
322 Some(ScalarImpl::Decimal(Decimal::from_str("-0.01").unwrap())),
323 ];
324 let data_type = [
325 DataType::Int32,
326 DataType::Int64,
327 DataType::Decimal,
328 DataType::Decimal,
329 DataType::Decimal,
330 DataType::Decimal,
331 ];
332
333 for i in 0..values.len() {
334 let value = Value::Number(String::from(values[i]));
335 let res = binder.bind_value(&value).unwrap();
336 let ans = Literal::new(data[i].clone(), data_type[i].clone());
337 assert_eq!(res, ans);
338 }
339 }
340
341 #[tokio::test]
342 async fn test_bind_radix() {
343 let mut binder = mock_binder();
344
345 for (input, expected) in [
346 ("0x42e3", ScalarImpl::Int32(0x42e3)),
347 ("-0x40", ScalarImpl::Int32(-0x40)),
348 ("0b1101", ScalarImpl::Int32(0b1101)),
349 ("-0b101", ScalarImpl::Int32(-0b101)),
350 ("0o664", ScalarImpl::Int32(0o664)),
351 ("-0o755", ScalarImpl::Int32(-0o755)),
352 ("2147483647", ScalarImpl::Int32(2147483647)),
353 ("2147483648", ScalarImpl::Int64(2147483648)),
354 ("-2147483648", ScalarImpl::Int32(-2147483648)),
355 ("0x7fffffff", ScalarImpl::Int32(0x7fffffff)),
356 ("0x80000000", ScalarImpl::Int64(0x80000000)),
357 ("-0x80000000", ScalarImpl::Int32(-0x80000000)),
358 ] {
359 let lit = binder.bind_number(input.into()).unwrap();
360 assert_eq!(lit.get_data().as_ref().unwrap(), &expected);
361 }
362 }
363
364 #[tokio::test]
365 async fn test_bind_scientific_number() {
366 use std::str::FromStr;
367
368 let mut binder = mock_binder();
369 let values = [
370 ("1e6"),
371 ("1.25e6"),
372 ("1.25e1"),
373 ("1e-2"),
374 ("1.25e-2"),
375 ("1e15"),
376 ];
377 let data = [
378 Some(ScalarImpl::Decimal(Decimal::from_str("1000000").unwrap())),
379 Some(ScalarImpl::Decimal(Decimal::from_str("1250000").unwrap())),
380 Some(ScalarImpl::Decimal(Decimal::from_str("12.5").unwrap())),
381 Some(ScalarImpl::Decimal(Decimal::from_str("0.01").unwrap())),
382 Some(ScalarImpl::Decimal(Decimal::from_str("0.0125").unwrap())),
383 Some(ScalarImpl::Decimal(
384 Decimal::from_str("1000000000000000").unwrap(),
385 )),
386 ];
387 let data_type = [
388 DataType::Decimal,
389 DataType::Decimal,
390 DataType::Decimal,
391 DataType::Decimal,
392 DataType::Decimal,
393 DataType::Decimal,
394 ];
395
396 for i in 0..values.len() {
397 let res = binder.bind_value(&Number(values[i].to_owned())).unwrap();
398 let ans = Literal::new(data[i].clone(), data_type[i].clone());
399 assert_eq!(res, ans);
400 }
401 }
402
403 #[test]
404 fn test_array_expr() {
405 let expr: ExprImpl = FunctionCall::new_unchecked(
406 ExprType::Array,
407 vec![ExprImpl::literal_int(11)],
408 DataType::Int32.list(),
409 )
410 .into();
411 let expr_pb = expr.to_expr_proto();
412 let expr = build_from_prost(&expr_pb).unwrap();
413 match expr.return_type() {
414 DataType::List(list) => {
415 assert_eq!(list.into_elem(), DataType::Int32);
416 }
417 _ => panic!("unexpected type"),
418 };
419 }
420
421 #[test]
422 fn test_array_index_expr() {
423 let array_expr = FunctionCall::new_unchecked(
424 ExprType::Array,
425 vec![ExprImpl::literal_int(11), ExprImpl::literal_int(22)],
426 DataType::Int32.list(),
427 )
428 .into();
429
430 let expr: ExprImpl = FunctionCall::new_unchecked(
431 ExprType::ArrayAccess,
432 vec![array_expr, ExprImpl::literal_int(1)],
433 DataType::Int32,
434 )
435 .into();
436
437 let expr_pb = expr.to_expr_proto();
438 let expr = build_from_prost(&expr_pb).unwrap();
439 assert_eq!(expr.return_type(), DataType::Int32);
440 }
441
442 #[tokio::test]
443 async fn test_bind_interval() {
444 let mut binder = mock_binder();
445 let values = [
446 "1 hour",
447 "1 h",
448 "1 year",
449 "6 second",
450 "2 minutes",
451 "1 month",
452 ];
453 let data = [
454 Literal::new(
455 Some(ScalarImpl::Interval(Interval::from_minutes(60))),
456 DataType::Interval,
457 ),
458 Literal::new(
459 Some(ScalarImpl::Interval(Interval::from_minutes(60))),
460 DataType::Interval,
461 ),
462 Literal::new(
463 Some(ScalarImpl::Interval(Interval::from_ymd(1, 0, 0))),
464 DataType::Interval,
465 ),
466 Literal::new(
467 Some(ScalarImpl::Interval(Interval::from_millis(6 * 1000))),
468 DataType::Interval,
469 ),
470 Literal::new(
471 Some(ScalarImpl::Interval(Interval::from_minutes(2))),
472 DataType::Interval,
473 ),
474 Literal::new(
475 Some(ScalarImpl::Interval(Interval::from_month(1))),
476 DataType::Interval,
477 ),
478 ];
479
480 for i in 0..values.len() {
481 let value = Value::Interval {
482 value: values[i].to_owned(),
483 leading_field: None,
484 leading_precision: None,
485 last_field: None,
486 fractional_seconds_precision: None,
487 };
488 assert_eq!(binder.bind_value(&value).unwrap(), data[i]);
489 }
490 }
491}