1use enum_as_inner::EnumAsInner;
16use fixedbitset::FixedBitSet;
17use futures::FutureExt;
18use paste::paste;
19use risingwave_common::array::ListValue;
20use risingwave_common::types::{DataType, Datum, JsonbVal, MapType, Scalar, ScalarImpl};
21use risingwave_expr::aggregate::PbAggKind;
22use risingwave_expr::expr::build_from_prost;
23use risingwave_pb::expr::expr_node::RexNode;
24use risingwave_pb::expr::{ExprNode, ProjectSetSelectItem};
25use user_defined_function::UserDefinedFunctionDisplay;
26
27use crate::error::{ErrorCode, Result as RwResult};
28
29mod agg_call;
30mod correlated_input_ref;
31mod function_call;
32mod function_call_with_lambda;
33mod input_ref;
34mod literal;
35mod now;
36mod parameter;
37mod pure;
38mod subquery;
39mod table_function;
40mod user_defined_function;
41mod window_function;
42
43mod order_by_expr;
44pub use order_by_expr::{OrderBy, OrderByExpr};
45
46mod expr_mutator;
47mod expr_rewriter;
48mod expr_visitor;
49pub mod function_impl;
50mod session_timezone;
51mod type_inference;
52mod utils;
53
54pub use agg_call::AggCall;
55pub use correlated_input_ref::{CorrelatedId, CorrelatedInputRef, Depth};
56pub use expr_mutator::ExprMutator;
57pub use expr_rewriter::{ExprRewriter, default_rewrite_expr};
58pub use expr_visitor::{ExprVisitor, default_visit_expr};
59pub use function_call::{FunctionCall, FunctionCallDisplay, is_row_function};
60pub use function_call_with_lambda::FunctionCallWithLambda;
61pub use input_ref::{InputRef, InputRefDisplay, input_ref_to_column_indices};
62pub use literal::Literal;
63pub use now::{InlineNowProcTime, Now, NowProcTimeFinder};
64pub use parameter::Parameter;
65pub use pure::*;
66pub use risingwave_pb::expr::expr_node::Type as ExprType;
67pub use session_timezone::{SessionTimezone, TimestamptzExprFinder};
68pub use subquery::{Subquery, SubqueryKind};
69pub use table_function::{TableFunction, TableFunctionType};
70pub use type_inference::*;
71pub use user_defined_function::UserDefinedFunction;
72pub use utils::*;
73pub use window_function::WindowFunction;
74
75const EXPR_DEPTH_THRESHOLD: usize = 30;
76const EXPR_TOO_DEEP_NOTICE: &str = "Some expression is too complicated. \
77Consider simplifying or splitting the query if you encounter any issues.";
78
79pub trait Expr: Into<ExprImpl> {
81 fn return_type(&self) -> DataType;
83
84 fn to_expr_proto(&self) -> ExprNode;
86}
87
88macro_rules! impl_expr_impl {
89 ($($t:ident,)*) => {
90 #[derive(Clone, Eq, PartialEq, Hash, EnumAsInner)]
91 pub enum ExprImpl {
92 $($t(Box<$t>),)*
93 }
94
95 impl ExprImpl {
96 pub fn variant_name(&self) -> &'static str {
97 match self {
98 $(ExprImpl::$t(_) => stringify!($t),)*
99 }
100 }
101 }
102
103 $(
104 impl From<$t> for ExprImpl {
105 fn from(o: $t) -> ExprImpl {
106 ExprImpl::$t(Box::new(o))
107 }
108 })*
109
110 impl Expr for ExprImpl {
111 fn return_type(&self) -> DataType {
112 match self {
113 $(ExprImpl::$t(expr) => expr.return_type(),)*
114 }
115 }
116
117 fn to_expr_proto(&self) -> ExprNode {
118 match self {
119 $(ExprImpl::$t(expr) => expr.to_expr_proto(),)*
120 }
121 }
122 }
123 };
124}
125
126impl_expr_impl!(
127 CorrelatedInputRef,
129 InputRef,
130 Literal,
131 FunctionCall,
132 FunctionCallWithLambda,
133 AggCall,
134 Subquery,
135 TableFunction,
136 WindowFunction,
137 UserDefinedFunction,
138 Parameter,
139 Now,
140);
141
142impl ExprImpl {
143 #[inline(always)]
145 pub fn literal_int(v: i32) -> Self {
146 Literal::new(Some(v.to_scalar_value()), DataType::Int32).into()
147 }
148
149 #[inline(always)]
151 pub fn literal_bigint(v: i64) -> Self {
152 Literal::new(Some(v.to_scalar_value()), DataType::Int64).into()
153 }
154
155 #[inline(always)]
157 pub fn literal_f64(v: f64) -> Self {
158 Literal::new(Some(v.into()), DataType::Float64).into()
159 }
160
161 #[inline(always)]
163 pub fn literal_bool(v: bool) -> Self {
164 Literal::new(Some(v.to_scalar_value()), DataType::Boolean).into()
165 }
166
167 #[inline(always)]
169 pub fn literal_varchar(v: String) -> Self {
170 Literal::new(Some(v.into()), DataType::Varchar).into()
171 }
172
173 #[inline(always)]
175 pub fn literal_null(element_type: DataType) -> Self {
176 Literal::new(None, element_type).into()
177 }
178
179 #[inline(always)]
181 pub fn literal_jsonb(v: JsonbVal) -> Self {
182 Literal::new(Some(v.into()), DataType::Jsonb).into()
183 }
184
185 #[inline(always)]
187 pub fn literal_list(v: ListValue, element_type: DataType) -> Self {
188 Literal::new(
189 Some(v.to_scalar_value()),
190 DataType::List(Box::new(element_type)),
191 )
192 .into()
193 }
194
195 pub fn take(&mut self) -> Self {
197 std::mem::replace(self, Self::literal_null(self.return_type()))
198 }
199
200 #[inline(always)]
202 pub fn count_star() -> Self {
203 AggCall::new(
204 PbAggKind::Count.into(),
205 vec![],
206 false,
207 OrderBy::any(),
208 Condition::true_cond(),
209 vec![],
210 )
211 .unwrap()
212 .into()
213 }
214
215 pub fn and(exprs: impl IntoIterator<Item = ExprImpl>) -> Self {
219 merge_expr_by_logical(exprs, ExprType::And, ExprImpl::literal_bool(true))
220 }
221
222 pub fn or(exprs: impl IntoIterator<Item = ExprImpl>) -> Self {
226 merge_expr_by_logical(exprs, ExprType::Or, ExprImpl::literal_bool(false))
227 }
228
229 pub fn collect_input_refs(&self, input_col_num: usize) -> FixedBitSet {
234 collect_input_refs(input_col_num, [self])
235 }
236
237 pub fn is_pure(&self) -> bool {
239 is_pure(self)
240 }
241
242 pub fn is_impure(&self) -> bool {
243 is_impure(self)
244 }
245
246 pub fn count_nows(&self) -> usize {
248 let mut visitor = CountNow::default();
249 visitor.visit_expr(self);
250 visitor.count()
251 }
252
253 pub fn is_null(&self) -> bool {
255 matches!(self, ExprImpl::Literal(literal) if literal.get_data().is_none())
256 }
257
258 pub fn is_untyped(&self) -> bool {
260 matches!(self, ExprImpl::Literal(literal) if literal.is_untyped())
261 || matches!(self, ExprImpl::Parameter(parameter) if !parameter.has_infer())
262 }
263
264 pub fn cast_implicit(mut self, target: DataType) -> Result<ExprImpl, CastError> {
266 FunctionCall::cast_mut(&mut self, target, CastContext::Implicit)?;
267 Ok(self)
268 }
269
270 pub fn cast_assign(mut self, target: DataType) -> Result<ExprImpl, CastError> {
272 FunctionCall::cast_mut(&mut self, target, CastContext::Assign)?;
273 Ok(self)
274 }
275
276 pub fn cast_explicit(mut self, target: DataType) -> Result<ExprImpl, CastError> {
278 FunctionCall::cast_mut(&mut self, target, CastContext::Explicit)?;
279 Ok(self)
280 }
281
282 pub fn cast_implicit_mut(&mut self, target: DataType) -> Result<(), CastError> {
284 FunctionCall::cast_mut(self, target, CastContext::Implicit)
285 }
286
287 pub fn cast_explicit_mut(&mut self, target: DataType) -> Result<(), CastError> {
289 FunctionCall::cast_mut(self, target, CastContext::Explicit)
290 }
291
292 pub fn cast_to_regclass(self) -> Result<ExprImpl, CastError> {
295 match self.return_type() {
296 DataType::Varchar => Ok(ExprImpl::FunctionCall(Box::new(
297 FunctionCall::new_unchecked(ExprType::CastRegclass, vec![self], DataType::Int32),
298 ))),
299 DataType::Int32 => Ok(self),
300 dt if dt.is_int() => Ok(self.cast_explicit(DataType::Int32)?),
301 _ => bail_cast_error!("unsupported input type"),
302 }
303 }
304
305 pub fn cast_to_regclass_mut(&mut self) -> Result<(), CastError> {
307 let owned = std::mem::replace(self, ExprImpl::literal_bool(false));
308 *self = owned.cast_to_regclass()?;
309 Ok(())
310 }
311
312 pub fn ensure_array_type(&self) -> Result<(), ErrorCode> {
314 if self.is_untyped() {
315 return Err(ErrorCode::BindError(
316 "could not determine polymorphic type because input has type unknown".into(),
317 ));
318 }
319 match self.return_type() {
320 DataType::List(_) => Ok(()),
321 t => Err(ErrorCode::BindError(format!("expects array but got {t}"))),
322 }
323 }
324
325 pub fn try_into_map_type(&self) -> Result<MapType, ErrorCode> {
327 if self.is_untyped() {
328 return Err(ErrorCode::BindError(
329 "could not determine polymorphic type because input has type unknown".into(),
330 ));
331 }
332 match self.return_type() {
333 DataType::Map(m) => Ok(m),
334 t => Err(ErrorCode::BindError(format!("expects map but got {t}"))),
335 }
336 }
337
338 pub fn enforce_bool_clause(self, clause: &str) -> RwResult<ExprImpl> {
340 if self.is_untyped() {
341 let inner = self.cast_implicit(DataType::Boolean)?;
342 return Ok(inner);
343 }
344 let return_type = self.return_type();
345 if return_type != DataType::Boolean {
346 bail!(
347 "argument of {} must be boolean, not type {:?}",
348 clause,
349 return_type
350 )
351 }
352 Ok(self)
353 }
354
355 pub fn cast_output(self) -> RwResult<ExprImpl> {
366 if self.return_type() == DataType::Boolean {
367 return Ok(FunctionCall::new(ExprType::BoolOut, vec![self])?.into());
368 }
369 self.cast_assign(DataType::Varchar)
372 .map_err(|err| err.into())
373 }
374
375 pub async fn eval_row(&self, input: &OwnedRow) -> RwResult<Datum> {
380 let backend_expr = build_from_prost(&self.to_expr_proto())?;
381 Ok(backend_expr.eval_row(input).await?)
382 }
383
384 pub fn try_fold_const(&self) -> Option<RwResult<Datum>> {
391 if self.is_const() {
392 self.eval_row(&OwnedRow::empty())
393 .now_or_never()
394 .expect("constant expression should not be async")
395 .into()
396 } else {
397 None
398 }
399 }
400
401 pub fn fold_const(&self) -> RwResult<Datum> {
403 self.try_fold_const().expect("expression is not constant")
404 }
405}
406
407macro_rules! impl_has_variant {
412 ( $($variant:ty),* ) => {
413 paste! {
414 impl ExprImpl {
415 $(
416 pub fn [<has_ $variant:snake>](&self) -> bool {
417 struct Has { has: bool }
418
419 impl ExprVisitor for Has {
420 fn [<visit_ $variant:snake>](&mut self, _: &$variant) {
421 self.has = true;
422 }
423 }
424
425 let mut visitor = Has { has: false };
426 visitor.visit_expr(self);
427 visitor.has
428 }
429 )*
430 }
431 }
432 };
433}
434
435impl_has_variant! {InputRef, Literal, FunctionCall, FunctionCallWithLambda, AggCall, Subquery, TableFunction, WindowFunction, Now}
436
437#[derive(Debug, Clone, PartialEq, Eq, Hash)]
438pub struct InequalityInputPair {
439 pub(crate) key_required_larger: usize,
441 pub(crate) key_required_smaller: usize,
443 pub(crate) delta_expression: Option<(ExprType, ExprImpl)>,
445}
446
447impl InequalityInputPair {
448 fn new(
449 key_required_larger: usize,
450 key_required_smaller: usize,
451 delta_expression: Option<(ExprType, ExprImpl)>,
452 ) -> Self {
453 Self {
454 key_required_larger,
455 key_required_smaller,
456 delta_expression,
457 }
458 }
459}
460
461impl ExprImpl {
462 pub fn has_correlated_input_ref(&self, _: std::convert::Infallible) -> bool {
473 unreachable!()
474 }
475
476 pub fn has_correlated_input_ref_by_depth(&self, depth: Depth) -> bool {
482 struct Has {
483 depth: usize,
484 has: bool,
485 }
486
487 impl ExprVisitor for Has {
488 fn visit_correlated_input_ref(&mut self, correlated_input_ref: &CorrelatedInputRef) {
489 if correlated_input_ref.depth() == self.depth {
490 self.has = true;
491 }
492 }
493
494 fn visit_subquery(&mut self, subquery: &Subquery) {
495 self.depth += 1;
496 self.visit_bound_set_expr(&subquery.query.body);
497 self.depth -= 1;
498 }
499 }
500
501 impl Has {
502 fn visit_bound_set_expr(&mut self, set_expr: &BoundSetExpr) {
503 match set_expr {
504 BoundSetExpr::Select(select) => {
505 select.exprs().for_each(|expr| self.visit_expr(expr));
506 match select.from.as_ref() {
507 Some(from) => from.is_correlated(self.depth),
508 None => false,
509 };
510 }
511 BoundSetExpr::Values(values) => {
512 values.exprs().for_each(|expr| self.visit_expr(expr))
513 }
514 BoundSetExpr::Query(query) => {
515 self.depth += 1;
516 self.visit_bound_set_expr(&query.body);
517 self.depth -= 1;
518 }
519 BoundSetExpr::SetOperation { left, right, .. } => {
520 self.visit_bound_set_expr(left);
521 self.visit_bound_set_expr(right);
522 }
523 };
524 }
525 }
526
527 let mut visitor = Has { depth, has: false };
528 visitor.visit_expr(self);
529 visitor.has
530 }
531
532 pub fn has_correlated_input_ref_by_correlated_id(&self, correlated_id: CorrelatedId) -> bool {
533 struct Has {
534 correlated_id: CorrelatedId,
535 has: bool,
536 }
537
538 impl ExprVisitor for Has {
539 fn visit_correlated_input_ref(&mut self, correlated_input_ref: &CorrelatedInputRef) {
540 if correlated_input_ref.correlated_id() == self.correlated_id {
541 self.has = true;
542 }
543 }
544
545 fn visit_subquery(&mut self, subquery: &Subquery) {
546 self.visit_bound_set_expr(&subquery.query.body);
547 }
548 }
549
550 impl Has {
551 fn visit_bound_set_expr(&mut self, set_expr: &BoundSetExpr) {
552 match set_expr {
553 BoundSetExpr::Select(select) => {
554 select.exprs().for_each(|expr| self.visit_expr(expr))
555 }
556 BoundSetExpr::Values(values) => {
557 values.exprs().for_each(|expr| self.visit_expr(expr));
558 }
559 BoundSetExpr::Query(query) => self.visit_bound_set_expr(&query.body),
560 BoundSetExpr::SetOperation { left, right, .. } => {
561 self.visit_bound_set_expr(left);
562 self.visit_bound_set_expr(right);
563 }
564 }
565 }
566 }
567
568 let mut visitor = Has {
569 correlated_id,
570 has: false,
571 };
572 visitor.visit_expr(self);
573 visitor.has
574 }
575
576 pub fn collect_correlated_indices_by_depth_and_assign_id(
579 &mut self,
580 depth: Depth,
581 correlated_id: CorrelatedId,
582 ) -> Vec<usize> {
583 struct Collector {
584 depth: Depth,
585 correlated_indices: Vec<usize>,
586 correlated_id: CorrelatedId,
587 }
588
589 impl ExprMutator for Collector {
590 fn visit_correlated_input_ref(
591 &mut self,
592 correlated_input_ref: &mut CorrelatedInputRef,
593 ) {
594 if correlated_input_ref.depth() == self.depth {
595 self.correlated_indices.push(correlated_input_ref.index());
596 correlated_input_ref.set_correlated_id(self.correlated_id);
597 }
598 }
599
600 fn visit_subquery(&mut self, subquery: &mut Subquery) {
601 self.depth += 1;
602 self.visit_bound_set_expr(&mut subquery.query.body);
603 self.depth -= 1;
604 }
605 }
606
607 impl Collector {
608 fn visit_bound_set_expr(&mut self, set_expr: &mut BoundSetExpr) {
609 match set_expr {
610 BoundSetExpr::Select(select) => {
611 select.exprs_mut().for_each(|expr| self.visit_expr(expr));
612 if let Some(from) = select.from.as_mut() {
613 self.correlated_indices.extend(
614 from.collect_correlated_indices_by_depth_and_assign_id(
615 self.depth,
616 self.correlated_id,
617 ),
618 );
619 };
620 }
621 BoundSetExpr::Values(values) => {
622 values.exprs_mut().for_each(|expr| self.visit_expr(expr))
623 }
624 BoundSetExpr::Query(query) => {
625 self.depth += 1;
626 self.visit_bound_set_expr(&mut query.body);
627 self.depth -= 1;
628 }
629 BoundSetExpr::SetOperation { left, right, .. } => {
630 self.visit_bound_set_expr(&mut *left);
631 self.visit_bound_set_expr(&mut *right);
632 }
633 }
634 }
635 }
636
637 let mut collector = Collector {
638 depth,
639 correlated_indices: vec![],
640 correlated_id,
641 };
642 collector.visit_expr(self);
643 collector.correlated_indices
644 }
645
646 pub fn is_const(&self) -> bool {
650 let only_literal_and_func = {
651 struct HasOthers {
652 has_others: bool,
653 }
654
655 impl ExprVisitor for HasOthers {
656 fn visit_expr(&mut self, expr: &ExprImpl) {
657 match expr {
658 ExprImpl::CorrelatedInputRef(_)
659 | ExprImpl::InputRef(_)
660 | ExprImpl::AggCall(_)
661 | ExprImpl::Subquery(_)
662 | ExprImpl::TableFunction(_)
663 | ExprImpl::WindowFunction(_)
664 | ExprImpl::UserDefinedFunction(_)
665 | ExprImpl::Parameter(_)
666 | ExprImpl::Now(_) => self.has_others = true,
667 ExprImpl::Literal(_inner) => {}
668 ExprImpl::FunctionCall(inner) => {
669 if !self.is_short_circuit(inner) {
670 self.visit_function_call(inner)
674 }
675 }
676 ExprImpl::FunctionCallWithLambda(inner) => {
677 self.visit_function_call_with_lambda(inner)
678 }
679 }
680 }
681 }
682
683 impl HasOthers {
684 fn is_short_circuit(&self, func_call: &FunctionCall) -> bool {
685 fn eval_first(e: &ExprImpl, expect: bool) -> bool {
687 if let ExprImpl::Literal(l) = e {
688 *l.get_data() == Some(ScalarImpl::Bool(expect))
689 } else {
690 false
691 }
692 }
693
694 match func_call.func_type {
695 ExprType::Or => eval_first(&func_call.inputs()[0], true),
696 ExprType::And => eval_first(&func_call.inputs()[0], false),
697 _ => false,
698 }
699 }
700 }
701
702 let mut visitor = HasOthers { has_others: false };
703 visitor.visit_expr(self);
704 !visitor.has_others
705 };
706
707 let is_pure = self.is_pure();
708
709 only_literal_and_func && is_pure
710 }
711
712 pub fn as_eq_cond(&self) -> Option<(InputRef, InputRef)> {
715 if let ExprImpl::FunctionCall(function_call) = self
716 && function_call.func_type() == ExprType::Equal
717 && let (_, ExprImpl::InputRef(x), ExprImpl::InputRef(y)) =
718 function_call.clone().decompose_as_binary()
719 {
720 if x.index() < y.index() {
721 Some((*x, *y))
722 } else {
723 Some((*y, *x))
724 }
725 } else {
726 None
727 }
728 }
729
730 pub fn as_is_not_distinct_from_cond(&self) -> Option<(InputRef, InputRef)> {
731 if let ExprImpl::FunctionCall(function_call) = self
732 && function_call.func_type() == ExprType::IsNotDistinctFrom
733 && let (_, ExprImpl::InputRef(x), ExprImpl::InputRef(y)) =
734 function_call.clone().decompose_as_binary()
735 {
736 if x.index() < y.index() {
737 Some((*x, *y))
738 } else {
739 Some((*y, *x))
740 }
741 } else {
742 None
743 }
744 }
745
746 pub fn reverse_comparison(comparison: ExprType) -> ExprType {
747 match comparison {
748 ExprType::LessThan => ExprType::GreaterThan,
749 ExprType::LessThanOrEqual => ExprType::GreaterThanOrEqual,
750 ExprType::GreaterThan => ExprType::LessThan,
751 ExprType::GreaterThanOrEqual => ExprType::LessThanOrEqual,
752 ExprType::Equal | ExprType::IsNotDistinctFrom => comparison,
753 _ => unreachable!(),
754 }
755 }
756
757 pub fn as_comparison_cond(&self) -> Option<(InputRef, ExprType, InputRef)> {
758 if let ExprImpl::FunctionCall(function_call) = self {
759 match function_call.func_type() {
760 ty @ (ExprType::LessThan
761 | ExprType::LessThanOrEqual
762 | ExprType::GreaterThan
763 | ExprType::GreaterThanOrEqual) => {
764 let (_, op1, op2) = function_call.clone().decompose_as_binary();
765 if let (ExprImpl::InputRef(x), ExprImpl::InputRef(y)) = (op1, op2) {
766 if x.index < y.index {
767 Some((*x, ty, *y))
768 } else {
769 Some((*y, Self::reverse_comparison(ty), *x))
770 }
771 } else {
772 None
773 }
774 }
775 _ => None,
776 }
777 } else {
778 None
779 }
780 }
781
782 pub fn as_now_comparison_cond(&self) -> Option<(ExprImpl, ExprType, ExprImpl)> {
788 if let ExprImpl::FunctionCall(function_call) = self {
789 match function_call.func_type() {
790 ty @ (ExprType::Equal
791 | ExprType::LessThan
792 | ExprType::LessThanOrEqual
793 | ExprType::GreaterThan
794 | ExprType::GreaterThanOrEqual) => {
795 let (_, op1, op2) = function_call.clone().decompose_as_binary();
796 if !op1.has_now()
797 && op1.has_input_ref()
798 && op2.has_now()
799 && !op2.has_input_ref()
800 {
801 Some((op1, ty, op2))
802 } else if op1.has_now()
803 && !op1.has_input_ref()
804 && !op2.has_now()
805 && op2.has_input_ref()
806 {
807 Some((op2, Self::reverse_comparison(ty), op1))
808 } else {
809 None
810 }
811 }
812 _ => None,
813 }
814 } else {
815 None
816 }
817 }
818
819 pub(crate) fn as_input_comparison_cond(&self) -> Option<InequalityInputPair> {
822 if let ExprImpl::FunctionCall(function_call) = self {
823 match function_call.func_type() {
824 ty @ (ExprType::LessThan
825 | ExprType::LessThanOrEqual
826 | ExprType::GreaterThan
827 | ExprType::GreaterThanOrEqual) => {
828 let (_, mut op1, mut op2) = function_call.clone().decompose_as_binary();
829 if matches!(ty, ExprType::LessThan | ExprType::LessThanOrEqual) {
830 std::mem::swap(&mut op1, &mut op2);
831 }
832 if let (Some((lft_input, lft_offset)), Some((rht_input, rht_offset))) =
833 (op1.as_input_offset(), op2.as_input_offset())
834 {
835 match (lft_offset, rht_offset) {
836 (Some(_), Some(_)) => None,
837 (None, rht_offset @ Some(_)) => {
838 Some(InequalityInputPair::new(lft_input, rht_input, rht_offset))
839 }
840 (Some((operator, operand)), None) => Some(InequalityInputPair::new(
841 lft_input,
842 rht_input,
843 Some((
844 if operator == ExprType::Add {
845 ExprType::Subtract
846 } else {
847 ExprType::Add
848 },
849 operand,
850 )),
851 )),
852 (None, None) => {
853 Some(InequalityInputPair::new(lft_input, rht_input, None))
854 }
855 }
856 } else {
857 None
858 }
859 }
860 _ => None,
861 }
862 } else {
863 None
864 }
865 }
866
867 fn as_input_offset(&self) -> Option<(usize, Option<(ExprType, ExprImpl)>)> {
870 match self {
871 ExprImpl::InputRef(input_ref) => Some((input_ref.index(), None)),
872 ExprImpl::FunctionCall(function_call) => {
873 let expr_type = function_call.func_type();
874 match expr_type {
875 ExprType::Add | ExprType::Subtract => {
876 let (_, lhs, rhs) = function_call.clone().decompose_as_binary();
877 if let ExprImpl::InputRef(input_ref) = &lhs
878 && rhs.is_const()
879 {
880 if rhs.return_type() == DataType::Interval
883 && rhs.as_literal().is_none_or(|literal| {
884 literal.get_data().as_ref().is_some_and(|scalar| {
885 let interval = scalar.as_interval();
886 interval.months() != 0 || interval.days() != 0
887 })
888 })
889 {
890 None
891 } else {
892 Some((input_ref.index(), Some((expr_type, rhs))))
893 }
894 } else {
895 None
896 }
897 }
898 _ => None,
899 }
900 }
901 _ => None,
902 }
903 }
904
905 pub fn as_eq_const(&self) -> Option<(InputRef, ExprImpl)> {
906 if let ExprImpl::FunctionCall(function_call) = self
907 && function_call.func_type() == ExprType::Equal
908 {
909 match function_call.clone().decompose_as_binary() {
910 (_, ExprImpl::InputRef(x), y) if y.is_const() => Some((*x, y)),
911 (_, x, ExprImpl::InputRef(y)) if x.is_const() => Some((*y, x)),
912 _ => None,
913 }
914 } else {
915 None
916 }
917 }
918
919 pub fn as_eq_correlated_input_ref(&self) -> Option<(InputRef, CorrelatedInputRef)> {
920 if let ExprImpl::FunctionCall(function_call) = self
921 && function_call.func_type() == ExprType::Equal
922 {
923 match function_call.clone().decompose_as_binary() {
924 (_, ExprImpl::InputRef(x), ExprImpl::CorrelatedInputRef(y)) => Some((*x, *y)),
925 (_, ExprImpl::CorrelatedInputRef(x), ExprImpl::InputRef(y)) => Some((*y, *x)),
926 _ => None,
927 }
928 } else {
929 None
930 }
931 }
932
933 pub fn as_is_null(&self) -> Option<InputRef> {
934 if let ExprImpl::FunctionCall(function_call) = self
935 && function_call.func_type() == ExprType::IsNull
936 {
937 match function_call.clone().decompose_as_unary() {
938 (_, ExprImpl::InputRef(x)) => Some(*x),
939 _ => None,
940 }
941 } else {
942 None
943 }
944 }
945
946 pub fn as_comparison_const(&self) -> Option<(InputRef, ExprType, ExprImpl)> {
947 fn reverse_comparison(comparison: ExprType) -> ExprType {
948 match comparison {
949 ExprType::LessThan => ExprType::GreaterThan,
950 ExprType::LessThanOrEqual => ExprType::GreaterThanOrEqual,
951 ExprType::GreaterThan => ExprType::LessThan,
952 ExprType::GreaterThanOrEqual => ExprType::LessThanOrEqual,
953 _ => unreachable!(),
954 }
955 }
956
957 if let ExprImpl::FunctionCall(function_call) = self {
958 match function_call.func_type() {
959 ty @ (ExprType::LessThan
960 | ExprType::LessThanOrEqual
961 | ExprType::GreaterThan
962 | ExprType::GreaterThanOrEqual) => {
963 let (_, op1, op2) = function_call.clone().decompose_as_binary();
964 match (op1, op2) {
965 (ExprImpl::InputRef(x), y) if y.is_const() => Some((*x, ty, y)),
966 (x, ExprImpl::InputRef(y)) if x.is_const() => {
967 Some((*y, reverse_comparison(ty), x))
968 }
969 _ => None,
970 }
971 }
972 _ => None,
973 }
974 } else {
975 None
976 }
977 }
978
979 pub fn as_in_const_list(&self) -> Option<(InputRef, Vec<ExprImpl>)> {
980 if let ExprImpl::FunctionCall(function_call) = self
981 && function_call.func_type() == ExprType::In
982 {
983 let mut inputs = function_call.inputs().iter().cloned();
984 let input_ref = match inputs.next().unwrap() {
985 ExprImpl::InputRef(i) => *i,
986 _ => return None,
987 };
988 let list: Vec<_> = inputs
989 .inspect(|expr| {
990 assert!(expr.is_const());
992 })
993 .collect();
994
995 Some((input_ref, list))
996 } else {
997 None
998 }
999 }
1000
1001 pub fn as_or_disjunctions(&self) -> Option<Vec<ExprImpl>> {
1002 if let ExprImpl::FunctionCall(function_call) = self
1003 && function_call.func_type() == ExprType::Or
1004 {
1005 Some(to_disjunctions(self.clone()))
1006 } else {
1007 None
1008 }
1009 }
1010
1011 pub fn to_project_set_select_item_proto(&self) -> ProjectSetSelectItem {
1012 use risingwave_pb::expr::project_set_select_item::SelectItem::*;
1013
1014 ProjectSetSelectItem {
1015 select_item: Some(match self {
1016 ExprImpl::TableFunction(tf) => TableFunction(tf.to_protobuf()),
1017 expr => Expr(expr.to_expr_proto()),
1018 }),
1019 }
1020 }
1021
1022 pub fn from_expr_proto(proto: &ExprNode) -> RwResult<Self> {
1023 let rex_node = proto.get_rex_node()?;
1024 let ret_type = proto.get_return_type()?.into();
1025
1026 Ok(match rex_node {
1027 RexNode::InputRef(column_index) => Self::InputRef(Box::new(InputRef::from_expr_proto(
1028 *column_index as _,
1029 ret_type,
1030 )?)),
1031 RexNode::Constant(_) => Self::Literal(Box::new(Literal::from_expr_proto(proto)?)),
1032 RexNode::Udf(udf) => Self::UserDefinedFunction(Box::new(
1033 UserDefinedFunction::from_expr_proto(udf, ret_type)?,
1034 )),
1035 RexNode::FuncCall(function_call) => {
1036 Self::FunctionCall(Box::new(FunctionCall::from_expr_proto(
1037 function_call,
1038 proto.get_function_type()?, ret_type,
1040 )?))
1041 }
1042 RexNode::Now(_) => Self::Now(Box::new(Now {})),
1043 })
1044 }
1045}
1046
1047impl From<Condition> for ExprImpl {
1048 fn from(c: Condition) -> Self {
1049 ExprImpl::and(c.conjunctions)
1050 }
1051}
1052
1053impl std::fmt::Debug for ExprImpl {
1057 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1058 if f.alternate() {
1059 return match self {
1060 Self::InputRef(arg0) => f.debug_tuple("InputRef").field(arg0).finish(),
1061 Self::Literal(arg0) => f.debug_tuple("Literal").field(arg0).finish(),
1062 Self::FunctionCall(arg0) => f.debug_tuple("FunctionCall").field(arg0).finish(),
1063 Self::FunctionCallWithLambda(arg0) => {
1064 f.debug_tuple("FunctionCallWithLambda").field(arg0).finish()
1065 }
1066 Self::AggCall(arg0) => f.debug_tuple("AggCall").field(arg0).finish(),
1067 Self::Subquery(arg0) => f.debug_tuple("Subquery").field(arg0).finish(),
1068 Self::CorrelatedInputRef(arg0) => {
1069 f.debug_tuple("CorrelatedInputRef").field(arg0).finish()
1070 }
1071 Self::TableFunction(arg0) => f.debug_tuple("TableFunction").field(arg0).finish(),
1072 Self::WindowFunction(arg0) => f.debug_tuple("WindowFunction").field(arg0).finish(),
1073 Self::UserDefinedFunction(arg0) => {
1074 f.debug_tuple("UserDefinedFunction").field(arg0).finish()
1075 }
1076 Self::Parameter(arg0) => f.debug_tuple("Parameter").field(arg0).finish(),
1077 Self::Now(_) => f.debug_tuple("Now").finish(),
1078 };
1079 }
1080 match self {
1081 Self::InputRef(x) => write!(f, "{:?}", x),
1082 Self::Literal(x) => write!(f, "{:?}", x),
1083 Self::FunctionCall(x) => write!(f, "{:?}", x),
1084 Self::FunctionCallWithLambda(x) => write!(f, "{:?}", x),
1085 Self::AggCall(x) => write!(f, "{:?}", x),
1086 Self::Subquery(x) => write!(f, "{:?}", x),
1087 Self::CorrelatedInputRef(x) => write!(f, "{:?}", x),
1088 Self::TableFunction(x) => write!(f, "{:?}", x),
1089 Self::WindowFunction(x) => write!(f, "{:?}", x),
1090 Self::UserDefinedFunction(x) => write!(f, "{:?}", x),
1091 Self::Parameter(x) => write!(f, "{:?}", x),
1092 Self::Now(x) => write!(f, "{:?}", x),
1093 }
1094 }
1095}
1096
1097pub struct ExprDisplay<'a> {
1098 pub expr: &'a ExprImpl,
1099 pub input_schema: &'a Schema,
1100}
1101
1102impl std::fmt::Debug for ExprDisplay<'_> {
1103 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1104 let that = self.expr;
1105 match that {
1106 ExprImpl::InputRef(x) => write!(
1107 f,
1108 "{:?}",
1109 InputRefDisplay {
1110 input_ref: x,
1111 input_schema: self.input_schema
1112 }
1113 ),
1114 ExprImpl::Literal(x) => write!(f, "{:?}", x),
1115 ExprImpl::FunctionCall(x) => write!(
1116 f,
1117 "{:?}",
1118 FunctionCallDisplay {
1119 function_call: x,
1120 input_schema: self.input_schema
1121 }
1122 ),
1123 ExprImpl::FunctionCallWithLambda(x) => write!(
1124 f,
1125 "{:?}",
1126 FunctionCallDisplay {
1127 function_call: &x.to_full_function_call(),
1128 input_schema: self.input_schema
1129 }
1130 ),
1131 ExprImpl::AggCall(x) => write!(f, "{:?}", x),
1132 ExprImpl::Subquery(x) => write!(f, "{:?}", x),
1133 ExprImpl::CorrelatedInputRef(x) => write!(f, "{:?}", x),
1134 ExprImpl::TableFunction(x) => {
1135 write!(f, "{:?}", x)
1137 }
1138 ExprImpl::WindowFunction(x) => {
1139 write!(f, "{:?}", x)
1141 }
1142 ExprImpl::UserDefinedFunction(x) => {
1143 write!(
1144 f,
1145 "{:?}",
1146 UserDefinedFunctionDisplay {
1147 func_call: x,
1148 input_schema: self.input_schema
1149 }
1150 )
1151 }
1152 ExprImpl::Parameter(x) => write!(f, "{:?}", x),
1153 ExprImpl::Now(x) => write!(f, "{:?}", x),
1154 }
1155 }
1156}
1157
1158impl std::fmt::Display for ExprDisplay<'_> {
1159 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1160 (self as &dyn std::fmt::Debug).fmt(f)
1161 }
1162}
1163
1164#[cfg(test)]
1165macro_rules! assert_eq_input_ref {
1167 ($e:expr, $index:expr) => {
1168 match $e {
1169 ExprImpl::InputRef(i) => assert_eq!(i.index(), $index),
1170 _ => assert!(false, "Expected input ref, found {:?}", $e),
1171 }
1172 };
1173}
1174
1175#[cfg(test)]
1176pub(crate) use assert_eq_input_ref;
1177use risingwave_common::bail;
1178use risingwave_common::catalog::Schema;
1179use risingwave_common::row::OwnedRow;
1180
1181use crate::binder::BoundSetExpr;
1182use crate::utils::Condition;
1183
1184#[cfg(test)]
1185mod tests {
1186 use super::*;
1187
1188 #[test]
1189 fn test_expr_debug_alternate() {
1190 let mut e = InputRef::new(1, DataType::Boolean).into();
1191 e = FunctionCall::new(ExprType::Not, vec![e]).unwrap().into();
1192 let s = format!("{:#?}", e);
1193 assert!(s.contains("return_type: Boolean"))
1194 }
1195}