risingwave_frontend/optimizer/rule/
index_selection_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! # Index selection cost matrix
16//!
17//! |`column_idx`| 0   |  1 | 2  | 3  | 4  | remark |
18//! |-----------|-----|----|----|----|----|---|
19//! |Equal      | 1   | 1  | 1  | 1  | 1  | |
20//! |In         | 10  | 8  | 5  | 5  | 5  | take the minimum value with actual in number |
21//! |Range(Two) | 600 | 50 | 20 | 10 | 10 | `RangeTwoSideBound` like a between 1 and 2 |
22//! |Range(One) | 1400| 70 | 25 | 15 | 10 | `RangeOneSideBound` like a > 1, a >= 1, a < 1|
23//! |All        | 4000| 100| 30 | 20 | 10 | |
24//!
25//! ```text
26//! index cost = cost(match type of 0 idx)
27//! * cost(match type of 1 idx)
28//! * ... cost(match type of the last idx)
29//! ```
30//!
31//! ## Example
32//!
33//! Given index order key (a, b, c)
34//!
35//! - For `a = 1 and b = 1 and c = 1`, its cost is 1 = Equal0 * Equal1 * Equal2 = 1
36//! - For `a in (xxx) and b = 1 and c = 1`, its cost is In0 * Equal1 * Equal2 = 10
37//! - For `a = 1 and b in (xxx)`, its cost is Equal0 * In1 * All2 = 1 * 8 * 50 = 400
38//! - For `a between xxx and yyy`, its cost is Range(Two)0 = 600
39//! - For `a = 1 and b between xxx and yyy`, its cost is Equal0 * Range(Two)1 = 50
40//! - For `a = 1 and b > 1`, its cost is Equal0 * Range(One)1 = 70
41//! - For `a = 1`, its cost is 100 = Equal0 * All1 = 100
42//! - For no condition, its cost is All0 = 4000
43//!
44//! With the assumption that the most effective part of a index is its prefix,
45//! cost decreases as `column_idx` increasing.
46//!
47//! For index order key length > 5, we just ignore the rest.
48
49use std::cmp::min;
50use std::collections::hash_map::Entry::{Occupied, Vacant};
51use std::collections::{BTreeMap, HashMap};
52use std::sync::Arc;
53
54use itertools::Itertools;
55use risingwave_common::catalog::Schema;
56use risingwave_common::types::{
57    DataType, Date, Decimal, Int256, Interval, Serial, Time, Timestamp, Timestamptz,
58};
59use risingwave_common::util::iter_util::ZipEqFast;
60use risingwave_pb::plan_common::JoinType;
61use risingwave_sqlparser::ast::AsOf;
62
63use super::prelude::{PlanRef, *};
64use crate::catalog::index_catalog::TableIndex;
65use crate::expr::{
66    Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, FunctionCall, InputRef, to_conjunctions,
67    to_disjunctions,
68};
69use crate::optimizer::optimizer_context::OptimizerContextRef;
70use crate::optimizer::plan_node::generic::GenericPlanRef;
71use crate::optimizer::plan_node::{
72    ColumnPruningContext, LogicalJoin, LogicalScan, LogicalUnion, PlanTreeNode, PlanTreeNodeBinary,
73    PredicatePushdown, PredicatePushdownContext, generic,
74};
75use crate::utils::Condition;
76
77const INDEX_MAX_LEN: usize = 5;
78const INDEX_COST_MATRIX: [[usize; INDEX_MAX_LEN]; 5] = [
79    [1, 1, 1, 1, 1],
80    [10, 8, 5, 5, 5],
81    [600, 50, 20, 10, 10],
82    [1400, 70, 25, 15, 10],
83    [4000, 100, 30, 20, 20],
84];
85const LOOKUP_COST_CONST: usize = 3;
86const MAX_COMBINATION_SIZE: usize = 3;
87const MAX_CONJUNCTION_SIZE: usize = 8;
88
89pub struct IndexSelectionRule {}
90
91impl Rule<Logical> for IndexSelectionRule {
92    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
93        let logical_scan: &LogicalScan = plan.as_logical_scan()?;
94        let indexes = logical_scan.table_indexes();
95        if indexes.is_empty() {
96            return None;
97        }
98        let primary_table_row_size = TableScanIoEstimator::estimate_row_size(logical_scan);
99        let primary_cost = min(
100            self.estimate_table_scan_cost(logical_scan, primary_table_row_size),
101            self.estimate_full_table_scan_cost(logical_scan, primary_table_row_size),
102        );
103
104        // If it is a primary lookup plan, avoid checking other indexes.
105        if primary_cost.primary_lookup {
106            return None;
107        }
108
109        let mut final_plan: PlanRef = logical_scan.clone().into();
110        let mut min_cost = primary_cost.clone();
111
112        for index in indexes {
113            if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
114                let index_cost = self.estimate_table_scan_cost(
115                    &index_scan,
116                    TableScanIoEstimator::estimate_row_size(&index_scan),
117                );
118
119                if index_cost.le(&min_cost) {
120                    min_cost = index_cost;
121                    final_plan = index_scan.into();
122                }
123            } else {
124                // non-covering index selection
125                let (index_lookup, lookup_cost) = self.gen_index_lookup(logical_scan, index);
126                if lookup_cost.le(&min_cost) {
127                    min_cost = lookup_cost;
128                    final_plan = index_lookup;
129                }
130            }
131        }
132
133        if let Some((merge_index, merge_index_cost)) = self.index_merge_selection(logical_scan)
134            && merge_index_cost.le(&min_cost)
135        {
136            min_cost = merge_index_cost;
137            final_plan = merge_index;
138        }
139
140        if min_cost == primary_cost {
141            None
142        } else {
143            Some(final_plan)
144        }
145    }
146}
147
148struct IndexPredicateRewriter<'a> {
149    p2s_mapping: &'a BTreeMap<usize, usize>,
150    function_mapping: &'a HashMap<FunctionCall, usize>,
151    offset: usize,
152    covered_by_index: bool,
153}
154
155impl<'a> IndexPredicateRewriter<'a> {
156    fn new(
157        p2s_mapping: &'a BTreeMap<usize, usize>,
158        function_mapping: &'a HashMap<FunctionCall, usize>,
159        offset: usize,
160    ) -> Self {
161        Self {
162            p2s_mapping,
163            function_mapping,
164            offset,
165            covered_by_index: true,
166        }
167    }
168
169    fn covered_by_index(&self) -> bool {
170        self.covered_by_index
171    }
172}
173
174impl ExprRewriter for IndexPredicateRewriter<'_> {
175    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
176        // transform primary predicate to index predicate if it can
177        if self.p2s_mapping.contains_key(&input_ref.index) {
178            InputRef::new(
179                *self.p2s_mapping.get(&input_ref.index()).unwrap(),
180                input_ref.return_type(),
181            )
182            .into()
183        } else {
184            self.covered_by_index = false;
185            InputRef::new(input_ref.index() + self.offset, input_ref.return_type()).into()
186        }
187    }
188
189    fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
190        if let Some(index) = self.function_mapping.get(&func_call) {
191            return InputRef::new(*index, func_call.return_type()).into();
192        }
193
194        let (func_type, inputs, ret) = func_call.decompose();
195        let inputs = inputs
196            .into_iter()
197            .map(|expr| self.rewrite_expr(expr))
198            .collect();
199        FunctionCall::new_unchecked(func_type, inputs, ret).into()
200    }
201}
202
203impl IndexSelectionRule {
204    fn gen_index_lookup(
205        &self,
206        logical_scan: &LogicalScan,
207        index: &TableIndex,
208    ) -> (PlanRef, IndexCost) {
209        // 1. logical_scan ->  logical_join
210        //                      /        \
211        //                index_scan   primary_table_scan
212        let index_scan = LogicalScan::create(
213            index.index_table.clone(),
214            logical_scan.ctx(),
215            logical_scan.as_of().clone(),
216        );
217        // We use `schema.len` instead of `index_item.len` here,
218        // because schema contains system columns like `_rw_timestamp` column which is not represented in the index item.
219        let offset = index_scan.table().columns().len();
220
221        let primary_table_scan = LogicalScan::create(
222            index.primary_table.clone(),
223            logical_scan.ctx(),
224            logical_scan.as_of().clone(),
225        );
226
227        let predicate = logical_scan.predicate().clone();
228        let mut rewriter = IndexPredicateRewriter::new(
229            index.primary_to_secondary_mapping(),
230            index.function_mapping(),
231            offset,
232        );
233        let new_predicate = predicate.rewrite_expr(&mut rewriter);
234
235        let conjunctions = index
236            .primary_table_pk_ref_to_index_table()
237            .iter()
238            .zip_eq_fast(index.primary_table.pk.iter())
239            .map(|(x, y)| {
240                Self::create_null_safe_equal_expr(
241                    x.column_index,
242                    index.index_table.columns[x.column_index]
243                        .data_type()
244                        .clone(),
245                    y.column_index + offset,
246                    index.primary_table.columns[y.column_index]
247                        .data_type()
248                        .clone(),
249                )
250            })
251            .chain(new_predicate)
252            .collect_vec();
253        let on = Condition { conjunctions };
254        let join: PlanRef = LogicalJoin::new(
255            index_scan.into(),
256            primary_table_scan.into(),
257            JoinType::Inner,
258            on,
259        )
260        .into();
261
262        // 2. push down predicate, so we can calculate the cost of index lookup
263        let join_ref = join.predicate_pushdown(
264            Condition::true_cond(),
265            &mut PredicatePushdownContext::new(join.clone()),
266        );
267
268        let join_with_predicate_push_down =
269            join_ref.as_logical_join().expect("must be a logical join");
270        let new_join_left = join_with_predicate_push_down.left();
271        let index_scan_with_predicate: &LogicalScan = new_join_left
272            .as_logical_scan()
273            .expect("must be a logical scan");
274
275        // 3. calculate index cost, index lookup use primary table to estimate row size.
276        let index_cost = self.estimate_table_scan_cost(
277            index_scan_with_predicate,
278            TableScanIoEstimator::estimate_row_size(logical_scan),
279        );
280        // lookup cost = index cost * LOOKUP_COST_CONST
281        let lookup_cost = index_cost.mul(&IndexCost::new(LOOKUP_COST_CONST, false));
282
283        // 4. keep the same schema with original logical_scan
284        let scan_output_col_idx = logical_scan.output_col_idx();
285        let lookup_join = join_ref.prune_col(
286            &scan_output_col_idx
287                .iter()
288                .map(|&col_idx| col_idx + offset)
289                .collect_vec(),
290            &mut ColumnPruningContext::new(join_ref.clone()),
291        );
292
293        (lookup_join, lookup_cost)
294    }
295
296    /// Index Merge Selection
297    /// Deal with predicate like a = 1 or b = 1
298    /// Merge index scans from a table, currently merge is union semantic.
299    fn index_merge_selection(&self, logical_scan: &LogicalScan) -> Option<(PlanRef, IndexCost)> {
300        let predicate = logical_scan.predicate().clone();
301        // Index merge is kind of index lookup join so use primary table row size to estimate index
302        // cost.
303        let primary_table_row_size = TableScanIoEstimator::estimate_row_size(logical_scan);
304        // 1. choose lowest cost index merge path
305        let paths = self.gen_paths(
306            &predicate.conjunctions,
307            logical_scan,
308            primary_table_row_size,
309        );
310        let (index_access, index_access_cost) =
311            self.choose_min_cost_path(&paths, primary_table_row_size)?;
312
313        // 2. lookup primary table
314        // the schema of index_access is the order key of primary table .
315        let schema: &Schema = index_access.schema();
316        let index_access_len = schema.len();
317
318        let mut shift_input_ref_rewriter = ShiftInputRefRewriter {
319            offset: index_access_len,
320        };
321        let new_predicate = predicate.rewrite_expr(&mut shift_input_ref_rewriter);
322
323        let primary_table = logical_scan.table();
324
325        let primary_table_scan = LogicalScan::create(
326            logical_scan.table().clone(),
327            logical_scan.ctx(),
328            logical_scan.as_of().clone(),
329        );
330
331        let conjunctions = primary_table
332            .pk
333            .iter()
334            .enumerate()
335            .map(|(x, y)| {
336                Self::create_null_safe_equal_expr(
337                    x,
338                    schema.fields[x].data_type.clone(),
339                    y.column_index + index_access_len,
340                    primary_table.columns[y.column_index].data_type.clone(),
341                )
342            })
343            .chain(new_predicate)
344            .collect_vec();
345
346        let on = Condition { conjunctions };
347        let join: PlanRef =
348            LogicalJoin::new(index_access, primary_table_scan.into(), JoinType::Inner, on).into();
349
350        // 3 push down predicate
351        let join_ref = join.predicate_pushdown(
352            Condition::true_cond(),
353            &mut PredicatePushdownContext::new(join.clone()),
354        );
355
356        // 4. keep the same schema with original logical_scan
357        let scan_output_col_idx = logical_scan.output_col_idx();
358        let lookup_join = join_ref.prune_col(
359            &scan_output_col_idx
360                .iter()
361                .map(|&col_idx| col_idx + index_access_len)
362                .collect_vec(),
363            &mut ColumnPruningContext::new(join_ref.clone()),
364        );
365
366        Some((
367            lookup_join,
368            index_access_cost.mul(&IndexCost::new(LOOKUP_COST_CONST, false)),
369        ))
370    }
371
372    /// Generate possible paths that can be used to access.
373    /// The schema of output is the order key of primary table, so it can be used to lookup primary
374    /// table later.
375    /// Method `gen_paths` handles the complex condition recursively which may contains nested `AND`
376    /// and `OR`. However, Method `gen_index_path` handles one arm of an OR clause which is a
377    /// basic unit for index selection.
378    fn gen_paths(
379        &self,
380        conjunctions: &[ExprImpl],
381        logical_scan: &LogicalScan,
382        primary_table_row_size: usize,
383    ) -> Vec<PlanRef> {
384        let mut result = vec![];
385        for expr in conjunctions {
386            // it's OR clause!
387            if let ExprImpl::FunctionCall(function_call) = expr
388                && function_call.func_type() == ExprType::Or
389            {
390                let mut index_to_be_merged = vec![];
391
392                let disjunctions = to_disjunctions(expr.clone());
393                let (map, others) = self.clustering_disjunction(disjunctions);
394                let iter = map
395                    .into_iter()
396                    .map(|(column_index, expr)| (Some(column_index), expr))
397                    .chain(others.into_iter().map(|expr| (None, expr)));
398                for (column_index, expr) in iter {
399                    let mut index_paths = vec![];
400                    let conjunctions = to_conjunctions(expr);
401                    index_paths.extend(
402                        self.gen_index_path(column_index, &conjunctions, logical_scan)
403                            .into_iter(),
404                    );
405                    // complex condition, recursively gen paths
406                    if conjunctions.len() > 1 {
407                        index_paths.extend(
408                            self.gen_paths(&conjunctions, logical_scan, primary_table_row_size)
409                                .into_iter(),
410                        );
411                    }
412
413                    match self.choose_min_cost_path(&index_paths, primary_table_row_size) {
414                        None => {
415                            // One arm of OR clause can't use index, bail out
416                            index_to_be_merged.clear();
417                            break;
418                        }
419                        Some((path, _)) => index_to_be_merged.push(path),
420                    }
421                }
422
423                if let Some(path) = self.merge(index_to_be_merged) {
424                    result.push(path)
425                }
426            }
427        }
428
429        result
430    }
431
432    /// Clustering disjunction or expr by column index. If expr is complex, classify them as others.
433    ///
434    /// a = 1, b = 2, b = 3 -> map: [a, (a = 1)], [b, (b = 2 or b = 3)], others: []
435    ///
436    /// a = 1, (b = 2 and c = 3) -> map: [a, (a = 1)], others:
437    ///
438    /// (a > 1 and a < 8) or (c > 1 and c < 8)
439    /// -> map: [], others: [(a > 1 and a < 8), (c > 1 and c < 8)]
440    fn clustering_disjunction(
441        &self,
442        disjunctions: Vec<ExprImpl>,
443    ) -> (HashMap<usize, ExprImpl>, Vec<ExprImpl>) {
444        let mut map: HashMap<usize, ExprImpl> = HashMap::new();
445        let mut others = vec![];
446        for expr in disjunctions {
447            let idx = {
448                if let Some((input_ref, _const_expr)) = expr.as_eq_const() {
449                    Some(input_ref.index)
450                } else if let Some((input_ref, _in_const_list)) = expr.as_in_const_list() {
451                    Some(input_ref.index)
452                } else if let Some((input_ref, _op, _const_expr)) = expr.as_comparison_const() {
453                    Some(input_ref.index)
454                } else {
455                    None
456                }
457            };
458
459            if let Some(idx) = idx {
460                match map.entry(idx) {
461                    Occupied(mut entry) => {
462                        let expr2: ExprImpl = entry.get().to_owned();
463                        let or_expr = ExprImpl::FunctionCall(
464                            FunctionCall::new_unchecked(
465                                ExprType::Or,
466                                vec![expr, expr2],
467                                DataType::Boolean,
468                            )
469                            .into(),
470                        );
471                        entry.insert(or_expr);
472                    }
473                    Vacant(entry) => {
474                        entry.insert(expr);
475                    }
476                };
477            } else {
478                others.push(expr);
479                continue;
480            }
481        }
482
483        (map, others)
484    }
485
486    /// Given a conjunctions from one arm of an OR clause (basic unit to index selection), generate
487    /// all matching index path (including primary index) for the relation.
488    /// `column_index` (refers to primary table) is a hint can be used to prune index.
489    /// Steps:
490    /// 1. Take the combination of `conjunctions` to extract the potential clauses.
491    /// 2. For each potential clauses, generate index path if it can.
492    fn gen_index_path(
493        &self,
494        column_index: Option<usize>,
495        conjunctions: &[ExprImpl],
496        logical_scan: &LogicalScan,
497    ) -> Vec<PlanRef> {
498        // Assumption: use at most `MAX_COMBINATION_SIZE` clauses, we can determine which is the
499        // best index.
500        let mut combinations = vec![];
501        for i in 1..min(conjunctions.len(), MAX_COMBINATION_SIZE) + 1 {
502            combinations.extend(
503                conjunctions
504                    .iter()
505                    .take(min(conjunctions.len(), MAX_CONJUNCTION_SIZE))
506                    .combinations(i),
507            );
508        }
509
510        let mut result = vec![];
511
512        for index in logical_scan.table_indexes() {
513            if let Some(column_index) = column_index {
514                assert_eq!(conjunctions.len(), 1);
515                let p2s_mapping = index.primary_to_secondary_mapping();
516                match p2s_mapping.get(&column_index) {
517                    None => continue, // not found, prune this index
518                    Some(&idx) => {
519                        if index.index_table.pk()[0].column_index != idx {
520                            // not match, prune this index
521                            continue;
522                        }
523                    }
524                }
525            }
526
527            // try secondary index
528            for conj in &combinations {
529                let condition = Condition {
530                    conjunctions: conj.iter().map(|&x| x.to_owned()).collect(),
531                };
532                if let Some(index_access) = self.build_index_access(
533                    index.clone(),
534                    condition,
535                    logical_scan.ctx().clone(),
536                    logical_scan.as_of().clone(),
537                ) {
538                    result.push(index_access);
539                }
540            }
541        }
542
543        // try primary index
544        let primary_table = logical_scan.table();
545        if let Some(idx) = column_index {
546            assert_eq!(conjunctions.len(), 1);
547            if primary_table.pk[0].column_index != idx {
548                return result;
549            }
550        }
551
552        let primary_access = generic::TableScan::new(
553            primary_table
554                .pk
555                .iter()
556                .map(|x| x.column_index)
557                .collect_vec(),
558            logical_scan.table().clone(),
559            vec![],
560            logical_scan.ctx(),
561            Condition {
562                conjunctions: conjunctions.to_vec(),
563            },
564            logical_scan.as_of().clone(),
565        );
566
567        result.push(primary_access.into());
568
569        result
570    }
571
572    /// build index access if predicate (refers to primary table) is covered by index
573    fn build_index_access(
574        &self,
575        index: Arc<TableIndex>,
576        predicate: Condition,
577        ctx: OptimizerContextRef,
578        as_of: Option<AsOf>,
579    ) -> Option<PlanRef> {
580        let mut rewriter = IndexPredicateRewriter::new(
581            index.primary_to_secondary_mapping(),
582            index.function_mapping(),
583            0,
584        );
585        let new_predicate = predicate.rewrite_expr(&mut rewriter);
586
587        // check condition is covered by index.
588        if !rewriter.covered_by_index() {
589            return None;
590        }
591
592        Some(
593            generic::TableScan::new(
594                index
595                    .primary_table_pk_ref_to_index_table()
596                    .iter()
597                    .map(|x| x.column_index)
598                    .collect_vec(),
599                index.index_table.clone(),
600                vec![],
601                ctx,
602                new_predicate,
603                as_of,
604            )
605            .into(),
606        )
607    }
608
609    fn merge(&self, paths: Vec<PlanRef>) -> Option<PlanRef> {
610        if paths.is_empty() {
611            return None;
612        }
613
614        let new_paths = paths
615            .iter()
616            .flat_map(|path| {
617                if let Some(union) = path.as_logical_union() {
618                    union.inputs().to_vec()
619                } else if let Some(_scan) = path.as_logical_scan() {
620                    vec![path.clone()]
621                } else {
622                    unreachable!();
623                }
624            })
625            .sorted_by(|a, b| {
626                // sort inputs to make plan deterministic
627                a.as_logical_scan()
628                    .expect("expect to be a logical scan")
629                    .table_name()
630                    .cmp(
631                        b.as_logical_scan()
632                            .expect("expect to be a logical scan")
633                            .table_name(),
634                    )
635            })
636            .collect_vec();
637
638        Some(LogicalUnion::create(false, new_paths))
639    }
640
641    fn choose_min_cost_path(
642        &self,
643        paths: &[PlanRef],
644        primary_table_row_size: usize,
645    ) -> Option<(PlanRef, IndexCost)> {
646        paths
647            .iter()
648            .map(|path| {
649                if let Some(scan) = path.as_logical_scan() {
650                    let cost = self.estimate_table_scan_cost(scan, primary_table_row_size);
651                    (scan.clone().into(), cost)
652                } else if let Some(union) = path.as_logical_union() {
653                    let cost = union
654                        .inputs()
655                        .iter()
656                        .map(|input| {
657                            self.estimate_table_scan_cost(
658                                input.as_logical_scan().expect("expect to be a scan"),
659                                primary_table_row_size,
660                            )
661                        })
662                        .reduce(|a, b| a.add(&b))
663                        .unwrap();
664                    (union.clone().into(), cost)
665                } else {
666                    unreachable!()
667                }
668            })
669            .min_by(|(_, cost1), (_, cost2)| Ord::cmp(cost1, cost2))
670    }
671
672    fn estimate_table_scan_cost(&self, scan: &LogicalScan, row_size: usize) -> IndexCost {
673        let mut table_scan_io_estimator = TableScanIoEstimator::new(scan, row_size);
674        table_scan_io_estimator.estimate(scan.predicate())
675    }
676
677    fn estimate_full_table_scan_cost(&self, scan: &LogicalScan, row_size: usize) -> IndexCost {
678        let mut table_scan_io_estimator = TableScanIoEstimator::new(scan, row_size);
679        table_scan_io_estimator.estimate(&Condition::true_cond())
680    }
681
682    fn create_null_safe_equal_expr(
683        left: usize,
684        left_data_type: DataType,
685        right: usize,
686        right_data_type: DataType,
687    ) -> ExprImpl {
688        ExprImpl::FunctionCall(Box::new(FunctionCall::new_unchecked(
689            ExprType::IsNotDistinctFrom,
690            vec![
691                ExprImpl::InputRef(Box::new(InputRef::new(left, left_data_type))),
692                ExprImpl::InputRef(Box::new(InputRef::new(right, right_data_type))),
693            ],
694            DataType::Boolean,
695        )))
696    }
697}
698
699struct TableScanIoEstimator<'a> {
700    table_scan: &'a LogicalScan,
701    row_size: usize,
702    cost: Option<IndexCost>,
703}
704
705impl<'a> TableScanIoEstimator<'a> {
706    pub fn new(table_scan: &'a LogicalScan, row_size: usize) -> Self {
707        Self {
708            table_scan,
709            row_size,
710            cost: None,
711        }
712    }
713
714    pub fn estimate_row_size(table_scan: &LogicalScan) -> usize {
715        // 5 for table_id + 1 for vnode + 8 for epoch
716        let row_meta_field_estimate_size = 14_usize;
717        let table = table_scan.table();
718        row_meta_field_estimate_size
719            + table
720                .columns
721                .iter()
722                // add order key twice for its appearance both in key and value
723                .chain(table.pk.iter().map(|x| &table.columns[x.column_index]))
724                .map(|x| TableScanIoEstimator::estimate_data_type_size(&x.data_type))
725                .sum::<usize>()
726    }
727
728    fn estimate_data_type_size(data_type: &DataType) -> usize {
729        use std::mem::size_of;
730
731        match data_type {
732            DataType::Boolean => size_of::<bool>(),
733            DataType::Int16 => size_of::<i16>(),
734            DataType::Int32 => size_of::<i32>(),
735            DataType::Int64 => size_of::<i64>(),
736            DataType::Serial => size_of::<Serial>(),
737            DataType::Float32 => size_of::<f32>(),
738            DataType::Float64 => size_of::<f64>(),
739            DataType::Decimal => size_of::<Decimal>(),
740            DataType::Date => size_of::<Date>(),
741            DataType::Time => size_of::<Time>(),
742            DataType::Timestamp => size_of::<Timestamp>(),
743            DataType::Timestamptz => size_of::<Timestamptz>(),
744            DataType::Interval => size_of::<Interval>(),
745            DataType::Int256 => Int256::size(),
746            DataType::Varchar => 20,
747            DataType::Bytea => 20,
748            DataType::Jsonb => 20,
749            DataType::Struct { .. } => 20,
750            DataType::List { .. } => 20,
751            DataType::Map(_) => 20,
752            DataType::Vector(_) => todo!("VECTOR_PLACEHOLDER"),
753        }
754    }
755
756    pub fn estimate(&mut self, predicate: &Condition) -> IndexCost {
757        // try to deal with OR condition
758        if predicate.conjunctions.len() == 1 {
759            self.visit_expr(&predicate.conjunctions[0]);
760            self.cost.take().unwrap_or_default()
761        } else {
762            self.estimate_conjunctions(&predicate.conjunctions)
763        }
764    }
765
766    fn estimate_conjunctions(&mut self, conjunctions: &[ExprImpl]) -> IndexCost {
767        let mut new_conjunctions = conjunctions.to_owned();
768
769        let mut match_item_vec = vec![];
770
771        for column_idx in self.table_scan.table().order_column_indices() {
772            let match_item = self.match_index_column(column_idx, &mut new_conjunctions);
773            // seeing range, we don't need to match anymore.
774            let should_break = match match_item {
775                MatchItem::Equal | MatchItem::In(_) => false,
776                MatchItem::RangeOneSideBound | MatchItem::RangeTwoSideBound | MatchItem::All => {
777                    true
778                }
779            };
780            match_item_vec.push(match_item);
781            if should_break {
782                break;
783            }
784        }
785
786        let index_cost = match_item_vec
787            .iter()
788            .enumerate()
789            .take(INDEX_MAX_LEN)
790            .map(|(i, match_item)| match match_item {
791                MatchItem::Equal => INDEX_COST_MATRIX[0][i],
792                MatchItem::In(num) => min(INDEX_COST_MATRIX[1][i], *num),
793                MatchItem::RangeTwoSideBound => INDEX_COST_MATRIX[2][i],
794                MatchItem::RangeOneSideBound => INDEX_COST_MATRIX[3][i],
795                MatchItem::All => INDEX_COST_MATRIX[4][i],
796            })
797            .reduce(|x, y| x * y)
798            .unwrap();
799
800        // If `index_cost` equals 1, it is a primary lookup
801        let primary_lookup = index_cost == 1;
802
803        IndexCost::new(index_cost, primary_lookup)
804            .mul(&IndexCost::new(self.row_size, primary_lookup))
805    }
806
807    fn match_index_column(
808        &mut self,
809        column_idx: usize,
810        conjunctions: &mut Vec<ExprImpl>,
811    ) -> MatchItem {
812        // Equal
813        for (i, expr) in conjunctions.iter().enumerate() {
814            if let Some((input_ref, _const_expr)) = expr.as_eq_const()
815                && input_ref.index == column_idx
816            {
817                conjunctions.remove(i);
818                return MatchItem::Equal;
819            }
820        }
821
822        // In
823        for (i, expr) in conjunctions.iter().enumerate() {
824            if let Some((input_ref, in_const_list)) = expr.as_in_const_list()
825                && input_ref.index == column_idx
826            {
827                conjunctions.remove(i);
828                return MatchItem::In(in_const_list.len());
829            }
830        }
831
832        // Range
833        let mut left_side_bound = false;
834        let mut right_side_bound = false;
835        let mut i = 0;
836        while i < conjunctions.len() {
837            let expr = &conjunctions[i];
838            if let Some((input_ref, op, _const_expr)) = expr.as_comparison_const()
839                && input_ref.index == column_idx
840            {
841                conjunctions.remove(i);
842                match op {
843                    ExprType::LessThan | ExprType::LessThanOrEqual => right_side_bound = true,
844                    ExprType::GreaterThan | ExprType::GreaterThanOrEqual => left_side_bound = true,
845                    _ => unreachable!(),
846                };
847            } else {
848                i += 1;
849            }
850        }
851
852        if left_side_bound && right_side_bound {
853            MatchItem::RangeTwoSideBound
854        } else if left_side_bound || right_side_bound {
855            MatchItem::RangeOneSideBound
856        } else {
857            MatchItem::All
858        }
859    }
860}
861
862enum MatchItem {
863    Equal,
864    In(usize),
865    RangeTwoSideBound,
866    RangeOneSideBound,
867    All,
868}
869
870#[derive(PartialEq, Eq, Hash, Clone, Debug, PartialOrd, Ord)]
871struct IndexCost {
872    cost: usize,
873    primary_lookup: bool,
874}
875
876impl Default for IndexCost {
877    fn default() -> Self {
878        Self {
879            cost: IndexCost::maximum(),
880            primary_lookup: false,
881        }
882    }
883}
884
885impl IndexCost {
886    fn new(cost: usize, primary_lookup: bool) -> IndexCost {
887        Self {
888            cost: min(cost, IndexCost::maximum()),
889            primary_lookup,
890        }
891    }
892
893    fn maximum() -> usize {
894        10000000
895    }
896
897    fn add(&self, other: &IndexCost) -> IndexCost {
898        IndexCost::new(
899            self.cost
900                .checked_add(other.cost)
901                .unwrap_or_else(IndexCost::maximum),
902            self.primary_lookup && other.primary_lookup,
903        )
904    }
905
906    fn mul(&self, other: &IndexCost) -> IndexCost {
907        IndexCost::new(
908            self.cost
909                .checked_mul(other.cost)
910                .unwrap_or_else(IndexCost::maximum),
911            self.primary_lookup && other.primary_lookup,
912        )
913    }
914
915    fn le(&self, other: &IndexCost) -> bool {
916        self.cost < other.cost
917    }
918}
919
920impl ExprVisitor for TableScanIoEstimator<'_> {
921    fn visit_function_call(&mut self, func_call: &FunctionCall) {
922        let cost = match func_call.func_type() {
923            ExprType::Or => func_call
924                .inputs()
925                .iter()
926                .map(|x| {
927                    let mut estimator = TableScanIoEstimator::new(self.table_scan, self.row_size);
928                    estimator.visit_expr(x);
929                    estimator.cost.take().unwrap_or_default()
930                })
931                .reduce(|x, y| x.add(&y))
932                .unwrap(),
933            ExprType::And => self.estimate_conjunctions(func_call.inputs()),
934            _ => {
935                let single = vec![ExprImpl::FunctionCall(func_call.clone().into())];
936                self.estimate_conjunctions(&single)
937            }
938        };
939        self.cost = Some(cost);
940    }
941}
942
943struct ShiftInputRefRewriter {
944    offset: usize,
945}
946impl ExprRewriter for ShiftInputRefRewriter {
947    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
948        InputRef::new(input_ref.index() + self.offset, input_ref.return_type()).into()
949    }
950}
951
952impl IndexSelectionRule {
953    pub fn create() -> BoxedRule {
954        Box::new(IndexSelectionRule {})
955    }
956}