risingwave_frontend/utils/
condition.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::Ordering;
16use std::collections::{BTreeMap, HashSet};
17use std::fmt::{self, Debug};
18use std::ops::Bound;
19use std::sync::LazyLock;
20
21use fixedbitset::FixedBitSet;
22use itertools::Itertools;
23use risingwave_common::catalog::Schema;
24use risingwave_common::types::{DataType, DefaultOrd, ScalarImpl};
25use risingwave_common::util::iter_util::ZipEqFast;
26use risingwave_common::util::scan_range::{ScanRange, is_full_range};
27use risingwave_common::util::sort_util::{OrderType, cmp_rows};
28
29use crate::TableCatalog;
30use crate::error::Result;
31use crate::expr::{
32    ExprDisplay, ExprImpl, ExprMutator, ExprRewriter, ExprType, ExprVisitor, FunctionCall,
33    InequalityInputPair, InputRef, collect_input_refs, column_self_eq_eliminate,
34    factorization_expr, fold_boolean_constant, push_down_not, to_conjunctions,
35    try_get_bool_constant,
36};
37use crate::utils::condition::cast_compare::{ResultForCmp, ResultForEq};
38
39#[derive(Debug, Clone, PartialEq, Eq, Hash)]
40pub struct Condition {
41    /// Condition expressions in conjunction form (combined with `AND`)
42    pub conjunctions: Vec<ExprImpl>,
43}
44
45impl IntoIterator for Condition {
46    type IntoIter = std::vec::IntoIter<ExprImpl>;
47    type Item = ExprImpl;
48
49    fn into_iter(self) -> Self::IntoIter {
50        self.conjunctions.into_iter()
51    }
52}
53
54impl fmt::Display for Condition {
55    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56        let mut conjunctions = self.conjunctions.iter();
57        if let Some(expr) = conjunctions.next() {
58            write!(f, "{:?}", expr)?;
59        }
60        if self.always_true() {
61            write!(f, "true")?;
62        } else {
63            for expr in conjunctions {
64                write!(f, " AND {:?}", expr)?;
65            }
66        }
67        Ok(())
68    }
69}
70
71impl Condition {
72    pub fn with_expr(expr: ExprImpl) -> Self {
73        let conjunctions = to_conjunctions(expr);
74
75        Self { conjunctions }.simplify()
76    }
77
78    pub fn true_cond() -> Self {
79        Self {
80            conjunctions: vec![],
81        }
82    }
83
84    pub fn false_cond() -> Self {
85        Self {
86            conjunctions: vec![ExprImpl::literal_bool(false)],
87        }
88    }
89
90    pub fn always_true(&self) -> bool {
91        self.conjunctions.is_empty()
92    }
93
94    pub fn always_false(&self) -> bool {
95        static FALSE: LazyLock<ExprImpl> = LazyLock::new(|| ExprImpl::literal_bool(false));
96        // There is at least one conjunction that is false.
97        !self.conjunctions.is_empty() && self.conjunctions.contains(&*FALSE)
98    }
99
100    /// Convert condition to an expression. If always true, return `None`.
101    pub fn as_expr_unless_true(&self) -> Option<ExprImpl> {
102        if self.always_true() {
103            None
104        } else {
105            Some(self.clone().into())
106        }
107    }
108
109    #[must_use]
110    pub fn and(self, other: Self) -> Self {
111        let mut ret = self;
112        ret.conjunctions.extend(other.conjunctions);
113        ret.simplify()
114    }
115
116    #[must_use]
117    pub fn or(self, other: Self) -> Self {
118        let or_expr = ExprImpl::FunctionCall(
119            FunctionCall::new_unchecked(
120                ExprType::Or,
121                vec![self.into(), other.into()],
122                DataType::Boolean,
123            )
124            .into(),
125        );
126        let ret = Self::with_expr(or_expr);
127        ret.simplify()
128    }
129
130    /// Split the condition expressions into 3 groups: left, right and others
131    #[must_use]
132    pub fn split(self, left_col_num: usize, right_col_num: usize) -> (Self, Self, Self) {
133        let left_bit_map = FixedBitSet::from_iter(0..left_col_num);
134        let right_bit_map = FixedBitSet::from_iter(left_col_num..left_col_num + right_col_num);
135
136        self.group_by::<_, 3>(|expr| {
137            let input_bits = expr.collect_input_refs(left_col_num + right_col_num);
138            if input_bits.is_subset(&left_bit_map) {
139                0
140            } else if input_bits.is_subset(&right_bit_map) {
141                1
142            } else {
143                2
144            }
145        })
146        .into_iter()
147        .next_tuple()
148        .unwrap()
149    }
150
151    /// Collect all `InputRef`s' indexes in the expressions.
152    ///
153    /// # Panics
154    /// Panics if `input_ref >= input_col_num`.
155    pub fn collect_input_refs(&self, input_col_num: usize) -> FixedBitSet {
156        collect_input_refs(input_col_num, &self.conjunctions)
157    }
158
159    /// Split the condition expressions into (N choose 2) + 1 groups: those containing two columns
160    /// from different buckets (and optionally, needing an equal condition between them), and
161    /// others.
162    ///
163    /// `input_num_cols` are the number of columns in each of the input buckets. For instance, with
164    /// bucket0: col0, col1, col2 | bucket1: col3, col4 | bucket2: col5
165    /// `input_num_cols` = [3, 2, 1]
166    ///
167    /// Returns hashmap with keys of the form (col1, col2) where col1 < col2 in terms of their col
168    /// index.
169    ///
170    /// `only_eq`: whether to only split those conditions with an eq condition predicate between two
171    /// buckets.
172    #[must_use]
173    pub fn split_by_input_col_nums(
174        self,
175        input_col_nums: &[usize],
176        only_eq: bool,
177    ) -> (BTreeMap<(usize, usize), Self>, Self) {
178        let mut bitmaps = Vec::with_capacity(input_col_nums.len());
179        let mut cols_seen = 0;
180        for cols in input_col_nums {
181            bitmaps.push(FixedBitSet::from_iter(cols_seen..cols_seen + cols));
182            cols_seen += cols;
183        }
184
185        let mut pairwise_conditions = BTreeMap::new();
186        let mut non_eq_join = vec![];
187
188        for expr in self.conjunctions {
189            let input_bits = expr.collect_input_refs(cols_seen);
190            let mut subset_indices = Vec::with_capacity(input_col_nums.len());
191            for (idx, bitmap) in bitmaps.iter().enumerate() {
192                if !input_bits.is_disjoint(bitmap) {
193                    subset_indices.push(idx);
194                }
195            }
196            if subset_indices.len() != 2 || (only_eq && expr.as_eq_cond().is_none()) {
197                non_eq_join.push(expr);
198            } else {
199                // The key has the canonical ordering (lower, higher)
200                let key = if subset_indices[0] < subset_indices[1] {
201                    (subset_indices[0], subset_indices[1])
202                } else {
203                    (subset_indices[1], subset_indices[0])
204                };
205                let e = pairwise_conditions
206                    .entry(key)
207                    .or_insert_with(Condition::true_cond);
208                e.conjunctions.push(expr);
209            }
210        }
211        (
212            pairwise_conditions,
213            Condition {
214                conjunctions: non_eq_join,
215            },
216        )
217    }
218
219    #[must_use]
220    /// For [`EqJoinPredicate`], separate equality conditions which connect left columns and right
221    /// columns from other conditions.
222    ///
223    /// The equality conditions are transformed into `(left_col_id, right_col_id, null_eq_null)` tuples.
224    ///
225    /// [`EqJoinPredicate`]: crate::optimizer::plan_node::EqJoinPredicate
226    pub fn split_eq_keys(
227        self,
228        left_col_num: usize,
229        right_col_num: usize,
230    ) -> (Vec<(InputRef, InputRef, bool)>, Self) {
231        let left_bit_map = FixedBitSet::from_iter(0..left_col_num);
232        let right_bit_map = FixedBitSet::from_iter(left_col_num..left_col_num + right_col_num);
233
234        let (mut eq_keys, mut others) = (vec![], vec![]);
235        self.conjunctions.into_iter().for_each(|expr| {
236            let input_bits = expr.collect_input_refs(left_col_num + right_col_num);
237            if input_bits.is_disjoint(&left_bit_map) || input_bits.is_disjoint(&right_bit_map) {
238                others.push(expr)
239            } else if let Some(columns) = expr.as_eq_cond() {
240                eq_keys.push((columns.0, columns.1, false));
241            } else if let Some(columns) = expr.as_is_not_distinct_from_cond() {
242                eq_keys.push((columns.0, columns.1, true));
243            } else {
244                others.push(expr)
245            }
246        });
247
248        (
249            eq_keys,
250            Condition {
251                conjunctions: others,
252            },
253        )
254    }
255
256    /// For [`EqJoinPredicate`], extract inequality conditions which connect left columns and right
257    /// columns from other conditions.
258    ///
259    /// Returns a list of `(conjunction_index, InequalityInputPair)` where the pair contains
260    /// the left column index, right column index (NOT offset by `left_col_num`), and the comparison
261    /// operator.
262    ///
263    /// Only pure `InputRef <op> InputRef` conditions are extracted (no offsets like `+ INTERVAL`).
264    ///
265    /// [`EqJoinPredicate`]: crate::optimizer::plan_node::EqJoinPredicate
266    pub(crate) fn extract_inequality_keys(
267        &self,
268        left_col_num: usize,
269        right_col_num: usize,
270    ) -> Vec<(usize, InequalityInputPair)> {
271        let left_bit_map = FixedBitSet::from_iter(0..left_col_num);
272        let right_bit_map = FixedBitSet::from_iter(left_col_num..left_col_num + right_col_num);
273
274        self.conjunctions
275            .iter()
276            .enumerate()
277            .filter_map(|(conjunction_idx, expr)| {
278                let input_bits = expr.collect_input_refs(left_col_num + right_col_num);
279                if input_bits.is_disjoint(&left_bit_map) || input_bits.is_disjoint(&right_bit_map) {
280                    return None;
281                }
282
283                // Use as_comparison_cond which only matches pure InputRef <op> InputRef
284                let (left_input, op, right_input) = expr.as_comparison_cond()?;
285
286                // Ensure left is from left input and right is from right input
287                // as_comparison_cond normalizes to left.index < right.index
288                if left_input.index() < left_col_num
289                    && right_input.index() >= left_col_num
290                    && right_input.index() < left_col_num + right_col_num
291                {
292                    Some((
293                        conjunction_idx,
294                        InequalityInputPair::new(
295                            left_input.index(),
296                            right_input.index() - left_col_num, // Convert to right input index
297                            op,
298                        ),
299                    ))
300                } else {
301                    None
302                }
303            })
304            .collect_vec()
305    }
306
307    /// Split the condition expressions into 2 groups: those referencing `columns` and others which
308    /// are disjoint with columns.
309    #[must_use]
310    pub fn split_disjoint(self, columns: &FixedBitSet) -> (Self, Self) {
311        self.group_by::<_, 2>(|expr| {
312            let input_bits = expr.collect_input_refs(columns.len());
313            input_bits.is_disjoint(columns) as usize
314        })
315        .into_iter()
316        .next_tuple()
317        .unwrap()
318    }
319
320    /// Generate range scans from each arm of `OR` clause and merge them.
321    /// Currently, only support equal type range scans.
322    /// Keep in mind that range scans can not overlap, otherwise duplicate rows will occur.
323    fn disjunctions_to_scan_ranges(
324        table: &TableCatalog,
325        max_split_range_gap: u64,
326        disjunctions: Vec<ExprImpl>,
327    ) -> Result<Option<(Vec<ScanRange>, bool)>> {
328        let disjunctions_result: Result<Vec<(Vec<ScanRange>, Self)>> = disjunctions
329            .into_iter()
330            .map(|x| {
331                Condition {
332                    conjunctions: to_conjunctions(x),
333                }
334                .split_to_scan_ranges(table, max_split_range_gap)
335            })
336            .collect();
337
338        // If any arm of `OR` clause fails, bail out.
339        let disjunctions_result = disjunctions_result?;
340
341        // If all arms of `OR` clause scan ranges are simply equal condition type, merge all
342        // of them.
343        let all_equal = disjunctions_result
344            .iter()
345            .all(|(scan_ranges, other_condition)| {
346                other_condition.always_true()
347                    && scan_ranges
348                        .iter()
349                        .all(|x| !x.eq_conds.is_empty() && is_full_range(&x.range))
350            });
351
352        if all_equal {
353            // Think about the case (a = 1) or (a = 1 and b = 2).
354            // We should only keep the large one range scan a = 1, because a = 1 overlaps with
355            // (a = 1 and b = 2).
356            let scan_ranges = disjunctions_result
357                .into_iter()
358                .flat_map(|(scan_ranges, _)| scan_ranges)
359                // sort, large one first
360                .sorted_by(|a, b| a.eq_conds.len().cmp(&b.eq_conds.len()))
361                .collect_vec();
362            // Make sure each range never overlaps with others, that's what scan range mean.
363            let mut non_overlap_scan_ranges: Vec<ScanRange> = vec![];
364            for s1 in &scan_ranges {
365                let overlap = non_overlap_scan_ranges.iter().any(|s2| {
366                    #[allow(clippy::disallowed_methods)]
367                    s1.eq_conds
368                        .iter()
369                        .zip(s2.eq_conds.iter())
370                        .all(|(a, b)| a == b)
371                });
372                // If overlap happens, keep the large one and large one always in
373                // `non_overlap_scan_ranges`.
374                // Otherwise, put s1 into `non_overlap_scan_ranges`.
375                if !overlap {
376                    non_overlap_scan_ranges.push(s1.clone());
377                }
378            }
379
380            Ok(Some((non_overlap_scan_ranges, false)))
381        } else {
382            let mut scan_ranges = vec![];
383            for (scan_ranges_chunk, _) in disjunctions_result {
384                if scan_ranges_chunk.is_empty() {
385                    // full scan range
386                    return Ok(None);
387                }
388
389                scan_ranges.extend(scan_ranges_chunk);
390            }
391
392            let order_types = table
393                .pk
394                .iter()
395                .cloned()
396                .map(|x| {
397                    if x.order_type.is_descending() {
398                        x.order_type.reverse()
399                    } else {
400                        x.order_type
401                    }
402                })
403                .collect_vec();
404            scan_ranges.sort_by(|left, right| {
405                let (left_start, _left_end) = &left.convert_to_range();
406                let (right_start, _right_end) = &right.convert_to_range();
407
408                let left_start_vec = match &left_start {
409                    Bound::Included(vec) | Bound::Excluded(vec) => vec,
410                    _ => &vec![],
411                };
412                let right_start_vec = match &right_start {
413                    Bound::Included(vec) | Bound::Excluded(vec) => vec,
414                    _ => &vec![],
415                };
416
417                if left_start_vec.is_empty() && right_start_vec.is_empty() {
418                    return Ordering::Less;
419                }
420
421                if left_start_vec.is_empty() {
422                    return Ordering::Less;
423                }
424
425                if right_start_vec.is_empty() {
426                    return Ordering::Greater;
427                }
428
429                let cmp_column_len = left_start_vec.len().min(right_start_vec.len());
430                cmp_rows(
431                    &left_start_vec[0..cmp_column_len],
432                    &right_start_vec[0..cmp_column_len],
433                    &order_types[0..cmp_column_len],
434                )
435            });
436
437            if scan_ranges.is_empty() {
438                return Ok(None);
439            }
440
441            if scan_ranges.len() == 1 {
442                return Ok(Some((scan_ranges, true)));
443            }
444
445            let mut output_scan_ranges: Vec<ScanRange> = vec![];
446            output_scan_ranges.push(scan_ranges[0].clone());
447            let mut idx = 1;
448            loop {
449                if idx >= scan_ranges.len() {
450                    break;
451                }
452
453                let scan_range_left = output_scan_ranges.last_mut().unwrap();
454                let scan_range_right = &scan_ranges[idx];
455
456                if scan_range_left.eq_conds == scan_range_right.eq_conds {
457                    // range merge
458
459                    if !ScanRange::is_overlap(scan_range_left, scan_range_right, &order_types) {
460                        // not merge
461                        output_scan_ranges.push(scan_range_right.clone());
462                        idx += 1;
463                        continue;
464                    }
465
466                    // merge range
467                    fn merge_bound(
468                        left_scan_range: &Bound<Vec<Option<ScalarImpl>>>,
469                        right_scan_range: &Bound<Vec<Option<ScalarImpl>>>,
470                        order_types: &[OrderType],
471                        left_bound: bool,
472                    ) -> Bound<Vec<Option<ScalarImpl>>> {
473                        let left_scan_range = match left_scan_range {
474                            Bound::Included(vec) | Bound::Excluded(vec) => vec,
475                            Bound::Unbounded => return Bound::Unbounded,
476                        };
477
478                        let right_scan_range = match right_scan_range {
479                            Bound::Included(vec) | Bound::Excluded(vec) => vec,
480                            Bound::Unbounded => return Bound::Unbounded,
481                        };
482
483                        let cmp_len = left_scan_range.len().min(right_scan_range.len());
484
485                        let cmp = cmp_rows(
486                            &left_scan_range[..cmp_len],
487                            &right_scan_range[..cmp_len],
488                            &order_types[..cmp_len],
489                        );
490
491                        let bound = {
492                            if (cmp.is_le() && left_bound) || (cmp.is_ge() && !left_bound) {
493                                left_scan_range.clone()
494                            } else {
495                                right_scan_range.clone()
496                            }
497                        };
498
499                        // Included Bound just for convenience, the correctness will be guaranteed by the upper level filter.
500                        Bound::Included(bound)
501                    }
502
503                    scan_range_left.range.0 = merge_bound(
504                        &scan_range_left.range.0,
505                        &scan_range_right.range.0,
506                        &order_types,
507                        true,
508                    );
509
510                    scan_range_left.range.1 = merge_bound(
511                        &scan_range_left.range.1,
512                        &scan_range_right.range.1,
513                        &order_types,
514                        false,
515                    );
516
517                    if scan_range_left.is_full_table_scan() {
518                        return Ok(None);
519                    }
520                } else {
521                    output_scan_ranges.push(scan_range_right.clone());
522                }
523
524                idx += 1;
525            }
526
527            Ok(Some((output_scan_ranges, true)))
528        }
529    }
530
531    fn split_row_cmp_to_scan_ranges(
532        &self,
533        table: &TableCatalog,
534    ) -> Result<Option<(Vec<ScanRange>, Self)>> {
535        let (mut row_conjunctions, row_conjunctions_without_struct): (Vec<_>, Vec<_>) =
536            self.conjunctions.clone().into_iter().partition(|expr| {
537                if let Some(f) = expr.as_function_call() {
538                    if let Some(left_input) = f.inputs().get(0)
539                        && let Some(left_input) = left_input.as_function_call()
540                        && matches!(left_input.func_type(), ExprType::Row)
541                        && left_input.inputs().iter().all(|x| x.is_input_ref())
542                        && let Some(right_input) = f.inputs().get(1)
543                        && right_input.is_literal()
544                    {
545                        true
546                    } else {
547                        false
548                    }
549                } else {
550                    false
551                }
552            });
553        // optimize for single row conjunctions. More optimisations may come later
554        // For example, (v1,v2,v3) > (1, 2, 3) means all data from (1, 2, 3).
555        // Suppose v1 v2 v3 are both pk, we can push (v1,v2,v3)> (1,2,3) down to scan
556        // Suppose v1 v2 are both pk, we can push (v1,v2)> (1,2) down to scan and add (v1,v2,v3) > (1,2,3) in filter, it is still possible to reduce the value of scan
557        if row_conjunctions.len() == 1 {
558            let row_conjunction = row_conjunctions.pop().unwrap();
559            let row_left_inputs = row_conjunction
560                .as_function_call()
561                .unwrap()
562                .inputs()
563                .get(0)
564                .unwrap()
565                .as_function_call()
566                .unwrap()
567                .inputs();
568            let row_right_literal = row_conjunction
569                .as_function_call()
570                .unwrap()
571                .inputs()
572                .get(1)
573                .unwrap()
574                .as_literal()
575                .unwrap();
576            if !matches!(row_right_literal.get_data(), Some(ScalarImpl::Struct(_))) {
577                return Ok(None);
578            }
579            let row_right_literal_data = row_right_literal.get_data().clone().unwrap();
580            let right_iter = row_right_literal_data.as_struct().fields();
581            let func_type = row_conjunction.as_function_call().unwrap().func_type();
582            if row_left_inputs.len() > 1
583                && (matches!(func_type, ExprType::LessThan)
584                    || matches!(func_type, ExprType::GreaterThan))
585            {
586                let mut pk_struct = vec![];
587                let mut order_type = None;
588                let mut all_added = true;
589                let mut iter = row_left_inputs.iter().zip_eq_fast(right_iter);
590                for column_order in &table.pk {
591                    if let Some((left_expr, right_expr)) = iter.next() {
592                        if left_expr.as_input_ref().unwrap().index != column_order.column_index {
593                            all_added = false;
594                            break;
595                        }
596                        match order_type {
597                            Some(o) => {
598                                if o != column_order.order_type {
599                                    all_added = false;
600                                    break;
601                                }
602                            }
603                            None => order_type = Some(column_order.order_type),
604                        }
605                        pk_struct.push(right_expr.clone());
606                    }
607                }
608
609                // Here it is necessary to determine whether all of row is included in the `ScanRanges`, if so, the data for eq is not needed
610                if !pk_struct.is_empty() {
611                    if !all_added {
612                        let scan_range = ScanRange {
613                            eq_conds: vec![],
614                            range: match func_type {
615                                ExprType::GreaterThan => {
616                                    (Bound::Included(pk_struct), Bound::Unbounded)
617                                }
618                                ExprType::LessThan => {
619                                    (Bound::Unbounded, Bound::Included(pk_struct))
620                                }
621                                _ => unreachable!(),
622                            },
623                        };
624                        return Ok(Some((
625                            vec![scan_range],
626                            Condition {
627                                conjunctions: self.conjunctions.clone(),
628                            },
629                        )));
630                    } else {
631                        let scan_range = ScanRange {
632                            eq_conds: vec![],
633                            range: match func_type {
634                                ExprType::GreaterThan => {
635                                    (Bound::Excluded(pk_struct), Bound::Unbounded)
636                                }
637                                ExprType::LessThan => {
638                                    (Bound::Unbounded, Bound::Excluded(pk_struct))
639                                }
640                                _ => unreachable!(),
641                            },
642                        };
643                        return Ok(Some((
644                            vec![scan_range],
645                            Condition {
646                                conjunctions: row_conjunctions_without_struct,
647                            },
648                        )));
649                    }
650                }
651            }
652        }
653        Ok(None)
654    }
655
656    /// x = 1 AND y = 2 AND z = 3 => [x, y, z]
657    pub fn get_eq_const_input_refs(&self) -> Vec<InputRef> {
658        self.conjunctions
659            .iter()
660            .filter_map(|expr| expr.as_eq_const().map(|(input_ref, _)| input_ref))
661            .collect()
662    }
663
664    /// See also [`ScanRange`](risingwave_pb::batch_plan::ScanRange).
665    pub fn split_to_scan_ranges(
666        self,
667        table: &TableCatalog,
668        max_split_range_gap: u64,
669    ) -> Result<(Vec<ScanRange>, Self)> {
670        fn false_cond() -> (Vec<ScanRange>, Condition) {
671            (vec![], Condition::false_cond())
672        }
673
674        // It's an OR.
675        if self.conjunctions.len() == 1
676            && let Some(disjunctions) = self.conjunctions[0].as_or_disjunctions()
677        {
678            if let Some((scan_ranges, maintaining_condition)) =
679                Self::disjunctions_to_scan_ranges(table, max_split_range_gap, disjunctions)?
680            {
681                if maintaining_condition {
682                    return Ok((scan_ranges, self));
683                } else {
684                    return Ok((scan_ranges, Condition::true_cond()));
685                }
686            } else {
687                return Ok((vec![], self));
688            }
689        }
690        if let Some((scan_ranges, other_condition)) = self.split_row_cmp_to_scan_ranges(table)? {
691            return Ok((scan_ranges, other_condition));
692        }
693
694        let mut groups = Self::classify_conjunctions_by_pk(self.conjunctions, table);
695        let mut other_conds = groups.pop().unwrap();
696
697        // Analyze each group and use result to update scan range.
698        let mut scan_range = ScanRange::full_table_scan();
699        for i in 0..table.pk.len() {
700            let group = std::mem::take(&mut groups[i]);
701            if group.is_empty() {
702                groups.push(other_conds);
703                return Ok((
704                    if scan_range.is_full_table_scan() {
705                        vec![]
706                    } else {
707                        vec![scan_range]
708                    },
709                    Self {
710                        conjunctions: groups[i + 1..].concat(),
711                    },
712                ));
713            }
714
715            let Some((
716                lower_bound_conjunctions,
717                upper_bound_conjunctions,
718                eq_conds,
719                part_of_other_conds,
720            )) = Self::analyze_group(group)?
721            else {
722                return Ok(false_cond());
723            };
724            other_conds.extend(part_of_other_conds.into_iter());
725
726            let lower_bound = Self::merge_lower_bound_conjunctions(lower_bound_conjunctions);
727            let upper_bound = Self::merge_upper_bound_conjunctions(upper_bound_conjunctions);
728
729            if Self::is_invalid_range(&lower_bound, &upper_bound) {
730                return Ok(false_cond());
731            }
732
733            // update scan_range
734            match eq_conds.len() {
735                1 => {
736                    let eq_conds =
737                        Self::extract_eq_conds_within_range(eq_conds, &upper_bound, &lower_bound);
738                    if eq_conds.is_empty() {
739                        return Ok(false_cond());
740                    }
741                    scan_range.eq_conds.extend(eq_conds.into_iter());
742                }
743                0 => {
744                    let convert = |bound| match bound {
745                        Bound::Included(l) => Bound::Included(vec![Some(l)]),
746                        Bound::Excluded(l) => Bound::Excluded(vec![Some(l)]),
747                        Bound::Unbounded => Bound::Unbounded,
748                    };
749                    scan_range.range = (convert(lower_bound), convert(upper_bound));
750                    other_conds.extend(groups[i + 1..].iter().flatten().cloned());
751                    break;
752                }
753                _ => {
754                    // currently we will split IN list to multiple scan ranges immediately
755                    // i.e., a = 1 AND b in (1,2) is handled
756                    // TODO:
757                    // a in (1,2) AND b = 1
758                    // a in (1,2) AND b in (1,2)
759                    // a in (1,2) AND b > 1
760                    let eq_conds =
761                        Self::extract_eq_conds_within_range(eq_conds, &upper_bound, &lower_bound);
762                    if eq_conds.is_empty() {
763                        return Ok(false_cond());
764                    }
765                    other_conds.extend(groups[i + 1..].iter().flatten().cloned());
766                    let scan_ranges = eq_conds
767                        .into_iter()
768                        .map(|lit| {
769                            let mut scan_range = scan_range.clone();
770                            scan_range.eq_conds.push(lit);
771                            scan_range
772                        })
773                        .collect();
774                    return Ok((
775                        scan_ranges,
776                        Self {
777                            conjunctions: other_conds,
778                        },
779                    ));
780                }
781            }
782        }
783
784        Ok((
785            if scan_range.is_full_table_scan() {
786                vec![]
787            } else if table.columns[table.pk[0].column_index].data_type.is_int() {
788                match scan_range.split_small_range(max_split_range_gap) {
789                    Some(scan_ranges) => scan_ranges,
790                    None => vec![scan_range],
791                }
792            } else {
793                vec![scan_range]
794            },
795            Self {
796                conjunctions: other_conds,
797            },
798        ))
799    }
800
801    /// classify conjunctions into groups:
802    /// The i-th group has exprs that only reference the i-th PK column.
803    /// The last group contains all the other exprs.
804    fn classify_conjunctions_by_pk(
805        conjunctions: Vec<ExprImpl>,
806        table: &TableCatalog,
807    ) -> Vec<Vec<ExprImpl>> {
808        let pk_cols_num = table.pk.len();
809        let cols_num = table.columns.len();
810
811        let mut col_idx_to_pk_idx = vec![None; cols_num];
812        table
813            .order_column_indices()
814            .enumerate()
815            .for_each(|(idx, pk_idx)| {
816                col_idx_to_pk_idx[pk_idx] = Some(idx);
817            });
818
819        let mut groups = vec![vec![]; pk_cols_num + 1];
820        for (key, group) in &conjunctions.into_iter().chunk_by(|expr| {
821            let input_bits = expr.collect_input_refs(cols_num);
822            if input_bits.count_ones(..) == 1 {
823                let col_idx = input_bits.ones().next().unwrap();
824                col_idx_to_pk_idx[col_idx].unwrap_or(pk_cols_num)
825            } else {
826                pk_cols_num
827            }
828        }) {
829            groups[key].extend(group);
830        }
831
832        groups
833    }
834
835    /// Extract the following information in a group of conjunctions:
836    /// 1. lower bound conjunctions
837    /// 2. upper bound conjunctions
838    /// 3. eq conditions
839    /// 4. other conditions
840    ///
841    /// return None indicates that this conjunctions is always false
842    #[allow(clippy::type_complexity)]
843    fn analyze_group(
844        group: Vec<ExprImpl>,
845    ) -> Result<
846        Option<(
847            Vec<Bound<ScalarImpl>>,
848            Vec<Bound<ScalarImpl>>,
849            Vec<Option<ScalarImpl>>,
850            Vec<ExprImpl>,
851        )>,
852    > {
853        let mut lower_bound_conjunctions = vec![];
854        let mut upper_bound_conjunctions = vec![];
855        // values in eq_cond are OR'ed
856        let mut eq_conds = vec![];
857        let mut other_conds = vec![];
858
859        // analyze exprs in the group. scan_range is not updated
860        for expr in group {
861            if let Some((input_ref, const_expr)) = expr.as_eq_const() {
862                let new_expr = if let Ok(expr) =
863                    const_expr.clone().cast_implicit(&input_ref.data_type)
864                {
865                    expr
866                } else {
867                    match self::cast_compare::cast_compare_for_eq(const_expr, input_ref.data_type) {
868                        Ok(ResultForEq::Success(expr)) => expr,
869                        Ok(ResultForEq::NeverEqual) => {
870                            return Ok(None);
871                        }
872                        Err(_) => {
873                            other_conds.push(expr);
874                            continue;
875                        }
876                    }
877                };
878
879                let Some(new_cond) = new_expr.fold_const()? else {
880                    // column = NULL, the result is always NULL.
881                    return Ok(None);
882                };
883                if Self::mutual_exclusive_with_eq_conds(&new_cond, &eq_conds) {
884                    return Ok(None);
885                }
886                eq_conds = vec![Some(new_cond)];
887            } else if expr.as_is_null().is_some() {
888                if !eq_conds.is_empty() && eq_conds.into_iter().all(|l| l.is_some()) {
889                    return Ok(None);
890                }
891                eq_conds = vec![None];
892            } else if let Some((input_ref, in_const_list)) = expr.as_in_const_list() {
893                let mut scalars = HashSet::new();
894                for const_expr in in_const_list {
895                    // The cast should succeed, because otherwise the input_ref is casted
896                    // and thus `as_in_const_list` returns None.
897                    let const_expr = const_expr.cast_implicit(&input_ref.data_type).unwrap();
898                    let value = const_expr.fold_const()?;
899                    let Some(value) = value else {
900                        continue;
901                    };
902                    scalars.insert(Some(value));
903                }
904                if scalars.is_empty() {
905                    // There're only NULLs in the in-list
906                    return Ok(None);
907                }
908                if !eq_conds.is_empty() {
909                    scalars = scalars
910                        .intersection(&HashSet::from_iter(eq_conds))
911                        .cloned()
912                        .collect();
913                    if scalars.is_empty() {
914                        return Ok(None);
915                    }
916                }
917                // Sort to ensure a deterministic result for planner test.
918                eq_conds = scalars
919                    .into_iter()
920                    .sorted_by(DefaultOrd::default_cmp)
921                    .collect();
922            } else if let Some((input_ref, op, const_expr)) = expr.as_comparison_const() {
923                let new_expr =
924                    if let Ok(expr) = const_expr.clone().cast_implicit(&input_ref.data_type) {
925                        expr
926                    } else {
927                        match self::cast_compare::cast_compare_for_cmp(
928                            const_expr,
929                            input_ref.data_type,
930                            op,
931                        ) {
932                            Ok(ResultForCmp::Success(expr)) => expr,
933                            _ => {
934                                other_conds.push(expr);
935                                continue;
936                            }
937                        }
938                    };
939                let Some(value) = new_expr.fold_const()? else {
940                    // column compare with NULL, the result is always  NULL.
941                    return Ok(None);
942                };
943                match op {
944                    ExprType::LessThan => {
945                        upper_bound_conjunctions.push(Bound::Excluded(value));
946                    }
947                    ExprType::LessThanOrEqual => {
948                        upper_bound_conjunctions.push(Bound::Included(value));
949                    }
950                    ExprType::GreaterThan => {
951                        lower_bound_conjunctions.push(Bound::Excluded(value));
952                    }
953                    ExprType::GreaterThanOrEqual => {
954                        lower_bound_conjunctions.push(Bound::Included(value));
955                    }
956                    _ => unreachable!(),
957                }
958            } else {
959                other_conds.push(expr);
960            }
961        }
962        Ok(Some((
963            lower_bound_conjunctions,
964            upper_bound_conjunctions,
965            eq_conds,
966            other_conds,
967        )))
968    }
969
970    fn mutual_exclusive_with_eq_conds(
971        new_conds: &ScalarImpl,
972        eq_conds: &[Option<ScalarImpl>],
973    ) -> bool {
974        !eq_conds.is_empty()
975            && eq_conds.iter().all(|l| {
976                if let Some(l) = l {
977                    l != new_conds
978                } else {
979                    true
980                }
981            })
982    }
983
984    fn merge_lower_bound_conjunctions(lb: Vec<Bound<ScalarImpl>>) -> Bound<ScalarImpl> {
985        lb.into_iter()
986            .max_by(|a, b| {
987                // For lower bound, Unbounded means -inf
988                match (a, b) {
989                    (Bound::Included(_), Bound::Unbounded) => std::cmp::Ordering::Greater,
990                    (Bound::Excluded(_), Bound::Unbounded) => std::cmp::Ordering::Greater,
991                    (Bound::Unbounded, Bound::Included(_)) => std::cmp::Ordering::Less,
992                    (Bound::Unbounded, Bound::Excluded(_)) => std::cmp::Ordering::Less,
993                    (Bound::Unbounded, Bound::Unbounded) => std::cmp::Ordering::Equal,
994                    (Bound::Included(a), Bound::Included(b)) => a.default_cmp(b),
995                    (Bound::Excluded(a), Bound::Excluded(b)) => a.default_cmp(b),
996                    // excluded bound is strict than included bound so we assume it more greater.
997                    (Bound::Included(a), Bound::Excluded(b)) => match a.default_cmp(b) {
998                        std::cmp::Ordering::Equal => std::cmp::Ordering::Less,
999                        other => other,
1000                    },
1001                    (Bound::Excluded(a), Bound::Included(b)) => match a.default_cmp(b) {
1002                        std::cmp::Ordering::Equal => std::cmp::Ordering::Greater,
1003                        other => other,
1004                    },
1005                }
1006            })
1007            .unwrap_or(Bound::Unbounded)
1008    }
1009
1010    fn merge_upper_bound_conjunctions(ub: Vec<Bound<ScalarImpl>>) -> Bound<ScalarImpl> {
1011        ub.into_iter()
1012            .min_by(|a, b| {
1013                // For upper bound, Unbounded means +inf
1014                match (a, b) {
1015                    (Bound::Included(_), Bound::Unbounded) => std::cmp::Ordering::Less,
1016                    (Bound::Excluded(_), Bound::Unbounded) => std::cmp::Ordering::Less,
1017                    (Bound::Unbounded, Bound::Included(_)) => std::cmp::Ordering::Greater,
1018                    (Bound::Unbounded, Bound::Excluded(_)) => std::cmp::Ordering::Greater,
1019                    (Bound::Unbounded, Bound::Unbounded) => std::cmp::Ordering::Equal,
1020                    (Bound::Included(a), Bound::Included(b)) => a.default_cmp(b),
1021                    (Bound::Excluded(a), Bound::Excluded(b)) => a.default_cmp(b),
1022                    // excluded bound is strict than included bound so we assume it more greater.
1023                    (Bound::Included(a), Bound::Excluded(b)) => match a.default_cmp(b) {
1024                        std::cmp::Ordering::Equal => std::cmp::Ordering::Greater,
1025                        other => other,
1026                    },
1027                    (Bound::Excluded(a), Bound::Included(b)) => match a.default_cmp(b) {
1028                        std::cmp::Ordering::Equal => std::cmp::Ordering::Less,
1029                        other => other,
1030                    },
1031                }
1032            })
1033            .unwrap_or(Bound::Unbounded)
1034    }
1035
1036    fn is_invalid_range(lower_bound: &Bound<ScalarImpl>, upper_bound: &Bound<ScalarImpl>) -> bool {
1037        match (lower_bound, upper_bound) {
1038            (Bound::Included(l), Bound::Included(u)) => l.default_cmp(u).is_gt(), // l > u
1039            (Bound::Included(l), Bound::Excluded(u)) => l.default_cmp(u).is_ge(), // l >= u
1040            (Bound::Excluded(l), Bound::Included(u)) => l.default_cmp(u).is_ge(), // l >= u
1041            (Bound::Excluded(l), Bound::Excluded(u)) => l.default_cmp(u).is_ge(), // l >= u
1042            _ => false,
1043        }
1044    }
1045
1046    fn extract_eq_conds_within_range(
1047        eq_conds: Vec<Option<ScalarImpl>>,
1048        upper_bound: &Bound<ScalarImpl>,
1049        lower_bound: &Bound<ScalarImpl>,
1050    ) -> Vec<Option<ScalarImpl>> {
1051        // defensive programming: for now we will guarantee that the range is valid before calling
1052        // this function
1053        if Self::is_invalid_range(lower_bound, upper_bound) {
1054            return vec![];
1055        }
1056
1057        let is_extract_null = upper_bound == &Bound::Unbounded && lower_bound == &Bound::Unbounded;
1058
1059        eq_conds
1060            .into_iter()
1061            .filter(|cond| {
1062                if let Some(cond) = cond {
1063                    match lower_bound {
1064                        Bound::Included(val) => {
1065                            if cond.default_cmp(val).is_lt() {
1066                                // cond < val
1067                                return false;
1068                            }
1069                        }
1070                        Bound::Excluded(val) => {
1071                            if cond.default_cmp(val).is_le() {
1072                                // cond <= val
1073                                return false;
1074                            }
1075                        }
1076                        Bound::Unbounded => {}
1077                    }
1078                    match upper_bound {
1079                        Bound::Included(val) => {
1080                            if cond.default_cmp(val).is_gt() {
1081                                // cond > val
1082                                return false;
1083                            }
1084                        }
1085                        Bound::Excluded(val) => {
1086                            if cond.default_cmp(val).is_ge() {
1087                                // cond >= val
1088                                return false;
1089                            }
1090                        }
1091                        Bound::Unbounded => {}
1092                    }
1093                    true
1094                } else {
1095                    is_extract_null
1096                }
1097            })
1098            .collect()
1099    }
1100
1101    /// Split the condition expressions into `N` groups.
1102    /// An expression `expr` is in the `i`-th group if `f(expr)==i`.
1103    ///
1104    /// # Panics
1105    /// Panics if `f(expr)>=N`.
1106    #[must_use]
1107    pub fn group_by<F, const N: usize>(self, f: F) -> [Self; N]
1108    where
1109        F: Fn(&ExprImpl) -> usize,
1110    {
1111        const EMPTY: Vec<ExprImpl> = vec![];
1112        let mut groups = [EMPTY; N];
1113        for (key, group) in &self.conjunctions.into_iter().chunk_by(|expr| {
1114            // i-th group
1115            let i = f(expr);
1116            assert!(i < N);
1117            i
1118        }) {
1119            groups[key].extend(group);
1120        }
1121
1122        groups.map(|group| Condition {
1123            conjunctions: group,
1124        })
1125    }
1126
1127    #[must_use]
1128    pub fn rewrite_expr(self, rewriter: &mut (impl ExprRewriter + ?Sized)) -> Self {
1129        Self {
1130            conjunctions: self
1131                .conjunctions
1132                .into_iter()
1133                .map(|expr| rewriter.rewrite_expr(expr))
1134                .collect(),
1135        }
1136        .simplify()
1137    }
1138
1139    pub fn visit_expr<V: ExprVisitor + ?Sized>(&self, visitor: &mut V) {
1140        self.conjunctions
1141            .iter()
1142            .for_each(|expr| visitor.visit_expr(expr));
1143    }
1144
1145    pub fn visit_expr_mut(&mut self, mutator: &mut (impl ExprMutator + ?Sized)) {
1146        self.conjunctions
1147            .iter_mut()
1148            .for_each(|expr| mutator.visit_expr(expr))
1149    }
1150
1151    /// Simplify conditions
1152    /// It simplify conditions by applying constant folding and removing unnecessary conjunctions
1153    fn simplify(self) -> Self {
1154        // boolean constant folding
1155        let conjunctions: Vec<_> = self
1156            .conjunctions
1157            .into_iter()
1158            .map(push_down_not)
1159            .map(fold_boolean_constant)
1160            .map(column_self_eq_eliminate)
1161            .flat_map(to_conjunctions)
1162            .collect();
1163        let mut res: Vec<ExprImpl> = Vec::new();
1164        let mut visited: HashSet<ExprImpl> = HashSet::new();
1165        for expr in conjunctions {
1166            // factorization_expr requires hash-able ExprImpl
1167            if !expr.has_subquery() {
1168                let results_of_factorization = factorization_expr(expr);
1169                res.extend(
1170                    results_of_factorization
1171                        .clone()
1172                        .into_iter()
1173                        .filter(|expr| !visited.contains(expr)),
1174                );
1175                visited.extend(results_of_factorization);
1176            } else {
1177                // for subquery, simply give up factorization
1178                res.push(expr);
1179            }
1180        }
1181        // remove all constant boolean `true`
1182        res.retain(|expr| {
1183            if let Some(v) = try_get_bool_constant(expr)
1184                && v
1185            {
1186                false
1187            } else {
1188                true
1189            }
1190        });
1191        // if there is a `false` in conjunctions, the whole condition will be `false`
1192        for expr in &mut res {
1193            if let Some(v) = try_get_bool_constant(expr)
1194                && !v
1195            {
1196                res.clear();
1197                res.push(ExprImpl::literal_bool(false));
1198                break;
1199            }
1200        }
1201        Self { conjunctions: res }
1202    }
1203}
1204
1205pub struct ConditionDisplay<'a> {
1206    pub condition: &'a Condition,
1207    pub input_schema: &'a Schema,
1208}
1209
1210impl ConditionDisplay<'_> {
1211    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1212        if self.condition.always_true() {
1213            write!(f, "true")
1214        } else {
1215            write!(
1216                f,
1217                "{}",
1218                self.condition
1219                    .conjunctions
1220                    .iter()
1221                    .format_with(" AND ", |expr, f| {
1222                        f(&ExprDisplay {
1223                            expr,
1224                            input_schema: self.input_schema,
1225                        })
1226                    })
1227            )
1228        }
1229    }
1230}
1231
1232impl fmt::Display for ConditionDisplay<'_> {
1233    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1234        self.fmt(f)
1235    }
1236}
1237
1238impl fmt::Debug for ConditionDisplay<'_> {
1239    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1240        self.fmt(f)
1241    }
1242}
1243
1244/// `cast_compare` can be summarized as casting to target type which can be compared but can't be
1245/// cast implicitly to, like:
1246/// 1. bigger range -> smaller range in same type, e.g. int64 -> int32
1247/// 2. different type, e.g. float type -> integral type
1248mod cast_compare {
1249    use risingwave_common::types::DataType;
1250
1251    use crate::expr::{Expr, ExprImpl, ExprType};
1252
1253    enum ShrinkResult {
1254        OutUpperBound,
1255        OutLowerBound,
1256        InRange(ExprImpl),
1257    }
1258
1259    pub enum ResultForEq {
1260        Success(ExprImpl),
1261        NeverEqual,
1262    }
1263
1264    pub enum ResultForCmp {
1265        Success(ExprImpl),
1266        OutUpperBound,
1267        OutLowerBound,
1268    }
1269
1270    pub fn cast_compare_for_eq(const_expr: ExprImpl, target: DataType) -> Result<ResultForEq, ()> {
1271        match (const_expr.return_type(), &target) {
1272            (DataType::Int64, DataType::Int32)
1273            | (DataType::Int64, DataType::Int16)
1274            | (DataType::Int32, DataType::Int16) => match shrink_integral(const_expr, target)? {
1275                ShrinkResult::InRange(expr) => Ok(ResultForEq::Success(expr)),
1276                ShrinkResult::OutUpperBound | ShrinkResult::OutLowerBound => {
1277                    Ok(ResultForEq::NeverEqual)
1278                }
1279            },
1280            _ => Err(()),
1281        }
1282    }
1283
1284    pub fn cast_compare_for_cmp(
1285        const_expr: ExprImpl,
1286        target: DataType,
1287        _op: ExprType,
1288    ) -> Result<ResultForCmp, ()> {
1289        match (const_expr.return_type(), &target) {
1290            (DataType::Int64, DataType::Int32)
1291            | (DataType::Int64, DataType::Int16)
1292            | (DataType::Int32, DataType::Int16) => match shrink_integral(const_expr, target)? {
1293                ShrinkResult::InRange(expr) => Ok(ResultForCmp::Success(expr)),
1294                ShrinkResult::OutUpperBound => Ok(ResultForCmp::OutUpperBound),
1295                ShrinkResult::OutLowerBound => Ok(ResultForCmp::OutLowerBound),
1296            },
1297            _ => Err(()),
1298        }
1299    }
1300
1301    fn shrink_integral(const_expr: ExprImpl, target: DataType) -> Result<ShrinkResult, ()> {
1302        let (upper_bound, lower_bound) = match (const_expr.return_type(), &target) {
1303            (DataType::Int64, DataType::Int32) => (i32::MAX as i64, i32::MIN as i64),
1304            (DataType::Int64, DataType::Int16) | (DataType::Int32, DataType::Int16) => {
1305                (i16::MAX as i64, i16::MIN as i64)
1306            }
1307            _ => unreachable!(),
1308        };
1309        match const_expr.fold_const().map_err(|_| ())? {
1310            Some(scalar) => {
1311                let value = scalar.as_integral();
1312                if value > upper_bound {
1313                    Ok(ShrinkResult::OutUpperBound)
1314                } else if value < lower_bound {
1315                    Ok(ShrinkResult::OutLowerBound)
1316                } else {
1317                    Ok(ShrinkResult::InRange(
1318                        const_expr.cast_explicit(&target).unwrap(),
1319                    ))
1320                }
1321            }
1322            None => Ok(ShrinkResult::InRange(
1323                const_expr.cast_explicit(&target).unwrap(),
1324            )),
1325        }
1326    }
1327}
1328
1329#[cfg(test)]
1330mod tests {
1331    use rand::Rng;
1332
1333    use super::*;
1334
1335    #[test]
1336    fn test_split() {
1337        let left_col_num = 3;
1338        let right_col_num = 2;
1339
1340        let ty = DataType::Int32;
1341
1342        let mut rng = rand::rng();
1343
1344        let left: ExprImpl = FunctionCall::new(
1345            ExprType::LessThanOrEqual,
1346            vec![
1347                InputRef::new(rng.random_range(0..left_col_num), ty.clone()).into(),
1348                InputRef::new(rng.random_range(0..left_col_num), ty.clone()).into(),
1349            ],
1350        )
1351        .unwrap()
1352        .into();
1353
1354        let right: ExprImpl = FunctionCall::new(
1355            ExprType::LessThan,
1356            vec![
1357                InputRef::new(
1358                    rng.random_range(left_col_num..left_col_num + right_col_num),
1359                    ty.clone(),
1360                )
1361                .into(),
1362                InputRef::new(
1363                    rng.random_range(left_col_num..left_col_num + right_col_num),
1364                    ty.clone(),
1365                )
1366                .into(),
1367            ],
1368        )
1369        .unwrap()
1370        .into();
1371
1372        let other: ExprImpl = FunctionCall::new(
1373            ExprType::GreaterThan,
1374            vec![
1375                InputRef::new(rng.random_range(0..left_col_num), ty.clone()).into(),
1376                InputRef::new(
1377                    rng.random_range(left_col_num..left_col_num + right_col_num),
1378                    ty,
1379                )
1380                .into(),
1381            ],
1382        )
1383        .unwrap()
1384        .into();
1385
1386        let cond = Condition::with_expr(other.clone())
1387            .and(Condition::with_expr(right.clone()))
1388            .and(Condition::with_expr(left.clone()));
1389
1390        let res = cond.split(left_col_num, right_col_num);
1391
1392        assert_eq!(res.0.conjunctions, vec![left]);
1393        assert_eq!(res.1.conjunctions, vec![right]);
1394        assert_eq!(res.2.conjunctions, vec![other]);
1395    }
1396
1397    #[test]
1398    fn test_self_eq_eliminate() {
1399        let left_col_num = 3;
1400        let right_col_num = 2;
1401
1402        let ty = DataType::Int32;
1403
1404        let mut rng = rand::rng();
1405
1406        let x: ExprImpl = InputRef::new(rng.random_range(0..left_col_num), ty.clone()).into();
1407
1408        let left: ExprImpl = FunctionCall::new(ExprType::Equal, vec![x.clone(), x.clone()])
1409            .unwrap()
1410            .into();
1411
1412        let right: ExprImpl = FunctionCall::new(
1413            ExprType::LessThan,
1414            vec![
1415                InputRef::new(
1416                    rng.random_range(left_col_num..left_col_num + right_col_num),
1417                    ty.clone(),
1418                )
1419                .into(),
1420                InputRef::new(
1421                    rng.random_range(left_col_num..left_col_num + right_col_num),
1422                    ty.clone(),
1423                )
1424                .into(),
1425            ],
1426        )
1427        .unwrap()
1428        .into();
1429
1430        let other: ExprImpl = FunctionCall::new(
1431            ExprType::GreaterThan,
1432            vec![
1433                InputRef::new(rng.random_range(0..left_col_num), ty.clone()).into(),
1434                InputRef::new(
1435                    rng.random_range(left_col_num..left_col_num + right_col_num),
1436                    ty,
1437                )
1438                .into(),
1439            ],
1440        )
1441        .unwrap()
1442        .into();
1443
1444        let cond = Condition::with_expr(other.clone())
1445            .and(Condition::with_expr(right.clone()))
1446            .and(Condition::with_expr(left));
1447
1448        let res = cond.split(left_col_num, right_col_num);
1449
1450        let left_res = FunctionCall::new(ExprType::IsNotNull, vec![x])
1451            .unwrap()
1452            .into();
1453
1454        assert_eq!(res.0.conjunctions, vec![left_res]);
1455        assert_eq!(res.1.conjunctions, vec![right]);
1456        assert_eq!(res.2.conjunctions, vec![other]);
1457    }
1458}