1use std::cmp::Ordering;
16use std::collections::{BTreeMap, HashSet};
17use std::fmt::{self, Debug};
18use std::ops::Bound;
19use std::sync::LazyLock;
20
21use fixedbitset::FixedBitSet;
22use itertools::Itertools;
23use risingwave_common::catalog::Schema;
24use risingwave_common::types::{DataType, DefaultOrd, ScalarImpl};
25use risingwave_common::util::iter_util::ZipEqFast;
26use risingwave_common::util::scan_range::{ScanRange, is_full_range};
27use risingwave_common::util::sort_util::{OrderType, cmp_rows};
28
29use crate::TableCatalog;
30use crate::error::Result;
31use crate::expr::{
32 ExprDisplay, ExprImpl, ExprMutator, ExprRewriter, ExprType, ExprVisitor, FunctionCall,
33 InequalityInputPair, InputRef, collect_input_refs, column_self_eq_eliminate,
34 factorization_expr, fold_boolean_constant, push_down_not, to_conjunctions,
35 try_get_bool_constant,
36};
37use crate::utils::condition::cast_compare::{ResultForCmp, ResultForEq};
38
39#[derive(Debug, Clone, PartialEq, Eq, Hash)]
40pub struct Condition {
41 pub conjunctions: Vec<ExprImpl>,
43}
44
45impl IntoIterator for Condition {
46 type IntoIter = std::vec::IntoIter<ExprImpl>;
47 type Item = ExprImpl;
48
49 fn into_iter(self) -> Self::IntoIter {
50 self.conjunctions.into_iter()
51 }
52}
53
54impl fmt::Display for Condition {
55 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56 let mut conjunctions = self.conjunctions.iter();
57 if let Some(expr) = conjunctions.next() {
58 write!(f, "{:?}", expr)?;
59 }
60 if self.always_true() {
61 write!(f, "true")?;
62 } else {
63 for expr in conjunctions {
64 write!(f, " AND {:?}", expr)?;
65 }
66 }
67 Ok(())
68 }
69}
70
71impl Condition {
72 pub fn with_expr(expr: ExprImpl) -> Self {
73 let conjunctions = to_conjunctions(expr);
74
75 Self { conjunctions }.simplify()
76 }
77
78 pub fn true_cond() -> Self {
79 Self {
80 conjunctions: vec![],
81 }
82 }
83
84 pub fn false_cond() -> Self {
85 Self {
86 conjunctions: vec![ExprImpl::literal_bool(false)],
87 }
88 }
89
90 pub fn always_true(&self) -> bool {
91 self.conjunctions.is_empty()
92 }
93
94 pub fn always_false(&self) -> bool {
95 static FALSE: LazyLock<ExprImpl> = LazyLock::new(|| ExprImpl::literal_bool(false));
96 !self.conjunctions.is_empty() && self.conjunctions.contains(&*FALSE)
98 }
99
100 pub fn as_expr_unless_true(&self) -> Option<ExprImpl> {
102 if self.always_true() {
103 None
104 } else {
105 Some(self.clone().into())
106 }
107 }
108
109 #[must_use]
110 pub fn and(self, other: Self) -> Self {
111 let mut ret = self;
112 ret.conjunctions.extend(other.conjunctions);
113 ret.simplify()
114 }
115
116 #[must_use]
117 pub fn or(self, other: Self) -> Self {
118 let or_expr = ExprImpl::FunctionCall(
119 FunctionCall::new_unchecked(
120 ExprType::Or,
121 vec![self.into(), other.into()],
122 DataType::Boolean,
123 )
124 .into(),
125 );
126 let ret = Self::with_expr(or_expr);
127 ret.simplify()
128 }
129
130 #[must_use]
132 pub fn split(self, left_col_num: usize, right_col_num: usize) -> (Self, Self, Self) {
133 let left_bit_map = FixedBitSet::from_iter(0..left_col_num);
134 let right_bit_map = FixedBitSet::from_iter(left_col_num..left_col_num + right_col_num);
135
136 self.group_by::<_, 3>(|expr| {
137 let input_bits = expr.collect_input_refs(left_col_num + right_col_num);
138 if input_bits.is_subset(&left_bit_map) {
139 0
140 } else if input_bits.is_subset(&right_bit_map) {
141 1
142 } else {
143 2
144 }
145 })
146 .into_iter()
147 .next_tuple()
148 .unwrap()
149 }
150
151 pub fn collect_input_refs(&self, input_col_num: usize) -> FixedBitSet {
156 collect_input_refs(input_col_num, &self.conjunctions)
157 }
158
159 #[must_use]
173 pub fn split_by_input_col_nums(
174 self,
175 input_col_nums: &[usize],
176 only_eq: bool,
177 ) -> (BTreeMap<(usize, usize), Self>, Self) {
178 let mut bitmaps = Vec::with_capacity(input_col_nums.len());
179 let mut cols_seen = 0;
180 for cols in input_col_nums {
181 bitmaps.push(FixedBitSet::from_iter(cols_seen..cols_seen + cols));
182 cols_seen += cols;
183 }
184
185 let mut pairwise_conditions = BTreeMap::new();
186 let mut non_eq_join = vec![];
187
188 for expr in self.conjunctions {
189 let input_bits = expr.collect_input_refs(cols_seen);
190 let mut subset_indices = Vec::with_capacity(input_col_nums.len());
191 for (idx, bitmap) in bitmaps.iter().enumerate() {
192 if !input_bits.is_disjoint(bitmap) {
193 subset_indices.push(idx);
194 }
195 }
196 if subset_indices.len() != 2 || (only_eq && expr.as_eq_cond().is_none()) {
197 non_eq_join.push(expr);
198 } else {
199 let key = if subset_indices[0] < subset_indices[1] {
201 (subset_indices[0], subset_indices[1])
202 } else {
203 (subset_indices[1], subset_indices[0])
204 };
205 let e = pairwise_conditions
206 .entry(key)
207 .or_insert_with(Condition::true_cond);
208 e.conjunctions.push(expr);
209 }
210 }
211 (
212 pairwise_conditions,
213 Condition {
214 conjunctions: non_eq_join,
215 },
216 )
217 }
218
219 #[must_use]
220 pub fn split_eq_keys(
227 self,
228 left_col_num: usize,
229 right_col_num: usize,
230 ) -> (Vec<(InputRef, InputRef, bool)>, Self) {
231 let left_bit_map = FixedBitSet::from_iter(0..left_col_num);
232 let right_bit_map = FixedBitSet::from_iter(left_col_num..left_col_num + right_col_num);
233
234 let (mut eq_keys, mut others) = (vec![], vec![]);
235 self.conjunctions.into_iter().for_each(|expr| {
236 let input_bits = expr.collect_input_refs(left_col_num + right_col_num);
237 if input_bits.is_disjoint(&left_bit_map) || input_bits.is_disjoint(&right_bit_map) {
238 others.push(expr)
239 } else if let Some(columns) = expr.as_eq_cond() {
240 eq_keys.push((columns.0, columns.1, false));
241 } else if let Some(columns) = expr.as_is_not_distinct_from_cond() {
242 eq_keys.push((columns.0, columns.1, true));
243 } else {
244 others.push(expr)
245 }
246 });
247
248 (
249 eq_keys,
250 Condition {
251 conjunctions: others,
252 },
253 )
254 }
255
256 pub(crate) fn extract_inequality_keys(
267 &self,
268 left_col_num: usize,
269 right_col_num: usize,
270 ) -> Vec<(usize, InequalityInputPair)> {
271 let left_bit_map = FixedBitSet::from_iter(0..left_col_num);
272 let right_bit_map = FixedBitSet::from_iter(left_col_num..left_col_num + right_col_num);
273
274 self.conjunctions
275 .iter()
276 .enumerate()
277 .filter_map(|(conjunction_idx, expr)| {
278 let input_bits = expr.collect_input_refs(left_col_num + right_col_num);
279 if input_bits.is_disjoint(&left_bit_map) || input_bits.is_disjoint(&right_bit_map) {
280 return None;
281 }
282
283 let (left_input, op, right_input) = expr.as_comparison_cond()?;
285
286 if left_input.index() < left_col_num
289 && right_input.index() >= left_col_num
290 && right_input.index() < left_col_num + right_col_num
291 {
292 Some((
293 conjunction_idx,
294 InequalityInputPair::new(
295 left_input.index(),
296 right_input.index() - left_col_num, op,
298 ),
299 ))
300 } else {
301 None
302 }
303 })
304 .collect_vec()
305 }
306
307 #[must_use]
310 pub fn split_disjoint(self, columns: &FixedBitSet) -> (Self, Self) {
311 self.group_by::<_, 2>(|expr| {
312 let input_bits = expr.collect_input_refs(columns.len());
313 input_bits.is_disjoint(columns) as usize
314 })
315 .into_iter()
316 .next_tuple()
317 .unwrap()
318 }
319
320 fn disjunctions_to_scan_ranges(
324 table: &TableCatalog,
325 max_split_range_gap: u64,
326 disjunctions: Vec<ExprImpl>,
327 ) -> Result<Option<(Vec<ScanRange>, bool)>> {
328 let disjunctions_result: Result<Vec<(Vec<ScanRange>, Self)>> = disjunctions
329 .into_iter()
330 .map(|x| {
331 Condition {
332 conjunctions: to_conjunctions(x),
333 }
334 .split_to_scan_ranges(table, max_split_range_gap)
335 })
336 .collect();
337
338 let disjunctions_result = disjunctions_result?;
340
341 let all_equal = disjunctions_result
344 .iter()
345 .all(|(scan_ranges, other_condition)| {
346 other_condition.always_true()
347 && scan_ranges
348 .iter()
349 .all(|x| !x.eq_conds.is_empty() && is_full_range(&x.range))
350 });
351
352 if all_equal {
353 let scan_ranges = disjunctions_result
357 .into_iter()
358 .flat_map(|(scan_ranges, _)| scan_ranges)
359 .sorted_by(|a, b| a.eq_conds.len().cmp(&b.eq_conds.len()))
361 .collect_vec();
362 let mut non_overlap_scan_ranges: Vec<ScanRange> = vec![];
364 for s1 in &scan_ranges {
365 let overlap = non_overlap_scan_ranges.iter().any(|s2| {
366 #[allow(clippy::disallowed_methods)]
367 s1.eq_conds
368 .iter()
369 .zip(s2.eq_conds.iter())
370 .all(|(a, b)| a == b)
371 });
372 if !overlap {
376 non_overlap_scan_ranges.push(s1.clone());
377 }
378 }
379
380 Ok(Some((non_overlap_scan_ranges, false)))
381 } else {
382 let mut scan_ranges = vec![];
383 for (scan_ranges_chunk, _) in disjunctions_result {
384 if scan_ranges_chunk.is_empty() {
385 return Ok(None);
387 }
388
389 scan_ranges.extend(scan_ranges_chunk);
390 }
391
392 let order_types = table
393 .pk
394 .iter()
395 .cloned()
396 .map(|x| {
397 if x.order_type.is_descending() {
398 x.order_type.reverse()
399 } else {
400 x.order_type
401 }
402 })
403 .collect_vec();
404 scan_ranges.sort_by(|left, right| {
405 let (left_start, _left_end) = &left.convert_to_range();
406 let (right_start, _right_end) = &right.convert_to_range();
407
408 let left_start_vec = match &left_start {
409 Bound::Included(vec) | Bound::Excluded(vec) => vec,
410 _ => &vec![],
411 };
412 let right_start_vec = match &right_start {
413 Bound::Included(vec) | Bound::Excluded(vec) => vec,
414 _ => &vec![],
415 };
416
417 if left_start_vec.is_empty() && right_start_vec.is_empty() {
418 return Ordering::Less;
419 }
420
421 if left_start_vec.is_empty() {
422 return Ordering::Less;
423 }
424
425 if right_start_vec.is_empty() {
426 return Ordering::Greater;
427 }
428
429 let cmp_column_len = left_start_vec.len().min(right_start_vec.len());
430 cmp_rows(
431 &left_start_vec[0..cmp_column_len],
432 &right_start_vec[0..cmp_column_len],
433 &order_types[0..cmp_column_len],
434 )
435 });
436
437 if scan_ranges.is_empty() {
438 return Ok(None);
439 }
440
441 if scan_ranges.len() == 1 {
442 return Ok(Some((scan_ranges, true)));
443 }
444
445 let mut output_scan_ranges: Vec<ScanRange> = vec![];
446 output_scan_ranges.push(scan_ranges[0].clone());
447 let mut idx = 1;
448 loop {
449 if idx >= scan_ranges.len() {
450 break;
451 }
452
453 let scan_range_left = output_scan_ranges.last_mut().unwrap();
454 let scan_range_right = &scan_ranges[idx];
455
456 if scan_range_left.eq_conds == scan_range_right.eq_conds {
457 if !ScanRange::is_overlap(scan_range_left, scan_range_right, &order_types) {
460 output_scan_ranges.push(scan_range_right.clone());
462 idx += 1;
463 continue;
464 }
465
466 fn merge_bound(
468 left_scan_range: &Bound<Vec<Option<ScalarImpl>>>,
469 right_scan_range: &Bound<Vec<Option<ScalarImpl>>>,
470 order_types: &[OrderType],
471 left_bound: bool,
472 ) -> Bound<Vec<Option<ScalarImpl>>> {
473 let left_scan_range = match left_scan_range {
474 Bound::Included(vec) | Bound::Excluded(vec) => vec,
475 Bound::Unbounded => return Bound::Unbounded,
476 };
477
478 let right_scan_range = match right_scan_range {
479 Bound::Included(vec) | Bound::Excluded(vec) => vec,
480 Bound::Unbounded => return Bound::Unbounded,
481 };
482
483 let cmp_len = left_scan_range.len().min(right_scan_range.len());
484
485 let cmp = cmp_rows(
486 &left_scan_range[..cmp_len],
487 &right_scan_range[..cmp_len],
488 &order_types[..cmp_len],
489 );
490
491 let bound = {
492 if (cmp.is_le() && left_bound) || (cmp.is_ge() && !left_bound) {
493 left_scan_range.clone()
494 } else {
495 right_scan_range.clone()
496 }
497 };
498
499 Bound::Included(bound)
501 }
502
503 scan_range_left.range.0 = merge_bound(
504 &scan_range_left.range.0,
505 &scan_range_right.range.0,
506 &order_types,
507 true,
508 );
509
510 scan_range_left.range.1 = merge_bound(
511 &scan_range_left.range.1,
512 &scan_range_right.range.1,
513 &order_types,
514 false,
515 );
516
517 if scan_range_left.is_full_table_scan() {
518 return Ok(None);
519 }
520 } else {
521 output_scan_ranges.push(scan_range_right.clone());
522 }
523
524 idx += 1;
525 }
526
527 Ok(Some((output_scan_ranges, true)))
528 }
529 }
530
531 fn split_row_cmp_to_scan_ranges(
532 &self,
533 table: &TableCatalog,
534 ) -> Result<Option<(Vec<ScanRange>, Self)>> {
535 let (mut row_conjunctions, row_conjunctions_without_struct): (Vec<_>, Vec<_>) =
536 self.conjunctions.clone().into_iter().partition(|expr| {
537 if let Some(f) = expr.as_function_call() {
538 if let Some(left_input) = f.inputs().get(0)
539 && let Some(left_input) = left_input.as_function_call()
540 && matches!(left_input.func_type(), ExprType::Row)
541 && left_input.inputs().iter().all(|x| x.is_input_ref())
542 && let Some(right_input) = f.inputs().get(1)
543 && right_input.is_literal()
544 {
545 true
546 } else {
547 false
548 }
549 } else {
550 false
551 }
552 });
553 if row_conjunctions.len() == 1 {
558 let row_conjunction = row_conjunctions.pop().unwrap();
559 let row_left_inputs = row_conjunction
560 .as_function_call()
561 .unwrap()
562 .inputs()
563 .get(0)
564 .unwrap()
565 .as_function_call()
566 .unwrap()
567 .inputs();
568 let row_right_literal = row_conjunction
569 .as_function_call()
570 .unwrap()
571 .inputs()
572 .get(1)
573 .unwrap()
574 .as_literal()
575 .unwrap();
576 if !matches!(row_right_literal.get_data(), Some(ScalarImpl::Struct(_))) {
577 return Ok(None);
578 }
579 let row_right_literal_data = row_right_literal.get_data().clone().unwrap();
580 let right_iter = row_right_literal_data.as_struct().fields();
581 let func_type = row_conjunction.as_function_call().unwrap().func_type();
582 if row_left_inputs.len() > 1
583 && (matches!(func_type, ExprType::LessThan)
584 || matches!(func_type, ExprType::GreaterThan))
585 {
586 let mut pk_struct = vec![];
587 let mut order_type = None;
588 let mut all_added = true;
589 let mut iter = row_left_inputs.iter().zip_eq_fast(right_iter);
590 for column_order in &table.pk {
591 if let Some((left_expr, right_expr)) = iter.next() {
592 if left_expr.as_input_ref().unwrap().index != column_order.column_index {
593 all_added = false;
594 break;
595 }
596 match order_type {
597 Some(o) => {
598 if o != column_order.order_type {
599 all_added = false;
600 break;
601 }
602 }
603 None => order_type = Some(column_order.order_type),
604 }
605 pk_struct.push(right_expr.clone());
606 }
607 }
608
609 if !pk_struct.is_empty() {
611 if !all_added {
612 let scan_range = ScanRange {
613 eq_conds: vec![],
614 range: match func_type {
615 ExprType::GreaterThan => {
616 (Bound::Included(pk_struct), Bound::Unbounded)
617 }
618 ExprType::LessThan => {
619 (Bound::Unbounded, Bound::Included(pk_struct))
620 }
621 _ => unreachable!(),
622 },
623 };
624 return Ok(Some((
625 vec![scan_range],
626 Condition {
627 conjunctions: self.conjunctions.clone(),
628 },
629 )));
630 } else {
631 let scan_range = ScanRange {
632 eq_conds: vec![],
633 range: match func_type {
634 ExprType::GreaterThan => {
635 (Bound::Excluded(pk_struct), Bound::Unbounded)
636 }
637 ExprType::LessThan => {
638 (Bound::Unbounded, Bound::Excluded(pk_struct))
639 }
640 _ => unreachable!(),
641 },
642 };
643 return Ok(Some((
644 vec![scan_range],
645 Condition {
646 conjunctions: row_conjunctions_without_struct,
647 },
648 )));
649 }
650 }
651 }
652 }
653 Ok(None)
654 }
655
656 pub fn get_eq_const_input_refs(&self) -> Vec<InputRef> {
658 self.conjunctions
659 .iter()
660 .filter_map(|expr| expr.as_eq_const().map(|(input_ref, _)| input_ref))
661 .collect()
662 }
663
664 pub fn split_to_scan_ranges(
666 self,
667 table: &TableCatalog,
668 max_split_range_gap: u64,
669 ) -> Result<(Vec<ScanRange>, Self)> {
670 fn false_cond() -> (Vec<ScanRange>, Condition) {
671 (vec![], Condition::false_cond())
672 }
673
674 if self.conjunctions.len() == 1
676 && let Some(disjunctions) = self.conjunctions[0].as_or_disjunctions()
677 {
678 if let Some((scan_ranges, maintaining_condition)) =
679 Self::disjunctions_to_scan_ranges(table, max_split_range_gap, disjunctions)?
680 {
681 if maintaining_condition {
682 return Ok((scan_ranges, self));
683 } else {
684 return Ok((scan_ranges, Condition::true_cond()));
685 }
686 } else {
687 return Ok((vec![], self));
688 }
689 }
690 if let Some((scan_ranges, other_condition)) = self.split_row_cmp_to_scan_ranges(table)? {
691 return Ok((scan_ranges, other_condition));
692 }
693
694 let mut groups = Self::classify_conjunctions_by_pk(self.conjunctions, table);
695 let mut other_conds = groups.pop().unwrap();
696
697 let mut scan_range = ScanRange::full_table_scan();
699 for i in 0..table.pk.len() {
700 let group = std::mem::take(&mut groups[i]);
701 if group.is_empty() {
702 groups.push(other_conds);
703 return Ok((
704 if scan_range.is_full_table_scan() {
705 vec![]
706 } else {
707 vec![scan_range]
708 },
709 Self {
710 conjunctions: groups[i + 1..].concat(),
711 },
712 ));
713 }
714
715 let Some((
716 lower_bound_conjunctions,
717 upper_bound_conjunctions,
718 eq_conds,
719 part_of_other_conds,
720 )) = Self::analyze_group(group)?
721 else {
722 return Ok(false_cond());
723 };
724 other_conds.extend(part_of_other_conds.into_iter());
725
726 let lower_bound = Self::merge_lower_bound_conjunctions(lower_bound_conjunctions);
727 let upper_bound = Self::merge_upper_bound_conjunctions(upper_bound_conjunctions);
728
729 if Self::is_invalid_range(&lower_bound, &upper_bound) {
730 return Ok(false_cond());
731 }
732
733 match eq_conds.len() {
735 1 => {
736 let eq_conds =
737 Self::extract_eq_conds_within_range(eq_conds, &upper_bound, &lower_bound);
738 if eq_conds.is_empty() {
739 return Ok(false_cond());
740 }
741 scan_range.eq_conds.extend(eq_conds.into_iter());
742 }
743 0 => {
744 let convert = |bound| match bound {
745 Bound::Included(l) => Bound::Included(vec![Some(l)]),
746 Bound::Excluded(l) => Bound::Excluded(vec![Some(l)]),
747 Bound::Unbounded => Bound::Unbounded,
748 };
749 scan_range.range = (convert(lower_bound), convert(upper_bound));
750 other_conds.extend(groups[i + 1..].iter().flatten().cloned());
751 break;
752 }
753 _ => {
754 let eq_conds =
761 Self::extract_eq_conds_within_range(eq_conds, &upper_bound, &lower_bound);
762 if eq_conds.is_empty() {
763 return Ok(false_cond());
764 }
765 other_conds.extend(groups[i + 1..].iter().flatten().cloned());
766 let scan_ranges = eq_conds
767 .into_iter()
768 .map(|lit| {
769 let mut scan_range = scan_range.clone();
770 scan_range.eq_conds.push(lit);
771 scan_range
772 })
773 .collect();
774 return Ok((
775 scan_ranges,
776 Self {
777 conjunctions: other_conds,
778 },
779 ));
780 }
781 }
782 }
783
784 Ok((
785 if scan_range.is_full_table_scan() {
786 vec![]
787 } else if table.columns[table.pk[0].column_index].data_type.is_int() {
788 match scan_range.split_small_range(max_split_range_gap) {
789 Some(scan_ranges) => scan_ranges,
790 None => vec![scan_range],
791 }
792 } else {
793 vec![scan_range]
794 },
795 Self {
796 conjunctions: other_conds,
797 },
798 ))
799 }
800
801 fn classify_conjunctions_by_pk(
805 conjunctions: Vec<ExprImpl>,
806 table: &TableCatalog,
807 ) -> Vec<Vec<ExprImpl>> {
808 let pk_cols_num = table.pk.len();
809 let cols_num = table.columns.len();
810
811 let mut col_idx_to_pk_idx = vec![None; cols_num];
812 table
813 .order_column_indices()
814 .enumerate()
815 .for_each(|(idx, pk_idx)| {
816 col_idx_to_pk_idx[pk_idx] = Some(idx);
817 });
818
819 let mut groups = vec![vec![]; pk_cols_num + 1];
820 for (key, group) in &conjunctions.into_iter().chunk_by(|expr| {
821 let input_bits = expr.collect_input_refs(cols_num);
822 if input_bits.count_ones(..) == 1 {
823 let col_idx = input_bits.ones().next().unwrap();
824 col_idx_to_pk_idx[col_idx].unwrap_or(pk_cols_num)
825 } else {
826 pk_cols_num
827 }
828 }) {
829 groups[key].extend(group);
830 }
831
832 groups
833 }
834
835 #[allow(clippy::type_complexity)]
843 fn analyze_group(
844 group: Vec<ExprImpl>,
845 ) -> Result<
846 Option<(
847 Vec<Bound<ScalarImpl>>,
848 Vec<Bound<ScalarImpl>>,
849 Vec<Option<ScalarImpl>>,
850 Vec<ExprImpl>,
851 )>,
852 > {
853 let mut lower_bound_conjunctions = vec![];
854 let mut upper_bound_conjunctions = vec![];
855 let mut eq_conds = vec![];
857 let mut other_conds = vec![];
858
859 for expr in group {
861 if let Some((input_ref, const_expr)) = expr.as_eq_const() {
862 let new_expr = if let Ok(expr) =
863 const_expr.clone().cast_implicit(&input_ref.data_type)
864 {
865 expr
866 } else {
867 match self::cast_compare::cast_compare_for_eq(const_expr, input_ref.data_type) {
868 Ok(ResultForEq::Success(expr)) => expr,
869 Ok(ResultForEq::NeverEqual) => {
870 return Ok(None);
871 }
872 Err(_) => {
873 other_conds.push(expr);
874 continue;
875 }
876 }
877 };
878
879 let Some(new_cond) = new_expr.fold_const()? else {
880 return Ok(None);
882 };
883 if Self::mutual_exclusive_with_eq_conds(&new_cond, &eq_conds) {
884 return Ok(None);
885 }
886 eq_conds = vec![Some(new_cond)];
887 } else if expr.as_is_null().is_some() {
888 if !eq_conds.is_empty() && eq_conds.into_iter().all(|l| l.is_some()) {
889 return Ok(None);
890 }
891 eq_conds = vec![None];
892 } else if let Some((input_ref, in_const_list)) = expr.as_in_const_list() {
893 let mut scalars = HashSet::new();
894 for const_expr in in_const_list {
895 let const_expr = const_expr.cast_implicit(&input_ref.data_type).unwrap();
898 let value = const_expr.fold_const()?;
899 let Some(value) = value else {
900 continue;
901 };
902 scalars.insert(Some(value));
903 }
904 if scalars.is_empty() {
905 return Ok(None);
907 }
908 if !eq_conds.is_empty() {
909 scalars = scalars
910 .intersection(&HashSet::from_iter(eq_conds))
911 .cloned()
912 .collect();
913 if scalars.is_empty() {
914 return Ok(None);
915 }
916 }
917 eq_conds = scalars
919 .into_iter()
920 .sorted_by(DefaultOrd::default_cmp)
921 .collect();
922 } else if let Some((input_ref, op, const_expr)) = expr.as_comparison_const() {
923 let new_expr =
924 if let Ok(expr) = const_expr.clone().cast_implicit(&input_ref.data_type) {
925 expr
926 } else {
927 match self::cast_compare::cast_compare_for_cmp(
928 const_expr,
929 input_ref.data_type,
930 op,
931 ) {
932 Ok(ResultForCmp::Success(expr)) => expr,
933 _ => {
934 other_conds.push(expr);
935 continue;
936 }
937 }
938 };
939 let Some(value) = new_expr.fold_const()? else {
940 return Ok(None);
942 };
943 match op {
944 ExprType::LessThan => {
945 upper_bound_conjunctions.push(Bound::Excluded(value));
946 }
947 ExprType::LessThanOrEqual => {
948 upper_bound_conjunctions.push(Bound::Included(value));
949 }
950 ExprType::GreaterThan => {
951 lower_bound_conjunctions.push(Bound::Excluded(value));
952 }
953 ExprType::GreaterThanOrEqual => {
954 lower_bound_conjunctions.push(Bound::Included(value));
955 }
956 _ => unreachable!(),
957 }
958 } else {
959 other_conds.push(expr);
960 }
961 }
962 Ok(Some((
963 lower_bound_conjunctions,
964 upper_bound_conjunctions,
965 eq_conds,
966 other_conds,
967 )))
968 }
969
970 fn mutual_exclusive_with_eq_conds(
971 new_conds: &ScalarImpl,
972 eq_conds: &[Option<ScalarImpl>],
973 ) -> bool {
974 !eq_conds.is_empty()
975 && eq_conds.iter().all(|l| {
976 if let Some(l) = l {
977 l != new_conds
978 } else {
979 true
980 }
981 })
982 }
983
984 fn merge_lower_bound_conjunctions(lb: Vec<Bound<ScalarImpl>>) -> Bound<ScalarImpl> {
985 lb.into_iter()
986 .max_by(|a, b| {
987 match (a, b) {
989 (Bound::Included(_), Bound::Unbounded) => std::cmp::Ordering::Greater,
990 (Bound::Excluded(_), Bound::Unbounded) => std::cmp::Ordering::Greater,
991 (Bound::Unbounded, Bound::Included(_)) => std::cmp::Ordering::Less,
992 (Bound::Unbounded, Bound::Excluded(_)) => std::cmp::Ordering::Less,
993 (Bound::Unbounded, Bound::Unbounded) => std::cmp::Ordering::Equal,
994 (Bound::Included(a), Bound::Included(b)) => a.default_cmp(b),
995 (Bound::Excluded(a), Bound::Excluded(b)) => a.default_cmp(b),
996 (Bound::Included(a), Bound::Excluded(b)) => match a.default_cmp(b) {
998 std::cmp::Ordering::Equal => std::cmp::Ordering::Less,
999 other => other,
1000 },
1001 (Bound::Excluded(a), Bound::Included(b)) => match a.default_cmp(b) {
1002 std::cmp::Ordering::Equal => std::cmp::Ordering::Greater,
1003 other => other,
1004 },
1005 }
1006 })
1007 .unwrap_or(Bound::Unbounded)
1008 }
1009
1010 fn merge_upper_bound_conjunctions(ub: Vec<Bound<ScalarImpl>>) -> Bound<ScalarImpl> {
1011 ub.into_iter()
1012 .min_by(|a, b| {
1013 match (a, b) {
1015 (Bound::Included(_), Bound::Unbounded) => std::cmp::Ordering::Less,
1016 (Bound::Excluded(_), Bound::Unbounded) => std::cmp::Ordering::Less,
1017 (Bound::Unbounded, Bound::Included(_)) => std::cmp::Ordering::Greater,
1018 (Bound::Unbounded, Bound::Excluded(_)) => std::cmp::Ordering::Greater,
1019 (Bound::Unbounded, Bound::Unbounded) => std::cmp::Ordering::Equal,
1020 (Bound::Included(a), Bound::Included(b)) => a.default_cmp(b),
1021 (Bound::Excluded(a), Bound::Excluded(b)) => a.default_cmp(b),
1022 (Bound::Included(a), Bound::Excluded(b)) => match a.default_cmp(b) {
1024 std::cmp::Ordering::Equal => std::cmp::Ordering::Greater,
1025 other => other,
1026 },
1027 (Bound::Excluded(a), Bound::Included(b)) => match a.default_cmp(b) {
1028 std::cmp::Ordering::Equal => std::cmp::Ordering::Less,
1029 other => other,
1030 },
1031 }
1032 })
1033 .unwrap_or(Bound::Unbounded)
1034 }
1035
1036 fn is_invalid_range(lower_bound: &Bound<ScalarImpl>, upper_bound: &Bound<ScalarImpl>) -> bool {
1037 match (lower_bound, upper_bound) {
1038 (Bound::Included(l), Bound::Included(u)) => l.default_cmp(u).is_gt(), (Bound::Included(l), Bound::Excluded(u)) => l.default_cmp(u).is_ge(), (Bound::Excluded(l), Bound::Included(u)) => l.default_cmp(u).is_ge(), (Bound::Excluded(l), Bound::Excluded(u)) => l.default_cmp(u).is_ge(), _ => false,
1043 }
1044 }
1045
1046 fn extract_eq_conds_within_range(
1047 eq_conds: Vec<Option<ScalarImpl>>,
1048 upper_bound: &Bound<ScalarImpl>,
1049 lower_bound: &Bound<ScalarImpl>,
1050 ) -> Vec<Option<ScalarImpl>> {
1051 if Self::is_invalid_range(lower_bound, upper_bound) {
1054 return vec![];
1055 }
1056
1057 let is_extract_null = upper_bound == &Bound::Unbounded && lower_bound == &Bound::Unbounded;
1058
1059 eq_conds
1060 .into_iter()
1061 .filter(|cond| {
1062 if let Some(cond) = cond {
1063 match lower_bound {
1064 Bound::Included(val) => {
1065 if cond.default_cmp(val).is_lt() {
1066 return false;
1068 }
1069 }
1070 Bound::Excluded(val) => {
1071 if cond.default_cmp(val).is_le() {
1072 return false;
1074 }
1075 }
1076 Bound::Unbounded => {}
1077 }
1078 match upper_bound {
1079 Bound::Included(val) => {
1080 if cond.default_cmp(val).is_gt() {
1081 return false;
1083 }
1084 }
1085 Bound::Excluded(val) => {
1086 if cond.default_cmp(val).is_ge() {
1087 return false;
1089 }
1090 }
1091 Bound::Unbounded => {}
1092 }
1093 true
1094 } else {
1095 is_extract_null
1096 }
1097 })
1098 .collect()
1099 }
1100
1101 #[must_use]
1107 pub fn group_by<F, const N: usize>(self, f: F) -> [Self; N]
1108 where
1109 F: Fn(&ExprImpl) -> usize,
1110 {
1111 const EMPTY: Vec<ExprImpl> = vec![];
1112 let mut groups = [EMPTY; N];
1113 for (key, group) in &self.conjunctions.into_iter().chunk_by(|expr| {
1114 let i = f(expr);
1116 assert!(i < N);
1117 i
1118 }) {
1119 groups[key].extend(group);
1120 }
1121
1122 groups.map(|group| Condition {
1123 conjunctions: group,
1124 })
1125 }
1126
1127 #[must_use]
1128 pub fn rewrite_expr(self, rewriter: &mut (impl ExprRewriter + ?Sized)) -> Self {
1129 Self {
1130 conjunctions: self
1131 .conjunctions
1132 .into_iter()
1133 .map(|expr| rewriter.rewrite_expr(expr))
1134 .collect(),
1135 }
1136 .simplify()
1137 }
1138
1139 pub fn visit_expr<V: ExprVisitor + ?Sized>(&self, visitor: &mut V) {
1140 self.conjunctions
1141 .iter()
1142 .for_each(|expr| visitor.visit_expr(expr));
1143 }
1144
1145 pub fn visit_expr_mut(&mut self, mutator: &mut (impl ExprMutator + ?Sized)) {
1146 self.conjunctions
1147 .iter_mut()
1148 .for_each(|expr| mutator.visit_expr(expr))
1149 }
1150
1151 fn simplify(self) -> Self {
1154 let conjunctions: Vec<_> = self
1156 .conjunctions
1157 .into_iter()
1158 .map(push_down_not)
1159 .map(fold_boolean_constant)
1160 .map(column_self_eq_eliminate)
1161 .flat_map(to_conjunctions)
1162 .collect();
1163 let mut res: Vec<ExprImpl> = Vec::new();
1164 let mut visited: HashSet<ExprImpl> = HashSet::new();
1165 for expr in conjunctions {
1166 if !expr.has_subquery() {
1168 let results_of_factorization = factorization_expr(expr);
1169 res.extend(
1170 results_of_factorization
1171 .clone()
1172 .into_iter()
1173 .filter(|expr| !visited.contains(expr)),
1174 );
1175 visited.extend(results_of_factorization);
1176 } else {
1177 res.push(expr);
1179 }
1180 }
1181 res.retain(|expr| {
1183 if let Some(v) = try_get_bool_constant(expr)
1184 && v
1185 {
1186 false
1187 } else {
1188 true
1189 }
1190 });
1191 for expr in &mut res {
1193 if let Some(v) = try_get_bool_constant(expr)
1194 && !v
1195 {
1196 res.clear();
1197 res.push(ExprImpl::literal_bool(false));
1198 break;
1199 }
1200 }
1201 Self { conjunctions: res }
1202 }
1203}
1204
1205pub struct ConditionDisplay<'a> {
1206 pub condition: &'a Condition,
1207 pub input_schema: &'a Schema,
1208}
1209
1210impl ConditionDisplay<'_> {
1211 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1212 if self.condition.always_true() {
1213 write!(f, "true")
1214 } else {
1215 write!(
1216 f,
1217 "{}",
1218 self.condition
1219 .conjunctions
1220 .iter()
1221 .format_with(" AND ", |expr, f| {
1222 f(&ExprDisplay {
1223 expr,
1224 input_schema: self.input_schema,
1225 })
1226 })
1227 )
1228 }
1229 }
1230}
1231
1232impl fmt::Display for ConditionDisplay<'_> {
1233 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1234 self.fmt(f)
1235 }
1236}
1237
1238impl fmt::Debug for ConditionDisplay<'_> {
1239 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1240 self.fmt(f)
1241 }
1242}
1243
1244mod cast_compare {
1249 use risingwave_common::types::DataType;
1250
1251 use crate::expr::{Expr, ExprImpl, ExprType};
1252
1253 enum ShrinkResult {
1254 OutUpperBound,
1255 OutLowerBound,
1256 InRange(ExprImpl),
1257 }
1258
1259 pub enum ResultForEq {
1260 Success(ExprImpl),
1261 NeverEqual,
1262 }
1263
1264 pub enum ResultForCmp {
1265 Success(ExprImpl),
1266 OutUpperBound,
1267 OutLowerBound,
1268 }
1269
1270 pub fn cast_compare_for_eq(const_expr: ExprImpl, target: DataType) -> Result<ResultForEq, ()> {
1271 match (const_expr.return_type(), &target) {
1272 (DataType::Int64, DataType::Int32)
1273 | (DataType::Int64, DataType::Int16)
1274 | (DataType::Int32, DataType::Int16) => match shrink_integral(const_expr, target)? {
1275 ShrinkResult::InRange(expr) => Ok(ResultForEq::Success(expr)),
1276 ShrinkResult::OutUpperBound | ShrinkResult::OutLowerBound => {
1277 Ok(ResultForEq::NeverEqual)
1278 }
1279 },
1280 _ => Err(()),
1281 }
1282 }
1283
1284 pub fn cast_compare_for_cmp(
1285 const_expr: ExprImpl,
1286 target: DataType,
1287 _op: ExprType,
1288 ) -> Result<ResultForCmp, ()> {
1289 match (const_expr.return_type(), &target) {
1290 (DataType::Int64, DataType::Int32)
1291 | (DataType::Int64, DataType::Int16)
1292 | (DataType::Int32, DataType::Int16) => match shrink_integral(const_expr, target)? {
1293 ShrinkResult::InRange(expr) => Ok(ResultForCmp::Success(expr)),
1294 ShrinkResult::OutUpperBound => Ok(ResultForCmp::OutUpperBound),
1295 ShrinkResult::OutLowerBound => Ok(ResultForCmp::OutLowerBound),
1296 },
1297 _ => Err(()),
1298 }
1299 }
1300
1301 fn shrink_integral(const_expr: ExprImpl, target: DataType) -> Result<ShrinkResult, ()> {
1302 let (upper_bound, lower_bound) = match (const_expr.return_type(), &target) {
1303 (DataType::Int64, DataType::Int32) => (i32::MAX as i64, i32::MIN as i64),
1304 (DataType::Int64, DataType::Int16) | (DataType::Int32, DataType::Int16) => {
1305 (i16::MAX as i64, i16::MIN as i64)
1306 }
1307 _ => unreachable!(),
1308 };
1309 match const_expr.fold_const().map_err(|_| ())? {
1310 Some(scalar) => {
1311 let value = scalar.as_integral();
1312 if value > upper_bound {
1313 Ok(ShrinkResult::OutUpperBound)
1314 } else if value < lower_bound {
1315 Ok(ShrinkResult::OutLowerBound)
1316 } else {
1317 Ok(ShrinkResult::InRange(
1318 const_expr.cast_explicit(&target).unwrap(),
1319 ))
1320 }
1321 }
1322 None => Ok(ShrinkResult::InRange(
1323 const_expr.cast_explicit(&target).unwrap(),
1324 )),
1325 }
1326 }
1327}
1328
1329#[cfg(test)]
1330mod tests {
1331 use rand::Rng;
1332
1333 use super::*;
1334
1335 #[test]
1336 fn test_split() {
1337 let left_col_num = 3;
1338 let right_col_num = 2;
1339
1340 let ty = DataType::Int32;
1341
1342 let mut rng = rand::rng();
1343
1344 let left: ExprImpl = FunctionCall::new(
1345 ExprType::LessThanOrEqual,
1346 vec![
1347 InputRef::new(rng.random_range(0..left_col_num), ty.clone()).into(),
1348 InputRef::new(rng.random_range(0..left_col_num), ty.clone()).into(),
1349 ],
1350 )
1351 .unwrap()
1352 .into();
1353
1354 let right: ExprImpl = FunctionCall::new(
1355 ExprType::LessThan,
1356 vec![
1357 InputRef::new(
1358 rng.random_range(left_col_num..left_col_num + right_col_num),
1359 ty.clone(),
1360 )
1361 .into(),
1362 InputRef::new(
1363 rng.random_range(left_col_num..left_col_num + right_col_num),
1364 ty.clone(),
1365 )
1366 .into(),
1367 ],
1368 )
1369 .unwrap()
1370 .into();
1371
1372 let other: ExprImpl = FunctionCall::new(
1373 ExprType::GreaterThan,
1374 vec![
1375 InputRef::new(rng.random_range(0..left_col_num), ty.clone()).into(),
1376 InputRef::new(
1377 rng.random_range(left_col_num..left_col_num + right_col_num),
1378 ty,
1379 )
1380 .into(),
1381 ],
1382 )
1383 .unwrap()
1384 .into();
1385
1386 let cond = Condition::with_expr(other.clone())
1387 .and(Condition::with_expr(right.clone()))
1388 .and(Condition::with_expr(left.clone()));
1389
1390 let res = cond.split(left_col_num, right_col_num);
1391
1392 assert_eq!(res.0.conjunctions, vec![left]);
1393 assert_eq!(res.1.conjunctions, vec![right]);
1394 assert_eq!(res.2.conjunctions, vec![other]);
1395 }
1396
1397 #[test]
1398 fn test_self_eq_eliminate() {
1399 let left_col_num = 3;
1400 let right_col_num = 2;
1401
1402 let ty = DataType::Int32;
1403
1404 let mut rng = rand::rng();
1405
1406 let x: ExprImpl = InputRef::new(rng.random_range(0..left_col_num), ty.clone()).into();
1407
1408 let left: ExprImpl = FunctionCall::new(ExprType::Equal, vec![x.clone(), x.clone()])
1409 .unwrap()
1410 .into();
1411
1412 let right: ExprImpl = FunctionCall::new(
1413 ExprType::LessThan,
1414 vec![
1415 InputRef::new(
1416 rng.random_range(left_col_num..left_col_num + right_col_num),
1417 ty.clone(),
1418 )
1419 .into(),
1420 InputRef::new(
1421 rng.random_range(left_col_num..left_col_num + right_col_num),
1422 ty.clone(),
1423 )
1424 .into(),
1425 ],
1426 )
1427 .unwrap()
1428 .into();
1429
1430 let other: ExprImpl = FunctionCall::new(
1431 ExprType::GreaterThan,
1432 vec![
1433 InputRef::new(rng.random_range(0..left_col_num), ty.clone()).into(),
1434 InputRef::new(
1435 rng.random_range(left_col_num..left_col_num + right_col_num),
1436 ty,
1437 )
1438 .into(),
1439 ],
1440 )
1441 .unwrap()
1442 .into();
1443
1444 let cond = Condition::with_expr(other.clone())
1445 .and(Condition::with_expr(right.clone()))
1446 .and(Condition::with_expr(left));
1447
1448 let res = cond.split(left_col_num, right_col_num);
1449
1450 let left_res = FunctionCall::new(ExprType::IsNotNull, vec![x])
1451 .unwrap()
1452 .into();
1453
1454 assert_eq!(res.0.conjunctions, vec![left_res]);
1455 assert_eq!(res.1.conjunctions, vec![right]);
1456 assert_eq!(res.2.conjunctions, vec![other]);
1457 }
1458}