1use itertools::Itertools;
16use risingwave_common::catalog::PG_CATALOG_SCHEMA_NAME;
17use risingwave_common::types::{DataType, MapType, StructType};
18use risingwave_common::util::iter_util::zip_eq_fast;
19use risingwave_common::{bail_no_function, bail_not_implemented, not_implemented};
20use risingwave_sqlparser::ast::{
21 Array, BinaryOperator, DataType as AstDataType, EscapeChar, Expr, Function, JsonPredicateType,
22 ObjectName, Query, TrimWhereField, UnaryOperator,
23};
24
25use crate::binder::Binder;
26use crate::binder::expr::function::is_sys_function_without_args;
27use crate::error::{ErrorCode, Result, RwError};
28use crate::expr::{Expr as _, ExprImpl, ExprType, FunctionCall, InputRef, Parameter, SubqueryKind};
29use crate::handler::create_sql_function::SQL_UDF_PATTERN;
30
31mod binary_op;
32mod column;
33mod function;
34mod order_by;
35mod subquery;
36mod value;
37
38const CASE_WHEN_ARMS_OPTIMIZE_LIMIT: usize = 30;
44
45impl Binder {
46 pub fn bind_expr(&mut self, expr: Expr) -> Result<ExprImpl> {
54 self.bind_expr_inner(expr.clone()).map_err(|e| {
55 RwError::from(ErrorCode::BindErrorRoot {
56 expr: expr.to_string(),
57 error: Box::new(e),
58 })
59 })
60 }
61
62 fn bind_expr_inner(&mut self, expr: Expr) -> Result<ExprImpl> {
63 match expr {
64 Expr::Value(v) => Ok(ExprImpl::Literal(Box::new(self.bind_value(v)?))),
66 Expr::TypedString { data_type, value } => {
67 let s: ExprImpl = self.bind_string(value)?.into();
68 s.cast_explicit(bind_data_type(&data_type)?)
69 .map_err(Into::into)
70 }
71 Expr::Row(exprs) => self.bind_row(exprs),
72 Expr::Identifier(ident) => {
74 if is_sys_function_without_args(&ident) {
75 self.bind_function(Function::no_arg(ObjectName(vec![ident])))
80 } else if let Some(ref lambda_args) = self.context.lambda_args {
81 if let Some((arg_idx, arg_type)) = lambda_args.get(&ident.real_value()) {
84 Ok(InputRef::new(*arg_idx, arg_type.clone()).into())
85 } else {
86 Err(
87 ErrorCode::ItemNotFound(format!("Unknown arg: {}", ident.real_value()))
88 .into(),
89 )
90 }
91 } else if let Some(ctx) = self.secure_compare_context.as_ref() {
92 if ident.real_value() == *"headers" {
98 Ok(InputRef::new(0, DataType::Jsonb).into())
99 } else if ctx.secret_name.is_some()
100 && ident.real_value() == *ctx.secret_name.as_ref().unwrap()
101 {
102 Ok(InputRef::new(1, DataType::Varchar).into())
103 } else if ident.real_value() == ctx.column_name {
104 Ok(InputRef::new(2, DataType::Bytea).into())
105 } else {
106 Err(
107 ErrorCode::ItemNotFound(format!("Unknown arg: {}", ident.real_value()))
108 .into(),
109 )
110 }
111 } else {
112 self.bind_column(&[ident])
113 }
114 }
115 Expr::CompoundIdentifier(idents) => self.bind_column(&idents),
116 Expr::FieldIdentifier(field_expr, idents) => {
117 self.bind_single_field_column(*field_expr, &idents)
118 }
119 Expr::UnaryOp { op, expr } => self.bind_unary_expr(op, *expr),
121 Expr::BinaryOp { left, op, right } => self.bind_binary_op(*left, op, *right),
122 Expr::Nested(expr) => self.bind_expr_inner(*expr),
123 Expr::Array(Array { elem: exprs, .. }) => self.bind_array(exprs),
124 Expr::Index { obj, index } => self.bind_index(*obj, *index),
125 Expr::ArrayRangeIndex { obj, start, end } => {
126 self.bind_array_range_index(*obj, start, end)
127 }
128 Expr::Function(f) => self.bind_function(f),
129 Expr::Subquery(q) => self.bind_subquery_expr(*q, SubqueryKind::Scalar),
130 Expr::Exists(q) => self.bind_subquery_expr(*q, SubqueryKind::Existential),
131 Expr::InSubquery {
132 expr,
133 subquery,
134 negated,
135 } => self.bind_in_subquery(*expr, *subquery, negated),
136 Expr::Cast { expr, data_type } => self.bind_cast(*expr, data_type),
138 Expr::IsNull(expr) => self.bind_is_operator(ExprType::IsNull, *expr),
139 Expr::IsNotNull(expr) => self.bind_is_operator(ExprType::IsNotNull, *expr),
140 Expr::IsTrue(expr) => self.bind_is_operator(ExprType::IsTrue, *expr),
141 Expr::IsNotTrue(expr) => self.bind_is_operator(ExprType::IsNotTrue, *expr),
142 Expr::IsFalse(expr) => self.bind_is_operator(ExprType::IsFalse, *expr),
143 Expr::IsNotFalse(expr) => self.bind_is_operator(ExprType::IsNotFalse, *expr),
144 Expr::IsUnknown(expr) => self.bind_is_unknown(ExprType::IsNull, *expr),
145 Expr::IsNotUnknown(expr) => self.bind_is_unknown(ExprType::IsNotNull, *expr),
146 Expr::IsDistinctFrom(left, right) => self.bind_distinct_from(*left, *right),
147 Expr::IsNotDistinctFrom(left, right) => self.bind_not_distinct_from(*left, *right),
148 Expr::IsJson {
149 expr,
150 negated,
151 item_type,
152 unique_keys: false,
153 } => self.bind_is_json(*expr, negated, item_type),
154 Expr::Case {
155 operand,
156 conditions,
157 results,
158 else_result,
159 } => self.bind_case(operand, conditions, results, else_result),
160 Expr::Between {
161 expr,
162 negated,
163 low,
164 high,
165 } => self.bind_between(*expr, negated, *low, *high),
166 Expr::Like {
167 negated,
168 expr,
169 pattern,
170 escape_char,
171 } => self.bind_like(ExprType::Like, *expr, negated, *pattern, escape_char),
172 Expr::ILike {
173 negated,
174 expr,
175 pattern,
176 escape_char,
177 } => self.bind_like(ExprType::ILike, *expr, negated, *pattern, escape_char),
178 Expr::SimilarTo {
179 expr,
180 negated,
181 pattern,
182 escape_char,
183 } => self.bind_similar_to(*expr, negated, *pattern, escape_char),
184 Expr::InList {
185 expr,
186 list,
187 negated,
188 } => self.bind_in_list(*expr, list, negated),
189 Expr::Extract { field, expr } => self.bind_extract(field, *expr),
191 Expr::AtTimeZone {
192 timestamp,
193 time_zone,
194 } => self.bind_at_time_zone(*timestamp, *time_zone),
195 Expr::Trim {
197 expr,
198 trim_where,
199 trim_what,
200 } => self.bind_trim(*expr, trim_where, trim_what),
201 Expr::Substring {
202 expr,
203 substring_from,
204 substring_for,
205 } => self.bind_substring(*expr, substring_from, substring_for),
206 Expr::Position { substring, string } => self.bind_position(*substring, *string),
207 Expr::Overlay {
208 expr,
209 new_substring,
210 start,
211 count,
212 } => self.bind_overlay(*expr, *new_substring, *start, count),
213 Expr::Parameter { index } => self.bind_parameter(index),
214 Expr::Collate { expr, collation } => self.bind_collate(*expr, collation),
215 Expr::ArraySubquery(q) => self.bind_subquery_expr(*q, SubqueryKind::Array),
216 Expr::Map { entries } => self.bind_map(entries),
217 Expr::IsJson {
218 unique_keys: true, ..
219 }
220 | Expr::SomeOp(_)
221 | Expr::AllOp(_)
222 | Expr::TryCast { .. }
223 | Expr::GroupingSets(_)
224 | Expr::Cube(_)
225 | Expr::Rollup(_)
226 | Expr::LambdaFunction { .. } => {
227 bail_not_implemented!(issue = 112, "unsupported expression {:?}", expr)
228 }
229 }
230 }
231
232 pub(super) fn bind_extract(&mut self, field: String, expr: Expr) -> Result<ExprImpl> {
233 let arg = self.bind_expr_inner(expr)?;
234 let arg_type = arg.return_type();
235 Ok(FunctionCall::new(
236 ExprType::Extract,
237 vec![self.bind_string(field.clone())?.into(), arg],
238 )
239 .map_err(|_| {
240 not_implemented!(
241 issue = 112,
242 "function extract({} from {:?}) doesn't exist",
243 field,
244 arg_type
245 )
246 })?
247 .into())
248 }
249
250 pub(super) fn bind_at_time_zone(&mut self, input: Expr, time_zone: Expr) -> Result<ExprImpl> {
251 let input = self.bind_expr_inner(input)?;
252 let time_zone = self.bind_expr_inner(time_zone)?;
253 FunctionCall::new(ExprType::AtTimeZone, vec![input, time_zone]).map(Into::into)
254 }
255
256 pub(super) fn bind_in_list(
257 &mut self,
258 expr: Expr,
259 list: Vec<Expr>,
260 negated: bool,
261 ) -> Result<ExprImpl> {
262 let left = self.bind_expr_inner(expr)?;
263 let mut bound_expr_list = vec![left.clone()];
264 let mut non_const_exprs = vec![];
265 for elem in list {
266 let expr = self.bind_expr_inner(elem)?;
267 match expr.is_const() {
268 true => bound_expr_list.push(expr),
269 false => non_const_exprs.push(expr),
270 }
271 }
272 let mut ret = FunctionCall::new(ExprType::In, bound_expr_list)?.into();
273 for expr in non_const_exprs {
275 ret = FunctionCall::new(
276 ExprType::Or,
277 vec![
278 ret,
279 FunctionCall::new(ExprType::Equal, vec![left.clone(), expr])?.into(),
280 ],
281 )?
282 .into();
283 }
284 if negated {
285 Ok(FunctionCall::new_unchecked(ExprType::Not, vec![ret], DataType::Boolean).into())
286 } else {
287 Ok(ret)
288 }
289 }
290
291 pub(super) fn bind_in_subquery(
292 &mut self,
293 expr: Expr,
294 subquery: Query,
295 negated: bool,
296 ) -> Result<ExprImpl> {
297 let bound_expr = self.bind_expr_inner(expr)?;
298 let bound_subquery = self.bind_subquery_expr(subquery, SubqueryKind::In(bound_expr))?;
299 if negated {
300 Ok(
301 FunctionCall::new_unchecked(ExprType::Not, vec![bound_subquery], DataType::Boolean)
302 .into(),
303 )
304 } else {
305 Ok(bound_subquery)
306 }
307 }
308
309 pub(super) fn bind_is_json(
310 &mut self,
311 expr: Expr,
312 negated: bool,
313 item_type: JsonPredicateType,
314 ) -> Result<ExprImpl> {
315 let mut args = vec![self.bind_expr_inner(expr)?];
316 let type_symbol = match item_type {
318 JsonPredicateType::Value => None,
319 JsonPredicateType::Array => Some("ARRAY"),
320 JsonPredicateType::Object => Some("OBJECT"),
321 JsonPredicateType::Scalar => Some("SCALAR"),
322 };
323 if let Some(s) = type_symbol {
324 args.push(ExprImpl::literal_varchar(s.into()));
325 }
326
327 let is_json = FunctionCall::new(ExprType::IsJson, args)?.into();
328 if negated {
329 Ok(FunctionCall::new(ExprType::Not, vec![is_json])?.into())
330 } else {
331 Ok(is_json)
332 }
333 }
334
335 pub(super) fn bind_unary_expr(&mut self, op: UnaryOperator, expr: Expr) -> Result<ExprImpl> {
336 let func_type = match op {
337 UnaryOperator::Not => ExprType::Not,
338 UnaryOperator::Minus => ExprType::Neg,
339 UnaryOperator::PGAbs => ExprType::Abs,
340 UnaryOperator::PGBitwiseNot => ExprType::BitwiseNot,
341 UnaryOperator::Plus => {
342 return self.rewrite_positive(expr);
343 }
344 UnaryOperator::PGSquareRoot => ExprType::Sqrt,
345 UnaryOperator::PGCubeRoot => ExprType::Cbrt,
346 _ => bail_not_implemented!(issue = 112, "unsupported unary expression: {:?}", op),
347 };
348 let expr = self.bind_expr_inner(expr)?;
349 FunctionCall::new(func_type, vec![expr]).map(|f| f.into())
350 }
351
352 fn rewrite_positive(&mut self, expr: Expr) -> Result<ExprImpl> {
354 let expr = self.bind_expr_inner(expr)?;
355 let return_type = expr.return_type();
356 if return_type.is_numeric() {
357 return Ok(expr);
358 }
359 Err(ErrorCode::InvalidInputSyntax(format!("+ {:?}", return_type)).into())
360 }
361
362 pub(super) fn bind_trim(
363 &mut self,
364 expr: Expr,
365 trim_where: Option<TrimWhereField>,
367 trim_what: Option<Box<Expr>>,
368 ) -> Result<ExprImpl> {
369 let mut inputs = vec![self.bind_expr_inner(expr)?];
370 let func_type = match trim_where {
371 Some(TrimWhereField::Both) => ExprType::Trim,
372 Some(TrimWhereField::Leading) => ExprType::Ltrim,
373 Some(TrimWhereField::Trailing) => ExprType::Rtrim,
374 None => ExprType::Trim,
375 };
376 if let Some(t) = trim_what {
377 inputs.push(self.bind_expr_inner(*t)?);
378 }
379 Ok(FunctionCall::new(func_type, inputs)?.into())
380 }
381
382 fn bind_substring(
383 &mut self,
384 expr: Expr,
385 substring_from: Option<Box<Expr>>,
386 substring_for: Option<Box<Expr>>,
387 ) -> Result<ExprImpl> {
388 let mut args = vec![
389 self.bind_expr_inner(expr)?,
390 match substring_from {
391 Some(expr) => self.bind_expr_inner(*expr)?,
392 None => ExprImpl::literal_int(1),
393 },
394 ];
395 if let Some(expr) = substring_for {
396 args.push(self.bind_expr_inner(*expr)?);
397 }
398 FunctionCall::new(ExprType::Substr, args).map(|f| f.into())
399 }
400
401 fn bind_position(&mut self, substring: Expr, string: Expr) -> Result<ExprImpl> {
402 let args = vec![
403 self.bind_expr_inner(string)?,
405 self.bind_expr_inner(substring)?,
406 ];
407 FunctionCall::new(ExprType::Position, args).map(Into::into)
408 }
409
410 fn bind_overlay(
411 &mut self,
412 expr: Expr,
413 new_substring: Expr,
414 start: Expr,
415 count: Option<Box<Expr>>,
416 ) -> Result<ExprImpl> {
417 let mut args = vec![
418 self.bind_expr_inner(expr)?,
419 self.bind_expr_inner(new_substring)?,
420 self.bind_expr_inner(start)?,
421 ];
422 if let Some(count) = count {
423 args.push(self.bind_expr_inner(*count)?);
424 }
425 FunctionCall::new(ExprType::Overlay, args).map(|f| f.into())
426 }
427
428 fn bind_parameter(&mut self, index: u64) -> Result<ExprImpl> {
429 if self.udf_context.global_count() != 0 {
434 if let Some(expr) = self.udf_context.get_expr(&format!("${index}")) {
435 return Ok(expr.clone());
436 }
437 return Err(ErrorCode::BindError(format!(
440 "{SQL_UDF_PATTERN} failed to find unnamed parameter ${index}"
441 ))
442 .into());
443 }
444
445 Ok(Parameter::new(index, self.param_types.clone()).into())
446 }
447
448 pub(super) fn bind_between(
450 &mut self,
451 expr: Expr,
452 negated: bool,
453 low: Expr,
454 high: Expr,
455 ) -> Result<ExprImpl> {
456 let expr = self.bind_expr_inner(expr)?;
457 let low = self.bind_expr_inner(low)?;
458 let high = self.bind_expr_inner(high)?;
459
460 let func_call = if negated {
461 FunctionCall::new_unchecked(
463 ExprType::Or,
464 vec![
465 FunctionCall::new(ExprType::LessThan, vec![expr.clone(), low])?.into(),
466 FunctionCall::new(ExprType::GreaterThan, vec![expr, high])?.into(),
467 ],
468 DataType::Boolean,
469 )
470 } else {
471 FunctionCall::new_unchecked(
473 ExprType::And,
474 vec![
475 FunctionCall::new(ExprType::GreaterThanOrEqual, vec![expr.clone(), low])?
476 .into(),
477 FunctionCall::new(ExprType::LessThanOrEqual, vec![expr, high])?.into(),
478 ],
479 DataType::Boolean,
480 )
481 };
482
483 Ok(func_call.into())
484 }
485
486 fn bind_like(
487 &mut self,
488 expr_type: ExprType,
489 expr: Expr,
490 negated: bool,
491 pattern: Expr,
492 escape_char: Option<EscapeChar>,
493 ) -> Result<ExprImpl> {
494 if matches!(pattern, Expr::AllOp(_) | Expr::SomeOp(_)) {
495 if escape_char.is_some() {
496 bail_not_implemented!(
502 "LIKE with both ALL|ANY pattern and escape character is not supported"
503 )
504 }
505 let op = match (expr_type, negated) {
507 (ExprType::Like, false) => BinaryOperator::PGLikeMatch,
508 (ExprType::Like, true) => BinaryOperator::PGNotLikeMatch,
509 (ExprType::ILike, false) => BinaryOperator::PGILikeMatch,
510 (ExprType::ILike, true) => BinaryOperator::PGNotILikeMatch,
511 _ => unreachable!(),
512 };
513 return self.bind_binary_op(expr, op, pattern);
514 }
515 let expr = self.bind_expr_inner(expr)?;
516 let pattern = self.bind_expr_inner(pattern)?;
517 match (expr.return_type(), pattern.return_type()) {
518 (DataType::Varchar, DataType::Varchar) => {}
519 (string_ty, pattern_ty) => match expr_type {
520 ExprType::Like => bail_no_function!("like({}, {})", string_ty, pattern_ty),
521 ExprType::ILike => bail_no_function!("ilike({}, {})", string_ty, pattern_ty),
522 _ => unreachable!(),
523 },
524 }
525 let args = match escape_char {
526 Some(escape_char) => {
527 let escape_char = ExprImpl::literal_varchar(escape_char.to_string());
528 vec![expr, pattern, escape_char]
529 }
530 None => vec![expr, pattern],
531 };
532 let func_call = FunctionCall::new_unchecked(expr_type, args, DataType::Boolean);
533 let func_call = if negated {
534 FunctionCall::new_unchecked(ExprType::Not, vec![func_call.into()], DataType::Boolean)
535 } else {
536 func_call
537 };
538 Ok(func_call.into())
539 }
540
541 pub(super) fn bind_similar_to(
543 &mut self,
544 expr: Expr,
545 negated: bool,
546 pattern: Expr,
547 escape_char: Option<EscapeChar>,
548 ) -> Result<ExprImpl> {
549 let expr = self.bind_expr_inner(expr)?;
550 let pattern = self.bind_expr_inner(pattern)?;
551
552 let esc_inputs = if let Some(escape_char) = escape_char {
553 let escape_char = ExprImpl::literal_varchar(escape_char.to_string());
554 vec![pattern, escape_char]
555 } else {
556 vec![pattern]
557 };
558
559 let esc_call =
560 FunctionCall::new_unchecked(ExprType::SimilarToEscape, esc_inputs, DataType::Varchar);
561
562 let regex_call = FunctionCall::new_unchecked(
563 ExprType::RegexpEq,
564 vec![expr, esc_call.into()],
565 DataType::Boolean,
566 );
567 let func_call = if negated {
568 FunctionCall::new_unchecked(ExprType::Not, vec![regex_call.into()], DataType::Boolean)
569 } else {
570 regex_call
571 };
572
573 Ok(func_call.into())
574 }
575
576 fn check_constant_case_when_optimization(
579 &mut self,
580 conditions: Vec<Expr>,
581 results_expr: Vec<ExprImpl>,
582 operand: Option<Box<Expr>>,
583 fallback: Option<ExprImpl>,
584 constant_case_when_eval_inputs: &mut Vec<ExprImpl>,
585 ) -> bool {
586 let operand_value;
588
589 if let Some(operand) = operand {
590 let Ok(operand) = self.bind_expr_inner(*operand) else {
591 return false;
592 };
593 if !operand.is_const() {
594 return false;
595 }
596 operand_value = operand;
597 } else {
598 return false;
599 }
600
601 for (condition, result) in zip_eq_fast(conditions, results_expr) {
602 if let Expr::Value(_) = condition.clone() {
603 let Ok(res) = self.bind_expr_inner(condition.clone()) else {
604 return false;
605 };
606 if res == operand_value {
608 constant_case_when_eval_inputs.push(result);
609 return true;
610 }
611 } else {
612 return false;
613 }
614 }
615
616 debug_assert!(
618 constant_case_when_eval_inputs.is_empty(),
619 "expect `inputs` to be empty"
620 );
621
622 let Some(fallback) = fallback else {
623 return false;
624 };
625
626 constant_case_when_eval_inputs.push(fallback);
627 true
628 }
629
630 fn compare_or_set(col_expr: &mut Option<Expr>, test_expr: Expr) -> bool {
633 let Expr::Identifier(test_ident) = test_expr else {
634 return false;
635 };
636 if let Some(expr) = col_expr {
637 let Expr::Identifier(ident) = expr else {
638 return false;
639 };
640 if ident.real_value() != test_ident.real_value() {
641 return false;
642 }
643 } else {
644 *col_expr = Some(Expr::Identifier(test_ident));
645 }
646 true
647 }
648
649 fn check_invariant(left: Expr, op: BinaryOperator, right: Expr) -> bool {
653 if op != BinaryOperator::Eq {
654 return false;
655 }
656 if let Expr::Identifier(_) = left {
657 let Expr::Value(_) = right else {
659 return false;
660 };
661 } else {
662 let Expr::Value(_) = left else {
664 return false;
665 };
666 let Expr::Identifier(_) = right else {
667 return false;
668 };
669 }
670 true
671 }
672
673 fn try_extract_simple_form(
678 &mut self,
679 ident_expr: Expr,
680 constant_expr: Expr,
681 column_expr: &mut Option<Expr>,
682 inputs: &mut Vec<ExprImpl>,
683 ) -> bool {
684 if !Self::compare_or_set(column_expr, ident_expr) {
685 return false;
686 }
687 let Ok(bound_expr) = self.bind_expr_inner(constant_expr) else {
688 return false;
689 };
690 inputs.push(bound_expr);
691 true
692 }
693
694 fn check_convert_simple_form(
698 &mut self,
699 conditions: Vec<Expr>,
700 results_expr: Vec<ExprImpl>,
701 fallback: Option<ExprImpl>,
702 constant_lookup_inputs: &mut Vec<ExprImpl>,
703 ) -> bool {
704 let mut column_expr = None;
705
706 for (condition, result) in zip_eq_fast(conditions, results_expr) {
707 if let Expr::BinaryOp { left, op, right } = condition {
708 if !Self::check_invariant(*(left.clone()), op.clone(), *(right.clone())) {
709 return false;
710 }
711 if let Expr::Identifier(_) = *(left.clone()) {
712 if !self.try_extract_simple_form(
713 *left,
714 *right,
715 &mut column_expr,
716 constant_lookup_inputs,
717 ) {
718 return false;
719 }
720 } else if !self.try_extract_simple_form(
721 *right,
722 *left,
723 &mut column_expr,
724 constant_lookup_inputs,
725 ) {
726 return false;
727 }
728 constant_lookup_inputs.push(result);
729 } else {
730 return false;
731 }
732 }
733
734 let Some(operand) = column_expr else {
736 return false;
737 };
738 let Ok(bound_operand) = self.bind_expr_inner(operand) else {
739 return false;
740 };
741 constant_lookup_inputs.insert(0, bound_operand);
742
743 if let Some(expr) = fallback {
745 constant_lookup_inputs.push(expr);
746 }
747
748 true
749 }
750
751 fn check_bind_case_optimization(
755 &mut self,
756 conditions: Vec<Expr>,
757 results_expr: Vec<ExprImpl>,
758 operand: Option<Box<Expr>>,
759 fallback: Option<ExprImpl>,
760 constant_lookup_inputs: &mut Vec<ExprImpl>,
761 ) -> bool {
762 if conditions.len() < CASE_WHEN_ARMS_OPTIMIZE_LIMIT {
763 return false;
764 }
765
766 if let Some(operand) = operand {
767 let Ok(operand) = self.bind_expr_inner(*operand) else {
768 return false;
769 };
770 if operand.is_const() {
774 return false;
775 }
776 constant_lookup_inputs.push(operand);
777 } else {
778 return self.check_convert_simple_form(
781 conditions,
782 results_expr,
783 fallback,
784 constant_lookup_inputs,
785 );
786 }
787
788 for (condition, result) in zip_eq_fast(conditions, results_expr) {
789 if let Expr::Value(_) = condition.clone() {
790 let Ok(input) = self.bind_expr_inner(condition.clone()) else {
791 return false;
792 };
793 constant_lookup_inputs.push(input);
794 } else {
795 return false;
798 }
799
800 constant_lookup_inputs.push(result);
801 }
802
803 if let Some(expr) = fallback {
805 constant_lookup_inputs.push(expr);
806 }
807
808 true
809 }
810
811 pub(super) fn bind_case(
812 &mut self,
813 operand: Option<Box<Expr>>,
814 conditions: Vec<Expr>,
815 results: Vec<Expr>,
816 else_result: Option<Box<Expr>>,
817 ) -> Result<ExprImpl> {
818 let mut inputs = Vec::new();
819 let results_expr: Vec<ExprImpl> = results
820 .into_iter()
821 .map(|expr| self.bind_expr_inner(expr))
822 .collect::<Result<_>>()?;
823 let else_result_expr = else_result
824 .map(|expr| self.bind_expr_inner(*expr))
825 .transpose()?;
826
827 let mut constant_lookup_inputs = Vec::new();
828 let mut constant_case_when_eval_inputs = Vec::new();
829
830 let constant_case_when_flag = self.check_constant_case_when_optimization(
831 conditions.clone(),
832 results_expr.clone(),
833 operand.clone(),
834 else_result_expr.clone(),
835 &mut constant_case_when_eval_inputs,
836 );
837
838 if constant_case_when_flag {
839 if constant_case_when_eval_inputs.len() != 1 {
841 return Err(ErrorCode::BindError(
842 "expect `constant_case_when_eval_inputs` only contains a single bound expression".to_owned()
843 )
844 .into());
845 }
846 return Ok(constant_case_when_eval_inputs[0].take());
848 }
849
850 let optimize_flag = self.check_bind_case_optimization(
852 conditions.clone(),
853 results_expr.clone(),
854 operand.clone(),
855 else_result_expr.clone(),
856 &mut constant_lookup_inputs,
857 );
858
859 if optimize_flag {
860 return Ok(FunctionCall::new(ExprType::ConstantLookup, constant_lookup_inputs)?.into());
861 }
862
863 for (condition, result) in zip_eq_fast(conditions, results_expr) {
864 let condition = match operand {
865 Some(ref t) => Expr::BinaryOp {
866 left: t.clone(),
867 op: BinaryOperator::Eq,
868 right: Box::new(condition),
869 },
870 None => condition,
871 };
872 inputs.push(
873 self.bind_expr_inner(condition)
874 .and_then(|expr| expr.enforce_bool_clause("CASE WHEN"))?,
875 );
876 inputs.push(result);
877 }
878
879 if let Some(expr) = else_result_expr {
881 inputs.push(expr);
882 }
883
884 if inputs.iter().any(ExprImpl::has_table_function) {
885 return Err(
886 ErrorCode::BindError("table functions are not allowed in CASE".into()).into(),
887 );
888 }
889
890 Ok(FunctionCall::new(ExprType::Case, inputs)?.into())
891 }
892
893 pub(super) fn bind_is_operator(&mut self, func_type: ExprType, expr: Expr) -> Result<ExprImpl> {
894 let expr = self.bind_expr_inner(expr)?;
895 Ok(FunctionCall::new(func_type, vec![expr])?.into())
896 }
897
898 pub(super) fn bind_is_unknown(&mut self, func_type: ExprType, expr: Expr) -> Result<ExprImpl> {
899 let expr = self
900 .bind_expr_inner(expr)?
901 .cast_implicit(DataType::Boolean)?;
902 Ok(FunctionCall::new(func_type, vec![expr])?.into())
903 }
904
905 pub(super) fn bind_distinct_from(&mut self, left: Expr, right: Expr) -> Result<ExprImpl> {
906 let left = self.bind_expr_inner(left)?;
907 let right = self.bind_expr_inner(right)?;
908 let func_call = FunctionCall::new(ExprType::IsDistinctFrom, vec![left, right]);
909 Ok(func_call?.into())
910 }
911
912 pub(super) fn bind_not_distinct_from(&mut self, left: Expr, right: Expr) -> Result<ExprImpl> {
913 let left = self.bind_expr_inner(left)?;
914 let right = self.bind_expr_inner(right)?;
915 let func_call = FunctionCall::new(ExprType::IsNotDistinctFrom, vec![left, right]);
916 Ok(func_call?.into())
917 }
918
919 pub(super) fn bind_cast(&mut self, expr: Expr, data_type: AstDataType) -> Result<ExprImpl> {
920 match &data_type {
921 AstDataType::Regclass => {
924 let input = self.bind_expr_inner(expr)?;
925 Ok(input.cast_to_regclass()?)
926 }
927 AstDataType::Regproc => {
928 let lhs = self.bind_expr_inner(expr)?;
929 let lhs_ty = lhs.return_type();
930 if lhs_ty == DataType::Varchar {
931 Ok(lhs)
935 } else {
936 Err(ErrorCode::BindError(format!("Can't cast {} to regproc", lhs_ty)).into())
937 }
938 }
939 AstDataType::Char(_) => self.bind_cast_inner(expr, DataType::Varchar),
953 _ => self.bind_cast_inner(expr, bind_data_type(&data_type)?),
954 }
955 }
956
957 pub fn bind_cast_inner(&mut self, expr: Expr, data_type: DataType) -> Result<ExprImpl> {
958 match (expr, data_type) {
959 (Expr::Array(Array { elem: ref expr, .. }), DataType::List(element_type)) => {
960 self.bind_array_cast(expr.clone(), element_type)
961 }
962 (Expr::Map { entries }, DataType::Map(m)) => self.bind_map_cast(entries, m),
963 (expr, data_type) => {
964 let lhs = self.bind_expr_inner(expr)?;
965 lhs.cast_explicit(data_type).map_err(Into::into)
966 }
967 }
968 }
969
970 pub fn bind_collate(&mut self, expr: Expr, collation: ObjectName) -> Result<ExprImpl> {
971 if !["C", "POSIX"].contains(&collation.real_value().as_str()) {
972 bail_not_implemented!("Collate collation other than `C` or `POSIX` is not implemented");
973 }
974
975 let bound_inner = self.bind_expr_inner(expr)?;
976 let ret_type = bound_inner.return_type();
977
978 match ret_type {
979 DataType::Varchar => {}
980 _ => {
981 return Err(ErrorCode::NotSupported(
982 format!("{} is not a collatable data type", ret_type),
983 "The only built-in collatable data types are `varchar`, please check your type"
984 .into(),
985 )
986 .into());
987 }
988 }
989
990 Ok(bound_inner)
991 }
992}
993
994pub fn bind_data_type(data_type: &AstDataType) -> Result<DataType> {
995 let new_err = || not_implemented!("unsupported data type: {:}", data_type);
996 let data_type = match data_type {
997 AstDataType::Boolean => DataType::Boolean,
998 AstDataType::SmallInt => DataType::Int16,
999 AstDataType::Int => DataType::Int32,
1000 AstDataType::BigInt => DataType::Int64,
1001 AstDataType::Real | AstDataType::Float(Some(1..=24)) => DataType::Float32,
1002 AstDataType::Double | AstDataType::Float(Some(25..=53) | None) => DataType::Float64,
1003 AstDataType::Float(Some(0 | 54..)) => unreachable!(),
1004 AstDataType::Decimal(None, None) => DataType::Decimal,
1005 AstDataType::Varchar | AstDataType::Text => DataType::Varchar,
1006 AstDataType::Date => DataType::Date,
1007 AstDataType::Time(false) => DataType::Time,
1008 AstDataType::Timestamp(false) => DataType::Timestamp,
1009 AstDataType::Timestamp(true) => DataType::Timestamptz,
1010 AstDataType::Interval => DataType::Interval,
1011 AstDataType::Array(datatype) => DataType::List(Box::new(bind_data_type(datatype)?)),
1012 AstDataType::Char(..) => {
1013 bail_not_implemented!("CHAR is not supported, please use VARCHAR instead")
1014 }
1015 AstDataType::Struct(types) => StructType::new(
1016 types
1017 .iter()
1018 .map(|f| Ok((f.name.real_value(), bind_data_type(&f.data_type)?)))
1019 .collect::<Result<Vec<_>>>()?,
1020 )
1021 .into(),
1022 AstDataType::Map(kv) => {
1023 let key = bind_data_type(&kv.0)?;
1024 let value = bind_data_type(&kv.1)?;
1025 DataType::Map(MapType::try_from_kv(key, value).map_err(ErrorCode::BindError)?)
1026 }
1027 AstDataType::Custom(qualified_type_name) => {
1028 let idents = qualified_type_name
1029 .0
1030 .iter()
1031 .map(|n| n.real_value())
1032 .collect_vec();
1033 let name = if idents.len() == 1 {
1034 idents[0].as_str() } else if idents.len() == 2 && idents[0] == PG_CATALOG_SCHEMA_NAME {
1036 idents[1].as_str() } else {
1038 return Err(new_err().into());
1039 };
1040
1041 match name {
1044 "int2" => DataType::Int16,
1045 "int4" => DataType::Int32,
1046 "int8" => DataType::Int64,
1047 "rw_int256" => DataType::Int256,
1048 "float4" => DataType::Float32,
1049 "float8" => DataType::Float64,
1050 "timestamptz" => DataType::Timestamptz,
1051 "text" => DataType::Varchar,
1052 "serial" => {
1053 return Err(ErrorCode::NotSupported(
1054 "Column type SERIAL is not supported".into(),
1055 "Please remove the SERIAL column".into(),
1056 )
1057 .into());
1058 }
1059 _ => return Err(new_err().into()),
1060 }
1061 }
1062 AstDataType::Bytea => DataType::Bytea,
1063 AstDataType::Jsonb => DataType::Jsonb,
1064 AstDataType::Regclass
1065 | AstDataType::Regproc
1066 | AstDataType::Uuid
1067 | AstDataType::Decimal(_, _)
1068 | AstDataType::Time(true) => return Err(new_err().into()),
1069 };
1070 Ok(data_type)
1071}