risingwave_frontend/expr/
mod.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use enum_as_inner::EnumAsInner;
16use fixedbitset::FixedBitSet;
17use futures::FutureExt;
18use paste::paste;
19use risingwave_common::array::ListValue;
20use risingwave_common::types::{DataType, Datum, JsonbVal, MapType, Scalar, ScalarImpl};
21use risingwave_expr::aggregate::PbAggKind;
22use risingwave_expr::expr::build_from_prost;
23use risingwave_pb::expr::expr_node::RexNode;
24use risingwave_pb::expr::{ExprNode, ProjectSetSelectItem};
25use user_defined_function::UserDefinedFunctionDisplay;
26
27use crate::error::{ErrorCode, Result as RwResult};
28
29mod agg_call;
30mod correlated_input_ref;
31mod function_call;
32mod function_call_with_lambda;
33mod input_ref;
34mod literal;
35mod now;
36mod parameter;
37mod pure;
38mod subquery;
39mod table_function;
40mod user_defined_function;
41mod window_function;
42
43mod order_by_expr;
44pub use order_by_expr::{OrderBy, OrderByExpr};
45
46mod expr_mutator;
47mod expr_rewriter;
48mod expr_visitor;
49pub mod function_impl;
50mod session_timezone;
51mod type_inference;
52mod utils;
53
54pub use agg_call::AggCall;
55pub use correlated_input_ref::{CorrelatedId, CorrelatedInputRef, Depth, InputRefDepthRewriter};
56pub use expr_mutator::ExprMutator;
57pub use expr_rewriter::{ExprRewriter, default_rewrite_expr};
58pub use expr_visitor::{ExprVisitor, default_visit_expr};
59pub use function_call::{FunctionCall, FunctionCallDisplay, is_row_function};
60pub use function_call_with_lambda::FunctionCallWithLambda;
61pub use input_ref::{InputRef, InputRefDisplay, input_ref_to_column_indices};
62pub use literal::Literal;
63pub use now::{InlineNowProcTime, Now, NowProcTimeFinder};
64pub use parameter::Parameter;
65pub use pure::*;
66pub use risingwave_pb::expr::expr_node::Type as ExprType;
67pub use session_timezone::{SessionTimezone, TimestamptzExprFinder};
68pub use subquery::{Subquery, SubqueryKind};
69pub use table_function::{TableFunction, TableFunctionType};
70pub use type_inference::*;
71pub use user_defined_function::UserDefinedFunction;
72pub use utils::*;
73pub use window_function::WindowFunction;
74
75pub(crate) const EXPR_DEPTH_THRESHOLD: usize = 30;
76pub(crate) const EXPR_TOO_DEEP_NOTICE: &str = "Some expression is too complicated. \
77Consider simplifying or splitting the query if you encounter any issues.";
78
79/// the trait of bound expressions
80pub trait Expr: Into<ExprImpl> {
81    /// Get the return type of the expr
82    fn return_type(&self) -> DataType;
83
84    /// Try to serialize the expression, returning an error if it's impossible.
85    fn try_to_expr_proto(&self) -> Result<ExprNode, String>;
86
87    /// Serialize the expression. Panic if it's impossible.
88    fn to_expr_proto(&self) -> ExprNode {
89        self.try_to_expr_proto()
90            .expect("failed to serialize expression to protobuf")
91    }
92}
93
94macro_rules! impl_expr_impl {
95    ($($t:ident,)*) => {
96        #[derive(Clone, Eq, PartialEq, Hash, EnumAsInner)]
97        pub enum ExprImpl {
98            $($t(Box<$t>),)*
99        }
100
101        impl ExprImpl {
102            pub fn variant_name(&self) -> &'static str {
103                match self {
104                    $(ExprImpl::$t(_) => stringify!($t),)*
105                }
106            }
107        }
108
109        $(
110        impl From<$t> for ExprImpl {
111            fn from(o: $t) -> ExprImpl {
112                ExprImpl::$t(Box::new(o))
113            }
114        })*
115
116        impl Expr for ExprImpl {
117            fn return_type(&self) -> DataType {
118                match self {
119                    $(ExprImpl::$t(expr) => expr.return_type(),)*
120                }
121            }
122
123            fn try_to_expr_proto(&self) -> Result<ExprNode, String> {
124                match self {
125                    $(ExprImpl::$t(expr) => expr.try_to_expr_proto(),)*
126                }
127            }
128        }
129    };
130}
131
132impl_expr_impl!(
133    // BoundColumnRef, might be used in binder.
134    CorrelatedInputRef,
135    InputRef,
136    Literal,
137    FunctionCall,
138    FunctionCallWithLambda,
139    AggCall,
140    Subquery,
141    TableFunction,
142    WindowFunction,
143    UserDefinedFunction,
144    Parameter,
145    Now,
146);
147
148impl ExprImpl {
149    /// A literal int value.
150    #[inline(always)]
151    pub fn literal_int(v: i32) -> Self {
152        Literal::new(Some(v.to_scalar_value()), DataType::Int32).into()
153    }
154
155    /// A literal bigint value
156    #[inline(always)]
157    pub fn literal_bigint(v: i64) -> Self {
158        Literal::new(Some(v.to_scalar_value()), DataType::Int64).into()
159    }
160
161    /// A literal float64 value.
162    #[inline(always)]
163    pub fn literal_f64(v: f64) -> Self {
164        Literal::new(Some(v.into()), DataType::Float64).into()
165    }
166
167    /// A literal boolean value.
168    #[inline(always)]
169    pub fn literal_bool(v: bool) -> Self {
170        Literal::new(Some(v.to_scalar_value()), DataType::Boolean).into()
171    }
172
173    /// A literal varchar value.
174    #[inline(always)]
175    pub fn literal_varchar(v: String) -> Self {
176        Literal::new(Some(v.into()), DataType::Varchar).into()
177    }
178
179    /// A literal null value.
180    #[inline(always)]
181    pub fn literal_null(element_type: DataType) -> Self {
182        Literal::new(None, element_type).into()
183    }
184
185    /// A literal jsonb value.
186    #[inline(always)]
187    pub fn literal_jsonb(v: JsonbVal) -> Self {
188        Literal::new(Some(v.into()), DataType::Jsonb).into()
189    }
190
191    /// A literal list value.
192    #[inline(always)]
193    pub fn literal_list(v: ListValue, element_type: DataType) -> Self {
194        Literal::new(Some(v.to_scalar_value()), DataType::list(element_type)).into()
195    }
196
197    /// Takes the expression, leaving a literal null of the same type in its place.
198    pub fn take(&mut self) -> Self {
199        std::mem::replace(self, Self::literal_null(self.return_type()))
200    }
201
202    /// A `count(*)` aggregate function.
203    #[inline(always)]
204    pub fn count_star() -> Self {
205        AggCall::new(
206            PbAggKind::Count.into(),
207            vec![],
208            false,
209            OrderBy::any(),
210            Condition::true_cond(),
211            vec![],
212        )
213        .unwrap()
214        .into()
215    }
216
217    /// Create a new expression by merging the given expressions by `And`.
218    ///
219    /// If `exprs` is empty, return a literal `true`.
220    pub fn and(exprs: impl IntoIterator<Item = ExprImpl>) -> Self {
221        merge_expr_by_logical(exprs, ExprType::And, ExprImpl::literal_bool(true))
222    }
223
224    /// Create a new expression by merging the given expressions by `Or`.
225    ///
226    /// If `exprs` is empty, return a literal `false`.
227    pub fn or(exprs: impl IntoIterator<Item = ExprImpl>) -> Self {
228        merge_expr_by_logical(exprs, ExprType::Or, ExprImpl::literal_bool(false))
229    }
230
231    /// Collect all `InputRef`s' indexes in the expression.
232    ///
233    /// # Panics
234    /// Panics if `input_ref >= input_col_num`.
235    pub fn collect_input_refs(&self, input_col_num: usize) -> FixedBitSet {
236        collect_input_refs(input_col_num, [self])
237    }
238
239    /// Check if the expression has no side effects and output is deterministic
240    pub fn is_pure(&self) -> bool {
241        is_pure(self)
242    }
243
244    pub fn is_impure(&self) -> bool {
245        is_impure(self)
246    }
247
248    /// Count `Now`s in the expression.
249    pub fn count_nows(&self) -> usize {
250        let mut visitor = CountNow::default();
251        visitor.visit_expr(self);
252        visitor.count()
253    }
254
255    /// Check whether self is literal NULL.
256    pub fn is_null(&self) -> bool {
257        matches!(self, ExprImpl::Literal(literal) if literal.get_data().is_none())
258    }
259
260    /// Check whether self is a literal NULL or literal string.
261    pub fn is_untyped(&self) -> bool {
262        matches!(self, ExprImpl::Literal(literal) if literal.is_untyped())
263            || matches!(self, ExprImpl::Parameter(parameter) if !parameter.has_infer())
264    }
265
266    /// Shorthand to create cast expr to `target` type in implicit context.
267    pub fn cast_implicit(mut self, target: &DataType) -> Result<ExprImpl, CastError> {
268        FunctionCall::cast_mut(&mut self, target, CastContext::Implicit)?;
269        Ok(self)
270    }
271
272    /// Shorthand to create cast expr to `target` type in assign context.
273    pub fn cast_assign(mut self, target: &DataType) -> Result<ExprImpl, CastError> {
274        FunctionCall::cast_mut(&mut self, target, CastContext::Assign)?;
275        Ok(self)
276    }
277
278    /// Shorthand to create cast expr to `target` type in explicit context.
279    pub fn cast_explicit(mut self, target: &DataType) -> Result<ExprImpl, CastError> {
280        FunctionCall::cast_mut(&mut self, target, CastContext::Explicit)?;
281        Ok(self)
282    }
283
284    /// Shorthand to inplace cast expr to `target` type in implicit context.
285    pub fn cast_implicit_mut(&mut self, target: &DataType) -> Result<(), CastError> {
286        FunctionCall::cast_mut(self, target, CastContext::Implicit)
287    }
288
289    /// Shorthand to inplace cast expr to `target` type in explicit context.
290    pub fn cast_explicit_mut(&mut self, target: &DataType) -> Result<(), CastError> {
291        FunctionCall::cast_mut(self, target, CastContext::Explicit)
292    }
293
294    /// Casting to Regclass type means getting the oid of expr.
295    /// See <https://www.postgresql.org/docs/current/datatype-oid.html>
296    pub fn cast_to_regclass(self) -> Result<ExprImpl, CastError> {
297        match self.return_type() {
298            DataType::Varchar => Ok(ExprImpl::FunctionCall(Box::new(
299                FunctionCall::new_unchecked(ExprType::CastRegclass, vec![self], DataType::Int32),
300            ))),
301            DataType::Int32 => Ok(self),
302            dt if dt.is_int() => Ok(self.cast_explicit(&DataType::Int32)?),
303            _ => bail_cast_error!("unsupported input type"),
304        }
305    }
306
307    /// Shorthand to inplace cast expr to `regclass` type.
308    pub fn cast_to_regclass_mut(&mut self) -> Result<(), CastError> {
309        let owned = std::mem::replace(self, ExprImpl::literal_bool(false));
310        *self = owned.cast_to_regclass()?;
311        Ok(())
312    }
313
314    /// Ensure the return type of this expression is an array of some type.
315    pub fn ensure_array_type(&self) -> Result<(), ErrorCode> {
316        if self.is_untyped() {
317            return Err(ErrorCode::BindError(
318                "could not determine polymorphic type because input has type unknown".into(),
319            ));
320        }
321        match self.return_type() {
322            DataType::List(_) => Ok(()),
323            t => Err(ErrorCode::BindError(format!("expects array but got {t}"))),
324        }
325    }
326
327    /// Ensure the return type of this expression is a map of some type.
328    pub fn try_into_map_type(&self) -> Result<MapType, ErrorCode> {
329        if self.is_untyped() {
330            return Err(ErrorCode::BindError(
331                "could not determine polymorphic type because input has type unknown".into(),
332            ));
333        }
334        match self.return_type() {
335            DataType::Map(m) => Ok(m),
336            t => Err(ErrorCode::BindError(format!("expects map but got {t}"))),
337        }
338    }
339
340    /// Shorthand to enforce implicit cast to boolean
341    pub fn enforce_bool_clause(self, clause: &str) -> RwResult<ExprImpl> {
342        if self.is_untyped() {
343            let inner = self.cast_implicit(&DataType::Boolean)?;
344            return Ok(inner);
345        }
346        let return_type = self.return_type();
347        if return_type != DataType::Boolean {
348            bail!(
349                "argument of {} must be boolean, not type {:?}",
350                clause,
351                return_type
352            )
353        }
354        Ok(self)
355    }
356
357    /// Create "cast" expr to string (`varchar`) type. This is different from a real cast, as
358    /// boolean is converted to a single char rather than full word.
359    ///
360    /// Choose between `cast_output` and `cast_{assign,explicit}(Varchar)` based on `PostgreSQL`'s
361    /// behavior on bools. For example, `concat(':', true)` is `:t` but `':' || true` is `:true`.
362    /// All other types have the same behavior when formatting to output and casting to string.
363    ///
364    /// References in `PostgreSQL`:
365    /// * [cast](https://github.com/postgres/postgres/blob/a3ff08e0b08dbfeb777ccfa8f13ebaa95d064c04/src/include/catalog/pg_cast.dat#L437-L444)
366    /// * [impl](https://github.com/postgres/postgres/blob/27b77ecf9f4d5be211900eda54d8155ada50d696/src/backend/utils/adt/bool.c#L204-L209)
367    pub fn cast_output(self) -> RwResult<ExprImpl> {
368        if self.return_type() == DataType::Boolean {
369            return Ok(FunctionCall::new(ExprType::BoolOut, vec![self])?.into());
370        }
371        // Use normal cast for other types. Both `assign` and `explicit` can pass the castability
372        // check and there is no difference.
373        self.cast_assign(&DataType::Varchar)
374            .map_err(|err| err.into())
375    }
376
377    /// Evaluate the expression on the given input.
378    ///
379    /// TODO: This is a naive implementation. We should avoid proto ser/de.
380    /// Tracking issue: <https://github.com/risingwavelabs/risingwave/issues/3479>
381    pub async fn eval_row(&self, input: &OwnedRow) -> RwResult<Datum> {
382        let backend_expr = build_from_prost(&self.to_expr_proto())?;
383        Ok(backend_expr.eval_row(input).await?)
384    }
385
386    /// Try to evaluate an expression if it's a constant expression by `ExprImpl::is_const`.
387    ///
388    /// Returns...
389    /// - `None` if it's not a constant expression,
390    /// - `Some(Ok(_))` if constant evaluation succeeds,
391    /// - `Some(Err(_))` if there's an error while evaluating a constant expression.
392    pub fn try_fold_const(&self) -> Option<RwResult<Datum>> {
393        if self.is_const() {
394            self.eval_row(&OwnedRow::empty())
395                .now_or_never()
396                .expect("constant expression should not be async")
397                .into()
398        } else {
399            None
400        }
401    }
402
403    /// Similar to `ExprImpl::try_fold_const`, but panics if the expression is not constant.
404    pub fn fold_const(&self) -> RwResult<Datum> {
405        self.try_fold_const().expect("expression is not constant")
406    }
407}
408
409/// Implement helper functions which recursively checks whether an variant is included in the
410/// expression. e.g., `has_subquery(&self) -> bool`
411///
412/// It will not traverse inside subqueries.
413macro_rules! impl_has_variant {
414    ( $($variant:ty),* ) => {
415        paste! {
416            impl ExprImpl {
417                $(
418                    pub fn [<has_ $variant:snake>](&self) -> bool {
419                        struct Has { has: bool }
420
421                        impl ExprVisitor for Has {
422                            fn [<visit_ $variant:snake>](&mut self, _: &$variant) {
423                                self.has = true;
424                            }
425                        }
426
427                        let mut visitor = Has { has: false };
428                        visitor.visit_expr(self);
429                        visitor.has
430                    }
431                )*
432            }
433        }
434    };
435}
436
437impl_has_variant! {InputRef, Literal, FunctionCall, FunctionCallWithLambda, AggCall, Subquery, TableFunction, WindowFunction, UserDefinedFunction, Now}
438
439#[derive(Debug, Clone, PartialEq, Eq, Hash)]
440pub struct InequalityInputPair {
441    /// Input index of greater side of inequality.
442    pub(crate) key_required_larger: usize,
443    /// Input index of less side of inequality.
444    pub(crate) key_required_smaller: usize,
445    /// greater >= less + `delta_expression`
446    pub(crate) delta_expression: Option<(ExprType, ExprImpl)>,
447}
448
449impl InequalityInputPair {
450    fn new(
451        key_required_larger: usize,
452        key_required_smaller: usize,
453        delta_expression: Option<(ExprType, ExprImpl)>,
454    ) -> Self {
455        Self {
456            key_required_larger,
457            key_required_smaller,
458            delta_expression,
459        }
460    }
461}
462
463impl ExprImpl {
464    /// This function is not meant to be called. In most cases you would want
465    /// [`ExprImpl::has_correlated_input_ref_by_depth`].
466    ///
467    /// When an expr contains a [`CorrelatedInputRef`] with lower depth, the whole expr is still
468    /// considered to be uncorrelated, and can be checked with [`ExprImpl::has_subquery`] as well.
469    /// See examples on [`crate::binder::BoundQuery::is_correlated_by_depth`] for details.
470    ///
471    /// This is a placeholder to trigger a compiler error when a trivial implementation checking for
472    /// enum variant is generated by accident. It cannot be called either because you cannot pass
473    /// `Infallible` to it.
474    pub fn has_correlated_input_ref(&self, _: std::convert::Infallible) -> bool {
475        unreachable!()
476    }
477
478    /// Used to check whether the expression has [`CorrelatedInputRef`].
479    ///
480    /// This is the core logic that supports [`crate::binder::BoundQuery::is_correlated_by_depth`]. Check the
481    /// doc of it for examples of `depth` being equal, less or greater.
482    // We need to traverse inside subqueries.
483    pub fn has_correlated_input_ref_by_depth(&self, depth: Depth) -> bool {
484        struct Has {
485            depth: usize,
486            has: bool,
487        }
488
489        impl ExprVisitor for Has {
490            fn visit_correlated_input_ref(&mut self, correlated_input_ref: &CorrelatedInputRef) {
491                if correlated_input_ref.depth() == self.depth {
492                    self.has = true;
493                }
494            }
495
496            fn visit_subquery(&mut self, subquery: &Subquery) {
497                self.has |= subquery.is_correlated_by_depth(self.depth);
498            }
499        }
500
501        let mut visitor = Has { depth, has: false };
502        visitor.visit_expr(self);
503        visitor.has
504    }
505
506    pub fn has_correlated_input_ref_by_correlated_id(&self, correlated_id: CorrelatedId) -> bool {
507        struct Has {
508            correlated_id: CorrelatedId,
509            has: bool,
510        }
511
512        impl ExprVisitor for Has {
513            fn visit_correlated_input_ref(&mut self, correlated_input_ref: &CorrelatedInputRef) {
514                if correlated_input_ref.correlated_id() == self.correlated_id {
515                    self.has = true;
516                }
517            }
518
519            fn visit_subquery(&mut self, subquery: &Subquery) {
520                self.has |= subquery.is_correlated_by_correlated_id(self.correlated_id);
521            }
522        }
523
524        let mut visitor = Has {
525            correlated_id,
526            has: false,
527        };
528        visitor.visit_expr(self);
529        visitor.has
530    }
531
532    /// Collect `CorrelatedInputRef`s in `ExprImpl` by relative `depth`, return their indices, and
533    /// assign absolute `correlated_id` for them.
534    pub fn collect_correlated_indices_by_depth_and_assign_id(
535        &mut self,
536        depth: Depth,
537        correlated_id: CorrelatedId,
538    ) -> Vec<usize> {
539        struct Collector {
540            depth: Depth,
541            correlated_indices: Vec<usize>,
542            correlated_id: CorrelatedId,
543        }
544
545        impl ExprMutator for Collector {
546            fn visit_correlated_input_ref(
547                &mut self,
548                correlated_input_ref: &mut CorrelatedInputRef,
549            ) {
550                if correlated_input_ref.depth() == self.depth {
551                    self.correlated_indices.push(correlated_input_ref.index());
552                    correlated_input_ref.set_correlated_id(self.correlated_id);
553                }
554            }
555
556            fn visit_subquery(&mut self, subquery: &mut Subquery) {
557                self.correlated_indices.extend(
558                    subquery.collect_correlated_indices_by_depth_and_assign_id(
559                        self.depth,
560                        self.correlated_id,
561                    ),
562                );
563            }
564        }
565
566        let mut collector = Collector {
567            depth,
568            correlated_indices: vec![],
569            correlated_id,
570        };
571        collector.visit_expr(self);
572        collector.correlated_indices.sort();
573        collector.correlated_indices.dedup();
574        collector.correlated_indices
575    }
576
577    pub fn only_literal_and_func(&self) -> bool {
578        {
579            struct HasOthers {
580                has_others: bool,
581            }
582
583            impl ExprVisitor for HasOthers {
584                fn visit_expr(&mut self, expr: &ExprImpl) {
585                    match expr {
586                        ExprImpl::CorrelatedInputRef(_)
587                        | ExprImpl::InputRef(_)
588                        | ExprImpl::AggCall(_)
589                        | ExprImpl::Subquery(_)
590                        | ExprImpl::TableFunction(_)
591                        | ExprImpl::WindowFunction(_)
592                        | ExprImpl::UserDefinedFunction(_)
593                        | ExprImpl::Parameter(_)
594                        | ExprImpl::Now(_) => self.has_others = true,
595                        ExprImpl::Literal(_inner) => {}
596                        ExprImpl::FunctionCall(inner) => {
597                            if !self.is_short_circuit(inner) {
598                                // only if the current `func_call` is *not* a short-circuit
599                                // expression, e.g., true or (...) | false and (...),
600                                // shall we proceed to visit it.
601                                self.visit_function_call(inner)
602                            }
603                        }
604                        ExprImpl::FunctionCallWithLambda(inner) => {
605                            self.visit_function_call_with_lambda(inner)
606                        }
607                    }
608                }
609            }
610
611            impl HasOthers {
612                fn is_short_circuit(&self, func_call: &FunctionCall) -> bool {
613                    /// evaluate the first parameter of `Or` or `And` function call
614                    fn eval_first(e: &ExprImpl, expect: bool) -> bool {
615                        if let ExprImpl::Literal(l) = e {
616                            *l.get_data() == Some(ScalarImpl::Bool(expect))
617                        } else {
618                            false
619                        }
620                    }
621
622                    match func_call.func_type {
623                        ExprType::Or => eval_first(&func_call.inputs()[0], true),
624                        ExprType::And => eval_first(&func_call.inputs()[0], false),
625                        _ => false,
626                    }
627                }
628            }
629
630            let mut visitor = HasOthers { has_others: false };
631            visitor.visit_expr(self);
632            !visitor.has_others
633        }
634    }
635
636    /// Checks whether this is a constant expr that can be evaluated over a dummy chunk.
637    ///
638    /// The expression tree should only consist of literals and **pure** function calls.
639    pub fn is_const(&self) -> bool {
640        self.only_literal_and_func() && self.is_pure()
641    }
642
643    /// Returns the `InputRefs` of an Equality predicate if it matches
644    /// ordered by the canonical ordering (lower, higher), else returns None
645    pub fn as_eq_cond(&self) -> Option<(InputRef, InputRef)> {
646        if let ExprImpl::FunctionCall(function_call) = self
647            && function_call.func_type() == ExprType::Equal
648            && let (_, ExprImpl::InputRef(x), ExprImpl::InputRef(y)) =
649                function_call.clone().decompose_as_binary()
650        {
651            if x.index() < y.index() {
652                Some((*x, *y))
653            } else {
654                Some((*y, *x))
655            }
656        } else {
657            None
658        }
659    }
660
661    pub fn as_is_not_distinct_from_cond(&self) -> Option<(InputRef, InputRef)> {
662        if let ExprImpl::FunctionCall(function_call) = self
663            && function_call.func_type() == ExprType::IsNotDistinctFrom
664            && let (_, ExprImpl::InputRef(x), ExprImpl::InputRef(y)) =
665                function_call.clone().decompose_as_binary()
666        {
667            if x.index() < y.index() {
668                Some((*x, *y))
669            } else {
670                Some((*y, *x))
671            }
672        } else {
673            None
674        }
675    }
676
677    pub fn reverse_comparison(comparison: ExprType) -> ExprType {
678        match comparison {
679            ExprType::LessThan => ExprType::GreaterThan,
680            ExprType::LessThanOrEqual => ExprType::GreaterThanOrEqual,
681            ExprType::GreaterThan => ExprType::LessThan,
682            ExprType::GreaterThanOrEqual => ExprType::LessThanOrEqual,
683            ExprType::Equal | ExprType::IsNotDistinctFrom => comparison,
684            _ => unreachable!(),
685        }
686    }
687
688    pub fn as_comparison_cond(&self) -> Option<(InputRef, ExprType, InputRef)> {
689        if let ExprImpl::FunctionCall(function_call) = self {
690            match function_call.func_type() {
691                ty @ (ExprType::LessThan
692                | ExprType::LessThanOrEqual
693                | ExprType::GreaterThan
694                | ExprType::GreaterThanOrEqual) => {
695                    let (_, op1, op2) = function_call.clone().decompose_as_binary();
696                    if let (ExprImpl::InputRef(x), ExprImpl::InputRef(y)) = (op1, op2) {
697                        if x.index < y.index {
698                            Some((*x, ty, *y))
699                        } else {
700                            Some((*y, Self::reverse_comparison(ty), *x))
701                        }
702                    } else {
703                        None
704                    }
705                }
706                _ => None,
707            }
708        } else {
709            None
710        }
711    }
712
713    /// Accepts expressions of the form `input_expr cmp now_expr` or `now_expr cmp input_expr`,
714    /// where `input_expr` contains an `InputRef` and contains no `now()`, and `now_expr`
715    /// contains a `now()` but no `InputRef`.
716    ///
717    /// Canonicalizes to the first ordering and returns `(input_expr, cmp, now_expr)`
718    pub fn as_now_comparison_cond(&self) -> Option<(ExprImpl, ExprType, ExprImpl)> {
719        if let ExprImpl::FunctionCall(function_call) = self {
720            match function_call.func_type() {
721                ty @ (ExprType::Equal
722                | ExprType::LessThan
723                | ExprType::LessThanOrEqual
724                | ExprType::GreaterThan
725                | ExprType::GreaterThanOrEqual) => {
726                    let (_, op1, op2) = function_call.clone().decompose_as_binary();
727                    if !op1.has_now()
728                        && op1.has_input_ref()
729                        && op2.has_now()
730                        && !op2.has_input_ref()
731                    {
732                        Some((op1, ty, op2))
733                    } else if op1.has_now()
734                        && !op1.has_input_ref()
735                        && !op2.has_now()
736                        && op2.has_input_ref()
737                    {
738                        Some((op2, Self::reverse_comparison(ty), op1))
739                    } else {
740                        None
741                    }
742                }
743                _ => None,
744            }
745        } else {
746            None
747        }
748    }
749
750    /// Accepts expressions of the form `InputRef cmp InputRef [+- const_expr]` or
751    /// `InputRef [+- const_expr] cmp InputRef`.
752    pub(crate) fn as_input_comparison_cond(&self) -> Option<InequalityInputPair> {
753        if let ExprImpl::FunctionCall(function_call) = self {
754            match function_call.func_type() {
755                ty @ (ExprType::LessThan
756                | ExprType::LessThanOrEqual
757                | ExprType::GreaterThan
758                | ExprType::GreaterThanOrEqual) => {
759                    let (_, mut op1, mut op2) = function_call.clone().decompose_as_binary();
760                    if matches!(ty, ExprType::LessThan | ExprType::LessThanOrEqual) {
761                        std::mem::swap(&mut op1, &mut op2);
762                    }
763                    if let (Some((lft_input, lft_offset)), Some((rht_input, rht_offset))) =
764                        (op1.as_input_offset(), op2.as_input_offset())
765                    {
766                        match (lft_offset, rht_offset) {
767                            (Some(_), Some(_)) => None,
768                            (None, rht_offset @ Some(_)) => {
769                                Some(InequalityInputPair::new(lft_input, rht_input, rht_offset))
770                            }
771                            (Some((operator, operand)), None) => Some(InequalityInputPair::new(
772                                lft_input,
773                                rht_input,
774                                Some((
775                                    if operator == ExprType::Add {
776                                        ExprType::Subtract
777                                    } else {
778                                        ExprType::Add
779                                    },
780                                    operand,
781                                )),
782                            )),
783                            (None, None) => {
784                                Some(InequalityInputPair::new(lft_input, rht_input, None))
785                            }
786                        }
787                    } else {
788                        None
789                    }
790                }
791                _ => None,
792            }
793        } else {
794            None
795        }
796    }
797
798    /// Returns the `InputRef` and offset of a predicate if it matches
799    /// the form `InputRef [+- const_expr]`, else returns None.
800    fn as_input_offset(&self) -> Option<(usize, Option<(ExprType, ExprImpl)>)> {
801        match self {
802            ExprImpl::InputRef(input_ref) => Some((input_ref.index(), None)),
803            ExprImpl::FunctionCall(function_call) => {
804                let expr_type = function_call.func_type();
805                match expr_type {
806                    ExprType::Add | ExprType::Subtract => {
807                        let (_, lhs, rhs) = function_call.clone().decompose_as_binary();
808                        if let ExprImpl::InputRef(input_ref) = &lhs
809                            && rhs.is_const()
810                        {
811                            // Currently we will return `None` for non-literal because the result of the expression might be '1 day'. However, there will definitely exist false positives such as '1 second + 1 second'.
812                            // We will treat the expression as an input offset when rhs is `null`.
813                            if rhs.return_type() == DataType::Interval
814                                && rhs.as_literal().is_none_or(|literal| {
815                                    literal.get_data().as_ref().is_some_and(|scalar| {
816                                        let interval = scalar.as_interval();
817                                        interval.months() != 0 || interval.days() != 0
818                                    })
819                                })
820                            {
821                                None
822                            } else {
823                                Some((input_ref.index(), Some((expr_type, rhs))))
824                            }
825                        } else {
826                            None
827                        }
828                    }
829                    _ => None,
830                }
831            }
832            _ => None,
833        }
834    }
835
836    pub fn as_eq_const(&self) -> Option<(InputRef, ExprImpl)> {
837        if let ExprImpl::FunctionCall(function_call) = self
838            && function_call.func_type() == ExprType::Equal
839        {
840            match function_call.clone().decompose_as_binary() {
841                (_, ExprImpl::InputRef(x), y) if y.is_const() => Some((*x, y)),
842                (_, x, ExprImpl::InputRef(y)) if x.is_const() => Some((*y, x)),
843                _ => None,
844            }
845        } else {
846            None
847        }
848    }
849
850    pub fn as_eq_correlated_input_ref(&self) -> Option<(InputRef, CorrelatedInputRef)> {
851        if let ExprImpl::FunctionCall(function_call) = self
852            && function_call.func_type() == ExprType::Equal
853        {
854            match function_call.clone().decompose_as_binary() {
855                (_, ExprImpl::InputRef(x), ExprImpl::CorrelatedInputRef(y)) => Some((*x, *y)),
856                (_, ExprImpl::CorrelatedInputRef(x), ExprImpl::InputRef(y)) => Some((*y, *x)),
857                _ => None,
858            }
859        } else {
860            None
861        }
862    }
863
864    pub fn as_is_null(&self) -> Option<InputRef> {
865        if let ExprImpl::FunctionCall(function_call) = self
866            && function_call.func_type() == ExprType::IsNull
867        {
868            match function_call.clone().decompose_as_unary() {
869                (_, ExprImpl::InputRef(x)) => Some(*x),
870                _ => None,
871            }
872        } else {
873            None
874        }
875    }
876
877    pub fn as_comparison_const(&self) -> Option<(InputRef, ExprType, ExprImpl)> {
878        fn reverse_comparison(comparison: ExprType) -> ExprType {
879            match comparison {
880                ExprType::LessThan => ExprType::GreaterThan,
881                ExprType::LessThanOrEqual => ExprType::GreaterThanOrEqual,
882                ExprType::GreaterThan => ExprType::LessThan,
883                ExprType::GreaterThanOrEqual => ExprType::LessThanOrEqual,
884                _ => unreachable!(),
885            }
886        }
887
888        if let ExprImpl::FunctionCall(function_call) = self {
889            match function_call.func_type() {
890                ty @ (ExprType::LessThan
891                | ExprType::LessThanOrEqual
892                | ExprType::GreaterThan
893                | ExprType::GreaterThanOrEqual) => {
894                    let (_, op1, op2) = function_call.clone().decompose_as_binary();
895                    match (op1, op2) {
896                        (ExprImpl::InputRef(x), y) if y.is_const() => Some((*x, ty, y)),
897                        (x, ExprImpl::InputRef(y)) if x.is_const() => {
898                            Some((*y, reverse_comparison(ty), x))
899                        }
900                        _ => None,
901                    }
902                }
903                _ => None,
904            }
905        } else {
906            None
907        }
908    }
909
910    pub fn as_in_const_list(&self) -> Option<(InputRef, Vec<ExprImpl>)> {
911        if let ExprImpl::FunctionCall(function_call) = self
912            && function_call.func_type() == ExprType::In
913        {
914            let mut inputs = function_call.inputs().iter().cloned();
915            let input_ref = match inputs.next().unwrap() {
916                ExprImpl::InputRef(i) => *i,
917                _ => return None,
918            };
919            let list: Vec<_> = inputs
920                .inspect(|expr| {
921                    // Non constant IN will be bound to OR
922                    assert!(expr.is_const());
923                })
924                .collect();
925
926            Some((input_ref, list))
927        } else {
928            None
929        }
930    }
931
932    pub fn as_or_disjunctions(&self) -> Option<Vec<ExprImpl>> {
933        if let ExprImpl::FunctionCall(function_call) = self
934            && function_call.func_type() == ExprType::Or
935        {
936            Some(to_disjunctions(self.clone()))
937        } else {
938            None
939        }
940    }
941
942    pub fn to_project_set_select_item_proto(&self) -> ProjectSetSelectItem {
943        use risingwave_pb::expr::project_set_select_item::SelectItem::*;
944
945        ProjectSetSelectItem {
946            select_item: Some(match self {
947                ExprImpl::TableFunction(tf) => TableFunction(tf.to_protobuf()),
948                expr => Expr(expr.to_expr_proto()),
949            }),
950        }
951    }
952
953    pub fn from_expr_proto(proto: &ExprNode) -> RwResult<Self> {
954        let rex_node = proto.get_rex_node()?;
955        let ret_type = proto.get_return_type()?.into();
956
957        Ok(match rex_node {
958            RexNode::InputRef(column_index) => Self::InputRef(Box::new(InputRef::from_expr_proto(
959                *column_index as _,
960                ret_type,
961            )?)),
962            RexNode::Constant(_) => Self::Literal(Box::new(Literal::from_expr_proto(proto)?)),
963            RexNode::Udf(udf) => Self::UserDefinedFunction(Box::new(
964                UserDefinedFunction::from_expr_proto(udf, ret_type)?,
965            )),
966            RexNode::FuncCall(function_call) => {
967                Self::FunctionCall(Box::new(FunctionCall::from_expr_proto(
968                    function_call,
969                    proto.get_function_type()?, // only interpret if it's a function call
970                    ret_type,
971                )?))
972            }
973            RexNode::Now(_) => Self::Now(Box::new(Now {})),
974        })
975    }
976}
977
978impl From<Condition> for ExprImpl {
979    fn from(c: Condition) -> Self {
980        ExprImpl::and(c.conjunctions)
981    }
982}
983
984/// A custom Debug implementation that is more concise and suitable to use with
985/// [`std::fmt::Formatter::debug_list`] in plan nodes. If the verbose output is preferred, it is
986/// still available via `{:#?}`.
987impl std::fmt::Debug for ExprImpl {
988    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
989        if f.alternate() {
990            return match self {
991                Self::InputRef(arg0) => f.debug_tuple("InputRef").field(arg0).finish(),
992                Self::Literal(arg0) => f.debug_tuple("Literal").field(arg0).finish(),
993                Self::FunctionCall(arg0) => f.debug_tuple("FunctionCall").field(arg0).finish(),
994                Self::FunctionCallWithLambda(arg0) => {
995                    f.debug_tuple("FunctionCallWithLambda").field(arg0).finish()
996                }
997                Self::AggCall(arg0) => f.debug_tuple("AggCall").field(arg0).finish(),
998                Self::Subquery(arg0) => f.debug_tuple("Subquery").field(arg0).finish(),
999                Self::CorrelatedInputRef(arg0) => {
1000                    f.debug_tuple("CorrelatedInputRef").field(arg0).finish()
1001                }
1002                Self::TableFunction(arg0) => f.debug_tuple("TableFunction").field(arg0).finish(),
1003                Self::WindowFunction(arg0) => f.debug_tuple("WindowFunction").field(arg0).finish(),
1004                Self::UserDefinedFunction(arg0) => {
1005                    f.debug_tuple("UserDefinedFunction").field(arg0).finish()
1006                }
1007                Self::Parameter(arg0) => f.debug_tuple("Parameter").field(arg0).finish(),
1008                Self::Now(_) => f.debug_tuple("Now").finish(),
1009            };
1010        }
1011        match self {
1012            Self::InputRef(x) => write!(f, "{:?}", x),
1013            Self::Literal(x) => write!(f, "{:?}", x),
1014            Self::FunctionCall(x) => write!(f, "{:?}", x),
1015            Self::FunctionCallWithLambda(x) => write!(f, "{:?}", x),
1016            Self::AggCall(x) => write!(f, "{:?}", x),
1017            Self::Subquery(x) => write!(f, "{:?}", x),
1018            Self::CorrelatedInputRef(x) => write!(f, "{:?}", x),
1019            Self::TableFunction(x) => write!(f, "{:?}", x),
1020            Self::WindowFunction(x) => write!(f, "{:?}", x),
1021            Self::UserDefinedFunction(x) => write!(f, "{:?}", x),
1022            Self::Parameter(x) => write!(f, "{:?}", x),
1023            Self::Now(x) => write!(f, "{:?}", x),
1024        }
1025    }
1026}
1027
1028pub struct ExprDisplay<'a> {
1029    pub expr: &'a ExprImpl,
1030    pub input_schema: &'a Schema,
1031}
1032
1033impl std::fmt::Debug for ExprDisplay<'_> {
1034    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1035        let that = self.expr;
1036        match that {
1037            ExprImpl::InputRef(x) => write!(
1038                f,
1039                "{:?}",
1040                InputRefDisplay {
1041                    input_ref: x,
1042                    input_schema: self.input_schema
1043                }
1044            ),
1045            ExprImpl::Literal(x) => write!(f, "{:?}", x),
1046            ExprImpl::FunctionCall(x) => write!(
1047                f,
1048                "{:?}",
1049                FunctionCallDisplay {
1050                    function_call: x,
1051                    input_schema: self.input_schema
1052                }
1053            ),
1054            ExprImpl::FunctionCallWithLambda(x) => write!(
1055                f,
1056                "{:?}",
1057                FunctionCallDisplay {
1058                    function_call: &x.to_full_function_call(),
1059                    input_schema: self.input_schema
1060                }
1061            ),
1062            ExprImpl::AggCall(x) => write!(f, "{:?}", x),
1063            ExprImpl::Subquery(x) => write!(f, "{:?}", x),
1064            ExprImpl::CorrelatedInputRef(x) => write!(f, "{:?}", x),
1065            ExprImpl::TableFunction(x) => {
1066                // TODO: TableFunctionCallVerboseDisplay
1067                write!(f, "{:?}", x)
1068            }
1069            ExprImpl::WindowFunction(x) => {
1070                // TODO: WindowFunctionCallVerboseDisplay
1071                write!(f, "{:?}", x)
1072            }
1073            ExprImpl::UserDefinedFunction(x) => {
1074                write!(
1075                    f,
1076                    "{:?}",
1077                    UserDefinedFunctionDisplay {
1078                        func_call: x,
1079                        input_schema: self.input_schema
1080                    }
1081                )
1082            }
1083            ExprImpl::Parameter(x) => write!(f, "{:?}", x),
1084            ExprImpl::Now(x) => write!(f, "{:?}", x),
1085        }
1086    }
1087}
1088
1089impl std::fmt::Display for ExprDisplay<'_> {
1090    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1091        (self as &dyn std::fmt::Debug).fmt(f)
1092    }
1093}
1094
1095#[cfg(test)]
1096/// Asserts that the expression is an [`InputRef`] with the given index.
1097macro_rules! assert_eq_input_ref {
1098    ($e:expr, $index:expr) => {
1099        match $e {
1100            ExprImpl::InputRef(i) => assert_eq!(i.index(), $index),
1101            _ => assert!(false, "Expected input ref, found {:?}", $e),
1102        }
1103    };
1104}
1105
1106#[cfg(test)]
1107pub(crate) use assert_eq_input_ref;
1108use risingwave_common::bail;
1109use risingwave_common::catalog::Schema;
1110use risingwave_common::row::OwnedRow;
1111
1112use crate::utils::Condition;
1113
1114#[cfg(test)]
1115mod tests {
1116    use super::*;
1117
1118    #[test]
1119    fn test_expr_debug_alternate() {
1120        let mut e = InputRef::new(1, DataType::Boolean).into();
1121        e = FunctionCall::new(ExprType::Not, vec![e]).unwrap().into();
1122        let s = format!("{:#?}", e);
1123        assert!(s.contains("return_type: Boolean"))
1124    }
1125}