risingwave_frontend/binder/expr/
mod.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16use risingwave_common::catalog::PG_CATALOG_SCHEMA_NAME;
17use risingwave_common::types::{DataType, MapType, StructType};
18use risingwave_common::util::iter_util::zip_eq_fast;
19use risingwave_common::{bail_no_function, bail_not_implemented, not_implemented};
20use risingwave_sqlparser::ast::{
21    Array, BinaryOperator, DataType as AstDataType, EscapeChar, Expr, Function, JsonPredicateType,
22    ObjectName, Query, TrimWhereField, UnaryOperator,
23};
24
25use crate::binder::Binder;
26use crate::binder::expr::function::is_sys_function_without_args;
27use crate::error::{ErrorCode, Result, RwError};
28use crate::expr::{Expr as _, ExprImpl, ExprType, FunctionCall, InputRef, Parameter, SubqueryKind};
29use crate::handler::create_sql_function::SQL_UDF_PATTERN;
30
31mod binary_op;
32mod column;
33mod function;
34mod order_by;
35mod subquery;
36mod value;
37
38/// The limit arms for case-when expression
39/// When the number of condition arms exceed
40/// this limit, we will try optimize the case-when
41/// expression to `ConstantLookupExpression`
42/// Check `case.rs` for details.
43const CASE_WHEN_ARMS_OPTIMIZE_LIMIT: usize = 30;
44
45impl Binder {
46    /// Bind an expression with `bind_expr_inner`, attach the original expression
47    /// to the error message.
48    ///
49    /// This may only be called at the root of the expression tree or when crossing
50    /// the boundary of a subquery. Otherwise, the source chain might be too deep
51    /// and confusing to the user.
52    // TODO(error-handling): use a dedicated error type during binding to make it clear.
53    pub fn bind_expr(&mut self, expr: Expr) -> Result<ExprImpl> {
54        self.bind_expr_inner(expr.clone()).map_err(|e| {
55            RwError::from(ErrorCode::BindErrorRoot {
56                expr: expr.to_string(),
57                error: Box::new(e),
58            })
59        })
60    }
61
62    fn bind_expr_inner(&mut self, expr: Expr) -> Result<ExprImpl> {
63        match expr {
64            // literal
65            Expr::Value(v) => Ok(ExprImpl::Literal(Box::new(self.bind_value(v)?))),
66            Expr::TypedString { data_type, value } => {
67                let s: ExprImpl = self.bind_string(value)?.into();
68                s.cast_explicit(bind_data_type(&data_type)?)
69                    .map_err(Into::into)
70            }
71            Expr::Row(exprs) => self.bind_row(exprs),
72            // input ref
73            Expr::Identifier(ident) => {
74                if is_sys_function_without_args(&ident) {
75                    // Rewrite a system variable to a function call, e.g. `SELECT current_schema;`
76                    // will be rewritten to `SELECT current_schema();`.
77                    // NOTE: Here we don't 100% follow the behavior of Postgres, as it doesn't
78                    // allow `session_user()` while we do.
79                    self.bind_function(Function::no_arg(ObjectName(vec![ident])))
80                } else if let Some(ref lambda_args) = self.context.lambda_args {
81                    // We don't support capture, so if the expression is in the lambda context,
82                    // we'll not bind it for table columns.
83                    if let Some((arg_idx, arg_type)) = lambda_args.get(&ident.real_value()) {
84                        Ok(InputRef::new(*arg_idx, arg_type.clone()).into())
85                    } else {
86                        Err(
87                            ErrorCode::ItemNotFound(format!("Unknown arg: {}", ident.real_value()))
88                                .into(),
89                        )
90                    }
91                } else if let Some(ctx) = self.secure_compare_context.as_ref() {
92                    // Currently, the generated columns are not supported yet. So the ident here should only be one of the following
93                    // - `headers`
94                    // - secret name
95                    // - the name of the payload column
96                    // TODO(Kexiang): Generated columns or INCLUDE clause should be supported.
97                    if ident.real_value() == *"headers" {
98                        Ok(InputRef::new(0, DataType::Jsonb).into())
99                    } else if ctx.secret_name.is_some()
100                        && ident.real_value() == *ctx.secret_name.as_ref().unwrap()
101                    {
102                        Ok(InputRef::new(1, DataType::Varchar).into())
103                    } else if ident.real_value() == ctx.column_name {
104                        Ok(InputRef::new(2, DataType::Bytea).into())
105                    } else {
106                        Err(
107                            ErrorCode::ItemNotFound(format!("Unknown arg: {}", ident.real_value()))
108                                .into(),
109                        )
110                    }
111                } else {
112                    self.bind_column(&[ident])
113                }
114            }
115            Expr::CompoundIdentifier(idents) => self.bind_column(&idents),
116            Expr::FieldIdentifier(field_expr, idents) => {
117                self.bind_single_field_column(*field_expr, &idents)
118            }
119            // operators & functions
120            Expr::UnaryOp { op, expr } => self.bind_unary_expr(op, *expr),
121            Expr::BinaryOp { left, op, right } => self.bind_binary_op(*left, op, *right),
122            Expr::Nested(expr) => self.bind_expr_inner(*expr),
123            Expr::Array(Array { elem: exprs, .. }) => self.bind_array(exprs),
124            Expr::Index { obj, index } => self.bind_index(*obj, *index),
125            Expr::ArrayRangeIndex { obj, start, end } => {
126                self.bind_array_range_index(*obj, start, end)
127            }
128            Expr::Function(f) => self.bind_function(f),
129            Expr::Subquery(q) => self.bind_subquery_expr(*q, SubqueryKind::Scalar),
130            Expr::Exists(q) => self.bind_subquery_expr(*q, SubqueryKind::Existential),
131            Expr::InSubquery {
132                expr,
133                subquery,
134                negated,
135            } => self.bind_in_subquery(*expr, *subquery, negated),
136            // special syntax (except date/time or string)
137            Expr::Cast { expr, data_type } => self.bind_cast(*expr, data_type),
138            Expr::IsNull(expr) => self.bind_is_operator(ExprType::IsNull, *expr),
139            Expr::IsNotNull(expr) => self.bind_is_operator(ExprType::IsNotNull, *expr),
140            Expr::IsTrue(expr) => self.bind_is_operator(ExprType::IsTrue, *expr),
141            Expr::IsNotTrue(expr) => self.bind_is_operator(ExprType::IsNotTrue, *expr),
142            Expr::IsFalse(expr) => self.bind_is_operator(ExprType::IsFalse, *expr),
143            Expr::IsNotFalse(expr) => self.bind_is_operator(ExprType::IsNotFalse, *expr),
144            Expr::IsUnknown(expr) => self.bind_is_unknown(ExprType::IsNull, *expr),
145            Expr::IsNotUnknown(expr) => self.bind_is_unknown(ExprType::IsNotNull, *expr),
146            Expr::IsDistinctFrom(left, right) => self.bind_distinct_from(*left, *right),
147            Expr::IsNotDistinctFrom(left, right) => self.bind_not_distinct_from(*left, *right),
148            Expr::IsJson {
149                expr,
150                negated,
151                item_type,
152                unique_keys: false,
153            } => self.bind_is_json(*expr, negated, item_type),
154            Expr::Case {
155                operand,
156                conditions,
157                results,
158                else_result,
159            } => self.bind_case(operand, conditions, results, else_result),
160            Expr::Between {
161                expr,
162                negated,
163                low,
164                high,
165            } => self.bind_between(*expr, negated, *low, *high),
166            Expr::Like {
167                negated,
168                expr,
169                pattern,
170                escape_char,
171            } => self.bind_like(ExprType::Like, *expr, negated, *pattern, escape_char),
172            Expr::ILike {
173                negated,
174                expr,
175                pattern,
176                escape_char,
177            } => self.bind_like(ExprType::ILike, *expr, negated, *pattern, escape_char),
178            Expr::SimilarTo {
179                expr,
180                negated,
181                pattern,
182                escape_char,
183            } => self.bind_similar_to(*expr, negated, *pattern, escape_char),
184            Expr::InList {
185                expr,
186                list,
187                negated,
188            } => self.bind_in_list(*expr, list, negated),
189            // special syntax for date/time
190            Expr::Extract { field, expr } => self.bind_extract(field, *expr),
191            Expr::AtTimeZone {
192                timestamp,
193                time_zone,
194            } => self.bind_at_time_zone(*timestamp, *time_zone),
195            // special syntax for string
196            Expr::Trim {
197                expr,
198                trim_where,
199                trim_what,
200            } => self.bind_trim(*expr, trim_where, trim_what),
201            Expr::Substring {
202                expr,
203                substring_from,
204                substring_for,
205            } => self.bind_substring(*expr, substring_from, substring_for),
206            Expr::Position { substring, string } => self.bind_position(*substring, *string),
207            Expr::Overlay {
208                expr,
209                new_substring,
210                start,
211                count,
212            } => self.bind_overlay(*expr, *new_substring, *start, count),
213            Expr::Parameter { index } => self.bind_parameter(index),
214            Expr::Collate { expr, collation } => self.bind_collate(*expr, collation),
215            Expr::ArraySubquery(q) => self.bind_subquery_expr(*q, SubqueryKind::Array),
216            Expr::Map { entries } => self.bind_map(entries),
217            Expr::IsJson {
218                unique_keys: true, ..
219            }
220            | Expr::SomeOp(_)
221            | Expr::AllOp(_)
222            | Expr::TryCast { .. }
223            | Expr::GroupingSets(_)
224            | Expr::Cube(_)
225            | Expr::Rollup(_)
226            | Expr::LambdaFunction { .. } => {
227                bail_not_implemented!(issue = 112, "unsupported expression {:?}", expr)
228            }
229        }
230    }
231
232    pub(super) fn bind_extract(&mut self, field: String, expr: Expr) -> Result<ExprImpl> {
233        let arg = self.bind_expr_inner(expr)?;
234        let arg_type = arg.return_type();
235        Ok(FunctionCall::new(
236            ExprType::Extract,
237            vec![self.bind_string(field.clone())?.into(), arg],
238        )
239        .map_err(|_| {
240            not_implemented!(
241                issue = 112,
242                "function extract({} from {:?}) doesn't exist",
243                field,
244                arg_type
245            )
246        })?
247        .into())
248    }
249
250    pub(super) fn bind_at_time_zone(&mut self, input: Expr, time_zone: Expr) -> Result<ExprImpl> {
251        let input = self.bind_expr_inner(input)?;
252        let time_zone = self.bind_expr_inner(time_zone)?;
253        FunctionCall::new(ExprType::AtTimeZone, vec![input, time_zone]).map(Into::into)
254    }
255
256    pub(super) fn bind_in_list(
257        &mut self,
258        expr: Expr,
259        list: Vec<Expr>,
260        negated: bool,
261    ) -> Result<ExprImpl> {
262        let left = self.bind_expr_inner(expr)?;
263        let mut bound_expr_list = vec![left.clone()];
264        let mut non_const_exprs = vec![];
265        for elem in list {
266            let expr = self.bind_expr_inner(elem)?;
267            match expr.is_const() {
268                true => bound_expr_list.push(expr),
269                false => non_const_exprs.push(expr),
270            }
271        }
272        let mut ret = FunctionCall::new(ExprType::In, bound_expr_list)?.into();
273        // Non-const exprs are not part of IN-expr in backend and rewritten into OR-Equal-exprs.
274        for expr in non_const_exprs {
275            ret = FunctionCall::new(
276                ExprType::Or,
277                vec![
278                    ret,
279                    FunctionCall::new(ExprType::Equal, vec![left.clone(), expr])?.into(),
280                ],
281            )?
282            .into();
283        }
284        if negated {
285            Ok(FunctionCall::new_unchecked(ExprType::Not, vec![ret], DataType::Boolean).into())
286        } else {
287            Ok(ret)
288        }
289    }
290
291    pub(super) fn bind_in_subquery(
292        &mut self,
293        expr: Expr,
294        subquery: Query,
295        negated: bool,
296    ) -> Result<ExprImpl> {
297        let bound_expr = self.bind_expr_inner(expr)?;
298        let bound_subquery = self.bind_subquery_expr(subquery, SubqueryKind::In(bound_expr))?;
299        if negated {
300            Ok(
301                FunctionCall::new_unchecked(ExprType::Not, vec![bound_subquery], DataType::Boolean)
302                    .into(),
303            )
304        } else {
305            Ok(bound_subquery)
306        }
307    }
308
309    pub(super) fn bind_is_json(
310        &mut self,
311        expr: Expr,
312        negated: bool,
313        item_type: JsonPredicateType,
314    ) -> Result<ExprImpl> {
315        let mut args = vec![self.bind_expr_inner(expr)?];
316        // Avoid `JsonPredicateType::to_string` so that we decouple sqlparser from expr execution
317        let type_symbol = match item_type {
318            JsonPredicateType::Value => None,
319            JsonPredicateType::Array => Some("ARRAY"),
320            JsonPredicateType::Object => Some("OBJECT"),
321            JsonPredicateType::Scalar => Some("SCALAR"),
322        };
323        if let Some(s) = type_symbol {
324            args.push(ExprImpl::literal_varchar(s.into()));
325        }
326
327        let is_json = FunctionCall::new(ExprType::IsJson, args)?.into();
328        if negated {
329            Ok(FunctionCall::new(ExprType::Not, vec![is_json])?.into())
330        } else {
331            Ok(is_json)
332        }
333    }
334
335    pub(super) fn bind_unary_expr(&mut self, op: UnaryOperator, expr: Expr) -> Result<ExprImpl> {
336        let func_type = match &op {
337            UnaryOperator::Not => ExprType::Not,
338            UnaryOperator::Minus => ExprType::Neg,
339            UnaryOperator::Plus => {
340                return self.rewrite_positive(expr);
341            }
342            UnaryOperator::Custom(name) => match name.as_str() {
343                "~" => ExprType::BitwiseNot,
344                "@" => ExprType::Abs,
345                "|/" => ExprType::Sqrt,
346                "||/" => ExprType::Cbrt,
347                _ => bail_not_implemented!(issue = 112, "unsupported unary expression: {:?}", op),
348            },
349            UnaryOperator::PGQualified(_) => {
350                bail_not_implemented!(issue = 112, "unsupported unary expression: {:?}", op)
351            }
352        };
353        let expr = self.bind_expr_inner(expr)?;
354        FunctionCall::new(func_type, vec![expr]).map(|f| f.into())
355    }
356
357    /// Directly returns the expression itself if it is a positive number.
358    fn rewrite_positive(&mut self, expr: Expr) -> Result<ExprImpl> {
359        let expr = self.bind_expr_inner(expr)?;
360        let return_type = expr.return_type();
361        if return_type.is_numeric() {
362            return Ok(expr);
363        }
364        Err(ErrorCode::InvalidInputSyntax(format!("+ {:?}", return_type)).into())
365    }
366
367    pub(super) fn bind_trim(
368        &mut self,
369        expr: Expr,
370        // BOTH | LEADING | TRAILING
371        trim_where: Option<TrimWhereField>,
372        trim_what: Option<Box<Expr>>,
373    ) -> Result<ExprImpl> {
374        let mut inputs = vec![self.bind_expr_inner(expr)?];
375        let func_type = match trim_where {
376            Some(TrimWhereField::Both) => ExprType::Trim,
377            Some(TrimWhereField::Leading) => ExprType::Ltrim,
378            Some(TrimWhereField::Trailing) => ExprType::Rtrim,
379            None => ExprType::Trim,
380        };
381        if let Some(t) = trim_what {
382            inputs.push(self.bind_expr_inner(*t)?);
383        }
384        Ok(FunctionCall::new(func_type, inputs)?.into())
385    }
386
387    fn bind_substring(
388        &mut self,
389        expr: Expr,
390        substring_from: Option<Box<Expr>>,
391        substring_for: Option<Box<Expr>>,
392    ) -> Result<ExprImpl> {
393        let mut args = vec![
394            self.bind_expr_inner(expr)?,
395            match substring_from {
396                Some(expr) => self.bind_expr_inner(*expr)?,
397                None => ExprImpl::literal_int(1),
398            },
399        ];
400        if let Some(expr) = substring_for {
401            args.push(self.bind_expr_inner(*expr)?);
402        }
403        FunctionCall::new(ExprType::Substr, args).map(|f| f.into())
404    }
405
406    fn bind_position(&mut self, substring: Expr, string: Expr) -> Result<ExprImpl> {
407        let args = vec![
408            // Note that we reverse the order of arguments.
409            self.bind_expr_inner(string)?,
410            self.bind_expr_inner(substring)?,
411        ];
412        FunctionCall::new(ExprType::Position, args).map(Into::into)
413    }
414
415    fn bind_overlay(
416        &mut self,
417        expr: Expr,
418        new_substring: Expr,
419        start: Expr,
420        count: Option<Box<Expr>>,
421    ) -> Result<ExprImpl> {
422        let mut args = vec![
423            self.bind_expr_inner(expr)?,
424            self.bind_expr_inner(new_substring)?,
425            self.bind_expr_inner(start)?,
426        ];
427        if let Some(count) = count {
428            args.push(self.bind_expr_inner(*count)?);
429        }
430        FunctionCall::new(ExprType::Overlay, args).map(|f| f.into())
431    }
432
433    fn bind_parameter(&mut self, index: u64) -> Result<ExprImpl> {
434        // Special check for sql udf
435        // Note: This is specific to sql udf with unnamed parameters, since the
436        // parameters will be parsed and treated as `Parameter`.
437        // For detailed explanation, consider checking `bind_column`.
438        if self.udf_context.global_count() != 0 {
439            if let Some(expr) = self.udf_context.get_expr(&format!("${index}")) {
440                return Ok(expr.clone());
441            }
442            // Same as `bind_column`, the error message here
443            // help with hint display when invalid definition occurs
444            return Err(ErrorCode::BindError(format!(
445                "{SQL_UDF_PATTERN} failed to find unnamed parameter ${index}"
446            ))
447            .into());
448        }
449
450        Ok(Parameter::new(index, self.param_types.clone()).into())
451    }
452
453    /// Bind `expr (not) between low and high`
454    pub(super) fn bind_between(
455        &mut self,
456        expr: Expr,
457        negated: bool,
458        low: Expr,
459        high: Expr,
460    ) -> Result<ExprImpl> {
461        let expr = self.bind_expr_inner(expr)?;
462        let low = self.bind_expr_inner(low)?;
463        let high = self.bind_expr_inner(high)?;
464
465        let func_call = if negated {
466            // negated = true: expr < low or expr > high
467            FunctionCall::new_unchecked(
468                ExprType::Or,
469                vec![
470                    FunctionCall::new(ExprType::LessThan, vec![expr.clone(), low])?.into(),
471                    FunctionCall::new(ExprType::GreaterThan, vec![expr, high])?.into(),
472                ],
473                DataType::Boolean,
474            )
475        } else {
476            // negated = false: expr >= low and expr <= high
477            FunctionCall::new_unchecked(
478                ExprType::And,
479                vec![
480                    FunctionCall::new(ExprType::GreaterThanOrEqual, vec![expr.clone(), low])?
481                        .into(),
482                    FunctionCall::new(ExprType::LessThanOrEqual, vec![expr, high])?.into(),
483                ],
484                DataType::Boolean,
485            )
486        };
487
488        Ok(func_call.into())
489    }
490
491    fn bind_like(
492        &mut self,
493        expr_type: ExprType,
494        expr: Expr,
495        negated: bool,
496        pattern: Expr,
497        escape_char: Option<EscapeChar>,
498    ) -> Result<ExprImpl> {
499        if matches!(pattern, Expr::AllOp(_) | Expr::SomeOp(_)) {
500            if escape_char.is_some() {
501                // PostgreSQL also don't support the pattern due to the complexity of implementation.
502                // The SQL will failed on PostgreSQL 16.1:
503                // ```sql
504                // select 'a' like any(array[null]) escape '';
505                // ```
506                bail_not_implemented!(
507                    "LIKE with both ALL|ANY pattern and escape character is not supported"
508                )
509            }
510            // Use the `bind_binary_op` path to handle the ALL|ANY pattern.
511            let op = match (expr_type, negated) {
512                (ExprType::Like, false) => BinaryOperator::Custom("~~".to_owned()),
513                (ExprType::Like, true) => BinaryOperator::Custom("!~~".to_owned()),
514                (ExprType::ILike, false) => BinaryOperator::Custom("~~*".to_owned()),
515                (ExprType::ILike, true) => BinaryOperator::Custom("!~~*".to_owned()),
516                _ => unreachable!(),
517            };
518            return self.bind_binary_op(expr, op, pattern);
519        }
520        let expr = self.bind_expr_inner(expr)?;
521        let pattern = self.bind_expr_inner(pattern)?;
522        match (expr.return_type(), pattern.return_type()) {
523            (DataType::Varchar, DataType::Varchar) => {}
524            (string_ty, pattern_ty) => match expr_type {
525                ExprType::Like => bail_no_function!("like({}, {})", string_ty, pattern_ty),
526                ExprType::ILike => bail_no_function!("ilike({}, {})", string_ty, pattern_ty),
527                _ => unreachable!(),
528            },
529        }
530        let args = match escape_char {
531            Some(escape_char) => {
532                let escape_char = ExprImpl::literal_varchar(escape_char.to_string());
533                vec![expr, pattern, escape_char]
534            }
535            None => vec![expr, pattern],
536        };
537        let func_call = FunctionCall::new_unchecked(expr_type, args, DataType::Boolean);
538        let func_call = if negated {
539            FunctionCall::new_unchecked(ExprType::Not, vec![func_call.into()], DataType::Boolean)
540        } else {
541            func_call
542        };
543        Ok(func_call.into())
544    }
545
546    /// Bind `<expr> [ NOT ] SIMILAR TO <pat> ESCAPE <esc_text>`
547    pub(super) fn bind_similar_to(
548        &mut self,
549        expr: Expr,
550        negated: bool,
551        pattern: Expr,
552        escape_char: Option<EscapeChar>,
553    ) -> Result<ExprImpl> {
554        let expr = self.bind_expr_inner(expr)?;
555        let pattern = self.bind_expr_inner(pattern)?;
556
557        let esc_inputs = if let Some(escape_char) = escape_char {
558            let escape_char = ExprImpl::literal_varchar(escape_char.to_string());
559            vec![pattern, escape_char]
560        } else {
561            vec![pattern]
562        };
563
564        let esc_call =
565            FunctionCall::new_unchecked(ExprType::SimilarToEscape, esc_inputs, DataType::Varchar);
566
567        let regex_call = FunctionCall::new_unchecked(
568            ExprType::RegexpEq,
569            vec![expr, esc_call.into()],
570            DataType::Boolean,
571        );
572        let func_call = if negated {
573            FunctionCall::new_unchecked(ExprType::Not, vec![regex_call.into()], DataType::Boolean)
574        } else {
575            regex_call
576        };
577
578        Ok(func_call.into())
579    }
580
581    /// The optimization check for the following case-when expression pattern
582    /// e.g., select case 1 when (...) then (...) else (...) end;
583    fn check_constant_case_when_optimization(
584        &mut self,
585        conditions: Vec<Expr>,
586        results_expr: Vec<ExprImpl>,
587        operand: Option<Box<Expr>>,
588        fallback: Option<ExprImpl>,
589        constant_case_when_eval_inputs: &mut Vec<ExprImpl>,
590    ) -> bool {
591        // The operand value to be compared later
592        let operand_value;
593
594        if let Some(operand) = operand {
595            let Ok(operand) = self.bind_expr_inner(*operand) else {
596                return false;
597            };
598            if !operand.is_const() {
599                return false;
600            }
601            operand_value = operand;
602        } else {
603            return false;
604        }
605
606        for (condition, result) in zip_eq_fast(conditions, results_expr) {
607            if let Expr::Value(_) = condition.clone() {
608                let Ok(res) = self.bind_expr_inner(condition.clone()) else {
609                    return false;
610                };
611                // Found a match
612                if res == operand_value {
613                    constant_case_when_eval_inputs.push(result);
614                    return true;
615                }
616            } else {
617                return false;
618            }
619        }
620
621        // Otherwise this will eventually go through fallback arm
622        debug_assert!(
623            constant_case_when_eval_inputs.is_empty(),
624            "expect `inputs` to be empty"
625        );
626
627        let Some(fallback) = fallback else {
628            return false;
629        };
630
631        constant_case_when_eval_inputs.push(fallback);
632        true
633    }
634
635    /// Helper function to compare or set column identifier
636    /// used in `check_convert_simple_form`
637    fn compare_or_set(col_expr: &mut Option<Expr>, test_expr: Expr) -> bool {
638        let Expr::Identifier(test_ident) = test_expr else {
639            return false;
640        };
641        if let Some(expr) = col_expr {
642            let Expr::Identifier(ident) = expr else {
643                return false;
644            };
645            if ident.real_value() != test_ident.real_value() {
646                return false;
647            }
648        } else {
649            *col_expr = Some(Expr::Identifier(test_ident));
650        }
651        true
652    }
653
654    /// left expression and right expression must be either:
655    /// `<constant> <Eq> <identifier>` or `<identifier> <Eq> <constant>`
656    /// used in `check_convert_simple_form`
657    fn check_invariant(left: Expr, op: BinaryOperator, right: Expr) -> bool {
658        if op != BinaryOperator::Eq {
659            return false;
660        }
661        if let Expr::Identifier(_) = left {
662            // <identifier> <Eq> <constant>
663            let Expr::Value(_) = right else {
664                return false;
665            };
666        } else {
667            // <constant> <Eq> <identifier>
668            let Expr::Value(_) = left else {
669                return false;
670            };
671            let Expr::Identifier(_) = right else {
672                return false;
673            };
674        }
675        true
676    }
677
678    /// Helper function to extract expression out and insert
679    /// the corresponding bound version to `inputs`
680    /// used in `check_convert_simple_form`
681    /// Note: this function will be invoked per arm
682    fn try_extract_simple_form(
683        &mut self,
684        ident_expr: Expr,
685        constant_expr: Expr,
686        column_expr: &mut Option<Expr>,
687        inputs: &mut Vec<ExprImpl>,
688    ) -> bool {
689        if !Self::compare_or_set(column_expr, ident_expr) {
690            return false;
691        }
692        let Ok(bound_expr) = self.bind_expr_inner(constant_expr) else {
693            return false;
694        };
695        inputs.push(bound_expr);
696        true
697    }
698
699    /// See if the case when expression in form
700    /// `select case when <expr_1 = constant> (...with same pattern...) else <constant> end;`
701    /// If so, this expression could also be converted to constant lookup
702    fn check_convert_simple_form(
703        &mut self,
704        conditions: Vec<Expr>,
705        results_expr: Vec<ExprImpl>,
706        fallback: Option<ExprImpl>,
707        constant_lookup_inputs: &mut Vec<ExprImpl>,
708    ) -> bool {
709        let mut column_expr = None;
710
711        for (condition, result) in zip_eq_fast(conditions, results_expr) {
712            if let Expr::BinaryOp { left, op, right } = condition {
713                if !Self::check_invariant(*(left.clone()), op.clone(), *(right.clone())) {
714                    return false;
715                }
716                if let Expr::Identifier(_) = *(left.clone()) {
717                    if !self.try_extract_simple_form(
718                        *left,
719                        *right,
720                        &mut column_expr,
721                        constant_lookup_inputs,
722                    ) {
723                        return false;
724                    }
725                } else if !self.try_extract_simple_form(
726                    *right,
727                    *left,
728                    &mut column_expr,
729                    constant_lookup_inputs,
730                ) {
731                    return false;
732                }
733                constant_lookup_inputs.push(result);
734            } else {
735                return false;
736            }
737        }
738
739        // Insert operand first
740        let Some(operand) = column_expr else {
741            return false;
742        };
743        let Ok(bound_operand) = self.bind_expr_inner(operand) else {
744            return false;
745        };
746        constant_lookup_inputs.insert(0, bound_operand);
747
748        // fallback insertion
749        if let Some(expr) = fallback {
750            constant_lookup_inputs.push(expr);
751        }
752
753        true
754    }
755
756    /// The helper function to check if the current case-when
757    /// expression in `bind_case` could be optimized
758    /// into `ConstantLookupExpression`
759    fn check_bind_case_optimization(
760        &mut self,
761        conditions: Vec<Expr>,
762        results_expr: Vec<ExprImpl>,
763        operand: Option<Box<Expr>>,
764        fallback: Option<ExprImpl>,
765        constant_lookup_inputs: &mut Vec<ExprImpl>,
766    ) -> bool {
767        if conditions.len() < CASE_WHEN_ARMS_OPTIMIZE_LIMIT {
768            return false;
769        }
770
771        if let Some(operand) = operand {
772            let Ok(operand) = self.bind_expr_inner(*operand) else {
773                return false;
774            };
775            // This optimization should be done in subsequent optimization phase
776            // if the operand is const
777            // e.g., select case 1 when 1 then 114514 else 1919810 end;
778            if operand.is_const() {
779                return false;
780            }
781            constant_lookup_inputs.push(operand);
782        } else {
783            // Try converting to simple form
784            // see the example as illustrated in `check_convert_simple_form`
785            return self.check_convert_simple_form(
786                conditions,
787                results_expr,
788                fallback,
789                constant_lookup_inputs,
790            );
791        }
792
793        for (condition, result) in zip_eq_fast(conditions, results_expr) {
794            if let Expr::Value(_) = condition.clone() {
795                let Ok(input) = self.bind_expr_inner(condition.clone()) else {
796                    return false;
797                };
798                constant_lookup_inputs.push(input);
799            } else {
800                // If at least one condition is not in the simple form / not constant,
801                // we can NOT do the subsequent optimization pass
802                return false;
803            }
804
805            constant_lookup_inputs.push(result);
806        }
807
808        // The fallback arm for case-when expression
809        if let Some(expr) = fallback {
810            constant_lookup_inputs.push(expr);
811        }
812
813        true
814    }
815
816    pub(super) fn bind_case(
817        &mut self,
818        operand: Option<Box<Expr>>,
819        conditions: Vec<Expr>,
820        results: Vec<Expr>,
821        else_result: Option<Box<Expr>>,
822    ) -> Result<ExprImpl> {
823        let mut inputs = Vec::new();
824        let results_expr: Vec<ExprImpl> = results
825            .into_iter()
826            .map(|expr| self.bind_expr_inner(expr))
827            .collect::<Result<_>>()?;
828        let else_result_expr = else_result
829            .map(|expr| self.bind_expr_inner(*expr))
830            .transpose()?;
831
832        let mut constant_lookup_inputs = Vec::new();
833        let mut constant_case_when_eval_inputs = Vec::new();
834
835        let constant_case_when_flag = self.check_constant_case_when_optimization(
836            conditions.clone(),
837            results_expr.clone(),
838            operand.clone(),
839            else_result_expr.clone(),
840            &mut constant_case_when_eval_inputs,
841        );
842
843        if constant_case_when_flag {
844            // Sanity check
845            if constant_case_when_eval_inputs.len() != 1 {
846                return Err(ErrorCode::BindError(
847                    "expect `constant_case_when_eval_inputs` only contains a single bound expression".to_owned()
848                )
849                .into());
850            }
851            // Directly return the first element of the vector
852            return Ok(constant_case_when_eval_inputs[0].take());
853        }
854
855        // See if the case-when expression can be optimized
856        let optimize_flag = self.check_bind_case_optimization(
857            conditions.clone(),
858            results_expr.clone(),
859            operand.clone(),
860            else_result_expr.clone(),
861            &mut constant_lookup_inputs,
862        );
863
864        if optimize_flag {
865            return Ok(FunctionCall::new(ExprType::ConstantLookup, constant_lookup_inputs)?.into());
866        }
867
868        for (condition, result) in zip_eq_fast(conditions, results_expr) {
869            let condition = match operand {
870                Some(ref t) => Expr::BinaryOp {
871                    left: t.clone(),
872                    op: BinaryOperator::Eq,
873                    right: Box::new(condition),
874                },
875                None => condition,
876            };
877            inputs.push(
878                self.bind_expr_inner(condition)
879                    .and_then(|expr| expr.enforce_bool_clause("CASE WHEN"))?,
880            );
881            inputs.push(result);
882        }
883
884        // The fallback arm for case-when expression
885        if let Some(expr) = else_result_expr {
886            inputs.push(expr);
887        }
888
889        if inputs.iter().any(ExprImpl::has_table_function) {
890            return Err(
891                ErrorCode::BindError("table functions are not allowed in CASE".into()).into(),
892            );
893        }
894
895        Ok(FunctionCall::new(ExprType::Case, inputs)?.into())
896    }
897
898    pub(super) fn bind_is_operator(&mut self, func_type: ExprType, expr: Expr) -> Result<ExprImpl> {
899        let expr = self.bind_expr_inner(expr)?;
900        Ok(FunctionCall::new(func_type, vec![expr])?.into())
901    }
902
903    pub(super) fn bind_is_unknown(&mut self, func_type: ExprType, expr: Expr) -> Result<ExprImpl> {
904        let expr = self
905            .bind_expr_inner(expr)?
906            .cast_implicit(DataType::Boolean)?;
907        Ok(FunctionCall::new(func_type, vec![expr])?.into())
908    }
909
910    pub(super) fn bind_distinct_from(&mut self, left: Expr, right: Expr) -> Result<ExprImpl> {
911        let left = self.bind_expr_inner(left)?;
912        let right = self.bind_expr_inner(right)?;
913        let func_call = FunctionCall::new(ExprType::IsDistinctFrom, vec![left, right]);
914        Ok(func_call?.into())
915    }
916
917    pub(super) fn bind_not_distinct_from(&mut self, left: Expr, right: Expr) -> Result<ExprImpl> {
918        let left = self.bind_expr_inner(left)?;
919        let right = self.bind_expr_inner(right)?;
920        let func_call = FunctionCall::new(ExprType::IsNotDistinctFrom, vec![left, right]);
921        Ok(func_call?.into())
922    }
923
924    pub(super) fn bind_cast(&mut self, expr: Expr, data_type: AstDataType) -> Result<ExprImpl> {
925        match &data_type {
926            // Casting to Regclass type means getting the oid of expr.
927            // See https://www.postgresql.org/docs/current/datatype-oid.html.
928            AstDataType::Regclass => {
929                let input = self.bind_expr_inner(expr)?;
930                Ok(input.cast_to_regclass()?)
931            }
932            AstDataType::Regproc => {
933                let lhs = self.bind_expr_inner(expr)?;
934                let lhs_ty = lhs.return_type();
935                if lhs_ty == DataType::Varchar {
936                    // FIXME: Currently, we only allow VARCHAR to be casted to Regproc.
937                    // FIXME: Check whether it's a valid proc
938                    // FIXME: The return type should be casted to Regproc, but we don't have this type.
939                    Ok(lhs)
940                } else {
941                    Err(ErrorCode::BindError(format!("Can't cast {} to regproc", lhs_ty)).into())
942                }
943            }
944            // Redirect cast char to varchar to make system like Metabase happy.
945            // Char is not supported in RisingWave, but some ecosystem tools like Metabase will use it.
946            // Notice that the behavior of `char` and `varchar` is different in PostgreSQL.
947            // The following sql result should be different in PostgreSQL:
948            // ```
949            // select 'a'::char(2) = 'a '::char(2);
950            // ----------
951            // t
952            //
953            // select 'a'::varchar = 'a '::varchar;
954            // ----------
955            // f
956            // ```
957            AstDataType::Char(_) => self.bind_cast_inner(expr, DataType::Varchar),
958            _ => self.bind_cast_inner(expr, bind_data_type(&data_type)?),
959        }
960    }
961
962    pub fn bind_cast_inner(&mut self, expr: Expr, data_type: DataType) -> Result<ExprImpl> {
963        match (expr, data_type) {
964            (Expr::Array(Array { elem: ref expr, .. }), DataType::List(element_type)) => {
965                self.bind_array_cast(expr.clone(), element_type)
966            }
967            (Expr::Map { entries }, DataType::Map(m)) => self.bind_map_cast(entries, m),
968            (expr, data_type) => {
969                let lhs = self.bind_expr_inner(expr)?;
970                lhs.cast_explicit(data_type).map_err(Into::into)
971            }
972        }
973    }
974
975    pub fn bind_collate(&mut self, expr: Expr, collation: ObjectName) -> Result<ExprImpl> {
976        if !["C", "POSIX"].contains(&collation.real_value().as_str()) {
977            bail_not_implemented!("Collate collation other than `C` or `POSIX` is not implemented");
978        }
979
980        let bound_inner = self.bind_expr_inner(expr)?;
981        let ret_type = bound_inner.return_type();
982
983        match ret_type {
984            DataType::Varchar => {}
985            _ => {
986                return Err(ErrorCode::NotSupported(
987                    format!("{} is not a collatable data type", ret_type),
988                    "The only built-in collatable data types are `varchar`, please check your type"
989                        .into(),
990                )
991                .into());
992            }
993        }
994
995        Ok(bound_inner)
996    }
997}
998
999pub fn bind_data_type(data_type: &AstDataType) -> Result<DataType> {
1000    let new_err = || not_implemented!("unsupported data type: {:}", data_type);
1001    let data_type = match data_type {
1002        AstDataType::Boolean => DataType::Boolean,
1003        AstDataType::SmallInt => DataType::Int16,
1004        AstDataType::Int => DataType::Int32,
1005        AstDataType::BigInt => DataType::Int64,
1006        AstDataType::Real | AstDataType::Float(Some(1..=24)) => DataType::Float32,
1007        AstDataType::Double | AstDataType::Float(Some(25..=53) | None) => DataType::Float64,
1008        AstDataType::Float(Some(0 | 54..)) => unreachable!(),
1009        AstDataType::Decimal(None, None) => DataType::Decimal,
1010        AstDataType::Varchar | AstDataType::Text => DataType::Varchar,
1011        AstDataType::Date => DataType::Date,
1012        AstDataType::Time(false) => DataType::Time,
1013        AstDataType::Timestamp(false) => DataType::Timestamp,
1014        AstDataType::Timestamp(true) => DataType::Timestamptz,
1015        AstDataType::Interval => DataType::Interval,
1016        AstDataType::Array(datatype) => DataType::List(Box::new(bind_data_type(datatype)?)),
1017        AstDataType::Char(..) => {
1018            bail_not_implemented!("CHAR is not supported, please use VARCHAR instead")
1019        }
1020        AstDataType::Struct(types) => StructType::new(
1021            types
1022                .iter()
1023                .map(|f| Ok((f.name.real_value(), bind_data_type(&f.data_type)?)))
1024                .collect::<Result<Vec<_>>>()?,
1025        )
1026        .into(),
1027        AstDataType::Map(kv) => {
1028            let key = bind_data_type(&kv.0)?;
1029            let value = bind_data_type(&kv.1)?;
1030            DataType::Map(MapType::try_from_kv(key, value).map_err(ErrorCode::BindError)?)
1031        }
1032        AstDataType::Custom(qualified_type_name) => {
1033            let idents = qualified_type_name
1034                .0
1035                .iter()
1036                .map(|n| n.real_value())
1037                .collect_vec();
1038            let name = if idents.len() == 1 {
1039                idents[0].as_str() // `int2`
1040            } else if idents.len() == 2 && idents[0] == PG_CATALOG_SCHEMA_NAME {
1041                idents[1].as_str() // `pg_catalog.text`
1042            } else {
1043                return Err(new_err().into());
1044            };
1045
1046            // In PostgreSQL, these are non-keywords or non-reserved keywords but pre-defined
1047            // names that could be extended by `CREATE TYPE`.
1048            match name {
1049                "int2" => DataType::Int16,
1050                "int4" => DataType::Int32,
1051                "int8" => DataType::Int64,
1052                "rw_int256" => DataType::Int256,
1053                "float4" => DataType::Float32,
1054                "float8" => DataType::Float64,
1055                "timestamptz" => DataType::Timestamptz,
1056                "text" => DataType::Varchar,
1057                "serial" => {
1058                    return Err(ErrorCode::NotSupported(
1059                        "Column type SERIAL is not supported".into(),
1060                        "Please remove the SERIAL column".into(),
1061                    )
1062                    .into());
1063                }
1064                _ => return Err(new_err().into()),
1065            }
1066        }
1067        AstDataType::Bytea => DataType::Bytea,
1068        AstDataType::Jsonb => DataType::Jsonb,
1069        AstDataType::Vector(size) => DataType::Vector(*size as _),
1070        AstDataType::Regclass
1071        | AstDataType::Regproc
1072        | AstDataType::Uuid
1073        | AstDataType::Decimal(_, _)
1074        | AstDataType::Time(true) => return Err(new_err().into()),
1075    };
1076    Ok(data_type)
1077}