risingwave_frontend/binder/expr/
mod.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::slice;
16
17use itertools::Itertools;
18use risingwave_common::catalog::PG_CATALOG_SCHEMA_NAME;
19use risingwave_common::types::{DataType, MapType, StructType};
20use risingwave_common::util::iter_util::zip_eq_fast;
21use risingwave_common::{bail_no_function, bail_not_implemented, not_implemented};
22use risingwave_sqlparser::ast::{
23    Array, BinaryOperator, DataType as AstDataType, EscapeChar, Expr, Function, JsonPredicateType,
24    ObjectName, Query, TrimWhereField, UnaryOperator,
25};
26
27use crate::binder::Binder;
28use crate::binder::expr::function::is_sys_function_without_args;
29use crate::error::{ErrorCode, Result, RwError};
30use crate::expr::{
31    Expr as _, ExprImpl, ExprRewriter as _, ExprType, FunctionCall, InputRef,
32    InputRefDepthRewriter, Parameter, SubqueryKind,
33};
34
35mod binary_op;
36mod column;
37mod function;
38mod order_by;
39mod subquery;
40mod value;
41
42/// The limit arms for case-when expression
43/// When the number of condition arms exceed
44/// this limit, we will try optimize the case-when
45/// expression to `ConstantLookupExpression`
46/// Check `case.rs` for details.
47const CASE_WHEN_ARMS_OPTIMIZE_LIMIT: usize = 30;
48
49impl Binder {
50    /// Bind an expression with `bind_expr_inner`, attach the original expression
51    /// to the error message.
52    ///
53    /// This may only be called at the root of the expression tree or when crossing
54    /// the boundary of a subquery. Otherwise, the source chain might be too deep
55    /// and confusing to the user.
56    // TODO(error-handling): use a dedicated error type during binding to make it clear.
57    pub fn bind_expr(&mut self, expr: &Expr) -> Result<ExprImpl> {
58        self.bind_expr_inner(expr).map_err(|e| {
59            RwError::from(ErrorCode::BindErrorRoot {
60                expr: expr.to_string(),
61                error: Box::new(e),
62            })
63        })
64    }
65
66    fn bind_expr_inner(&mut self, expr: &Expr) -> Result<ExprImpl> {
67        match expr {
68            // literal
69            Expr::Value(v) => Ok(ExprImpl::Literal(Box::new(self.bind_value(v)?))),
70            Expr::TypedString { data_type, value } => {
71                let s: ExprImpl = self.bind_string(value)?.into();
72                s.cast_explicit(&bind_data_type(data_type)?)
73                    .map_err(Into::into)
74            }
75            Expr::Row(exprs) => self.bind_row(exprs),
76            // input ref
77            Expr::Identifier(ident) => {
78                if is_sys_function_without_args(ident) {
79                    // Rewrite a system variable to a function call, e.g. `SELECT current_schema;`
80                    // will be rewritten to `SELECT current_schema();`.
81                    // NOTE: Here we don't 100% follow the behavior of Postgres, as it doesn't
82                    // allow `session_user()` while we do.
83                    self.bind_function(&Function::no_arg(ObjectName(vec![ident.clone()])))
84                } else if let Some(ref lambda_args) = self.context.lambda_args {
85                    // We don't support capture, so if the expression is in the lambda context,
86                    // we'll not bind it for table columns.
87                    if let Some((arg_idx, arg_type)) = lambda_args.get(&ident.real_value()) {
88                        Ok(InputRef::new(*arg_idx, arg_type.clone()).into())
89                    } else {
90                        Err(
91                            ErrorCode::ItemNotFound(format!("Unknown arg: {}", ident.real_value()))
92                                .into(),
93                        )
94                    }
95                } else if let Some(ctx) = self.secure_compare_context.as_ref() {
96                    // Currently, the generated columns are not supported yet. So the ident here should only be one of the following
97                    // - `headers`
98                    // - secret name
99                    // - the name of the payload column
100                    // TODO(Kexiang): Generated columns or INCLUDE clause should be supported.
101                    if ident.real_value() == *"headers" {
102                        Ok(InputRef::new(0, DataType::Jsonb).into())
103                    } else if ctx.secret_name.is_some()
104                        && ident.real_value() == *ctx.secret_name.as_ref().unwrap()
105                    {
106                        Ok(InputRef::new(1, DataType::Varchar).into())
107                    } else if ident.real_value() == ctx.column_name {
108                        Ok(InputRef::new(2, DataType::Bytea).into())
109                    } else {
110                        Err(
111                            ErrorCode::ItemNotFound(format!("Unknown arg: {}", ident.real_value()))
112                                .into(),
113                        )
114                    }
115                } else {
116                    self.bind_column(slice::from_ref(ident))
117                }
118            }
119            Expr::CompoundIdentifier(idents) => self.bind_column(idents),
120            Expr::FieldIdentifier(field_expr, idents) => {
121                self.bind_single_field_column(field_expr, idents)
122            }
123            // operators & functions
124            Expr::UnaryOp { op, expr } => self.bind_unary_expr(op, expr),
125            Expr::BinaryOp { left, op, right } => self.bind_binary_op(left, op, right),
126            Expr::Nested(expr) => self.bind_expr_inner(expr),
127            Expr::Array(Array { elem: exprs, .. }) => self.bind_array(exprs),
128            Expr::Index { obj, index } => self.bind_index(obj, index),
129            Expr::ArrayRangeIndex { obj, start, end } => {
130                self.bind_array_range_index(obj, start.as_deref(), end.as_deref())
131            }
132            Expr::Function(f) => self.bind_function(f),
133            Expr::Subquery(q) => self.bind_subquery_expr(q, SubqueryKind::Scalar),
134            Expr::Exists(q) => self.bind_subquery_expr(q, SubqueryKind::Existential),
135            Expr::InSubquery {
136                expr,
137                subquery,
138                negated,
139            } => self.bind_in_subquery(expr, subquery, *negated),
140            // special syntax (except date/time or string)
141            Expr::Cast { expr, data_type } => self.bind_cast(expr, data_type),
142            Expr::IsNull(expr) => self.bind_is_operator(ExprType::IsNull, expr),
143            Expr::IsNotNull(expr) => self.bind_is_operator(ExprType::IsNotNull, expr),
144            Expr::IsTrue(expr) => self.bind_is_operator(ExprType::IsTrue, expr),
145            Expr::IsNotTrue(expr) => self.bind_is_operator(ExprType::IsNotTrue, expr),
146            Expr::IsFalse(expr) => self.bind_is_operator(ExprType::IsFalse, expr),
147            Expr::IsNotFalse(expr) => self.bind_is_operator(ExprType::IsNotFalse, expr),
148            Expr::IsUnknown(expr) => self.bind_is_unknown(ExprType::IsNull, expr),
149            Expr::IsNotUnknown(expr) => self.bind_is_unknown(ExprType::IsNotNull, expr),
150            Expr::IsDistinctFrom(left, right) => self.bind_distinct_from(left, right),
151            Expr::IsNotDistinctFrom(left, right) => self.bind_not_distinct_from(left, right),
152            Expr::IsJson {
153                expr,
154                negated,
155                item_type,
156                unique_keys: false,
157            } => self.bind_is_json(expr, *negated, *item_type),
158            Expr::Case {
159                operand,
160                conditions,
161                results,
162                else_result,
163            } => self.bind_case(
164                operand.as_deref(),
165                conditions,
166                results,
167                else_result.as_deref(),
168            ),
169            Expr::Between {
170                expr,
171                negated,
172                low,
173                high,
174            } => self.bind_between(expr, *negated, low, high),
175            Expr::Like {
176                negated,
177                expr,
178                pattern,
179                escape_char,
180            } => self.bind_like(ExprType::Like, expr, *negated, pattern, *escape_char),
181            Expr::ILike {
182                negated,
183                expr,
184                pattern,
185                escape_char,
186            } => self.bind_like(ExprType::ILike, expr, *negated, pattern, *escape_char),
187            Expr::SimilarTo {
188                expr,
189                negated,
190                pattern,
191                escape_char,
192            } => self.bind_similar_to(expr, *negated, pattern, *escape_char),
193            Expr::InList {
194                expr,
195                list,
196                negated,
197            } => self.bind_in_list(expr, list, *negated),
198            // special syntax for date/time
199            Expr::Extract { field, expr } => self.bind_extract(field, expr),
200            Expr::AtTimeZone {
201                timestamp,
202                time_zone,
203            } => self.bind_at_time_zone(timestamp, time_zone),
204            // special syntax for string
205            Expr::Trim {
206                expr,
207                trim_where,
208                trim_what,
209            } => self.bind_trim(expr, trim_where.as_ref(), trim_what.as_deref()),
210            Expr::Substring {
211                expr,
212                substring_from,
213                substring_for,
214            } => self.bind_substring(expr, substring_from.as_deref(), substring_for.as_deref()),
215            Expr::Position { substring, string } => self.bind_position(substring, string),
216            Expr::Overlay {
217                expr,
218                new_substring,
219                start,
220                count,
221            } => self.bind_overlay(expr, new_substring, start, count.as_deref()),
222            Expr::Parameter { index } => self.bind_parameter(*index),
223            Expr::Collate { expr, collation } => self.bind_collate(expr, collation),
224            Expr::ArraySubquery(q) => self.bind_subquery_expr(q, SubqueryKind::Array),
225            Expr::Map { entries } => self.bind_map(entries),
226            Expr::IsJson {
227                unique_keys: true, ..
228            }
229            | Expr::SomeOp(_)
230            | Expr::AllOp(_)
231            | Expr::TryCast { .. }
232            | Expr::GroupingSets(_)
233            | Expr::Cube(_)
234            | Expr::Rollup(_)
235            | Expr::LambdaFunction { .. } => {
236                bail_not_implemented!(issue = 112, "unsupported expression {:?}", expr)
237            }
238        }
239    }
240
241    pub(super) fn bind_extract(&mut self, field: &String, expr: &Expr) -> Result<ExprImpl> {
242        let arg = self.bind_expr_inner(expr)?;
243        let arg_type = arg.return_type();
244        Ok(FunctionCall::new(
245            ExprType::Extract,
246            vec![self.bind_string(field)?.into(), arg],
247        )
248        .map_err(|_| {
249            not_implemented!(
250                issue = 112,
251                "function extract({} from {:?}) doesn't exist",
252                field,
253                arg_type
254            )
255        })?
256        .into())
257    }
258
259    pub(super) fn bind_at_time_zone(&mut self, input: &Expr, time_zone: &Expr) -> Result<ExprImpl> {
260        let input = self.bind_expr_inner(input)?;
261        let time_zone = self.bind_expr_inner(time_zone)?;
262        FunctionCall::new(ExprType::AtTimeZone, vec![input, time_zone]).map(Into::into)
263    }
264
265    pub(super) fn bind_in_list(
266        &mut self,
267        expr: &Expr,
268        list: &[Expr],
269        negated: bool,
270    ) -> Result<ExprImpl> {
271        let left = self.bind_expr_inner(expr)?;
272        let mut bound_expr_list = vec![left.clone()];
273        let mut non_const_exprs = vec![];
274        for elem in list {
275            let expr = self.bind_expr_inner(elem)?;
276            match expr.is_const() {
277                true => bound_expr_list.push(expr),
278                false => non_const_exprs.push(expr),
279            }
280        }
281        let mut ret = FunctionCall::new(ExprType::In, bound_expr_list)?.into();
282        // Non-const exprs are not part of IN-expr in backend and rewritten into OR-Equal-exprs.
283        for expr in non_const_exprs {
284            ret = FunctionCall::new(
285                ExprType::Or,
286                vec![
287                    ret,
288                    FunctionCall::new(ExprType::Equal, vec![left.clone(), expr])?.into(),
289                ],
290            )?
291            .into();
292        }
293        if negated {
294            Ok(FunctionCall::new_unchecked(ExprType::Not, vec![ret], DataType::Boolean).into())
295        } else {
296            Ok(ret)
297        }
298    }
299
300    pub(super) fn bind_in_subquery(
301        &mut self,
302        expr: &Expr,
303        subquery: &Query,
304        negated: bool,
305    ) -> Result<ExprImpl> {
306        let bound_expr = self.bind_expr_inner(expr)?;
307        let bound_subquery = self.bind_subquery_expr(subquery, SubqueryKind::In(bound_expr))?;
308        if negated {
309            Ok(
310                FunctionCall::new_unchecked(ExprType::Not, vec![bound_subquery], DataType::Boolean)
311                    .into(),
312            )
313        } else {
314            Ok(bound_subquery)
315        }
316    }
317
318    pub(super) fn bind_is_json(
319        &mut self,
320        expr: &Expr,
321        negated: bool,
322        item_type: JsonPredicateType,
323    ) -> Result<ExprImpl> {
324        let mut args = vec![self.bind_expr_inner(expr)?];
325        // Avoid `JsonPredicateType::to_string` so that we decouple sqlparser from expr execution
326        let type_symbol = match item_type {
327            JsonPredicateType::Value => None,
328            JsonPredicateType::Array => Some("ARRAY"),
329            JsonPredicateType::Object => Some("OBJECT"),
330            JsonPredicateType::Scalar => Some("SCALAR"),
331        };
332        if let Some(s) = type_symbol {
333            args.push(ExprImpl::literal_varchar(s.into()));
334        }
335
336        let is_json = FunctionCall::new(ExprType::IsJson, args)?.into();
337        if negated {
338            Ok(FunctionCall::new(ExprType::Not, vec![is_json])?.into())
339        } else {
340            Ok(is_json)
341        }
342    }
343
344    pub(super) fn bind_unary_expr(&mut self, op: &UnaryOperator, expr: &Expr) -> Result<ExprImpl> {
345        let func_type = match &op {
346            UnaryOperator::Not => ExprType::Not,
347            UnaryOperator::Minus => ExprType::Neg,
348            UnaryOperator::Plus => {
349                return self.rewrite_positive(expr);
350            }
351            UnaryOperator::Custom(name) => match name.as_str() {
352                "~" => ExprType::BitwiseNot,
353                "@" => ExprType::Abs,
354                "|/" => ExprType::Sqrt,
355                "||/" => ExprType::Cbrt,
356                _ => bail_not_implemented!(issue = 112, "unsupported unary expression: {:?}", op),
357            },
358            UnaryOperator::PGQualified(_) => {
359                bail_not_implemented!(issue = 112, "unsupported unary expression: {:?}", op)
360            }
361        };
362        let expr = self.bind_expr_inner(expr)?;
363        FunctionCall::new(func_type, vec![expr]).map(|f| f.into())
364    }
365
366    /// Directly returns the expression itself if it is a positive number.
367    fn rewrite_positive(&mut self, expr: &Expr) -> Result<ExprImpl> {
368        let expr = self.bind_expr_inner(expr)?;
369        let return_type = expr.return_type();
370        if return_type.is_numeric() {
371            return Ok(expr);
372        }
373        Err(ErrorCode::InvalidInputSyntax(format!("+ {:?}", return_type)).into())
374    }
375
376    pub(super) fn bind_trim(
377        &mut self,
378        expr: &Expr,
379        // BOTH | LEADING | TRAILING
380        trim_where: Option<&TrimWhereField>,
381        trim_what: Option<&Expr>,
382    ) -> Result<ExprImpl> {
383        let mut inputs = vec![self.bind_expr_inner(expr)?];
384        let func_type = match trim_where {
385            Some(TrimWhereField::Both) => ExprType::Trim,
386            Some(TrimWhereField::Leading) => ExprType::Ltrim,
387            Some(TrimWhereField::Trailing) => ExprType::Rtrim,
388            None => ExprType::Trim,
389        };
390        if let Some(t) = trim_what {
391            inputs.push(self.bind_expr_inner(t)?);
392        }
393        Ok(FunctionCall::new(func_type, inputs)?.into())
394    }
395
396    fn bind_substring(
397        &mut self,
398        expr: &Expr,
399        substring_from: Option<&Expr>,
400        substring_for: Option<&Expr>,
401    ) -> Result<ExprImpl> {
402        let mut args = vec![
403            self.bind_expr_inner(expr)?,
404            match substring_from {
405                Some(expr) => self.bind_expr_inner(expr)?,
406                None => ExprImpl::literal_int(1),
407            },
408        ];
409        if let Some(expr) = substring_for {
410            args.push(self.bind_expr_inner(expr)?);
411        }
412        FunctionCall::new(ExprType::Substr, args).map(|f| f.into())
413    }
414
415    fn bind_position(&mut self, substring: &Expr, string: &Expr) -> Result<ExprImpl> {
416        let args = vec![
417            // Note that we reverse the order of arguments.
418            self.bind_expr_inner(string)?,
419            self.bind_expr_inner(substring)?,
420        ];
421        FunctionCall::new(ExprType::Position, args).map(Into::into)
422    }
423
424    fn bind_overlay(
425        &mut self,
426        expr: &Expr,
427        new_substring: &Expr,
428        start: &Expr,
429        count: Option<&Expr>,
430    ) -> Result<ExprImpl> {
431        let mut args = vec![
432            self.bind_expr_inner(expr)?,
433            self.bind_expr_inner(new_substring)?,
434            self.bind_expr_inner(start)?,
435        ];
436        if let Some(count) = count {
437            args.push(self.bind_expr_inner(count)?);
438        }
439        FunctionCall::new(ExprType::Overlay, args).map(|f| f.into())
440    }
441
442    fn is_binding_inline_sql_udf(&self) -> bool {
443        self.context.sql_udf_arguments.is_some()
444    }
445
446    /// Returns whether we're binding SQL UDF by checking if any of the upper subquery context has
447    /// `sql_udf_arguments` set.
448    fn is_binding_subquery_sql_udf(&self) -> bool {
449        self.upper_subquery_contexts
450            .iter()
451            .any(|(context, _)| context.sql_udf_arguments.is_some())
452    }
453
454    /// Bind a parameter for SQL UDF.
455    fn bind_sql_udf_parameter(&mut self, name: &str) -> Result<ExprImpl> {
456        for (depth, context) in std::iter::once(&self.context)
457            .chain((self.upper_subquery_contexts.iter().rev()).map(|(context, _)| context))
458            .enumerate()
459        {
460            // Only lookup the first non-empty udf context. If the parameter is not found in the
461            // current context, we will continue to the upper context.
462            if let Some(args) = &context.sql_udf_arguments {
463                if let Some(expr) = args.get(name) {
464                    // The arguments recorded in the context is relative to the that context.
465                    // We need to shift the depth to the current context.
466                    let mut rewriter = InputRefDepthRewriter::new(depth);
467                    return Ok(rewriter.rewrite_expr(expr.clone()));
468                } else {
469                    // A UDF cannot access parameters from outer UDFs. Do not continue but directly
470                    // return an error.
471                    break;
472                }
473            }
474        }
475
476        Err(ErrorCode::BindError(format!(
477            "failed to find {} parameter {name}",
478            if name.starts_with('$') {
479                "unnamed"
480            } else {
481                "named"
482            }
483        ))
484        .into())
485    }
486
487    fn bind_parameter(&mut self, index: u64) -> Result<ExprImpl> {
488        // Special check for sql udf
489        // Note: This is specific to sql udf with unnamed parameters, since the
490        // parameters will be parsed and treated as `Parameter`.
491        // For detailed explanation, consider checking `bind_column`.
492        if self.is_binding_inline_sql_udf() || self.is_binding_subquery_sql_udf() {
493            let column_name = format!("${index}");
494            return self.bind_sql_udf_parameter(&column_name);
495        }
496
497        Ok(Parameter::new(index, self.param_types.clone()).into())
498    }
499
500    /// Bind `expr (not) between low and high`
501    pub(super) fn bind_between(
502        &mut self,
503        expr: &Expr,
504        negated: bool,
505        low: &Expr,
506        high: &Expr,
507    ) -> Result<ExprImpl> {
508        let expr = self.bind_expr_inner(expr)?;
509        let low = self.bind_expr_inner(low)?;
510        let high = self.bind_expr_inner(high)?;
511
512        let func_call = if negated {
513            // negated = true: expr < low or expr > high
514            FunctionCall::new_unchecked(
515                ExprType::Or,
516                vec![
517                    FunctionCall::new(ExprType::LessThan, vec![expr.clone(), low])?.into(),
518                    FunctionCall::new(ExprType::GreaterThan, vec![expr, high])?.into(),
519                ],
520                DataType::Boolean,
521            )
522        } else {
523            // negated = false: expr >= low and expr <= high
524            FunctionCall::new_unchecked(
525                ExprType::And,
526                vec![
527                    FunctionCall::new(ExprType::GreaterThanOrEqual, vec![expr.clone(), low])?
528                        .into(),
529                    FunctionCall::new(ExprType::LessThanOrEqual, vec![expr, high])?.into(),
530                ],
531                DataType::Boolean,
532            )
533        };
534
535        Ok(func_call.into())
536    }
537
538    fn bind_like(
539        &mut self,
540        expr_type: ExprType,
541        expr: &Expr,
542        negated: bool,
543        pattern: &Expr,
544        escape_char: Option<EscapeChar>,
545    ) -> Result<ExprImpl> {
546        if matches!(pattern, Expr::AllOp(_) | Expr::SomeOp(_)) {
547            if escape_char.is_some() {
548                // PostgreSQL also don't support the pattern due to the complexity of implementation.
549                // The SQL will failed on PostgreSQL 16.1:
550                // ```sql
551                // select 'a' like any(array[null]) escape '';
552                // ```
553                bail_not_implemented!(
554                    "LIKE with both ALL|ANY pattern and escape character is not supported"
555                )
556            }
557            // Use the `bind_binary_op` path to handle the ALL|ANY pattern.
558            let op = match (expr_type, negated) {
559                (ExprType::Like, false) => BinaryOperator::Custom("~~".to_owned()),
560                (ExprType::Like, true) => BinaryOperator::Custom("!~~".to_owned()),
561                (ExprType::ILike, false) => BinaryOperator::Custom("~~*".to_owned()),
562                (ExprType::ILike, true) => BinaryOperator::Custom("!~~*".to_owned()),
563                _ => unreachable!(),
564            };
565            return self.bind_binary_op(expr, &op, pattern);
566        }
567        let expr = self.bind_expr_inner(expr)?;
568        let pattern = self.bind_expr_inner(pattern)?;
569        match (expr.return_type(), pattern.return_type()) {
570            (DataType::Varchar, DataType::Varchar) => {}
571            (string_ty, pattern_ty) => match expr_type {
572                ExprType::Like => bail_no_function!("like({}, {})", string_ty, pattern_ty),
573                ExprType::ILike => bail_no_function!("ilike({}, {})", string_ty, pattern_ty),
574                _ => unreachable!(),
575            },
576        }
577        let args = match escape_char {
578            Some(escape_char) => {
579                let escape_char = ExprImpl::literal_varchar(escape_char.to_string());
580                vec![expr, pattern, escape_char]
581            }
582            None => vec![expr, pattern],
583        };
584        let func_call = FunctionCall::new_unchecked(expr_type, args, DataType::Boolean);
585        let func_call = if negated {
586            FunctionCall::new_unchecked(ExprType::Not, vec![func_call.into()], DataType::Boolean)
587        } else {
588            func_call
589        };
590        Ok(func_call.into())
591    }
592
593    /// Bind `<expr> [ NOT ] SIMILAR TO <pat> ESCAPE <esc_text>`
594    pub(super) fn bind_similar_to(
595        &mut self,
596        expr: &Expr,
597        negated: bool,
598        pattern: &Expr,
599        escape_char: Option<EscapeChar>,
600    ) -> Result<ExprImpl> {
601        let expr = self.bind_expr_inner(expr)?;
602        let pattern = self.bind_expr_inner(pattern)?;
603
604        let esc_inputs = if let Some(escape_char) = escape_char {
605            let escape_char = ExprImpl::literal_varchar(escape_char.to_string());
606            vec![pattern, escape_char]
607        } else {
608            vec![pattern]
609        };
610
611        let esc_call =
612            FunctionCall::new_unchecked(ExprType::SimilarToEscape, esc_inputs, DataType::Varchar);
613
614        let regex_call = FunctionCall::new_unchecked(
615            ExprType::RegexpEq,
616            vec![expr, esc_call.into()],
617            DataType::Boolean,
618        );
619        let func_call = if negated {
620            FunctionCall::new_unchecked(ExprType::Not, vec![regex_call.into()], DataType::Boolean)
621        } else {
622            regex_call
623        };
624
625        Ok(func_call.into())
626    }
627
628    /// The optimization check for the following case-when expression pattern
629    /// e.g., select case 1 when (...) then (...) else (...) end;
630    fn check_constant_case_when_optimization(
631        &mut self,
632        conditions: &[Expr],
633        results_expr: &[ExprImpl],
634        operand: Option<&Expr>,
635        fallback: Option<&ExprImpl>,
636        constant_case_when_eval_inputs: &mut Vec<ExprImpl>,
637    ) -> bool {
638        // The operand value to be compared later
639        let operand_value;
640
641        if let Some(operand) = operand {
642            let Ok(operand) = self.bind_expr_inner(operand) else {
643                return false;
644            };
645            if !operand.is_const() {
646                return false;
647            }
648            operand_value = operand;
649        } else {
650            return false;
651        }
652
653        for (condition, result) in zip_eq_fast(conditions, results_expr) {
654            if let Expr::Value(_) = condition.clone() {
655                let Ok(res) = self.bind_expr_inner(condition) else {
656                    return false;
657                };
658                // Found a match
659                if res == operand_value {
660                    constant_case_when_eval_inputs.push(result.clone());
661                    return true;
662                }
663            } else {
664                return false;
665            }
666        }
667
668        // Otherwise this will eventually go through fallback arm
669        debug_assert!(
670            constant_case_when_eval_inputs.is_empty(),
671            "expect `inputs` to be empty"
672        );
673
674        let Some(fallback) = fallback else {
675            return false;
676        };
677
678        constant_case_when_eval_inputs.push(fallback.clone());
679        true
680    }
681
682    /// Helper function to compare or set column identifier
683    /// used in `check_convert_simple_form`
684    fn compare_or_set(col_expr: &mut Option<Expr>, test_expr: &Expr) -> bool {
685        let Expr::Identifier(test_ident) = test_expr else {
686            return false;
687        };
688        if let Some(expr) = col_expr {
689            let Expr::Identifier(ident) = expr else {
690                return false;
691            };
692            if ident.real_value() != test_ident.real_value() {
693                return false;
694            }
695        } else {
696            *col_expr = Some(Expr::Identifier(test_ident.clone()));
697        }
698        true
699    }
700
701    /// left expression and right expression must be either:
702    /// `<constant> <Eq> <identifier>` or `<identifier> <Eq> <constant>`
703    /// used in `check_convert_simple_form`
704    fn check_invariant(left: &Expr, op: &BinaryOperator, right: &Expr) -> bool {
705        if op != &BinaryOperator::Eq {
706            return false;
707        }
708        if let Expr::Identifier(_) = left {
709            // <identifier> <Eq> <constant>
710            let Expr::Value(_) = right else {
711                return false;
712            };
713        } else {
714            // <constant> <Eq> <identifier>
715            let Expr::Value(_) = left else {
716                return false;
717            };
718            let Expr::Identifier(_) = right else {
719                return false;
720            };
721        }
722        true
723    }
724
725    /// Helper function to extract expression out and insert
726    /// the corresponding bound version to `inputs`
727    /// used in `check_convert_simple_form`
728    /// Note: this function will be invoked per arm
729    fn try_extract_simple_form(
730        &mut self,
731        ident_expr: &Expr,
732        constant_expr: &Expr,
733        column_expr: &mut Option<Expr>,
734        inputs: &mut Vec<ExprImpl>,
735    ) -> bool {
736        if !Self::compare_or_set(column_expr, ident_expr) {
737            return false;
738        }
739        let Ok(bound_expr) = self.bind_expr_inner(constant_expr) else {
740            return false;
741        };
742        inputs.push(bound_expr);
743        true
744    }
745
746    /// See if the case when expression in form
747    /// `select case when <expr_1 = constant> (...with same pattern...) else <constant> end;`
748    /// If so, this expression could also be converted to constant lookup
749    fn check_convert_simple_form(
750        &mut self,
751        conditions: &[Expr],
752        results_expr: &[ExprImpl],
753        fallback: Option<ExprImpl>,
754        constant_lookup_inputs: &mut Vec<ExprImpl>,
755    ) -> bool {
756        let mut column_expr = None;
757
758        for (condition, result) in zip_eq_fast(conditions, results_expr) {
759            if let Expr::BinaryOp { left, op, right } = condition {
760                if !Self::check_invariant(left, op, right) {
761                    return false;
762                }
763                if let Expr::Identifier(_) = &**left {
764                    if !self.try_extract_simple_form(
765                        left,
766                        right,
767                        &mut column_expr,
768                        constant_lookup_inputs,
769                    ) {
770                        return false;
771                    }
772                } else if !self.try_extract_simple_form(
773                    right,
774                    left,
775                    &mut column_expr,
776                    constant_lookup_inputs,
777                ) {
778                    return false;
779                }
780                constant_lookup_inputs.push(result.clone());
781            } else {
782                return false;
783            }
784        }
785
786        // Insert operand first
787        let Some(operand) = column_expr else {
788            return false;
789        };
790        let Ok(bound_operand) = self.bind_expr_inner(&operand) else {
791            return false;
792        };
793        constant_lookup_inputs.insert(0, bound_operand);
794
795        // fallback insertion
796        if let Some(expr) = fallback {
797            constant_lookup_inputs.push(expr);
798        }
799
800        true
801    }
802
803    /// The helper function to check if the current case-when
804    /// expression in `bind_case` could be optimized
805    /// into `ConstantLookupExpression`
806    fn check_bind_case_optimization(
807        &mut self,
808        conditions: &[Expr],
809        results_expr: &[ExprImpl],
810        operand: Option<&Expr>,
811        fallback: Option<ExprImpl>,
812        constant_lookup_inputs: &mut Vec<ExprImpl>,
813    ) -> bool {
814        if conditions.len() < CASE_WHEN_ARMS_OPTIMIZE_LIMIT {
815            return false;
816        }
817
818        if let Some(operand) = operand {
819            let Ok(operand) = self.bind_expr_inner(operand) else {
820                return false;
821            };
822            // This optimization should be done in subsequent optimization phase
823            // if the operand is const
824            // e.g., select case 1 when 1 then 114514 else 1919810 end;
825            if operand.is_const() {
826                return false;
827            }
828            constant_lookup_inputs.push(operand);
829        } else {
830            // Try converting to simple form
831            // see the example as illustrated in `check_convert_simple_form`
832            return self.check_convert_simple_form(
833                conditions,
834                results_expr,
835                fallback,
836                constant_lookup_inputs,
837            );
838        }
839
840        for (condition, result) in zip_eq_fast(conditions, results_expr) {
841            if let Expr::Value(_) = condition {
842                let Ok(input) = self.bind_expr_inner(condition) else {
843                    return false;
844                };
845                constant_lookup_inputs.push(input);
846            } else {
847                // If at least one condition is not in the simple form / not constant,
848                // we can NOT do the subsequent optimization pass
849                return false;
850            }
851
852            constant_lookup_inputs.push(result.clone());
853        }
854
855        // The fallback arm for case-when expression
856        if let Some(expr) = fallback {
857            constant_lookup_inputs.push(expr);
858        }
859
860        true
861    }
862
863    pub(super) fn bind_case(
864        &mut self,
865        operand: Option<&Expr>,
866        conditions: &[Expr],
867        results: &[Expr],
868        else_result: Option<&Expr>,
869    ) -> Result<ExprImpl> {
870        let mut inputs = Vec::new();
871        let results_expr: Vec<ExprImpl> = results
872            .iter()
873            .map(|expr| self.bind_expr_inner(expr))
874            .collect::<Result<_>>()?;
875        let else_result_expr = else_result
876            .map(|expr| self.bind_expr_inner(expr))
877            .transpose()?;
878
879        let mut constant_lookup_inputs = Vec::new();
880        let mut constant_case_when_eval_inputs = Vec::new();
881
882        let constant_case_when_flag = self.check_constant_case_when_optimization(
883            conditions,
884            &results_expr,
885            operand,
886            else_result_expr.as_ref(),
887            &mut constant_case_when_eval_inputs,
888        );
889
890        if constant_case_when_flag {
891            // Sanity check
892            if constant_case_when_eval_inputs.len() != 1 {
893                return Err(ErrorCode::BindError(
894                    "expect `constant_case_when_eval_inputs` only contains a single bound expression".to_owned()
895                )
896                    .into());
897            }
898            // Directly return the first element of the vector
899            return Ok(constant_case_when_eval_inputs[0].take());
900        }
901
902        // See if the case-when expression can be optimized
903        let optimize_flag = self.check_bind_case_optimization(
904            conditions,
905            &results_expr,
906            operand,
907            else_result_expr.clone(),
908            &mut constant_lookup_inputs,
909        );
910
911        if optimize_flag {
912            return Ok(FunctionCall::new(ExprType::ConstantLookup, constant_lookup_inputs)?.into());
913        }
914
915        for (condition, result) in zip_eq_fast(conditions, results_expr) {
916            let condition = condition.clone();
917            let condition = match operand {
918                Some(t) => Expr::BinaryOp {
919                    left: t.clone().into(),
920                    op: BinaryOperator::Eq,
921                    right: Box::new(condition),
922                },
923                None => condition,
924            };
925            inputs.push(
926                self.bind_expr_inner(&condition)
927                    .and_then(|expr| expr.enforce_bool_clause("CASE WHEN"))?,
928            );
929            inputs.push(result);
930        }
931
932        // The fallback arm for case-when expression
933        if let Some(expr) = else_result_expr {
934            inputs.push(expr);
935        }
936
937        if inputs.iter().any(ExprImpl::has_table_function) {
938            return Err(
939                ErrorCode::BindError("table functions are not allowed in CASE".into()).into(),
940            );
941        }
942
943        Ok(FunctionCall::new(ExprType::Case, inputs)?.into())
944    }
945
946    pub(super) fn bind_is_operator(
947        &mut self,
948        func_type: ExprType,
949        expr: &Expr,
950    ) -> Result<ExprImpl> {
951        let expr = self.bind_expr_inner(expr)?;
952        Ok(FunctionCall::new(func_type, vec![expr])?.into())
953    }
954
955    pub(super) fn bind_is_unknown(&mut self, func_type: ExprType, expr: &Expr) -> Result<ExprImpl> {
956        let expr = self
957            .bind_expr_inner(expr)?
958            .cast_implicit(&DataType::Boolean)?;
959        Ok(FunctionCall::new(func_type, vec![expr])?.into())
960    }
961
962    pub(super) fn bind_distinct_from(&mut self, left: &Expr, right: &Expr) -> Result<ExprImpl> {
963        let left = self.bind_expr_inner(left)?;
964        let right = self.bind_expr_inner(right)?;
965        let func_call = FunctionCall::new(ExprType::IsDistinctFrom, vec![left, right]);
966        Ok(func_call?.into())
967    }
968
969    pub(super) fn bind_not_distinct_from(&mut self, left: &Expr, right: &Expr) -> Result<ExprImpl> {
970        let left = self.bind_expr_inner(left)?;
971        let right = self.bind_expr_inner(right)?;
972        let func_call = FunctionCall::new(ExprType::IsNotDistinctFrom, vec![left, right]);
973        Ok(func_call?.into())
974    }
975
976    pub(super) fn bind_cast(&mut self, expr: &Expr, data_type: &AstDataType) -> Result<ExprImpl> {
977        match &data_type {
978            // Casting to Regclass type means getting the oid of expr.
979            // See https://www.postgresql.org/docs/current/datatype-oid.html.
980            AstDataType::Regclass => {
981                let input = self.bind_expr_inner(expr)?;
982                Ok(input.cast_to_regclass()?)
983            }
984            AstDataType::Regproc => {
985                let lhs = self.bind_expr_inner(expr)?;
986                let lhs_ty = lhs.return_type();
987                if lhs_ty == DataType::Varchar {
988                    // FIXME: Currently, we only allow VARCHAR to be casted to Regproc.
989                    // FIXME: Check whether it's a valid proc
990                    // FIXME: The return type should be casted to Regproc, but we don't have this type.
991                    Ok(lhs)
992                } else {
993                    Err(ErrorCode::BindError(format!("Can't cast {} to regproc", lhs_ty)).into())
994                }
995            }
996            // Redirect cast char to varchar to make system like Metabase happy.
997            // Char is not supported in RisingWave, but some ecosystem tools like Metabase will use it.
998            // Notice that the behavior of `char` and `varchar` is different in PostgreSQL.
999            // The following sql result should be different in PostgreSQL:
1000            // ```
1001            // select 'a'::char(2) = 'a '::char(2);
1002            // ----------
1003            // t
1004            //
1005            // select 'a'::varchar = 'a '::varchar;
1006            // ----------
1007            // f
1008            // ```
1009            AstDataType::Char(_) => self.bind_cast_inner(expr, &DataType::Varchar),
1010            _ => self.bind_cast_inner(expr, &bind_data_type(data_type)?),
1011        }
1012    }
1013
1014    pub fn bind_cast_inner(&mut self, expr: &Expr, data_type: &DataType) -> Result<ExprImpl> {
1015        match (expr, data_type) {
1016            (Expr::Array(Array { elem: expr, .. }), DataType::List(element_type)) => {
1017                self.bind_array_cast(expr, element_type)
1018            }
1019            (Expr::Map { entries }, DataType::Map(m)) => self.bind_map_cast(entries, m),
1020            (expr, data_type) => {
1021                let lhs = self.bind_expr_inner(expr)?;
1022                lhs.cast_explicit(data_type).map_err(Into::into)
1023            }
1024        }
1025    }
1026
1027    pub fn bind_collate(&mut self, expr: &Expr, collation: &ObjectName) -> Result<ExprImpl> {
1028        if !["C", "POSIX"].contains(&collation.real_value().as_str()) {
1029            bail_not_implemented!("Collate collation other than `C` or `POSIX` is not implemented");
1030        }
1031
1032        let bound_inner = self.bind_expr_inner(expr)?;
1033        let ret_type = bound_inner.return_type();
1034
1035        match ret_type {
1036            DataType::Varchar => {}
1037            _ => {
1038                return Err(ErrorCode::NotSupported(
1039                    format!("{} is not a collatable data type", ret_type),
1040                    "The only built-in collatable data types are `varchar`, please check your type"
1041                        .into(),
1042                )
1043                .into());
1044            }
1045        }
1046
1047        Ok(bound_inner)
1048    }
1049}
1050
1051pub fn bind_data_type(data_type: &AstDataType) -> Result<DataType> {
1052    let new_err = || not_implemented!("unsupported data type: {:}", data_type);
1053    let data_type = match data_type {
1054        AstDataType::Boolean => DataType::Boolean,
1055        AstDataType::SmallInt => DataType::Int16,
1056        AstDataType::Int => DataType::Int32,
1057        AstDataType::BigInt => DataType::Int64,
1058        AstDataType::Real | AstDataType::Float(Some(1..=24)) => DataType::Float32,
1059        AstDataType::Double | AstDataType::Float(Some(25..=53) | None) => DataType::Float64,
1060        AstDataType::Float(Some(0 | 54..)) => unreachable!(),
1061        AstDataType::Decimal(None, None) => DataType::Decimal,
1062        AstDataType::Varchar | AstDataType::Text => DataType::Varchar,
1063        AstDataType::Date => DataType::Date,
1064        AstDataType::Time(false) => DataType::Time,
1065        AstDataType::Timestamp(false) => DataType::Timestamp,
1066        AstDataType::Timestamp(true) => DataType::Timestamptz,
1067        AstDataType::Interval => DataType::Interval,
1068        AstDataType::Array(datatype) => DataType::List(Box::new(bind_data_type(datatype)?)),
1069        AstDataType::Char(..) => {
1070            bail_not_implemented!("CHAR is not supported, please use VARCHAR instead")
1071        }
1072        AstDataType::Struct(types) => StructType::new(
1073            types
1074                .iter()
1075                .map(|f| Ok((f.name.real_value(), bind_data_type(&f.data_type)?)))
1076                .collect::<Result<Vec<_>>>()?,
1077        )
1078        .into(),
1079        AstDataType::Map(kv) => {
1080            let key = bind_data_type(&kv.0)?;
1081            let value = bind_data_type(&kv.1)?;
1082            DataType::Map(MapType::try_from_kv(key, value).map_err(ErrorCode::BindError)?)
1083        }
1084        AstDataType::Custom(qualified_type_name) => {
1085            let idents = qualified_type_name
1086                .0
1087                .iter()
1088                .map(|n| n.real_value())
1089                .collect_vec();
1090            let name = if idents.len() == 1 {
1091                idents[0].as_str() // `int2`
1092            } else if idents.len() == 2 && idents[0] == PG_CATALOG_SCHEMA_NAME {
1093                idents[1].as_str() // `pg_catalog.text`
1094            } else {
1095                return Err(new_err().into());
1096            };
1097
1098            // In PostgreSQL, these are non-keywords or non-reserved keywords but pre-defined
1099            // names that could be extended by `CREATE TYPE`.
1100            match name {
1101                "int2" => DataType::Int16,
1102                "int4" => DataType::Int32,
1103                "int8" => DataType::Int64,
1104                "rw_int256" => DataType::Int256,
1105                "float4" => DataType::Float32,
1106                "float8" => DataType::Float64,
1107                "timestamptz" => DataType::Timestamptz,
1108                "text" => DataType::Varchar,
1109                "serial" => {
1110                    return Err(ErrorCode::NotSupported(
1111                        "Column type SERIAL is not supported".into(),
1112                        "Please remove the SERIAL column".into(),
1113                    )
1114                    .into());
1115                }
1116                _ => return Err(new_err().into()),
1117            }
1118        }
1119        AstDataType::Bytea => DataType::Bytea,
1120        AstDataType::Jsonb => DataType::Jsonb,
1121        AstDataType::Vector(size) => match (1..=DataType::VEC_MAX_SIZE).contains(&(*size as _)) {
1122            true => DataType::Vector(*size as _),
1123            false => {
1124                return Err(ErrorCode::BindError(format!(
1125                    "vector size {} is out of range [1, {}]",
1126                    size,
1127                    DataType::VEC_MAX_SIZE
1128                ))
1129                .into());
1130            }
1131        },
1132        AstDataType::Regclass
1133        | AstDataType::Regproc
1134        | AstDataType::Uuid
1135        | AstDataType::Decimal(_, _)
1136        | AstDataType::Time(true) => return Err(new_err().into()),
1137    };
1138    Ok(data_type)
1139}