risingwave_frontend/optimizer/rule/
index_selection_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! # Index selection cost matrix
16//!
17//! |`column_idx`| 0   |  1 | 2  | 3  | 4  | remark |
18//! |-----------|-----|----|----|----|----|---|
19//! |Equal      | 1   | 1  | 1  | 1  | 1  | |
20//! |In         | 10  | 8  | 5  | 5  | 5  | take the minimum value with actual in number |
21//! |Range(Two) | 600 | 50 | 20 | 10 | 10 | `RangeTwoSideBound` like a between 1 and 2 |
22//! |Range(One) | 1400| 70 | 25 | 15 | 10 | `RangeOneSideBound` like a > 1, a >= 1, a < 1|
23//! |All        | 4000| 100| 30 | 20 | 10 | |
24//!
25//! ```text
26//! index cost = cost(match type of 0 idx)
27//! * cost(match type of 1 idx)
28//! * ... cost(match type of the last idx)
29//! ```
30//!
31//! ## Example
32//!
33//! Given index order key (a, b, c)
34//!
35//! - For `a = 1 and b = 1 and c = 1`, its cost is 1 = Equal0 * Equal1 * Equal2 = 1
36//! - For `a in (xxx) and b = 1 and c = 1`, its cost is In0 * Equal1 * Equal2 = 10
37//! - For `a = 1 and b in (xxx)`, its cost is Equal0 * In1 * All2 = 1 * 8 * 50 = 400
38//! - For `a between xxx and yyy`, its cost is Range(Two)0 = 600
39//! - For `a = 1 and b between xxx and yyy`, its cost is Equal0 * Range(Two)1 = 50
40//! - For `a = 1 and b > 1`, its cost is Equal0 * Range(One)1 = 70
41//! - For `a = 1`, its cost is 100 = Equal0 * All1 = 100
42//! - For no condition, its cost is All0 = 4000
43//!
44//! With the assumption that the most effective part of a index is its prefix,
45//! cost decreases as `column_idx` increasing.
46//!
47//! For index order key length > 5, we just ignore the rest.
48
49use std::cmp::min;
50use std::collections::hash_map::Entry::{Occupied, Vacant};
51use std::collections::{BTreeMap, HashMap};
52use std::sync::Arc;
53
54use itertools::Itertools;
55use risingwave_common::array::VectorDistanceType;
56use risingwave_common::catalog::Schema;
57use risingwave_common::types::{
58    DataType, Date, Decimal, Int256, Interval, Serial, Time, Timestamp, Timestamptz,
59};
60use risingwave_common::util::iter_util::ZipEqFast;
61use risingwave_pb::plan_common::JoinType;
62use risingwave_sqlparser::ast::AsOf;
63
64use super::prelude::{PlanRef, *};
65use crate::catalog::index_catalog::TableIndex;
66use crate::expr::{
67    Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, FunctionCall, InputRef, to_conjunctions,
68    to_disjunctions,
69};
70use crate::optimizer::optimizer_context::OptimizerContextRef;
71use crate::optimizer::plan_node::generic::GenericPlanRef;
72use crate::optimizer::plan_node::{
73    ColumnPruningContext, LogicalJoin, LogicalScan, LogicalUnion, PlanTreeNode, PlanTreeNodeBinary,
74    PredicatePushdown, PredicatePushdownContext, generic,
75};
76use crate::utils::Condition;
77
78const INDEX_MAX_LEN: usize = 5;
79const INDEX_COST_MATRIX: [[usize; INDEX_MAX_LEN]; 5] = [
80    [1, 1, 1, 1, 1],
81    [10, 8, 5, 5, 5],
82    [600, 50, 20, 10, 10],
83    [1400, 70, 25, 15, 10],
84    [4000, 100, 30, 20, 20],
85];
86const LOOKUP_COST_CONST: usize = 3;
87const MAX_COMBINATION_SIZE: usize = 3;
88const MAX_CONJUNCTION_SIZE: usize = 8;
89
90pub struct IndexSelectionRule {}
91
92impl Rule<Logical> for IndexSelectionRule {
93    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
94        let logical_scan: &LogicalScan = plan.as_logical_scan()?;
95        let indexes = logical_scan.table_indexes();
96        if indexes.is_empty() {
97            return None;
98        }
99        let primary_table_row_size = TableScanIoEstimator::estimate_row_size(logical_scan);
100        let primary_cost = min(
101            self.estimate_table_scan_cost(logical_scan, primary_table_row_size),
102            self.estimate_full_table_scan_cost(logical_scan, primary_table_row_size),
103        );
104
105        // If it is a primary lookup plan, avoid checking other indexes.
106        if primary_cost.primary_lookup {
107            return None;
108        }
109
110        let mut final_plan: PlanRef = logical_scan.clone().into();
111        let mut min_cost = primary_cost.clone();
112
113        for index in indexes {
114            if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
115                let index_cost = self.estimate_table_scan_cost(
116                    &index_scan,
117                    TableScanIoEstimator::estimate_row_size(&index_scan),
118                );
119
120                if index_cost.le(&min_cost) {
121                    min_cost = index_cost;
122                    final_plan = index_scan.into();
123                }
124            } else {
125                // non-covering index selection
126                let (index_lookup, lookup_cost) = self.gen_index_lookup(logical_scan, index);
127                if lookup_cost.le(&min_cost) {
128                    min_cost = lookup_cost;
129                    final_plan = index_lookup;
130                }
131            }
132        }
133
134        if let Some((merge_index, merge_index_cost)) = self.index_merge_selection(logical_scan)
135            && merge_index_cost.le(&min_cost)
136        {
137            min_cost = merge_index_cost;
138            final_plan = merge_index;
139        }
140
141        if min_cost == primary_cost {
142            None
143        } else {
144            Some(final_plan)
145        }
146    }
147}
148
149struct IndexPredicateRewriter<'a> {
150    p2s_mapping: &'a BTreeMap<usize, usize>,
151    function_mapping: &'a HashMap<FunctionCall, usize>,
152    offset: usize,
153    covered_by_index: bool,
154}
155
156impl<'a> IndexPredicateRewriter<'a> {
157    fn new(
158        p2s_mapping: &'a BTreeMap<usize, usize>,
159        function_mapping: &'a HashMap<FunctionCall, usize>,
160        offset: usize,
161    ) -> Self {
162        Self {
163            p2s_mapping,
164            function_mapping,
165            offset,
166            covered_by_index: true,
167        }
168    }
169
170    fn covered_by_index(&self) -> bool {
171        self.covered_by_index
172    }
173}
174
175impl ExprRewriter for IndexPredicateRewriter<'_> {
176    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
177        // transform primary predicate to index predicate if it can
178        if self.p2s_mapping.contains_key(&input_ref.index) {
179            InputRef::new(
180                *self.p2s_mapping.get(&input_ref.index()).unwrap(),
181                input_ref.return_type(),
182            )
183            .into()
184        } else {
185            self.covered_by_index = false;
186            InputRef::new(input_ref.index() + self.offset, input_ref.return_type()).into()
187        }
188    }
189
190    fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
191        if let Some(index) = self.function_mapping.get(&func_call) {
192            return InputRef::new(*index, func_call.return_type()).into();
193        }
194
195        let (func_type, inputs, ret) = func_call.decompose();
196        let inputs = inputs
197            .into_iter()
198            .map(|expr| self.rewrite_expr(expr))
199            .collect();
200        FunctionCall::new_unchecked(func_type, inputs, ret).into()
201    }
202}
203
204impl IndexSelectionRule {
205    fn gen_index_lookup(
206        &self,
207        logical_scan: &LogicalScan,
208        index: &TableIndex,
209    ) -> (PlanRef, IndexCost) {
210        // 1. logical_scan ->  logical_join
211        //                      /        \
212        //                index_scan   primary_table_scan
213        let index_scan = LogicalScan::create(
214            index.index_table.clone(),
215            logical_scan.ctx(),
216            logical_scan.as_of().clone(),
217        );
218        // We use `schema.len` instead of `index_item.len` here,
219        // because schema contains system columns like `_rw_timestamp` column which is not represented in the index item.
220        let offset = index_scan.table().columns().len();
221
222        let primary_table_scan = LogicalScan::create(
223            index.primary_table.clone(),
224            logical_scan.ctx(),
225            logical_scan.as_of().clone(),
226        );
227
228        let predicate = logical_scan.predicate().clone();
229        let mut rewriter = IndexPredicateRewriter::new(
230            index.primary_to_secondary_mapping(),
231            index.function_mapping(),
232            offset,
233        );
234        let new_predicate = predicate.rewrite_expr(&mut rewriter);
235
236        let conjunctions = index
237            .primary_table_pk_ref_to_index_table()
238            .iter()
239            .zip_eq_fast(index.primary_table.pk.iter())
240            .map(|(x, y)| {
241                Self::create_null_safe_equal_expr(
242                    x.column_index,
243                    index.index_table.columns[x.column_index]
244                        .data_type()
245                        .clone(),
246                    y.column_index + offset,
247                    index.primary_table.columns[y.column_index]
248                        .data_type()
249                        .clone(),
250                )
251            })
252            .chain(new_predicate)
253            .collect_vec();
254        let on = Condition { conjunctions };
255        let join: PlanRef = LogicalJoin::new(
256            index_scan.into(),
257            primary_table_scan.into(),
258            JoinType::Inner,
259            on,
260        )
261        .into();
262
263        // 2. push down predicate, so we can calculate the cost of index lookup
264        let join_ref = join.predicate_pushdown(
265            Condition::true_cond(),
266            &mut PredicatePushdownContext::new(join.clone()),
267        );
268
269        let join_with_predicate_push_down =
270            join_ref.as_logical_join().expect("must be a logical join");
271        let new_join_left = join_with_predicate_push_down.left();
272        let index_scan_with_predicate: &LogicalScan = new_join_left
273            .as_logical_scan()
274            .expect("must be a logical scan");
275
276        // 3. calculate index cost, index lookup use primary table to estimate row size.
277        let index_cost = self.estimate_table_scan_cost(
278            index_scan_with_predicate,
279            TableScanIoEstimator::estimate_row_size(logical_scan),
280        );
281        // lookup cost = index cost * LOOKUP_COST_CONST
282        let lookup_cost = index_cost.mul(&IndexCost::new(LOOKUP_COST_CONST, false));
283
284        // 4. keep the same schema with original logical_scan
285        let scan_output_col_idx = logical_scan.output_col_idx();
286        let lookup_join = join_ref.prune_col(
287            &scan_output_col_idx
288                .iter()
289                .map(|&col_idx| col_idx + offset)
290                .collect_vec(),
291            &mut ColumnPruningContext::new(join_ref.clone()),
292        );
293
294        (lookup_join, lookup_cost)
295    }
296
297    /// Index Merge Selection
298    /// Deal with predicate like a = 1 or b = 1
299    /// Merge index scans from a table, currently merge is union semantic.
300    fn index_merge_selection(&self, logical_scan: &LogicalScan) -> Option<(PlanRef, IndexCost)> {
301        let predicate = logical_scan.predicate().clone();
302        // Index merge is kind of index lookup join so use primary table row size to estimate index
303        // cost.
304        let primary_table_row_size = TableScanIoEstimator::estimate_row_size(logical_scan);
305        // 1. choose lowest cost index merge path
306        let paths = self.gen_paths(
307            &predicate.conjunctions,
308            logical_scan,
309            primary_table_row_size,
310        );
311        let (index_access, index_access_cost) =
312            self.choose_min_cost_path(&paths, primary_table_row_size)?;
313
314        // 2. lookup primary table
315        // the schema of index_access is the order key of primary table .
316        let schema: &Schema = index_access.schema();
317        let index_access_len = schema.len();
318
319        let mut shift_input_ref_rewriter = ShiftInputRefRewriter {
320            offset: index_access_len,
321        };
322        let new_predicate = predicate.rewrite_expr(&mut shift_input_ref_rewriter);
323
324        let primary_table = logical_scan.table();
325
326        let primary_table_scan = LogicalScan::create(
327            logical_scan.table().clone(),
328            logical_scan.ctx(),
329            logical_scan.as_of().clone(),
330        );
331
332        let conjunctions = primary_table
333            .pk
334            .iter()
335            .enumerate()
336            .map(|(x, y)| {
337                Self::create_null_safe_equal_expr(
338                    x,
339                    schema.fields[x].data_type.clone(),
340                    y.column_index + index_access_len,
341                    primary_table.columns[y.column_index].data_type.clone(),
342                )
343            })
344            .chain(new_predicate)
345            .collect_vec();
346
347        let on = Condition { conjunctions };
348        let join: PlanRef =
349            LogicalJoin::new(index_access, primary_table_scan.into(), JoinType::Inner, on).into();
350
351        // 3 push down predicate
352        let join_ref = join.predicate_pushdown(
353            Condition::true_cond(),
354            &mut PredicatePushdownContext::new(join.clone()),
355        );
356
357        // 4. keep the same schema with original logical_scan
358        let scan_output_col_idx = logical_scan.output_col_idx();
359        let lookup_join = join_ref.prune_col(
360            &scan_output_col_idx
361                .iter()
362                .map(|&col_idx| col_idx + index_access_len)
363                .collect_vec(),
364            &mut ColumnPruningContext::new(join_ref.clone()),
365        );
366
367        Some((
368            lookup_join,
369            index_access_cost.mul(&IndexCost::new(LOOKUP_COST_CONST, false)),
370        ))
371    }
372
373    /// Generate possible paths that can be used to access.
374    /// The schema of output is the order key of primary table, so it can be used to lookup primary
375    /// table later.
376    /// Method `gen_paths` handles the complex condition recursively which may contains nested `AND`
377    /// and `OR`. However, Method `gen_index_path` handles one arm of an OR clause which is a
378    /// basic unit for index selection.
379    fn gen_paths(
380        &self,
381        conjunctions: &[ExprImpl],
382        logical_scan: &LogicalScan,
383        primary_table_row_size: usize,
384    ) -> Vec<PlanRef> {
385        let mut result = vec![];
386        for expr in conjunctions {
387            // it's OR clause!
388            if let ExprImpl::FunctionCall(function_call) = expr
389                && function_call.func_type() == ExprType::Or
390            {
391                let mut index_to_be_merged = vec![];
392
393                let disjunctions = to_disjunctions(expr.clone());
394                let (map, others) = self.clustering_disjunction(disjunctions);
395                let iter = map
396                    .into_iter()
397                    .map(|(column_index, expr)| (Some(column_index), expr))
398                    .chain(others.into_iter().map(|expr| (None, expr)));
399                for (column_index, expr) in iter {
400                    let mut index_paths = vec![];
401                    let conjunctions = to_conjunctions(expr);
402                    index_paths.extend(
403                        self.gen_index_path(column_index, &conjunctions, logical_scan)
404                            .into_iter(),
405                    );
406                    // complex condition, recursively gen paths
407                    if conjunctions.len() > 1 {
408                        index_paths.extend(
409                            self.gen_paths(&conjunctions, logical_scan, primary_table_row_size)
410                                .into_iter(),
411                        );
412                    }
413
414                    match self.choose_min_cost_path(&index_paths, primary_table_row_size) {
415                        None => {
416                            // One arm of OR clause can't use index, bail out
417                            index_to_be_merged.clear();
418                            break;
419                        }
420                        Some((path, _)) => index_to_be_merged.push(path),
421                    }
422                }
423
424                if let Some(path) = self.merge(index_to_be_merged) {
425                    result.push(path)
426                }
427            }
428        }
429
430        result
431    }
432
433    /// Clustering disjunction or expr by column index. If expr is complex, classify them as others.
434    ///
435    /// a = 1, b = 2, b = 3 -> map: [a, (a = 1)], [b, (b = 2 or b = 3)], others: []
436    ///
437    /// a = 1, (b = 2 and c = 3) -> map: [a, (a = 1)], others:
438    ///
439    /// (a > 1 and a < 8) or (c > 1 and c < 8)
440    /// -> map: [], others: [(a > 1 and a < 8), (c > 1 and c < 8)]
441    fn clustering_disjunction(
442        &self,
443        disjunctions: Vec<ExprImpl>,
444    ) -> (HashMap<usize, ExprImpl>, Vec<ExprImpl>) {
445        let mut map: HashMap<usize, ExprImpl> = HashMap::new();
446        let mut others = vec![];
447        for expr in disjunctions {
448            let idx = {
449                if let Some((input_ref, _const_expr)) = expr.as_eq_const() {
450                    Some(input_ref.index)
451                } else if let Some((input_ref, _in_const_list)) = expr.as_in_const_list() {
452                    Some(input_ref.index)
453                } else if let Some((input_ref, _op, _const_expr)) = expr.as_comparison_const() {
454                    Some(input_ref.index)
455                } else {
456                    None
457                }
458            };
459
460            if let Some(idx) = idx {
461                match map.entry(idx) {
462                    Occupied(mut entry) => {
463                        let expr2: ExprImpl = entry.get().to_owned();
464                        let or_expr = ExprImpl::FunctionCall(
465                            FunctionCall::new_unchecked(
466                                ExprType::Or,
467                                vec![expr, expr2],
468                                DataType::Boolean,
469                            )
470                            .into(),
471                        );
472                        entry.insert(or_expr);
473                    }
474                    Vacant(entry) => {
475                        entry.insert(expr);
476                    }
477                };
478            } else {
479                others.push(expr);
480                continue;
481            }
482        }
483
484        (map, others)
485    }
486
487    /// Given a conjunctions from one arm of an OR clause (basic unit to index selection), generate
488    /// all matching index path (including primary index) for the relation.
489    /// `column_index` (refers to primary table) is a hint can be used to prune index.
490    /// Steps:
491    /// 1. Take the combination of `conjunctions` to extract the potential clauses.
492    /// 2. For each potential clauses, generate index path if it can.
493    fn gen_index_path(
494        &self,
495        column_index: Option<usize>,
496        conjunctions: &[ExprImpl],
497        logical_scan: &LogicalScan,
498    ) -> Vec<PlanRef> {
499        // Assumption: use at most `MAX_COMBINATION_SIZE` clauses, we can determine which is the
500        // best index.
501        let mut combinations = vec![];
502        for i in 1..min(conjunctions.len(), MAX_COMBINATION_SIZE) + 1 {
503            combinations.extend(
504                conjunctions
505                    .iter()
506                    .take(min(conjunctions.len(), MAX_CONJUNCTION_SIZE))
507                    .combinations(i),
508            );
509        }
510
511        let mut result = vec![];
512
513        for index in logical_scan.table_indexes() {
514            if let Some(column_index) = column_index {
515                assert_eq!(conjunctions.len(), 1);
516                let p2s_mapping = index.primary_to_secondary_mapping();
517                match p2s_mapping.get(&column_index) {
518                    None => continue, // not found, prune this index
519                    Some(&idx) => {
520                        if index.index_table.pk()[0].column_index != idx {
521                            // not match, prune this index
522                            continue;
523                        }
524                    }
525                }
526            }
527
528            // try secondary index
529            for conj in &combinations {
530                let condition = Condition {
531                    conjunctions: conj.iter().map(|&x| x.to_owned()).collect(),
532                };
533                if let Some(index_access) = self.build_index_access(
534                    index.clone(),
535                    condition,
536                    logical_scan.ctx().clone(),
537                    logical_scan.as_of().clone(),
538                ) {
539                    result.push(index_access);
540                }
541            }
542        }
543
544        // try primary index
545        let primary_table = logical_scan.table();
546        if let Some(idx) = column_index {
547            assert_eq!(conjunctions.len(), 1);
548            if primary_table.pk[0].column_index != idx {
549                return result;
550            }
551        }
552
553        let primary_access = generic::TableScan::new(
554            primary_table
555                .pk
556                .iter()
557                .map(|x| x.column_index)
558                .collect_vec(),
559            logical_scan.table().clone(),
560            vec![],
561            vec![],
562            logical_scan.ctx(),
563            Condition {
564                conjunctions: conjunctions.to_vec(),
565            },
566            logical_scan.as_of().clone(),
567        );
568
569        result.push(primary_access.into());
570
571        result
572    }
573
574    /// build index access if predicate (refers to primary table) is covered by index
575    fn build_index_access(
576        &self,
577        index: Arc<TableIndex>,
578        predicate: Condition,
579        ctx: OptimizerContextRef,
580        as_of: Option<AsOf>,
581    ) -> Option<PlanRef> {
582        let mut rewriter = IndexPredicateRewriter::new(
583            index.primary_to_secondary_mapping(),
584            index.function_mapping(),
585            0,
586        );
587        let new_predicate = predicate.rewrite_expr(&mut rewriter);
588
589        // check condition is covered by index.
590        if !rewriter.covered_by_index() {
591            return None;
592        }
593
594        Some(
595            generic::TableScan::new(
596                index
597                    .primary_table_pk_ref_to_index_table()
598                    .iter()
599                    .map(|x| x.column_index)
600                    .collect_vec(),
601                index.index_table.clone(),
602                vec![],
603                vec![],
604                ctx,
605                new_predicate,
606                as_of,
607            )
608            .into(),
609        )
610    }
611
612    fn merge(&self, paths: Vec<PlanRef>) -> Option<PlanRef> {
613        if paths.is_empty() {
614            return None;
615        }
616
617        let new_paths = paths
618            .iter()
619            .flat_map(|path| {
620                if let Some(union) = path.as_logical_union() {
621                    union.inputs().to_vec()
622                } else if let Some(_scan) = path.as_logical_scan() {
623                    vec![path.clone()]
624                } else {
625                    unreachable!();
626                }
627            })
628            .sorted_by(|a, b| {
629                // sort inputs to make plan deterministic
630                a.as_logical_scan()
631                    .expect("expect to be a logical scan")
632                    .table_name()
633                    .cmp(
634                        b.as_logical_scan()
635                            .expect("expect to be a logical scan")
636                            .table_name(),
637                    )
638            })
639            .collect_vec();
640
641        Some(LogicalUnion::create(false, new_paths))
642    }
643
644    fn choose_min_cost_path(
645        &self,
646        paths: &[PlanRef],
647        primary_table_row_size: usize,
648    ) -> Option<(PlanRef, IndexCost)> {
649        paths
650            .iter()
651            .map(|path| {
652                if let Some(scan) = path.as_logical_scan() {
653                    let cost = self.estimate_table_scan_cost(scan, primary_table_row_size);
654                    (scan.clone().into(), cost)
655                } else if let Some(union) = path.as_logical_union() {
656                    let cost = union
657                        .inputs()
658                        .iter()
659                        .map(|input| {
660                            self.estimate_table_scan_cost(
661                                input.as_logical_scan().expect("expect to be a scan"),
662                                primary_table_row_size,
663                            )
664                        })
665                        .reduce(|a, b| a.add(&b))
666                        .unwrap();
667                    (union.clone().into(), cost)
668                } else {
669                    unreachable!()
670                }
671            })
672            .min_by(|(_, cost1), (_, cost2)| Ord::cmp(cost1, cost2))
673    }
674
675    fn estimate_table_scan_cost(&self, scan: &LogicalScan, row_size: usize) -> IndexCost {
676        let mut table_scan_io_estimator = TableScanIoEstimator::new(scan, row_size);
677        table_scan_io_estimator.estimate(scan.predicate())
678    }
679
680    fn estimate_full_table_scan_cost(&self, scan: &LogicalScan, row_size: usize) -> IndexCost {
681        let mut table_scan_io_estimator = TableScanIoEstimator::new(scan, row_size);
682        table_scan_io_estimator.estimate(&Condition::true_cond())
683    }
684
685    pub fn create_null_safe_equal_expr(
686        left: usize,
687        left_data_type: DataType,
688        right: usize,
689        right_data_type: DataType,
690    ) -> ExprImpl {
691        ExprImpl::FunctionCall(Box::new(FunctionCall::new_unchecked(
692            ExprType::IsNotDistinctFrom,
693            vec![
694                ExprImpl::InputRef(Box::new(InputRef::new(left, left_data_type))),
695                ExprImpl::InputRef(Box::new(InputRef::new(right, right_data_type))),
696            ],
697            DataType::Boolean,
698        )))
699    }
700}
701
702struct TableScanIoEstimator<'a> {
703    table_scan: &'a LogicalScan,
704    row_size: usize,
705    cost: Option<IndexCost>,
706}
707
708impl<'a> TableScanIoEstimator<'a> {
709    pub fn new(table_scan: &'a LogicalScan, row_size: usize) -> Self {
710        Self {
711            table_scan,
712            row_size,
713            cost: None,
714        }
715    }
716
717    pub fn estimate_row_size(table_scan: &LogicalScan) -> usize {
718        // 5 for table_id + 1 for vnode + 8 for epoch
719        let row_meta_field_estimate_size = 14_usize;
720        let table = table_scan.table();
721        row_meta_field_estimate_size
722            + table
723                .columns
724                .iter()
725                // add order key twice for its appearance both in key and value
726                .chain(table.pk.iter().map(|x| &table.columns[x.column_index]))
727                .map(|x| TableScanIoEstimator::estimate_data_type_size(&x.data_type))
728                .sum::<usize>()
729    }
730
731    fn estimate_data_type_size(data_type: &DataType) -> usize {
732        use std::mem::size_of;
733
734        match data_type {
735            DataType::Boolean => size_of::<bool>(),
736            DataType::Int16 => size_of::<i16>(),
737            DataType::Int32 => size_of::<i32>(),
738            DataType::Int64 => size_of::<i64>(),
739            DataType::Serial => size_of::<Serial>(),
740            DataType::Float32 => size_of::<f32>(),
741            DataType::Float64 => size_of::<f64>(),
742            DataType::Decimal => size_of::<Decimal>(),
743            DataType::Date => size_of::<Date>(),
744            DataType::Time => size_of::<Time>(),
745            DataType::Timestamp => size_of::<Timestamp>(),
746            DataType::Timestamptz => size_of::<Timestamptz>(),
747            DataType::Interval => size_of::<Interval>(),
748            DataType::Int256 => Int256::size(),
749            DataType::Varchar => 20,
750            DataType::Bytea => 20,
751            DataType::Jsonb => 20,
752            DataType::Struct { .. } => 20,
753            DataType::List { .. } => 20,
754            DataType::Map(_) => 20,
755            DataType::Vector(d) => d * size_of::<VectorDistanceType>(),
756        }
757    }
758
759    pub fn estimate(&mut self, predicate: &Condition) -> IndexCost {
760        // try to deal with OR condition
761        if predicate.conjunctions.len() == 1 {
762            self.visit_expr(&predicate.conjunctions[0]);
763            self.cost.take().unwrap_or_default()
764        } else {
765            self.estimate_conjunctions(&predicate.conjunctions)
766        }
767    }
768
769    fn estimate_conjunctions(&mut self, conjunctions: &[ExprImpl]) -> IndexCost {
770        let mut new_conjunctions = conjunctions.to_owned();
771
772        let mut match_item_vec = vec![];
773
774        for column_idx in self.table_scan.table().order_column_indices() {
775            let match_item = self.match_index_column(column_idx, &mut new_conjunctions);
776            // seeing range, we don't need to match anymore.
777            let should_break = match match_item {
778                MatchItem::Equal | MatchItem::In(_) => false,
779                MatchItem::RangeOneSideBound | MatchItem::RangeTwoSideBound | MatchItem::All => {
780                    true
781                }
782            };
783            match_item_vec.push(match_item);
784            if should_break {
785                break;
786            }
787        }
788
789        let index_cost = match_item_vec
790            .iter()
791            .enumerate()
792            .take(INDEX_MAX_LEN)
793            .map(|(i, match_item)| match match_item {
794                MatchItem::Equal => INDEX_COST_MATRIX[0][i],
795                MatchItem::In(num) => min(INDEX_COST_MATRIX[1][i], *num),
796                MatchItem::RangeTwoSideBound => INDEX_COST_MATRIX[2][i],
797                MatchItem::RangeOneSideBound => INDEX_COST_MATRIX[3][i],
798                MatchItem::All => INDEX_COST_MATRIX[4][i],
799            })
800            .reduce(|x, y| x * y)
801            .unwrap();
802
803        // If `index_cost` equals 1, it is a primary lookup
804        let primary_lookup = index_cost == 1;
805
806        IndexCost::new(index_cost, primary_lookup)
807            .mul(&IndexCost::new(self.row_size, primary_lookup))
808    }
809
810    fn match_index_column(
811        &mut self,
812        column_idx: usize,
813        conjunctions: &mut Vec<ExprImpl>,
814    ) -> MatchItem {
815        // Equal
816        for (i, expr) in conjunctions.iter().enumerate() {
817            if let Some((input_ref, _const_expr)) = expr.as_eq_const()
818                && input_ref.index == column_idx
819            {
820                conjunctions.remove(i);
821                return MatchItem::Equal;
822            }
823        }
824
825        // In
826        for (i, expr) in conjunctions.iter().enumerate() {
827            if let Some((input_ref, in_const_list)) = expr.as_in_const_list()
828                && input_ref.index == column_idx
829            {
830                conjunctions.remove(i);
831                return MatchItem::In(in_const_list.len());
832            }
833        }
834
835        // Range
836        let mut left_side_bound = false;
837        let mut right_side_bound = false;
838        let mut i = 0;
839        while i < conjunctions.len() {
840            let expr = &conjunctions[i];
841            if let Some((input_ref, op, _const_expr)) = expr.as_comparison_const()
842                && input_ref.index == column_idx
843            {
844                conjunctions.remove(i);
845                match op {
846                    ExprType::LessThan | ExprType::LessThanOrEqual => right_side_bound = true,
847                    ExprType::GreaterThan | ExprType::GreaterThanOrEqual => left_side_bound = true,
848                    _ => unreachable!(),
849                };
850            } else {
851                i += 1;
852            }
853        }
854
855        if left_side_bound && right_side_bound {
856            MatchItem::RangeTwoSideBound
857        } else if left_side_bound || right_side_bound {
858            MatchItem::RangeOneSideBound
859        } else {
860            MatchItem::All
861        }
862    }
863}
864
865enum MatchItem {
866    Equal,
867    In(usize),
868    RangeTwoSideBound,
869    RangeOneSideBound,
870    All,
871}
872
873#[derive(PartialEq, Eq, Hash, Clone, Debug, PartialOrd, Ord)]
874struct IndexCost {
875    cost: usize,
876    primary_lookup: bool,
877}
878
879impl Default for IndexCost {
880    fn default() -> Self {
881        Self {
882            cost: IndexCost::maximum(),
883            primary_lookup: false,
884        }
885    }
886}
887
888impl IndexCost {
889    fn new(cost: usize, primary_lookup: bool) -> IndexCost {
890        Self {
891            cost: min(cost, IndexCost::maximum()),
892            primary_lookup,
893        }
894    }
895
896    fn maximum() -> usize {
897        10000000
898    }
899
900    fn add(&self, other: &IndexCost) -> IndexCost {
901        IndexCost::new(
902            self.cost
903                .checked_add(other.cost)
904                .unwrap_or_else(IndexCost::maximum),
905            self.primary_lookup && other.primary_lookup,
906        )
907    }
908
909    fn mul(&self, other: &IndexCost) -> IndexCost {
910        IndexCost::new(
911            self.cost
912                .checked_mul(other.cost)
913                .unwrap_or_else(IndexCost::maximum),
914            self.primary_lookup && other.primary_lookup,
915        )
916    }
917
918    fn le(&self, other: &IndexCost) -> bool {
919        self.cost < other.cost
920    }
921}
922
923impl ExprVisitor for TableScanIoEstimator<'_> {
924    fn visit_function_call(&mut self, func_call: &FunctionCall) {
925        let cost = match func_call.func_type() {
926            ExprType::Or => func_call
927                .inputs()
928                .iter()
929                .map(|x| {
930                    let mut estimator = TableScanIoEstimator::new(self.table_scan, self.row_size);
931                    estimator.visit_expr(x);
932                    estimator.cost.take().unwrap_or_default()
933                })
934                .reduce(|x, y| x.add(&y))
935                .unwrap(),
936            ExprType::And => self.estimate_conjunctions(func_call.inputs()),
937            _ => {
938                let single = vec![ExprImpl::FunctionCall(func_call.clone().into())];
939                self.estimate_conjunctions(&single)
940            }
941        };
942        self.cost = Some(cost);
943    }
944}
945
946struct ShiftInputRefRewriter {
947    offset: usize,
948}
949impl ExprRewriter for ShiftInputRefRewriter {
950    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
951        InputRef::new(input_ref.index() + self.offset, input_ref.return_type()).into()
952    }
953}
954
955impl IndexSelectionRule {
956    pub fn create() -> BoxedRule {
957        Box::new(IndexSelectionRule {})
958    }
959}