risingwave_frontend/expr/
mod.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use enum_as_inner::EnumAsInner;
16use fixedbitset::FixedBitSet;
17use futures::FutureExt;
18use paste::paste;
19use risingwave_common::array::ListValue;
20use risingwave_common::types::{DataType, Datum, JsonbVal, MapType, Scalar, ScalarImpl};
21use risingwave_expr::aggregate::PbAggKind;
22use risingwave_expr::expr::build_from_prost;
23use risingwave_pb::expr::expr_node::RexNode;
24use risingwave_pb::expr::{ExprNode, ProjectSetSelectItem};
25use user_defined_function::UserDefinedFunctionDisplay;
26
27use crate::error::{ErrorCode, Result as RwResult};
28
29mod agg_call;
30mod correlated_input_ref;
31mod function_call;
32mod function_call_with_lambda;
33mod input_ref;
34mod literal;
35mod now;
36mod parameter;
37mod pure;
38mod subquery;
39mod table_function;
40mod user_defined_function;
41mod window_function;
42
43mod order_by_expr;
44pub use order_by_expr::{OrderBy, OrderByExpr};
45
46mod expr_mutator;
47mod expr_rewriter;
48mod expr_visitor;
49pub mod function_impl;
50mod session_timezone;
51mod type_inference;
52mod utils;
53
54pub use agg_call::AggCall;
55pub use correlated_input_ref::{CorrelatedId, CorrelatedInputRef, Depth, InputRefDepthRewriter};
56pub use expr_mutator::ExprMutator;
57pub use expr_rewriter::{ExprRewriter, default_rewrite_expr};
58pub use expr_visitor::{ExprVisitor, default_visit_expr};
59pub use function_call::{FunctionCall, FunctionCallDisplay, is_row_function};
60pub use function_call_with_lambda::FunctionCallWithLambda;
61pub use input_ref::{InputRef, InputRefDisplay, input_ref_to_column_indices};
62pub use literal::Literal;
63pub use now::{InlineNowProcTime, Now, NowProcTimeFinder};
64pub use parameter::Parameter;
65pub use pure::*;
66pub use risingwave_pb::expr::expr_node::Type as ExprType;
67pub use session_timezone::{SessionTimezone, TimestamptzExprFinder};
68pub use subquery::{Subquery, SubqueryKind};
69pub use table_function::{TableFunction, TableFunctionType};
70pub use type_inference::*;
71pub use user_defined_function::UserDefinedFunction;
72pub use utils::*;
73pub use window_function::WindowFunction;
74
75const EXPR_DEPTH_THRESHOLD: usize = 30;
76const EXPR_TOO_DEEP_NOTICE: &str = "Some expression is too complicated. \
77Consider simplifying or splitting the query if you encounter any issues.";
78
79/// the trait of bound expressions
80pub trait Expr: Into<ExprImpl> {
81    /// Get the return type of the expr
82    fn return_type(&self) -> DataType;
83
84    /// Try to serialize the expression, returning an error if it's impossible.
85    fn try_to_expr_proto(&self) -> Result<ExprNode, String>;
86
87    /// Serialize the expression. Panic if it's impossible.
88    fn to_expr_proto(&self) -> ExprNode {
89        self.try_to_expr_proto()
90            .expect("failed to serialize expression to protobuf")
91    }
92}
93
94macro_rules! impl_expr_impl {
95    ($($t:ident,)*) => {
96        #[derive(Clone, Eq, PartialEq, Hash, EnumAsInner)]
97        pub enum ExprImpl {
98            $($t(Box<$t>),)*
99        }
100
101        impl ExprImpl {
102            pub fn variant_name(&self) -> &'static str {
103                match self {
104                    $(ExprImpl::$t(_) => stringify!($t),)*
105                }
106            }
107        }
108
109        $(
110        impl From<$t> for ExprImpl {
111            fn from(o: $t) -> ExprImpl {
112                ExprImpl::$t(Box::new(o))
113            }
114        })*
115
116        impl Expr for ExprImpl {
117            fn return_type(&self) -> DataType {
118                match self {
119                    $(ExprImpl::$t(expr) => expr.return_type(),)*
120                }
121            }
122
123            fn try_to_expr_proto(&self) -> Result<ExprNode, String> {
124                match self {
125                    $(ExprImpl::$t(expr) => expr.try_to_expr_proto(),)*
126                }
127            }
128        }
129    };
130}
131
132impl_expr_impl!(
133    // BoundColumnRef, might be used in binder.
134    CorrelatedInputRef,
135    InputRef,
136    Literal,
137    FunctionCall,
138    FunctionCallWithLambda,
139    AggCall,
140    Subquery,
141    TableFunction,
142    WindowFunction,
143    UserDefinedFunction,
144    Parameter,
145    Now,
146);
147
148impl ExprImpl {
149    /// A literal int value.
150    #[inline(always)]
151    pub fn literal_int(v: i32) -> Self {
152        Literal::new(Some(v.to_scalar_value()), DataType::Int32).into()
153    }
154
155    /// A literal bigint value
156    #[inline(always)]
157    pub fn literal_bigint(v: i64) -> Self {
158        Literal::new(Some(v.to_scalar_value()), DataType::Int64).into()
159    }
160
161    /// A literal float64 value.
162    #[inline(always)]
163    pub fn literal_f64(v: f64) -> Self {
164        Literal::new(Some(v.into()), DataType::Float64).into()
165    }
166
167    /// A literal boolean value.
168    #[inline(always)]
169    pub fn literal_bool(v: bool) -> Self {
170        Literal::new(Some(v.to_scalar_value()), DataType::Boolean).into()
171    }
172
173    /// A literal varchar value.
174    #[inline(always)]
175    pub fn literal_varchar(v: String) -> Self {
176        Literal::new(Some(v.into()), DataType::Varchar).into()
177    }
178
179    /// A literal null value.
180    #[inline(always)]
181    pub fn literal_null(element_type: DataType) -> Self {
182        Literal::new(None, element_type).into()
183    }
184
185    /// A literal jsonb value.
186    #[inline(always)]
187    pub fn literal_jsonb(v: JsonbVal) -> Self {
188        Literal::new(Some(v.into()), DataType::Jsonb).into()
189    }
190
191    /// A literal list value.
192    #[inline(always)]
193    pub fn literal_list(v: ListValue, element_type: DataType) -> Self {
194        Literal::new(
195            Some(v.to_scalar_value()),
196            DataType::List(Box::new(element_type)),
197        )
198        .into()
199    }
200
201    /// Takes the expression, leaving a literal null of the same type in its place.
202    pub fn take(&mut self) -> Self {
203        std::mem::replace(self, Self::literal_null(self.return_type()))
204    }
205
206    /// A `count(*)` aggregate function.
207    #[inline(always)]
208    pub fn count_star() -> Self {
209        AggCall::new(
210            PbAggKind::Count.into(),
211            vec![],
212            false,
213            OrderBy::any(),
214            Condition::true_cond(),
215            vec![],
216        )
217        .unwrap()
218        .into()
219    }
220
221    /// Create a new expression by merging the given expressions by `And`.
222    ///
223    /// If `exprs` is empty, return a literal `true`.
224    pub fn and(exprs: impl IntoIterator<Item = ExprImpl>) -> Self {
225        merge_expr_by_logical(exprs, ExprType::And, ExprImpl::literal_bool(true))
226    }
227
228    /// Create a new expression by merging the given expressions by `Or`.
229    ///
230    /// If `exprs` is empty, return a literal `false`.
231    pub fn or(exprs: impl IntoIterator<Item = ExprImpl>) -> Self {
232        merge_expr_by_logical(exprs, ExprType::Or, ExprImpl::literal_bool(false))
233    }
234
235    /// Collect all `InputRef`s' indexes in the expression.
236    ///
237    /// # Panics
238    /// Panics if `input_ref >= input_col_num`.
239    pub fn collect_input_refs(&self, input_col_num: usize) -> FixedBitSet {
240        collect_input_refs(input_col_num, [self])
241    }
242
243    /// Check if the expression has no side effects and output is deterministic
244    pub fn is_pure(&self) -> bool {
245        is_pure(self)
246    }
247
248    pub fn is_impure(&self) -> bool {
249        is_impure(self)
250    }
251
252    /// Count `Now`s in the expression.
253    pub fn count_nows(&self) -> usize {
254        let mut visitor = CountNow::default();
255        visitor.visit_expr(self);
256        visitor.count()
257    }
258
259    /// Check whether self is literal NULL.
260    pub fn is_null(&self) -> bool {
261        matches!(self, ExprImpl::Literal(literal) if literal.get_data().is_none())
262    }
263
264    /// Check whether self is a literal NULL or literal string.
265    pub fn is_untyped(&self) -> bool {
266        matches!(self, ExprImpl::Literal(literal) if literal.is_untyped())
267            || matches!(self, ExprImpl::Parameter(parameter) if !parameter.has_infer())
268    }
269
270    /// Shorthand to create cast expr to `target` type in implicit context.
271    pub fn cast_implicit(mut self, target: &DataType) -> Result<ExprImpl, CastError> {
272        FunctionCall::cast_mut(&mut self, target, CastContext::Implicit)?;
273        Ok(self)
274    }
275
276    /// Shorthand to create cast expr to `target` type in assign context.
277    pub fn cast_assign(mut self, target: &DataType) -> Result<ExprImpl, CastError> {
278        FunctionCall::cast_mut(&mut self, target, CastContext::Assign)?;
279        Ok(self)
280    }
281
282    /// Shorthand to create cast expr to `target` type in explicit context.
283    pub fn cast_explicit(mut self, target: &DataType) -> Result<ExprImpl, CastError> {
284        FunctionCall::cast_mut(&mut self, target, CastContext::Explicit)?;
285        Ok(self)
286    }
287
288    /// Shorthand to inplace cast expr to `target` type in implicit context.
289    pub fn cast_implicit_mut(&mut self, target: &DataType) -> Result<(), CastError> {
290        FunctionCall::cast_mut(self, target, CastContext::Implicit)
291    }
292
293    /// Shorthand to inplace cast expr to `target` type in explicit context.
294    pub fn cast_explicit_mut(&mut self, target: &DataType) -> Result<(), CastError> {
295        FunctionCall::cast_mut(self, target, CastContext::Explicit)
296    }
297
298    /// Casting to Regclass type means getting the oid of expr.
299    /// See <https://www.postgresql.org/docs/current/datatype-oid.html>
300    pub fn cast_to_regclass(self) -> Result<ExprImpl, CastError> {
301        match self.return_type() {
302            DataType::Varchar => Ok(ExprImpl::FunctionCall(Box::new(
303                FunctionCall::new_unchecked(ExprType::CastRegclass, vec![self], DataType::Int32),
304            ))),
305            DataType::Int32 => Ok(self),
306            dt if dt.is_int() => Ok(self.cast_explicit(&DataType::Int32)?),
307            _ => bail_cast_error!("unsupported input type"),
308        }
309    }
310
311    /// Shorthand to inplace cast expr to `regclass` type.
312    pub fn cast_to_regclass_mut(&mut self) -> Result<(), CastError> {
313        let owned = std::mem::replace(self, ExprImpl::literal_bool(false));
314        *self = owned.cast_to_regclass()?;
315        Ok(())
316    }
317
318    /// Ensure the return type of this expression is an array of some type.
319    pub fn ensure_array_type(&self) -> Result<(), ErrorCode> {
320        if self.is_untyped() {
321            return Err(ErrorCode::BindError(
322                "could not determine polymorphic type because input has type unknown".into(),
323            ));
324        }
325        match self.return_type() {
326            DataType::List(_) => Ok(()),
327            t => Err(ErrorCode::BindError(format!("expects array but got {t}"))),
328        }
329    }
330
331    /// Ensure the return type of this expression is a map of some type.
332    pub fn try_into_map_type(&self) -> Result<MapType, ErrorCode> {
333        if self.is_untyped() {
334            return Err(ErrorCode::BindError(
335                "could not determine polymorphic type because input has type unknown".into(),
336            ));
337        }
338        match self.return_type() {
339            DataType::Map(m) => Ok(m),
340            t => Err(ErrorCode::BindError(format!("expects map but got {t}"))),
341        }
342    }
343
344    /// Shorthand to enforce implicit cast to boolean
345    pub fn enforce_bool_clause(self, clause: &str) -> RwResult<ExprImpl> {
346        if self.is_untyped() {
347            let inner = self.cast_implicit(&DataType::Boolean)?;
348            return Ok(inner);
349        }
350        let return_type = self.return_type();
351        if return_type != DataType::Boolean {
352            bail!(
353                "argument of {} must be boolean, not type {:?}",
354                clause,
355                return_type
356            )
357        }
358        Ok(self)
359    }
360
361    /// Create "cast" expr to string (`varchar`) type. This is different from a real cast, as
362    /// boolean is converted to a single char rather than full word.
363    ///
364    /// Choose between `cast_output` and `cast_{assign,explicit}(Varchar)` based on `PostgreSQL`'s
365    /// behavior on bools. For example, `concat(':', true)` is `:t` but `':' || true` is `:true`.
366    /// All other types have the same behavior when formatting to output and casting to string.
367    ///
368    /// References in `PostgreSQL`:
369    /// * [cast](https://github.com/postgres/postgres/blob/a3ff08e0b08dbfeb777ccfa8f13ebaa95d064c04/src/include/catalog/pg_cast.dat#L437-L444)
370    /// * [impl](https://github.com/postgres/postgres/blob/27b77ecf9f4d5be211900eda54d8155ada50d696/src/backend/utils/adt/bool.c#L204-L209)
371    pub fn cast_output(self) -> RwResult<ExprImpl> {
372        if self.return_type() == DataType::Boolean {
373            return Ok(FunctionCall::new(ExprType::BoolOut, vec![self])?.into());
374        }
375        // Use normal cast for other types. Both `assign` and `explicit` can pass the castability
376        // check and there is no difference.
377        self.cast_assign(&DataType::Varchar)
378            .map_err(|err| err.into())
379    }
380
381    /// Evaluate the expression on the given input.
382    ///
383    /// TODO: This is a naive implementation. We should avoid proto ser/de.
384    /// Tracking issue: <https://github.com/risingwavelabs/risingwave/issues/3479>
385    pub async fn eval_row(&self, input: &OwnedRow) -> RwResult<Datum> {
386        let backend_expr = build_from_prost(&self.to_expr_proto())?;
387        Ok(backend_expr.eval_row(input).await?)
388    }
389
390    /// Try to evaluate an expression if it's a constant expression by `ExprImpl::is_const`.
391    ///
392    /// Returns...
393    /// - `None` if it's not a constant expression,
394    /// - `Some(Ok(_))` if constant evaluation succeeds,
395    /// - `Some(Err(_))` if there's an error while evaluating a constant expression.
396    pub fn try_fold_const(&self) -> Option<RwResult<Datum>> {
397        if self.is_const() {
398            self.eval_row(&OwnedRow::empty())
399                .now_or_never()
400                .expect("constant expression should not be async")
401                .into()
402        } else {
403            None
404        }
405    }
406
407    /// Similar to `ExprImpl::try_fold_const`, but panics if the expression is not constant.
408    pub fn fold_const(&self) -> RwResult<Datum> {
409        self.try_fold_const().expect("expression is not constant")
410    }
411}
412
413/// Implement helper functions which recursively checks whether an variant is included in the
414/// expression. e.g., `has_subquery(&self) -> bool`
415///
416/// It will not traverse inside subqueries.
417macro_rules! impl_has_variant {
418    ( $($variant:ty),* ) => {
419        paste! {
420            impl ExprImpl {
421                $(
422                    pub fn [<has_ $variant:snake>](&self) -> bool {
423                        struct Has { has: bool }
424
425                        impl ExprVisitor for Has {
426                            fn [<visit_ $variant:snake>](&mut self, _: &$variant) {
427                                self.has = true;
428                            }
429                        }
430
431                        let mut visitor = Has { has: false };
432                        visitor.visit_expr(self);
433                        visitor.has
434                    }
435                )*
436            }
437        }
438    };
439}
440
441impl_has_variant! {InputRef, Literal, FunctionCall, FunctionCallWithLambda, AggCall, Subquery, TableFunction, WindowFunction, UserDefinedFunction, Now}
442
443#[derive(Debug, Clone, PartialEq, Eq, Hash)]
444pub struct InequalityInputPair {
445    /// Input index of greater side of inequality.
446    pub(crate) key_required_larger: usize,
447    /// Input index of less side of inequality.
448    pub(crate) key_required_smaller: usize,
449    /// greater >= less + `delta_expression`
450    pub(crate) delta_expression: Option<(ExprType, ExprImpl)>,
451}
452
453impl InequalityInputPair {
454    fn new(
455        key_required_larger: usize,
456        key_required_smaller: usize,
457        delta_expression: Option<(ExprType, ExprImpl)>,
458    ) -> Self {
459        Self {
460            key_required_larger,
461            key_required_smaller,
462            delta_expression,
463        }
464    }
465}
466
467impl ExprImpl {
468    /// This function is not meant to be called. In most cases you would want
469    /// [`ExprImpl::has_correlated_input_ref_by_depth`].
470    ///
471    /// When an expr contains a [`CorrelatedInputRef`] with lower depth, the whole expr is still
472    /// considered to be uncorrelated, and can be checked with [`ExprImpl::has_subquery`] as well.
473    /// See examples on [`crate::binder::BoundQuery::is_correlated_by_depth`] for details.
474    ///
475    /// This is a placeholder to trigger a compiler error when a trivial implementation checking for
476    /// enum variant is generated by accident. It cannot be called either because you cannot pass
477    /// `Infallible` to it.
478    pub fn has_correlated_input_ref(&self, _: std::convert::Infallible) -> bool {
479        unreachable!()
480    }
481
482    /// Used to check whether the expression has [`CorrelatedInputRef`].
483    ///
484    /// This is the core logic that supports [`crate::binder::BoundQuery::is_correlated_by_depth`]. Check the
485    /// doc of it for examples of `depth` being equal, less or greater.
486    // We need to traverse inside subqueries.
487    pub fn has_correlated_input_ref_by_depth(&self, depth: Depth) -> bool {
488        struct Has {
489            depth: usize,
490            has: bool,
491        }
492
493        impl ExprVisitor for Has {
494            fn visit_correlated_input_ref(&mut self, correlated_input_ref: &CorrelatedInputRef) {
495                if correlated_input_ref.depth() == self.depth {
496                    self.has = true;
497                }
498            }
499
500            fn visit_subquery(&mut self, subquery: &Subquery) {
501                self.has |= subquery.is_correlated_by_depth(self.depth);
502            }
503        }
504
505        let mut visitor = Has { depth, has: false };
506        visitor.visit_expr(self);
507        visitor.has
508    }
509
510    pub fn has_correlated_input_ref_by_correlated_id(&self, correlated_id: CorrelatedId) -> bool {
511        struct Has {
512            correlated_id: CorrelatedId,
513            has: bool,
514        }
515
516        impl ExprVisitor for Has {
517            fn visit_correlated_input_ref(&mut self, correlated_input_ref: &CorrelatedInputRef) {
518                if correlated_input_ref.correlated_id() == self.correlated_id {
519                    self.has = true;
520                }
521            }
522
523            fn visit_subquery(&mut self, subquery: &Subquery) {
524                self.has |= subquery.is_correlated_by_correlated_id(self.correlated_id);
525            }
526        }
527
528        let mut visitor = Has {
529            correlated_id,
530            has: false,
531        };
532        visitor.visit_expr(self);
533        visitor.has
534    }
535
536    /// Collect `CorrelatedInputRef`s in `ExprImpl` by relative `depth`, return their indices, and
537    /// assign absolute `correlated_id` for them.
538    pub fn collect_correlated_indices_by_depth_and_assign_id(
539        &mut self,
540        depth: Depth,
541        correlated_id: CorrelatedId,
542    ) -> Vec<usize> {
543        struct Collector {
544            depth: Depth,
545            correlated_indices: Vec<usize>,
546            correlated_id: CorrelatedId,
547        }
548
549        impl ExprMutator for Collector {
550            fn visit_correlated_input_ref(
551                &mut self,
552                correlated_input_ref: &mut CorrelatedInputRef,
553            ) {
554                if correlated_input_ref.depth() == self.depth {
555                    self.correlated_indices.push(correlated_input_ref.index());
556                    correlated_input_ref.set_correlated_id(self.correlated_id);
557                }
558            }
559
560            fn visit_subquery(&mut self, subquery: &mut Subquery) {
561                self.correlated_indices.extend(
562                    subquery.collect_correlated_indices_by_depth_and_assign_id(
563                        self.depth,
564                        self.correlated_id,
565                    ),
566                );
567            }
568        }
569
570        let mut collector = Collector {
571            depth,
572            correlated_indices: vec![],
573            correlated_id,
574        };
575        collector.visit_expr(self);
576        collector.correlated_indices.sort();
577        collector.correlated_indices.dedup();
578        collector.correlated_indices
579    }
580
581    /// Checks whether this is a constant expr that can be evaluated over a dummy chunk.
582    ///
583    /// The expression tree should only consist of literals and **pure** function calls.
584    pub fn is_const(&self) -> bool {
585        let only_literal_and_func = {
586            struct HasOthers {
587                has_others: bool,
588            }
589
590            impl ExprVisitor for HasOthers {
591                fn visit_expr(&mut self, expr: &ExprImpl) {
592                    match expr {
593                        ExprImpl::CorrelatedInputRef(_)
594                        | ExprImpl::InputRef(_)
595                        | ExprImpl::AggCall(_)
596                        | ExprImpl::Subquery(_)
597                        | ExprImpl::TableFunction(_)
598                        | ExprImpl::WindowFunction(_)
599                        | ExprImpl::UserDefinedFunction(_)
600                        | ExprImpl::Parameter(_)
601                        | ExprImpl::Now(_) => self.has_others = true,
602                        ExprImpl::Literal(_inner) => {}
603                        ExprImpl::FunctionCall(inner) => {
604                            if !self.is_short_circuit(inner) {
605                                // only if the current `func_call` is *not* a short-circuit
606                                // expression, e.g., true or (...) | false and (...),
607                                // shall we proceed to visit it.
608                                self.visit_function_call(inner)
609                            }
610                        }
611                        ExprImpl::FunctionCallWithLambda(inner) => {
612                            self.visit_function_call_with_lambda(inner)
613                        }
614                    }
615                }
616            }
617
618            impl HasOthers {
619                fn is_short_circuit(&self, func_call: &FunctionCall) -> bool {
620                    /// evaluate the first parameter of `Or` or `And` function call
621                    fn eval_first(e: &ExprImpl, expect: bool) -> bool {
622                        if let ExprImpl::Literal(l) = e {
623                            *l.get_data() == Some(ScalarImpl::Bool(expect))
624                        } else {
625                            false
626                        }
627                    }
628
629                    match func_call.func_type {
630                        ExprType::Or => eval_first(&func_call.inputs()[0], true),
631                        ExprType::And => eval_first(&func_call.inputs()[0], false),
632                        _ => false,
633                    }
634                }
635            }
636
637            let mut visitor = HasOthers { has_others: false };
638            visitor.visit_expr(self);
639            !visitor.has_others
640        };
641
642        let is_pure = self.is_pure();
643
644        only_literal_and_func && is_pure
645    }
646
647    /// Returns the `InputRefs` of an Equality predicate if it matches
648    /// ordered by the canonical ordering (lower, higher), else returns None
649    pub fn as_eq_cond(&self) -> Option<(InputRef, InputRef)> {
650        if let ExprImpl::FunctionCall(function_call) = self
651            && function_call.func_type() == ExprType::Equal
652            && let (_, ExprImpl::InputRef(x), ExprImpl::InputRef(y)) =
653                function_call.clone().decompose_as_binary()
654        {
655            if x.index() < y.index() {
656                Some((*x, *y))
657            } else {
658                Some((*y, *x))
659            }
660        } else {
661            None
662        }
663    }
664
665    pub fn as_is_not_distinct_from_cond(&self) -> Option<(InputRef, InputRef)> {
666        if let ExprImpl::FunctionCall(function_call) = self
667            && function_call.func_type() == ExprType::IsNotDistinctFrom
668            && let (_, ExprImpl::InputRef(x), ExprImpl::InputRef(y)) =
669                function_call.clone().decompose_as_binary()
670        {
671            if x.index() < y.index() {
672                Some((*x, *y))
673            } else {
674                Some((*y, *x))
675            }
676        } else {
677            None
678        }
679    }
680
681    pub fn reverse_comparison(comparison: ExprType) -> ExprType {
682        match comparison {
683            ExprType::LessThan => ExprType::GreaterThan,
684            ExprType::LessThanOrEqual => ExprType::GreaterThanOrEqual,
685            ExprType::GreaterThan => ExprType::LessThan,
686            ExprType::GreaterThanOrEqual => ExprType::LessThanOrEqual,
687            ExprType::Equal | ExprType::IsNotDistinctFrom => comparison,
688            _ => unreachable!(),
689        }
690    }
691
692    pub fn as_comparison_cond(&self) -> Option<(InputRef, ExprType, InputRef)> {
693        if let ExprImpl::FunctionCall(function_call) = self {
694            match function_call.func_type() {
695                ty @ (ExprType::LessThan
696                | ExprType::LessThanOrEqual
697                | ExprType::GreaterThan
698                | ExprType::GreaterThanOrEqual) => {
699                    let (_, op1, op2) = function_call.clone().decompose_as_binary();
700                    if let (ExprImpl::InputRef(x), ExprImpl::InputRef(y)) = (op1, op2) {
701                        if x.index < y.index {
702                            Some((*x, ty, *y))
703                        } else {
704                            Some((*y, Self::reverse_comparison(ty), *x))
705                        }
706                    } else {
707                        None
708                    }
709                }
710                _ => None,
711            }
712        } else {
713            None
714        }
715    }
716
717    /// Accepts expressions of the form `input_expr cmp now_expr` or `now_expr cmp input_expr`,
718    /// where `input_expr` contains an `InputRef` and contains no `now()`, and `now_expr`
719    /// contains a `now()` but no `InputRef`.
720    ///
721    /// Canonicalizes to the first ordering and returns `(input_expr, cmp, now_expr)`
722    pub fn as_now_comparison_cond(&self) -> Option<(ExprImpl, ExprType, ExprImpl)> {
723        if let ExprImpl::FunctionCall(function_call) = self {
724            match function_call.func_type() {
725                ty @ (ExprType::Equal
726                | ExprType::LessThan
727                | ExprType::LessThanOrEqual
728                | ExprType::GreaterThan
729                | ExprType::GreaterThanOrEqual) => {
730                    let (_, op1, op2) = function_call.clone().decompose_as_binary();
731                    if !op1.has_now()
732                        && op1.has_input_ref()
733                        && op2.has_now()
734                        && !op2.has_input_ref()
735                    {
736                        Some((op1, ty, op2))
737                    } else if op1.has_now()
738                        && !op1.has_input_ref()
739                        && !op2.has_now()
740                        && op2.has_input_ref()
741                    {
742                        Some((op2, Self::reverse_comparison(ty), op1))
743                    } else {
744                        None
745                    }
746                }
747                _ => None,
748            }
749        } else {
750            None
751        }
752    }
753
754    /// Accepts expressions of the form `InputRef cmp InputRef [+- const_expr]` or
755    /// `InputRef [+- const_expr] cmp InputRef`.
756    pub(crate) fn as_input_comparison_cond(&self) -> Option<InequalityInputPair> {
757        if let ExprImpl::FunctionCall(function_call) = self {
758            match function_call.func_type() {
759                ty @ (ExprType::LessThan
760                | ExprType::LessThanOrEqual
761                | ExprType::GreaterThan
762                | ExprType::GreaterThanOrEqual) => {
763                    let (_, mut op1, mut op2) = function_call.clone().decompose_as_binary();
764                    if matches!(ty, ExprType::LessThan | ExprType::LessThanOrEqual) {
765                        std::mem::swap(&mut op1, &mut op2);
766                    }
767                    if let (Some((lft_input, lft_offset)), Some((rht_input, rht_offset))) =
768                        (op1.as_input_offset(), op2.as_input_offset())
769                    {
770                        match (lft_offset, rht_offset) {
771                            (Some(_), Some(_)) => None,
772                            (None, rht_offset @ Some(_)) => {
773                                Some(InequalityInputPair::new(lft_input, rht_input, rht_offset))
774                            }
775                            (Some((operator, operand)), None) => Some(InequalityInputPair::new(
776                                lft_input,
777                                rht_input,
778                                Some((
779                                    if operator == ExprType::Add {
780                                        ExprType::Subtract
781                                    } else {
782                                        ExprType::Add
783                                    },
784                                    operand,
785                                )),
786                            )),
787                            (None, None) => {
788                                Some(InequalityInputPair::new(lft_input, rht_input, None))
789                            }
790                        }
791                    } else {
792                        None
793                    }
794                }
795                _ => None,
796            }
797        } else {
798            None
799        }
800    }
801
802    /// Returns the `InputRef` and offset of a predicate if it matches
803    /// the form `InputRef [+- const_expr]`, else returns None.
804    fn as_input_offset(&self) -> Option<(usize, Option<(ExprType, ExprImpl)>)> {
805        match self {
806            ExprImpl::InputRef(input_ref) => Some((input_ref.index(), None)),
807            ExprImpl::FunctionCall(function_call) => {
808                let expr_type = function_call.func_type();
809                match expr_type {
810                    ExprType::Add | ExprType::Subtract => {
811                        let (_, lhs, rhs) = function_call.clone().decompose_as_binary();
812                        if let ExprImpl::InputRef(input_ref) = &lhs
813                            && rhs.is_const()
814                        {
815                            // Currently we will return `None` for non-literal because the result of the expression might be '1 day'. However, there will definitely exist false positives such as '1 second + 1 second'.
816                            // We will treat the expression as an input offset when rhs is `null`.
817                            if rhs.return_type() == DataType::Interval
818                                && rhs.as_literal().is_none_or(|literal| {
819                                    literal.get_data().as_ref().is_some_and(|scalar| {
820                                        let interval = scalar.as_interval();
821                                        interval.months() != 0 || interval.days() != 0
822                                    })
823                                })
824                            {
825                                None
826                            } else {
827                                Some((input_ref.index(), Some((expr_type, rhs))))
828                            }
829                        } else {
830                            None
831                        }
832                    }
833                    _ => None,
834                }
835            }
836            _ => None,
837        }
838    }
839
840    pub fn as_eq_const(&self) -> Option<(InputRef, ExprImpl)> {
841        if let ExprImpl::FunctionCall(function_call) = self
842            && function_call.func_type() == ExprType::Equal
843        {
844            match function_call.clone().decompose_as_binary() {
845                (_, ExprImpl::InputRef(x), y) if y.is_const() => Some((*x, y)),
846                (_, x, ExprImpl::InputRef(y)) if x.is_const() => Some((*y, x)),
847                _ => None,
848            }
849        } else {
850            None
851        }
852    }
853
854    pub fn as_eq_correlated_input_ref(&self) -> Option<(InputRef, CorrelatedInputRef)> {
855        if let ExprImpl::FunctionCall(function_call) = self
856            && function_call.func_type() == ExprType::Equal
857        {
858            match function_call.clone().decompose_as_binary() {
859                (_, ExprImpl::InputRef(x), ExprImpl::CorrelatedInputRef(y)) => Some((*x, *y)),
860                (_, ExprImpl::CorrelatedInputRef(x), ExprImpl::InputRef(y)) => Some((*y, *x)),
861                _ => None,
862            }
863        } else {
864            None
865        }
866    }
867
868    pub fn as_is_null(&self) -> Option<InputRef> {
869        if let ExprImpl::FunctionCall(function_call) = self
870            && function_call.func_type() == ExprType::IsNull
871        {
872            match function_call.clone().decompose_as_unary() {
873                (_, ExprImpl::InputRef(x)) => Some(*x),
874                _ => None,
875            }
876        } else {
877            None
878        }
879    }
880
881    pub fn as_comparison_const(&self) -> Option<(InputRef, ExprType, ExprImpl)> {
882        fn reverse_comparison(comparison: ExprType) -> ExprType {
883            match comparison {
884                ExprType::LessThan => ExprType::GreaterThan,
885                ExprType::LessThanOrEqual => ExprType::GreaterThanOrEqual,
886                ExprType::GreaterThan => ExprType::LessThan,
887                ExprType::GreaterThanOrEqual => ExprType::LessThanOrEqual,
888                _ => unreachable!(),
889            }
890        }
891
892        if let ExprImpl::FunctionCall(function_call) = self {
893            match function_call.func_type() {
894                ty @ (ExprType::LessThan
895                | ExprType::LessThanOrEqual
896                | ExprType::GreaterThan
897                | ExprType::GreaterThanOrEqual) => {
898                    let (_, op1, op2) = function_call.clone().decompose_as_binary();
899                    match (op1, op2) {
900                        (ExprImpl::InputRef(x), y) if y.is_const() => Some((*x, ty, y)),
901                        (x, ExprImpl::InputRef(y)) if x.is_const() => {
902                            Some((*y, reverse_comparison(ty), x))
903                        }
904                        _ => None,
905                    }
906                }
907                _ => None,
908            }
909        } else {
910            None
911        }
912    }
913
914    pub fn as_in_const_list(&self) -> Option<(InputRef, Vec<ExprImpl>)> {
915        if let ExprImpl::FunctionCall(function_call) = self
916            && function_call.func_type() == ExprType::In
917        {
918            let mut inputs = function_call.inputs().iter().cloned();
919            let input_ref = match inputs.next().unwrap() {
920                ExprImpl::InputRef(i) => *i,
921                _ => return None,
922            };
923            let list: Vec<_> = inputs
924                .inspect(|expr| {
925                    // Non constant IN will be bound to OR
926                    assert!(expr.is_const());
927                })
928                .collect();
929
930            Some((input_ref, list))
931        } else {
932            None
933        }
934    }
935
936    pub fn as_or_disjunctions(&self) -> Option<Vec<ExprImpl>> {
937        if let ExprImpl::FunctionCall(function_call) = self
938            && function_call.func_type() == ExprType::Or
939        {
940            Some(to_disjunctions(self.clone()))
941        } else {
942            None
943        }
944    }
945
946    pub fn to_project_set_select_item_proto(&self) -> ProjectSetSelectItem {
947        use risingwave_pb::expr::project_set_select_item::SelectItem::*;
948
949        ProjectSetSelectItem {
950            select_item: Some(match self {
951                ExprImpl::TableFunction(tf) => TableFunction(tf.to_protobuf()),
952                expr => Expr(expr.to_expr_proto()),
953            }),
954        }
955    }
956
957    pub fn from_expr_proto(proto: &ExprNode) -> RwResult<Self> {
958        let rex_node = proto.get_rex_node()?;
959        let ret_type = proto.get_return_type()?.into();
960
961        Ok(match rex_node {
962            RexNode::InputRef(column_index) => Self::InputRef(Box::new(InputRef::from_expr_proto(
963                *column_index as _,
964                ret_type,
965            )?)),
966            RexNode::Constant(_) => Self::Literal(Box::new(Literal::from_expr_proto(proto)?)),
967            RexNode::Udf(udf) => Self::UserDefinedFunction(Box::new(
968                UserDefinedFunction::from_expr_proto(udf, ret_type)?,
969            )),
970            RexNode::FuncCall(function_call) => {
971                Self::FunctionCall(Box::new(FunctionCall::from_expr_proto(
972                    function_call,
973                    proto.get_function_type()?, // only interpret if it's a function call
974                    ret_type,
975                )?))
976            }
977            RexNode::Now(_) => Self::Now(Box::new(Now {})),
978        })
979    }
980}
981
982impl From<Condition> for ExprImpl {
983    fn from(c: Condition) -> Self {
984        ExprImpl::and(c.conjunctions)
985    }
986}
987
988/// A custom Debug implementation that is more concise and suitable to use with
989/// [`std::fmt::Formatter::debug_list`] in plan nodes. If the verbose output is preferred, it is
990/// still available via `{:#?}`.
991impl std::fmt::Debug for ExprImpl {
992    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
993        if f.alternate() {
994            return match self {
995                Self::InputRef(arg0) => f.debug_tuple("InputRef").field(arg0).finish(),
996                Self::Literal(arg0) => f.debug_tuple("Literal").field(arg0).finish(),
997                Self::FunctionCall(arg0) => f.debug_tuple("FunctionCall").field(arg0).finish(),
998                Self::FunctionCallWithLambda(arg0) => {
999                    f.debug_tuple("FunctionCallWithLambda").field(arg0).finish()
1000                }
1001                Self::AggCall(arg0) => f.debug_tuple("AggCall").field(arg0).finish(),
1002                Self::Subquery(arg0) => f.debug_tuple("Subquery").field(arg0).finish(),
1003                Self::CorrelatedInputRef(arg0) => {
1004                    f.debug_tuple("CorrelatedInputRef").field(arg0).finish()
1005                }
1006                Self::TableFunction(arg0) => f.debug_tuple("TableFunction").field(arg0).finish(),
1007                Self::WindowFunction(arg0) => f.debug_tuple("WindowFunction").field(arg0).finish(),
1008                Self::UserDefinedFunction(arg0) => {
1009                    f.debug_tuple("UserDefinedFunction").field(arg0).finish()
1010                }
1011                Self::Parameter(arg0) => f.debug_tuple("Parameter").field(arg0).finish(),
1012                Self::Now(_) => f.debug_tuple("Now").finish(),
1013            };
1014        }
1015        match self {
1016            Self::InputRef(x) => write!(f, "{:?}", x),
1017            Self::Literal(x) => write!(f, "{:?}", x),
1018            Self::FunctionCall(x) => write!(f, "{:?}", x),
1019            Self::FunctionCallWithLambda(x) => write!(f, "{:?}", x),
1020            Self::AggCall(x) => write!(f, "{:?}", x),
1021            Self::Subquery(x) => write!(f, "{:?}", x),
1022            Self::CorrelatedInputRef(x) => write!(f, "{:?}", x),
1023            Self::TableFunction(x) => write!(f, "{:?}", x),
1024            Self::WindowFunction(x) => write!(f, "{:?}", x),
1025            Self::UserDefinedFunction(x) => write!(f, "{:?}", x),
1026            Self::Parameter(x) => write!(f, "{:?}", x),
1027            Self::Now(x) => write!(f, "{:?}", x),
1028        }
1029    }
1030}
1031
1032pub struct ExprDisplay<'a> {
1033    pub expr: &'a ExprImpl,
1034    pub input_schema: &'a Schema,
1035}
1036
1037impl std::fmt::Debug for ExprDisplay<'_> {
1038    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1039        let that = self.expr;
1040        match that {
1041            ExprImpl::InputRef(x) => write!(
1042                f,
1043                "{:?}",
1044                InputRefDisplay {
1045                    input_ref: x,
1046                    input_schema: self.input_schema
1047                }
1048            ),
1049            ExprImpl::Literal(x) => write!(f, "{:?}", x),
1050            ExprImpl::FunctionCall(x) => write!(
1051                f,
1052                "{:?}",
1053                FunctionCallDisplay {
1054                    function_call: x,
1055                    input_schema: self.input_schema
1056                }
1057            ),
1058            ExprImpl::FunctionCallWithLambda(x) => write!(
1059                f,
1060                "{:?}",
1061                FunctionCallDisplay {
1062                    function_call: &x.to_full_function_call(),
1063                    input_schema: self.input_schema
1064                }
1065            ),
1066            ExprImpl::AggCall(x) => write!(f, "{:?}", x),
1067            ExprImpl::Subquery(x) => write!(f, "{:?}", x),
1068            ExprImpl::CorrelatedInputRef(x) => write!(f, "{:?}", x),
1069            ExprImpl::TableFunction(x) => {
1070                // TODO: TableFunctionCallVerboseDisplay
1071                write!(f, "{:?}", x)
1072            }
1073            ExprImpl::WindowFunction(x) => {
1074                // TODO: WindowFunctionCallVerboseDisplay
1075                write!(f, "{:?}", x)
1076            }
1077            ExprImpl::UserDefinedFunction(x) => {
1078                write!(
1079                    f,
1080                    "{:?}",
1081                    UserDefinedFunctionDisplay {
1082                        func_call: x,
1083                        input_schema: self.input_schema
1084                    }
1085                )
1086            }
1087            ExprImpl::Parameter(x) => write!(f, "{:?}", x),
1088            ExprImpl::Now(x) => write!(f, "{:?}", x),
1089        }
1090    }
1091}
1092
1093impl std::fmt::Display for ExprDisplay<'_> {
1094    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1095        (self as &dyn std::fmt::Debug).fmt(f)
1096    }
1097}
1098
1099#[cfg(test)]
1100/// Asserts that the expression is an [`InputRef`] with the given index.
1101macro_rules! assert_eq_input_ref {
1102    ($e:expr, $index:expr) => {
1103        match $e {
1104            ExprImpl::InputRef(i) => assert_eq!(i.index(), $index),
1105            _ => assert!(false, "Expected input ref, found {:?}", $e),
1106        }
1107    };
1108}
1109
1110#[cfg(test)]
1111pub(crate) use assert_eq_input_ref;
1112use risingwave_common::bail;
1113use risingwave_common::catalog::Schema;
1114use risingwave_common::row::OwnedRow;
1115
1116use crate::utils::Condition;
1117
1118#[cfg(test)]
1119mod tests {
1120    use super::*;
1121
1122    #[test]
1123    fn test_expr_debug_alternate() {
1124        let mut e = InputRef::new(1, DataType::Boolean).into();
1125        e = FunctionCall::new(ExprType::Not, vec![e]).unwrap().into();
1126        let s = format!("{:#?}", e);
1127        assert!(s.contains("return_type: Boolean"))
1128    }
1129}