1use std::cmp::Ordering;
16use std::collections::{BTreeMap, HashSet};
17use std::fmt::{self, Debug};
18use std::ops::Bound;
19use std::rc::Rc;
20use std::sync::LazyLock;
21
22use fixedbitset::FixedBitSet;
23use itertools::Itertools;
24use risingwave_common::catalog::{Schema, TableDesc};
25use risingwave_common::types::{DataType, DefaultOrd, ScalarImpl};
26use risingwave_common::util::iter_util::ZipEqFast;
27use risingwave_common::util::scan_range::{ScanRange, is_full_range};
28use risingwave_common::util::sort_util::{OrderType, cmp_rows};
29
30use crate::error::Result;
31use crate::expr::{
32 ExprDisplay, ExprImpl, ExprMutator, ExprRewriter, ExprType, ExprVisitor, FunctionCall,
33 InequalityInputPair, InputRef, collect_input_refs, column_self_eq_eliminate,
34 factorization_expr, fold_boolean_constant, push_down_not, to_conjunctions,
35 try_get_bool_constant,
36};
37use crate::utils::condition::cast_compare::{ResultForCmp, ResultForEq};
38
39#[derive(Debug, Clone, PartialEq, Eq, Hash)]
40pub struct Condition {
41 pub conjunctions: Vec<ExprImpl>,
43}
44
45impl IntoIterator for Condition {
46 type IntoIter = std::vec::IntoIter<ExprImpl>;
47 type Item = ExprImpl;
48
49 fn into_iter(self) -> Self::IntoIter {
50 self.conjunctions.into_iter()
51 }
52}
53
54impl fmt::Display for Condition {
55 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56 let mut conjunctions = self.conjunctions.iter();
57 if let Some(expr) = conjunctions.next() {
58 write!(f, "{:?}", expr)?;
59 }
60 if self.always_true() {
61 write!(f, "true")?;
62 } else {
63 for expr in conjunctions {
64 write!(f, " AND {:?}", expr)?;
65 }
66 }
67 Ok(())
68 }
69}
70
71impl Condition {
72 pub fn with_expr(expr: ExprImpl) -> Self {
73 let conjunctions = to_conjunctions(expr);
74
75 Self { conjunctions }.simplify()
76 }
77
78 pub fn true_cond() -> Self {
79 Self {
80 conjunctions: vec![],
81 }
82 }
83
84 pub fn false_cond() -> Self {
85 Self {
86 conjunctions: vec![ExprImpl::literal_bool(false)],
87 }
88 }
89
90 pub fn always_true(&self) -> bool {
91 self.conjunctions.is_empty()
92 }
93
94 pub fn always_false(&self) -> bool {
95 static FALSE: LazyLock<ExprImpl> = LazyLock::new(|| ExprImpl::literal_bool(false));
96 !self.conjunctions.is_empty() && self.conjunctions.contains(&*FALSE)
98 }
99
100 pub fn as_expr_unless_true(&self) -> Option<ExprImpl> {
102 if self.always_true() {
103 None
104 } else {
105 Some(self.clone().into())
106 }
107 }
108
109 #[must_use]
110 pub fn and(self, other: Self) -> Self {
111 let mut ret = self;
112 ret.conjunctions.extend(other.conjunctions);
113 ret.simplify()
114 }
115
116 #[must_use]
117 pub fn or(self, other: Self) -> Self {
118 let or_expr = ExprImpl::FunctionCall(
119 FunctionCall::new_unchecked(
120 ExprType::Or,
121 vec![self.into(), other.into()],
122 DataType::Boolean,
123 )
124 .into(),
125 );
126 let ret = Self::with_expr(or_expr);
127 ret.simplify()
128 }
129
130 #[must_use]
132 pub fn split(self, left_col_num: usize, right_col_num: usize) -> (Self, Self, Self) {
133 let left_bit_map = FixedBitSet::from_iter(0..left_col_num);
134 let right_bit_map = FixedBitSet::from_iter(left_col_num..left_col_num + right_col_num);
135
136 self.group_by::<_, 3>(|expr| {
137 let input_bits = expr.collect_input_refs(left_col_num + right_col_num);
138 if input_bits.is_subset(&left_bit_map) {
139 0
140 } else if input_bits.is_subset(&right_bit_map) {
141 1
142 } else {
143 2
144 }
145 })
146 .into_iter()
147 .next_tuple()
148 .unwrap()
149 }
150
151 pub fn collect_input_refs(&self, input_col_num: usize) -> FixedBitSet {
156 collect_input_refs(input_col_num, &self.conjunctions)
157 }
158
159 #[must_use]
173 pub fn split_by_input_col_nums(
174 self,
175 input_col_nums: &[usize],
176 only_eq: bool,
177 ) -> (BTreeMap<(usize, usize), Self>, Self) {
178 let mut bitmaps = Vec::with_capacity(input_col_nums.len());
179 let mut cols_seen = 0;
180 for cols in input_col_nums {
181 bitmaps.push(FixedBitSet::from_iter(cols_seen..cols_seen + cols));
182 cols_seen += cols;
183 }
184
185 let mut pairwise_conditions = BTreeMap::new();
186 let mut non_eq_join = vec![];
187
188 for expr in self.conjunctions {
189 let input_bits = expr.collect_input_refs(cols_seen);
190 let mut subset_indices = Vec::with_capacity(input_col_nums.len());
191 for (idx, bitmap) in bitmaps.iter().enumerate() {
192 if !input_bits.is_disjoint(bitmap) {
193 subset_indices.push(idx);
194 }
195 }
196 if subset_indices.len() != 2 || (only_eq && expr.as_eq_cond().is_none()) {
197 non_eq_join.push(expr);
198 } else {
199 let key = if subset_indices[0] < subset_indices[1] {
201 (subset_indices[0], subset_indices[1])
202 } else {
203 (subset_indices[1], subset_indices[0])
204 };
205 let e = pairwise_conditions
206 .entry(key)
207 .or_insert_with(Condition::true_cond);
208 e.conjunctions.push(expr);
209 }
210 }
211 (
212 pairwise_conditions,
213 Condition {
214 conjunctions: non_eq_join,
215 },
216 )
217 }
218
219 #[must_use]
220 pub fn split_eq_keys(
227 self,
228 left_col_num: usize,
229 right_col_num: usize,
230 ) -> (Vec<(InputRef, InputRef, bool)>, Self) {
231 let left_bit_map = FixedBitSet::from_iter(0..left_col_num);
232 let right_bit_map = FixedBitSet::from_iter(left_col_num..left_col_num + right_col_num);
233
234 let (mut eq_keys, mut others) = (vec![], vec![]);
235 self.conjunctions.into_iter().for_each(|expr| {
236 let input_bits = expr.collect_input_refs(left_col_num + right_col_num);
237 if input_bits.is_disjoint(&left_bit_map) || input_bits.is_disjoint(&right_bit_map) {
238 others.push(expr)
239 } else if let Some(columns) = expr.as_eq_cond() {
240 eq_keys.push((columns.0, columns.1, false));
241 } else if let Some(columns) = expr.as_is_not_distinct_from_cond() {
242 eq_keys.push((columns.0, columns.1, true));
243 } else {
244 others.push(expr)
245 }
246 });
247
248 (
249 eq_keys,
250 Condition {
251 conjunctions: others,
252 },
253 )
254 }
255
256 pub(crate) fn extract_inequality_keys(
263 &self,
264 left_col_num: usize,
265 right_col_num: usize,
266 ) -> Vec<(usize, InequalityInputPair)> {
267 let left_bit_map = FixedBitSet::from_iter(0..left_col_num);
268 let right_bit_map = FixedBitSet::from_iter(left_col_num..left_col_num + right_col_num);
269
270 self.conjunctions
271 .iter()
272 .enumerate()
273 .filter_map(|(conjunction_idx, expr)| {
274 let input_bits = expr.collect_input_refs(left_col_num + right_col_num);
275 if input_bits.is_disjoint(&left_bit_map) || input_bits.is_disjoint(&right_bit_map) {
276 None
277 } else {
278 expr.as_input_comparison_cond()
279 .map(|inequality_pair| (conjunction_idx, inequality_pair))
280 }
281 })
282 .collect_vec()
283 }
284
285 #[must_use]
288 pub fn split_disjoint(self, columns: &FixedBitSet) -> (Self, Self) {
289 self.group_by::<_, 2>(|expr| {
290 let input_bits = expr.collect_input_refs(columns.len());
291 input_bits.is_disjoint(columns) as usize
292 })
293 .into_iter()
294 .next_tuple()
295 .unwrap()
296 }
297
298 fn disjunctions_to_scan_ranges(
302 table_desc: Rc<TableDesc>,
303 max_split_range_gap: u64,
304 disjunctions: Vec<ExprImpl>,
305 ) -> Result<Option<(Vec<ScanRange>, bool)>> {
306 let disjunctions_result: Result<Vec<(Vec<ScanRange>, Self)>> = disjunctions
307 .into_iter()
308 .map(|x| {
309 Condition {
310 conjunctions: to_conjunctions(x),
311 }
312 .split_to_scan_ranges(table_desc.clone(), max_split_range_gap)
313 })
314 .collect();
315
316 let disjunctions_result = disjunctions_result?;
318
319 let all_equal = disjunctions_result
322 .iter()
323 .all(|(scan_ranges, other_condition)| {
324 other_condition.always_true()
325 && scan_ranges
326 .iter()
327 .all(|x| !x.eq_conds.is_empty() && is_full_range(&x.range))
328 });
329
330 if all_equal {
331 let scan_ranges = disjunctions_result
335 .into_iter()
336 .flat_map(|(scan_ranges, _)| scan_ranges)
337 .sorted_by(|a, b| a.eq_conds.len().cmp(&b.eq_conds.len()))
339 .collect_vec();
340 let mut non_overlap_scan_ranges: Vec<ScanRange> = vec![];
342 for s1 in &scan_ranges {
343 let overlap = non_overlap_scan_ranges.iter().any(|s2| {
344 #[allow(clippy::disallowed_methods)]
345 s1.eq_conds
346 .iter()
347 .zip(s2.eq_conds.iter())
348 .all(|(a, b)| a == b)
349 });
350 if !overlap {
354 non_overlap_scan_ranges.push(s1.clone());
355 }
356 }
357
358 Ok(Some((non_overlap_scan_ranges, false)))
359 } else {
360 let mut scan_ranges = vec![];
361 for (scan_ranges_chunk, _) in disjunctions_result {
362 if scan_ranges_chunk.is_empty() {
363 return Ok(None);
365 }
366
367 scan_ranges.extend(scan_ranges_chunk);
368 }
369
370 let order_types = table_desc
371 .pk
372 .iter()
373 .cloned()
374 .map(|x| {
375 if x.order_type.is_descending() {
376 x.order_type.reverse()
377 } else {
378 x.order_type
379 }
380 })
381 .collect_vec();
382 scan_ranges.sort_by(|left, right| {
383 let (left_start, _left_end) = &left.convert_to_range();
384 let (right_start, _right_end) = &right.convert_to_range();
385
386 let left_start_vec = match &left_start {
387 Bound::Included(vec) | Bound::Excluded(vec) => vec,
388 _ => &vec![],
389 };
390 let right_start_vec = match &right_start {
391 Bound::Included(vec) | Bound::Excluded(vec) => vec,
392 _ => &vec![],
393 };
394
395 if left_start_vec.is_empty() && right_start_vec.is_empty() {
396 return Ordering::Less;
397 }
398
399 if left_start_vec.is_empty() {
400 return Ordering::Less;
401 }
402
403 if right_start_vec.is_empty() {
404 return Ordering::Greater;
405 }
406
407 let cmp_column_len = left_start_vec.len().min(right_start_vec.len());
408 cmp_rows(
409 &left_start_vec[0..cmp_column_len],
410 &right_start_vec[0..cmp_column_len],
411 &order_types[0..cmp_column_len],
412 )
413 });
414
415 if scan_ranges.is_empty() {
416 return Ok(None);
417 }
418
419 if scan_ranges.len() == 1 {
420 return Ok(Some((scan_ranges, true)));
421 }
422
423 let mut output_scan_ranges: Vec<ScanRange> = vec![];
424 output_scan_ranges.push(scan_ranges[0].clone());
425 let mut idx = 1;
426 loop {
427 if idx >= scan_ranges.len() {
428 break;
429 }
430
431 let scan_range_left = output_scan_ranges.last_mut().unwrap();
432 let scan_range_right = &scan_ranges[idx];
433
434 if scan_range_left.eq_conds == scan_range_right.eq_conds {
435 if !ScanRange::is_overlap(scan_range_left, scan_range_right, &order_types) {
438 output_scan_ranges.push(scan_range_right.clone());
440 idx += 1;
441 continue;
442 }
443
444 fn merge_bound(
446 left_scan_range: &Bound<Vec<Option<ScalarImpl>>>,
447 right_scan_range: &Bound<Vec<Option<ScalarImpl>>>,
448 order_types: &[OrderType],
449 left_bound: bool,
450 ) -> Bound<Vec<Option<ScalarImpl>>> {
451 let left_scan_range = match left_scan_range {
452 Bound::Included(vec) | Bound::Excluded(vec) => vec,
453 Bound::Unbounded => return Bound::Unbounded,
454 };
455
456 let right_scan_range = match right_scan_range {
457 Bound::Included(vec) | Bound::Excluded(vec) => vec,
458 Bound::Unbounded => return Bound::Unbounded,
459 };
460
461 let cmp_len = left_scan_range.len().min(right_scan_range.len());
462
463 let cmp = cmp_rows(
464 &left_scan_range[..cmp_len],
465 &right_scan_range[..cmp_len],
466 &order_types[..cmp_len],
467 );
468
469 let bound = {
470 if (cmp.is_le() && left_bound) || (cmp.is_ge() && !left_bound) {
471 left_scan_range.to_vec()
472 } else {
473 right_scan_range.to_vec()
474 }
475 };
476
477 Bound::Included(bound)
479 }
480
481 scan_range_left.range.0 = merge_bound(
482 &scan_range_left.range.0,
483 &scan_range_right.range.0,
484 &order_types,
485 true,
486 );
487
488 scan_range_left.range.1 = merge_bound(
489 &scan_range_left.range.1,
490 &scan_range_right.range.1,
491 &order_types,
492 false,
493 );
494
495 if scan_range_left.is_full_table_scan() {
496 return Ok(None);
497 }
498 } else {
499 output_scan_ranges.push(scan_range_right.clone());
500 }
501
502 idx += 1;
503 }
504
505 Ok(Some((output_scan_ranges, true)))
506 }
507 }
508
509 fn split_row_cmp_to_scan_ranges(
510 &self,
511 table_desc: Rc<TableDesc>,
512 ) -> Result<Option<(Vec<ScanRange>, Self)>> {
513 let (mut row_conjunctions, row_conjunctions_without_struct): (Vec<_>, Vec<_>) =
514 self.conjunctions.clone().into_iter().partition(|expr| {
515 if let Some(f) = expr.as_function_call() {
516 if let Some(left_input) = f.inputs().get(0)
517 && let Some(left_input) = left_input.as_function_call()
518 && matches!(left_input.func_type(), ExprType::Row)
519 && left_input.inputs().iter().all(|x| x.is_input_ref())
520 && let Some(right_input) = f.inputs().get(1)
521 && right_input.is_literal()
522 {
523 true
524 } else {
525 false
526 }
527 } else {
528 false
529 }
530 });
531
532 if row_conjunctions.len() == 1 {
537 let row_conjunction = row_conjunctions.pop().unwrap();
538 let row_left_inputs = row_conjunction
539 .as_function_call()
540 .unwrap()
541 .inputs()
542 .get(0)
543 .unwrap()
544 .as_function_call()
545 .unwrap()
546 .inputs();
547 let row_right_literal = row_conjunction
548 .as_function_call()
549 .unwrap()
550 .inputs()
551 .get(1)
552 .unwrap()
553 .as_literal()
554 .unwrap();
555 if !matches!(row_right_literal.get_data(), Some(ScalarImpl::Struct(_))) {
556 return Ok(None);
557 }
558 let row_right_literal_data = row_right_literal.get_data().clone().unwrap();
559 let right_iter = row_right_literal_data.as_struct().fields();
560 let func_type = row_conjunction.as_function_call().unwrap().func_type();
561 if row_left_inputs.len() > 1
562 && (matches!(func_type, ExprType::LessThan)
563 || matches!(func_type, ExprType::GreaterThan))
564 {
565 let mut pk_struct = vec![];
566 let mut order_type = None;
567 let mut all_added = true;
568 let mut iter = row_left_inputs.iter().zip_eq_fast(right_iter);
569 for column_order in &table_desc.pk {
570 if let Some((left_expr, right_expr)) = iter.next() {
571 if left_expr.as_input_ref().unwrap().index != column_order.column_index {
572 all_added = false;
573 break;
574 }
575 match order_type {
576 Some(o) => {
577 if o != column_order.order_type {
578 all_added = false;
579 break;
580 }
581 }
582 None => order_type = Some(column_order.order_type),
583 }
584 pk_struct.push(right_expr.clone());
585 }
586 }
587
588 if !pk_struct.is_empty() {
590 if !all_added {
591 let scan_range = ScanRange {
592 eq_conds: vec![],
593 range: match func_type {
594 ExprType::GreaterThan => {
595 (Bound::Included(pk_struct), Bound::Unbounded)
596 }
597 ExprType::LessThan => {
598 (Bound::Unbounded, Bound::Included(pk_struct))
599 }
600 _ => unreachable!(),
601 },
602 };
603 return Ok(Some((
604 vec![scan_range],
605 Condition {
606 conjunctions: self.conjunctions.clone(),
607 },
608 )));
609 } else {
610 let scan_range = ScanRange {
611 eq_conds: vec![],
612 range: match func_type {
613 ExprType::GreaterThan => {
614 (Bound::Excluded(pk_struct), Bound::Unbounded)
615 }
616 ExprType::LessThan => {
617 (Bound::Unbounded, Bound::Excluded(pk_struct))
618 }
619 _ => unreachable!(),
620 },
621 };
622 return Ok(Some((
623 vec![scan_range],
624 Condition {
625 conjunctions: row_conjunctions_without_struct,
626 },
627 )));
628 }
629 }
630 }
631 }
632 Ok(None)
633 }
634
635 pub fn get_eq_const_input_refs(&self) -> Vec<InputRef> {
637 self.conjunctions
638 .iter()
639 .filter_map(|expr| expr.as_eq_const().map(|(input_ref, _)| input_ref))
640 .collect()
641 }
642
643 pub fn split_to_scan_ranges(
645 self,
646 table_desc: Rc<TableDesc>,
647 max_split_range_gap: u64,
648 ) -> Result<(Vec<ScanRange>, Self)> {
649 fn false_cond() -> (Vec<ScanRange>, Condition) {
650 (vec![], Condition::false_cond())
651 }
652
653 if self.conjunctions.len() == 1 {
655 if let Some(disjunctions) = self.conjunctions[0].as_or_disjunctions() {
656 if let Some((scan_ranges, maintaining_condition)) =
657 Self::disjunctions_to_scan_ranges(
658 table_desc,
659 max_split_range_gap,
660 disjunctions,
661 )?
662 {
663 if maintaining_condition {
664 return Ok((scan_ranges, self));
665 } else {
666 return Ok((scan_ranges, Condition::true_cond()));
667 }
668 } else {
669 return Ok((vec![], self));
670 }
671 }
672 }
673 if let Some((scan_ranges, other_condition)) =
674 self.split_row_cmp_to_scan_ranges(table_desc.clone())?
675 {
676 return Ok((scan_ranges, other_condition));
677 }
678
679 let mut groups = Self::classify_conjunctions_by_pk(self.conjunctions, &table_desc);
680 let mut other_conds = groups.pop().unwrap();
681
682 let mut scan_range = ScanRange::full_table_scan();
684 for i in 0..table_desc.order_column_indices().len() {
685 let group = std::mem::take(&mut groups[i]);
686 if group.is_empty() {
687 groups.push(other_conds);
688 return Ok((
689 if scan_range.is_full_table_scan() {
690 vec![]
691 } else {
692 vec![scan_range]
693 },
694 Self {
695 conjunctions: groups[i + 1..].concat(),
696 },
697 ));
698 }
699
700 let Some((
701 lower_bound_conjunctions,
702 upper_bound_conjunctions,
703 eq_conds,
704 part_of_other_conds,
705 )) = Self::analyze_group(group)?
706 else {
707 return Ok(false_cond());
708 };
709 other_conds.extend(part_of_other_conds.into_iter());
710
711 let lower_bound = Self::merge_lower_bound_conjunctions(lower_bound_conjunctions);
712 let upper_bound = Self::merge_upper_bound_conjunctions(upper_bound_conjunctions);
713
714 if Self::is_invalid_range(&lower_bound, &upper_bound) {
715 return Ok(false_cond());
716 }
717
718 match eq_conds.len() {
720 1 => {
721 let eq_conds =
722 Self::extract_eq_conds_within_range(eq_conds, &upper_bound, &lower_bound);
723 if eq_conds.is_empty() {
724 return Ok(false_cond());
725 }
726 scan_range.eq_conds.extend(eq_conds.into_iter());
727 }
728 0 => {
729 let convert = |bound| match bound {
730 Bound::Included(l) => Bound::Included(vec![Some(l)]),
731 Bound::Excluded(l) => Bound::Excluded(vec![Some(l)]),
732 Bound::Unbounded => Bound::Unbounded,
733 };
734 scan_range.range = (convert(lower_bound), convert(upper_bound));
735 other_conds.extend(groups[i + 1..].iter().flatten().cloned());
736 break;
737 }
738 _ => {
739 let eq_conds =
746 Self::extract_eq_conds_within_range(eq_conds, &upper_bound, &lower_bound);
747 if eq_conds.is_empty() {
748 return Ok(false_cond());
749 }
750 other_conds.extend(groups[i + 1..].iter().flatten().cloned());
751 let scan_ranges = eq_conds
752 .into_iter()
753 .map(|lit| {
754 let mut scan_range = scan_range.clone();
755 scan_range.eq_conds.push(lit);
756 scan_range
757 })
758 .collect();
759 return Ok((
760 scan_ranges,
761 Self {
762 conjunctions: other_conds,
763 },
764 ));
765 }
766 }
767 }
768
769 Ok((
770 if scan_range.is_full_table_scan() {
771 vec![]
772 } else if table_desc.columns[table_desc.order_column_indices()[0]]
773 .data_type
774 .is_int()
775 {
776 match scan_range.split_small_range(max_split_range_gap) {
777 Some(scan_ranges) => scan_ranges,
778 None => vec![scan_range],
779 }
780 } else {
781 vec![scan_range]
782 },
783 Self {
784 conjunctions: other_conds,
785 },
786 ))
787 }
788
789 fn classify_conjunctions_by_pk(
793 conjunctions: Vec<ExprImpl>,
794 table_desc: &Rc<TableDesc>,
795 ) -> Vec<Vec<ExprImpl>> {
796 let pk_column_ids = &table_desc.order_column_indices();
797 let pk_cols_num = pk_column_ids.len();
798 let cols_num = table_desc.columns.len();
799
800 let mut col_idx_to_pk_idx = vec![None; cols_num];
801 pk_column_ids.iter().enumerate().for_each(|(idx, pk_idx)| {
802 col_idx_to_pk_idx[*pk_idx] = Some(idx);
803 });
804
805 let mut groups = vec![vec![]; pk_cols_num + 1];
806 for (key, group) in &conjunctions.into_iter().chunk_by(|expr| {
807 let input_bits = expr.collect_input_refs(cols_num);
808 if input_bits.count_ones(..) == 1 {
809 let col_idx = input_bits.ones().next().unwrap();
810 col_idx_to_pk_idx[col_idx].unwrap_or(pk_cols_num)
811 } else {
812 pk_cols_num
813 }
814 }) {
815 groups[key].extend(group);
816 }
817
818 groups
819 }
820
821 #[allow(clippy::type_complexity)]
829 fn analyze_group(
830 group: Vec<ExprImpl>,
831 ) -> Result<
832 Option<(
833 Vec<Bound<ScalarImpl>>,
834 Vec<Bound<ScalarImpl>>,
835 Vec<Option<ScalarImpl>>,
836 Vec<ExprImpl>,
837 )>,
838 > {
839 let mut lower_bound_conjunctions = vec![];
840 let mut upper_bound_conjunctions = vec![];
841 let mut eq_conds = vec![];
843 let mut other_conds = vec![];
844
845 for expr in group {
847 if let Some((input_ref, const_expr)) = expr.as_eq_const() {
848 let new_expr = if let Ok(expr) = const_expr
849 .clone()
850 .cast_implicit(input_ref.data_type.clone())
851 {
852 expr
853 } else {
854 match self::cast_compare::cast_compare_for_eq(const_expr, input_ref.data_type) {
855 Ok(ResultForEq::Success(expr)) => expr,
856 Ok(ResultForEq::NeverEqual) => {
857 return Ok(None);
858 }
859 Err(_) => {
860 other_conds.push(expr);
861 continue;
862 }
863 }
864 };
865
866 let Some(new_cond) = new_expr.fold_const()? else {
867 return Ok(None);
869 };
870 if Self::mutual_exclusive_with_eq_conds(&new_cond, &eq_conds) {
871 return Ok(None);
872 }
873 eq_conds = vec![Some(new_cond)];
874 } else if expr.as_is_null().is_some() {
875 if !eq_conds.is_empty() && eq_conds.into_iter().all(|l| l.is_some()) {
876 return Ok(None);
877 }
878 eq_conds = vec![None];
879 } else if let Some((input_ref, in_const_list)) = expr.as_in_const_list() {
880 let mut scalars = HashSet::new();
881 for const_expr in in_const_list {
882 let const_expr = const_expr
885 .cast_implicit(input_ref.data_type.clone())
886 .unwrap();
887 let value = const_expr.fold_const()?;
888 let Some(value) = value else {
889 continue;
890 };
891 scalars.insert(Some(value));
892 }
893 if scalars.is_empty() {
894 return Ok(None);
896 }
897 if !eq_conds.is_empty() {
898 scalars = scalars
899 .intersection(&HashSet::from_iter(eq_conds))
900 .cloned()
901 .collect();
902 if scalars.is_empty() {
903 return Ok(None);
904 }
905 }
906 eq_conds = scalars
908 .into_iter()
909 .sorted_by(DefaultOrd::default_cmp)
910 .collect();
911 } else if let Some((input_ref, op, const_expr)) = expr.as_comparison_const() {
912 let new_expr = if let Ok(expr) = const_expr
913 .clone()
914 .cast_implicit(input_ref.data_type.clone())
915 {
916 expr
917 } else {
918 match self::cast_compare::cast_compare_for_cmp(
919 const_expr,
920 input_ref.data_type,
921 op,
922 ) {
923 Ok(ResultForCmp::Success(expr)) => expr,
924 _ => {
925 other_conds.push(expr);
926 continue;
927 }
928 }
929 };
930 let Some(value) = new_expr.fold_const()? else {
931 return Ok(None);
933 };
934 match op {
935 ExprType::LessThan => {
936 upper_bound_conjunctions.push(Bound::Excluded(value));
937 }
938 ExprType::LessThanOrEqual => {
939 upper_bound_conjunctions.push(Bound::Included(value));
940 }
941 ExprType::GreaterThan => {
942 lower_bound_conjunctions.push(Bound::Excluded(value));
943 }
944 ExprType::GreaterThanOrEqual => {
945 lower_bound_conjunctions.push(Bound::Included(value));
946 }
947 _ => unreachable!(),
948 }
949 } else {
950 other_conds.push(expr);
951 }
952 }
953 Ok(Some((
954 lower_bound_conjunctions,
955 upper_bound_conjunctions,
956 eq_conds,
957 other_conds,
958 )))
959 }
960
961 fn mutual_exclusive_with_eq_conds(
962 new_conds: &ScalarImpl,
963 eq_conds: &[Option<ScalarImpl>],
964 ) -> bool {
965 !eq_conds.is_empty()
966 && eq_conds.iter().all(|l| {
967 if let Some(l) = l {
968 l != new_conds
969 } else {
970 true
971 }
972 })
973 }
974
975 fn merge_lower_bound_conjunctions(lb: Vec<Bound<ScalarImpl>>) -> Bound<ScalarImpl> {
976 lb.into_iter()
977 .max_by(|a, b| {
978 match (a, b) {
980 (Bound::Included(_), Bound::Unbounded) => std::cmp::Ordering::Greater,
981 (Bound::Excluded(_), Bound::Unbounded) => std::cmp::Ordering::Greater,
982 (Bound::Unbounded, Bound::Included(_)) => std::cmp::Ordering::Less,
983 (Bound::Unbounded, Bound::Excluded(_)) => std::cmp::Ordering::Less,
984 (Bound::Unbounded, Bound::Unbounded) => std::cmp::Ordering::Equal,
985 (Bound::Included(a), Bound::Included(b)) => a.default_cmp(b),
986 (Bound::Excluded(a), Bound::Excluded(b)) => a.default_cmp(b),
987 (Bound::Included(a), Bound::Excluded(b)) => match a.default_cmp(b) {
989 std::cmp::Ordering::Equal => std::cmp::Ordering::Less,
990 other => other,
991 },
992 (Bound::Excluded(a), Bound::Included(b)) => match a.default_cmp(b) {
993 std::cmp::Ordering::Equal => std::cmp::Ordering::Greater,
994 other => other,
995 },
996 }
997 })
998 .unwrap_or(Bound::Unbounded)
999 }
1000
1001 fn merge_upper_bound_conjunctions(ub: Vec<Bound<ScalarImpl>>) -> Bound<ScalarImpl> {
1002 ub.into_iter()
1003 .min_by(|a, b| {
1004 match (a, b) {
1006 (Bound::Included(_), Bound::Unbounded) => std::cmp::Ordering::Less,
1007 (Bound::Excluded(_), Bound::Unbounded) => std::cmp::Ordering::Less,
1008 (Bound::Unbounded, Bound::Included(_)) => std::cmp::Ordering::Greater,
1009 (Bound::Unbounded, Bound::Excluded(_)) => std::cmp::Ordering::Greater,
1010 (Bound::Unbounded, Bound::Unbounded) => std::cmp::Ordering::Equal,
1011 (Bound::Included(a), Bound::Included(b)) => a.default_cmp(b),
1012 (Bound::Excluded(a), Bound::Excluded(b)) => a.default_cmp(b),
1013 (Bound::Included(a), Bound::Excluded(b)) => match a.default_cmp(b) {
1015 std::cmp::Ordering::Equal => std::cmp::Ordering::Greater,
1016 other => other,
1017 },
1018 (Bound::Excluded(a), Bound::Included(b)) => match a.default_cmp(b) {
1019 std::cmp::Ordering::Equal => std::cmp::Ordering::Less,
1020 other => other,
1021 },
1022 }
1023 })
1024 .unwrap_or(Bound::Unbounded)
1025 }
1026
1027 fn is_invalid_range(lower_bound: &Bound<ScalarImpl>, upper_bound: &Bound<ScalarImpl>) -> bool {
1028 match (lower_bound, upper_bound) {
1029 (Bound::Included(l), Bound::Included(u)) => l.default_cmp(u).is_gt(), (Bound::Included(l), Bound::Excluded(u)) => l.default_cmp(u).is_ge(), (Bound::Excluded(l), Bound::Included(u)) => l.default_cmp(u).is_ge(), (Bound::Excluded(l), Bound::Excluded(u)) => l.default_cmp(u).is_ge(), _ => false,
1034 }
1035 }
1036
1037 fn extract_eq_conds_within_range(
1038 eq_conds: Vec<Option<ScalarImpl>>,
1039 upper_bound: &Bound<ScalarImpl>,
1040 lower_bound: &Bound<ScalarImpl>,
1041 ) -> Vec<Option<ScalarImpl>> {
1042 if Self::is_invalid_range(lower_bound, upper_bound) {
1045 return vec![];
1046 }
1047
1048 let is_extract_null = upper_bound == &Bound::Unbounded && lower_bound == &Bound::Unbounded;
1049
1050 eq_conds
1051 .into_iter()
1052 .filter(|cond| {
1053 if let Some(cond) = cond {
1054 match lower_bound {
1055 Bound::Included(val) => {
1056 if cond.default_cmp(val).is_lt() {
1057 return false;
1059 }
1060 }
1061 Bound::Excluded(val) => {
1062 if cond.default_cmp(val).is_le() {
1063 return false;
1065 }
1066 }
1067 Bound::Unbounded => {}
1068 }
1069 match upper_bound {
1070 Bound::Included(val) => {
1071 if cond.default_cmp(val).is_gt() {
1072 return false;
1074 }
1075 }
1076 Bound::Excluded(val) => {
1077 if cond.default_cmp(val).is_ge() {
1078 return false;
1080 }
1081 }
1082 Bound::Unbounded => {}
1083 }
1084 true
1085 } else {
1086 is_extract_null
1087 }
1088 })
1089 .collect()
1090 }
1091
1092 #[must_use]
1098 pub fn group_by<F, const N: usize>(self, f: F) -> [Self; N]
1099 where
1100 F: Fn(&ExprImpl) -> usize,
1101 {
1102 const EMPTY: Vec<ExprImpl> = vec![];
1103 let mut groups = [EMPTY; N];
1104 for (key, group) in &self.conjunctions.into_iter().chunk_by(|expr| {
1105 let i = f(expr);
1107 assert!(i < N);
1108 i
1109 }) {
1110 groups[key].extend(group);
1111 }
1112
1113 groups.map(|group| Condition {
1114 conjunctions: group,
1115 })
1116 }
1117
1118 #[must_use]
1119 pub fn rewrite_expr(self, rewriter: &mut (impl ExprRewriter + ?Sized)) -> Self {
1120 Self {
1121 conjunctions: self
1122 .conjunctions
1123 .into_iter()
1124 .map(|expr| rewriter.rewrite_expr(expr))
1125 .collect(),
1126 }
1127 .simplify()
1128 }
1129
1130 pub fn visit_expr<V: ExprVisitor + ?Sized>(&self, visitor: &mut V) {
1131 self.conjunctions
1132 .iter()
1133 .for_each(|expr| visitor.visit_expr(expr));
1134 }
1135
1136 pub fn visit_expr_mut(&mut self, mutator: &mut (impl ExprMutator + ?Sized)) {
1137 self.conjunctions
1138 .iter_mut()
1139 .for_each(|expr| mutator.visit_expr(expr))
1140 }
1141
1142 fn simplify(self) -> Self {
1145 let conjunctions: Vec<_> = self
1147 .conjunctions
1148 .into_iter()
1149 .map(push_down_not)
1150 .map(fold_boolean_constant)
1151 .map(column_self_eq_eliminate)
1152 .flat_map(to_conjunctions)
1153 .collect();
1154 let mut res: Vec<ExprImpl> = Vec::new();
1155 let mut visited: HashSet<ExprImpl> = HashSet::new();
1156 for expr in conjunctions {
1157 if !expr.has_subquery() {
1159 let results_of_factorization = factorization_expr(expr);
1160 res.extend(
1161 results_of_factorization
1162 .clone()
1163 .into_iter()
1164 .filter(|expr| !visited.contains(expr)),
1165 );
1166 visited.extend(results_of_factorization);
1167 } else {
1168 res.push(expr);
1170 }
1171 }
1172 res.retain(|expr| {
1174 if let Some(v) = try_get_bool_constant(expr)
1175 && v
1176 {
1177 false
1178 } else {
1179 true
1180 }
1181 });
1182 for expr in &mut res {
1184 if let Some(v) = try_get_bool_constant(expr) {
1185 if !v {
1186 res.clear();
1187 res.push(ExprImpl::literal_bool(false));
1188 break;
1189 }
1190 }
1191 }
1192 Self { conjunctions: res }
1193 }
1194}
1195
1196pub struct ConditionDisplay<'a> {
1197 pub condition: &'a Condition,
1198 pub input_schema: &'a Schema,
1199}
1200
1201impl ConditionDisplay<'_> {
1202 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1203 if self.condition.always_true() {
1204 write!(f, "true")
1205 } else {
1206 write!(
1207 f,
1208 "{}",
1209 self.condition
1210 .conjunctions
1211 .iter()
1212 .format_with(" AND ", |expr, f| {
1213 f(&ExprDisplay {
1214 expr,
1215 input_schema: self.input_schema,
1216 })
1217 })
1218 )
1219 }
1220 }
1221}
1222
1223impl fmt::Display for ConditionDisplay<'_> {
1224 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1225 self.fmt(f)
1226 }
1227}
1228
1229impl fmt::Debug for ConditionDisplay<'_> {
1230 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1231 self.fmt(f)
1232 }
1233}
1234
1235mod cast_compare {
1240 use risingwave_common::types::DataType;
1241
1242 use crate::expr::{Expr, ExprImpl, ExprType};
1243
1244 enum ShrinkResult {
1245 OutUpperBound,
1246 OutLowerBound,
1247 InRange(ExprImpl),
1248 }
1249
1250 pub enum ResultForEq {
1251 Success(ExprImpl),
1252 NeverEqual,
1253 }
1254
1255 pub enum ResultForCmp {
1256 Success(ExprImpl),
1257 OutUpperBound,
1258 OutLowerBound,
1259 }
1260
1261 pub fn cast_compare_for_eq(const_expr: ExprImpl, target: DataType) -> Result<ResultForEq, ()> {
1262 match (const_expr.return_type(), &target) {
1263 (DataType::Int64, DataType::Int32)
1264 | (DataType::Int64, DataType::Int16)
1265 | (DataType::Int32, DataType::Int16) => match shrink_integral(const_expr, target)? {
1266 ShrinkResult::InRange(expr) => Ok(ResultForEq::Success(expr)),
1267 ShrinkResult::OutUpperBound | ShrinkResult::OutLowerBound => {
1268 Ok(ResultForEq::NeverEqual)
1269 }
1270 },
1271 _ => Err(()),
1272 }
1273 }
1274
1275 pub fn cast_compare_for_cmp(
1276 const_expr: ExprImpl,
1277 target: DataType,
1278 _op: ExprType,
1279 ) -> Result<ResultForCmp, ()> {
1280 match (const_expr.return_type(), &target) {
1281 (DataType::Int64, DataType::Int32)
1282 | (DataType::Int64, DataType::Int16)
1283 | (DataType::Int32, DataType::Int16) => match shrink_integral(const_expr, target)? {
1284 ShrinkResult::InRange(expr) => Ok(ResultForCmp::Success(expr)),
1285 ShrinkResult::OutUpperBound => Ok(ResultForCmp::OutUpperBound),
1286 ShrinkResult::OutLowerBound => Ok(ResultForCmp::OutLowerBound),
1287 },
1288 _ => Err(()),
1289 }
1290 }
1291
1292 fn shrink_integral(const_expr: ExprImpl, target: DataType) -> Result<ShrinkResult, ()> {
1293 let (upper_bound, lower_bound) = match (const_expr.return_type(), &target) {
1294 (DataType::Int64, DataType::Int32) => (i32::MAX as i64, i32::MIN as i64),
1295 (DataType::Int64, DataType::Int16) | (DataType::Int32, DataType::Int16) => {
1296 (i16::MAX as i64, i16::MIN as i64)
1297 }
1298 _ => unreachable!(),
1299 };
1300 match const_expr.fold_const().map_err(|_| ())? {
1301 Some(scalar) => {
1302 let value = scalar.as_integral();
1303 if value > upper_bound {
1304 Ok(ShrinkResult::OutUpperBound)
1305 } else if value < lower_bound {
1306 Ok(ShrinkResult::OutLowerBound)
1307 } else {
1308 Ok(ShrinkResult::InRange(
1309 const_expr.cast_explicit(target).unwrap(),
1310 ))
1311 }
1312 }
1313 None => Ok(ShrinkResult::InRange(
1314 const_expr.cast_explicit(target).unwrap(),
1315 )),
1316 }
1317 }
1318}
1319
1320#[cfg(test)]
1321mod tests {
1322 use rand::Rng;
1323
1324 use super::*;
1325
1326 #[test]
1327 fn test_split() {
1328 let left_col_num = 3;
1329 let right_col_num = 2;
1330
1331 let ty = DataType::Int32;
1332
1333 let mut rng = rand::rng();
1334
1335 let left: ExprImpl = FunctionCall::new(
1336 ExprType::LessThanOrEqual,
1337 vec![
1338 InputRef::new(rng.random_range(0..left_col_num), ty.clone()).into(),
1339 InputRef::new(rng.random_range(0..left_col_num), ty.clone()).into(),
1340 ],
1341 )
1342 .unwrap()
1343 .into();
1344
1345 let right: ExprImpl = FunctionCall::new(
1346 ExprType::LessThan,
1347 vec![
1348 InputRef::new(
1349 rng.random_range(left_col_num..left_col_num + right_col_num),
1350 ty.clone(),
1351 )
1352 .into(),
1353 InputRef::new(
1354 rng.random_range(left_col_num..left_col_num + right_col_num),
1355 ty.clone(),
1356 )
1357 .into(),
1358 ],
1359 )
1360 .unwrap()
1361 .into();
1362
1363 let other: ExprImpl = FunctionCall::new(
1364 ExprType::GreaterThan,
1365 vec![
1366 InputRef::new(rng.random_range(0..left_col_num), ty.clone()).into(),
1367 InputRef::new(
1368 rng.random_range(left_col_num..left_col_num + right_col_num),
1369 ty,
1370 )
1371 .into(),
1372 ],
1373 )
1374 .unwrap()
1375 .into();
1376
1377 let cond = Condition::with_expr(other.clone())
1378 .and(Condition::with_expr(right.clone()))
1379 .and(Condition::with_expr(left.clone()));
1380
1381 let res = cond.split(left_col_num, right_col_num);
1382
1383 assert_eq!(res.0.conjunctions, vec![left]);
1384 assert_eq!(res.1.conjunctions, vec![right]);
1385 assert_eq!(res.2.conjunctions, vec![other]);
1386 }
1387
1388 #[test]
1389 fn test_self_eq_eliminate() {
1390 let left_col_num = 3;
1391 let right_col_num = 2;
1392
1393 let ty = DataType::Int32;
1394
1395 let mut rng = rand::rng();
1396
1397 let x: ExprImpl = InputRef::new(rng.random_range(0..left_col_num), ty.clone()).into();
1398
1399 let left: ExprImpl = FunctionCall::new(ExprType::Equal, vec![x.clone(), x.clone()])
1400 .unwrap()
1401 .into();
1402
1403 let right: ExprImpl = FunctionCall::new(
1404 ExprType::LessThan,
1405 vec![
1406 InputRef::new(
1407 rng.random_range(left_col_num..left_col_num + right_col_num),
1408 ty.clone(),
1409 )
1410 .into(),
1411 InputRef::new(
1412 rng.random_range(left_col_num..left_col_num + right_col_num),
1413 ty.clone(),
1414 )
1415 .into(),
1416 ],
1417 )
1418 .unwrap()
1419 .into();
1420
1421 let other: ExprImpl = FunctionCall::new(
1422 ExprType::GreaterThan,
1423 vec![
1424 InputRef::new(rng.random_range(0..left_col_num), ty.clone()).into(),
1425 InputRef::new(
1426 rng.random_range(left_col_num..left_col_num + right_col_num),
1427 ty,
1428 )
1429 .into(),
1430 ],
1431 )
1432 .unwrap()
1433 .into();
1434
1435 let cond = Condition::with_expr(other.clone())
1436 .and(Condition::with_expr(right.clone()))
1437 .and(Condition::with_expr(left.clone()));
1438
1439 let res = cond.split(left_col_num, right_col_num);
1440
1441 let left_res = FunctionCall::new(ExprType::IsNotNull, vec![x])
1442 .unwrap()
1443 .into();
1444
1445 assert_eq!(res.0.conjunctions, vec![left_res]);
1446 assert_eq!(res.1.conjunctions, vec![right]);
1447 assert_eq!(res.2.conjunctions, vec![other]);
1448 }
1449}