risingwave_frontend/optimizer/rule/
index_selection_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! # Index selection cost matrix
16//!
17//! |`column_idx`| 0   |  1 | 2  | 3  | 4  | remark |
18//! |-----------|-----|----|----|----|----|---|
19//! |Equal      | 1   | 1  | 1  | 1  | 1  | |
20//! |In         | 10  | 8  | 5  | 5  | 5  | take the minimum value with actual in number |
21//! |Range(Two) | 600 | 50 | 20 | 10 | 10 | `RangeTwoSideBound` like a between 1 and 2 |
22//! |Range(One) | 1400| 70 | 25 | 15 | 10 | `RangeOneSideBound` like a > 1, a >= 1, a < 1|
23//! |All        | 4000| 100| 30 | 20 | 10 | |
24//!
25//! ```text
26//! index cost = cost(match type of 0 idx)
27//! * cost(match type of 1 idx)
28//! * ... cost(match type of the last idx)
29//! ```
30//!
31//! ## Example
32//!
33//! Given index order key (a, b, c)
34//!
35//! - For `a = 1 and b = 1 and c = 1`, its cost is 1 = Equal0 * Equal1 * Equal2 = 1
36//! - For `a in (xxx) and b = 1 and c = 1`, its cost is In0 * Equal1 * Equal2 = 10
37//! - For `a = 1 and b in (xxx)`, its cost is Equal0 * In1 * All2 = 1 * 8 * 50 = 400
38//! - For `a between xxx and yyy`, its cost is Range(Two)0 = 600
39//! - For `a = 1 and b between xxx and yyy`, its cost is Equal0 * Range(Two)1 = 50
40//! - For `a = 1 and b > 1`, its cost is Equal0 * Range(One)1 = 70
41//! - For `a = 1`, its cost is 100 = Equal0 * All1 = 100
42//! - For no condition, its cost is All0 = 4000
43//!
44//! With the assumption that the most effective part of a index is its prefix,
45//! cost decreases as `column_idx` increasing.
46//!
47//! For index order key length > 5, we just ignore the rest.
48
49use std::cmp::min;
50use std::collections::hash_map::Entry::{Occupied, Vacant};
51use std::collections::{BTreeMap, HashMap};
52use std::rc::Rc;
53
54use itertools::Itertools;
55use risingwave_common::catalog::Schema;
56use risingwave_common::types::{
57    DataType, Date, Decimal, Int256, Interval, Serial, Time, Timestamp, Timestamptz,
58};
59use risingwave_common::util::iter_util::ZipEqFast;
60use risingwave_pb::plan_common::JoinType;
61use risingwave_sqlparser::ast::AsOf;
62
63use super::{BoxedRule, Rule};
64use crate::catalog::IndexCatalog;
65use crate::expr::{
66    Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, FunctionCall, InputRef, to_conjunctions,
67    to_disjunctions,
68};
69use crate::optimizer::PlanRef;
70use crate::optimizer::optimizer_context::OptimizerContextRef;
71use crate::optimizer::plan_node::generic::GenericPlanRef;
72use crate::optimizer::plan_node::{
73    ColumnPruningContext, LogicalJoin, LogicalScan, LogicalUnion, PlanTreeNode, PlanTreeNodeBinary,
74    PredicatePushdown, PredicatePushdownContext, generic,
75};
76use crate::utils::Condition;
77
78const INDEX_MAX_LEN: usize = 5;
79const INDEX_COST_MATRIX: [[usize; INDEX_MAX_LEN]; 5] = [
80    [1, 1, 1, 1, 1],
81    [10, 8, 5, 5, 5],
82    [600, 50, 20, 10, 10],
83    [1400, 70, 25, 15, 10],
84    [4000, 100, 30, 20, 20],
85];
86const LOOKUP_COST_CONST: usize = 3;
87const MAX_COMBINATION_SIZE: usize = 3;
88const MAX_CONJUNCTION_SIZE: usize = 8;
89
90pub struct IndexSelectionRule {}
91
92impl Rule for IndexSelectionRule {
93    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
94        let logical_scan: &LogicalScan = plan.as_logical_scan()?;
95        let indexes = logical_scan.indexes();
96        if indexes.is_empty() {
97            return None;
98        }
99        let primary_table_row_size = TableScanIoEstimator::estimate_row_size(logical_scan);
100        let primary_cost = min(
101            self.estimate_table_scan_cost(logical_scan, primary_table_row_size),
102            self.estimate_full_table_scan_cost(logical_scan, primary_table_row_size),
103        );
104
105        // If it is a primary lookup plan, avoid checking other indexes.
106        if primary_cost.primary_lookup {
107            return None;
108        }
109
110        let mut final_plan: PlanRef = logical_scan.clone().into();
111        let mut min_cost = primary_cost.clone();
112
113        for index in indexes {
114            if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
115                let index_cost = self.estimate_table_scan_cost(
116                    &index_scan,
117                    TableScanIoEstimator::estimate_row_size(&index_scan),
118                );
119
120                if index_cost.le(&min_cost) {
121                    min_cost = index_cost;
122                    final_plan = index_scan.into();
123                }
124            } else {
125                // non-covering index selection
126                let (index_lookup, lookup_cost) = self.gen_index_lookup(logical_scan, index);
127                if lookup_cost.le(&min_cost) {
128                    min_cost = lookup_cost;
129                    final_plan = index_lookup;
130                }
131            }
132        }
133
134        if let Some((merge_index, merge_index_cost)) = self.index_merge_selection(logical_scan)
135            && merge_index_cost.le(&min_cost)
136        {
137            min_cost = merge_index_cost;
138            final_plan = merge_index;
139        }
140
141        if min_cost == primary_cost {
142            None
143        } else {
144            Some(final_plan)
145        }
146    }
147}
148
149struct IndexPredicateRewriter<'a> {
150    p2s_mapping: &'a BTreeMap<usize, usize>,
151    function_mapping: &'a HashMap<FunctionCall, usize>,
152    offset: usize,
153    covered_by_index: bool,
154}
155
156impl<'a> IndexPredicateRewriter<'a> {
157    fn new(
158        p2s_mapping: &'a BTreeMap<usize, usize>,
159        function_mapping: &'a HashMap<FunctionCall, usize>,
160        offset: usize,
161    ) -> Self {
162        Self {
163            p2s_mapping,
164            function_mapping,
165            offset,
166            covered_by_index: true,
167        }
168    }
169
170    fn covered_by_index(&self) -> bool {
171        self.covered_by_index
172    }
173}
174
175impl ExprRewriter for IndexPredicateRewriter<'_> {
176    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
177        // transform primary predicate to index predicate if it can
178        if self.p2s_mapping.contains_key(&input_ref.index) {
179            InputRef::new(
180                *self.p2s_mapping.get(&input_ref.index()).unwrap(),
181                input_ref.return_type(),
182            )
183            .into()
184        } else {
185            self.covered_by_index = false;
186            InputRef::new(input_ref.index() + self.offset, input_ref.return_type()).into()
187        }
188    }
189
190    fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
191        if let Some(index) = self.function_mapping.get(&func_call) {
192            return InputRef::new(*index, func_call.return_type()).into();
193        }
194
195        let (func_type, inputs, ret) = func_call.decompose();
196        let inputs = inputs
197            .into_iter()
198            .map(|expr| self.rewrite_expr(expr))
199            .collect();
200        FunctionCall::new_unchecked(func_type, inputs, ret).into()
201    }
202}
203
204impl IndexSelectionRule {
205    fn gen_index_lookup(
206        &self,
207        logical_scan: &LogicalScan,
208        index: &IndexCatalog,
209    ) -> (PlanRef, IndexCost) {
210        // 1. logical_scan ->  logical_join
211        //                      /        \
212        //                index_scan   primary_table_scan
213        let index_scan = LogicalScan::create(
214            index.index_table.name.clone(),
215            index.index_table.clone(),
216            vec![],
217            logical_scan.ctx(),
218            logical_scan.as_of().clone(),
219            index.index_table.cardinality,
220        );
221        // We use `schema.len` instead of `index_item.len` here,
222        // because schema contains system columns like `_rw_timestamp` column which is not represented in the index item.
223        let offset = index_scan.table_catalog().columns().len();
224
225        let primary_table_scan = LogicalScan::create(
226            index.primary_table.name.clone(),
227            index.primary_table.clone(),
228            vec![],
229            logical_scan.ctx(),
230            logical_scan.as_of().clone(),
231            index.primary_table.cardinality,
232        );
233
234        let predicate = logical_scan.predicate().clone();
235        let mut rewriter = IndexPredicateRewriter::new(
236            index.primary_to_secondary_mapping(),
237            index.function_mapping(),
238            offset,
239        );
240        let new_predicate = predicate.rewrite_expr(&mut rewriter);
241
242        let conjunctions = index
243            .primary_table_pk_ref_to_index_table()
244            .iter()
245            .zip_eq_fast(index.primary_table.pk.iter())
246            .map(|(x, y)| {
247                Self::create_null_safe_equal_expr(
248                    x.column_index,
249                    index.index_table.columns[x.column_index]
250                        .data_type()
251                        .clone(),
252                    y.column_index + offset,
253                    index.primary_table.columns[y.column_index]
254                        .data_type()
255                        .clone(),
256                )
257            })
258            .chain(new_predicate)
259            .collect_vec();
260        let on = Condition { conjunctions };
261        let join: PlanRef = LogicalJoin::new(
262            index_scan.into(),
263            primary_table_scan.into(),
264            JoinType::Inner,
265            on,
266        )
267        .into();
268
269        // 2. push down predicate, so we can calculate the cost of index lookup
270        let join_ref = join.predicate_pushdown(
271            Condition::true_cond(),
272            &mut PredicatePushdownContext::new(join.clone()),
273        );
274
275        let join_with_predicate_push_down =
276            join_ref.as_logical_join().expect("must be a logical join");
277        let new_join_left = join_with_predicate_push_down.left();
278        let index_scan_with_predicate: &LogicalScan = new_join_left
279            .as_logical_scan()
280            .expect("must be a logical scan");
281
282        // 3. calculate index cost, index lookup use primary table to estimate row size.
283        let index_cost = self.estimate_table_scan_cost(
284            index_scan_with_predicate,
285            TableScanIoEstimator::estimate_row_size(logical_scan),
286        );
287        // lookup cost = index cost * LOOKUP_COST_CONST
288        let lookup_cost = index_cost.mul(&IndexCost::new(LOOKUP_COST_CONST, false));
289
290        // 4. keep the same schema with original logical_scan
291        let scan_output_col_idx = logical_scan.output_col_idx();
292        let lookup_join = join_ref.prune_col(
293            &scan_output_col_idx
294                .iter()
295                .map(|&col_idx| col_idx + offset)
296                .collect_vec(),
297            &mut ColumnPruningContext::new(join_ref.clone()),
298        );
299
300        (lookup_join, lookup_cost)
301    }
302
303    /// Index Merge Selection
304    /// Deal with predicate like a = 1 or b = 1
305    /// Merge index scans from a table, currently merge is union semantic.
306    fn index_merge_selection(&self, logical_scan: &LogicalScan) -> Option<(PlanRef, IndexCost)> {
307        let predicate = logical_scan.predicate().clone();
308        // Index merge is kind of index lookup join so use primary table row size to estimate index
309        // cost.
310        let primary_table_row_size = TableScanIoEstimator::estimate_row_size(logical_scan);
311        // 1. choose lowest cost index merge path
312        let paths = self.gen_paths(
313            &predicate.conjunctions,
314            logical_scan,
315            primary_table_row_size,
316        );
317        let (index_access, index_access_cost) =
318            self.choose_min_cost_path(&paths, primary_table_row_size)?;
319
320        // 2. lookup primary table
321        // the schema of index_access is the order key of primary table .
322        let schema: &Schema = index_access.schema();
323        let index_access_len = schema.len();
324
325        let mut shift_input_ref_rewriter = ShiftInputRefRewriter {
326            offset: index_access_len,
327        };
328        let new_predicate = predicate.rewrite_expr(&mut shift_input_ref_rewriter);
329
330        let primary_table_desc = logical_scan.table_desc();
331
332        let primary_table_scan = LogicalScan::create(
333            logical_scan.table_name().to_owned(),
334            logical_scan.table_catalog(),
335            vec![],
336            logical_scan.ctx(),
337            logical_scan.as_of().clone(),
338            logical_scan.table_cardinality(),
339        );
340
341        let conjunctions = primary_table_desc
342            .pk
343            .iter()
344            .enumerate()
345            .map(|(x, y)| {
346                Self::create_null_safe_equal_expr(
347                    x,
348                    schema.fields[x].data_type.clone(),
349                    y.column_index + index_access_len,
350                    primary_table_desc.columns[y.column_index].data_type.clone(),
351                )
352            })
353            .chain(new_predicate)
354            .collect_vec();
355
356        let on = Condition { conjunctions };
357        let join: PlanRef =
358            LogicalJoin::new(index_access, primary_table_scan.into(), JoinType::Inner, on).into();
359
360        // 3 push down predicate
361        let join_ref = join.predicate_pushdown(
362            Condition::true_cond(),
363            &mut PredicatePushdownContext::new(join.clone()),
364        );
365
366        // 4. keep the same schema with original logical_scan
367        let scan_output_col_idx = logical_scan.output_col_idx();
368        let lookup_join = join_ref.prune_col(
369            &scan_output_col_idx
370                .iter()
371                .map(|&col_idx| col_idx + index_access_len)
372                .collect_vec(),
373            &mut ColumnPruningContext::new(join_ref.clone()),
374        );
375
376        Some((
377            lookup_join,
378            index_access_cost.mul(&IndexCost::new(LOOKUP_COST_CONST, false)),
379        ))
380    }
381
382    /// Generate possible paths that can be used to access.
383    /// The schema of output is the order key of primary table, so it can be used to lookup primary
384    /// table later.
385    /// Method `gen_paths` handles the complex condition recursively which may contains nested `AND`
386    /// and `OR`. However, Method `gen_index_path` handles one arm of an OR clause which is a
387    /// basic unit for index selection.
388    fn gen_paths(
389        &self,
390        conjunctions: &[ExprImpl],
391        logical_scan: &LogicalScan,
392        primary_table_row_size: usize,
393    ) -> Vec<PlanRef> {
394        let mut result = vec![];
395        for expr in conjunctions {
396            // it's OR clause!
397            if let ExprImpl::FunctionCall(function_call) = expr
398                && function_call.func_type() == ExprType::Or
399            {
400                let mut index_to_be_merged = vec![];
401
402                let disjunctions = to_disjunctions(expr.clone());
403                let (map, others) = self.clustering_disjunction(disjunctions);
404                let iter = map
405                    .into_iter()
406                    .map(|(column_index, expr)| (Some(column_index), expr))
407                    .chain(others.into_iter().map(|expr| (None, expr)));
408                for (column_index, expr) in iter {
409                    let mut index_paths = vec![];
410                    let conjunctions = to_conjunctions(expr);
411                    index_paths.extend(
412                        self.gen_index_path(column_index, &conjunctions, logical_scan)
413                            .into_iter(),
414                    );
415                    // complex condition, recursively gen paths
416                    if conjunctions.len() > 1 {
417                        index_paths.extend(
418                            self.gen_paths(&conjunctions, logical_scan, primary_table_row_size)
419                                .into_iter(),
420                        );
421                    }
422
423                    match self.choose_min_cost_path(&index_paths, primary_table_row_size) {
424                        None => {
425                            // One arm of OR clause can't use index, bail out
426                            index_to_be_merged.clear();
427                            break;
428                        }
429                        Some((path, _)) => index_to_be_merged.push(path),
430                    }
431                }
432
433                if let Some(path) = self.merge(index_to_be_merged) {
434                    result.push(path)
435                }
436            }
437        }
438
439        result
440    }
441
442    /// Clustering disjunction or expr by column index. If expr is complex, classify them as others.
443    ///
444    /// a = 1, b = 2, b = 3 -> map: [a, (a = 1)], [b, (b = 2 or b = 3)], others: []
445    ///
446    /// a = 1, (b = 2 and c = 3) -> map: [a, (a = 1)], others:
447    ///
448    /// (a > 1 and a < 8) or (c > 1 and c < 8)
449    /// -> map: [], others: [(a > 1 and a < 8), (c > 1 and c < 8)]
450    fn clustering_disjunction(
451        &self,
452        disjunctions: Vec<ExprImpl>,
453    ) -> (HashMap<usize, ExprImpl>, Vec<ExprImpl>) {
454        let mut map: HashMap<usize, ExprImpl> = HashMap::new();
455        let mut others = vec![];
456        for expr in disjunctions {
457            let idx = {
458                if let Some((input_ref, _const_expr)) = expr.as_eq_const() {
459                    Some(input_ref.index)
460                } else if let Some((input_ref, _in_const_list)) = expr.as_in_const_list() {
461                    Some(input_ref.index)
462                } else if let Some((input_ref, _op, _const_expr)) = expr.as_comparison_const() {
463                    Some(input_ref.index)
464                } else {
465                    None
466                }
467            };
468
469            if let Some(idx) = idx {
470                match map.entry(idx) {
471                    Occupied(mut entry) => {
472                        let expr2: ExprImpl = entry.get().to_owned();
473                        let or_expr = ExprImpl::FunctionCall(
474                            FunctionCall::new_unchecked(
475                                ExprType::Or,
476                                vec![expr, expr2],
477                                DataType::Boolean,
478                            )
479                            .into(),
480                        );
481                        entry.insert(or_expr);
482                    }
483                    Vacant(entry) => {
484                        entry.insert(expr);
485                    }
486                };
487            } else {
488                others.push(expr);
489                continue;
490            }
491        }
492
493        (map, others)
494    }
495
496    /// Given a conjunctions from one arm of an OR clause (basic unit to index selection), generate
497    /// all matching index path (including primary index) for the relation.
498    /// `column_index` (refers to primary table) is a hint can be used to prune index.
499    /// Steps:
500    /// 1. Take the combination of `conjunctions` to extract the potential clauses.
501    /// 2. For each potential clauses, generate index path if it can.
502    fn gen_index_path(
503        &self,
504        column_index: Option<usize>,
505        conjunctions: &[ExprImpl],
506        logical_scan: &LogicalScan,
507    ) -> Vec<PlanRef> {
508        // Assumption: use at most `MAX_COMBINATION_SIZE` clauses, we can determine which is the
509        // best index.
510        let mut combinations = vec![];
511        for i in 1..min(conjunctions.len(), MAX_COMBINATION_SIZE) + 1 {
512            combinations.extend(
513                conjunctions
514                    .iter()
515                    .take(min(conjunctions.len(), MAX_CONJUNCTION_SIZE))
516                    .combinations(i),
517            );
518        }
519
520        let mut result = vec![];
521
522        for index in logical_scan.indexes() {
523            if let Some(column_index) = column_index {
524                assert_eq!(conjunctions.len(), 1);
525                let p2s_mapping = index.primary_to_secondary_mapping();
526                match p2s_mapping.get(&column_index) {
527                    None => continue, // not found, prune this index
528                    Some(&idx) => {
529                        if index.index_table.pk()[0].column_index != idx {
530                            // not match, prune this index
531                            continue;
532                        }
533                    }
534                }
535            }
536
537            // try secondary index
538            for conj in &combinations {
539                let condition = Condition {
540                    conjunctions: conj.iter().map(|&x| x.to_owned()).collect(),
541                };
542                if let Some(index_access) = self.build_index_access(
543                    index.clone(),
544                    condition,
545                    logical_scan.ctx().clone(),
546                    logical_scan.as_of().clone(),
547                ) {
548                    result.push(index_access);
549                }
550            }
551        }
552
553        // try primary index
554        let primary_table_desc = logical_scan.table_desc();
555        if let Some(idx) = column_index {
556            assert_eq!(conjunctions.len(), 1);
557            if primary_table_desc.pk[0].column_index != idx {
558                return result;
559            }
560        }
561
562        let primary_access = generic::TableScan::new(
563            logical_scan.table_name().to_owned(),
564            primary_table_desc
565                .pk
566                .iter()
567                .map(|x| x.column_index)
568                .collect_vec(),
569            logical_scan.table_catalog(),
570            vec![],
571            logical_scan.ctx(),
572            Condition {
573                conjunctions: conjunctions.to_vec(),
574            },
575            logical_scan.as_of().clone(),
576            logical_scan.table_cardinality(),
577        );
578
579        result.push(primary_access.into());
580
581        result
582    }
583
584    /// build index access if predicate (refers to primary table) is covered by index
585    fn build_index_access(
586        &self,
587        index: Rc<IndexCatalog>,
588        predicate: Condition,
589        ctx: OptimizerContextRef,
590        as_of: Option<AsOf>,
591    ) -> Option<PlanRef> {
592        let mut rewriter = IndexPredicateRewriter::new(
593            index.primary_to_secondary_mapping(),
594            index.function_mapping(),
595            0,
596        );
597        let new_predicate = predicate.rewrite_expr(&mut rewriter);
598
599        // check condition is covered by index.
600        if !rewriter.covered_by_index() {
601            return None;
602        }
603
604        Some(
605            generic::TableScan::new(
606                index.index_table.name.clone(),
607                index
608                    .primary_table_pk_ref_to_index_table()
609                    .iter()
610                    .map(|x| x.column_index)
611                    .collect_vec(),
612                index.index_table.clone(),
613                vec![],
614                ctx,
615                new_predicate,
616                as_of,
617                index.index_table.cardinality,
618            )
619            .into(),
620        )
621    }
622
623    fn merge(&self, paths: Vec<PlanRef>) -> Option<PlanRef> {
624        if paths.is_empty() {
625            return None;
626        }
627
628        let new_paths = paths
629            .iter()
630            .flat_map(|path| {
631                if let Some(union) = path.as_logical_union() {
632                    union.inputs().to_vec()
633                } else if let Some(_scan) = path.as_logical_scan() {
634                    vec![path.clone()]
635                } else {
636                    unreachable!();
637                }
638            })
639            .sorted_by(|a, b| {
640                // sort inputs to make plan deterministic
641                a.as_logical_scan()
642                    .expect("expect to be a logical scan")
643                    .table_name()
644                    .cmp(
645                        b.as_logical_scan()
646                            .expect("expect to be a logical scan")
647                            .table_name(),
648                    )
649            })
650            .collect_vec();
651
652        Some(LogicalUnion::create(false, new_paths))
653    }
654
655    fn choose_min_cost_path(
656        &self,
657        paths: &[PlanRef],
658        primary_table_row_size: usize,
659    ) -> Option<(PlanRef, IndexCost)> {
660        paths
661            .iter()
662            .map(|path| {
663                if let Some(scan) = path.as_logical_scan() {
664                    let cost = self.estimate_table_scan_cost(scan, primary_table_row_size);
665                    (scan.clone().into(), cost)
666                } else if let Some(union) = path.as_logical_union() {
667                    let cost = union
668                        .inputs()
669                        .iter()
670                        .map(|input| {
671                            self.estimate_table_scan_cost(
672                                input.as_logical_scan().expect("expect to be a scan"),
673                                primary_table_row_size,
674                            )
675                        })
676                        .reduce(|a, b| a.add(&b))
677                        .unwrap();
678                    (union.clone().into(), cost)
679                } else {
680                    unreachable!()
681                }
682            })
683            .min_by(|(_, cost1), (_, cost2)| Ord::cmp(cost1, cost2))
684    }
685
686    fn estimate_table_scan_cost(&self, scan: &LogicalScan, row_size: usize) -> IndexCost {
687        let mut table_scan_io_estimator = TableScanIoEstimator::new(scan, row_size);
688        table_scan_io_estimator.estimate(scan.predicate())
689    }
690
691    fn estimate_full_table_scan_cost(&self, scan: &LogicalScan, row_size: usize) -> IndexCost {
692        let mut table_scan_io_estimator = TableScanIoEstimator::new(scan, row_size);
693        table_scan_io_estimator.estimate(&Condition::true_cond())
694    }
695
696    fn create_null_safe_equal_expr(
697        left: usize,
698        left_data_type: DataType,
699        right: usize,
700        right_data_type: DataType,
701    ) -> ExprImpl {
702        ExprImpl::FunctionCall(Box::new(FunctionCall::new_unchecked(
703            ExprType::IsNotDistinctFrom,
704            vec![
705                ExprImpl::InputRef(Box::new(InputRef::new(left, left_data_type))),
706                ExprImpl::InputRef(Box::new(InputRef::new(right, right_data_type))),
707            ],
708            DataType::Boolean,
709        )))
710    }
711}
712
713struct TableScanIoEstimator<'a> {
714    table_scan: &'a LogicalScan,
715    row_size: usize,
716    cost: Option<IndexCost>,
717}
718
719impl<'a> TableScanIoEstimator<'a> {
720    pub fn new(table_scan: &'a LogicalScan, row_size: usize) -> Self {
721        Self {
722            table_scan,
723            row_size,
724            cost: None,
725        }
726    }
727
728    pub fn estimate_row_size(table_scan: &LogicalScan) -> usize {
729        // 5 for table_id + 1 for vnode + 8 for epoch
730        let row_meta_field_estimate_size = 14_usize;
731        let table_desc = table_scan.table_desc();
732        row_meta_field_estimate_size
733            + table_desc
734                .columns
735                .iter()
736                // add order key twice for its appearance both in key and value
737                .chain(
738                    table_desc
739                        .pk
740                        .iter()
741                        .map(|x| &table_desc.columns[x.column_index]),
742                )
743                .map(|x| TableScanIoEstimator::estimate_data_type_size(&x.data_type))
744                .sum::<usize>()
745    }
746
747    fn estimate_data_type_size(data_type: &DataType) -> usize {
748        use std::mem::size_of;
749
750        match data_type {
751            DataType::Boolean => size_of::<bool>(),
752            DataType::Int16 => size_of::<i16>(),
753            DataType::Int32 => size_of::<i32>(),
754            DataType::Int64 => size_of::<i64>(),
755            DataType::Serial => size_of::<Serial>(),
756            DataType::Float32 => size_of::<f32>(),
757            DataType::Float64 => size_of::<f64>(),
758            DataType::Decimal => size_of::<Decimal>(),
759            DataType::Date => size_of::<Date>(),
760            DataType::Time => size_of::<Time>(),
761            DataType::Timestamp => size_of::<Timestamp>(),
762            DataType::Timestamptz => size_of::<Timestamptz>(),
763            DataType::Interval => size_of::<Interval>(),
764            DataType::Int256 => Int256::size(),
765            DataType::Varchar => 20,
766            DataType::Bytea => 20,
767            DataType::Jsonb => 20,
768            DataType::Struct { .. } => 20,
769            DataType::List { .. } => 20,
770            DataType::Map(_) => 20,
771            DataType::Vector(_) => todo!("VECTOR_PLACEHOLDER"),
772        }
773    }
774
775    pub fn estimate(&mut self, predicate: &Condition) -> IndexCost {
776        // try to deal with OR condition
777        if predicate.conjunctions.len() == 1 {
778            self.visit_expr(&predicate.conjunctions[0]);
779            self.cost.take().unwrap_or_default()
780        } else {
781            self.estimate_conjunctions(&predicate.conjunctions)
782        }
783    }
784
785    fn estimate_conjunctions(&mut self, conjunctions: &[ExprImpl]) -> IndexCost {
786        let order_column_indices = self.table_scan.table_desc().order_column_indices();
787
788        let mut new_conjunctions = conjunctions.to_owned();
789
790        let mut match_item_vec = vec![];
791
792        for column_idx in order_column_indices {
793            let match_item = self.match_index_column(column_idx, &mut new_conjunctions);
794            // seeing range, we don't need to match anymore.
795            let should_break = match match_item {
796                MatchItem::Equal | MatchItem::In(_) => false,
797                MatchItem::RangeOneSideBound | MatchItem::RangeTwoSideBound | MatchItem::All => {
798                    true
799                }
800            };
801            match_item_vec.push(match_item);
802            if should_break {
803                break;
804            }
805        }
806
807        let index_cost = match_item_vec
808            .iter()
809            .enumerate()
810            .take(INDEX_MAX_LEN)
811            .map(|(i, match_item)| match match_item {
812                MatchItem::Equal => INDEX_COST_MATRIX[0][i],
813                MatchItem::In(num) => min(INDEX_COST_MATRIX[1][i], *num),
814                MatchItem::RangeTwoSideBound => INDEX_COST_MATRIX[2][i],
815                MatchItem::RangeOneSideBound => INDEX_COST_MATRIX[3][i],
816                MatchItem::All => INDEX_COST_MATRIX[4][i],
817            })
818            .reduce(|x, y| x * y)
819            .unwrap();
820
821        // If `index_cost` equals 1, it is a primary lookup
822        let primary_lookup = index_cost == 1;
823
824        IndexCost::new(index_cost, primary_lookup)
825            .mul(&IndexCost::new(self.row_size, primary_lookup))
826    }
827
828    fn match_index_column(
829        &mut self,
830        column_idx: usize,
831        conjunctions: &mut Vec<ExprImpl>,
832    ) -> MatchItem {
833        // Equal
834        for (i, expr) in conjunctions.iter().enumerate() {
835            if let Some((input_ref, _const_expr)) = expr.as_eq_const()
836                && input_ref.index == column_idx
837            {
838                conjunctions.remove(i);
839                return MatchItem::Equal;
840            }
841        }
842
843        // In
844        for (i, expr) in conjunctions.iter().enumerate() {
845            if let Some((input_ref, in_const_list)) = expr.as_in_const_list()
846                && input_ref.index == column_idx
847            {
848                conjunctions.remove(i);
849                return MatchItem::In(in_const_list.len());
850            }
851        }
852
853        // Range
854        let mut left_side_bound = false;
855        let mut right_side_bound = false;
856        let mut i = 0;
857        while i < conjunctions.len() {
858            let expr = &conjunctions[i];
859            if let Some((input_ref, op, _const_expr)) = expr.as_comparison_const()
860                && input_ref.index == column_idx
861            {
862                conjunctions.remove(i);
863                match op {
864                    ExprType::LessThan | ExprType::LessThanOrEqual => right_side_bound = true,
865                    ExprType::GreaterThan | ExprType::GreaterThanOrEqual => left_side_bound = true,
866                    _ => unreachable!(),
867                };
868            } else {
869                i += 1;
870            }
871        }
872
873        if left_side_bound && right_side_bound {
874            MatchItem::RangeTwoSideBound
875        } else if left_side_bound || right_side_bound {
876            MatchItem::RangeOneSideBound
877        } else {
878            MatchItem::All
879        }
880    }
881}
882
883enum MatchItem {
884    Equal,
885    In(usize),
886    RangeTwoSideBound,
887    RangeOneSideBound,
888    All,
889}
890
891#[derive(PartialEq, Eq, Hash, Clone, Debug, PartialOrd, Ord)]
892struct IndexCost {
893    cost: usize,
894    primary_lookup: bool,
895}
896
897impl Default for IndexCost {
898    fn default() -> Self {
899        Self {
900            cost: IndexCost::maximum(),
901            primary_lookup: false,
902        }
903    }
904}
905
906impl IndexCost {
907    fn new(cost: usize, primary_lookup: bool) -> IndexCost {
908        Self {
909            cost: min(cost, IndexCost::maximum()),
910            primary_lookup,
911        }
912    }
913
914    fn maximum() -> usize {
915        10000000
916    }
917
918    fn add(&self, other: &IndexCost) -> IndexCost {
919        IndexCost::new(
920            self.cost
921                .checked_add(other.cost)
922                .unwrap_or_else(IndexCost::maximum),
923            self.primary_lookup && other.primary_lookup,
924        )
925    }
926
927    fn mul(&self, other: &IndexCost) -> IndexCost {
928        IndexCost::new(
929            self.cost
930                .checked_mul(other.cost)
931                .unwrap_or_else(IndexCost::maximum),
932            self.primary_lookup && other.primary_lookup,
933        )
934    }
935
936    fn le(&self, other: &IndexCost) -> bool {
937        self.cost < other.cost
938    }
939}
940
941impl ExprVisitor for TableScanIoEstimator<'_> {
942    fn visit_function_call(&mut self, func_call: &FunctionCall) {
943        let cost = match func_call.func_type() {
944            ExprType::Or => func_call
945                .inputs()
946                .iter()
947                .map(|x| {
948                    let mut estimator = TableScanIoEstimator::new(self.table_scan, self.row_size);
949                    estimator.visit_expr(x);
950                    estimator.cost.take().unwrap_or_default()
951                })
952                .reduce(|x, y| x.add(&y))
953                .unwrap(),
954            ExprType::And => self.estimate_conjunctions(func_call.inputs()),
955            _ => {
956                let single = vec![ExprImpl::FunctionCall(func_call.clone().into())];
957                self.estimate_conjunctions(&single)
958            }
959        };
960        self.cost = Some(cost);
961    }
962}
963
964struct ShiftInputRefRewriter {
965    offset: usize,
966}
967impl ExprRewriter for ShiftInputRefRewriter {
968    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
969        InputRef::new(input_ref.index() + self.offset, input_ref.return_type()).into()
970    }
971}
972
973impl IndexSelectionRule {
974    pub fn create() -> BoxedRule {
975        Box::new(IndexSelectionRule {})
976    }
977}