1use enum_as_inner::EnumAsInner;
16use fixedbitset::FixedBitSet;
17use futures::FutureExt;
18use paste::paste;
19use risingwave_common::array::ListValue;
20use risingwave_common::types::{DataType, Datum, JsonbVal, MapType, Scalar, ScalarImpl};
21use risingwave_expr::aggregate::PbAggKind;
22use risingwave_expr::expr::build_from_prost;
23use risingwave_pb::expr::expr_node::RexNode;
24use risingwave_pb::expr::{ExprNode, ProjectSetSelectItem};
25use user_defined_function::UserDefinedFunctionDisplay;
26
27use crate::error::{ErrorCode, Result as RwResult};
28
29mod agg_call;
30mod correlated_input_ref;
31mod function_call;
32mod function_call_with_lambda;
33mod input_ref;
34mod literal;
35mod now;
36mod parameter;
37mod pure;
38mod subquery;
39mod table_function;
40mod user_defined_function;
41mod window_function;
42
43mod order_by_expr;
44pub use order_by_expr::{OrderBy, OrderByExpr};
45
46mod expr_mutator;
47mod expr_rewriter;
48mod expr_visitor;
49pub mod function_impl;
50mod session_timezone;
51mod type_inference;
52mod utils;
53
54pub use agg_call::AggCall;
55pub use correlated_input_ref::{CorrelatedId, CorrelatedInputRef, Depth, InputRefDepthRewriter};
56pub use expr_mutator::ExprMutator;
57pub use expr_rewriter::{ExprRewriter, default_rewrite_expr};
58pub use expr_visitor::{ExprVisitor, default_visit_expr};
59pub use function_call::{FunctionCall, FunctionCallDisplay, is_row_function};
60pub use function_call_with_lambda::FunctionCallWithLambda;
61pub use input_ref::{InputRef, InputRefDisplay, input_ref_to_column_indices};
62pub use literal::Literal;
63pub use now::{InlineNowProcTime, Now, NowProcTimeFinder};
64pub use parameter::Parameter;
65pub use pure::*;
66pub use risingwave_pb::expr::expr_node::Type as ExprType;
67pub use session_timezone::{SessionTimezone, TimestamptzExprFinder};
68pub use subquery::{Subquery, SubqueryKind};
69pub use table_function::{TableFunction, TableFunctionType};
70pub use type_inference::*;
71pub use user_defined_function::UserDefinedFunction;
72pub use utils::*;
73pub use window_function::WindowFunction;
74
75pub(crate) const EXPR_DEPTH_THRESHOLD: usize = 30;
76pub(crate) const EXPR_TOO_DEEP_NOTICE: &str = "Some expression is too complicated. \
77Consider simplifying or splitting the query if you encounter any issues.";
78
79pub trait Expr: Into<ExprImpl> {
81 fn return_type(&self) -> DataType;
83
84 fn try_to_expr_proto(&self) -> Result<ExprNode, String>;
86
87 fn to_expr_proto(&self) -> ExprNode {
89 self.try_to_expr_proto()
90 .expect("failed to serialize expression to protobuf")
91 }
92}
93
94macro_rules! impl_expr_impl {
95 ($($t:ident,)*) => {
96 #[derive(Clone, Eq, PartialEq, Hash, EnumAsInner)]
97 pub enum ExprImpl {
98 $($t(Box<$t>),)*
99 }
100
101 impl ExprImpl {
102 pub fn variant_name(&self) -> &'static str {
103 match self {
104 $(ExprImpl::$t(_) => stringify!($t),)*
105 }
106 }
107 }
108
109 $(
110 impl From<$t> for ExprImpl {
111 fn from(o: $t) -> ExprImpl {
112 ExprImpl::$t(Box::new(o))
113 }
114 })*
115
116 impl Expr for ExprImpl {
117 fn return_type(&self) -> DataType {
118 match self {
119 $(ExprImpl::$t(expr) => expr.return_type(),)*
120 }
121 }
122
123 fn try_to_expr_proto(&self) -> Result<ExprNode, String> {
124 match self {
125 $(ExprImpl::$t(expr) => expr.try_to_expr_proto(),)*
126 }
127 }
128 }
129 };
130}
131
132impl_expr_impl!(
133 CorrelatedInputRef,
135 InputRef,
136 Literal,
137 FunctionCall,
138 FunctionCallWithLambda,
139 AggCall,
140 Subquery,
141 TableFunction,
142 WindowFunction,
143 UserDefinedFunction,
144 Parameter,
145 Now,
146);
147
148impl ExprImpl {
149 #[inline(always)]
151 pub fn literal_int(v: i32) -> Self {
152 Literal::new(Some(v.to_scalar_value()), DataType::Int32).into()
153 }
154
155 #[inline(always)]
157 pub fn literal_bigint(v: i64) -> Self {
158 Literal::new(Some(v.to_scalar_value()), DataType::Int64).into()
159 }
160
161 #[inline(always)]
163 pub fn literal_f64(v: f64) -> Self {
164 Literal::new(Some(v.into()), DataType::Float64).into()
165 }
166
167 #[inline(always)]
169 pub fn literal_bool(v: bool) -> Self {
170 Literal::new(Some(v.to_scalar_value()), DataType::Boolean).into()
171 }
172
173 #[inline(always)]
175 pub fn literal_varchar(v: String) -> Self {
176 Literal::new(Some(v.into()), DataType::Varchar).into()
177 }
178
179 #[inline(always)]
181 pub fn literal_null(element_type: DataType) -> Self {
182 Literal::new(None, element_type).into()
183 }
184
185 #[inline(always)]
187 pub fn literal_jsonb(v: JsonbVal) -> Self {
188 Literal::new(Some(v.into()), DataType::Jsonb).into()
189 }
190
191 #[inline(always)]
193 pub fn literal_list(v: ListValue, element_type: DataType) -> Self {
194 Literal::new(Some(v.to_scalar_value()), DataType::list(element_type)).into()
195 }
196
197 pub fn take(&mut self) -> Self {
199 std::mem::replace(self, Self::literal_null(self.return_type()))
200 }
201
202 #[inline(always)]
204 pub fn count_star() -> Self {
205 AggCall::new(
206 PbAggKind::Count.into(),
207 vec![],
208 false,
209 OrderBy::any(),
210 Condition::true_cond(),
211 vec![],
212 )
213 .unwrap()
214 .into()
215 }
216
217 pub fn and(exprs: impl IntoIterator<Item = ExprImpl>) -> Self {
221 merge_expr_by_logical(exprs, ExprType::And, ExprImpl::literal_bool(true))
222 }
223
224 pub fn or(exprs: impl IntoIterator<Item = ExprImpl>) -> Self {
228 merge_expr_by_logical(exprs, ExprType::Or, ExprImpl::literal_bool(false))
229 }
230
231 pub fn collect_input_refs(&self, input_col_num: usize) -> FixedBitSet {
236 collect_input_refs(input_col_num, [self])
237 }
238
239 pub fn is_pure(&self) -> bool {
241 is_pure(self)
242 }
243
244 pub fn is_impure(&self) -> bool {
245 is_impure(self)
246 }
247
248 pub fn count_nows(&self) -> usize {
250 let mut visitor = CountNow::default();
251 visitor.visit_expr(self);
252 visitor.count()
253 }
254
255 pub fn is_null(&self) -> bool {
257 matches!(self, ExprImpl::Literal(literal) if literal.get_data().is_none())
258 }
259
260 pub fn is_untyped(&self) -> bool {
262 matches!(self, ExprImpl::Literal(literal) if literal.is_untyped())
263 || matches!(self, ExprImpl::Parameter(parameter) if !parameter.has_infer())
264 }
265
266 pub fn cast_implicit(mut self, target: &DataType) -> Result<ExprImpl, CastError> {
268 FunctionCall::cast_mut(&mut self, target, CastContext::Implicit)?;
269 Ok(self)
270 }
271
272 pub fn cast_assign(mut self, target: &DataType) -> Result<ExprImpl, CastError> {
274 FunctionCall::cast_mut(&mut self, target, CastContext::Assign)?;
275 Ok(self)
276 }
277
278 pub fn cast_explicit(mut self, target: &DataType) -> Result<ExprImpl, CastError> {
280 FunctionCall::cast_mut(&mut self, target, CastContext::Explicit)?;
281 Ok(self)
282 }
283
284 pub fn cast_implicit_mut(&mut self, target: &DataType) -> Result<(), CastError> {
286 FunctionCall::cast_mut(self, target, CastContext::Implicit)
287 }
288
289 pub fn cast_explicit_mut(&mut self, target: &DataType) -> Result<(), CastError> {
291 FunctionCall::cast_mut(self, target, CastContext::Explicit)
292 }
293
294 pub fn cast_to_regclass(self) -> Result<ExprImpl, CastError> {
297 match self.return_type() {
298 DataType::Varchar => Ok(ExprImpl::FunctionCall(Box::new(
299 FunctionCall::new_unchecked(ExprType::CastRegclass, vec![self], DataType::Int32),
300 ))),
301 DataType::Int32 => Ok(self),
302 dt if dt.is_int() => Ok(self.cast_explicit(&DataType::Int32)?),
303 _ => bail_cast_error!("unsupported input type"),
304 }
305 }
306
307 pub fn cast_to_regclass_mut(&mut self) -> Result<(), CastError> {
309 let owned = std::mem::replace(self, ExprImpl::literal_bool(false));
310 *self = owned.cast_to_regclass()?;
311 Ok(())
312 }
313
314 pub fn ensure_array_type(&self) -> Result<(), ErrorCode> {
316 if self.is_untyped() {
317 return Err(ErrorCode::BindError(
318 "could not determine polymorphic type because input has type unknown".into(),
319 ));
320 }
321 match self.return_type() {
322 DataType::List(_) => Ok(()),
323 t => Err(ErrorCode::BindError(format!("expects array but got {t}"))),
324 }
325 }
326
327 pub fn try_into_map_type(&self) -> Result<MapType, ErrorCode> {
329 if self.is_untyped() {
330 return Err(ErrorCode::BindError(
331 "could not determine polymorphic type because input has type unknown".into(),
332 ));
333 }
334 match self.return_type() {
335 DataType::Map(m) => Ok(m),
336 t => Err(ErrorCode::BindError(format!("expects map but got {t}"))),
337 }
338 }
339
340 pub fn enforce_bool_clause(self, clause: &str) -> RwResult<ExprImpl> {
342 if self.is_untyped() {
343 let inner = self.cast_implicit(&DataType::Boolean)?;
344 return Ok(inner);
345 }
346 let return_type = self.return_type();
347 if return_type != DataType::Boolean {
348 bail!(
349 "argument of {} must be boolean, not type {:?}",
350 clause,
351 return_type
352 )
353 }
354 Ok(self)
355 }
356
357 pub fn cast_output(self) -> RwResult<ExprImpl> {
368 if self.return_type() == DataType::Boolean {
369 return Ok(FunctionCall::new(ExprType::BoolOut, vec![self])?.into());
370 }
371 self.cast_assign(&DataType::Varchar)
374 .map_err(|err| err.into())
375 }
376
377 pub async fn eval_row(&self, input: &OwnedRow) -> RwResult<Datum> {
382 let backend_expr = build_from_prost(&self.to_expr_proto())?;
383 Ok(backend_expr.eval_row(input).await?)
384 }
385
386 pub fn try_fold_const(&self) -> Option<RwResult<Datum>> {
393 if self.is_const() {
394 self.eval_row(&OwnedRow::empty())
395 .now_or_never()
396 .expect("constant expression should not be async")
397 .into()
398 } else {
399 None
400 }
401 }
402
403 pub fn fold_const(&self) -> RwResult<Datum> {
405 self.try_fold_const().expect("expression is not constant")
406 }
407}
408
409macro_rules! impl_has_variant {
414 ( $($variant:ty),* ) => {
415 paste! {
416 impl ExprImpl {
417 $(
418 pub fn [<has_ $variant:snake>](&self) -> bool {
419 struct Has { has: bool }
420
421 impl ExprVisitor for Has {
422 fn [<visit_ $variant:snake>](&mut self, _: &$variant) {
423 self.has = true;
424 }
425 }
426
427 let mut visitor = Has { has: false };
428 visitor.visit_expr(self);
429 visitor.has
430 }
431 )*
432 }
433 }
434 };
435}
436
437impl_has_variant! {InputRef, Literal, FunctionCall, FunctionCallWithLambda, AggCall, Subquery, TableFunction, WindowFunction, UserDefinedFunction, Now}
438
439#[derive(Debug, Clone, PartialEq, Eq, Hash)]
440pub struct InequalityInputPair {
441 pub(crate) key_required_larger: usize,
443 pub(crate) key_required_smaller: usize,
445 pub(crate) delta_expression: Option<(ExprType, ExprImpl)>,
447}
448
449impl InequalityInputPair {
450 fn new(
451 key_required_larger: usize,
452 key_required_smaller: usize,
453 delta_expression: Option<(ExprType, ExprImpl)>,
454 ) -> Self {
455 Self {
456 key_required_larger,
457 key_required_smaller,
458 delta_expression,
459 }
460 }
461}
462
463impl ExprImpl {
464 pub fn has_correlated_input_ref(&self, _: std::convert::Infallible) -> bool {
475 unreachable!()
476 }
477
478 pub fn has_correlated_input_ref_by_depth(&self, depth: Depth) -> bool {
484 struct Has {
485 depth: usize,
486 has: bool,
487 }
488
489 impl ExprVisitor for Has {
490 fn visit_correlated_input_ref(&mut self, correlated_input_ref: &CorrelatedInputRef) {
491 if correlated_input_ref.depth() == self.depth {
492 self.has = true;
493 }
494 }
495
496 fn visit_subquery(&mut self, subquery: &Subquery) {
497 self.has |= subquery.is_correlated_by_depth(self.depth);
498 }
499 }
500
501 let mut visitor = Has { depth, has: false };
502 visitor.visit_expr(self);
503 visitor.has
504 }
505
506 pub fn has_correlated_input_ref_by_correlated_id(&self, correlated_id: CorrelatedId) -> bool {
507 struct Has {
508 correlated_id: CorrelatedId,
509 has: bool,
510 }
511
512 impl ExprVisitor for Has {
513 fn visit_correlated_input_ref(&mut self, correlated_input_ref: &CorrelatedInputRef) {
514 if correlated_input_ref.correlated_id() == self.correlated_id {
515 self.has = true;
516 }
517 }
518
519 fn visit_subquery(&mut self, subquery: &Subquery) {
520 self.has |= subquery.is_correlated_by_correlated_id(self.correlated_id);
521 }
522 }
523
524 let mut visitor = Has {
525 correlated_id,
526 has: false,
527 };
528 visitor.visit_expr(self);
529 visitor.has
530 }
531
532 pub fn collect_correlated_indices_by_depth_and_assign_id(
535 &mut self,
536 depth: Depth,
537 correlated_id: CorrelatedId,
538 ) -> Vec<usize> {
539 struct Collector {
540 depth: Depth,
541 correlated_indices: Vec<usize>,
542 correlated_id: CorrelatedId,
543 }
544
545 impl ExprMutator for Collector {
546 fn visit_correlated_input_ref(
547 &mut self,
548 correlated_input_ref: &mut CorrelatedInputRef,
549 ) {
550 if correlated_input_ref.depth() == self.depth {
551 self.correlated_indices.push(correlated_input_ref.index());
552 correlated_input_ref.set_correlated_id(self.correlated_id);
553 }
554 }
555
556 fn visit_subquery(&mut self, subquery: &mut Subquery) {
557 self.correlated_indices.extend(
558 subquery.collect_correlated_indices_by_depth_and_assign_id(
559 self.depth,
560 self.correlated_id,
561 ),
562 );
563 }
564 }
565
566 let mut collector = Collector {
567 depth,
568 correlated_indices: vec![],
569 correlated_id,
570 };
571 collector.visit_expr(self);
572 collector.correlated_indices.sort();
573 collector.correlated_indices.dedup();
574 collector.correlated_indices
575 }
576
577 pub fn only_literal_and_func(&self) -> bool {
578 {
579 struct HasOthers {
580 has_others: bool,
581 }
582
583 impl ExprVisitor for HasOthers {
584 fn visit_expr(&mut self, expr: &ExprImpl) {
585 match expr {
586 ExprImpl::CorrelatedInputRef(_)
587 | ExprImpl::InputRef(_)
588 | ExprImpl::AggCall(_)
589 | ExprImpl::Subquery(_)
590 | ExprImpl::TableFunction(_)
591 | ExprImpl::WindowFunction(_)
592 | ExprImpl::UserDefinedFunction(_)
593 | ExprImpl::Parameter(_)
594 | ExprImpl::Now(_) => self.has_others = true,
595 ExprImpl::Literal(_inner) => {}
596 ExprImpl::FunctionCall(inner) => {
597 if !self.is_short_circuit(inner) {
598 self.visit_function_call(inner)
602 }
603 }
604 ExprImpl::FunctionCallWithLambda(inner) => {
605 self.visit_function_call_with_lambda(inner)
606 }
607 }
608 }
609 }
610
611 impl HasOthers {
612 fn is_short_circuit(&self, func_call: &FunctionCall) -> bool {
613 fn eval_first(e: &ExprImpl, expect: bool) -> bool {
615 if let ExprImpl::Literal(l) = e {
616 *l.get_data() == Some(ScalarImpl::Bool(expect))
617 } else {
618 false
619 }
620 }
621
622 match func_call.func_type {
623 ExprType::Or => eval_first(&func_call.inputs()[0], true),
624 ExprType::And => eval_first(&func_call.inputs()[0], false),
625 _ => false,
626 }
627 }
628 }
629
630 let mut visitor = HasOthers { has_others: false };
631 visitor.visit_expr(self);
632 !visitor.has_others
633 }
634 }
635
636 pub fn is_const(&self) -> bool {
640 self.only_literal_and_func() && self.is_pure()
641 }
642
643 pub fn as_eq_cond(&self) -> Option<(InputRef, InputRef)> {
646 if let ExprImpl::FunctionCall(function_call) = self
647 && function_call.func_type() == ExprType::Equal
648 && let (_, ExprImpl::InputRef(x), ExprImpl::InputRef(y)) =
649 function_call.clone().decompose_as_binary()
650 {
651 if x.index() < y.index() {
652 Some((*x, *y))
653 } else {
654 Some((*y, *x))
655 }
656 } else {
657 None
658 }
659 }
660
661 pub fn as_is_not_distinct_from_cond(&self) -> Option<(InputRef, InputRef)> {
662 if let ExprImpl::FunctionCall(function_call) = self
663 && function_call.func_type() == ExprType::IsNotDistinctFrom
664 && let (_, ExprImpl::InputRef(x), ExprImpl::InputRef(y)) =
665 function_call.clone().decompose_as_binary()
666 {
667 if x.index() < y.index() {
668 Some((*x, *y))
669 } else {
670 Some((*y, *x))
671 }
672 } else {
673 None
674 }
675 }
676
677 pub fn reverse_comparison(comparison: ExprType) -> ExprType {
678 match comparison {
679 ExprType::LessThan => ExprType::GreaterThan,
680 ExprType::LessThanOrEqual => ExprType::GreaterThanOrEqual,
681 ExprType::GreaterThan => ExprType::LessThan,
682 ExprType::GreaterThanOrEqual => ExprType::LessThanOrEqual,
683 ExprType::Equal | ExprType::IsNotDistinctFrom => comparison,
684 _ => unreachable!(),
685 }
686 }
687
688 pub fn as_comparison_cond(&self) -> Option<(InputRef, ExprType, InputRef)> {
689 if let ExprImpl::FunctionCall(function_call) = self {
690 match function_call.func_type() {
691 ty @ (ExprType::LessThan
692 | ExprType::LessThanOrEqual
693 | ExprType::GreaterThan
694 | ExprType::GreaterThanOrEqual) => {
695 let (_, op1, op2) = function_call.clone().decompose_as_binary();
696 if let (ExprImpl::InputRef(x), ExprImpl::InputRef(y)) = (op1, op2) {
697 if x.index < y.index {
698 Some((*x, ty, *y))
699 } else {
700 Some((*y, Self::reverse_comparison(ty), *x))
701 }
702 } else {
703 None
704 }
705 }
706 _ => None,
707 }
708 } else {
709 None
710 }
711 }
712
713 pub fn as_now_comparison_cond(&self) -> Option<(ExprImpl, ExprType, ExprImpl)> {
719 if let ExprImpl::FunctionCall(function_call) = self {
720 match function_call.func_type() {
721 ty @ (ExprType::Equal
722 | ExprType::LessThan
723 | ExprType::LessThanOrEqual
724 | ExprType::GreaterThan
725 | ExprType::GreaterThanOrEqual) => {
726 let (_, op1, op2) = function_call.clone().decompose_as_binary();
727 if !op1.has_now()
728 && op1.has_input_ref()
729 && op2.has_now()
730 && !op2.has_input_ref()
731 {
732 Some((op1, ty, op2))
733 } else if op1.has_now()
734 && !op1.has_input_ref()
735 && !op2.has_now()
736 && op2.has_input_ref()
737 {
738 Some((op2, Self::reverse_comparison(ty), op1))
739 } else {
740 None
741 }
742 }
743 _ => None,
744 }
745 } else {
746 None
747 }
748 }
749
750 pub(crate) fn as_input_comparison_cond(&self) -> Option<InequalityInputPair> {
753 if let ExprImpl::FunctionCall(function_call) = self {
754 match function_call.func_type() {
755 ty @ (ExprType::LessThan
756 | ExprType::LessThanOrEqual
757 | ExprType::GreaterThan
758 | ExprType::GreaterThanOrEqual) => {
759 let (_, mut op1, mut op2) = function_call.clone().decompose_as_binary();
760 if matches!(ty, ExprType::LessThan | ExprType::LessThanOrEqual) {
761 std::mem::swap(&mut op1, &mut op2);
762 }
763 if let (Some((lft_input, lft_offset)), Some((rht_input, rht_offset))) =
764 (op1.as_input_offset(), op2.as_input_offset())
765 {
766 match (lft_offset, rht_offset) {
767 (Some(_), Some(_)) => None,
768 (None, rht_offset @ Some(_)) => {
769 Some(InequalityInputPair::new(lft_input, rht_input, rht_offset))
770 }
771 (Some((operator, operand)), None) => Some(InequalityInputPair::new(
772 lft_input,
773 rht_input,
774 Some((
775 if operator == ExprType::Add {
776 ExprType::Subtract
777 } else {
778 ExprType::Add
779 },
780 operand,
781 )),
782 )),
783 (None, None) => {
784 Some(InequalityInputPair::new(lft_input, rht_input, None))
785 }
786 }
787 } else {
788 None
789 }
790 }
791 _ => None,
792 }
793 } else {
794 None
795 }
796 }
797
798 fn as_input_offset(&self) -> Option<(usize, Option<(ExprType, ExprImpl)>)> {
801 match self {
802 ExprImpl::InputRef(input_ref) => Some((input_ref.index(), None)),
803 ExprImpl::FunctionCall(function_call) => {
804 let expr_type = function_call.func_type();
805 match expr_type {
806 ExprType::Add | ExprType::Subtract => {
807 let (_, lhs, rhs) = function_call.clone().decompose_as_binary();
808 if let ExprImpl::InputRef(input_ref) = &lhs
809 && rhs.is_const()
810 {
811 if rhs.return_type() == DataType::Interval
814 && rhs.as_literal().is_none_or(|literal| {
815 literal.get_data().as_ref().is_some_and(|scalar| {
816 let interval = scalar.as_interval();
817 interval.months() != 0 || interval.days() != 0
818 })
819 })
820 {
821 None
822 } else {
823 Some((input_ref.index(), Some((expr_type, rhs))))
824 }
825 } else {
826 None
827 }
828 }
829 _ => None,
830 }
831 }
832 _ => None,
833 }
834 }
835
836 pub fn as_eq_const(&self) -> Option<(InputRef, ExprImpl)> {
837 if let ExprImpl::FunctionCall(function_call) = self
838 && function_call.func_type() == ExprType::Equal
839 {
840 match function_call.clone().decompose_as_binary() {
841 (_, ExprImpl::InputRef(x), y) if y.is_const() => Some((*x, y)),
842 (_, x, ExprImpl::InputRef(y)) if x.is_const() => Some((*y, x)),
843 _ => None,
844 }
845 } else {
846 None
847 }
848 }
849
850 pub fn as_eq_correlated_input_ref(&self) -> Option<(InputRef, CorrelatedInputRef)> {
851 if let ExprImpl::FunctionCall(function_call) = self
852 && function_call.func_type() == ExprType::Equal
853 {
854 match function_call.clone().decompose_as_binary() {
855 (_, ExprImpl::InputRef(x), ExprImpl::CorrelatedInputRef(y)) => Some((*x, *y)),
856 (_, ExprImpl::CorrelatedInputRef(x), ExprImpl::InputRef(y)) => Some((*y, *x)),
857 _ => None,
858 }
859 } else {
860 None
861 }
862 }
863
864 pub fn as_is_null(&self) -> Option<InputRef> {
865 if let ExprImpl::FunctionCall(function_call) = self
866 && function_call.func_type() == ExprType::IsNull
867 {
868 match function_call.clone().decompose_as_unary() {
869 (_, ExprImpl::InputRef(x)) => Some(*x),
870 _ => None,
871 }
872 } else {
873 None
874 }
875 }
876
877 pub fn as_comparison_const(&self) -> Option<(InputRef, ExprType, ExprImpl)> {
878 fn reverse_comparison(comparison: ExprType) -> ExprType {
879 match comparison {
880 ExprType::LessThan => ExprType::GreaterThan,
881 ExprType::LessThanOrEqual => ExprType::GreaterThanOrEqual,
882 ExprType::GreaterThan => ExprType::LessThan,
883 ExprType::GreaterThanOrEqual => ExprType::LessThanOrEqual,
884 _ => unreachable!(),
885 }
886 }
887
888 if let ExprImpl::FunctionCall(function_call) = self {
889 match function_call.func_type() {
890 ty @ (ExprType::LessThan
891 | ExprType::LessThanOrEqual
892 | ExprType::GreaterThan
893 | ExprType::GreaterThanOrEqual) => {
894 let (_, op1, op2) = function_call.clone().decompose_as_binary();
895 match (op1, op2) {
896 (ExprImpl::InputRef(x), y) if y.is_const() => Some((*x, ty, y)),
897 (x, ExprImpl::InputRef(y)) if x.is_const() => {
898 Some((*y, reverse_comparison(ty), x))
899 }
900 _ => None,
901 }
902 }
903 _ => None,
904 }
905 } else {
906 None
907 }
908 }
909
910 pub fn as_in_const_list(&self) -> Option<(InputRef, Vec<ExprImpl>)> {
911 if let ExprImpl::FunctionCall(function_call) = self
912 && function_call.func_type() == ExprType::In
913 {
914 let mut inputs = function_call.inputs().iter().cloned();
915 let input_ref = match inputs.next().unwrap() {
916 ExprImpl::InputRef(i) => *i,
917 _ => return None,
918 };
919 let list: Vec<_> = inputs
920 .inspect(|expr| {
921 assert!(expr.is_const());
923 })
924 .collect();
925
926 Some((input_ref, list))
927 } else {
928 None
929 }
930 }
931
932 pub fn as_or_disjunctions(&self) -> Option<Vec<ExprImpl>> {
933 if let ExprImpl::FunctionCall(function_call) = self
934 && function_call.func_type() == ExprType::Or
935 {
936 Some(to_disjunctions(self.clone()))
937 } else {
938 None
939 }
940 }
941
942 pub fn to_project_set_select_item_proto(&self) -> ProjectSetSelectItem {
943 use risingwave_pb::expr::project_set_select_item::SelectItem::*;
944
945 ProjectSetSelectItem {
946 select_item: Some(match self {
947 ExprImpl::TableFunction(tf) => TableFunction(tf.to_protobuf()),
948 expr => Expr(expr.to_expr_proto()),
949 }),
950 }
951 }
952
953 pub fn from_expr_proto(proto: &ExprNode) -> RwResult<Self> {
954 let rex_node = proto.get_rex_node()?;
955 let ret_type = proto.get_return_type()?.into();
956
957 Ok(match rex_node {
958 RexNode::InputRef(column_index) => Self::InputRef(Box::new(InputRef::from_expr_proto(
959 *column_index as _,
960 ret_type,
961 )?)),
962 RexNode::Constant(_) => Self::Literal(Box::new(Literal::from_expr_proto(proto)?)),
963 RexNode::Udf(udf) => Self::UserDefinedFunction(Box::new(
964 UserDefinedFunction::from_expr_proto(udf, ret_type)?,
965 )),
966 RexNode::FuncCall(function_call) => {
967 Self::FunctionCall(Box::new(FunctionCall::from_expr_proto(
968 function_call,
969 proto.get_function_type()?, ret_type,
971 )?))
972 }
973 RexNode::Now(_) => Self::Now(Box::new(Now {})),
974 })
975 }
976}
977
978impl From<Condition> for ExprImpl {
979 fn from(c: Condition) -> Self {
980 ExprImpl::and(c.conjunctions)
981 }
982}
983
984impl std::fmt::Debug for ExprImpl {
988 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
989 if f.alternate() {
990 return match self {
991 Self::InputRef(arg0) => f.debug_tuple("InputRef").field(arg0).finish(),
992 Self::Literal(arg0) => f.debug_tuple("Literal").field(arg0).finish(),
993 Self::FunctionCall(arg0) => f.debug_tuple("FunctionCall").field(arg0).finish(),
994 Self::FunctionCallWithLambda(arg0) => {
995 f.debug_tuple("FunctionCallWithLambda").field(arg0).finish()
996 }
997 Self::AggCall(arg0) => f.debug_tuple("AggCall").field(arg0).finish(),
998 Self::Subquery(arg0) => f.debug_tuple("Subquery").field(arg0).finish(),
999 Self::CorrelatedInputRef(arg0) => {
1000 f.debug_tuple("CorrelatedInputRef").field(arg0).finish()
1001 }
1002 Self::TableFunction(arg0) => f.debug_tuple("TableFunction").field(arg0).finish(),
1003 Self::WindowFunction(arg0) => f.debug_tuple("WindowFunction").field(arg0).finish(),
1004 Self::UserDefinedFunction(arg0) => {
1005 f.debug_tuple("UserDefinedFunction").field(arg0).finish()
1006 }
1007 Self::Parameter(arg0) => f.debug_tuple("Parameter").field(arg0).finish(),
1008 Self::Now(_) => f.debug_tuple("Now").finish(),
1009 };
1010 }
1011 match self {
1012 Self::InputRef(x) => write!(f, "{:?}", x),
1013 Self::Literal(x) => write!(f, "{:?}", x),
1014 Self::FunctionCall(x) => write!(f, "{:?}", x),
1015 Self::FunctionCallWithLambda(x) => write!(f, "{:?}", x),
1016 Self::AggCall(x) => write!(f, "{:?}", x),
1017 Self::Subquery(x) => write!(f, "{:?}", x),
1018 Self::CorrelatedInputRef(x) => write!(f, "{:?}", x),
1019 Self::TableFunction(x) => write!(f, "{:?}", x),
1020 Self::WindowFunction(x) => write!(f, "{:?}", x),
1021 Self::UserDefinedFunction(x) => write!(f, "{:?}", x),
1022 Self::Parameter(x) => write!(f, "{:?}", x),
1023 Self::Now(x) => write!(f, "{:?}", x),
1024 }
1025 }
1026}
1027
1028pub struct ExprDisplay<'a> {
1029 pub expr: &'a ExprImpl,
1030 pub input_schema: &'a Schema,
1031}
1032
1033impl std::fmt::Debug for ExprDisplay<'_> {
1034 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1035 let that = self.expr;
1036 match that {
1037 ExprImpl::InputRef(x) => write!(
1038 f,
1039 "{:?}",
1040 InputRefDisplay {
1041 input_ref: x,
1042 input_schema: self.input_schema
1043 }
1044 ),
1045 ExprImpl::Literal(x) => write!(f, "{:?}", x),
1046 ExprImpl::FunctionCall(x) => write!(
1047 f,
1048 "{:?}",
1049 FunctionCallDisplay {
1050 function_call: x,
1051 input_schema: self.input_schema
1052 }
1053 ),
1054 ExprImpl::FunctionCallWithLambda(x) => write!(
1055 f,
1056 "{:?}",
1057 FunctionCallDisplay {
1058 function_call: &x.to_full_function_call(),
1059 input_schema: self.input_schema
1060 }
1061 ),
1062 ExprImpl::AggCall(x) => write!(f, "{:?}", x),
1063 ExprImpl::Subquery(x) => write!(f, "{:?}", x),
1064 ExprImpl::CorrelatedInputRef(x) => write!(f, "{:?}", x),
1065 ExprImpl::TableFunction(x) => {
1066 write!(f, "{:?}", x)
1068 }
1069 ExprImpl::WindowFunction(x) => {
1070 write!(f, "{:?}", x)
1072 }
1073 ExprImpl::UserDefinedFunction(x) => {
1074 write!(
1075 f,
1076 "{:?}",
1077 UserDefinedFunctionDisplay {
1078 func_call: x,
1079 input_schema: self.input_schema
1080 }
1081 )
1082 }
1083 ExprImpl::Parameter(x) => write!(f, "{:?}", x),
1084 ExprImpl::Now(x) => write!(f, "{:?}", x),
1085 }
1086 }
1087}
1088
1089impl std::fmt::Display for ExprDisplay<'_> {
1090 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1091 (self as &dyn std::fmt::Debug).fmt(f)
1092 }
1093}
1094
1095#[cfg(test)]
1096macro_rules! assert_eq_input_ref {
1098 ($e:expr, $index:expr) => {
1099 match $e {
1100 ExprImpl::InputRef(i) => assert_eq!(i.index(), $index),
1101 _ => assert!(false, "Expected input ref, found {:?}", $e),
1102 }
1103 };
1104}
1105
1106#[cfg(test)]
1107pub(crate) use assert_eq_input_ref;
1108use risingwave_common::bail;
1109use risingwave_common::catalog::Schema;
1110use risingwave_common::row::OwnedRow;
1111
1112use crate::utils::Condition;
1113
1114#[cfg(test)]
1115mod tests {
1116 use super::*;
1117
1118 #[test]
1119 fn test_expr_debug_alternate() {
1120 let mut e = InputRef::new(1, DataType::Boolean).into();
1121 e = FunctionCall::new(ExprType::Not, vec![e]).unwrap().into();
1122 let s = format!("{:#?}", e);
1123 assert!(s.contains("return_type: Boolean"))
1124 }
1125}