1use std::collections::VecDeque;
16
17use fixedbitset::FixedBitSet;
18use risingwave_common::types::{DataType, ScalarImpl};
19use risingwave_pb::expr::expr_node::Type;
20
21use super::now::RewriteNowToProcTime;
22use super::{Expr, ExprImpl, ExprRewriter, ExprVisitor, FunctionCall, InputRef};
23use crate::expr::ExprType;
24
25fn split_expr_by(expr: ExprImpl, op: ExprType, rets: &mut Vec<ExprImpl>) {
26 match expr {
27 ExprImpl::FunctionCall(func_call) if func_call.func_type() == op => {
28 let (_, exprs, _) = func_call.decompose();
29 for expr in exprs {
30 split_expr_by(expr, op, rets);
31 }
32 }
33 _ => rets.push(expr),
34 }
35}
36
37pub(super) fn merge_expr_by_logical<I>(exprs: I, op: ExprType, identity_elem: ExprImpl) -> ExprImpl
41where
42 I: IntoIterator<Item = ExprImpl>,
43{
44 let mut exprs: VecDeque<_> = exprs.into_iter().map(|e| (0usize, e)).collect();
45
46 while exprs.len() > 1 {
47 let (level, lhs) = exprs.pop_front().unwrap();
48 let rhs_level = exprs.front().unwrap().0;
49
50 if level < rhs_level {
52 exprs.push_back((level, lhs));
53 } else {
54 let rhs = exprs.pop_front().unwrap().1;
55 let new_expr = FunctionCall::new(op, vec![lhs, rhs]).unwrap().into();
56 exprs.push_back((level + 1, new_expr));
57 }
58 }
59
60 exprs.pop_front().map(|(_, e)| e).unwrap_or(identity_elem)
61}
62
63pub fn to_conjunctions(expr: ExprImpl) -> Vec<ExprImpl> {
66 let mut rets = vec![];
67 split_expr_by(expr, ExprType::And, &mut rets);
68 rets
69}
70
71pub fn to_disjunctions(expr: ExprImpl) -> Vec<ExprImpl> {
74 let mut rets = vec![];
75 split_expr_by(expr, ExprType::Or, &mut rets);
76 rets
77}
78
79pub fn fold_boolean_constant(expr: ExprImpl) -> ExprImpl {
82 let mut rewriter = BooleanConstantFolding {};
83 rewriter.rewrite_expr(expr)
84}
85
86pub fn column_self_eq_eliminate(expr: ExprImpl) -> ExprImpl {
88 ColumnSelfEqualRewriter::rewrite(expr)
89}
90
91pub struct ColumnSelfEqualRewriter {}
97
98impl ColumnSelfEqualRewriter {
99 fn extract_column(expr: ExprImpl, columns: &mut Vec<ExprImpl>) {
101 match expr.clone() {
102 ExprImpl::FunctionCall(func_call) => {
103 if Self::is_not_null(func_call.func_type()) {
105 return;
106 }
107 for sub_expr in func_call.inputs() {
108 Self::extract_column(sub_expr.clone(), columns);
109 }
110 }
111 ExprImpl::InputRef(_) if !columns.contains(&expr) => {
112 columns.push(expr);
114 }
115 _ => (),
116 }
117 }
118
119 fn is_not_null(func_type: ExprType) -> bool {
121 func_type == ExprType::IsNull
122 || func_type == ExprType::IsNotNull
123 || func_type == ExprType::IsTrue
124 || func_type == ExprType::IsFalse
125 || func_type == ExprType::IsNotTrue
126 || func_type == ExprType::IsNotFalse
127 }
128
129 pub fn rewrite(expr: ExprImpl) -> ExprImpl {
130 let mut columns = vec![];
131 Self::extract_column(expr.clone(), &mut columns);
132 if columns.len() > 1 {
133 return expr;
135 }
136
137 let ExprImpl::FunctionCall(func_call) = expr.clone() else {
139 return expr;
140 };
141 if func_call.func_type() != ExprType::Equal || func_call.inputs().len() != 2 {
142 return expr;
143 }
144 assert_eq!(func_call.return_type(), DataType::Boolean);
145 let inputs = func_call.inputs();
146 let e1 = inputs[0].clone();
147 let e2 = inputs[1].clone();
148
149 if e1 == e2 {
150 if columns.is_empty() {
151 return ExprImpl::literal_bool(true);
152 }
153 let Ok(ret) = FunctionCall::new(ExprType::IsNotNull, vec![columns[0].clone()]) else {
154 return expr;
155 };
156 ret.into()
157 } else {
158 expr
159 }
160 }
161}
162
163struct BooleanConstantFolding {}
165
166impl ExprRewriter for BooleanConstantFolding {
167 fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
173 let (func_type, inputs, ret) = func_call.decompose();
174 let inputs: Vec<_> = inputs
175 .into_iter()
176 .map(|expr| self.rewrite_expr(expr))
177 .collect();
178 let bool_constant_values: Vec<Option<bool>> =
179 inputs.iter().map(try_get_bool_constant).collect();
180 let contains_bool_constant = bool_constant_values.iter().any(|x| x.is_some());
181 let prepare_binary_function_inputs = |mut inputs: Vec<ExprImpl>| -> (ExprImpl, ExprImpl) {
183 assert_eq!(inputs.len(), 2);
185 let rhs = inputs.pop().unwrap();
186 let lhs = inputs.pop().unwrap();
187 if bool_constant_values[0].is_some() {
188 (lhs, rhs)
189 } else {
190 (rhs, lhs)
191 }
192 };
193 match func_type {
194 Type::Not => {
196 let input = inputs.first().unwrap();
197 if let Some(v) = try_get_bool_constant(input) {
198 return ExprImpl::literal_bool(!v);
199 }
200 }
201 Type::IsFalse => {
202 let input = inputs.first().unwrap();
203 if input.is_null() {
204 return ExprImpl::literal_bool(false);
205 }
206 if let Some(v) = try_get_bool_constant(input) {
207 return ExprImpl::literal_bool(!v);
208 }
209 }
210 Type::IsTrue => {
211 let input = inputs.first().unwrap();
212 if input.is_null() {
213 return ExprImpl::literal_bool(false);
214 }
215 if let Some(v) = try_get_bool_constant(input) {
216 return ExprImpl::literal_bool(v);
217 }
218 }
219 Type::IsNull => {
220 let input = inputs.first().unwrap();
221 if input.is_null() {
222 return ExprImpl::literal_bool(true);
223 }
224 }
225 Type::IsNotTrue => {
226 let input = inputs.first().unwrap();
227 if input.is_null() {
228 return ExprImpl::literal_bool(true);
229 }
230 if let Some(v) = try_get_bool_constant(input) {
231 return ExprImpl::literal_bool(!v);
232 }
233 }
234 Type::IsNotFalse => {
235 let input = inputs.first().unwrap();
236 if input.is_null() {
237 return ExprImpl::literal_bool(true);
238 }
239 if let Some(v) = try_get_bool_constant(input) {
240 return ExprImpl::literal_bool(v);
241 }
242 }
243 Type::IsNotNull => {
244 let input = inputs.first().unwrap();
245 if let ExprImpl::Literal(lit) = input {
246 return ExprImpl::literal_bool(lit.get_data().is_some());
247 }
248 }
249 Type::And if contains_bool_constant => {
251 let (constant_lhs, rhs) = prepare_binary_function_inputs(inputs);
252 return boolean_constant_fold_and(constant_lhs, rhs);
253 }
254 Type::Or if contains_bool_constant => {
255 let (constant_lhs, rhs) = prepare_binary_function_inputs(inputs);
256 return boolean_constant_fold_or(constant_lhs, rhs);
257 }
258 _ => {}
259 }
260 FunctionCall::new_unchecked(func_type, inputs, ret).into()
261 }
262}
263
264pub fn try_get_bool_constant(expr: &ExprImpl) -> Option<bool> {
268 if let ExprImpl::Literal(l) = expr
269 && let Some(ScalarImpl::Bool(v)) = l.get_data()
270 {
271 return Some(*v);
272 }
273 None
274}
275
276fn boolean_constant_fold_and(constant_lhs: ExprImpl, rhs: ExprImpl) -> ExprImpl {
279 if try_get_bool_constant(&constant_lhs).unwrap() {
280 rhs
282 } else {
283 constant_lhs
285 }
286}
287
288fn boolean_constant_fold_or(constant_lhs: ExprImpl, rhs: ExprImpl) -> ExprImpl {
291 if try_get_bool_constant(&constant_lhs).unwrap() {
292 constant_lhs
294 } else {
295 rhs
297 }
298}
299
300pub fn push_down_not(expr: ExprImpl) -> ExprImpl {
303 let mut not_push_down = NotPushDown {};
304 not_push_down.rewrite_expr(expr)
305}
306
307struct NotPushDown {}
308
309impl ExprRewriter for NotPushDown {
310 fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
313 let (func_type, mut inputs, ret) = func_call.decompose();
314
315 if func_type != Type::Not {
316 let inputs = inputs
317 .into_iter()
318 .map(|expr| self.rewrite_expr(expr))
319 .collect();
320 FunctionCall::new_unchecked(func_type, inputs, ret).into()
321 } else {
322 assert_eq!(inputs.len(), 1);
326
327 let input = inputs.pop().unwrap();
328 let rewritten_not_expr = match input {
329 ExprImpl::FunctionCall(func) => {
330 let (func_type, mut inputs, ret) = func.decompose();
331 match func_type {
332 Type::Not => {
334 assert_eq!(inputs.len(), 1);
336 Ok(inputs.pop().unwrap())
337 }
338 Type::And => {
340 assert_eq!(inputs.len(), 2);
342 let rhs = inputs.pop().unwrap();
343 let lhs = inputs.pop().unwrap();
344 let rhs_not: ExprImpl =
345 FunctionCall::new(Type::Not, vec![rhs]).unwrap().into();
346 let lhs_not: ExprImpl =
347 FunctionCall::new(Type::Not, vec![lhs]).unwrap().into();
348 Ok(FunctionCall::new(Type::Or, vec![lhs_not, rhs_not])
349 .unwrap()
350 .into())
351 }
352 Type::Or => {
353 assert_eq!(inputs.len(), 2);
355 let rhs = inputs.pop().unwrap();
356 let lhs = inputs.pop().unwrap();
357 let rhs_not: ExprImpl =
358 FunctionCall::new(Type::Not, vec![rhs]).unwrap().into();
359 let lhs_not: ExprImpl =
360 FunctionCall::new(Type::Not, vec![lhs]).unwrap().into();
361 Ok(FunctionCall::new(Type::And, vec![lhs_not, rhs_not])
362 .unwrap()
363 .into())
364 }
365 _ => Err(FunctionCall::new_unchecked(func_type, inputs, ret).into()),
366 }
367 }
368 _ => Err(input),
369 };
370 match rewritten_not_expr {
371 Ok(res) => self.rewrite_expr(res),
373 Err(input) => FunctionCall::new(Type::Not, vec![self.rewrite_expr(input)])
376 .unwrap()
377 .into(),
378 }
379 }
380 }
381}
382
383pub fn factorization_expr(expr: ExprImpl) -> Vec<ExprImpl> {
384 let disjunctions: Vec<ExprImpl> = to_disjunctions(expr);
386
387 if disjunctions.len() == 1 {
389 return disjunctions;
390 }
391
392 let mut disjunctions: Vec<Vec<_>> = disjunctions
395 .into_iter()
396 .map(|x| to_conjunctions(x).into_iter().collect())
397 .collect();
398 let (last, remaining) = disjunctions.split_last_mut().unwrap();
399 let greatest_common_divider: Vec<_> = last
401 .extract_if(.., |factor| {
402 remaining.iter().all(|expr| expr.contains(factor))
403 })
404 .collect();
405 for disjunction in remaining {
406 disjunction.retain(|factor| !greatest_common_divider.contains(factor));
408 }
409 let remaining = ExprImpl::or(disjunctions.into_iter().map(ExprImpl::and));
411 greatest_common_divider
414 .into_iter()
415 .chain(std::iter::once(remaining))
416 .map(fold_boolean_constant)
417 .collect()
418}
419
420macro_rules! assert_input_ref {
423 ($expr:expr, $input_col_num:expr) => {
424 let _ = $expr.collect_input_refs($input_col_num);
425 };
426}
427pub(crate) use assert_input_ref;
428
429#[derive(Clone)]
434pub struct CollectInputRef {
435 input_bits: FixedBitSet,
437}
438
439impl ExprVisitor for CollectInputRef {
440 fn visit_input_ref(&mut self, expr: &InputRef) {
441 self.input_bits.insert(expr.index());
442 }
443}
444
445impl CollectInputRef {
446 pub fn new(initial_input_bits: FixedBitSet) -> Self {
448 CollectInputRef {
449 input_bits: initial_input_bits,
450 }
451 }
452
453 pub fn with_capacity(capacity: usize) -> Self {
455 CollectInputRef {
456 input_bits: FixedBitSet::with_capacity(capacity),
457 }
458 }
459}
460
461impl From<CollectInputRef> for FixedBitSet {
462 fn from(s: CollectInputRef) -> Self {
463 s.input_bits
464 }
465}
466
467impl Extend<usize> for CollectInputRef {
468 fn extend<T: IntoIterator<Item = usize>>(&mut self, iter: T) {
469 self.input_bits.extend(iter);
470 }
471}
472
473pub fn collect_input_refs<'a>(
478 input_col_num: usize,
479 exprs: impl IntoIterator<Item = &'a ExprImpl>,
480) -> FixedBitSet {
481 let mut input_ref_collector = CollectInputRef::with_capacity(input_col_num);
482 for expr in exprs {
483 input_ref_collector.visit_expr(expr);
484 }
485 input_ref_collector.into()
486}
487
488#[derive(Clone, Default)]
490pub struct CountNow {
491 count: usize,
492}
493
494impl CountNow {
495 pub fn count(&self) -> usize {
496 self.count
497 }
498}
499
500impl ExprVisitor for CountNow {
501 fn visit_now(&mut self, _: &super::Now) {
502 self.count += 1;
503 }
504}
505
506pub fn rewrite_now_to_proctime(expr: ExprImpl) -> ExprImpl {
507 let mut r = RewriteNowToProcTime;
508 r.rewrite_expr(expr)
509}
510
511#[cfg(test)]
512mod tests {
513 use risingwave_common::types::{DataType, ScalarImpl};
514 use risingwave_pb::expr::expr_node::Type;
515
516 use super::{fold_boolean_constant, push_down_not};
517 use crate::expr::{ExprImpl, FunctionCall, InputRef};
518
519 #[test]
520 fn constant_boolean_folding_basic_and() {
521 let expr: ExprImpl = FunctionCall::new(
523 Type::And,
524 vec![
525 InputRef::new(0, DataType::Boolean).into(),
526 ExprImpl::literal_bool(true),
527 ],
528 )
529 .unwrap()
530 .into();
531
532 let res = fold_boolean_constant(expr);
533
534 assert!(res.as_input_ref().is_some());
535 let res = res.as_input_ref().unwrap();
536 assert_eq!(res.index(), 0);
537
538 let expr: ExprImpl = FunctionCall::new(
540 Type::And,
541 vec![
542 InputRef::new(0, DataType::Boolean).into(),
543 ExprImpl::literal_bool(false),
544 ],
545 )
546 .unwrap()
547 .into();
548
549 let res = fold_boolean_constant(expr);
550 assert!(res.as_literal().is_some());
551 let res = res.as_literal().unwrap();
552 assert_eq!(*res.get_data(), Some(ScalarImpl::Bool(false)));
553 }
554
555 #[test]
556 fn constant_boolean_folding_basic_or() {
557 let expr: ExprImpl = FunctionCall::new(
559 Type::Or,
560 vec![
561 InputRef::new(0, DataType::Boolean).into(),
562 ExprImpl::literal_bool(true),
563 ],
564 )
565 .unwrap()
566 .into();
567
568 let res = fold_boolean_constant(expr);
569 assert!(res.as_literal().is_some());
570 let res = res.as_literal().unwrap();
571 assert_eq!(*res.get_data(), Some(ScalarImpl::Bool(true)));
572
573 let expr: ExprImpl = FunctionCall::new(
575 Type::Or,
576 vec![
577 InputRef::new(0, DataType::Boolean).into(),
578 ExprImpl::literal_bool(false),
579 ],
580 )
581 .unwrap()
582 .into();
583
584 let res = fold_boolean_constant(expr);
585
586 assert!(res.as_input_ref().is_some());
587 let res = res.as_input_ref().unwrap();
588 assert_eq!(res.index(), 0);
589 }
590
591 #[test]
592 fn constant_boolean_folding_complex() {
593 let expr: ExprImpl = FunctionCall::new(
595 Type::And,
596 vec![
597 FunctionCall::new(
598 Type::And,
599 vec![ExprImpl::literal_bool(false), ExprImpl::literal_bool(true)],
600 )
601 .unwrap()
602 .into(),
603 FunctionCall::new(
604 Type::Or,
605 vec![
606 ExprImpl::literal_bool(true),
607 FunctionCall::new(
608 Type::Equal,
609 vec![ExprImpl::literal_int(1), ExprImpl::literal_int(2)],
610 )
611 .unwrap()
612 .into(),
613 ],
614 )
615 .unwrap()
616 .into(),
617 ],
618 )
619 .unwrap()
620 .into();
621
622 let res = fold_boolean_constant(expr);
623
624 assert!(res.as_literal().is_some());
625 let res = res.as_literal().unwrap();
626 assert_eq!(*res.get_data(), Some(ScalarImpl::Bool(false)));
627 }
628
629 #[test]
630 fn not_push_down_test() {
631 let expr: ExprImpl = FunctionCall::new(
633 Type::Not,
634 vec![
635 FunctionCall::new(Type::Not, vec![InputRef::new(0, DataType::Boolean).into()])
636 .unwrap()
637 .into(),
638 ],
639 )
640 .unwrap()
641 .into();
642 let res = push_down_not(expr);
643 assert!(res.as_input_ref().is_some());
644 let expr: ExprImpl = FunctionCall::new(
646 Type::Not,
647 vec![
648 FunctionCall::new(
649 Type::And,
650 vec![
651 InputRef::new(0, DataType::Boolean).into(),
652 FunctionCall::new(
653 Type::Not,
654 vec![InputRef::new(1, DataType::Boolean).into()],
655 )
656 .unwrap()
657 .into(),
658 ],
659 )
660 .unwrap()
661 .into(),
662 ],
663 )
664 .unwrap()
665 .into();
666 let res = push_down_not(expr);
667 assert!(res.as_function_call().is_some());
668 let res = res.as_function_call().unwrap().clone();
669 let (func, lhs, rhs) = res.decompose_as_binary();
670 assert_eq!(func, Type::Or);
671 assert!(rhs.as_input_ref().is_some());
672 assert!(lhs.as_function_call().is_some());
673 let lhs = lhs.as_function_call().unwrap().clone();
674 let (func, input) = lhs.decompose_as_unary();
675 assert_eq!(func, Type::Not);
676 assert!(input.as_input_ref().is_some());
677 let expr: ExprImpl = FunctionCall::new(
679 Type::Not,
680 vec![
681 FunctionCall::new(
682 Type::Or,
683 vec![
684 InputRef::new(0, DataType::Boolean).into(),
685 InputRef::new(1, DataType::Boolean).into(),
686 ],
687 )
688 .unwrap()
689 .into(),
690 ],
691 )
692 .unwrap()
693 .into();
694 let res = push_down_not(expr);
695 assert!(res.as_function_call().is_some());
696 let (func_type, lhs, rhs) = res
697 .as_function_call()
698 .unwrap()
699 .clone()
700 .decompose_as_binary();
701 assert_eq!(func_type, Type::And);
702 let (lhs_type, lhs_input) = lhs.as_function_call().unwrap().clone().decompose_as_unary();
703 assert_eq!(lhs_type, Type::Not);
704 assert!(lhs_input.as_input_ref().is_some());
705 let (rhs_type, rhs_input) = rhs.as_function_call().unwrap().clone().decompose_as_unary();
706 assert_eq!(rhs_type, Type::Not);
707 assert!(rhs_input.as_input_ref().is_some());
708 }
709}