risingwave_frontend/optimizer/rule/
index_selection_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! # Index selection cost matrix
16//!
17//! |`column_idx`| 0   |  1 | 2  | 3  | 4  | remark |
18//! |-----------|-----|----|----|----|----|---|
19//! |Equal      | 1   | 1  | 1  | 1  | 1  | |
20//! |In         | 10  | 8  | 5  | 5  | 5  | take the minimum value with actual in number |
21//! |Range(Two) | 600 | 50 | 20 | 10 | 10 | `RangeTwoSideBound` like a between 1 and 2 |
22//! |Range(One) | 1400| 70 | 25 | 15 | 10 | `RangeOneSideBound` like a > 1, a >= 1, a < 1|
23//! |All        | 4000| 100| 30 | 20 | 10 | |
24//!
25//! ```text
26//! index cost = cost(match type of 0 idx)
27//! * cost(match type of 1 idx)
28//! * ... cost(match type of the last idx)
29//! ```
30//!
31//! ## Example
32//!
33//! Given index order key (a, b, c)
34//!
35//! - For `a = 1 and b = 1 and c = 1`, its cost is 1 = Equal0 * Equal1 * Equal2 = 1
36//! - For `a in (xxx) and b = 1 and c = 1`, its cost is In0 * Equal1 * Equal2 = 10
37//! - For `a = 1 and b in (xxx)`, its cost is Equal0 * In1 * All2 = 1 * 8 * 50 = 400
38//! - For `a between xxx and yyy`, its cost is Range(Two)0 = 600
39//! - For `a = 1 and b between xxx and yyy`, its cost is Equal0 * Range(Two)1 = 50
40//! - For `a = 1 and b > 1`, its cost is Equal0 * Range(One)1 = 70
41//! - For `a = 1`, its cost is 100 = Equal0 * All1 = 100
42//! - For no condition, its cost is All0 = 4000
43//!
44//! With the assumption that the most effective part of a index is its prefix,
45//! cost decreases as `column_idx` increasing.
46//!
47//! For index order key length > 5, we just ignore the rest.
48
49use std::cmp::min;
50use std::collections::hash_map::Entry::{Occupied, Vacant};
51use std::collections::{BTreeMap, HashMap};
52use std::sync::Arc;
53
54use itertools::Itertools;
55use risingwave_common::array::VectorDistanceType;
56use risingwave_common::catalog::Schema;
57use risingwave_common::types::{
58    DataType, Date, Decimal, Int256, Interval, Serial, Time, Timestamp, Timestamptz,
59};
60use risingwave_common::util::iter_util::ZipEqFast;
61use risingwave_pb::plan_common::JoinType;
62use risingwave_sqlparser::ast::AsOf;
63
64use super::prelude::{PlanRef, *};
65use crate::catalog::index_catalog::TableIndex;
66use crate::expr::{
67    Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, FunctionCall, InputRef, to_conjunctions,
68    to_disjunctions,
69};
70use crate::optimizer::optimizer_context::OptimizerContextRef;
71use crate::optimizer::plan_node::generic::GenericPlanRef;
72use crate::optimizer::plan_node::{
73    ColumnPruningContext, LogicalJoin, LogicalScan, LogicalUnion, PlanTreeNode, PlanTreeNodeBinary,
74    PredicatePushdown, PredicatePushdownContext, generic,
75};
76use crate::utils::Condition;
77
78const INDEX_MAX_LEN: usize = 5;
79const INDEX_COST_MATRIX: [[usize; INDEX_MAX_LEN]; 5] = [
80    [1, 1, 1, 1, 1],
81    [10, 8, 5, 5, 5],
82    [600, 50, 20, 10, 10],
83    [1400, 70, 25, 15, 10],
84    [4000, 100, 30, 20, 20],
85];
86const LOOKUP_COST_CONST: usize = 3;
87const MAX_COMBINATION_SIZE: usize = 4;
88const MAX_CONJUNCTION_SIZE: usize = 8;
89
90pub struct IndexSelectionRule {}
91
92impl Rule<Logical> for IndexSelectionRule {
93    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
94        let logical_scan: &LogicalScan = plan.as_logical_scan()?;
95        let indexes = logical_scan.table_indexes();
96        if indexes.is_empty() {
97            return None;
98        }
99        let primary_table_row_size = TableScanIoEstimator::estimate_row_size(logical_scan);
100        let primary_cost = min(
101            self.estimate_table_scan_cost(logical_scan, primary_table_row_size),
102            self.estimate_full_table_scan_cost(logical_scan, primary_table_row_size),
103        );
104
105        // If it is a primary lookup plan, avoid checking other indexes.
106        if primary_cost.primary_lookup {
107            return None;
108        }
109
110        let mut final_plan: PlanRef = logical_scan.clone().into();
111        let mut min_cost = primary_cost.clone();
112
113        for index in indexes {
114            if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
115                let index_cost = self.estimate_table_scan_cost(
116                    &index_scan,
117                    TableScanIoEstimator::estimate_row_size(&index_scan),
118                );
119
120                if index_cost.le(&min_cost) {
121                    min_cost = index_cost;
122                    final_plan = index_scan.into();
123                }
124            } else {
125                // non-covering index selection
126                let (index_lookup, lookup_cost) = self.gen_index_lookup(logical_scan, index);
127                if lookup_cost.le(&min_cost) {
128                    min_cost = lookup_cost;
129                    final_plan = index_lookup;
130                }
131            }
132        }
133
134        if let Some((merge_index, merge_index_cost)) = self.index_merge_selection(logical_scan)
135            && merge_index_cost.le(&min_cost)
136        {
137            min_cost = merge_index_cost;
138            final_plan = merge_index;
139        }
140
141        if min_cost == primary_cost {
142            None
143        } else {
144            Some(final_plan)
145        }
146    }
147}
148
149struct IndexPredicateRewriter<'a> {
150    p2s_mapping: &'a BTreeMap<usize, usize>,
151    function_mapping: &'a HashMap<FunctionCall, usize>,
152    offset: usize,
153    covered_by_index: bool,
154}
155
156impl<'a> IndexPredicateRewriter<'a> {
157    fn new(
158        p2s_mapping: &'a BTreeMap<usize, usize>,
159        function_mapping: &'a HashMap<FunctionCall, usize>,
160        offset: usize,
161    ) -> Self {
162        Self {
163            p2s_mapping,
164            function_mapping,
165            offset,
166            covered_by_index: true,
167        }
168    }
169
170    fn covered_by_index(&self) -> bool {
171        self.covered_by_index
172    }
173}
174
175impl ExprRewriter for IndexPredicateRewriter<'_> {
176    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
177        // transform primary predicate to index predicate if it can
178        if self.p2s_mapping.contains_key(&input_ref.index) {
179            InputRef::new(
180                *self.p2s_mapping.get(&input_ref.index()).unwrap(),
181                input_ref.return_type(),
182            )
183            .into()
184        } else {
185            self.covered_by_index = false;
186            InputRef::new(input_ref.index() + self.offset, input_ref.return_type()).into()
187        }
188    }
189
190    fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
191        if let Some(index) = self.function_mapping.get(&func_call) {
192            return InputRef::new(*index, func_call.return_type()).into();
193        }
194
195        let (func_type, inputs, ret) = func_call.decompose();
196        let inputs = inputs
197            .into_iter()
198            .map(|expr| self.rewrite_expr(expr))
199            .collect();
200        FunctionCall::new_unchecked(func_type, inputs, ret).into()
201    }
202}
203
204impl IndexSelectionRule {
205    fn gen_index_lookup(
206        &self,
207        logical_scan: &LogicalScan,
208        index: &TableIndex,
209    ) -> (PlanRef, IndexCost) {
210        // 1. logical_scan ->  logical_join
211        //                      /        \
212        //                index_scan   primary_table_scan
213        let index_scan = LogicalScan::create(
214            index.index_table.clone(),
215            logical_scan.ctx(),
216            logical_scan.as_of().clone(),
217        );
218        // We use `schema.len` instead of `index_item.len` here,
219        // because schema contains system columns like `_rw_timestamp` column which is not represented in the index item.
220        let offset = index_scan.table().columns().len();
221
222        let primary_table_scan = LogicalScan::create(
223            index.primary_table.clone(),
224            logical_scan.ctx(),
225            logical_scan.as_of().clone(),
226        );
227
228        let predicate = logical_scan.predicate().clone();
229        let mut rewriter = IndexPredicateRewriter::new(
230            index.primary_to_secondary_mapping(),
231            index.function_mapping(),
232            offset,
233        );
234        let new_predicate = predicate.rewrite_expr(&mut rewriter);
235
236        let conjunctions = index
237            .primary_table_pk_ref_to_index_table()
238            .iter()
239            .zip_eq_fast(index.primary_table.pk.iter())
240            .map(|(x, y)| {
241                Self::create_null_safe_equal_expr(
242                    x.column_index,
243                    index.index_table.columns[x.column_index]
244                        .data_type()
245                        .clone(),
246                    y.column_index + offset,
247                    index.primary_table.columns[y.column_index]
248                        .data_type()
249                        .clone(),
250                )
251            })
252            .chain(new_predicate)
253            .collect_vec();
254        let on = Condition { conjunctions };
255        let join: PlanRef = LogicalJoin::new(
256            index_scan.into(),
257            primary_table_scan.into(),
258            JoinType::Inner,
259            on,
260        )
261        .into();
262
263        // 2. push down predicate, so we can calculate the cost of index lookup
264        let join_ref = join.predicate_pushdown(
265            Condition::true_cond(),
266            &mut PredicatePushdownContext::new(join.clone()),
267        );
268
269        let join_with_predicate_push_down =
270            join_ref.as_logical_join().expect("must be a logical join");
271        let new_join_left = join_with_predicate_push_down.left();
272        let index_scan_with_predicate: &LogicalScan = new_join_left
273            .as_logical_scan()
274            .expect("must be a logical scan");
275
276        // 3. calculate index cost, index lookup use primary table to estimate row size.
277        let index_cost = self.estimate_table_scan_cost(
278            index_scan_with_predicate,
279            TableScanIoEstimator::estimate_row_size(logical_scan),
280        );
281        // lookup cost = index cost * LOOKUP_COST_CONST
282        let lookup_cost = index_cost.mul(&IndexCost::new(LOOKUP_COST_CONST, false));
283
284        // 4. keep the same schema with original logical_scan
285        let scan_output_col_idx = logical_scan.output_col_idx();
286        let lookup_join = join_ref.prune_col(
287            &scan_output_col_idx
288                .iter()
289                .map(|&col_idx| col_idx + offset)
290                .collect_vec(),
291            &mut ColumnPruningContext::new(join_ref.clone()),
292        );
293
294        (lookup_join, lookup_cost)
295    }
296
297    /// Index Merge Selection
298    /// Deal with predicate like a = 1 or b = 1
299    /// Merge index scans from a table, currently merge is union semantic.
300    fn index_merge_selection(&self, logical_scan: &LogicalScan) -> Option<(PlanRef, IndexCost)> {
301        let predicate = logical_scan.predicate().clone();
302        // Index merge is kind of index lookup join so use primary table row size to estimate index
303        // cost.
304        let primary_table_row_size = TableScanIoEstimator::estimate_row_size(logical_scan);
305        // 1. choose lowest cost index merge path
306        let paths = self.gen_paths(
307            &predicate.conjunctions,
308            logical_scan,
309            primary_table_row_size,
310        );
311        let (index_access, index_access_cost) =
312            self.choose_min_cost_path(&paths, primary_table_row_size)?;
313
314        // 2. lookup primary table
315        // the schema of index_access is the order key of primary table .
316        let schema: &Schema = index_access.schema();
317        let index_access_len = schema.len();
318
319        let mut shift_input_ref_rewriter = ShiftInputRefRewriter {
320            offset: index_access_len,
321        };
322        let new_predicate = predicate.rewrite_expr(&mut shift_input_ref_rewriter);
323
324        let primary_table = logical_scan.table();
325
326        let primary_table_scan = LogicalScan::create(
327            logical_scan.table().clone(),
328            logical_scan.ctx(),
329            logical_scan.as_of().clone(),
330        );
331
332        let conjunctions = primary_table
333            .pk
334            .iter()
335            .enumerate()
336            .map(|(x, y)| {
337                Self::create_null_safe_equal_expr(
338                    x,
339                    schema.fields[x].data_type.clone(),
340                    y.column_index + index_access_len,
341                    primary_table.columns[y.column_index].data_type.clone(),
342                )
343            })
344            .chain(new_predicate)
345            .collect_vec();
346
347        let on = Condition { conjunctions };
348        let join: PlanRef =
349            LogicalJoin::new(index_access, primary_table_scan.into(), JoinType::Inner, on).into();
350
351        // 3 push down predicate
352        let join_ref = join.predicate_pushdown(
353            Condition::true_cond(),
354            &mut PredicatePushdownContext::new(join.clone()),
355        );
356
357        // 4. keep the same schema with original logical_scan
358        let scan_output_col_idx = logical_scan.output_col_idx();
359        let lookup_join = join_ref.prune_col(
360            &scan_output_col_idx
361                .iter()
362                .map(|&col_idx| col_idx + index_access_len)
363                .collect_vec(),
364            &mut ColumnPruningContext::new(join_ref.clone()),
365        );
366
367        Some((
368            lookup_join,
369            index_access_cost.mul(&IndexCost::new(LOOKUP_COST_CONST, false)),
370        ))
371    }
372
373    /// Generate possible paths that can be used to access.
374    /// The schema of output is the order key of primary table, so it can be used to lookup primary
375    /// table later.
376    /// Method `gen_paths` handles the complex condition recursively which may contains nested `AND`
377    /// and `OR`. However, Method `gen_index_path` handles one arm of an OR clause which is a
378    /// basic unit for index selection.
379    fn gen_paths(
380        &self,
381        conjunctions: &[ExprImpl],
382        logical_scan: &LogicalScan,
383        primary_table_row_size: usize,
384    ) -> Vec<PlanRef> {
385        let mut result = vec![];
386
387        // split by OR clause, the not_or_conjunctions could be used to generate index path by combining with each arm of OR clause.
388        let (or_conjunctions, not_or_conjunctions): (Vec<ExprImpl>, Vec<ExprImpl>) =
389            conjunctions.iter().cloned().partition(|expr| {
390                if let ExprImpl::FunctionCall(function_call) = expr
391                    && function_call.func_type() == ExprType::Or
392                {
393                    true
394                } else {
395                    false
396                }
397            });
398        // Only consider eq ,in and cmp condition for not_or_conjunctions
399        let interest_conjunctions: Vec<ExprImpl> = not_or_conjunctions
400            .into_iter()
401            .filter(|expr| {
402                expr.as_eq_const().is_some()
403                    || expr.as_in_const_list().is_some()
404                    || expr.as_comparison_const().is_some()
405            })
406            .collect();
407
408        for expr in or_conjunctions {
409            // it must be OR clause!
410            let mut index_to_be_merged = vec![];
411
412            let disjunctions = to_disjunctions(expr.clone());
413
414            let extended_disjunctions = disjunctions
415                .into_iter()
416                .map(|expr| {
417                    if interest_conjunctions.is_empty() {
418                        expr
419                    } else {
420                        ExprImpl::FunctionCall(
421                            FunctionCall::new_unchecked(
422                                ExprType::And,
423                                vec![expr]
424                                    .into_iter()
425                                    .chain(interest_conjunctions.iter().cloned())
426                                    .collect(),
427                                DataType::Boolean,
428                            )
429                            .into(),
430                        )
431                    }
432                })
433                .collect_vec();
434
435            let (map, others) = self.clustering_disjunction(extended_disjunctions);
436            let iter = map
437                .into_iter()
438                .map(|(column_index, expr)| (Some(column_index), expr))
439                .chain(others.into_iter().map(|expr| (None, expr)));
440            for (column_index, expr) in iter {
441                let mut index_paths = vec![];
442                let conjunctions = to_conjunctions(expr);
443                index_paths.extend(
444                    self.gen_index_path(column_index, &conjunctions, logical_scan)
445                        .into_iter(),
446                );
447                // complex condition, recursively gen paths
448                if conjunctions.len() > 1 {
449                    index_paths.extend(
450                        self.gen_paths(&conjunctions, logical_scan, primary_table_row_size)
451                            .into_iter(),
452                    );
453                }
454
455                match self.choose_min_cost_path(&index_paths, primary_table_row_size) {
456                    None => {
457                        // One arm of OR clause can't use index, bail out
458                        index_to_be_merged.clear();
459                        break;
460                    }
461                    Some((path, _)) => index_to_be_merged.push(path),
462                }
463            }
464
465            if let Some(path) = self.merge(index_to_be_merged) {
466                result.push(path)
467            }
468        }
469
470        result
471    }
472
473    /// Clustering disjunction or expr by column index. If expr is complex, classify them as others.
474    ///
475    /// a = 1, b = 2, b = 3 -> map: [a, (a = 1)], [b, (b = 2 or b = 3)], others: []
476    ///
477    /// a = 1, (b = 2 and c = 3) -> map: [a, (a = 1)], others:
478    ///
479    /// (a > 1 and a < 8) or (c > 1 and c < 8)
480    /// -> map: [], others: [(a > 1 and a < 8), (c > 1 and c < 8)]
481    fn clustering_disjunction(
482        &self,
483        disjunctions: Vec<ExprImpl>,
484    ) -> (HashMap<usize, ExprImpl>, Vec<ExprImpl>) {
485        let mut map: HashMap<usize, ExprImpl> = HashMap::new();
486        let mut others = vec![];
487        for expr in disjunctions {
488            let idx = {
489                if let Some((input_ref, _const_expr)) = expr.as_eq_const() {
490                    Some(input_ref.index)
491                } else if let Some((input_ref, _in_const_list)) = expr.as_in_const_list() {
492                    Some(input_ref.index)
493                } else if let Some((input_ref, _op, _const_expr)) = expr.as_comparison_const() {
494                    Some(input_ref.index)
495                } else {
496                    None
497                }
498            };
499
500            if let Some(idx) = idx {
501                match map.entry(idx) {
502                    Occupied(mut entry) => {
503                        let expr2: ExprImpl = entry.get().to_owned();
504                        let or_expr = ExprImpl::FunctionCall(
505                            FunctionCall::new_unchecked(
506                                ExprType::Or,
507                                vec![expr, expr2],
508                                DataType::Boolean,
509                            )
510                            .into(),
511                        );
512                        entry.insert(or_expr);
513                    }
514                    Vacant(entry) => {
515                        entry.insert(expr);
516                    }
517                };
518            } else {
519                others.push(expr);
520                continue;
521            }
522        }
523
524        (map, others)
525    }
526
527    /// Given a conjunctions from one arm of an OR clause (basic unit to index selection), generate
528    /// all matching index path (including primary index) for the relation.
529    /// `column_index` (refers to primary table) is a hint can be used to prune index.
530    /// Steps:
531    /// 1. Take the combination of `conjunctions` to extract the potential clauses.
532    /// 2. For each potential clauses, generate index path if it can.
533    fn gen_index_path(
534        &self,
535        column_index: Option<usize>,
536        conjunctions: &[ExprImpl],
537        logical_scan: &LogicalScan,
538    ) -> Vec<PlanRef> {
539        // Assumption: use at most `MAX_COMBINATION_SIZE` clauses, we can determine which is the
540        // best index.
541        let combinations = conjunctions
542            .iter()
543            .take(min(conjunctions.len(), MAX_CONJUNCTION_SIZE))
544            .combinations(min(conjunctions.len(), MAX_COMBINATION_SIZE))
545            .collect_vec();
546
547        let mut result = vec![];
548
549        for index in logical_scan.table_indexes() {
550            if let Some(column_index) = column_index {
551                assert_eq!(conjunctions.len(), 1);
552                let p2s_mapping = index.primary_to_secondary_mapping();
553                match p2s_mapping.get(&column_index) {
554                    None => continue, // not found, prune this index
555                    Some(&idx) => {
556                        if index.index_table.pk()[0].column_index != idx {
557                            // not match, prune this index
558                            continue;
559                        }
560                    }
561                }
562            }
563
564            // try secondary index
565            for conj in &combinations {
566                let condition = Condition {
567                    conjunctions: conj.iter().map(|&x| x.to_owned()).collect(),
568                };
569                if let Some(index_access) = self.build_index_access(
570                    index.clone(),
571                    condition,
572                    logical_scan.ctx().clone(),
573                    logical_scan.as_of().clone(),
574                ) {
575                    result.push(index_access);
576                }
577            }
578        }
579
580        // try primary index
581        let primary_table = logical_scan.table();
582        if let Some(idx) = column_index {
583            assert_eq!(conjunctions.len(), 1);
584            if primary_table.pk[0].column_index != idx {
585                return result;
586            }
587        }
588
589        let primary_access = generic::TableScan::new(
590            primary_table
591                .pk
592                .iter()
593                .map(|x| x.column_index)
594                .collect_vec(),
595            logical_scan.table().clone(),
596            vec![],
597            vec![],
598            logical_scan.ctx(),
599            Condition {
600                conjunctions: conjunctions.to_vec(),
601            },
602            logical_scan.as_of().clone(),
603        );
604
605        result.push(primary_access.into());
606
607        result
608    }
609
610    /// build index access if predicate (refers to primary table) is covered by index
611    fn build_index_access(
612        &self,
613        index: Arc<TableIndex>,
614        predicate: Condition,
615        ctx: OptimizerContextRef,
616        as_of: Option<AsOf>,
617    ) -> Option<PlanRef> {
618        let mut rewriter = IndexPredicateRewriter::new(
619            index.primary_to_secondary_mapping(),
620            index.function_mapping(),
621            0,
622        );
623        let new_predicate = predicate.rewrite_expr(&mut rewriter);
624
625        // check condition is covered by index.
626        if !rewriter.covered_by_index() {
627            return None;
628        }
629
630        Some(
631            generic::TableScan::new(
632                index
633                    .primary_table_pk_ref_to_index_table()
634                    .iter()
635                    .map(|x| x.column_index)
636                    .collect_vec(),
637                index.index_table.clone(),
638                vec![],
639                vec![],
640                ctx,
641                new_predicate,
642                as_of,
643            )
644            .into(),
645        )
646    }
647
648    fn merge(&self, paths: Vec<PlanRef>) -> Option<PlanRef> {
649        if paths.is_empty() {
650            return None;
651        }
652
653        let new_paths = paths
654            .iter()
655            .flat_map(|path| {
656                if let Some(union) = path.as_logical_union() {
657                    union.inputs().to_vec()
658                } else if let Some(_scan) = path.as_logical_scan() {
659                    vec![path.clone()]
660                } else {
661                    unreachable!();
662                }
663            })
664            .sorted_by(|a, b| {
665                // sort inputs to make plan deterministic
666                a.as_logical_scan()
667                    .expect("expect to be a logical scan")
668                    .table_name()
669                    .cmp(
670                        b.as_logical_scan()
671                            .expect("expect to be a logical scan")
672                            .table_name(),
673                    )
674            })
675            .collect_vec();
676
677        Some(LogicalUnion::create(false, new_paths))
678    }
679
680    fn choose_min_cost_path(
681        &self,
682        paths: &[PlanRef],
683        primary_table_row_size: usize,
684    ) -> Option<(PlanRef, IndexCost)> {
685        paths
686            .iter()
687            .map(|path| {
688                if let Some(scan) = path.as_logical_scan() {
689                    let cost = self.estimate_table_scan_cost(scan, primary_table_row_size);
690                    (scan.clone().into(), cost)
691                } else if let Some(union) = path.as_logical_union() {
692                    let cost = union
693                        .inputs()
694                        .iter()
695                        .map(|input| {
696                            self.estimate_table_scan_cost(
697                                input.as_logical_scan().expect("expect to be a scan"),
698                                primary_table_row_size,
699                            )
700                        })
701                        .reduce(|a, b| a.add(&b))
702                        .unwrap();
703                    (union.clone().into(), cost)
704                } else {
705                    unreachable!()
706                }
707            })
708            .min_by(|(_, cost1), (_, cost2)| Ord::cmp(cost1, cost2))
709    }
710
711    fn estimate_table_scan_cost(&self, scan: &LogicalScan, row_size: usize) -> IndexCost {
712        let mut table_scan_io_estimator = TableScanIoEstimator::new(scan, row_size);
713        table_scan_io_estimator.estimate(scan.predicate())
714    }
715
716    fn estimate_full_table_scan_cost(&self, scan: &LogicalScan, row_size: usize) -> IndexCost {
717        let mut table_scan_io_estimator = TableScanIoEstimator::new(scan, row_size);
718        table_scan_io_estimator.estimate(&Condition::true_cond())
719    }
720
721    pub fn create_null_safe_equal_expr(
722        left: usize,
723        left_data_type: DataType,
724        right: usize,
725        right_data_type: DataType,
726    ) -> ExprImpl {
727        ExprImpl::FunctionCall(Box::new(FunctionCall::new_unchecked(
728            ExprType::IsNotDistinctFrom,
729            vec![
730                ExprImpl::InputRef(Box::new(InputRef::new(left, left_data_type))),
731                ExprImpl::InputRef(Box::new(InputRef::new(right, right_data_type))),
732            ],
733            DataType::Boolean,
734        )))
735    }
736}
737
738struct TableScanIoEstimator<'a> {
739    table_scan: &'a LogicalScan,
740    row_size: usize,
741    cost: Option<IndexCost>,
742}
743
744impl<'a> TableScanIoEstimator<'a> {
745    pub fn new(table_scan: &'a LogicalScan, row_size: usize) -> Self {
746        Self {
747            table_scan,
748            row_size,
749            cost: None,
750        }
751    }
752
753    pub fn estimate_row_size(table_scan: &LogicalScan) -> usize {
754        // 5 for table_id + 1 for vnode + 8 for epoch
755        let row_meta_field_estimate_size = 14_usize;
756        let table = table_scan.table();
757        row_meta_field_estimate_size
758            + table
759                .columns
760                .iter()
761                // add order key twice for its appearance both in key and value
762                .chain(table.pk.iter().map(|x| &table.columns[x.column_index]))
763                .map(|x| TableScanIoEstimator::estimate_data_type_size(&x.data_type))
764                .sum::<usize>()
765    }
766
767    fn estimate_data_type_size(data_type: &DataType) -> usize {
768        use std::mem::size_of;
769
770        match data_type {
771            DataType::Boolean => size_of::<bool>(),
772            DataType::Int16 => size_of::<i16>(),
773            DataType::Int32 => size_of::<i32>(),
774            DataType::Int64 => size_of::<i64>(),
775            DataType::Serial => size_of::<Serial>(),
776            DataType::Float32 => size_of::<f32>(),
777            DataType::Float64 => size_of::<f64>(),
778            DataType::Decimal => size_of::<Decimal>(),
779            DataType::Date => size_of::<Date>(),
780            DataType::Time => size_of::<Time>(),
781            DataType::Timestamp => size_of::<Timestamp>(),
782            DataType::Timestamptz => size_of::<Timestamptz>(),
783            DataType::Interval => size_of::<Interval>(),
784            DataType::Int256 => Int256::size(),
785            DataType::Varchar => 20,
786            DataType::Bytea => 20,
787            DataType::Jsonb => 20,
788            DataType::Struct { .. } => 20,
789            DataType::List { .. } => 20,
790            DataType::Map(_) => 20,
791            DataType::Vector(d) => d * size_of::<VectorDistanceType>(),
792        }
793    }
794
795    pub fn estimate(&mut self, predicate: &Condition) -> IndexCost {
796        // try to deal with OR condition
797        if predicate.conjunctions.len() == 1 {
798            self.visit_expr(&predicate.conjunctions[0]);
799            self.cost.take().unwrap_or_default()
800        } else {
801            self.estimate_conjunctions(&predicate.conjunctions)
802        }
803    }
804
805    fn estimate_conjunctions(&mut self, conjunctions: &[ExprImpl]) -> IndexCost {
806        let mut new_conjunctions = conjunctions.to_owned();
807
808        let mut match_item_vec = vec![];
809
810        for column_idx in self.table_scan.table().order_column_indices() {
811            let match_item = self.match_index_column(column_idx, &mut new_conjunctions);
812            // seeing range, we don't need to match anymore.
813            let should_break = match match_item {
814                MatchItem::Equal | MatchItem::In(_) => false,
815                MatchItem::RangeOneSideBound | MatchItem::RangeTwoSideBound | MatchItem::All => {
816                    true
817                }
818            };
819            match_item_vec.push(match_item);
820            if should_break {
821                break;
822            }
823        }
824
825        let index_cost = match_item_vec
826            .iter()
827            .enumerate()
828            .take(INDEX_MAX_LEN)
829            .map(|(i, match_item)| match match_item {
830                MatchItem::Equal => INDEX_COST_MATRIX[0][i],
831                MatchItem::In(num) => min(INDEX_COST_MATRIX[1][i], *num),
832                MatchItem::RangeTwoSideBound => INDEX_COST_MATRIX[2][i],
833                MatchItem::RangeOneSideBound => INDEX_COST_MATRIX[3][i],
834                MatchItem::All => INDEX_COST_MATRIX[4][i],
835            })
836            .reduce(|x, y| x * y)
837            .unwrap();
838
839        // If `index_cost` equals 1, it is a primary lookup
840        let primary_lookup = index_cost == 1;
841
842        IndexCost::new(index_cost, primary_lookup)
843            .mul(&IndexCost::new(self.row_size, primary_lookup))
844    }
845
846    fn match_index_column(
847        &mut self,
848        column_idx: usize,
849        conjunctions: &mut Vec<ExprImpl>,
850    ) -> MatchItem {
851        // Equal
852        for (i, expr) in conjunctions.iter().enumerate() {
853            if let Some((input_ref, _const_expr)) = expr.as_eq_const()
854                && input_ref.index == column_idx
855            {
856                conjunctions.remove(i);
857                return MatchItem::Equal;
858            }
859        }
860
861        // In
862        for (i, expr) in conjunctions.iter().enumerate() {
863            if let Some((input_ref, in_const_list)) = expr.as_in_const_list()
864                && input_ref.index == column_idx
865            {
866                conjunctions.remove(i);
867                return MatchItem::In(in_const_list.len());
868            }
869        }
870
871        // Range
872        let mut left_side_bound = false;
873        let mut right_side_bound = false;
874        let mut i = 0;
875        while i < conjunctions.len() {
876            let expr = &conjunctions[i];
877            if let Some((input_ref, op, _const_expr)) = expr.as_comparison_const()
878                && input_ref.index == column_idx
879            {
880                conjunctions.remove(i);
881                match op {
882                    ExprType::LessThan | ExprType::LessThanOrEqual => right_side_bound = true,
883                    ExprType::GreaterThan | ExprType::GreaterThanOrEqual => left_side_bound = true,
884                    _ => unreachable!(),
885                };
886            } else {
887                i += 1;
888            }
889        }
890
891        if left_side_bound && right_side_bound {
892            MatchItem::RangeTwoSideBound
893        } else if left_side_bound || right_side_bound {
894            MatchItem::RangeOneSideBound
895        } else {
896            MatchItem::All
897        }
898    }
899}
900
901enum MatchItem {
902    Equal,
903    In(usize),
904    RangeTwoSideBound,
905    RangeOneSideBound,
906    All,
907}
908
909#[derive(PartialEq, Eq, Hash, Clone, Debug, PartialOrd, Ord)]
910struct IndexCost {
911    cost: usize,
912    primary_lookup: bool,
913}
914
915impl Default for IndexCost {
916    fn default() -> Self {
917        Self {
918            cost: IndexCost::maximum(),
919            primary_lookup: false,
920        }
921    }
922}
923
924impl IndexCost {
925    fn new(cost: usize, primary_lookup: bool) -> IndexCost {
926        Self {
927            cost: min(cost, IndexCost::maximum()),
928            primary_lookup,
929        }
930    }
931
932    fn maximum() -> usize {
933        10000000
934    }
935
936    fn add(&self, other: &IndexCost) -> IndexCost {
937        IndexCost::new(
938            self.cost
939                .checked_add(other.cost)
940                .unwrap_or_else(IndexCost::maximum),
941            self.primary_lookup && other.primary_lookup,
942        )
943    }
944
945    fn mul(&self, other: &IndexCost) -> IndexCost {
946        IndexCost::new(
947            self.cost
948                .checked_mul(other.cost)
949                .unwrap_or_else(IndexCost::maximum),
950            self.primary_lookup && other.primary_lookup,
951        )
952    }
953
954    fn le(&self, other: &IndexCost) -> bool {
955        self.cost < other.cost
956    }
957}
958
959impl ExprVisitor for TableScanIoEstimator<'_> {
960    fn visit_function_call(&mut self, func_call: &FunctionCall) {
961        let cost = match func_call.func_type() {
962            ExprType::Or => func_call
963                .inputs()
964                .iter()
965                .map(|x| {
966                    let mut estimator = TableScanIoEstimator::new(self.table_scan, self.row_size);
967                    estimator.visit_expr(x);
968                    estimator.cost.take().unwrap_or_default()
969                })
970                .reduce(|x, y| x.add(&y))
971                .unwrap(),
972            ExprType::And => self.estimate_conjunctions(func_call.inputs()),
973            _ => {
974                let single = vec![ExprImpl::FunctionCall(func_call.clone().into())];
975                self.estimate_conjunctions(&single)
976            }
977        };
978        self.cost = Some(cost);
979    }
980}
981
982struct ShiftInputRefRewriter {
983    offset: usize,
984}
985impl ExprRewriter for ShiftInputRefRewriter {
986    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
987        InputRef::new(input_ref.index() + self.offset, input_ref.return_type()).into()
988    }
989}
990
991impl IndexSelectionRule {
992    pub fn create() -> BoxedRule {
993        Box::new(IndexSelectionRule {})
994    }
995}