1use std::collections::VecDeque;
16
17use fixedbitset::FixedBitSet;
18use risingwave_common::types::{DataType, ScalarImpl};
19use risingwave_pb::expr::expr_node::Type;
20
21use super::now::RewriteNowToProcTime;
22use super::{Expr, ExprImpl, ExprRewriter, ExprVisitor, FunctionCall, InputRef};
23use crate::expr::ExprType;
24
25fn split_expr_by(expr: ExprImpl, op: ExprType, rets: &mut Vec<ExprImpl>) {
26 match expr {
27 ExprImpl::FunctionCall(func_call) if func_call.func_type() == op => {
28 let (_, exprs, _) = func_call.decompose();
29 for expr in exprs {
30 split_expr_by(expr, op, rets);
31 }
32 }
33 _ => rets.push(expr),
34 }
35}
36
37pub(super) fn merge_expr_by_logical<I>(exprs: I, op: ExprType, identity_elem: ExprImpl) -> ExprImpl
41where
42 I: IntoIterator<Item = ExprImpl>,
43{
44 let mut exprs: VecDeque<_> = exprs.into_iter().map(|e| (0usize, e)).collect();
45
46 while exprs.len() > 1 {
47 let (level, lhs) = exprs.pop_front().unwrap();
48 let rhs_level = exprs.front().unwrap().0;
49
50 if level < rhs_level {
52 exprs.push_back((level, lhs));
53 } else {
54 let rhs = exprs.pop_front().unwrap().1;
55 let new_expr = FunctionCall::new(op, vec![lhs, rhs]).unwrap().into();
56 exprs.push_back((level + 1, new_expr));
57 }
58 }
59
60 exprs.pop_front().map(|(_, e)| e).unwrap_or(identity_elem)
61}
62
63pub fn to_conjunctions(expr: ExprImpl) -> Vec<ExprImpl> {
66 let mut rets = vec![];
67 split_expr_by(expr, ExprType::And, &mut rets);
68 rets
69}
70
71pub fn to_disjunctions(expr: ExprImpl) -> Vec<ExprImpl> {
74 let mut rets = vec![];
75 split_expr_by(expr, ExprType::Or, &mut rets);
76 rets
77}
78
79pub fn fold_boolean_constant(expr: ExprImpl) -> ExprImpl {
82 let mut rewriter = BooleanConstantFolding {};
83 rewriter.rewrite_expr(expr)
84}
85
86pub fn column_self_eq_eliminate(expr: ExprImpl) -> ExprImpl {
88 ColumnSelfEqualRewriter::rewrite(expr)
89}
90
91pub struct ColumnSelfEqualRewriter {}
97
98impl ColumnSelfEqualRewriter {
99 fn extract_column(expr: ExprImpl, columns: &mut Vec<ExprImpl>) {
101 match expr.clone() {
102 ExprImpl::FunctionCall(func_call) => {
103 if Self::is_not_null(func_call.func_type()) {
105 return;
106 }
107 for sub_expr in func_call.inputs() {
108 Self::extract_column(sub_expr.clone(), columns);
109 }
110 }
111 ExprImpl::InputRef(_) => {
112 if !columns.contains(&expr) {
113 columns.push(expr);
115 }
116 }
117 _ => (),
118 }
119 }
120
121 fn is_not_null(func_type: ExprType) -> bool {
123 func_type == ExprType::IsNull
124 || func_type == ExprType::IsNotNull
125 || func_type == ExprType::IsTrue
126 || func_type == ExprType::IsFalse
127 || func_type == ExprType::IsNotTrue
128 || func_type == ExprType::IsNotFalse
129 }
130
131 pub fn rewrite(expr: ExprImpl) -> ExprImpl {
132 let mut columns = vec![];
133 Self::extract_column(expr.clone(), &mut columns);
134 if columns.len() > 1 {
135 return expr;
137 }
138
139 let ExprImpl::FunctionCall(func_call) = expr.clone() else {
141 return expr;
142 };
143 if func_call.func_type() != ExprType::Equal || func_call.inputs().len() != 2 {
144 return expr;
145 }
146 assert_eq!(func_call.return_type(), DataType::Boolean);
147 let inputs = func_call.inputs();
148 let e1 = inputs[0].clone();
149 let e2 = inputs[1].clone();
150
151 if e1 == e2 {
152 if columns.is_empty() {
153 return ExprImpl::literal_bool(true);
154 }
155 let Ok(ret) = FunctionCall::new(ExprType::IsNotNull, vec![columns[0].clone()]) else {
156 return expr;
157 };
158 ret.into()
159 } else {
160 expr
161 }
162 }
163}
164
165struct BooleanConstantFolding {}
167
168impl ExprRewriter for BooleanConstantFolding {
169 fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
175 let (func_type, inputs, ret) = func_call.decompose();
176 let inputs: Vec<_> = inputs
177 .into_iter()
178 .map(|expr| self.rewrite_expr(expr))
179 .collect();
180 let bool_constant_values: Vec<Option<bool>> =
181 inputs.iter().map(try_get_bool_constant).collect();
182 let contains_bool_constant = bool_constant_values.iter().any(|x| x.is_some());
183 let prepare_binary_function_inputs = |mut inputs: Vec<ExprImpl>| -> (ExprImpl, ExprImpl) {
185 assert_eq!(inputs.len(), 2);
187 let rhs = inputs.pop().unwrap();
188 let lhs = inputs.pop().unwrap();
189 if bool_constant_values[0].is_some() {
190 (lhs, rhs)
191 } else {
192 (rhs, lhs)
193 }
194 };
195 match func_type {
196 Type::Not => {
198 let input = inputs.first().unwrap();
199 if let Some(v) = try_get_bool_constant(input) {
200 return ExprImpl::literal_bool(!v);
201 }
202 }
203 Type::IsFalse => {
204 let input = inputs.first().unwrap();
205 if input.is_null() {
206 return ExprImpl::literal_bool(false);
207 }
208 if let Some(v) = try_get_bool_constant(input) {
209 return ExprImpl::literal_bool(!v);
210 }
211 }
212 Type::IsTrue => {
213 let input = inputs.first().unwrap();
214 if input.is_null() {
215 return ExprImpl::literal_bool(false);
216 }
217 if let Some(v) = try_get_bool_constant(input) {
218 return ExprImpl::literal_bool(v);
219 }
220 }
221 Type::IsNull => {
222 let input = inputs.first().unwrap();
223 if input.is_null() {
224 return ExprImpl::literal_bool(true);
225 }
226 }
227 Type::IsNotTrue => {
228 let input = inputs.first().unwrap();
229 if input.is_null() {
230 return ExprImpl::literal_bool(true);
231 }
232 if let Some(v) = try_get_bool_constant(input) {
233 return ExprImpl::literal_bool(!v);
234 }
235 }
236 Type::IsNotFalse => {
237 let input = inputs.first().unwrap();
238 if input.is_null() {
239 return ExprImpl::literal_bool(true);
240 }
241 if let Some(v) = try_get_bool_constant(input) {
242 return ExprImpl::literal_bool(v);
243 }
244 }
245 Type::IsNotNull => {
246 let input = inputs.first().unwrap();
247 if let ExprImpl::Literal(lit) = input {
248 return ExprImpl::literal_bool(lit.get_data().is_some());
249 }
250 }
251 Type::And if contains_bool_constant => {
253 let (constant_lhs, rhs) = prepare_binary_function_inputs(inputs);
254 return boolean_constant_fold_and(constant_lhs, rhs);
255 }
256 Type::Or if contains_bool_constant => {
257 let (constant_lhs, rhs) = prepare_binary_function_inputs(inputs);
258 return boolean_constant_fold_or(constant_lhs, rhs);
259 }
260 _ => {}
261 }
262 FunctionCall::new_unchecked(func_type, inputs, ret).into()
263 }
264}
265
266pub fn try_get_bool_constant(expr: &ExprImpl) -> Option<bool> {
270 if let ExprImpl::Literal(l) = expr {
271 if let Some(ScalarImpl::Bool(v)) = l.get_data() {
272 return Some(*v);
273 }
274 }
275 None
276}
277
278fn boolean_constant_fold_and(constant_lhs: ExprImpl, rhs: ExprImpl) -> ExprImpl {
281 if try_get_bool_constant(&constant_lhs).unwrap() {
282 rhs
284 } else {
285 constant_lhs
287 }
288}
289
290fn boolean_constant_fold_or(constant_lhs: ExprImpl, rhs: ExprImpl) -> ExprImpl {
293 if try_get_bool_constant(&constant_lhs).unwrap() {
294 constant_lhs
296 } else {
297 rhs
299 }
300}
301
302pub fn push_down_not(expr: ExprImpl) -> ExprImpl {
305 let mut not_push_down = NotPushDown {};
306 not_push_down.rewrite_expr(expr)
307}
308
309struct NotPushDown {}
310
311impl ExprRewriter for NotPushDown {
312 fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
315 let (func_type, mut inputs, ret) = func_call.decompose();
316
317 if func_type != Type::Not {
318 let inputs = inputs
319 .into_iter()
320 .map(|expr| self.rewrite_expr(expr))
321 .collect();
322 FunctionCall::new_unchecked(func_type, inputs, ret).into()
323 } else {
324 assert_eq!(inputs.len(), 1);
328
329 let input = inputs.pop().unwrap();
330 let rewritten_not_expr = match input {
331 ExprImpl::FunctionCall(func) => {
332 let (func_type, mut inputs, ret) = func.decompose();
333 match func_type {
334 Type::Not => {
336 assert_eq!(inputs.len(), 1);
338 Ok(inputs.pop().unwrap())
339 }
340 Type::And => {
342 assert_eq!(inputs.len(), 2);
344 let rhs = inputs.pop().unwrap();
345 let lhs = inputs.pop().unwrap();
346 let rhs_not: ExprImpl =
347 FunctionCall::new(Type::Not, vec![rhs]).unwrap().into();
348 let lhs_not: ExprImpl =
349 FunctionCall::new(Type::Not, vec![lhs]).unwrap().into();
350 Ok(FunctionCall::new(Type::Or, vec![lhs_not, rhs_not])
351 .unwrap()
352 .into())
353 }
354 Type::Or => {
355 assert_eq!(inputs.len(), 2);
357 let rhs = inputs.pop().unwrap();
358 let lhs = inputs.pop().unwrap();
359 let rhs_not: ExprImpl =
360 FunctionCall::new(Type::Not, vec![rhs]).unwrap().into();
361 let lhs_not: ExprImpl =
362 FunctionCall::new(Type::Not, vec![lhs]).unwrap().into();
363 Ok(FunctionCall::new(Type::And, vec![lhs_not, rhs_not])
364 .unwrap()
365 .into())
366 }
367 _ => Err(FunctionCall::new_unchecked(func_type, inputs, ret).into()),
368 }
369 }
370 _ => Err(input),
371 };
372 match rewritten_not_expr {
373 Ok(res) => self.rewrite_expr(res),
375 Err(input) => FunctionCall::new(Type::Not, vec![self.rewrite_expr(input)])
378 .unwrap()
379 .into(),
380 }
381 }
382 }
383}
384
385pub fn factorization_expr(expr: ExprImpl) -> Vec<ExprImpl> {
386 let disjunctions: Vec<ExprImpl> = to_disjunctions(expr);
388
389 if disjunctions.len() == 1 {
391 return disjunctions;
392 }
393
394 let mut disjunctions: Vec<Vec<_>> = disjunctions
397 .into_iter()
398 .map(|x| to_conjunctions(x).into_iter().collect())
399 .collect();
400 let (last, remaining) = disjunctions.split_last_mut().unwrap();
401 let greatest_common_divider: Vec<_> = last
403 .extract_if(.., |factor| {
404 remaining.iter().all(|expr| expr.contains(factor))
405 })
406 .collect();
407 for disjunction in remaining {
408 disjunction.retain(|factor| !greatest_common_divider.contains(factor));
410 }
411 let remaining = ExprImpl::or(disjunctions.into_iter().map(ExprImpl::and));
413 greatest_common_divider
416 .into_iter()
417 .chain(std::iter::once(remaining))
418 .map(fold_boolean_constant)
419 .collect()
420}
421
422macro_rules! assert_input_ref {
425 ($expr:expr, $input_col_num:expr) => {
426 let _ = $expr.collect_input_refs($input_col_num);
427 };
428}
429pub(crate) use assert_input_ref;
430
431#[derive(Clone)]
436pub struct CollectInputRef {
437 input_bits: FixedBitSet,
439}
440
441impl ExprVisitor for CollectInputRef {
442 fn visit_input_ref(&mut self, expr: &InputRef) {
443 self.input_bits.insert(expr.index());
444 }
445}
446
447impl CollectInputRef {
448 pub fn new(initial_input_bits: FixedBitSet) -> Self {
450 CollectInputRef {
451 input_bits: initial_input_bits,
452 }
453 }
454
455 pub fn with_capacity(capacity: usize) -> Self {
457 CollectInputRef {
458 input_bits: FixedBitSet::with_capacity(capacity),
459 }
460 }
461}
462
463impl From<CollectInputRef> for FixedBitSet {
464 fn from(s: CollectInputRef) -> Self {
465 s.input_bits
466 }
467}
468
469impl Extend<usize> for CollectInputRef {
470 fn extend<T: IntoIterator<Item = usize>>(&mut self, iter: T) {
471 self.input_bits.extend(iter);
472 }
473}
474
475pub fn collect_input_refs<'a>(
480 input_col_num: usize,
481 exprs: impl IntoIterator<Item = &'a ExprImpl>,
482) -> FixedBitSet {
483 let mut input_ref_collector = CollectInputRef::with_capacity(input_col_num);
484 for expr in exprs {
485 input_ref_collector.visit_expr(expr);
486 }
487 input_ref_collector.into()
488}
489
490#[derive(Clone, Default)]
492pub struct CountNow {
493 count: usize,
494}
495
496impl CountNow {
497 pub fn count(&self) -> usize {
498 self.count
499 }
500}
501
502impl ExprVisitor for CountNow {
503 fn visit_now(&mut self, _: &super::Now) {
504 self.count += 1;
505 }
506}
507
508pub fn rewrite_now_to_proctime(expr: ExprImpl) -> ExprImpl {
509 let mut r = RewriteNowToProcTime;
510 r.rewrite_expr(expr)
511}
512
513#[cfg(test)]
514mod tests {
515 use risingwave_common::types::{DataType, ScalarImpl};
516 use risingwave_pb::expr::expr_node::Type;
517
518 use super::{fold_boolean_constant, push_down_not};
519 use crate::expr::{ExprImpl, FunctionCall, InputRef};
520
521 #[test]
522 fn constant_boolean_folding_basic_and() {
523 let expr: ExprImpl = FunctionCall::new(
525 Type::And,
526 vec![
527 InputRef::new(0, DataType::Boolean).into(),
528 ExprImpl::literal_bool(true),
529 ],
530 )
531 .unwrap()
532 .into();
533
534 let res = fold_boolean_constant(expr);
535
536 assert!(res.as_input_ref().is_some());
537 let res = res.as_input_ref().unwrap();
538 assert_eq!(res.index(), 0);
539
540 let expr: ExprImpl = FunctionCall::new(
542 Type::And,
543 vec![
544 InputRef::new(0, DataType::Boolean).into(),
545 ExprImpl::literal_bool(false),
546 ],
547 )
548 .unwrap()
549 .into();
550
551 let res = fold_boolean_constant(expr);
552 assert!(res.as_literal().is_some());
553 let res = res.as_literal().unwrap();
554 assert_eq!(*res.get_data(), Some(ScalarImpl::Bool(false)));
555 }
556
557 #[test]
558 fn constant_boolean_folding_basic_or() {
559 let expr: ExprImpl = FunctionCall::new(
561 Type::Or,
562 vec![
563 InputRef::new(0, DataType::Boolean).into(),
564 ExprImpl::literal_bool(true),
565 ],
566 )
567 .unwrap()
568 .into();
569
570 let res = fold_boolean_constant(expr);
571 assert!(res.as_literal().is_some());
572 let res = res.as_literal().unwrap();
573 assert_eq!(*res.get_data(), Some(ScalarImpl::Bool(true)));
574
575 let expr: ExprImpl = FunctionCall::new(
577 Type::Or,
578 vec![
579 InputRef::new(0, DataType::Boolean).into(),
580 ExprImpl::literal_bool(false),
581 ],
582 )
583 .unwrap()
584 .into();
585
586 let res = fold_boolean_constant(expr);
587
588 assert!(res.as_input_ref().is_some());
589 let res = res.as_input_ref().unwrap();
590 assert_eq!(res.index(), 0);
591 }
592
593 #[test]
594 fn constant_boolean_folding_complex() {
595 let expr: ExprImpl = FunctionCall::new(
597 Type::And,
598 vec![
599 FunctionCall::new(
600 Type::And,
601 vec![ExprImpl::literal_bool(false), ExprImpl::literal_bool(true)],
602 )
603 .unwrap()
604 .into(),
605 FunctionCall::new(
606 Type::Or,
607 vec![
608 ExprImpl::literal_bool(true),
609 FunctionCall::new(
610 Type::Equal,
611 vec![ExprImpl::literal_int(1), ExprImpl::literal_int(2)],
612 )
613 .unwrap()
614 .into(),
615 ],
616 )
617 .unwrap()
618 .into(),
619 ],
620 )
621 .unwrap()
622 .into();
623
624 let res = fold_boolean_constant(expr);
625
626 assert!(res.as_literal().is_some());
627 let res = res.as_literal().unwrap();
628 assert_eq!(*res.get_data(), Some(ScalarImpl::Bool(false)));
629 }
630
631 #[test]
632 fn not_push_down_test() {
633 let expr: ExprImpl = FunctionCall::new(
635 Type::Not,
636 vec![
637 FunctionCall::new(Type::Not, vec![InputRef::new(0, DataType::Boolean).into()])
638 .unwrap()
639 .into(),
640 ],
641 )
642 .unwrap()
643 .into();
644 let res = push_down_not(expr);
645 assert!(res.as_input_ref().is_some());
646 let expr: ExprImpl = FunctionCall::new(
648 Type::Not,
649 vec![
650 FunctionCall::new(
651 Type::And,
652 vec![
653 InputRef::new(0, DataType::Boolean).into(),
654 FunctionCall::new(
655 Type::Not,
656 vec![InputRef::new(1, DataType::Boolean).into()],
657 )
658 .unwrap()
659 .into(),
660 ],
661 )
662 .unwrap()
663 .into(),
664 ],
665 )
666 .unwrap()
667 .into();
668 let res = push_down_not(expr);
669 assert!(res.as_function_call().is_some());
670 let res = res.as_function_call().unwrap().clone();
671 let (func, lhs, rhs) = res.decompose_as_binary();
672 assert_eq!(func, Type::Or);
673 assert!(rhs.as_input_ref().is_some());
674 assert!(lhs.as_function_call().is_some());
675 let lhs = lhs.as_function_call().unwrap().clone();
676 let (func, input) = lhs.decompose_as_unary();
677 assert_eq!(func, Type::Not);
678 assert!(input.as_input_ref().is_some());
679 let expr: ExprImpl = FunctionCall::new(
681 Type::Not,
682 vec![
683 FunctionCall::new(
684 Type::Or,
685 vec![
686 InputRef::new(0, DataType::Boolean).into(),
687 InputRef::new(1, DataType::Boolean).into(),
688 ],
689 )
690 .unwrap()
691 .into(),
692 ],
693 )
694 .unwrap()
695 .into();
696 let res = push_down_not(expr);
697 assert!(res.as_function_call().is_some());
698 let (func_type, lhs, rhs) = res
699 .as_function_call()
700 .unwrap()
701 .clone()
702 .decompose_as_binary();
703 assert_eq!(func_type, Type::And);
704 let (lhs_type, lhs_input) = lhs.as_function_call().unwrap().clone().decompose_as_unary();
705 assert_eq!(lhs_type, Type::Not);
706 assert!(lhs_input.as_input_ref().is_some());
707 let (rhs_type, rhs_input) = rhs.as_function_call().unwrap().clone().decompose_as_unary();
708 assert_eq!(rhs_type, Type::Not);
709 assert!(rhs_input.as_input_ref().is_some());
710 }
711}