1use enum_as_inner::EnumAsInner;
16use fixedbitset::FixedBitSet;
17use futures::FutureExt;
18use paste::paste;
19use risingwave_common::array::ListValue;
20use risingwave_common::types::{DataType, Datum, JsonbVal, MapType, Scalar, ScalarImpl};
21use risingwave_expr::aggregate::PbAggKind;
22use risingwave_expr::expr::build_from_prost;
23use risingwave_pb::expr::expr_node::RexNode;
24use risingwave_pb::expr::{ExprNode, ProjectSetSelectItem};
25use user_defined_function::UserDefinedFunctionDisplay;
26
27use crate::error::{ErrorCode, Result as RwResult};
28
29mod agg_call;
30mod correlated_input_ref;
31mod function_call;
32mod function_call_with_lambda;
33mod input_ref;
34mod literal;
35mod now;
36mod parameter;
37mod pure;
38mod subquery;
39mod table_function;
40mod user_defined_function;
41mod window_function;
42
43mod order_by_expr;
44pub use order_by_expr::{OrderBy, OrderByExpr};
45
46mod expr_mutator;
47mod expr_rewriter;
48mod expr_visitor;
49pub mod function_impl;
50mod session_timezone;
51mod type_inference;
52mod utils;
53
54pub use agg_call::AggCall;
55pub use correlated_input_ref::{CorrelatedId, CorrelatedInputRef, Depth, InputRefDepthRewriter};
56pub use expr_mutator::ExprMutator;
57pub use expr_rewriter::{ExprRewriter, default_rewrite_expr};
58pub use expr_visitor::{ExprVisitor, default_visit_expr};
59pub use function_call::{FunctionCall, FunctionCallDisplay, is_row_function};
60pub use function_call_with_lambda::FunctionCallWithLambda;
61pub use input_ref::{InputRef, InputRefDisplay, input_ref_to_column_indices};
62pub use literal::Literal;
63pub use now::{InlineNowProcTime, Now, NowProcTimeFinder};
64pub use parameter::Parameter;
65pub use pure::*;
66pub use risingwave_pb::expr::expr_node::Type as ExprType;
67pub use session_timezone::{SessionTimezone, TimestamptzExprFinder};
68pub use subquery::{Subquery, SubqueryKind};
69pub use table_function::{TableFunction, TableFunctionType};
70pub use type_inference::*;
71pub use user_defined_function::UserDefinedFunction;
72pub use utils::*;
73pub use window_function::WindowFunction;
74
75const EXPR_DEPTH_THRESHOLD: usize = 30;
76const EXPR_TOO_DEEP_NOTICE: &str = "Some expression is too complicated. \
77Consider simplifying or splitting the query if you encounter any issues.";
78
79pub trait Expr: Into<ExprImpl> {
81 fn return_type(&self) -> DataType;
83
84 fn try_to_expr_proto(&self) -> Result<ExprNode, String>;
86
87 fn to_expr_proto(&self) -> ExprNode {
89 self.try_to_expr_proto()
90 .expect("failed to serialize expression to protobuf")
91 }
92}
93
94macro_rules! impl_expr_impl {
95 ($($t:ident,)*) => {
96 #[derive(Clone, Eq, PartialEq, Hash, EnumAsInner)]
97 pub enum ExprImpl {
98 $($t(Box<$t>),)*
99 }
100
101 impl ExprImpl {
102 pub fn variant_name(&self) -> &'static str {
103 match self {
104 $(ExprImpl::$t(_) => stringify!($t),)*
105 }
106 }
107 }
108
109 $(
110 impl From<$t> for ExprImpl {
111 fn from(o: $t) -> ExprImpl {
112 ExprImpl::$t(Box::new(o))
113 }
114 })*
115
116 impl Expr for ExprImpl {
117 fn return_type(&self) -> DataType {
118 match self {
119 $(ExprImpl::$t(expr) => expr.return_type(),)*
120 }
121 }
122
123 fn try_to_expr_proto(&self) -> Result<ExprNode, String> {
124 match self {
125 $(ExprImpl::$t(expr) => expr.try_to_expr_proto(),)*
126 }
127 }
128 }
129 };
130}
131
132impl_expr_impl!(
133 CorrelatedInputRef,
135 InputRef,
136 Literal,
137 FunctionCall,
138 FunctionCallWithLambda,
139 AggCall,
140 Subquery,
141 TableFunction,
142 WindowFunction,
143 UserDefinedFunction,
144 Parameter,
145 Now,
146);
147
148impl ExprImpl {
149 #[inline(always)]
151 pub fn literal_int(v: i32) -> Self {
152 Literal::new(Some(v.to_scalar_value()), DataType::Int32).into()
153 }
154
155 #[inline(always)]
157 pub fn literal_bigint(v: i64) -> Self {
158 Literal::new(Some(v.to_scalar_value()), DataType::Int64).into()
159 }
160
161 #[inline(always)]
163 pub fn literal_f64(v: f64) -> Self {
164 Literal::new(Some(v.into()), DataType::Float64).into()
165 }
166
167 #[inline(always)]
169 pub fn literal_bool(v: bool) -> Self {
170 Literal::new(Some(v.to_scalar_value()), DataType::Boolean).into()
171 }
172
173 #[inline(always)]
175 pub fn literal_varchar(v: String) -> Self {
176 Literal::new(Some(v.into()), DataType::Varchar).into()
177 }
178
179 #[inline(always)]
181 pub fn literal_null(element_type: DataType) -> Self {
182 Literal::new(None, element_type).into()
183 }
184
185 #[inline(always)]
187 pub fn literal_jsonb(v: JsonbVal) -> Self {
188 Literal::new(Some(v.into()), DataType::Jsonb).into()
189 }
190
191 #[inline(always)]
193 pub fn literal_list(v: ListValue, element_type: DataType) -> Self {
194 Literal::new(
195 Some(v.to_scalar_value()),
196 DataType::List(Box::new(element_type)),
197 )
198 .into()
199 }
200
201 pub fn take(&mut self) -> Self {
203 std::mem::replace(self, Self::literal_null(self.return_type()))
204 }
205
206 #[inline(always)]
208 pub fn count_star() -> Self {
209 AggCall::new(
210 PbAggKind::Count.into(),
211 vec![],
212 false,
213 OrderBy::any(),
214 Condition::true_cond(),
215 vec![],
216 )
217 .unwrap()
218 .into()
219 }
220
221 pub fn and(exprs: impl IntoIterator<Item = ExprImpl>) -> Self {
225 merge_expr_by_logical(exprs, ExprType::And, ExprImpl::literal_bool(true))
226 }
227
228 pub fn or(exprs: impl IntoIterator<Item = ExprImpl>) -> Self {
232 merge_expr_by_logical(exprs, ExprType::Or, ExprImpl::literal_bool(false))
233 }
234
235 pub fn collect_input_refs(&self, input_col_num: usize) -> FixedBitSet {
240 collect_input_refs(input_col_num, [self])
241 }
242
243 pub fn is_pure(&self) -> bool {
245 is_pure(self)
246 }
247
248 pub fn is_impure(&self) -> bool {
249 is_impure(self)
250 }
251
252 pub fn count_nows(&self) -> usize {
254 let mut visitor = CountNow::default();
255 visitor.visit_expr(self);
256 visitor.count()
257 }
258
259 pub fn is_null(&self) -> bool {
261 matches!(self, ExprImpl::Literal(literal) if literal.get_data().is_none())
262 }
263
264 pub fn is_untyped(&self) -> bool {
266 matches!(self, ExprImpl::Literal(literal) if literal.is_untyped())
267 || matches!(self, ExprImpl::Parameter(parameter) if !parameter.has_infer())
268 }
269
270 pub fn cast_implicit(mut self, target: &DataType) -> Result<ExprImpl, CastError> {
272 FunctionCall::cast_mut(&mut self, target, CastContext::Implicit)?;
273 Ok(self)
274 }
275
276 pub fn cast_assign(mut self, target: &DataType) -> Result<ExprImpl, CastError> {
278 FunctionCall::cast_mut(&mut self, target, CastContext::Assign)?;
279 Ok(self)
280 }
281
282 pub fn cast_explicit(mut self, target: &DataType) -> Result<ExprImpl, CastError> {
284 FunctionCall::cast_mut(&mut self, target, CastContext::Explicit)?;
285 Ok(self)
286 }
287
288 pub fn cast_implicit_mut(&mut self, target: &DataType) -> Result<(), CastError> {
290 FunctionCall::cast_mut(self, target, CastContext::Implicit)
291 }
292
293 pub fn cast_explicit_mut(&mut self, target: &DataType) -> Result<(), CastError> {
295 FunctionCall::cast_mut(self, target, CastContext::Explicit)
296 }
297
298 pub fn cast_to_regclass(self) -> Result<ExprImpl, CastError> {
301 match self.return_type() {
302 DataType::Varchar => Ok(ExprImpl::FunctionCall(Box::new(
303 FunctionCall::new_unchecked(ExprType::CastRegclass, vec![self], DataType::Int32),
304 ))),
305 DataType::Int32 => Ok(self),
306 dt if dt.is_int() => Ok(self.cast_explicit(&DataType::Int32)?),
307 _ => bail_cast_error!("unsupported input type"),
308 }
309 }
310
311 pub fn cast_to_regclass_mut(&mut self) -> Result<(), CastError> {
313 let owned = std::mem::replace(self, ExprImpl::literal_bool(false));
314 *self = owned.cast_to_regclass()?;
315 Ok(())
316 }
317
318 pub fn ensure_array_type(&self) -> Result<(), ErrorCode> {
320 if self.is_untyped() {
321 return Err(ErrorCode::BindError(
322 "could not determine polymorphic type because input has type unknown".into(),
323 ));
324 }
325 match self.return_type() {
326 DataType::List(_) => Ok(()),
327 t => Err(ErrorCode::BindError(format!("expects array but got {t}"))),
328 }
329 }
330
331 pub fn try_into_map_type(&self) -> Result<MapType, ErrorCode> {
333 if self.is_untyped() {
334 return Err(ErrorCode::BindError(
335 "could not determine polymorphic type because input has type unknown".into(),
336 ));
337 }
338 match self.return_type() {
339 DataType::Map(m) => Ok(m),
340 t => Err(ErrorCode::BindError(format!("expects map but got {t}"))),
341 }
342 }
343
344 pub fn enforce_bool_clause(self, clause: &str) -> RwResult<ExprImpl> {
346 if self.is_untyped() {
347 let inner = self.cast_implicit(&DataType::Boolean)?;
348 return Ok(inner);
349 }
350 let return_type = self.return_type();
351 if return_type != DataType::Boolean {
352 bail!(
353 "argument of {} must be boolean, not type {:?}",
354 clause,
355 return_type
356 )
357 }
358 Ok(self)
359 }
360
361 pub fn cast_output(self) -> RwResult<ExprImpl> {
372 if self.return_type() == DataType::Boolean {
373 return Ok(FunctionCall::new(ExprType::BoolOut, vec![self])?.into());
374 }
375 self.cast_assign(&DataType::Varchar)
378 .map_err(|err| err.into())
379 }
380
381 pub async fn eval_row(&self, input: &OwnedRow) -> RwResult<Datum> {
386 let backend_expr = build_from_prost(&self.to_expr_proto())?;
387 Ok(backend_expr.eval_row(input).await?)
388 }
389
390 pub fn try_fold_const(&self) -> Option<RwResult<Datum>> {
397 if self.is_const() {
398 self.eval_row(&OwnedRow::empty())
399 .now_or_never()
400 .expect("constant expression should not be async")
401 .into()
402 } else {
403 None
404 }
405 }
406
407 pub fn fold_const(&self) -> RwResult<Datum> {
409 self.try_fold_const().expect("expression is not constant")
410 }
411}
412
413macro_rules! impl_has_variant {
418 ( $($variant:ty),* ) => {
419 paste! {
420 impl ExprImpl {
421 $(
422 pub fn [<has_ $variant:snake>](&self) -> bool {
423 struct Has { has: bool }
424
425 impl ExprVisitor for Has {
426 fn [<visit_ $variant:snake>](&mut self, _: &$variant) {
427 self.has = true;
428 }
429 }
430
431 let mut visitor = Has { has: false };
432 visitor.visit_expr(self);
433 visitor.has
434 }
435 )*
436 }
437 }
438 };
439}
440
441impl_has_variant! {InputRef, Literal, FunctionCall, FunctionCallWithLambda, AggCall, Subquery, TableFunction, WindowFunction, UserDefinedFunction, Now}
442
443#[derive(Debug, Clone, PartialEq, Eq, Hash)]
444pub struct InequalityInputPair {
445 pub(crate) key_required_larger: usize,
447 pub(crate) key_required_smaller: usize,
449 pub(crate) delta_expression: Option<(ExprType, ExprImpl)>,
451}
452
453impl InequalityInputPair {
454 fn new(
455 key_required_larger: usize,
456 key_required_smaller: usize,
457 delta_expression: Option<(ExprType, ExprImpl)>,
458 ) -> Self {
459 Self {
460 key_required_larger,
461 key_required_smaller,
462 delta_expression,
463 }
464 }
465}
466
467impl ExprImpl {
468 pub fn has_correlated_input_ref(&self, _: std::convert::Infallible) -> bool {
479 unreachable!()
480 }
481
482 pub fn has_correlated_input_ref_by_depth(&self, depth: Depth) -> bool {
488 struct Has {
489 depth: usize,
490 has: bool,
491 }
492
493 impl ExprVisitor for Has {
494 fn visit_correlated_input_ref(&mut self, correlated_input_ref: &CorrelatedInputRef) {
495 if correlated_input_ref.depth() == self.depth {
496 self.has = true;
497 }
498 }
499
500 fn visit_subquery(&mut self, subquery: &Subquery) {
501 self.has |= subquery.is_correlated_by_depth(self.depth);
502 }
503 }
504
505 let mut visitor = Has { depth, has: false };
506 visitor.visit_expr(self);
507 visitor.has
508 }
509
510 pub fn has_correlated_input_ref_by_correlated_id(&self, correlated_id: CorrelatedId) -> bool {
511 struct Has {
512 correlated_id: CorrelatedId,
513 has: bool,
514 }
515
516 impl ExprVisitor for Has {
517 fn visit_correlated_input_ref(&mut self, correlated_input_ref: &CorrelatedInputRef) {
518 if correlated_input_ref.correlated_id() == self.correlated_id {
519 self.has = true;
520 }
521 }
522
523 fn visit_subquery(&mut self, subquery: &Subquery) {
524 self.has |= subquery.is_correlated_by_correlated_id(self.correlated_id);
525 }
526 }
527
528 let mut visitor = Has {
529 correlated_id,
530 has: false,
531 };
532 visitor.visit_expr(self);
533 visitor.has
534 }
535
536 pub fn collect_correlated_indices_by_depth_and_assign_id(
539 &mut self,
540 depth: Depth,
541 correlated_id: CorrelatedId,
542 ) -> Vec<usize> {
543 struct Collector {
544 depth: Depth,
545 correlated_indices: Vec<usize>,
546 correlated_id: CorrelatedId,
547 }
548
549 impl ExprMutator for Collector {
550 fn visit_correlated_input_ref(
551 &mut self,
552 correlated_input_ref: &mut CorrelatedInputRef,
553 ) {
554 if correlated_input_ref.depth() == self.depth {
555 self.correlated_indices.push(correlated_input_ref.index());
556 correlated_input_ref.set_correlated_id(self.correlated_id);
557 }
558 }
559
560 fn visit_subquery(&mut self, subquery: &mut Subquery) {
561 self.correlated_indices.extend(
562 subquery.collect_correlated_indices_by_depth_and_assign_id(
563 self.depth,
564 self.correlated_id,
565 ),
566 );
567 }
568 }
569
570 let mut collector = Collector {
571 depth,
572 correlated_indices: vec![],
573 correlated_id,
574 };
575 collector.visit_expr(self);
576 collector.correlated_indices.sort();
577 collector.correlated_indices.dedup();
578 collector.correlated_indices
579 }
580
581 pub fn is_const(&self) -> bool {
585 let only_literal_and_func = {
586 struct HasOthers {
587 has_others: bool,
588 }
589
590 impl ExprVisitor for HasOthers {
591 fn visit_expr(&mut self, expr: &ExprImpl) {
592 match expr {
593 ExprImpl::CorrelatedInputRef(_)
594 | ExprImpl::InputRef(_)
595 | ExprImpl::AggCall(_)
596 | ExprImpl::Subquery(_)
597 | ExprImpl::TableFunction(_)
598 | ExprImpl::WindowFunction(_)
599 | ExprImpl::UserDefinedFunction(_)
600 | ExprImpl::Parameter(_)
601 | ExprImpl::Now(_) => self.has_others = true,
602 ExprImpl::Literal(_inner) => {}
603 ExprImpl::FunctionCall(inner) => {
604 if !self.is_short_circuit(inner) {
605 self.visit_function_call(inner)
609 }
610 }
611 ExprImpl::FunctionCallWithLambda(inner) => {
612 self.visit_function_call_with_lambda(inner)
613 }
614 }
615 }
616 }
617
618 impl HasOthers {
619 fn is_short_circuit(&self, func_call: &FunctionCall) -> bool {
620 fn eval_first(e: &ExprImpl, expect: bool) -> bool {
622 if let ExprImpl::Literal(l) = e {
623 *l.get_data() == Some(ScalarImpl::Bool(expect))
624 } else {
625 false
626 }
627 }
628
629 match func_call.func_type {
630 ExprType::Or => eval_first(&func_call.inputs()[0], true),
631 ExprType::And => eval_first(&func_call.inputs()[0], false),
632 _ => false,
633 }
634 }
635 }
636
637 let mut visitor = HasOthers { has_others: false };
638 visitor.visit_expr(self);
639 !visitor.has_others
640 };
641
642 let is_pure = self.is_pure();
643
644 only_literal_and_func && is_pure
645 }
646
647 pub fn as_eq_cond(&self) -> Option<(InputRef, InputRef)> {
650 if let ExprImpl::FunctionCall(function_call) = self
651 && function_call.func_type() == ExprType::Equal
652 && let (_, ExprImpl::InputRef(x), ExprImpl::InputRef(y)) =
653 function_call.clone().decompose_as_binary()
654 {
655 if x.index() < y.index() {
656 Some((*x, *y))
657 } else {
658 Some((*y, *x))
659 }
660 } else {
661 None
662 }
663 }
664
665 pub fn as_is_not_distinct_from_cond(&self) -> Option<(InputRef, InputRef)> {
666 if let ExprImpl::FunctionCall(function_call) = self
667 && function_call.func_type() == ExprType::IsNotDistinctFrom
668 && let (_, ExprImpl::InputRef(x), ExprImpl::InputRef(y)) =
669 function_call.clone().decompose_as_binary()
670 {
671 if x.index() < y.index() {
672 Some((*x, *y))
673 } else {
674 Some((*y, *x))
675 }
676 } else {
677 None
678 }
679 }
680
681 pub fn reverse_comparison(comparison: ExprType) -> ExprType {
682 match comparison {
683 ExprType::LessThan => ExprType::GreaterThan,
684 ExprType::LessThanOrEqual => ExprType::GreaterThanOrEqual,
685 ExprType::GreaterThan => ExprType::LessThan,
686 ExprType::GreaterThanOrEqual => ExprType::LessThanOrEqual,
687 ExprType::Equal | ExprType::IsNotDistinctFrom => comparison,
688 _ => unreachable!(),
689 }
690 }
691
692 pub fn as_comparison_cond(&self) -> Option<(InputRef, ExprType, InputRef)> {
693 if let ExprImpl::FunctionCall(function_call) = self {
694 match function_call.func_type() {
695 ty @ (ExprType::LessThan
696 | ExprType::LessThanOrEqual
697 | ExprType::GreaterThan
698 | ExprType::GreaterThanOrEqual) => {
699 let (_, op1, op2) = function_call.clone().decompose_as_binary();
700 if let (ExprImpl::InputRef(x), ExprImpl::InputRef(y)) = (op1, op2) {
701 if x.index < y.index {
702 Some((*x, ty, *y))
703 } else {
704 Some((*y, Self::reverse_comparison(ty), *x))
705 }
706 } else {
707 None
708 }
709 }
710 _ => None,
711 }
712 } else {
713 None
714 }
715 }
716
717 pub fn as_now_comparison_cond(&self) -> Option<(ExprImpl, ExprType, ExprImpl)> {
723 if let ExprImpl::FunctionCall(function_call) = self {
724 match function_call.func_type() {
725 ty @ (ExprType::Equal
726 | ExprType::LessThan
727 | ExprType::LessThanOrEqual
728 | ExprType::GreaterThan
729 | ExprType::GreaterThanOrEqual) => {
730 let (_, op1, op2) = function_call.clone().decompose_as_binary();
731 if !op1.has_now()
732 && op1.has_input_ref()
733 && op2.has_now()
734 && !op2.has_input_ref()
735 {
736 Some((op1, ty, op2))
737 } else if op1.has_now()
738 && !op1.has_input_ref()
739 && !op2.has_now()
740 && op2.has_input_ref()
741 {
742 Some((op2, Self::reverse_comparison(ty), op1))
743 } else {
744 None
745 }
746 }
747 _ => None,
748 }
749 } else {
750 None
751 }
752 }
753
754 pub(crate) fn as_input_comparison_cond(&self) -> Option<InequalityInputPair> {
757 if let ExprImpl::FunctionCall(function_call) = self {
758 match function_call.func_type() {
759 ty @ (ExprType::LessThan
760 | ExprType::LessThanOrEqual
761 | ExprType::GreaterThan
762 | ExprType::GreaterThanOrEqual) => {
763 let (_, mut op1, mut op2) = function_call.clone().decompose_as_binary();
764 if matches!(ty, ExprType::LessThan | ExprType::LessThanOrEqual) {
765 std::mem::swap(&mut op1, &mut op2);
766 }
767 if let (Some((lft_input, lft_offset)), Some((rht_input, rht_offset))) =
768 (op1.as_input_offset(), op2.as_input_offset())
769 {
770 match (lft_offset, rht_offset) {
771 (Some(_), Some(_)) => None,
772 (None, rht_offset @ Some(_)) => {
773 Some(InequalityInputPair::new(lft_input, rht_input, rht_offset))
774 }
775 (Some((operator, operand)), None) => Some(InequalityInputPair::new(
776 lft_input,
777 rht_input,
778 Some((
779 if operator == ExprType::Add {
780 ExprType::Subtract
781 } else {
782 ExprType::Add
783 },
784 operand,
785 )),
786 )),
787 (None, None) => {
788 Some(InequalityInputPair::new(lft_input, rht_input, None))
789 }
790 }
791 } else {
792 None
793 }
794 }
795 _ => None,
796 }
797 } else {
798 None
799 }
800 }
801
802 fn as_input_offset(&self) -> Option<(usize, Option<(ExprType, ExprImpl)>)> {
805 match self {
806 ExprImpl::InputRef(input_ref) => Some((input_ref.index(), None)),
807 ExprImpl::FunctionCall(function_call) => {
808 let expr_type = function_call.func_type();
809 match expr_type {
810 ExprType::Add | ExprType::Subtract => {
811 let (_, lhs, rhs) = function_call.clone().decompose_as_binary();
812 if let ExprImpl::InputRef(input_ref) = &lhs
813 && rhs.is_const()
814 {
815 if rhs.return_type() == DataType::Interval
818 && rhs.as_literal().is_none_or(|literal| {
819 literal.get_data().as_ref().is_some_and(|scalar| {
820 let interval = scalar.as_interval();
821 interval.months() != 0 || interval.days() != 0
822 })
823 })
824 {
825 None
826 } else {
827 Some((input_ref.index(), Some((expr_type, rhs))))
828 }
829 } else {
830 None
831 }
832 }
833 _ => None,
834 }
835 }
836 _ => None,
837 }
838 }
839
840 pub fn as_eq_const(&self) -> Option<(InputRef, ExprImpl)> {
841 if let ExprImpl::FunctionCall(function_call) = self
842 && function_call.func_type() == ExprType::Equal
843 {
844 match function_call.clone().decompose_as_binary() {
845 (_, ExprImpl::InputRef(x), y) if y.is_const() => Some((*x, y)),
846 (_, x, ExprImpl::InputRef(y)) if x.is_const() => Some((*y, x)),
847 _ => None,
848 }
849 } else {
850 None
851 }
852 }
853
854 pub fn as_eq_correlated_input_ref(&self) -> Option<(InputRef, CorrelatedInputRef)> {
855 if let ExprImpl::FunctionCall(function_call) = self
856 && function_call.func_type() == ExprType::Equal
857 {
858 match function_call.clone().decompose_as_binary() {
859 (_, ExprImpl::InputRef(x), ExprImpl::CorrelatedInputRef(y)) => Some((*x, *y)),
860 (_, ExprImpl::CorrelatedInputRef(x), ExprImpl::InputRef(y)) => Some((*y, *x)),
861 _ => None,
862 }
863 } else {
864 None
865 }
866 }
867
868 pub fn as_is_null(&self) -> Option<InputRef> {
869 if let ExprImpl::FunctionCall(function_call) = self
870 && function_call.func_type() == ExprType::IsNull
871 {
872 match function_call.clone().decompose_as_unary() {
873 (_, ExprImpl::InputRef(x)) => Some(*x),
874 _ => None,
875 }
876 } else {
877 None
878 }
879 }
880
881 pub fn as_comparison_const(&self) -> Option<(InputRef, ExprType, ExprImpl)> {
882 fn reverse_comparison(comparison: ExprType) -> ExprType {
883 match comparison {
884 ExprType::LessThan => ExprType::GreaterThan,
885 ExprType::LessThanOrEqual => ExprType::GreaterThanOrEqual,
886 ExprType::GreaterThan => ExprType::LessThan,
887 ExprType::GreaterThanOrEqual => ExprType::LessThanOrEqual,
888 _ => unreachable!(),
889 }
890 }
891
892 if let ExprImpl::FunctionCall(function_call) = self {
893 match function_call.func_type() {
894 ty @ (ExprType::LessThan
895 | ExprType::LessThanOrEqual
896 | ExprType::GreaterThan
897 | ExprType::GreaterThanOrEqual) => {
898 let (_, op1, op2) = function_call.clone().decompose_as_binary();
899 match (op1, op2) {
900 (ExprImpl::InputRef(x), y) if y.is_const() => Some((*x, ty, y)),
901 (x, ExprImpl::InputRef(y)) if x.is_const() => {
902 Some((*y, reverse_comparison(ty), x))
903 }
904 _ => None,
905 }
906 }
907 _ => None,
908 }
909 } else {
910 None
911 }
912 }
913
914 pub fn as_in_const_list(&self) -> Option<(InputRef, Vec<ExprImpl>)> {
915 if let ExprImpl::FunctionCall(function_call) = self
916 && function_call.func_type() == ExprType::In
917 {
918 let mut inputs = function_call.inputs().iter().cloned();
919 let input_ref = match inputs.next().unwrap() {
920 ExprImpl::InputRef(i) => *i,
921 _ => return None,
922 };
923 let list: Vec<_> = inputs
924 .inspect(|expr| {
925 assert!(expr.is_const());
927 })
928 .collect();
929
930 Some((input_ref, list))
931 } else {
932 None
933 }
934 }
935
936 pub fn as_or_disjunctions(&self) -> Option<Vec<ExprImpl>> {
937 if let ExprImpl::FunctionCall(function_call) = self
938 && function_call.func_type() == ExprType::Or
939 {
940 Some(to_disjunctions(self.clone()))
941 } else {
942 None
943 }
944 }
945
946 pub fn to_project_set_select_item_proto(&self) -> ProjectSetSelectItem {
947 use risingwave_pb::expr::project_set_select_item::SelectItem::*;
948
949 ProjectSetSelectItem {
950 select_item: Some(match self {
951 ExprImpl::TableFunction(tf) => TableFunction(tf.to_protobuf()),
952 expr => Expr(expr.to_expr_proto()),
953 }),
954 }
955 }
956
957 pub fn from_expr_proto(proto: &ExprNode) -> RwResult<Self> {
958 let rex_node = proto.get_rex_node()?;
959 let ret_type = proto.get_return_type()?.into();
960
961 Ok(match rex_node {
962 RexNode::InputRef(column_index) => Self::InputRef(Box::new(InputRef::from_expr_proto(
963 *column_index as _,
964 ret_type,
965 )?)),
966 RexNode::Constant(_) => Self::Literal(Box::new(Literal::from_expr_proto(proto)?)),
967 RexNode::Udf(udf) => Self::UserDefinedFunction(Box::new(
968 UserDefinedFunction::from_expr_proto(udf, ret_type)?,
969 )),
970 RexNode::FuncCall(function_call) => {
971 Self::FunctionCall(Box::new(FunctionCall::from_expr_proto(
972 function_call,
973 proto.get_function_type()?, ret_type,
975 )?))
976 }
977 RexNode::Now(_) => Self::Now(Box::new(Now {})),
978 })
979 }
980}
981
982impl From<Condition> for ExprImpl {
983 fn from(c: Condition) -> Self {
984 ExprImpl::and(c.conjunctions)
985 }
986}
987
988impl std::fmt::Debug for ExprImpl {
992 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
993 if f.alternate() {
994 return match self {
995 Self::InputRef(arg0) => f.debug_tuple("InputRef").field(arg0).finish(),
996 Self::Literal(arg0) => f.debug_tuple("Literal").field(arg0).finish(),
997 Self::FunctionCall(arg0) => f.debug_tuple("FunctionCall").field(arg0).finish(),
998 Self::FunctionCallWithLambda(arg0) => {
999 f.debug_tuple("FunctionCallWithLambda").field(arg0).finish()
1000 }
1001 Self::AggCall(arg0) => f.debug_tuple("AggCall").field(arg0).finish(),
1002 Self::Subquery(arg0) => f.debug_tuple("Subquery").field(arg0).finish(),
1003 Self::CorrelatedInputRef(arg0) => {
1004 f.debug_tuple("CorrelatedInputRef").field(arg0).finish()
1005 }
1006 Self::TableFunction(arg0) => f.debug_tuple("TableFunction").field(arg0).finish(),
1007 Self::WindowFunction(arg0) => f.debug_tuple("WindowFunction").field(arg0).finish(),
1008 Self::UserDefinedFunction(arg0) => {
1009 f.debug_tuple("UserDefinedFunction").field(arg0).finish()
1010 }
1011 Self::Parameter(arg0) => f.debug_tuple("Parameter").field(arg0).finish(),
1012 Self::Now(_) => f.debug_tuple("Now").finish(),
1013 };
1014 }
1015 match self {
1016 Self::InputRef(x) => write!(f, "{:?}", x),
1017 Self::Literal(x) => write!(f, "{:?}", x),
1018 Self::FunctionCall(x) => write!(f, "{:?}", x),
1019 Self::FunctionCallWithLambda(x) => write!(f, "{:?}", x),
1020 Self::AggCall(x) => write!(f, "{:?}", x),
1021 Self::Subquery(x) => write!(f, "{:?}", x),
1022 Self::CorrelatedInputRef(x) => write!(f, "{:?}", x),
1023 Self::TableFunction(x) => write!(f, "{:?}", x),
1024 Self::WindowFunction(x) => write!(f, "{:?}", x),
1025 Self::UserDefinedFunction(x) => write!(f, "{:?}", x),
1026 Self::Parameter(x) => write!(f, "{:?}", x),
1027 Self::Now(x) => write!(f, "{:?}", x),
1028 }
1029 }
1030}
1031
1032pub struct ExprDisplay<'a> {
1033 pub expr: &'a ExprImpl,
1034 pub input_schema: &'a Schema,
1035}
1036
1037impl std::fmt::Debug for ExprDisplay<'_> {
1038 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1039 let that = self.expr;
1040 match that {
1041 ExprImpl::InputRef(x) => write!(
1042 f,
1043 "{:?}",
1044 InputRefDisplay {
1045 input_ref: x,
1046 input_schema: self.input_schema
1047 }
1048 ),
1049 ExprImpl::Literal(x) => write!(f, "{:?}", x),
1050 ExprImpl::FunctionCall(x) => write!(
1051 f,
1052 "{:?}",
1053 FunctionCallDisplay {
1054 function_call: x,
1055 input_schema: self.input_schema
1056 }
1057 ),
1058 ExprImpl::FunctionCallWithLambda(x) => write!(
1059 f,
1060 "{:?}",
1061 FunctionCallDisplay {
1062 function_call: &x.to_full_function_call(),
1063 input_schema: self.input_schema
1064 }
1065 ),
1066 ExprImpl::AggCall(x) => write!(f, "{:?}", x),
1067 ExprImpl::Subquery(x) => write!(f, "{:?}", x),
1068 ExprImpl::CorrelatedInputRef(x) => write!(f, "{:?}", x),
1069 ExprImpl::TableFunction(x) => {
1070 write!(f, "{:?}", x)
1072 }
1073 ExprImpl::WindowFunction(x) => {
1074 write!(f, "{:?}", x)
1076 }
1077 ExprImpl::UserDefinedFunction(x) => {
1078 write!(
1079 f,
1080 "{:?}",
1081 UserDefinedFunctionDisplay {
1082 func_call: x,
1083 input_schema: self.input_schema
1084 }
1085 )
1086 }
1087 ExprImpl::Parameter(x) => write!(f, "{:?}", x),
1088 ExprImpl::Now(x) => write!(f, "{:?}", x),
1089 }
1090 }
1091}
1092
1093impl std::fmt::Display for ExprDisplay<'_> {
1094 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1095 (self as &dyn std::fmt::Debug).fmt(f)
1096 }
1097}
1098
1099#[cfg(test)]
1100macro_rules! assert_eq_input_ref {
1102 ($e:expr, $index:expr) => {
1103 match $e {
1104 ExprImpl::InputRef(i) => assert_eq!(i.index(), $index),
1105 _ => assert!(false, "Expected input ref, found {:?}", $e),
1106 }
1107 };
1108}
1109
1110#[cfg(test)]
1111pub(crate) use assert_eq_input_ref;
1112use risingwave_common::bail;
1113use risingwave_common::catalog::Schema;
1114use risingwave_common::row::OwnedRow;
1115
1116use crate::utils::Condition;
1117
1118#[cfg(test)]
1119mod tests {
1120 use super::*;
1121
1122 #[test]
1123 fn test_expr_debug_alternate() {
1124 let mut e = InputRef::new(1, DataType::Boolean).into();
1125 e = FunctionCall::new(ExprType::Not, vec![e]).unwrap().into();
1126 let s = format!("{:#?}", e);
1127 assert!(s.contains("return_type: Boolean"))
1128 }
1129}