risingwave_frontend/utils/
condition.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::Ordering;
16use std::collections::{BTreeMap, HashSet};
17use std::fmt::{self, Debug};
18use std::ops::Bound;
19use std::rc::Rc;
20use std::sync::LazyLock;
21
22use fixedbitset::FixedBitSet;
23use itertools::Itertools;
24use risingwave_common::catalog::{Schema, TableDesc};
25use risingwave_common::types::{DataType, DefaultOrd, ScalarImpl};
26use risingwave_common::util::iter_util::ZipEqFast;
27use risingwave_common::util::scan_range::{ScanRange, is_full_range};
28use risingwave_common::util::sort_util::{OrderType, cmp_rows};
29
30use crate::error::Result;
31use crate::expr::{
32    ExprDisplay, ExprImpl, ExprMutator, ExprRewriter, ExprType, ExprVisitor, FunctionCall,
33    InequalityInputPair, InputRef, collect_input_refs, column_self_eq_eliminate,
34    factorization_expr, fold_boolean_constant, push_down_not, to_conjunctions,
35    try_get_bool_constant,
36};
37use crate::utils::condition::cast_compare::{ResultForCmp, ResultForEq};
38
39#[derive(Debug, Clone, PartialEq, Eq, Hash)]
40pub struct Condition {
41    /// Condition expressions in conjunction form (combined with `AND`)
42    pub conjunctions: Vec<ExprImpl>,
43}
44
45impl IntoIterator for Condition {
46    type IntoIter = std::vec::IntoIter<ExprImpl>;
47    type Item = ExprImpl;
48
49    fn into_iter(self) -> Self::IntoIter {
50        self.conjunctions.into_iter()
51    }
52}
53
54impl fmt::Display for Condition {
55    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56        let mut conjunctions = self.conjunctions.iter();
57        if let Some(expr) = conjunctions.next() {
58            write!(f, "{:?}", expr)?;
59        }
60        if self.always_true() {
61            write!(f, "true")?;
62        } else {
63            for expr in conjunctions {
64                write!(f, " AND {:?}", expr)?;
65            }
66        }
67        Ok(())
68    }
69}
70
71impl Condition {
72    pub fn with_expr(expr: ExprImpl) -> Self {
73        let conjunctions = to_conjunctions(expr);
74
75        Self { conjunctions }.simplify()
76    }
77
78    pub fn true_cond() -> Self {
79        Self {
80            conjunctions: vec![],
81        }
82    }
83
84    pub fn false_cond() -> Self {
85        Self {
86            conjunctions: vec![ExprImpl::literal_bool(false)],
87        }
88    }
89
90    pub fn always_true(&self) -> bool {
91        self.conjunctions.is_empty()
92    }
93
94    pub fn always_false(&self) -> bool {
95        static FALSE: LazyLock<ExprImpl> = LazyLock::new(|| ExprImpl::literal_bool(false));
96        // There is at least one conjunction that is false.
97        !self.conjunctions.is_empty() && self.conjunctions.contains(&*FALSE)
98    }
99
100    /// Convert condition to an expression. If always true, return `None`.
101    pub fn as_expr_unless_true(&self) -> Option<ExprImpl> {
102        if self.always_true() {
103            None
104        } else {
105            Some(self.clone().into())
106        }
107    }
108
109    #[must_use]
110    pub fn and(self, other: Self) -> Self {
111        let mut ret = self;
112        ret.conjunctions.extend(other.conjunctions);
113        ret.simplify()
114    }
115
116    #[must_use]
117    pub fn or(self, other: Self) -> Self {
118        let or_expr = ExprImpl::FunctionCall(
119            FunctionCall::new_unchecked(
120                ExprType::Or,
121                vec![self.into(), other.into()],
122                DataType::Boolean,
123            )
124            .into(),
125        );
126        let ret = Self::with_expr(or_expr);
127        ret.simplify()
128    }
129
130    /// Split the condition expressions into 3 groups: left, right and others
131    #[must_use]
132    pub fn split(self, left_col_num: usize, right_col_num: usize) -> (Self, Self, Self) {
133        let left_bit_map = FixedBitSet::from_iter(0..left_col_num);
134        let right_bit_map = FixedBitSet::from_iter(left_col_num..left_col_num + right_col_num);
135
136        self.group_by::<_, 3>(|expr| {
137            let input_bits = expr.collect_input_refs(left_col_num + right_col_num);
138            if input_bits.is_subset(&left_bit_map) {
139                0
140            } else if input_bits.is_subset(&right_bit_map) {
141                1
142            } else {
143                2
144            }
145        })
146        .into_iter()
147        .next_tuple()
148        .unwrap()
149    }
150
151    /// Collect all `InputRef`s' indexes in the expressions.
152    ///
153    /// # Panics
154    /// Panics if `input_ref >= input_col_num`.
155    pub fn collect_input_refs(&self, input_col_num: usize) -> FixedBitSet {
156        collect_input_refs(input_col_num, &self.conjunctions)
157    }
158
159    /// Split the condition expressions into (N choose 2) + 1 groups: those containing two columns
160    /// from different buckets (and optionally, needing an equal condition between them), and
161    /// others.
162    ///
163    /// `input_num_cols` are the number of columns in each of the input buckets. For instance, with
164    /// bucket0: col0, col1, col2 | bucket1: col3, col4 | bucket2: col5
165    /// `input_num_cols` = [3, 2, 1]
166    ///
167    /// Returns hashmap with keys of the form (col1, col2) where col1 < col2 in terms of their col
168    /// index.
169    ///
170    /// `only_eq`: whether to only split those conditions with an eq condition predicate between two
171    /// buckets.
172    #[must_use]
173    pub fn split_by_input_col_nums(
174        self,
175        input_col_nums: &[usize],
176        only_eq: bool,
177    ) -> (BTreeMap<(usize, usize), Self>, Self) {
178        let mut bitmaps = Vec::with_capacity(input_col_nums.len());
179        let mut cols_seen = 0;
180        for cols in input_col_nums {
181            bitmaps.push(FixedBitSet::from_iter(cols_seen..cols_seen + cols));
182            cols_seen += cols;
183        }
184
185        let mut pairwise_conditions = BTreeMap::new();
186        let mut non_eq_join = vec![];
187
188        for expr in self.conjunctions {
189            let input_bits = expr.collect_input_refs(cols_seen);
190            let mut subset_indices = Vec::with_capacity(input_col_nums.len());
191            for (idx, bitmap) in bitmaps.iter().enumerate() {
192                if !input_bits.is_disjoint(bitmap) {
193                    subset_indices.push(idx);
194                }
195            }
196            if subset_indices.len() != 2 || (only_eq && expr.as_eq_cond().is_none()) {
197                non_eq_join.push(expr);
198            } else {
199                // The key has the canonical ordering (lower, higher)
200                let key = if subset_indices[0] < subset_indices[1] {
201                    (subset_indices[0], subset_indices[1])
202                } else {
203                    (subset_indices[1], subset_indices[0])
204                };
205                let e = pairwise_conditions
206                    .entry(key)
207                    .or_insert_with(Condition::true_cond);
208                e.conjunctions.push(expr);
209            }
210        }
211        (
212            pairwise_conditions,
213            Condition {
214                conjunctions: non_eq_join,
215            },
216        )
217    }
218
219    #[must_use]
220    /// For [`EqJoinPredicate`], separate equality conditions which connect left columns and right
221    /// columns from other conditions.
222    ///
223    /// The equality conditions are transformed into `(left_col_id, right_col_id)` pairs.
224    ///
225    /// [`EqJoinPredicate`]: crate::optimizer::plan_node::EqJoinPredicate
226    pub fn split_eq_keys(
227        self,
228        left_col_num: usize,
229        right_col_num: usize,
230    ) -> (Vec<(InputRef, InputRef, bool)>, Self) {
231        let left_bit_map = FixedBitSet::from_iter(0..left_col_num);
232        let right_bit_map = FixedBitSet::from_iter(left_col_num..left_col_num + right_col_num);
233
234        let (mut eq_keys, mut others) = (vec![], vec![]);
235        self.conjunctions.into_iter().for_each(|expr| {
236            let input_bits = expr.collect_input_refs(left_col_num + right_col_num);
237            if input_bits.is_disjoint(&left_bit_map) || input_bits.is_disjoint(&right_bit_map) {
238                others.push(expr)
239            } else if let Some(columns) = expr.as_eq_cond() {
240                eq_keys.push((columns.0, columns.1, false));
241            } else if let Some(columns) = expr.as_is_not_distinct_from_cond() {
242                eq_keys.push((columns.0, columns.1, true));
243            } else {
244                others.push(expr)
245            }
246        });
247
248        (
249            eq_keys,
250            Condition {
251                conjunctions: others,
252            },
253        )
254    }
255
256    /// For [`EqJoinPredicate`], extract inequality conditions which connect left columns and right
257    /// columns from other conditions.
258    ///
259    /// The inequality conditions are transformed into `(left_col_id, right_col_id, offset)` pairs.
260    ///
261    /// [`EqJoinPredicate`]: crate::optimizer::plan_node::EqJoinPredicate
262    pub(crate) fn extract_inequality_keys(
263        &self,
264        left_col_num: usize,
265        right_col_num: usize,
266    ) -> Vec<(usize, InequalityInputPair)> {
267        let left_bit_map = FixedBitSet::from_iter(0..left_col_num);
268        let right_bit_map = FixedBitSet::from_iter(left_col_num..left_col_num + right_col_num);
269
270        self.conjunctions
271            .iter()
272            .enumerate()
273            .filter_map(|(conjunction_idx, expr)| {
274                let input_bits = expr.collect_input_refs(left_col_num + right_col_num);
275                if input_bits.is_disjoint(&left_bit_map) || input_bits.is_disjoint(&right_bit_map) {
276                    None
277                } else {
278                    expr.as_input_comparison_cond()
279                        .map(|inequality_pair| (conjunction_idx, inequality_pair))
280                }
281            })
282            .collect_vec()
283    }
284
285    /// Split the condition expressions into 2 groups: those referencing `columns` and others which
286    /// are disjoint with columns.
287    #[must_use]
288    pub fn split_disjoint(self, columns: &FixedBitSet) -> (Self, Self) {
289        self.group_by::<_, 2>(|expr| {
290            let input_bits = expr.collect_input_refs(columns.len());
291            input_bits.is_disjoint(columns) as usize
292        })
293        .into_iter()
294        .next_tuple()
295        .unwrap()
296    }
297
298    /// Generate range scans from each arm of `OR` clause and merge them.
299    /// Currently, only support equal type range scans.
300    /// Keep in mind that range scans can not overlap, otherwise duplicate rows will occur.
301    fn disjunctions_to_scan_ranges(
302        table_desc: Rc<TableDesc>,
303        max_split_range_gap: u64,
304        disjunctions: Vec<ExprImpl>,
305    ) -> Result<Option<(Vec<ScanRange>, bool)>> {
306        let disjunctions_result: Result<Vec<(Vec<ScanRange>, Self)>> = disjunctions
307            .into_iter()
308            .map(|x| {
309                Condition {
310                    conjunctions: to_conjunctions(x),
311                }
312                .split_to_scan_ranges(table_desc.clone(), max_split_range_gap)
313            })
314            .collect();
315
316        // If any arm of `OR` clause fails, bail out.
317        let disjunctions_result = disjunctions_result?;
318
319        // If all arms of `OR` clause scan ranges are simply equal condition type, merge all
320        // of them.
321        let all_equal = disjunctions_result
322            .iter()
323            .all(|(scan_ranges, other_condition)| {
324                other_condition.always_true()
325                    && scan_ranges
326                        .iter()
327                        .all(|x| !x.eq_conds.is_empty() && is_full_range(&x.range))
328            });
329
330        if all_equal {
331            // Think about the case (a = 1) or (a = 1 and b = 2).
332            // We should only keep the large one range scan a = 1, because a = 1 overlaps with
333            // (a = 1 and b = 2).
334            let scan_ranges = disjunctions_result
335                .into_iter()
336                .flat_map(|(scan_ranges, _)| scan_ranges)
337                // sort, large one first
338                .sorted_by(|a, b| a.eq_conds.len().cmp(&b.eq_conds.len()))
339                .collect_vec();
340            // Make sure each range never overlaps with others, that's what scan range mean.
341            let mut non_overlap_scan_ranges: Vec<ScanRange> = vec![];
342            for s1 in &scan_ranges {
343                let overlap = non_overlap_scan_ranges.iter().any(|s2| {
344                    #[allow(clippy::disallowed_methods)]
345                    s1.eq_conds
346                        .iter()
347                        .zip(s2.eq_conds.iter())
348                        .all(|(a, b)| a == b)
349                });
350                // If overlap happens, keep the large one and large one always in
351                // `non_overlap_scan_ranges`.
352                // Otherwise, put s1 into `non_overlap_scan_ranges`.
353                if !overlap {
354                    non_overlap_scan_ranges.push(s1.clone());
355                }
356            }
357
358            Ok(Some((non_overlap_scan_ranges, false)))
359        } else {
360            let mut scan_ranges = vec![];
361            for (scan_ranges_chunk, _) in disjunctions_result {
362                if scan_ranges_chunk.is_empty() {
363                    // full scan range
364                    return Ok(None);
365                }
366
367                scan_ranges.extend(scan_ranges_chunk);
368            }
369
370            let order_types = table_desc
371                .pk
372                .iter()
373                .cloned()
374                .map(|x| {
375                    if x.order_type.is_descending() {
376                        x.order_type.reverse()
377                    } else {
378                        x.order_type
379                    }
380                })
381                .collect_vec();
382            scan_ranges.sort_by(|left, right| {
383                let (left_start, _left_end) = &left.convert_to_range();
384                let (right_start, _right_end) = &right.convert_to_range();
385
386                let left_start_vec = match &left_start {
387                    Bound::Included(vec) | Bound::Excluded(vec) => vec,
388                    _ => &vec![],
389                };
390                let right_start_vec = match &right_start {
391                    Bound::Included(vec) | Bound::Excluded(vec) => vec,
392                    _ => &vec![],
393                };
394
395                if left_start_vec.is_empty() && right_start_vec.is_empty() {
396                    return Ordering::Less;
397                }
398
399                if left_start_vec.is_empty() {
400                    return Ordering::Less;
401                }
402
403                if right_start_vec.is_empty() {
404                    return Ordering::Greater;
405                }
406
407                let cmp_column_len = left_start_vec.len().min(right_start_vec.len());
408                cmp_rows(
409                    &left_start_vec[0..cmp_column_len],
410                    &right_start_vec[0..cmp_column_len],
411                    &order_types[0..cmp_column_len],
412                )
413            });
414
415            if scan_ranges.is_empty() {
416                return Ok(None);
417            }
418
419            if scan_ranges.len() == 1 {
420                return Ok(Some((scan_ranges, true)));
421            }
422
423            let mut output_scan_ranges: Vec<ScanRange> = vec![];
424            output_scan_ranges.push(scan_ranges[0].clone());
425            let mut idx = 1;
426            loop {
427                if idx >= scan_ranges.len() {
428                    break;
429                }
430
431                let scan_range_left = output_scan_ranges.last_mut().unwrap();
432                let scan_range_right = &scan_ranges[idx];
433
434                if scan_range_left.eq_conds == scan_range_right.eq_conds {
435                    // range merge
436
437                    if !ScanRange::is_overlap(scan_range_left, scan_range_right, &order_types) {
438                        // not merge
439                        output_scan_ranges.push(scan_range_right.clone());
440                        idx += 1;
441                        continue;
442                    }
443
444                    // merge range
445                    fn merge_bound(
446                        left_scan_range: &Bound<Vec<Option<ScalarImpl>>>,
447                        right_scan_range: &Bound<Vec<Option<ScalarImpl>>>,
448                        order_types: &[OrderType],
449                        left_bound: bool,
450                    ) -> Bound<Vec<Option<ScalarImpl>>> {
451                        let left_scan_range = match left_scan_range {
452                            Bound::Included(vec) | Bound::Excluded(vec) => vec,
453                            Bound::Unbounded => return Bound::Unbounded,
454                        };
455
456                        let right_scan_range = match right_scan_range {
457                            Bound::Included(vec) | Bound::Excluded(vec) => vec,
458                            Bound::Unbounded => return Bound::Unbounded,
459                        };
460
461                        let cmp_len = left_scan_range.len().min(right_scan_range.len());
462
463                        let cmp = cmp_rows(
464                            &left_scan_range[..cmp_len],
465                            &right_scan_range[..cmp_len],
466                            &order_types[..cmp_len],
467                        );
468
469                        let bound = {
470                            if (cmp.is_le() && left_bound) || (cmp.is_ge() && !left_bound) {
471                                left_scan_range.to_vec()
472                            } else {
473                                right_scan_range.to_vec()
474                            }
475                        };
476
477                        // Included Bound just for convenience, the correctness will be guaranteed by the upper level filter.
478                        Bound::Included(bound)
479                    }
480
481                    scan_range_left.range.0 = merge_bound(
482                        &scan_range_left.range.0,
483                        &scan_range_right.range.0,
484                        &order_types,
485                        true,
486                    );
487
488                    scan_range_left.range.1 = merge_bound(
489                        &scan_range_left.range.1,
490                        &scan_range_right.range.1,
491                        &order_types,
492                        false,
493                    );
494
495                    if scan_range_left.is_full_table_scan() {
496                        return Ok(None);
497                    }
498                } else {
499                    output_scan_ranges.push(scan_range_right.clone());
500                }
501
502                idx += 1;
503            }
504
505            Ok(Some((output_scan_ranges, true)))
506        }
507    }
508
509    fn split_row_cmp_to_scan_ranges(
510        &self,
511        table_desc: Rc<TableDesc>,
512    ) -> Result<Option<(Vec<ScanRange>, Self)>> {
513        let (mut row_conjunctions, row_conjunctions_without_struct): (Vec<_>, Vec<_>) =
514            self.conjunctions.clone().into_iter().partition(|expr| {
515                if let Some(f) = expr.as_function_call() {
516                    if let Some(left_input) = f.inputs().get(0)
517                        && let Some(left_input) = left_input.as_function_call()
518                        && matches!(left_input.func_type(), ExprType::Row)
519                        && left_input.inputs().iter().all(|x| x.is_input_ref())
520                        && let Some(right_input) = f.inputs().get(1)
521                        && right_input.is_literal()
522                    {
523                        true
524                    } else {
525                        false
526                    }
527                } else {
528                    false
529                }
530            });
531
532        // optimize for single row conjunctions. More optimisations may come later
533        // For example, (v1,v2,v3) > (1, 2, 3) means all data from (1, 2, 3).
534        // Suppose v1 v2 v3 are both pk, we can push (v1,v2,v3)> (1,2,3) down to scan
535        // Suppose v1 v2 are both pk, we can push (v1,v2)> (1,2) down to scan and add (v1,v2,v3) > (1,2,3) in filter, it is still possible to reduce the value of scan
536        if row_conjunctions.len() == 1 {
537            let row_conjunction = row_conjunctions.pop().unwrap();
538            let row_left_inputs = row_conjunction
539                .as_function_call()
540                .unwrap()
541                .inputs()
542                .get(0)
543                .unwrap()
544                .as_function_call()
545                .unwrap()
546                .inputs();
547            let row_right_literal = row_conjunction
548                .as_function_call()
549                .unwrap()
550                .inputs()
551                .get(1)
552                .unwrap()
553                .as_literal()
554                .unwrap();
555            if !matches!(row_right_literal.get_data(), Some(ScalarImpl::Struct(_))) {
556                return Ok(None);
557            }
558            let row_right_literal_data = row_right_literal.get_data().clone().unwrap();
559            let right_iter = row_right_literal_data.as_struct().fields();
560            let func_type = row_conjunction.as_function_call().unwrap().func_type();
561            if row_left_inputs.len() > 1
562                && (matches!(func_type, ExprType::LessThan)
563                    || matches!(func_type, ExprType::GreaterThan))
564            {
565                let mut pk_struct = vec![];
566                let mut order_type = None;
567                let mut all_added = true;
568                let mut iter = row_left_inputs.iter().zip_eq_fast(right_iter);
569                for column_order in &table_desc.pk {
570                    if let Some((left_expr, right_expr)) = iter.next() {
571                        if left_expr.as_input_ref().unwrap().index != column_order.column_index {
572                            all_added = false;
573                            break;
574                        }
575                        match order_type {
576                            Some(o) => {
577                                if o != column_order.order_type {
578                                    all_added = false;
579                                    break;
580                                }
581                            }
582                            None => order_type = Some(column_order.order_type),
583                        }
584                        pk_struct.push(right_expr.clone());
585                    }
586                }
587
588                // Here it is necessary to determine whether all of row is included in the `ScanRanges`, if so, the data for eq is not needed
589                if !pk_struct.is_empty() {
590                    if !all_added {
591                        let scan_range = ScanRange {
592                            eq_conds: vec![],
593                            range: match func_type {
594                                ExprType::GreaterThan => {
595                                    (Bound::Included(pk_struct), Bound::Unbounded)
596                                }
597                                ExprType::LessThan => {
598                                    (Bound::Unbounded, Bound::Included(pk_struct))
599                                }
600                                _ => unreachable!(),
601                            },
602                        };
603                        return Ok(Some((
604                            vec![scan_range],
605                            Condition {
606                                conjunctions: self.conjunctions.clone(),
607                            },
608                        )));
609                    } else {
610                        let scan_range = ScanRange {
611                            eq_conds: vec![],
612                            range: match func_type {
613                                ExprType::GreaterThan => {
614                                    (Bound::Excluded(pk_struct), Bound::Unbounded)
615                                }
616                                ExprType::LessThan => {
617                                    (Bound::Unbounded, Bound::Excluded(pk_struct))
618                                }
619                                _ => unreachable!(),
620                            },
621                        };
622                        return Ok(Some((
623                            vec![scan_range],
624                            Condition {
625                                conjunctions: row_conjunctions_without_struct,
626                            },
627                        )));
628                    }
629                }
630            }
631        }
632        Ok(None)
633    }
634
635    /// x = 1 AND y = 2 AND z = 3 => [x, y, z]
636    pub fn get_eq_const_input_refs(&self) -> Vec<InputRef> {
637        self.conjunctions
638            .iter()
639            .filter_map(|expr| expr.as_eq_const().map(|(input_ref, _)| input_ref))
640            .collect()
641    }
642
643    /// See also [`ScanRange`](risingwave_pb::batch_plan::ScanRange).
644    pub fn split_to_scan_ranges(
645        self,
646        table_desc: Rc<TableDesc>,
647        max_split_range_gap: u64,
648    ) -> Result<(Vec<ScanRange>, Self)> {
649        fn false_cond() -> (Vec<ScanRange>, Condition) {
650            (vec![], Condition::false_cond())
651        }
652
653        // It's an OR.
654        if self.conjunctions.len() == 1 {
655            if let Some(disjunctions) = self.conjunctions[0].as_or_disjunctions() {
656                if let Some((scan_ranges, maintaining_condition)) =
657                    Self::disjunctions_to_scan_ranges(
658                        table_desc,
659                        max_split_range_gap,
660                        disjunctions,
661                    )?
662                {
663                    if maintaining_condition {
664                        return Ok((scan_ranges, self));
665                    } else {
666                        return Ok((scan_ranges, Condition::true_cond()));
667                    }
668                } else {
669                    return Ok((vec![], self));
670                }
671            }
672        }
673        if let Some((scan_ranges, other_condition)) =
674            self.split_row_cmp_to_scan_ranges(table_desc.clone())?
675        {
676            return Ok((scan_ranges, other_condition));
677        }
678
679        let mut groups = Self::classify_conjunctions_by_pk(self.conjunctions, &table_desc);
680        let mut other_conds = groups.pop().unwrap();
681
682        // Analyze each group and use result to update scan range.
683        let mut scan_range = ScanRange::full_table_scan();
684        for i in 0..table_desc.order_column_indices().len() {
685            let group = std::mem::take(&mut groups[i]);
686            if group.is_empty() {
687                groups.push(other_conds);
688                return Ok((
689                    if scan_range.is_full_table_scan() {
690                        vec![]
691                    } else {
692                        vec![scan_range]
693                    },
694                    Self {
695                        conjunctions: groups[i + 1..].concat(),
696                    },
697                ));
698            }
699
700            let Some((
701                lower_bound_conjunctions,
702                upper_bound_conjunctions,
703                eq_conds,
704                part_of_other_conds,
705            )) = Self::analyze_group(group)?
706            else {
707                return Ok(false_cond());
708            };
709            other_conds.extend(part_of_other_conds.into_iter());
710
711            let lower_bound = Self::merge_lower_bound_conjunctions(lower_bound_conjunctions);
712            let upper_bound = Self::merge_upper_bound_conjunctions(upper_bound_conjunctions);
713
714            if Self::is_invalid_range(&lower_bound, &upper_bound) {
715                return Ok(false_cond());
716            }
717
718            // update scan_range
719            match eq_conds.len() {
720                1 => {
721                    let eq_conds =
722                        Self::extract_eq_conds_within_range(eq_conds, &upper_bound, &lower_bound);
723                    if eq_conds.is_empty() {
724                        return Ok(false_cond());
725                    }
726                    scan_range.eq_conds.extend(eq_conds.into_iter());
727                }
728                0 => {
729                    let convert = |bound| match bound {
730                        Bound::Included(l) => Bound::Included(vec![Some(l)]),
731                        Bound::Excluded(l) => Bound::Excluded(vec![Some(l)]),
732                        Bound::Unbounded => Bound::Unbounded,
733                    };
734                    scan_range.range = (convert(lower_bound), convert(upper_bound));
735                    other_conds.extend(groups[i + 1..].iter().flatten().cloned());
736                    break;
737                }
738                _ => {
739                    // currently we will split IN list to multiple scan ranges immediately
740                    // i.e., a = 1 AND b in (1,2) is handled
741                    // TODO:
742                    // a in (1,2) AND b = 1
743                    // a in (1,2) AND b in (1,2)
744                    // a in (1,2) AND b > 1
745                    let eq_conds =
746                        Self::extract_eq_conds_within_range(eq_conds, &upper_bound, &lower_bound);
747                    if eq_conds.is_empty() {
748                        return Ok(false_cond());
749                    }
750                    other_conds.extend(groups[i + 1..].iter().flatten().cloned());
751                    let scan_ranges = eq_conds
752                        .into_iter()
753                        .map(|lit| {
754                            let mut scan_range = scan_range.clone();
755                            scan_range.eq_conds.push(lit);
756                            scan_range
757                        })
758                        .collect();
759                    return Ok((
760                        scan_ranges,
761                        Self {
762                            conjunctions: other_conds,
763                        },
764                    ));
765                }
766            }
767        }
768
769        Ok((
770            if scan_range.is_full_table_scan() {
771                vec![]
772            } else if table_desc.columns[table_desc.order_column_indices()[0]]
773                .data_type
774                .is_int()
775            {
776                match scan_range.split_small_range(max_split_range_gap) {
777                    Some(scan_ranges) => scan_ranges,
778                    None => vec![scan_range],
779                }
780            } else {
781                vec![scan_range]
782            },
783            Self {
784                conjunctions: other_conds,
785            },
786        ))
787    }
788
789    /// classify conjunctions into groups:
790    /// The i-th group has exprs that only reference the i-th PK column.
791    /// The last group contains all the other exprs.
792    fn classify_conjunctions_by_pk(
793        conjunctions: Vec<ExprImpl>,
794        table_desc: &Rc<TableDesc>,
795    ) -> Vec<Vec<ExprImpl>> {
796        let pk_column_ids = &table_desc.order_column_indices();
797        let pk_cols_num = pk_column_ids.len();
798        let cols_num = table_desc.columns.len();
799
800        let mut col_idx_to_pk_idx = vec![None; cols_num];
801        pk_column_ids.iter().enumerate().for_each(|(idx, pk_idx)| {
802            col_idx_to_pk_idx[*pk_idx] = Some(idx);
803        });
804
805        let mut groups = vec![vec![]; pk_cols_num + 1];
806        for (key, group) in &conjunctions.into_iter().chunk_by(|expr| {
807            let input_bits = expr.collect_input_refs(cols_num);
808            if input_bits.count_ones(..) == 1 {
809                let col_idx = input_bits.ones().next().unwrap();
810                col_idx_to_pk_idx[col_idx].unwrap_or(pk_cols_num)
811            } else {
812                pk_cols_num
813            }
814        }) {
815            groups[key].extend(group);
816        }
817
818        groups
819    }
820
821    /// Extract the following information in a group of conjunctions:
822    /// 1. lower bound conjunctions
823    /// 2. upper bound conjunctions
824    /// 3. eq conditions
825    /// 4. other conditions
826    ///
827    /// return None indicates that this conjunctions is always false
828    #[allow(clippy::type_complexity)]
829    fn analyze_group(
830        group: Vec<ExprImpl>,
831    ) -> Result<
832        Option<(
833            Vec<Bound<ScalarImpl>>,
834            Vec<Bound<ScalarImpl>>,
835            Vec<Option<ScalarImpl>>,
836            Vec<ExprImpl>,
837        )>,
838    > {
839        let mut lower_bound_conjunctions = vec![];
840        let mut upper_bound_conjunctions = vec![];
841        // values in eq_cond are OR'ed
842        let mut eq_conds = vec![];
843        let mut other_conds = vec![];
844
845        // analyze exprs in the group. scan_range is not updated
846        for expr in group {
847            if let Some((input_ref, const_expr)) = expr.as_eq_const() {
848                let new_expr = if let Ok(expr) = const_expr
849                    .clone()
850                    .cast_implicit(input_ref.data_type.clone())
851                {
852                    expr
853                } else {
854                    match self::cast_compare::cast_compare_for_eq(const_expr, input_ref.data_type) {
855                        Ok(ResultForEq::Success(expr)) => expr,
856                        Ok(ResultForEq::NeverEqual) => {
857                            return Ok(None);
858                        }
859                        Err(_) => {
860                            other_conds.push(expr);
861                            continue;
862                        }
863                    }
864                };
865
866                let Some(new_cond) = new_expr.fold_const()? else {
867                    // column = NULL, the result is always NULL.
868                    return Ok(None);
869                };
870                if Self::mutual_exclusive_with_eq_conds(&new_cond, &eq_conds) {
871                    return Ok(None);
872                }
873                eq_conds = vec![Some(new_cond)];
874            } else if expr.as_is_null().is_some() {
875                if !eq_conds.is_empty() && eq_conds.into_iter().all(|l| l.is_some()) {
876                    return Ok(None);
877                }
878                eq_conds = vec![None];
879            } else if let Some((input_ref, in_const_list)) = expr.as_in_const_list() {
880                let mut scalars = HashSet::new();
881                for const_expr in in_const_list {
882                    // The cast should succeed, because otherwise the input_ref is casted
883                    // and thus `as_in_const_list` returns None.
884                    let const_expr = const_expr
885                        .cast_implicit(input_ref.data_type.clone())
886                        .unwrap();
887                    let value = const_expr.fold_const()?;
888                    let Some(value) = value else {
889                        continue;
890                    };
891                    scalars.insert(Some(value));
892                }
893                if scalars.is_empty() {
894                    // There're only NULLs in the in-list
895                    return Ok(None);
896                }
897                if !eq_conds.is_empty() {
898                    scalars = scalars
899                        .intersection(&HashSet::from_iter(eq_conds))
900                        .cloned()
901                        .collect();
902                    if scalars.is_empty() {
903                        return Ok(None);
904                    }
905                }
906                // Sort to ensure a deterministic result for planner test.
907                eq_conds = scalars
908                    .into_iter()
909                    .sorted_by(DefaultOrd::default_cmp)
910                    .collect();
911            } else if let Some((input_ref, op, const_expr)) = expr.as_comparison_const() {
912                let new_expr = if let Ok(expr) = const_expr
913                    .clone()
914                    .cast_implicit(input_ref.data_type.clone())
915                {
916                    expr
917                } else {
918                    match self::cast_compare::cast_compare_for_cmp(
919                        const_expr,
920                        input_ref.data_type,
921                        op,
922                    ) {
923                        Ok(ResultForCmp::Success(expr)) => expr,
924                        _ => {
925                            other_conds.push(expr);
926                            continue;
927                        }
928                    }
929                };
930                let Some(value) = new_expr.fold_const()? else {
931                    // column compare with NULL, the result is always  NULL.
932                    return Ok(None);
933                };
934                match op {
935                    ExprType::LessThan => {
936                        upper_bound_conjunctions.push(Bound::Excluded(value));
937                    }
938                    ExprType::LessThanOrEqual => {
939                        upper_bound_conjunctions.push(Bound::Included(value));
940                    }
941                    ExprType::GreaterThan => {
942                        lower_bound_conjunctions.push(Bound::Excluded(value));
943                    }
944                    ExprType::GreaterThanOrEqual => {
945                        lower_bound_conjunctions.push(Bound::Included(value));
946                    }
947                    _ => unreachable!(),
948                }
949            } else {
950                other_conds.push(expr);
951            }
952        }
953        Ok(Some((
954            lower_bound_conjunctions,
955            upper_bound_conjunctions,
956            eq_conds,
957            other_conds,
958        )))
959    }
960
961    fn mutual_exclusive_with_eq_conds(
962        new_conds: &ScalarImpl,
963        eq_conds: &[Option<ScalarImpl>],
964    ) -> bool {
965        !eq_conds.is_empty()
966            && eq_conds.iter().all(|l| {
967                if let Some(l) = l {
968                    l != new_conds
969                } else {
970                    true
971                }
972            })
973    }
974
975    fn merge_lower_bound_conjunctions(lb: Vec<Bound<ScalarImpl>>) -> Bound<ScalarImpl> {
976        lb.into_iter()
977            .max_by(|a, b| {
978                // For lower bound, Unbounded means -inf
979                match (a, b) {
980                    (Bound::Included(_), Bound::Unbounded) => std::cmp::Ordering::Greater,
981                    (Bound::Excluded(_), Bound::Unbounded) => std::cmp::Ordering::Greater,
982                    (Bound::Unbounded, Bound::Included(_)) => std::cmp::Ordering::Less,
983                    (Bound::Unbounded, Bound::Excluded(_)) => std::cmp::Ordering::Less,
984                    (Bound::Unbounded, Bound::Unbounded) => std::cmp::Ordering::Equal,
985                    (Bound::Included(a), Bound::Included(b)) => a.default_cmp(b),
986                    (Bound::Excluded(a), Bound::Excluded(b)) => a.default_cmp(b),
987                    // excluded bound is strict than included bound so we assume it more greater.
988                    (Bound::Included(a), Bound::Excluded(b)) => match a.default_cmp(b) {
989                        std::cmp::Ordering::Equal => std::cmp::Ordering::Less,
990                        other => other,
991                    },
992                    (Bound::Excluded(a), Bound::Included(b)) => match a.default_cmp(b) {
993                        std::cmp::Ordering::Equal => std::cmp::Ordering::Greater,
994                        other => other,
995                    },
996                }
997            })
998            .unwrap_or(Bound::Unbounded)
999    }
1000
1001    fn merge_upper_bound_conjunctions(ub: Vec<Bound<ScalarImpl>>) -> Bound<ScalarImpl> {
1002        ub.into_iter()
1003            .min_by(|a, b| {
1004                // For upper bound, Unbounded means +inf
1005                match (a, b) {
1006                    (Bound::Included(_), Bound::Unbounded) => std::cmp::Ordering::Less,
1007                    (Bound::Excluded(_), Bound::Unbounded) => std::cmp::Ordering::Less,
1008                    (Bound::Unbounded, Bound::Included(_)) => std::cmp::Ordering::Greater,
1009                    (Bound::Unbounded, Bound::Excluded(_)) => std::cmp::Ordering::Greater,
1010                    (Bound::Unbounded, Bound::Unbounded) => std::cmp::Ordering::Equal,
1011                    (Bound::Included(a), Bound::Included(b)) => a.default_cmp(b),
1012                    (Bound::Excluded(a), Bound::Excluded(b)) => a.default_cmp(b),
1013                    // excluded bound is strict than included bound so we assume it more greater.
1014                    (Bound::Included(a), Bound::Excluded(b)) => match a.default_cmp(b) {
1015                        std::cmp::Ordering::Equal => std::cmp::Ordering::Greater,
1016                        other => other,
1017                    },
1018                    (Bound::Excluded(a), Bound::Included(b)) => match a.default_cmp(b) {
1019                        std::cmp::Ordering::Equal => std::cmp::Ordering::Less,
1020                        other => other,
1021                    },
1022                }
1023            })
1024            .unwrap_or(Bound::Unbounded)
1025    }
1026
1027    fn is_invalid_range(lower_bound: &Bound<ScalarImpl>, upper_bound: &Bound<ScalarImpl>) -> bool {
1028        match (lower_bound, upper_bound) {
1029            (Bound::Included(l), Bound::Included(u)) => l.default_cmp(u).is_gt(), // l > u
1030            (Bound::Included(l), Bound::Excluded(u)) => l.default_cmp(u).is_ge(), // l >= u
1031            (Bound::Excluded(l), Bound::Included(u)) => l.default_cmp(u).is_ge(), // l >= u
1032            (Bound::Excluded(l), Bound::Excluded(u)) => l.default_cmp(u).is_ge(), // l >= u
1033            _ => false,
1034        }
1035    }
1036
1037    fn extract_eq_conds_within_range(
1038        eq_conds: Vec<Option<ScalarImpl>>,
1039        upper_bound: &Bound<ScalarImpl>,
1040        lower_bound: &Bound<ScalarImpl>,
1041    ) -> Vec<Option<ScalarImpl>> {
1042        // defensive programming: for now we will guarantee that the range is valid before calling
1043        // this function
1044        if Self::is_invalid_range(lower_bound, upper_bound) {
1045            return vec![];
1046        }
1047
1048        let is_extract_null = upper_bound == &Bound::Unbounded && lower_bound == &Bound::Unbounded;
1049
1050        eq_conds
1051            .into_iter()
1052            .filter(|cond| {
1053                if let Some(cond) = cond {
1054                    match lower_bound {
1055                        Bound::Included(val) => {
1056                            if cond.default_cmp(val).is_lt() {
1057                                // cond < val
1058                                return false;
1059                            }
1060                        }
1061                        Bound::Excluded(val) => {
1062                            if cond.default_cmp(val).is_le() {
1063                                // cond <= val
1064                                return false;
1065                            }
1066                        }
1067                        Bound::Unbounded => {}
1068                    }
1069                    match upper_bound {
1070                        Bound::Included(val) => {
1071                            if cond.default_cmp(val).is_gt() {
1072                                // cond > val
1073                                return false;
1074                            }
1075                        }
1076                        Bound::Excluded(val) => {
1077                            if cond.default_cmp(val).is_ge() {
1078                                // cond >= val
1079                                return false;
1080                            }
1081                        }
1082                        Bound::Unbounded => {}
1083                    }
1084                    true
1085                } else {
1086                    is_extract_null
1087                }
1088            })
1089            .collect()
1090    }
1091
1092    /// Split the condition expressions into `N` groups.
1093    /// An expression `expr` is in the `i`-th group if `f(expr)==i`.
1094    ///
1095    /// # Panics
1096    /// Panics if `f(expr)>=N`.
1097    #[must_use]
1098    pub fn group_by<F, const N: usize>(self, f: F) -> [Self; N]
1099    where
1100        F: Fn(&ExprImpl) -> usize,
1101    {
1102        const EMPTY: Vec<ExprImpl> = vec![];
1103        let mut groups = [EMPTY; N];
1104        for (key, group) in &self.conjunctions.into_iter().chunk_by(|expr| {
1105            // i-th group
1106            let i = f(expr);
1107            assert!(i < N);
1108            i
1109        }) {
1110            groups[key].extend(group);
1111        }
1112
1113        groups.map(|group| Condition {
1114            conjunctions: group,
1115        })
1116    }
1117
1118    #[must_use]
1119    pub fn rewrite_expr(self, rewriter: &mut (impl ExprRewriter + ?Sized)) -> Self {
1120        Self {
1121            conjunctions: self
1122                .conjunctions
1123                .into_iter()
1124                .map(|expr| rewriter.rewrite_expr(expr))
1125                .collect(),
1126        }
1127        .simplify()
1128    }
1129
1130    pub fn visit_expr<V: ExprVisitor + ?Sized>(&self, visitor: &mut V) {
1131        self.conjunctions
1132            .iter()
1133            .for_each(|expr| visitor.visit_expr(expr));
1134    }
1135
1136    pub fn visit_expr_mut(&mut self, mutator: &mut (impl ExprMutator + ?Sized)) {
1137        self.conjunctions
1138            .iter_mut()
1139            .for_each(|expr| mutator.visit_expr(expr))
1140    }
1141
1142    /// Simplify conditions
1143    /// It simplify conditions by applying constant folding and removing unnecessary conjunctions
1144    fn simplify(self) -> Self {
1145        // boolean constant folding
1146        let conjunctions: Vec<_> = self
1147            .conjunctions
1148            .into_iter()
1149            .map(push_down_not)
1150            .map(fold_boolean_constant)
1151            .map(column_self_eq_eliminate)
1152            .flat_map(to_conjunctions)
1153            .collect();
1154        let mut res: Vec<ExprImpl> = Vec::new();
1155        let mut visited: HashSet<ExprImpl> = HashSet::new();
1156        for expr in conjunctions {
1157            // factorization_expr requires hash-able ExprImpl
1158            if !expr.has_subquery() {
1159                let results_of_factorization = factorization_expr(expr);
1160                res.extend(
1161                    results_of_factorization
1162                        .clone()
1163                        .into_iter()
1164                        .filter(|expr| !visited.contains(expr)),
1165                );
1166                visited.extend(results_of_factorization);
1167            } else {
1168                // for subquery, simply give up factorization
1169                res.push(expr);
1170            }
1171        }
1172        // remove all constant boolean `true`
1173        res.retain(|expr| {
1174            if let Some(v) = try_get_bool_constant(expr)
1175                && v
1176            {
1177                false
1178            } else {
1179                true
1180            }
1181        });
1182        // if there is a `false` in conjunctions, the whole condition will be `false`
1183        for expr in &mut res {
1184            if let Some(v) = try_get_bool_constant(expr) {
1185                if !v {
1186                    res.clear();
1187                    res.push(ExprImpl::literal_bool(false));
1188                    break;
1189                }
1190            }
1191        }
1192        Self { conjunctions: res }
1193    }
1194}
1195
1196pub struct ConditionDisplay<'a> {
1197    pub condition: &'a Condition,
1198    pub input_schema: &'a Schema,
1199}
1200
1201impl ConditionDisplay<'_> {
1202    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1203        if self.condition.always_true() {
1204            write!(f, "true")
1205        } else {
1206            write!(
1207                f,
1208                "{}",
1209                self.condition
1210                    .conjunctions
1211                    .iter()
1212                    .format_with(" AND ", |expr, f| {
1213                        f(&ExprDisplay {
1214                            expr,
1215                            input_schema: self.input_schema,
1216                        })
1217                    })
1218            )
1219        }
1220    }
1221}
1222
1223impl fmt::Display for ConditionDisplay<'_> {
1224    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1225        self.fmt(f)
1226    }
1227}
1228
1229impl fmt::Debug for ConditionDisplay<'_> {
1230    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1231        self.fmt(f)
1232    }
1233}
1234
1235/// `cast_compare` can be summarized as casting to target type which can be compared but can't be
1236/// cast implicitly to, like:
1237/// 1. bigger range -> smaller range in same type, e.g. int64 -> int32
1238/// 2. different type, e.g. float type -> integral type
1239mod cast_compare {
1240    use risingwave_common::types::DataType;
1241
1242    use crate::expr::{Expr, ExprImpl, ExprType};
1243
1244    enum ShrinkResult {
1245        OutUpperBound,
1246        OutLowerBound,
1247        InRange(ExprImpl),
1248    }
1249
1250    pub enum ResultForEq {
1251        Success(ExprImpl),
1252        NeverEqual,
1253    }
1254
1255    pub enum ResultForCmp {
1256        Success(ExprImpl),
1257        OutUpperBound,
1258        OutLowerBound,
1259    }
1260
1261    pub fn cast_compare_for_eq(const_expr: ExprImpl, target: DataType) -> Result<ResultForEq, ()> {
1262        match (const_expr.return_type(), &target) {
1263            (DataType::Int64, DataType::Int32)
1264            | (DataType::Int64, DataType::Int16)
1265            | (DataType::Int32, DataType::Int16) => match shrink_integral(const_expr, target)? {
1266                ShrinkResult::InRange(expr) => Ok(ResultForEq::Success(expr)),
1267                ShrinkResult::OutUpperBound | ShrinkResult::OutLowerBound => {
1268                    Ok(ResultForEq::NeverEqual)
1269                }
1270            },
1271            _ => Err(()),
1272        }
1273    }
1274
1275    pub fn cast_compare_for_cmp(
1276        const_expr: ExprImpl,
1277        target: DataType,
1278        _op: ExprType,
1279    ) -> Result<ResultForCmp, ()> {
1280        match (const_expr.return_type(), &target) {
1281            (DataType::Int64, DataType::Int32)
1282            | (DataType::Int64, DataType::Int16)
1283            | (DataType::Int32, DataType::Int16) => match shrink_integral(const_expr, target)? {
1284                ShrinkResult::InRange(expr) => Ok(ResultForCmp::Success(expr)),
1285                ShrinkResult::OutUpperBound => Ok(ResultForCmp::OutUpperBound),
1286                ShrinkResult::OutLowerBound => Ok(ResultForCmp::OutLowerBound),
1287            },
1288            _ => Err(()),
1289        }
1290    }
1291
1292    fn shrink_integral(const_expr: ExprImpl, target: DataType) -> Result<ShrinkResult, ()> {
1293        let (upper_bound, lower_bound) = match (const_expr.return_type(), &target) {
1294            (DataType::Int64, DataType::Int32) => (i32::MAX as i64, i32::MIN as i64),
1295            (DataType::Int64, DataType::Int16) | (DataType::Int32, DataType::Int16) => {
1296                (i16::MAX as i64, i16::MIN as i64)
1297            }
1298            _ => unreachable!(),
1299        };
1300        match const_expr.fold_const().map_err(|_| ())? {
1301            Some(scalar) => {
1302                let value = scalar.as_integral();
1303                if value > upper_bound {
1304                    Ok(ShrinkResult::OutUpperBound)
1305                } else if value < lower_bound {
1306                    Ok(ShrinkResult::OutLowerBound)
1307                } else {
1308                    Ok(ShrinkResult::InRange(
1309                        const_expr.cast_explicit(target).unwrap(),
1310                    ))
1311                }
1312            }
1313            None => Ok(ShrinkResult::InRange(
1314                const_expr.cast_explicit(target).unwrap(),
1315            )),
1316        }
1317    }
1318}
1319
1320#[cfg(test)]
1321mod tests {
1322    use rand::Rng;
1323
1324    use super::*;
1325
1326    #[test]
1327    fn test_split() {
1328        let left_col_num = 3;
1329        let right_col_num = 2;
1330
1331        let ty = DataType::Int32;
1332
1333        let mut rng = rand::rng();
1334
1335        let left: ExprImpl = FunctionCall::new(
1336            ExprType::LessThanOrEqual,
1337            vec![
1338                InputRef::new(rng.random_range(0..left_col_num), ty.clone()).into(),
1339                InputRef::new(rng.random_range(0..left_col_num), ty.clone()).into(),
1340            ],
1341        )
1342        .unwrap()
1343        .into();
1344
1345        let right: ExprImpl = FunctionCall::new(
1346            ExprType::LessThan,
1347            vec![
1348                InputRef::new(
1349                    rng.random_range(left_col_num..left_col_num + right_col_num),
1350                    ty.clone(),
1351                )
1352                .into(),
1353                InputRef::new(
1354                    rng.random_range(left_col_num..left_col_num + right_col_num),
1355                    ty.clone(),
1356                )
1357                .into(),
1358            ],
1359        )
1360        .unwrap()
1361        .into();
1362
1363        let other: ExprImpl = FunctionCall::new(
1364            ExprType::GreaterThan,
1365            vec![
1366                InputRef::new(rng.random_range(0..left_col_num), ty.clone()).into(),
1367                InputRef::new(
1368                    rng.random_range(left_col_num..left_col_num + right_col_num),
1369                    ty,
1370                )
1371                .into(),
1372            ],
1373        )
1374        .unwrap()
1375        .into();
1376
1377        let cond = Condition::with_expr(other.clone())
1378            .and(Condition::with_expr(right.clone()))
1379            .and(Condition::with_expr(left.clone()));
1380
1381        let res = cond.split(left_col_num, right_col_num);
1382
1383        assert_eq!(res.0.conjunctions, vec![left]);
1384        assert_eq!(res.1.conjunctions, vec![right]);
1385        assert_eq!(res.2.conjunctions, vec![other]);
1386    }
1387
1388    #[test]
1389    fn test_self_eq_eliminate() {
1390        let left_col_num = 3;
1391        let right_col_num = 2;
1392
1393        let ty = DataType::Int32;
1394
1395        let mut rng = rand::rng();
1396
1397        let x: ExprImpl = InputRef::new(rng.random_range(0..left_col_num), ty.clone()).into();
1398
1399        let left: ExprImpl = FunctionCall::new(ExprType::Equal, vec![x.clone(), x.clone()])
1400            .unwrap()
1401            .into();
1402
1403        let right: ExprImpl = FunctionCall::new(
1404            ExprType::LessThan,
1405            vec![
1406                InputRef::new(
1407                    rng.random_range(left_col_num..left_col_num + right_col_num),
1408                    ty.clone(),
1409                )
1410                .into(),
1411                InputRef::new(
1412                    rng.random_range(left_col_num..left_col_num + right_col_num),
1413                    ty.clone(),
1414                )
1415                .into(),
1416            ],
1417        )
1418        .unwrap()
1419        .into();
1420
1421        let other: ExprImpl = FunctionCall::new(
1422            ExprType::GreaterThan,
1423            vec![
1424                InputRef::new(rng.random_range(0..left_col_num), ty.clone()).into(),
1425                InputRef::new(
1426                    rng.random_range(left_col_num..left_col_num + right_col_num),
1427                    ty,
1428                )
1429                .into(),
1430            ],
1431        )
1432        .unwrap()
1433        .into();
1434
1435        let cond = Condition::with_expr(other.clone())
1436            .and(Condition::with_expr(right.clone()))
1437            .and(Condition::with_expr(left.clone()));
1438
1439        let res = cond.split(left_col_num, right_col_num);
1440
1441        let left_res = FunctionCall::new(ExprType::IsNotNull, vec![x])
1442            .unwrap()
1443            .into();
1444
1445        assert_eq!(res.0.conjunctions, vec![left_res]);
1446        assert_eq!(res.1.conjunctions, vec![right]);
1447        assert_eq!(res.2.conjunctions, vec![other]);
1448    }
1449}