risingwave_frontend/binder/expr/
mod.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::slice;
16
17use itertools::Itertools;
18use risingwave_common::catalog::PG_CATALOG_SCHEMA_NAME;
19use risingwave_common::types::{DataType, MapType, StructType};
20use risingwave_common::util::iter_util::zip_eq_fast;
21use risingwave_common::{bail_no_function, bail_not_implemented, not_implemented};
22use risingwave_sqlparser::ast::{
23    Array, BinaryOperator, DataType as AstDataType, EscapeChar, Expr, Function, JsonPredicateType,
24    ObjectName, Query, TrimWhereField, UnaryOperator,
25};
26
27use crate::binder::Binder;
28use crate::binder::expr::function::is_sys_function_without_args;
29use crate::error::{ErrorCode, Result, RwError};
30use crate::expr::{
31    Expr as _, ExprImpl, ExprRewriter as _, ExprType, FunctionCall, InputRef,
32    InputRefDepthRewriter, Parameter, SubqueryKind,
33};
34
35mod binary_op;
36mod column;
37mod function;
38mod order_by;
39mod subquery;
40mod value;
41
42/// The limit arms for case-when expression
43/// When the number of condition arms exceed
44/// this limit, we will try optimize the case-when
45/// expression to `ConstantLookupExpression`
46/// Check `case.rs` for details.
47const CASE_WHEN_ARMS_OPTIMIZE_LIMIT: usize = 30;
48
49impl Binder {
50    /// Bind an expression with `bind_expr_inner`, attach the original expression
51    /// to the error message.
52    ///
53    /// This may only be called at the root of the expression tree or when crossing
54    /// the boundary of a subquery. Otherwise, the source chain might be too deep
55    /// and confusing to the user.
56    // TODO(error-handling): use a dedicated error type during binding to make it clear.
57    pub fn bind_expr(&mut self, expr: &Expr) -> Result<ExprImpl> {
58        self.bind_expr_inner(expr).map_err(|e| {
59            RwError::from(ErrorCode::BindErrorRoot {
60                expr: expr.to_string(),
61                error: Box::new(e),
62            })
63        })
64    }
65
66    fn bind_expr_inner(&mut self, expr: &Expr) -> Result<ExprImpl> {
67        match expr {
68            // literal
69            Expr::Value(v) => Ok(ExprImpl::Literal(Box::new(self.bind_value(v)?))),
70            Expr::TypedString { data_type, value } => {
71                let s: ExprImpl = self.bind_string(value)?.into();
72                s.cast_explicit(&bind_data_type(data_type)?)
73                    .map_err(Into::into)
74            }
75            Expr::Row(exprs) => self.bind_row(exprs),
76            // input ref
77            Expr::Identifier(ident) => {
78                if is_sys_function_without_args(ident) {
79                    // Rewrite a system variable to a function call, e.g. `SELECT current_schema;`
80                    // will be rewritten to `SELECT current_schema();`.
81                    // NOTE: Here we don't 100% follow the behavior of Postgres, as it doesn't
82                    // allow `session_user()` while we do.
83                    self.bind_function(&Function::no_arg(ObjectName(vec![ident.clone()])))
84                } else if let Some(ref lambda_args) = self.context.lambda_args {
85                    // We don't support capture, so if the expression is in the lambda context,
86                    // we'll not bind it for table columns.
87                    if let Some((arg_idx, arg_type)) = lambda_args.get(&ident.real_value()) {
88                        Ok(InputRef::new(*arg_idx, arg_type.clone()).into())
89                    } else {
90                        Err(
91                            ErrorCode::ItemNotFound(format!("Unknown arg: {}", ident.real_value()))
92                                .into(),
93                        )
94                    }
95                } else if let Some(ctx) = self.secure_compare_context.as_ref() {
96                    // Currently, the generated columns are not supported yet. So the ident here should only be one of the following
97                    // - `headers`
98                    // - secret name
99                    // - the name of the payload column
100                    // TODO(Kexiang): Generated columns or INCLUDE clause should be supported.
101                    if ident.real_value() == *"headers" {
102                        Ok(InputRef::new(0, DataType::Jsonb).into())
103                    } else if ctx.secret_name.is_some()
104                        && ident.real_value() == *ctx.secret_name.as_ref().unwrap()
105                    {
106                        Ok(InputRef::new(1, DataType::Varchar).into())
107                    } else if ident.real_value() == ctx.column_name {
108                        Ok(InputRef::new(2, DataType::Bytea).into())
109                    } else {
110                        Err(
111                            ErrorCode::ItemNotFound(format!("Unknown arg: {}", ident.real_value()))
112                                .into(),
113                        )
114                    }
115                } else {
116                    self.bind_column(slice::from_ref(ident))
117                }
118            }
119            Expr::CompoundIdentifier(idents) => self.bind_column(idents),
120            Expr::FieldIdentifier(field_expr, idents) => {
121                self.bind_single_field_column(field_expr, idents)
122            }
123            // operators & functions
124            Expr::UnaryOp { op, expr } => self.bind_unary_expr(op, expr),
125            Expr::BinaryOp { left, op, right } => self.bind_binary_op(left, op, right),
126            Expr::Nested(expr) => self.bind_expr_inner(expr),
127            Expr::Array(Array { elem: exprs, .. }) => self.bind_array(exprs),
128            Expr::Index { obj, index } => self.bind_index(obj, index),
129            Expr::ArrayRangeIndex { obj, start, end } => {
130                self.bind_array_range_index(obj, start.as_deref(), end.as_deref())
131            }
132            Expr::Function(f) => self.bind_function(f),
133            Expr::Subquery(q) => self.bind_subquery_expr(q, SubqueryKind::Scalar),
134            Expr::Exists(q) => self.bind_subquery_expr(q, SubqueryKind::Existential),
135            Expr::InSubquery {
136                expr,
137                subquery,
138                negated,
139            } => self.bind_in_subquery(expr, subquery, *negated),
140            // special syntax (except date/time or string)
141            Expr::Cast { expr, data_type } => self.bind_cast(expr, data_type),
142            Expr::IsNull(expr) => self.bind_is_operator(ExprType::IsNull, expr),
143            Expr::IsNotNull(expr) => self.bind_is_operator(ExprType::IsNotNull, expr),
144            Expr::IsTrue(expr) => self.bind_is_operator(ExprType::IsTrue, expr),
145            Expr::IsNotTrue(expr) => self.bind_is_operator(ExprType::IsNotTrue, expr),
146            Expr::IsFalse(expr) => self.bind_is_operator(ExprType::IsFalse, expr),
147            Expr::IsNotFalse(expr) => self.bind_is_operator(ExprType::IsNotFalse, expr),
148            Expr::IsUnknown(expr) => self.bind_is_unknown(ExprType::IsNull, expr),
149            Expr::IsNotUnknown(expr) => self.bind_is_unknown(ExprType::IsNotNull, expr),
150            Expr::IsDistinctFrom(left, right) => self.bind_distinct_from(left, right),
151            Expr::IsNotDistinctFrom(left, right) => self.bind_not_distinct_from(left, right),
152            Expr::IsJson {
153                expr,
154                negated,
155                item_type,
156                unique_keys: false,
157            } => self.bind_is_json(expr, *negated, *item_type),
158            Expr::Case {
159                operand,
160                conditions,
161                results,
162                else_result,
163            } => self.bind_case(
164                operand.as_deref(),
165                conditions,
166                results,
167                else_result.as_deref(),
168            ),
169            Expr::Between {
170                expr,
171                negated,
172                low,
173                high,
174            } => self.bind_between(expr, *negated, low, high),
175            Expr::Like {
176                negated,
177                expr,
178                pattern,
179                escape_char,
180            } => self.bind_like(ExprType::Like, expr, *negated, pattern, *escape_char),
181            Expr::ILike {
182                negated,
183                expr,
184                pattern,
185                escape_char,
186            } => self.bind_like(ExprType::ILike, expr, *negated, pattern, *escape_char),
187            Expr::SimilarTo {
188                expr,
189                negated,
190                pattern,
191                escape_char,
192            } => self.bind_similar_to(expr, *negated, pattern, *escape_char),
193            Expr::InList {
194                expr,
195                list,
196                negated,
197            } => self.bind_in_list(expr, list, *negated),
198            // special syntax for date/time
199            Expr::Extract { field, expr } => self.bind_extract(field, expr),
200            Expr::AtTimeZone {
201                timestamp,
202                time_zone,
203            } => self.bind_at_time_zone(timestamp, time_zone),
204            // special syntax for string
205            Expr::Trim {
206                expr,
207                trim_where,
208                trim_what,
209            } => self.bind_trim(expr, trim_where.as_ref(), trim_what.as_deref()),
210            Expr::Substring {
211                expr,
212                substring_from,
213                substring_for,
214            } => self.bind_substring(expr, substring_from.as_deref(), substring_for.as_deref()),
215            Expr::Position { substring, string } => self.bind_position(substring, string),
216            Expr::Overlay {
217                expr,
218                new_substring,
219                start,
220                count,
221            } => self.bind_overlay(expr, new_substring, start, count.as_deref()),
222            Expr::Parameter { index } => self.bind_parameter(*index),
223            Expr::Collate { expr, collation } => self.bind_collate(expr, collation),
224            Expr::ArraySubquery(q) => self.bind_subquery_expr(q, SubqueryKind::Array),
225            Expr::Map { entries } => self.bind_map(entries),
226            Expr::IsJson {
227                unique_keys: true, ..
228            }
229            | Expr::SomeOp(_)
230            | Expr::AllOp(_)
231            | Expr::TryCast { .. }
232            | Expr::GroupingSets(_)
233            | Expr::Cube(_)
234            | Expr::Rollup(_)
235            | Expr::LambdaFunction { .. } => {
236                bail_not_implemented!(issue = 112, "unsupported expression {:?}", expr)
237            }
238        }
239    }
240
241    pub(super) fn bind_extract(&mut self, field: &String, expr: &Expr) -> Result<ExprImpl> {
242        let arg = self.bind_expr_inner(expr)?;
243        let arg_type = arg.return_type();
244        Ok(FunctionCall::new(
245            ExprType::Extract,
246            vec![self.bind_string(field)?.into(), arg],
247        )
248        .map_err(|_| {
249            not_implemented!(
250                issue = 112,
251                "function extract({} from {:?}) doesn't exist",
252                field,
253                arg_type
254            )
255        })?
256        .into())
257    }
258
259    pub(super) fn bind_at_time_zone(&mut self, input: &Expr, time_zone: &Expr) -> Result<ExprImpl> {
260        let input = self.bind_expr_inner(input)?;
261        let time_zone = self.bind_expr_inner(time_zone)?;
262        FunctionCall::new(ExprType::AtTimeZone, vec![input, time_zone]).map(Into::into)
263    }
264
265    pub(super) fn bind_in_list(
266        &mut self,
267        expr: &Expr,
268        list: &[Expr],
269        negated: bool,
270    ) -> Result<ExprImpl> {
271        let left = self.bind_expr_inner(expr)?;
272        let mut bound_expr_list = vec![left.clone()];
273        let mut non_const_exprs = vec![];
274        for elem in list {
275            let expr = self.bind_expr_inner(elem)?;
276            match expr.is_const() {
277                true => bound_expr_list.push(expr),
278                false => non_const_exprs.push(expr),
279            }
280        }
281
282        let mut ret = if bound_expr_list.len() == 1 {
283            None
284        } else {
285            Some(FunctionCall::new(ExprType::In, bound_expr_list)?.into())
286        };
287        // Non-const exprs are not part of IN-expr in backend and rewritten into OR-Equal-exprs.
288        for expr in non_const_exprs {
289            if let Some(inner_ret) = ret {
290                ret = Some(
291                    FunctionCall::new(
292                        ExprType::Or,
293                        vec![
294                            inner_ret,
295                            FunctionCall::new(ExprType::Equal, vec![left.clone(), expr])?.into(),
296                        ],
297                    )?
298                    .into(),
299                );
300            } else {
301                ret = Some(FunctionCall::new(ExprType::Equal, vec![left.clone(), expr])?.into());
302            }
303        }
304        if negated {
305            Ok(
306                FunctionCall::new_unchecked(ExprType::Not, vec![ret.unwrap()], DataType::Boolean)
307                    .into(),
308            )
309        } else {
310            Ok(ret.unwrap())
311        }
312    }
313
314    pub(super) fn bind_in_subquery(
315        &mut self,
316        expr: &Expr,
317        subquery: &Query,
318        negated: bool,
319    ) -> Result<ExprImpl> {
320        let bound_expr = self.bind_expr_inner(expr)?;
321        let bound_subquery = self.bind_subquery_expr(subquery, SubqueryKind::In(bound_expr))?;
322        if negated {
323            Ok(
324                FunctionCall::new_unchecked(ExprType::Not, vec![bound_subquery], DataType::Boolean)
325                    .into(),
326            )
327        } else {
328            Ok(bound_subquery)
329        }
330    }
331
332    pub(super) fn bind_is_json(
333        &mut self,
334        expr: &Expr,
335        negated: bool,
336        item_type: JsonPredicateType,
337    ) -> Result<ExprImpl> {
338        let mut args = vec![self.bind_expr_inner(expr)?];
339        // Avoid `JsonPredicateType::to_string` so that we decouple sqlparser from expr execution
340        let type_symbol = match item_type {
341            JsonPredicateType::Value => None,
342            JsonPredicateType::Array => Some("ARRAY"),
343            JsonPredicateType::Object => Some("OBJECT"),
344            JsonPredicateType::Scalar => Some("SCALAR"),
345        };
346        if let Some(s) = type_symbol {
347            args.push(ExprImpl::literal_varchar(s.into()));
348        }
349
350        let is_json = FunctionCall::new(ExprType::IsJson, args)?.into();
351        if negated {
352            Ok(FunctionCall::new(ExprType::Not, vec![is_json])?.into())
353        } else {
354            Ok(is_json)
355        }
356    }
357
358    pub(super) fn bind_unary_expr(&mut self, op: &UnaryOperator, expr: &Expr) -> Result<ExprImpl> {
359        let func_type = match &op {
360            UnaryOperator::Not => ExprType::Not,
361            UnaryOperator::Minus => ExprType::Neg,
362            UnaryOperator::Plus => {
363                return self.rewrite_positive(expr);
364            }
365            UnaryOperator::Custom(name) => match name.as_str() {
366                "~" => ExprType::BitwiseNot,
367                "@" => ExprType::Abs,
368                "|/" => ExprType::Sqrt,
369                "||/" => ExprType::Cbrt,
370                _ => bail_not_implemented!(issue = 112, "unsupported unary expression: {:?}", op),
371            },
372            UnaryOperator::PGQualified(_) => {
373                bail_not_implemented!(issue = 112, "unsupported unary expression: {:?}", op)
374            }
375        };
376        let expr = self.bind_expr_inner(expr)?;
377        FunctionCall::new(func_type, vec![expr]).map(|f| f.into())
378    }
379
380    /// Directly returns the expression itself if it is a positive number.
381    fn rewrite_positive(&mut self, expr: &Expr) -> Result<ExprImpl> {
382        let expr = self.bind_expr_inner(expr)?;
383        let return_type = expr.return_type();
384        if return_type.is_numeric() {
385            return Ok(expr);
386        }
387        Err(ErrorCode::InvalidInputSyntax(format!("+ {:?}", return_type)).into())
388    }
389
390    pub(super) fn bind_trim(
391        &mut self,
392        expr: &Expr,
393        // BOTH | LEADING | TRAILING
394        trim_where: Option<&TrimWhereField>,
395        trim_what: Option<&Expr>,
396    ) -> Result<ExprImpl> {
397        let mut inputs = vec![self.bind_expr_inner(expr)?];
398        let func_type = match trim_where {
399            Some(TrimWhereField::Both) => ExprType::Trim,
400            Some(TrimWhereField::Leading) => ExprType::Ltrim,
401            Some(TrimWhereField::Trailing) => ExprType::Rtrim,
402            None => ExprType::Trim,
403        };
404        if let Some(t) = trim_what {
405            inputs.push(self.bind_expr_inner(t)?);
406        }
407        Ok(FunctionCall::new(func_type, inputs)?.into())
408    }
409
410    fn bind_substring(
411        &mut self,
412        expr: &Expr,
413        substring_from: Option<&Expr>,
414        substring_for: Option<&Expr>,
415    ) -> Result<ExprImpl> {
416        let mut args = vec![
417            self.bind_expr_inner(expr)?,
418            match substring_from {
419                Some(expr) => self.bind_expr_inner(expr)?,
420                None => ExprImpl::literal_int(1),
421            },
422        ];
423        if let Some(expr) = substring_for {
424            args.push(self.bind_expr_inner(expr)?);
425        }
426        FunctionCall::new(ExprType::Substr, args).map(|f| f.into())
427    }
428
429    fn bind_position(&mut self, substring: &Expr, string: &Expr) -> Result<ExprImpl> {
430        let args = vec![
431            // Note that we reverse the order of arguments.
432            self.bind_expr_inner(string)?,
433            self.bind_expr_inner(substring)?,
434        ];
435        FunctionCall::new(ExprType::Position, args).map(Into::into)
436    }
437
438    fn bind_overlay(
439        &mut self,
440        expr: &Expr,
441        new_substring: &Expr,
442        start: &Expr,
443        count: Option<&Expr>,
444    ) -> Result<ExprImpl> {
445        let mut args = vec![
446            self.bind_expr_inner(expr)?,
447            self.bind_expr_inner(new_substring)?,
448            self.bind_expr_inner(start)?,
449        ];
450        if let Some(count) = count {
451            args.push(self.bind_expr_inner(count)?);
452        }
453        FunctionCall::new(ExprType::Overlay, args).map(|f| f.into())
454    }
455
456    fn is_binding_inline_sql_udf(&self) -> bool {
457        self.context.sql_udf_arguments.is_some()
458    }
459
460    /// Returns whether we're binding SQL UDF by checking if any of the upper subquery context has
461    /// `sql_udf_arguments` set.
462    fn is_binding_subquery_sql_udf(&self) -> bool {
463        self.upper_subquery_contexts
464            .iter()
465            .any(|(context, _)| context.sql_udf_arguments.is_some())
466    }
467
468    /// Bind a parameter for SQL UDF.
469    fn bind_sql_udf_parameter(&mut self, name: &str) -> Result<ExprImpl> {
470        for (depth, context) in std::iter::once(&self.context)
471            .chain((self.upper_subquery_contexts.iter().rev()).map(|(context, _)| context))
472            .enumerate()
473        {
474            // Only lookup the first non-empty udf context. If the parameter is not found in the
475            // current context, we will continue to the upper context.
476            if let Some(args) = &context.sql_udf_arguments {
477                if let Some(expr) = args.get(name) {
478                    // The arguments recorded in the context is relative to the that context.
479                    // We need to shift the depth to the current context.
480                    let mut rewriter = InputRefDepthRewriter::new(depth);
481                    return Ok(rewriter.rewrite_expr(expr.clone()));
482                } else {
483                    // A UDF cannot access parameters from outer UDFs. Do not continue but directly
484                    // return an error.
485                    break;
486                }
487            }
488        }
489
490        Err(ErrorCode::BindError(format!(
491            "failed to find {} parameter {name}",
492            if name.starts_with('$') {
493                "unnamed"
494            } else {
495                "named"
496            }
497        ))
498        .into())
499    }
500
501    fn bind_parameter(&mut self, index: u64) -> Result<ExprImpl> {
502        // Special check for sql udf
503        // Note: This is specific to sql udf with unnamed parameters, since the
504        // parameters will be parsed and treated as `Parameter`.
505        // For detailed explanation, consider checking `bind_column`.
506        if self.is_binding_inline_sql_udf() || self.is_binding_subquery_sql_udf() {
507            let column_name = format!("${index}");
508            return self.bind_sql_udf_parameter(&column_name);
509        }
510
511        Ok(Parameter::new(index, self.param_types.clone()).into())
512    }
513
514    /// Bind `expr (not) between low and high`
515    pub(super) fn bind_between(
516        &mut self,
517        expr: &Expr,
518        negated: bool,
519        low: &Expr,
520        high: &Expr,
521    ) -> Result<ExprImpl> {
522        let expr = self.bind_expr_inner(expr)?;
523        let low = self.bind_expr_inner(low)?;
524        let high = self.bind_expr_inner(high)?;
525
526        let func_call = if negated {
527            // negated = true: expr < low or expr > high
528            FunctionCall::new_unchecked(
529                ExprType::Or,
530                vec![
531                    FunctionCall::new(ExprType::LessThan, vec![expr.clone(), low])?.into(),
532                    FunctionCall::new(ExprType::GreaterThan, vec![expr, high])?.into(),
533                ],
534                DataType::Boolean,
535            )
536        } else {
537            // negated = false: expr >= low and expr <= high
538            FunctionCall::new_unchecked(
539                ExprType::And,
540                vec![
541                    FunctionCall::new(ExprType::GreaterThanOrEqual, vec![expr.clone(), low])?
542                        .into(),
543                    FunctionCall::new(ExprType::LessThanOrEqual, vec![expr, high])?.into(),
544                ],
545                DataType::Boolean,
546            )
547        };
548
549        Ok(func_call.into())
550    }
551
552    fn bind_like(
553        &mut self,
554        expr_type: ExprType,
555        expr: &Expr,
556        negated: bool,
557        pattern: &Expr,
558        escape_char: Option<EscapeChar>,
559    ) -> Result<ExprImpl> {
560        if matches!(pattern, Expr::AllOp(_) | Expr::SomeOp(_)) {
561            if escape_char.is_some() {
562                // PostgreSQL also don't support the pattern due to the complexity of implementation.
563                // The SQL will failed on PostgreSQL 16.1:
564                // ```sql
565                // select 'a' like any(array[null]) escape '';
566                // ```
567                bail_not_implemented!(
568                    "LIKE with both ALL|ANY pattern and escape character is not supported"
569                )
570            }
571            // Use the `bind_binary_op` path to handle the ALL|ANY pattern.
572            let op = match (expr_type, negated) {
573                (ExprType::Like, false) => BinaryOperator::Custom("~~".to_owned()),
574                (ExprType::Like, true) => BinaryOperator::Custom("!~~".to_owned()),
575                (ExprType::ILike, false) => BinaryOperator::Custom("~~*".to_owned()),
576                (ExprType::ILike, true) => BinaryOperator::Custom("!~~*".to_owned()),
577                _ => unreachable!(),
578            };
579            return self.bind_binary_op(expr, &op, pattern);
580        }
581        let expr = self.bind_expr_inner(expr)?;
582        let pattern = self.bind_expr_inner(pattern)?;
583        match (expr.return_type(), pattern.return_type()) {
584            (DataType::Varchar, DataType::Varchar) => {}
585            (string_ty, pattern_ty) => match expr_type {
586                ExprType::Like => bail_no_function!("like({}, {})", string_ty, pattern_ty),
587                ExprType::ILike => bail_no_function!("ilike({}, {})", string_ty, pattern_ty),
588                _ => unreachable!(),
589            },
590        }
591        let args = match escape_char {
592            Some(escape_char) => {
593                let escape_char = ExprImpl::literal_varchar(escape_char.to_string());
594                vec![expr, pattern, escape_char]
595            }
596            None => vec![expr, pattern],
597        };
598        let func_call = FunctionCall::new_unchecked(expr_type, args, DataType::Boolean);
599        let func_call = if negated {
600            FunctionCall::new_unchecked(ExprType::Not, vec![func_call.into()], DataType::Boolean)
601        } else {
602            func_call
603        };
604        Ok(func_call.into())
605    }
606
607    /// Bind `<expr> [ NOT ] SIMILAR TO <pat> ESCAPE <esc_text>`
608    pub(super) fn bind_similar_to(
609        &mut self,
610        expr: &Expr,
611        negated: bool,
612        pattern: &Expr,
613        escape_char: Option<EscapeChar>,
614    ) -> Result<ExprImpl> {
615        let expr = self.bind_expr_inner(expr)?;
616        let pattern = self.bind_expr_inner(pattern)?;
617
618        let esc_inputs = if let Some(escape_char) = escape_char {
619            let escape_char = ExprImpl::literal_varchar(escape_char.to_string());
620            vec![pattern, escape_char]
621        } else {
622            vec![pattern]
623        };
624
625        let esc_call =
626            FunctionCall::new_unchecked(ExprType::SimilarToEscape, esc_inputs, DataType::Varchar);
627
628        let regex_call = FunctionCall::new_unchecked(
629            ExprType::RegexpEq,
630            vec![expr, esc_call.into()],
631            DataType::Boolean,
632        );
633        let func_call = if negated {
634            FunctionCall::new_unchecked(ExprType::Not, vec![regex_call.into()], DataType::Boolean)
635        } else {
636            regex_call
637        };
638
639        Ok(func_call.into())
640    }
641
642    /// The optimization check for the following case-when expression pattern
643    /// e.g., select case 1 when (...) then (...) else (...) end;
644    fn check_constant_case_when_optimization(
645        &mut self,
646        conditions: &[Expr],
647        results_expr: &[ExprImpl],
648        operand: Option<&Expr>,
649        fallback: Option<&ExprImpl>,
650        constant_case_when_eval_inputs: &mut Vec<ExprImpl>,
651    ) -> bool {
652        // The operand value to be compared later
653        let operand_value;
654
655        if let Some(operand) = operand {
656            let Ok(operand) = self.bind_expr_inner(operand) else {
657                return false;
658            };
659            if !operand.is_const() {
660                return false;
661            }
662            operand_value = operand;
663        } else {
664            return false;
665        }
666
667        for (condition, result) in zip_eq_fast(conditions, results_expr) {
668            if let Expr::Value(_) = condition.clone() {
669                let Ok(res) = self.bind_expr_inner(condition) else {
670                    return false;
671                };
672                // Found a match
673                if res == operand_value {
674                    constant_case_when_eval_inputs.push(result.clone());
675                    return true;
676                }
677            } else {
678                return false;
679            }
680        }
681
682        // Otherwise this will eventually go through fallback arm
683        debug_assert!(
684            constant_case_when_eval_inputs.is_empty(),
685            "expect `inputs` to be empty"
686        );
687
688        let Some(fallback) = fallback else {
689            return false;
690        };
691
692        constant_case_when_eval_inputs.push(fallback.clone());
693        true
694    }
695
696    /// Helper function to compare or set column identifier
697    /// used in `check_convert_simple_form`
698    fn compare_or_set(col_expr: &mut Option<Expr>, test_expr: &Expr) -> bool {
699        let Expr::Identifier(test_ident) = test_expr else {
700            return false;
701        };
702        if let Some(expr) = col_expr {
703            let Expr::Identifier(ident) = expr else {
704                return false;
705            };
706            if ident.real_value() != test_ident.real_value() {
707                return false;
708            }
709        } else {
710            *col_expr = Some(Expr::Identifier(test_ident.clone()));
711        }
712        true
713    }
714
715    /// left expression and right expression must be either:
716    /// `<constant> <Eq> <identifier>` or `<identifier> <Eq> <constant>`
717    /// used in `check_convert_simple_form`
718    fn check_invariant(left: &Expr, op: &BinaryOperator, right: &Expr) -> bool {
719        if op != &BinaryOperator::Eq {
720            return false;
721        }
722        if let Expr::Identifier(_) = left {
723            // <identifier> <Eq> <constant>
724            let Expr::Value(_) = right else {
725                return false;
726            };
727        } else {
728            // <constant> <Eq> <identifier>
729            let Expr::Value(_) = left else {
730                return false;
731            };
732            let Expr::Identifier(_) = right else {
733                return false;
734            };
735        }
736        true
737    }
738
739    /// Helper function to extract expression out and insert
740    /// the corresponding bound version to `inputs`
741    /// used in `check_convert_simple_form`
742    /// Note: this function will be invoked per arm
743    fn try_extract_simple_form(
744        &mut self,
745        ident_expr: &Expr,
746        constant_expr: &Expr,
747        column_expr: &mut Option<Expr>,
748        inputs: &mut Vec<ExprImpl>,
749    ) -> bool {
750        if !Self::compare_or_set(column_expr, ident_expr) {
751            return false;
752        }
753        let Ok(bound_expr) = self.bind_expr_inner(constant_expr) else {
754            return false;
755        };
756        inputs.push(bound_expr);
757        true
758    }
759
760    /// See if the case when expression in form
761    /// `select case when <expr_1 = constant> (...with same pattern...) else <constant> end;`
762    /// If so, this expression could also be converted to constant lookup
763    fn check_convert_simple_form(
764        &mut self,
765        conditions: &[Expr],
766        results_expr: &[ExprImpl],
767        fallback: Option<ExprImpl>,
768        constant_lookup_inputs: &mut Vec<ExprImpl>,
769    ) -> bool {
770        let mut column_expr = None;
771
772        for (condition, result) in zip_eq_fast(conditions, results_expr) {
773            if let Expr::BinaryOp { left, op, right } = condition {
774                if !Self::check_invariant(left, op, right) {
775                    return false;
776                }
777                if let Expr::Identifier(_) = &**left {
778                    if !self.try_extract_simple_form(
779                        left,
780                        right,
781                        &mut column_expr,
782                        constant_lookup_inputs,
783                    ) {
784                        return false;
785                    }
786                } else if !self.try_extract_simple_form(
787                    right,
788                    left,
789                    &mut column_expr,
790                    constant_lookup_inputs,
791                ) {
792                    return false;
793                }
794                constant_lookup_inputs.push(result.clone());
795            } else {
796                return false;
797            }
798        }
799
800        // Insert operand first
801        let Some(operand) = column_expr else {
802            return false;
803        };
804        let Ok(bound_operand) = self.bind_expr_inner(&operand) else {
805            return false;
806        };
807        constant_lookup_inputs.insert(0, bound_operand);
808
809        // fallback insertion
810        if let Some(expr) = fallback {
811            constant_lookup_inputs.push(expr);
812        }
813
814        true
815    }
816
817    /// The helper function to check if the current case-when
818    /// expression in `bind_case` could be optimized
819    /// into `ConstantLookupExpression`
820    fn check_bind_case_optimization(
821        &mut self,
822        conditions: &[Expr],
823        results_expr: &[ExprImpl],
824        operand: Option<&Expr>,
825        fallback: Option<ExprImpl>,
826        constant_lookup_inputs: &mut Vec<ExprImpl>,
827    ) -> bool {
828        if conditions.len() < CASE_WHEN_ARMS_OPTIMIZE_LIMIT {
829            return false;
830        }
831
832        if let Some(operand) = operand {
833            let Ok(operand) = self.bind_expr_inner(operand) else {
834                return false;
835            };
836            // This optimization should be done in subsequent optimization phase
837            // if the operand is const
838            // e.g., select case 1 when 1 then 114514 else 1919810 end;
839            if operand.is_const() {
840                return false;
841            }
842            constant_lookup_inputs.push(operand);
843        } else {
844            // Try converting to simple form
845            // see the example as illustrated in `check_convert_simple_form`
846            return self.check_convert_simple_form(
847                conditions,
848                results_expr,
849                fallback,
850                constant_lookup_inputs,
851            );
852        }
853
854        for (condition, result) in zip_eq_fast(conditions, results_expr) {
855            if let Expr::Value(_) = condition {
856                let Ok(input) = self.bind_expr_inner(condition) else {
857                    return false;
858                };
859                constant_lookup_inputs.push(input);
860            } else {
861                // If at least one condition is not in the simple form / not constant,
862                // we can NOT do the subsequent optimization pass
863                return false;
864            }
865
866            constant_lookup_inputs.push(result.clone());
867        }
868
869        // The fallback arm for case-when expression
870        if let Some(expr) = fallback {
871            constant_lookup_inputs.push(expr);
872        }
873
874        true
875    }
876
877    pub(super) fn bind_case(
878        &mut self,
879        operand: Option<&Expr>,
880        conditions: &[Expr],
881        results: &[Expr],
882        else_result: Option<&Expr>,
883    ) -> Result<ExprImpl> {
884        let mut inputs = Vec::new();
885        let results_expr: Vec<ExprImpl> = results
886            .iter()
887            .map(|expr| self.bind_expr_inner(expr))
888            .collect::<Result<_>>()?;
889        let else_result_expr = else_result
890            .map(|expr| self.bind_expr_inner(expr))
891            .transpose()?;
892
893        let mut constant_lookup_inputs = Vec::new();
894        let mut constant_case_when_eval_inputs = Vec::new();
895
896        let constant_case_when_flag = self.check_constant_case_when_optimization(
897            conditions,
898            &results_expr,
899            operand,
900            else_result_expr.as_ref(),
901            &mut constant_case_when_eval_inputs,
902        );
903
904        if constant_case_when_flag {
905            // Sanity check
906            if constant_case_when_eval_inputs.len() != 1 {
907                return Err(ErrorCode::BindError(
908                    "expect `constant_case_when_eval_inputs` only contains a single bound expression".to_owned()
909                )
910                    .into());
911            }
912            // Directly return the first element of the vector
913            return Ok(constant_case_when_eval_inputs[0].take());
914        }
915
916        // See if the case-when expression can be optimized
917        let optimize_flag = self.check_bind_case_optimization(
918            conditions,
919            &results_expr,
920            operand,
921            else_result_expr.clone(),
922            &mut constant_lookup_inputs,
923        );
924
925        if optimize_flag {
926            return Ok(FunctionCall::new(ExprType::ConstantLookup, constant_lookup_inputs)?.into());
927        }
928
929        for (condition, result) in zip_eq_fast(conditions, results_expr) {
930            let condition = condition.clone();
931            let condition = match operand {
932                Some(t) => Expr::BinaryOp {
933                    left: t.clone().into(),
934                    op: BinaryOperator::Eq,
935                    right: Box::new(condition),
936                },
937                None => condition,
938            };
939            inputs.push(
940                self.bind_expr_inner(&condition)
941                    .and_then(|expr| expr.enforce_bool_clause("CASE WHEN"))?,
942            );
943            inputs.push(result);
944        }
945
946        // The fallback arm for case-when expression
947        if let Some(expr) = else_result_expr {
948            inputs.push(expr);
949        }
950
951        if inputs.iter().any(ExprImpl::has_table_function) {
952            return Err(
953                ErrorCode::BindError("table functions are not allowed in CASE".into()).into(),
954            );
955        }
956
957        Ok(FunctionCall::new(ExprType::Case, inputs)?.into())
958    }
959
960    pub(super) fn bind_is_operator(
961        &mut self,
962        func_type: ExprType,
963        expr: &Expr,
964    ) -> Result<ExprImpl> {
965        let expr = self.bind_expr_inner(expr)?;
966        Ok(FunctionCall::new(func_type, vec![expr])?.into())
967    }
968
969    pub(super) fn bind_is_unknown(&mut self, func_type: ExprType, expr: &Expr) -> Result<ExprImpl> {
970        let expr = self
971            .bind_expr_inner(expr)?
972            .cast_implicit(&DataType::Boolean)?;
973        Ok(FunctionCall::new(func_type, vec![expr])?.into())
974    }
975
976    pub(super) fn bind_distinct_from(&mut self, left: &Expr, right: &Expr) -> Result<ExprImpl> {
977        let left = self.bind_expr_inner(left)?;
978        let right = self.bind_expr_inner(right)?;
979        let func_call = FunctionCall::new(ExprType::IsDistinctFrom, vec![left, right]);
980        Ok(func_call?.into())
981    }
982
983    pub(super) fn bind_not_distinct_from(&mut self, left: &Expr, right: &Expr) -> Result<ExprImpl> {
984        let left = self.bind_expr_inner(left)?;
985        let right = self.bind_expr_inner(right)?;
986        let func_call = FunctionCall::new(ExprType::IsNotDistinctFrom, vec![left, right]);
987        Ok(func_call?.into())
988    }
989
990    pub(super) fn bind_cast(&mut self, expr: &Expr, data_type: &AstDataType) -> Result<ExprImpl> {
991        match &data_type {
992            // Casting to Regclass type means getting the oid of expr.
993            // See https://www.postgresql.org/docs/current/datatype-oid.html.
994            AstDataType::Regclass => {
995                let input = self.bind_expr_inner(expr)?;
996                Ok(input.cast_to_regclass()?)
997            }
998            AstDataType::Regproc => {
999                let lhs = self.bind_expr_inner(expr)?;
1000                let lhs_ty = lhs.return_type();
1001                if lhs_ty == DataType::Varchar {
1002                    // FIXME: Currently, we only allow VARCHAR to be casted to Regproc.
1003                    // FIXME: Check whether it's a valid proc
1004                    // FIXME: The return type should be casted to Regproc, but we don't have this type.
1005                    Ok(lhs)
1006                } else {
1007                    Err(ErrorCode::BindError(format!("Can't cast {} to regproc", lhs_ty)).into())
1008                }
1009            }
1010            // Redirect cast char to varchar to make system like Metabase happy.
1011            // Char is not supported in RisingWave, but some ecosystem tools like Metabase will use it.
1012            // Notice that the behavior of `char` and `varchar` is different in PostgreSQL.
1013            // The following sql result should be different in PostgreSQL:
1014            // ```
1015            // select 'a'::char(2) = 'a '::char(2);
1016            // ----------
1017            // t
1018            //
1019            // select 'a'::varchar = 'a '::varchar;
1020            // ----------
1021            // f
1022            // ```
1023            AstDataType::Char(_) => self.bind_cast_inner(expr, &DataType::Varchar),
1024            _ => self.bind_cast_inner(expr, &bind_data_type(data_type)?),
1025        }
1026    }
1027
1028    pub fn bind_cast_inner(&mut self, expr: &Expr, data_type: &DataType) -> Result<ExprImpl> {
1029        match (expr, data_type) {
1030            (Expr::Array(Array { elem: expr, .. }), DataType::List(list_type)) => {
1031                self.bind_array_cast(expr, list_type.elem())
1032            }
1033            (Expr::Map { entries }, DataType::Map(m)) => self.bind_map_cast(entries, m),
1034            (expr, data_type) => {
1035                let lhs = self.bind_expr_inner(expr)?;
1036                lhs.cast_explicit(data_type).map_err(Into::into)
1037            }
1038        }
1039    }
1040
1041    pub fn bind_collate(&mut self, expr: &Expr, collation: &ObjectName) -> Result<ExprImpl> {
1042        if !["C", "POSIX"].contains(&collation.real_value().as_str()) {
1043            bail_not_implemented!("Collate collation other than `C` or `POSIX` is not implemented");
1044        }
1045
1046        let bound_inner = self.bind_expr_inner(expr)?;
1047        let ret_type = bound_inner.return_type();
1048
1049        match ret_type {
1050            DataType::Varchar => {}
1051            _ => {
1052                return Err(ErrorCode::NotSupported(
1053                    format!("{} is not a collatable data type", ret_type),
1054                    "The only built-in collatable data types are `varchar`, please check your type"
1055                        .into(),
1056                )
1057                .into());
1058            }
1059        }
1060
1061        Ok(bound_inner)
1062    }
1063}
1064
1065pub fn bind_data_type(data_type: &AstDataType) -> Result<DataType> {
1066    let new_err = || not_implemented!("unsupported data type: {:}", data_type);
1067    let data_type = match data_type {
1068        AstDataType::Boolean => DataType::Boolean,
1069        AstDataType::SmallInt => DataType::Int16,
1070        AstDataType::Int => DataType::Int32,
1071        AstDataType::BigInt => DataType::Int64,
1072        AstDataType::Real | AstDataType::Float(Some(1..=24)) => DataType::Float32,
1073        AstDataType::Double | AstDataType::Float(Some(25..=53) | None) => DataType::Float64,
1074        AstDataType::Float(Some(0 | 54..)) => unreachable!(),
1075        AstDataType::Decimal(None, None) => DataType::Decimal,
1076        AstDataType::Varchar | AstDataType::Text => DataType::Varchar,
1077        AstDataType::Date => DataType::Date,
1078        AstDataType::Time(false) => DataType::Time,
1079        AstDataType::Timestamp(false) => DataType::Timestamp,
1080        AstDataType::Timestamp(true) => DataType::Timestamptz,
1081        AstDataType::Interval => DataType::Interval,
1082        AstDataType::Array(datatype) => DataType::list(bind_data_type(datatype)?),
1083        AstDataType::Char(..) => {
1084            bail_not_implemented!("CHAR is not supported, please use VARCHAR instead")
1085        }
1086        AstDataType::Struct(types) => StructType::new(
1087            types
1088                .iter()
1089                .map(|f| Ok((f.name.real_value(), bind_data_type(&f.data_type)?)))
1090                .collect::<Result<Vec<_>>>()?,
1091        )
1092        .into(),
1093        AstDataType::Map(kv) => {
1094            let key = bind_data_type(&kv.0)?;
1095            let value = bind_data_type(&kv.1)?;
1096            DataType::Map(MapType::try_from_kv(key, value).map_err(ErrorCode::BindError)?)
1097        }
1098        AstDataType::Custom(qualified_type_name) => {
1099            let idents = qualified_type_name
1100                .0
1101                .iter()
1102                .map(|n| n.real_value())
1103                .collect_vec();
1104            let name = if idents.len() == 1 {
1105                idents[0].as_str() // `int2`
1106            } else if idents.len() == 2 && idents[0] == PG_CATALOG_SCHEMA_NAME {
1107                idents[1].as_str() // `pg_catalog.text`
1108            } else {
1109                return Err(new_err().into());
1110            };
1111
1112            // In PostgreSQL, these are non-keywords or non-reserved keywords but pre-defined
1113            // names that could be extended by `CREATE TYPE`.
1114            match name {
1115                "int2" => DataType::Int16,
1116                "int4" => DataType::Int32,
1117                "int8" => DataType::Int64,
1118                "rw_int256" => DataType::Int256,
1119                "float4" => DataType::Float32,
1120                "float8" => DataType::Float64,
1121                "timestamptz" => DataType::Timestamptz,
1122                "text" => DataType::Varchar,
1123                "serial" => {
1124                    return Err(ErrorCode::NotSupported(
1125                        "Column type SERIAL is not supported".into(),
1126                        "Please remove the SERIAL column".into(),
1127                    )
1128                    .into());
1129                }
1130                _ => return Err(new_err().into()),
1131            }
1132        }
1133        AstDataType::Bytea => DataType::Bytea,
1134        AstDataType::Jsonb => DataType::Jsonb,
1135        AstDataType::Vector(size) => match (1..=DataType::VEC_MAX_SIZE).contains(&(*size as _)) {
1136            true => DataType::Vector(*size as _),
1137            false => {
1138                return Err(ErrorCode::BindError(format!(
1139                    "vector size {} is out of range [1, {}]",
1140                    size,
1141                    DataType::VEC_MAX_SIZE
1142                ))
1143                .into());
1144            }
1145        },
1146        AstDataType::Regclass
1147        | AstDataType::Regproc
1148        | AstDataType::Uuid
1149        | AstDataType::Decimal(_, _)
1150        | AstDataType::Time(true) => return Err(new_err().into()),
1151    };
1152    Ok(data_type)
1153}