risingwave_frontend/expr/
mod.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use enum_as_inner::EnumAsInner;
16use fixedbitset::FixedBitSet;
17use futures::FutureExt;
18use paste::paste;
19use risingwave_common::array::ListValue;
20use risingwave_common::types::{DataType, Datum, JsonbVal, MapType, Scalar, ScalarImpl};
21use risingwave_expr::aggregate::PbAggKind;
22use risingwave_expr::expr::build_from_prost;
23use risingwave_pb::expr::expr_node::RexNode;
24use risingwave_pb::expr::{ExprNode, ProjectSetSelectItem};
25use user_defined_function::UserDefinedFunctionDisplay;
26
27use crate::error::{ErrorCode, Result as RwResult};
28
29mod agg_call;
30mod correlated_input_ref;
31mod function_call;
32mod function_call_with_lambda;
33mod input_ref;
34mod literal;
35mod now;
36mod parameter;
37mod pure;
38mod subquery;
39mod table_function;
40mod user_defined_function;
41mod window_function;
42
43mod order_by_expr;
44pub use order_by_expr::{OrderBy, OrderByExpr};
45
46mod expr_mutator;
47mod expr_rewriter;
48mod expr_visitor;
49pub mod function_impl;
50mod session_timezone;
51mod type_inference;
52mod utils;
53
54pub use agg_call::AggCall;
55pub use correlated_input_ref::{CorrelatedId, CorrelatedInputRef, Depth};
56pub use expr_mutator::ExprMutator;
57pub use expr_rewriter::{ExprRewriter, default_rewrite_expr};
58pub use expr_visitor::{ExprVisitor, default_visit_expr};
59pub use function_call::{FunctionCall, FunctionCallDisplay, is_row_function};
60pub use function_call_with_lambda::FunctionCallWithLambda;
61pub use input_ref::{InputRef, InputRefDisplay, input_ref_to_column_indices};
62pub use literal::Literal;
63pub use now::{InlineNowProcTime, Now, NowProcTimeFinder};
64pub use parameter::Parameter;
65pub use pure::*;
66pub use risingwave_pb::expr::expr_node::Type as ExprType;
67pub use session_timezone::{SessionTimezone, TimestamptzExprFinder};
68pub use subquery::{Subquery, SubqueryKind};
69pub use table_function::{TableFunction, TableFunctionType};
70pub use type_inference::*;
71pub use user_defined_function::UserDefinedFunction;
72pub use utils::*;
73pub use window_function::WindowFunction;
74
75const EXPR_DEPTH_THRESHOLD: usize = 30;
76const EXPR_TOO_DEEP_NOTICE: &str = "Some expression is too complicated. \
77Consider simplifying or splitting the query if you encounter any issues.";
78
79/// the trait of bound expressions
80pub trait Expr: Into<ExprImpl> {
81    /// Get the return type of the expr
82    fn return_type(&self) -> DataType;
83
84    /// Serialize the expression
85    fn to_expr_proto(&self) -> ExprNode;
86}
87
88macro_rules! impl_expr_impl {
89    ($($t:ident,)*) => {
90        #[derive(Clone, Eq, PartialEq, Hash, EnumAsInner)]
91        pub enum ExprImpl {
92            $($t(Box<$t>),)*
93        }
94
95        impl ExprImpl {
96            pub fn variant_name(&self) -> &'static str {
97                match self {
98                    $(ExprImpl::$t(_) => stringify!($t),)*
99                }
100            }
101        }
102
103        $(
104        impl From<$t> for ExprImpl {
105            fn from(o: $t) -> ExprImpl {
106                ExprImpl::$t(Box::new(o))
107            }
108        })*
109
110        impl Expr for ExprImpl {
111            fn return_type(&self) -> DataType {
112                match self {
113                    $(ExprImpl::$t(expr) => expr.return_type(),)*
114                }
115            }
116
117            fn to_expr_proto(&self) -> ExprNode {
118                match self {
119                    $(ExprImpl::$t(expr) => expr.to_expr_proto(),)*
120                }
121            }
122        }
123    };
124}
125
126impl_expr_impl!(
127    // BoundColumnRef, might be used in binder.
128    CorrelatedInputRef,
129    InputRef,
130    Literal,
131    FunctionCall,
132    FunctionCallWithLambda,
133    AggCall,
134    Subquery,
135    TableFunction,
136    WindowFunction,
137    UserDefinedFunction,
138    Parameter,
139    Now,
140);
141
142impl ExprImpl {
143    /// A literal int value.
144    #[inline(always)]
145    pub fn literal_int(v: i32) -> Self {
146        Literal::new(Some(v.to_scalar_value()), DataType::Int32).into()
147    }
148
149    /// A literal bigint value
150    #[inline(always)]
151    pub fn literal_bigint(v: i64) -> Self {
152        Literal::new(Some(v.to_scalar_value()), DataType::Int64).into()
153    }
154
155    /// A literal float64 value.
156    #[inline(always)]
157    pub fn literal_f64(v: f64) -> Self {
158        Literal::new(Some(v.into()), DataType::Float64).into()
159    }
160
161    /// A literal boolean value.
162    #[inline(always)]
163    pub fn literal_bool(v: bool) -> Self {
164        Literal::new(Some(v.to_scalar_value()), DataType::Boolean).into()
165    }
166
167    /// A literal varchar value.
168    #[inline(always)]
169    pub fn literal_varchar(v: String) -> Self {
170        Literal::new(Some(v.into()), DataType::Varchar).into()
171    }
172
173    /// A literal null value.
174    #[inline(always)]
175    pub fn literal_null(element_type: DataType) -> Self {
176        Literal::new(None, element_type).into()
177    }
178
179    /// A literal jsonb value.
180    #[inline(always)]
181    pub fn literal_jsonb(v: JsonbVal) -> Self {
182        Literal::new(Some(v.into()), DataType::Jsonb).into()
183    }
184
185    /// A literal list value.
186    #[inline(always)]
187    pub fn literal_list(v: ListValue, element_type: DataType) -> Self {
188        Literal::new(
189            Some(v.to_scalar_value()),
190            DataType::List(Box::new(element_type)),
191        )
192        .into()
193    }
194
195    /// Takes the expression, leaving a literal null of the same type in its place.
196    pub fn take(&mut self) -> Self {
197        std::mem::replace(self, Self::literal_null(self.return_type()))
198    }
199
200    /// A `count(*)` aggregate function.
201    #[inline(always)]
202    pub fn count_star() -> Self {
203        AggCall::new(
204            PbAggKind::Count.into(),
205            vec![],
206            false,
207            OrderBy::any(),
208            Condition::true_cond(),
209            vec![],
210        )
211        .unwrap()
212        .into()
213    }
214
215    /// Create a new expression by merging the given expressions by `And`.
216    ///
217    /// If `exprs` is empty, return a literal `true`.
218    pub fn and(exprs: impl IntoIterator<Item = ExprImpl>) -> Self {
219        merge_expr_by_logical(exprs, ExprType::And, ExprImpl::literal_bool(true))
220    }
221
222    /// Create a new expression by merging the given expressions by `Or`.
223    ///
224    /// If `exprs` is empty, return a literal `false`.
225    pub fn or(exprs: impl IntoIterator<Item = ExprImpl>) -> Self {
226        merge_expr_by_logical(exprs, ExprType::Or, ExprImpl::literal_bool(false))
227    }
228
229    /// Collect all `InputRef`s' indexes in the expression.
230    ///
231    /// # Panics
232    /// Panics if `input_ref >= input_col_num`.
233    pub fn collect_input_refs(&self, input_col_num: usize) -> FixedBitSet {
234        collect_input_refs(input_col_num, [self])
235    }
236
237    /// Check if the expression has no side effects and output is deterministic
238    pub fn is_pure(&self) -> bool {
239        is_pure(self)
240    }
241
242    pub fn is_impure(&self) -> bool {
243        is_impure(self)
244    }
245
246    /// Count `Now`s in the expression.
247    pub fn count_nows(&self) -> usize {
248        let mut visitor = CountNow::default();
249        visitor.visit_expr(self);
250        visitor.count()
251    }
252
253    /// Check whether self is literal NULL.
254    pub fn is_null(&self) -> bool {
255        matches!(self, ExprImpl::Literal(literal) if literal.get_data().is_none())
256    }
257
258    /// Check whether self is a literal NULL or literal string.
259    pub fn is_untyped(&self) -> bool {
260        matches!(self, ExprImpl::Literal(literal) if literal.is_untyped())
261            || matches!(self, ExprImpl::Parameter(parameter) if !parameter.has_infer())
262    }
263
264    /// Shorthand to create cast expr to `target` type in implicit context.
265    pub fn cast_implicit(mut self, target: DataType) -> Result<ExprImpl, CastError> {
266        FunctionCall::cast_mut(&mut self, target, CastContext::Implicit)?;
267        Ok(self)
268    }
269
270    /// Shorthand to create cast expr to `target` type in assign context.
271    pub fn cast_assign(mut self, target: DataType) -> Result<ExprImpl, CastError> {
272        FunctionCall::cast_mut(&mut self, target, CastContext::Assign)?;
273        Ok(self)
274    }
275
276    /// Shorthand to create cast expr to `target` type in explicit context.
277    pub fn cast_explicit(mut self, target: DataType) -> Result<ExprImpl, CastError> {
278        FunctionCall::cast_mut(&mut self, target, CastContext::Explicit)?;
279        Ok(self)
280    }
281
282    /// Shorthand to inplace cast expr to `target` type in implicit context.
283    pub fn cast_implicit_mut(&mut self, target: DataType) -> Result<(), CastError> {
284        FunctionCall::cast_mut(self, target, CastContext::Implicit)
285    }
286
287    /// Shorthand to inplace cast expr to `target` type in explicit context.
288    pub fn cast_explicit_mut(&mut self, target: DataType) -> Result<(), CastError> {
289        FunctionCall::cast_mut(self, target, CastContext::Explicit)
290    }
291
292    /// Casting to Regclass type means getting the oid of expr.
293    /// See <https://www.postgresql.org/docs/current/datatype-oid.html>
294    pub fn cast_to_regclass(self) -> Result<ExprImpl, CastError> {
295        match self.return_type() {
296            DataType::Varchar => Ok(ExprImpl::FunctionCall(Box::new(
297                FunctionCall::new_unchecked(ExprType::CastRegclass, vec![self], DataType::Int32),
298            ))),
299            DataType::Int32 => Ok(self),
300            dt if dt.is_int() => Ok(self.cast_explicit(DataType::Int32)?),
301            _ => bail_cast_error!("unsupported input type"),
302        }
303    }
304
305    /// Shorthand to inplace cast expr to `regclass` type.
306    pub fn cast_to_regclass_mut(&mut self) -> Result<(), CastError> {
307        let owned = std::mem::replace(self, ExprImpl::literal_bool(false));
308        *self = owned.cast_to_regclass()?;
309        Ok(())
310    }
311
312    /// Ensure the return type of this expression is an array of some type.
313    pub fn ensure_array_type(&self) -> Result<(), ErrorCode> {
314        if self.is_untyped() {
315            return Err(ErrorCode::BindError(
316                "could not determine polymorphic type because input has type unknown".into(),
317            ));
318        }
319        match self.return_type() {
320            DataType::List(_) => Ok(()),
321            t => Err(ErrorCode::BindError(format!("expects array but got {t}"))),
322        }
323    }
324
325    /// Ensure the return type of this expression is a map of some type.
326    pub fn try_into_map_type(&self) -> Result<MapType, ErrorCode> {
327        if self.is_untyped() {
328            return Err(ErrorCode::BindError(
329                "could not determine polymorphic type because input has type unknown".into(),
330            ));
331        }
332        match self.return_type() {
333            DataType::Map(m) => Ok(m),
334            t => Err(ErrorCode::BindError(format!("expects map but got {t}"))),
335        }
336    }
337
338    /// Shorthand to enforce implicit cast to boolean
339    pub fn enforce_bool_clause(self, clause: &str) -> RwResult<ExprImpl> {
340        if self.is_untyped() {
341            let inner = self.cast_implicit(DataType::Boolean)?;
342            return Ok(inner);
343        }
344        let return_type = self.return_type();
345        if return_type != DataType::Boolean {
346            bail!(
347                "argument of {} must be boolean, not type {:?}",
348                clause,
349                return_type
350            )
351        }
352        Ok(self)
353    }
354
355    /// Create "cast" expr to string (`varchar`) type. This is different from a real cast, as
356    /// boolean is converted to a single char rather than full word.
357    ///
358    /// Choose between `cast_output` and `cast_{assign,explicit}(Varchar)` based on `PostgreSQL`'s
359    /// behavior on bools. For example, `concat(':', true)` is `:t` but `':' || true` is `:true`.
360    /// All other types have the same behavior when formatting to output and casting to string.
361    ///
362    /// References in `PostgreSQL`:
363    /// * [cast](https://github.com/postgres/postgres/blob/a3ff08e0b08dbfeb777ccfa8f13ebaa95d064c04/src/include/catalog/pg_cast.dat#L437-L444)
364    /// * [impl](https://github.com/postgres/postgres/blob/27b77ecf9f4d5be211900eda54d8155ada50d696/src/backend/utils/adt/bool.c#L204-L209)
365    pub fn cast_output(self) -> RwResult<ExprImpl> {
366        if self.return_type() == DataType::Boolean {
367            return Ok(FunctionCall::new(ExprType::BoolOut, vec![self])?.into());
368        }
369        // Use normal cast for other types. Both `assign` and `explicit` can pass the castability
370        // check and there is no difference.
371        self.cast_assign(DataType::Varchar)
372            .map_err(|err| err.into())
373    }
374
375    /// Evaluate the expression on the given input.
376    ///
377    /// TODO: This is a naive implementation. We should avoid proto ser/de.
378    /// Tracking issue: <https://github.com/risingwavelabs/risingwave/issues/3479>
379    pub async fn eval_row(&self, input: &OwnedRow) -> RwResult<Datum> {
380        let backend_expr = build_from_prost(&self.to_expr_proto())?;
381        Ok(backend_expr.eval_row(input).await?)
382    }
383
384    /// Try to evaluate an expression if it's a constant expression by `ExprImpl::is_const`.
385    ///
386    /// Returns...
387    /// - `None` if it's not a constant expression,
388    /// - `Some(Ok(_))` if constant evaluation succeeds,
389    /// - `Some(Err(_))` if there's an error while evaluating a constant expression.
390    pub fn try_fold_const(&self) -> Option<RwResult<Datum>> {
391        if self.is_const() {
392            self.eval_row(&OwnedRow::empty())
393                .now_or_never()
394                .expect("constant expression should not be async")
395                .into()
396        } else {
397            None
398        }
399    }
400
401    /// Similar to `ExprImpl::try_fold_const`, but panics if the expression is not constant.
402    pub fn fold_const(&self) -> RwResult<Datum> {
403        self.try_fold_const().expect("expression is not constant")
404    }
405}
406
407/// Implement helper functions which recursively checks whether an variant is included in the
408/// expression. e.g., `has_subquery(&self) -> bool`
409///
410/// It will not traverse inside subqueries.
411macro_rules! impl_has_variant {
412    ( $($variant:ty),* ) => {
413        paste! {
414            impl ExprImpl {
415                $(
416                    pub fn [<has_ $variant:snake>](&self) -> bool {
417                        struct Has { has: bool }
418
419                        impl ExprVisitor for Has {
420                            fn [<visit_ $variant:snake>](&mut self, _: &$variant) {
421                                self.has = true;
422                            }
423                        }
424
425                        let mut visitor = Has { has: false };
426                        visitor.visit_expr(self);
427                        visitor.has
428                    }
429                )*
430            }
431        }
432    };
433}
434
435impl_has_variant! {InputRef, Literal, FunctionCall, FunctionCallWithLambda, AggCall, Subquery, TableFunction, WindowFunction, UserDefinedFunction, Now}
436
437#[derive(Debug, Clone, PartialEq, Eq, Hash)]
438pub struct InequalityInputPair {
439    /// Input index of greater side of inequality.
440    pub(crate) key_required_larger: usize,
441    /// Input index of less side of inequality.
442    pub(crate) key_required_smaller: usize,
443    /// greater >= less + `delta_expression`
444    pub(crate) delta_expression: Option<(ExprType, ExprImpl)>,
445}
446
447impl InequalityInputPair {
448    fn new(
449        key_required_larger: usize,
450        key_required_smaller: usize,
451        delta_expression: Option<(ExprType, ExprImpl)>,
452    ) -> Self {
453        Self {
454            key_required_larger,
455            key_required_smaller,
456            delta_expression,
457        }
458    }
459}
460
461impl ExprImpl {
462    /// This function is not meant to be called. In most cases you would want
463    /// [`ExprImpl::has_correlated_input_ref_by_depth`].
464    ///
465    /// When an expr contains a [`CorrelatedInputRef`] with lower depth, the whole expr is still
466    /// considered to be uncorrelated, and can be checked with [`ExprImpl::has_subquery`] as well.
467    /// See examples on [`crate::binder::BoundQuery::is_correlated_by_depth`] for details.
468    ///
469    /// This is a placeholder to trigger a compiler error when a trivial implementation checking for
470    /// enum variant is generated by accident. It cannot be called either because you cannot pass
471    /// `Infallible` to it.
472    pub fn has_correlated_input_ref(&self, _: std::convert::Infallible) -> bool {
473        unreachable!()
474    }
475
476    /// Used to check whether the expression has [`CorrelatedInputRef`].
477    ///
478    /// This is the core logic that supports [`crate::binder::BoundQuery::is_correlated_by_depth`]. Check the
479    /// doc of it for examples of `depth` being equal, less or greater.
480    // We need to traverse inside subqueries.
481    pub fn has_correlated_input_ref_by_depth(&self, depth: Depth) -> bool {
482        struct Has {
483            depth: usize,
484            has: bool,
485        }
486
487        impl ExprVisitor for Has {
488            fn visit_correlated_input_ref(&mut self, correlated_input_ref: &CorrelatedInputRef) {
489                if correlated_input_ref.depth() == self.depth {
490                    self.has = true;
491                }
492            }
493
494            fn visit_subquery(&mut self, subquery: &Subquery) {
495                self.has |= subquery.is_correlated_by_depth(self.depth);
496            }
497        }
498
499        let mut visitor = Has { depth, has: false };
500        visitor.visit_expr(self);
501        visitor.has
502    }
503
504    pub fn has_correlated_input_ref_by_correlated_id(&self, correlated_id: CorrelatedId) -> bool {
505        struct Has {
506            correlated_id: CorrelatedId,
507            has: bool,
508        }
509
510        impl ExprVisitor for Has {
511            fn visit_correlated_input_ref(&mut self, correlated_input_ref: &CorrelatedInputRef) {
512                if correlated_input_ref.correlated_id() == self.correlated_id {
513                    self.has = true;
514                }
515            }
516
517            fn visit_subquery(&mut self, subquery: &Subquery) {
518                self.has |= subquery.is_correlated_by_correlated_id(self.correlated_id);
519            }
520        }
521
522        let mut visitor = Has {
523            correlated_id,
524            has: false,
525        };
526        visitor.visit_expr(self);
527        visitor.has
528    }
529
530    /// Collect `CorrelatedInputRef`s in `ExprImpl` by relative `depth`, return their indices, and
531    /// assign absolute `correlated_id` for them.
532    pub fn collect_correlated_indices_by_depth_and_assign_id(
533        &mut self,
534        depth: Depth,
535        correlated_id: CorrelatedId,
536    ) -> Vec<usize> {
537        struct Collector {
538            depth: Depth,
539            correlated_indices: Vec<usize>,
540            correlated_id: CorrelatedId,
541        }
542
543        impl ExprMutator for Collector {
544            fn visit_correlated_input_ref(
545                &mut self,
546                correlated_input_ref: &mut CorrelatedInputRef,
547            ) {
548                if correlated_input_ref.depth() == self.depth {
549                    self.correlated_indices.push(correlated_input_ref.index());
550                    correlated_input_ref.set_correlated_id(self.correlated_id);
551                }
552            }
553
554            fn visit_subquery(&mut self, subquery: &mut Subquery) {
555                self.correlated_indices.extend(
556                    subquery.collect_correlated_indices_by_depth_and_assign_id(
557                        self.depth,
558                        self.correlated_id,
559                    ),
560                );
561            }
562        }
563
564        let mut collector = Collector {
565            depth,
566            correlated_indices: vec![],
567            correlated_id,
568        };
569        collector.visit_expr(self);
570        collector.correlated_indices.sort();
571        collector.correlated_indices.dedup();
572        collector.correlated_indices
573    }
574
575    /// Checks whether this is a constant expr that can be evaluated over a dummy chunk.
576    ///
577    /// The expression tree should only consist of literals and **pure** function calls.
578    pub fn is_const(&self) -> bool {
579        let only_literal_and_func = {
580            struct HasOthers {
581                has_others: bool,
582            }
583
584            impl ExprVisitor for HasOthers {
585                fn visit_expr(&mut self, expr: &ExprImpl) {
586                    match expr {
587                        ExprImpl::CorrelatedInputRef(_)
588                        | ExprImpl::InputRef(_)
589                        | ExprImpl::AggCall(_)
590                        | ExprImpl::Subquery(_)
591                        | ExprImpl::TableFunction(_)
592                        | ExprImpl::WindowFunction(_)
593                        | ExprImpl::UserDefinedFunction(_)
594                        | ExprImpl::Parameter(_)
595                        | ExprImpl::Now(_) => self.has_others = true,
596                        ExprImpl::Literal(_inner) => {}
597                        ExprImpl::FunctionCall(inner) => {
598                            if !self.is_short_circuit(inner) {
599                                // only if the current `func_call` is *not* a short-circuit
600                                // expression, e.g., true or (...) | false and (...),
601                                // shall we proceed to visit it.
602                                self.visit_function_call(inner)
603                            }
604                        }
605                        ExprImpl::FunctionCallWithLambda(inner) => {
606                            self.visit_function_call_with_lambda(inner)
607                        }
608                    }
609                }
610            }
611
612            impl HasOthers {
613                fn is_short_circuit(&self, func_call: &FunctionCall) -> bool {
614                    /// evaluate the first parameter of `Or` or `And` function call
615                    fn eval_first(e: &ExprImpl, expect: bool) -> bool {
616                        if let ExprImpl::Literal(l) = e {
617                            *l.get_data() == Some(ScalarImpl::Bool(expect))
618                        } else {
619                            false
620                        }
621                    }
622
623                    match func_call.func_type {
624                        ExprType::Or => eval_first(&func_call.inputs()[0], true),
625                        ExprType::And => eval_first(&func_call.inputs()[0], false),
626                        _ => false,
627                    }
628                }
629            }
630
631            let mut visitor = HasOthers { has_others: false };
632            visitor.visit_expr(self);
633            !visitor.has_others
634        };
635
636        let is_pure = self.is_pure();
637
638        only_literal_and_func && is_pure
639    }
640
641    /// Returns the `InputRefs` of an Equality predicate if it matches
642    /// ordered by the canonical ordering (lower, higher), else returns None
643    pub fn as_eq_cond(&self) -> Option<(InputRef, InputRef)> {
644        if let ExprImpl::FunctionCall(function_call) = self
645            && function_call.func_type() == ExprType::Equal
646            && let (_, ExprImpl::InputRef(x), ExprImpl::InputRef(y)) =
647                function_call.clone().decompose_as_binary()
648        {
649            if x.index() < y.index() {
650                Some((*x, *y))
651            } else {
652                Some((*y, *x))
653            }
654        } else {
655            None
656        }
657    }
658
659    pub fn as_is_not_distinct_from_cond(&self) -> Option<(InputRef, InputRef)> {
660        if let ExprImpl::FunctionCall(function_call) = self
661            && function_call.func_type() == ExprType::IsNotDistinctFrom
662            && let (_, ExprImpl::InputRef(x), ExprImpl::InputRef(y)) =
663                function_call.clone().decompose_as_binary()
664        {
665            if x.index() < y.index() {
666                Some((*x, *y))
667            } else {
668                Some((*y, *x))
669            }
670        } else {
671            None
672        }
673    }
674
675    pub fn reverse_comparison(comparison: ExprType) -> ExprType {
676        match comparison {
677            ExprType::LessThan => ExprType::GreaterThan,
678            ExprType::LessThanOrEqual => ExprType::GreaterThanOrEqual,
679            ExprType::GreaterThan => ExprType::LessThan,
680            ExprType::GreaterThanOrEqual => ExprType::LessThanOrEqual,
681            ExprType::Equal | ExprType::IsNotDistinctFrom => comparison,
682            _ => unreachable!(),
683        }
684    }
685
686    pub fn as_comparison_cond(&self) -> Option<(InputRef, ExprType, InputRef)> {
687        if let ExprImpl::FunctionCall(function_call) = self {
688            match function_call.func_type() {
689                ty @ (ExprType::LessThan
690                | ExprType::LessThanOrEqual
691                | ExprType::GreaterThan
692                | ExprType::GreaterThanOrEqual) => {
693                    let (_, op1, op2) = function_call.clone().decompose_as_binary();
694                    if let (ExprImpl::InputRef(x), ExprImpl::InputRef(y)) = (op1, op2) {
695                        if x.index < y.index {
696                            Some((*x, ty, *y))
697                        } else {
698                            Some((*y, Self::reverse_comparison(ty), *x))
699                        }
700                    } else {
701                        None
702                    }
703                }
704                _ => None,
705            }
706        } else {
707            None
708        }
709    }
710
711    /// Accepts expressions of the form `input_expr cmp now_expr` or `now_expr cmp input_expr`,
712    /// where `input_expr` contains an `InputRef` and contains no `now()`, and `now_expr`
713    /// contains a `now()` but no `InputRef`.
714    ///
715    /// Canonicalizes to the first ordering and returns `(input_expr, cmp, now_expr)`
716    pub fn as_now_comparison_cond(&self) -> Option<(ExprImpl, ExprType, ExprImpl)> {
717        if let ExprImpl::FunctionCall(function_call) = self {
718            match function_call.func_type() {
719                ty @ (ExprType::Equal
720                | ExprType::LessThan
721                | ExprType::LessThanOrEqual
722                | ExprType::GreaterThan
723                | ExprType::GreaterThanOrEqual) => {
724                    let (_, op1, op2) = function_call.clone().decompose_as_binary();
725                    if !op1.has_now()
726                        && op1.has_input_ref()
727                        && op2.has_now()
728                        && !op2.has_input_ref()
729                    {
730                        Some((op1, ty, op2))
731                    } else if op1.has_now()
732                        && !op1.has_input_ref()
733                        && !op2.has_now()
734                        && op2.has_input_ref()
735                    {
736                        Some((op2, Self::reverse_comparison(ty), op1))
737                    } else {
738                        None
739                    }
740                }
741                _ => None,
742            }
743        } else {
744            None
745        }
746    }
747
748    /// Accepts expressions of the form `InputRef cmp InputRef [+- const_expr]` or
749    /// `InputRef [+- const_expr] cmp InputRef`.
750    pub(crate) fn as_input_comparison_cond(&self) -> Option<InequalityInputPair> {
751        if let ExprImpl::FunctionCall(function_call) = self {
752            match function_call.func_type() {
753                ty @ (ExprType::LessThan
754                | ExprType::LessThanOrEqual
755                | ExprType::GreaterThan
756                | ExprType::GreaterThanOrEqual) => {
757                    let (_, mut op1, mut op2) = function_call.clone().decompose_as_binary();
758                    if matches!(ty, ExprType::LessThan | ExprType::LessThanOrEqual) {
759                        std::mem::swap(&mut op1, &mut op2);
760                    }
761                    if let (Some((lft_input, lft_offset)), Some((rht_input, rht_offset))) =
762                        (op1.as_input_offset(), op2.as_input_offset())
763                    {
764                        match (lft_offset, rht_offset) {
765                            (Some(_), Some(_)) => None,
766                            (None, rht_offset @ Some(_)) => {
767                                Some(InequalityInputPair::new(lft_input, rht_input, rht_offset))
768                            }
769                            (Some((operator, operand)), None) => Some(InequalityInputPair::new(
770                                lft_input,
771                                rht_input,
772                                Some((
773                                    if operator == ExprType::Add {
774                                        ExprType::Subtract
775                                    } else {
776                                        ExprType::Add
777                                    },
778                                    operand,
779                                )),
780                            )),
781                            (None, None) => {
782                                Some(InequalityInputPair::new(lft_input, rht_input, None))
783                            }
784                        }
785                    } else {
786                        None
787                    }
788                }
789                _ => None,
790            }
791        } else {
792            None
793        }
794    }
795
796    /// Returns the `InputRef` and offset of a predicate if it matches
797    /// the form `InputRef [+- const_expr]`, else returns None.
798    fn as_input_offset(&self) -> Option<(usize, Option<(ExprType, ExprImpl)>)> {
799        match self {
800            ExprImpl::InputRef(input_ref) => Some((input_ref.index(), None)),
801            ExprImpl::FunctionCall(function_call) => {
802                let expr_type = function_call.func_type();
803                match expr_type {
804                    ExprType::Add | ExprType::Subtract => {
805                        let (_, lhs, rhs) = function_call.clone().decompose_as_binary();
806                        if let ExprImpl::InputRef(input_ref) = &lhs
807                            && rhs.is_const()
808                        {
809                            // Currently we will return `None` for non-literal because the result of the expression might be '1 day'. However, there will definitely exist false positives such as '1 second + 1 second'.
810                            // We will treat the expression as an input offset when rhs is `null`.
811                            if rhs.return_type() == DataType::Interval
812                                && rhs.as_literal().is_none_or(|literal| {
813                                    literal.get_data().as_ref().is_some_and(|scalar| {
814                                        let interval = scalar.as_interval();
815                                        interval.months() != 0 || interval.days() != 0
816                                    })
817                                })
818                            {
819                                None
820                            } else {
821                                Some((input_ref.index(), Some((expr_type, rhs))))
822                            }
823                        } else {
824                            None
825                        }
826                    }
827                    _ => None,
828                }
829            }
830            _ => None,
831        }
832    }
833
834    pub fn as_eq_const(&self) -> Option<(InputRef, ExprImpl)> {
835        if let ExprImpl::FunctionCall(function_call) = self
836            && function_call.func_type() == ExprType::Equal
837        {
838            match function_call.clone().decompose_as_binary() {
839                (_, ExprImpl::InputRef(x), y) if y.is_const() => Some((*x, y)),
840                (_, x, ExprImpl::InputRef(y)) if x.is_const() => Some((*y, x)),
841                _ => None,
842            }
843        } else {
844            None
845        }
846    }
847
848    pub fn as_eq_correlated_input_ref(&self) -> Option<(InputRef, CorrelatedInputRef)> {
849        if let ExprImpl::FunctionCall(function_call) = self
850            && function_call.func_type() == ExprType::Equal
851        {
852            match function_call.clone().decompose_as_binary() {
853                (_, ExprImpl::InputRef(x), ExprImpl::CorrelatedInputRef(y)) => Some((*x, *y)),
854                (_, ExprImpl::CorrelatedInputRef(x), ExprImpl::InputRef(y)) => Some((*y, *x)),
855                _ => None,
856            }
857        } else {
858            None
859        }
860    }
861
862    pub fn as_is_null(&self) -> Option<InputRef> {
863        if let ExprImpl::FunctionCall(function_call) = self
864            && function_call.func_type() == ExprType::IsNull
865        {
866            match function_call.clone().decompose_as_unary() {
867                (_, ExprImpl::InputRef(x)) => Some(*x),
868                _ => None,
869            }
870        } else {
871            None
872        }
873    }
874
875    pub fn as_comparison_const(&self) -> Option<(InputRef, ExprType, ExprImpl)> {
876        fn reverse_comparison(comparison: ExprType) -> ExprType {
877            match comparison {
878                ExprType::LessThan => ExprType::GreaterThan,
879                ExprType::LessThanOrEqual => ExprType::GreaterThanOrEqual,
880                ExprType::GreaterThan => ExprType::LessThan,
881                ExprType::GreaterThanOrEqual => ExprType::LessThanOrEqual,
882                _ => unreachable!(),
883            }
884        }
885
886        if let ExprImpl::FunctionCall(function_call) = self {
887            match function_call.func_type() {
888                ty @ (ExprType::LessThan
889                | ExprType::LessThanOrEqual
890                | ExprType::GreaterThan
891                | ExprType::GreaterThanOrEqual) => {
892                    let (_, op1, op2) = function_call.clone().decompose_as_binary();
893                    match (op1, op2) {
894                        (ExprImpl::InputRef(x), y) if y.is_const() => Some((*x, ty, y)),
895                        (x, ExprImpl::InputRef(y)) if x.is_const() => {
896                            Some((*y, reverse_comparison(ty), x))
897                        }
898                        _ => None,
899                    }
900                }
901                _ => None,
902            }
903        } else {
904            None
905        }
906    }
907
908    pub fn as_in_const_list(&self) -> Option<(InputRef, Vec<ExprImpl>)> {
909        if let ExprImpl::FunctionCall(function_call) = self
910            && function_call.func_type() == ExprType::In
911        {
912            let mut inputs = function_call.inputs().iter().cloned();
913            let input_ref = match inputs.next().unwrap() {
914                ExprImpl::InputRef(i) => *i,
915                _ => return None,
916            };
917            let list: Vec<_> = inputs
918                .inspect(|expr| {
919                    // Non constant IN will be bound to OR
920                    assert!(expr.is_const());
921                })
922                .collect();
923
924            Some((input_ref, list))
925        } else {
926            None
927        }
928    }
929
930    pub fn as_or_disjunctions(&self) -> Option<Vec<ExprImpl>> {
931        if let ExprImpl::FunctionCall(function_call) = self
932            && function_call.func_type() == ExprType::Or
933        {
934            Some(to_disjunctions(self.clone()))
935        } else {
936            None
937        }
938    }
939
940    pub fn to_project_set_select_item_proto(&self) -> ProjectSetSelectItem {
941        use risingwave_pb::expr::project_set_select_item::SelectItem::*;
942
943        ProjectSetSelectItem {
944            select_item: Some(match self {
945                ExprImpl::TableFunction(tf) => TableFunction(tf.to_protobuf()),
946                expr => Expr(expr.to_expr_proto()),
947            }),
948        }
949    }
950
951    pub fn from_expr_proto(proto: &ExprNode) -> RwResult<Self> {
952        let rex_node = proto.get_rex_node()?;
953        let ret_type = proto.get_return_type()?.into();
954
955        Ok(match rex_node {
956            RexNode::InputRef(column_index) => Self::InputRef(Box::new(InputRef::from_expr_proto(
957                *column_index as _,
958                ret_type,
959            )?)),
960            RexNode::Constant(_) => Self::Literal(Box::new(Literal::from_expr_proto(proto)?)),
961            RexNode::Udf(udf) => Self::UserDefinedFunction(Box::new(
962                UserDefinedFunction::from_expr_proto(udf, ret_type)?,
963            )),
964            RexNode::FuncCall(function_call) => {
965                Self::FunctionCall(Box::new(FunctionCall::from_expr_proto(
966                    function_call,
967                    proto.get_function_type()?, // only interpret if it's a function call
968                    ret_type,
969                )?))
970            }
971            RexNode::Now(_) => Self::Now(Box::new(Now {})),
972        })
973    }
974}
975
976impl From<Condition> for ExprImpl {
977    fn from(c: Condition) -> Self {
978        ExprImpl::and(c.conjunctions)
979    }
980}
981
982/// A custom Debug implementation that is more concise and suitable to use with
983/// [`std::fmt::Formatter::debug_list`] in plan nodes. If the verbose output is preferred, it is
984/// still available via `{:#?}`.
985impl std::fmt::Debug for ExprImpl {
986    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
987        if f.alternate() {
988            return match self {
989                Self::InputRef(arg0) => f.debug_tuple("InputRef").field(arg0).finish(),
990                Self::Literal(arg0) => f.debug_tuple("Literal").field(arg0).finish(),
991                Self::FunctionCall(arg0) => f.debug_tuple("FunctionCall").field(arg0).finish(),
992                Self::FunctionCallWithLambda(arg0) => {
993                    f.debug_tuple("FunctionCallWithLambda").field(arg0).finish()
994                }
995                Self::AggCall(arg0) => f.debug_tuple("AggCall").field(arg0).finish(),
996                Self::Subquery(arg0) => f.debug_tuple("Subquery").field(arg0).finish(),
997                Self::CorrelatedInputRef(arg0) => {
998                    f.debug_tuple("CorrelatedInputRef").field(arg0).finish()
999                }
1000                Self::TableFunction(arg0) => f.debug_tuple("TableFunction").field(arg0).finish(),
1001                Self::WindowFunction(arg0) => f.debug_tuple("WindowFunction").field(arg0).finish(),
1002                Self::UserDefinedFunction(arg0) => {
1003                    f.debug_tuple("UserDefinedFunction").field(arg0).finish()
1004                }
1005                Self::Parameter(arg0) => f.debug_tuple("Parameter").field(arg0).finish(),
1006                Self::Now(_) => f.debug_tuple("Now").finish(),
1007            };
1008        }
1009        match self {
1010            Self::InputRef(x) => write!(f, "{:?}", x),
1011            Self::Literal(x) => write!(f, "{:?}", x),
1012            Self::FunctionCall(x) => write!(f, "{:?}", x),
1013            Self::FunctionCallWithLambda(x) => write!(f, "{:?}", x),
1014            Self::AggCall(x) => write!(f, "{:?}", x),
1015            Self::Subquery(x) => write!(f, "{:?}", x),
1016            Self::CorrelatedInputRef(x) => write!(f, "{:?}", x),
1017            Self::TableFunction(x) => write!(f, "{:?}", x),
1018            Self::WindowFunction(x) => write!(f, "{:?}", x),
1019            Self::UserDefinedFunction(x) => write!(f, "{:?}", x),
1020            Self::Parameter(x) => write!(f, "{:?}", x),
1021            Self::Now(x) => write!(f, "{:?}", x),
1022        }
1023    }
1024}
1025
1026pub struct ExprDisplay<'a> {
1027    pub expr: &'a ExprImpl,
1028    pub input_schema: &'a Schema,
1029}
1030
1031impl std::fmt::Debug for ExprDisplay<'_> {
1032    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1033        let that = self.expr;
1034        match that {
1035            ExprImpl::InputRef(x) => write!(
1036                f,
1037                "{:?}",
1038                InputRefDisplay {
1039                    input_ref: x,
1040                    input_schema: self.input_schema
1041                }
1042            ),
1043            ExprImpl::Literal(x) => write!(f, "{:?}", x),
1044            ExprImpl::FunctionCall(x) => write!(
1045                f,
1046                "{:?}",
1047                FunctionCallDisplay {
1048                    function_call: x,
1049                    input_schema: self.input_schema
1050                }
1051            ),
1052            ExprImpl::FunctionCallWithLambda(x) => write!(
1053                f,
1054                "{:?}",
1055                FunctionCallDisplay {
1056                    function_call: &x.to_full_function_call(),
1057                    input_schema: self.input_schema
1058                }
1059            ),
1060            ExprImpl::AggCall(x) => write!(f, "{:?}", x),
1061            ExprImpl::Subquery(x) => write!(f, "{:?}", x),
1062            ExprImpl::CorrelatedInputRef(x) => write!(f, "{:?}", x),
1063            ExprImpl::TableFunction(x) => {
1064                // TODO: TableFunctionCallVerboseDisplay
1065                write!(f, "{:?}", x)
1066            }
1067            ExprImpl::WindowFunction(x) => {
1068                // TODO: WindowFunctionCallVerboseDisplay
1069                write!(f, "{:?}", x)
1070            }
1071            ExprImpl::UserDefinedFunction(x) => {
1072                write!(
1073                    f,
1074                    "{:?}",
1075                    UserDefinedFunctionDisplay {
1076                        func_call: x,
1077                        input_schema: self.input_schema
1078                    }
1079                )
1080            }
1081            ExprImpl::Parameter(x) => write!(f, "{:?}", x),
1082            ExprImpl::Now(x) => write!(f, "{:?}", x),
1083        }
1084    }
1085}
1086
1087impl std::fmt::Display for ExprDisplay<'_> {
1088    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1089        (self as &dyn std::fmt::Debug).fmt(f)
1090    }
1091}
1092
1093#[cfg(test)]
1094/// Asserts that the expression is an [`InputRef`] with the given index.
1095macro_rules! assert_eq_input_ref {
1096    ($e:expr, $index:expr) => {
1097        match $e {
1098            ExprImpl::InputRef(i) => assert_eq!(i.index(), $index),
1099            _ => assert!(false, "Expected input ref, found {:?}", $e),
1100        }
1101    };
1102}
1103
1104#[cfg(test)]
1105pub(crate) use assert_eq_input_ref;
1106use risingwave_common::bail;
1107use risingwave_common::catalog::Schema;
1108use risingwave_common::row::OwnedRow;
1109
1110use crate::utils::Condition;
1111
1112#[cfg(test)]
1113mod tests {
1114    use super::*;
1115
1116    #[test]
1117    fn test_expr_debug_alternate() {
1118        let mut e = InputRef::new(1, DataType::Boolean).into();
1119        e = FunctionCall::new(ExprType::Not, vec![e]).unwrap().into();
1120        let s = format!("{:#?}", e);
1121        assert!(s.contains("return_type: Boolean"))
1122    }
1123}