Skip to main content

risingwave_frontend/optimizer/rule/
index_selection_rule.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! # Index selection cost matrix
16//!
17//! |`column_idx`| 0   |  1 | 2  | 3  | 4  | remark |
18//! |-----------|-----|----|----|----|----|---|
19//! |Equal      | 1   | 1  | 1  | 1  | 1  | |
20//! |In         | 10  | 8  | 5  | 5  | 5  | take the minimum value with actual in number |
21//! |Range(Two) | 600 | 50 | 20 | 10 | 10 | `RangeTwoSideBound` like a between 1 and 2 |
22//! |Range(One) | 1400| 70 | 25 | 15 | 10 | `RangeOneSideBound` like a > 1, a >= 1, a < 1|
23//! |All        | 4000| 100| 30 | 20 | 10 | |
24//!
25//! ```text
26//! index cost = cost(match type of 0 idx)
27//! * cost(match type of 1 idx)
28//! * ... cost(match type of the last idx)
29//! ```
30//!
31//! ## Example
32//!
33//! Given index order key (a, b, c)
34//!
35//! - For `a = 1 and b = 1 and c = 1`, its cost is 1 = Equal0 * Equal1 * Equal2 = 1
36//! - For `a in (xxx) and b = 1 and c = 1`, its cost is In0 * Equal1 * Equal2 = 10
37//! - For `a = 1 and b in (xxx)`, its cost is Equal0 * In1 * All2 = 1 * 8 * 50 = 400
38//! - For `a between xxx and yyy`, its cost is Range(Two)0 = 600
39//! - For `a = 1 and b between xxx and yyy`, its cost is Equal0 * Range(Two)1 = 50
40//! - For `a = 1 and b > 1`, its cost is Equal0 * Range(One)1 = 70
41//! - For `a = 1`, its cost is 100 = Equal0 * All1 = 100
42//! - For no condition, its cost is All0 = 4000
43//!
44//! With the assumption that the most effective part of a index is its prefix,
45//! cost decreases as `column_idx` increasing.
46//!
47//! For index order key length > 5, we just ignore the rest.
48
49use std::cmp::min;
50use std::collections::hash_map::Entry::{Occupied, Vacant};
51use std::collections::{BTreeMap, HashMap};
52use std::sync::Arc;
53
54use itertools::Itertools;
55use risingwave_common::array::VectorDistanceType;
56use risingwave_common::catalog::Schema;
57use risingwave_common::types::{
58    DataType, Date, Decimal, Int256, Interval, Serial, Time, Timestamp, Timestamptz,
59};
60use risingwave_common::util::iter_util::ZipEqFast;
61use risingwave_pb::plan_common::JoinType;
62use risingwave_sqlparser::ast::AsOf;
63
64use super::prelude::{PlanRef, *};
65use crate::catalog::index_catalog::TableIndex;
66use crate::expr::{
67    Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, FunctionCall, InputRef, to_conjunctions,
68    to_disjunctions,
69};
70use crate::optimizer::optimizer_context::OptimizerContextRef;
71use crate::optimizer::plan_node::generic::GenericPlanRef;
72use crate::optimizer::plan_node::{
73    ColumnPruningContext, LogicalJoin, LogicalScan, LogicalUnion, PlanTreeNode, PlanTreeNodeBinary,
74    PredicatePushdown, PredicatePushdownContext, generic,
75};
76use crate::utils::Condition;
77
78const INDEX_MAX_LEN: usize = 5;
79const INDEX_COST_MATRIX: [[usize; INDEX_MAX_LEN]; 5] = [
80    [1, 1, 1, 1, 1],
81    [10, 8, 5, 5, 5],
82    [600, 50, 20, 10, 10],
83    [1400, 70, 25, 15, 10],
84    [4000, 100, 30, 20, 20],
85];
86const LOOKUP_COST_CONST: usize = 3;
87const MAX_COMBINATION_SIZE: usize = 4;
88const MAX_CONJUNCTION_SIZE: usize = 8;
89
90pub struct IndexSelectionRule {}
91
92impl Rule<Logical> for IndexSelectionRule {
93    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
94        let logical_scan: &LogicalScan = plan.as_logical_scan()?;
95        let indexes = logical_scan.table_indexes();
96        if indexes.is_empty() {
97            return None;
98        }
99        let primary_table_row_size = TableScanIoEstimator::estimate_row_size(logical_scan);
100        let primary_cost = min(
101            self.estimate_table_scan_cost(logical_scan, primary_table_row_size),
102            self.estimate_full_table_scan_cost(logical_scan, primary_table_row_size),
103        );
104
105        // If it is a primary lookup plan, avoid checking other indexes.
106        if primary_cost.primary_lookup {
107            return None;
108        }
109
110        let mut final_plan: PlanRef = logical_scan.clone().into();
111        let mut min_cost = primary_cost.clone();
112
113        for index in indexes {
114            if let Some(index_scan) = logical_scan.to_index_scan_if_index_covered(index) {
115                let index_cost = self.estimate_table_scan_cost(
116                    &index_scan,
117                    TableScanIoEstimator::estimate_row_size(&index_scan),
118                );
119
120                if index_cost.le(&min_cost) {
121                    min_cost = index_cost;
122                    final_plan = index_scan.into();
123                }
124            } else {
125                // non-covering index selection
126                let (index_lookup, lookup_cost) = self.gen_index_lookup(logical_scan, index);
127                if lookup_cost.le(&min_cost) {
128                    min_cost = lookup_cost;
129                    final_plan = index_lookup;
130                }
131            }
132        }
133
134        if let Some((merge_index, merge_index_cost)) = self.index_merge_selection(logical_scan)
135            && merge_index_cost.le(&min_cost)
136        {
137            min_cost = merge_index_cost;
138            final_plan = merge_index;
139        }
140
141        if min_cost == primary_cost {
142            None
143        } else {
144            Some(final_plan)
145        }
146    }
147}
148
149struct IndexPredicateRewriter<'a> {
150    p2s_mapping: &'a BTreeMap<usize, usize>,
151    function_mapping: &'a HashMap<FunctionCall, usize>,
152    offset: usize,
153    covered_by_index: bool,
154}
155
156impl<'a> IndexPredicateRewriter<'a> {
157    fn new(
158        p2s_mapping: &'a BTreeMap<usize, usize>,
159        function_mapping: &'a HashMap<FunctionCall, usize>,
160        offset: usize,
161    ) -> Self {
162        Self {
163            p2s_mapping,
164            function_mapping,
165            offset,
166            covered_by_index: true,
167        }
168    }
169
170    fn covered_by_index(&self) -> bool {
171        self.covered_by_index
172    }
173}
174
175impl ExprRewriter for IndexPredicateRewriter<'_> {
176    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
177        // transform primary predicate to index predicate if it can
178        if self.p2s_mapping.contains_key(&input_ref.index) {
179            InputRef::new(
180                *self.p2s_mapping.get(&input_ref.index()).unwrap(),
181                input_ref.return_type(),
182            )
183            .into()
184        } else {
185            self.covered_by_index = false;
186            InputRef::new(input_ref.index() + self.offset, input_ref.return_type()).into()
187        }
188    }
189
190    fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
191        if let Some(index) = self.function_mapping.get(&func_call) {
192            return InputRef::new(*index, func_call.return_type()).into();
193        }
194
195        let (func_type, inputs, ret) = func_call.decompose();
196        let inputs = inputs
197            .into_iter()
198            .map(|expr| self.rewrite_expr(expr))
199            .collect();
200        FunctionCall::new_unchecked(func_type, inputs, ret).into()
201    }
202}
203
204impl IndexSelectionRule {
205    fn gen_index_lookup(
206        &self,
207        logical_scan: &LogicalScan,
208        index: &TableIndex,
209    ) -> (PlanRef, IndexCost) {
210        // 1. logical_scan ->  logical_join
211        //                      /        \
212        //                index_scan   primary_table_scan
213        let index_scan = LogicalScan::create(
214            index.index_table.clone(),
215            logical_scan.ctx(),
216            logical_scan.as_of(),
217        );
218        // We use `schema.len` instead of `index_item.len` here,
219        // because schema contains system columns like `_rw_timestamp` column which is not represented in the index item.
220        let offset = index_scan.table().columns().len();
221
222        let primary_table_scan = LogicalScan::create(
223            index.primary_table.clone(),
224            logical_scan.ctx(),
225            logical_scan.as_of(),
226        );
227
228        let predicate = logical_scan.predicate().clone();
229        let mut rewriter = IndexPredicateRewriter::new(
230            index.primary_to_secondary_mapping(),
231            index.function_mapping(),
232            offset,
233        );
234        let new_predicate = predicate.rewrite_expr(&mut rewriter);
235
236        let conjunctions = index
237            .primary_table_pk_ref_to_index_table()
238            .iter()
239            .zip_eq_fast(index.primary_table.pk.iter())
240            .map(|(x, y)| {
241                Self::create_null_safe_equal_expr(
242                    x.column_index,
243                    index.index_table.columns[x.column_index]
244                        .data_type()
245                        .clone(),
246                    y.column_index + offset,
247                    index.primary_table.columns[y.column_index]
248                        .data_type()
249                        .clone(),
250                )
251            })
252            .chain(new_predicate)
253            .collect_vec();
254        let on = Condition { conjunctions };
255        let join: PlanRef = LogicalJoin::new(
256            index_scan.into(),
257            primary_table_scan.into(),
258            JoinType::Inner,
259            on,
260        )
261        .into();
262
263        // 2. push down predicate, so we can calculate the cost of index lookup
264        let join_ref = join.predicate_pushdown(
265            Condition::true_cond(),
266            &mut PredicatePushdownContext::new(join.clone()),
267        );
268
269        let join_with_predicate_push_down =
270            join_ref.as_logical_join().expect("must be a logical join");
271        let new_join_left = join_with_predicate_push_down.left();
272        let index_scan_with_predicate: &LogicalScan = new_join_left
273            .as_logical_scan()
274            .expect("must be a logical scan");
275
276        // 3. calculate index cost, index lookup use primary table to estimate row size.
277        let index_cost = self.estimate_table_scan_cost(
278            index_scan_with_predicate,
279            TableScanIoEstimator::estimate_row_size(logical_scan),
280        );
281        // lookup cost = index cost * LOOKUP_COST_CONST
282        let lookup_cost = index_cost.mul(&IndexCost::new(LOOKUP_COST_CONST, false));
283
284        // 4. keep the same schema with original logical_scan
285        let scan_output_col_idx = logical_scan.output_col_idx();
286        let lookup_join = join_ref.prune_col(
287            &scan_output_col_idx
288                .iter()
289                .map(|&col_idx| col_idx + offset)
290                .collect_vec(),
291            &mut ColumnPruningContext::new(join_ref.clone()),
292        );
293
294        (lookup_join, lookup_cost)
295    }
296
297    /// Index Merge Selection
298    /// Deal with predicate like a = 1 or b = 1
299    /// Merge index scans from a table, currently merge is union semantic.
300    fn index_merge_selection(&self, logical_scan: &LogicalScan) -> Option<(PlanRef, IndexCost)> {
301        let predicate = logical_scan.predicate().clone();
302        // Index merge is kind of index lookup join so use primary table row size to estimate index
303        // cost.
304        let primary_table_row_size = TableScanIoEstimator::estimate_row_size(logical_scan);
305        // 1. choose lowest cost index merge path
306        let paths = self.gen_paths(
307            &predicate.conjunctions,
308            logical_scan,
309            primary_table_row_size,
310        );
311        let (index_access, index_access_cost) =
312            self.choose_min_cost_path(&paths, primary_table_row_size)?;
313
314        // 2. lookup primary table
315        // the schema of index_access is the order key of primary table .
316        let schema: &Schema = index_access.schema();
317        let index_access_len = schema.len();
318
319        let mut shift_input_ref_rewriter = ShiftInputRefRewriter {
320            offset: index_access_len,
321        };
322        let new_predicate = predicate.rewrite_expr(&mut shift_input_ref_rewriter);
323
324        let primary_table = logical_scan.table();
325
326        let primary_table_scan = LogicalScan::create(
327            logical_scan.table().clone(),
328            logical_scan.ctx(),
329            logical_scan.as_of(),
330        );
331
332        let conjunctions = primary_table
333            .pk
334            .iter()
335            .enumerate()
336            .map(|(x, y)| {
337                Self::create_null_safe_equal_expr(
338                    x,
339                    schema.fields[x].data_type.clone(),
340                    y.column_index + index_access_len,
341                    primary_table.columns[y.column_index].data_type.clone(),
342                )
343            })
344            .chain(new_predicate)
345            .collect_vec();
346
347        let on = Condition { conjunctions };
348        let join: PlanRef =
349            LogicalJoin::new(index_access, primary_table_scan.into(), JoinType::Inner, on).into();
350
351        // 3 push down predicate
352        let join_ref = join.predicate_pushdown(
353            Condition::true_cond(),
354            &mut PredicatePushdownContext::new(join.clone()),
355        );
356
357        // 4. keep the same schema with original logical_scan
358        let scan_output_col_idx = logical_scan.output_col_idx();
359        let lookup_join = join_ref.prune_col(
360            &scan_output_col_idx
361                .iter()
362                .map(|&col_idx| col_idx + index_access_len)
363                .collect_vec(),
364            &mut ColumnPruningContext::new(join_ref.clone()),
365        );
366
367        Some((
368            lookup_join,
369            index_access_cost.mul(&IndexCost::new(LOOKUP_COST_CONST, false)),
370        ))
371    }
372
373    /// Generate possible paths that can be used to access.
374    /// The schema of output is the order key of primary table, so it can be used to lookup primary
375    /// table later.
376    /// Method `gen_paths` handles the complex condition recursively which may contains nested `AND`
377    /// and `OR`. However, Method `gen_index_path` handles one arm of an OR clause which is a
378    /// basic unit for index selection.
379    fn gen_paths(
380        &self,
381        conjunctions: &[ExprImpl],
382        logical_scan: &LogicalScan,
383        primary_table_row_size: usize,
384    ) -> Vec<PlanRef> {
385        let mut result = vec![];
386
387        // split by OR clause, the not_or_conjunctions could be used to generate index path by combining with each arm of OR clause.
388        let (or_conjunctions, not_or_conjunctions): (Vec<ExprImpl>, Vec<ExprImpl>) =
389            conjunctions.iter().cloned().partition(|expr| {
390                if let ExprImpl::FunctionCall(function_call) = expr
391                    && function_call.func_type() == ExprType::Or
392                {
393                    true
394                } else {
395                    false
396                }
397            });
398        // Only consider eq ,in and cmp condition for not_or_conjunctions
399        let interest_conjunctions: Vec<ExprImpl> = not_or_conjunctions
400            .into_iter()
401            .filter(|expr| {
402                expr.as_eq_const().is_some()
403                    || expr.as_in_const_list().is_some()
404                    || expr.as_comparison_const().is_some()
405            })
406            .collect();
407
408        for expr in or_conjunctions {
409            // it must be OR clause!
410            let mut index_to_be_merged = vec![];
411
412            let disjunctions = to_disjunctions(expr.clone());
413
414            let extended_disjunctions = disjunctions
415                .into_iter()
416                .map(|expr| {
417                    if interest_conjunctions.is_empty() {
418                        expr
419                    } else {
420                        ExprImpl::FunctionCall(
421                            FunctionCall::new_unchecked(
422                                ExprType::And,
423                                vec![expr]
424                                    .into_iter()
425                                    .chain(interest_conjunctions.iter().cloned())
426                                    .collect(),
427                                DataType::Boolean,
428                            )
429                            .into(),
430                        )
431                    }
432                })
433                .collect_vec();
434
435            let (map, others) = self.clustering_disjunction(extended_disjunctions);
436            let iter = map
437                .into_iter()
438                .map(|(column_index, expr)| (Some(column_index), expr))
439                .chain(others.into_iter().map(|expr| (None, expr)));
440            for (column_index, expr) in iter {
441                let mut index_paths = vec![];
442                let conjunctions = to_conjunctions(expr);
443                index_paths.extend(self.gen_index_path(column_index, &conjunctions, logical_scan));
444                // complex condition, recursively gen paths
445                if conjunctions.len() > 1 {
446                    index_paths.extend(self.gen_paths(
447                        &conjunctions,
448                        logical_scan,
449                        primary_table_row_size,
450                    ));
451                }
452
453                match self.choose_min_cost_path(&index_paths, primary_table_row_size) {
454                    None => {
455                        // One arm of OR clause can't use index, bail out
456                        index_to_be_merged.clear();
457                        break;
458                    }
459                    Some((path, _)) => index_to_be_merged.push(path),
460                }
461            }
462
463            if let Some(path) = self.merge(index_to_be_merged) {
464                result.push(path)
465            }
466        }
467
468        result
469    }
470
471    /// Clustering disjunction or expr by column index. If expr is complex, classify them as others.
472    ///
473    /// a = 1, b = 2, b = 3 -> map: [a, (a = 1)], [b, (b = 2 or b = 3)], others: []
474    ///
475    /// a = 1, (b = 2 and c = 3) -> map: [a, (a = 1)], others:
476    ///
477    /// (a > 1 and a < 8) or (c > 1 and c < 8)
478    /// -> map: [], others: [(a > 1 and a < 8), (c > 1 and c < 8)]
479    fn clustering_disjunction(
480        &self,
481        disjunctions: Vec<ExprImpl>,
482    ) -> (HashMap<usize, ExprImpl>, Vec<ExprImpl>) {
483        let mut map: HashMap<usize, ExprImpl> = HashMap::new();
484        let mut others = vec![];
485        for expr in disjunctions {
486            let idx = {
487                if let Some((input_ref, _const_expr)) = expr.as_eq_const() {
488                    Some(input_ref.index)
489                } else if let Some((input_ref, _in_const_list)) = expr.as_in_const_list() {
490                    Some(input_ref.index)
491                } else if let Some((input_ref, _op, _const_expr)) = expr.as_comparison_const() {
492                    Some(input_ref.index)
493                } else {
494                    None
495                }
496            };
497
498            if let Some(idx) = idx {
499                match map.entry(idx) {
500                    Occupied(mut entry) => {
501                        let expr2: ExprImpl = entry.get().to_owned();
502                        let or_expr = ExprImpl::FunctionCall(
503                            FunctionCall::new_unchecked(
504                                ExprType::Or,
505                                vec![expr, expr2],
506                                DataType::Boolean,
507                            )
508                            .into(),
509                        );
510                        entry.insert(or_expr);
511                    }
512                    Vacant(entry) => {
513                        entry.insert(expr);
514                    }
515                };
516            } else {
517                others.push(expr);
518                continue;
519            }
520        }
521
522        (map, others)
523    }
524
525    /// Given a conjunctions from one arm of an OR clause (basic unit to index selection), generate
526    /// all matching index path (including primary index) for the relation.
527    /// `column_index` (refers to primary table) is a hint can be used to prune index.
528    /// Steps:
529    /// 1. Take the combination of `conjunctions` to extract the potential clauses.
530    /// 2. For each potential clauses, generate index path if it can.
531    fn gen_index_path(
532        &self,
533        column_index: Option<usize>,
534        conjunctions: &[ExprImpl],
535        logical_scan: &LogicalScan,
536    ) -> Vec<PlanRef> {
537        // Assumption: use at most `MAX_COMBINATION_SIZE` clauses, we can determine which is the
538        // best index.
539        let combinations = conjunctions
540            .iter()
541            .take(min(conjunctions.len(), MAX_CONJUNCTION_SIZE))
542            .combinations(min(conjunctions.len(), MAX_COMBINATION_SIZE))
543            .collect_vec();
544
545        let mut result = vec![];
546
547        for index in logical_scan.table_indexes() {
548            if let Some(column_index) = column_index {
549                assert_eq!(conjunctions.len(), 1);
550                let p2s_mapping = index.primary_to_secondary_mapping();
551                match p2s_mapping.get(&column_index) {
552                    None => continue, // not found, prune this index
553                    Some(&idx) => {
554                        if index.index_table.pk()[0].column_index != idx {
555                            // not match, prune this index
556                            continue;
557                        }
558                    }
559                }
560            }
561
562            // try secondary index
563            for conj in &combinations {
564                let condition = Condition {
565                    conjunctions: conj.iter().map(|&x| x.to_owned()).collect(),
566                };
567                if let Some(index_access) = self.build_index_access(
568                    index.clone(),
569                    condition,
570                    logical_scan.ctx().clone(),
571                    logical_scan.as_of().clone(),
572                ) {
573                    result.push(index_access);
574                }
575            }
576        }
577
578        // try primary index
579        let primary_table = logical_scan.table();
580        if let Some(idx) = column_index {
581            assert_eq!(conjunctions.len(), 1);
582            if primary_table.pk[0].column_index != idx {
583                return result;
584            }
585        }
586
587        let primary_access = generic::TableScan::new(
588            primary_table
589                .pk
590                .iter()
591                .map(|x| x.column_index)
592                .collect_vec(),
593            logical_scan.table().clone(),
594            vec![],
595            vec![],
596            logical_scan.ctx(),
597            Condition {
598                conjunctions: conjunctions.to_vec(),
599            },
600            logical_scan.as_of(),
601        );
602
603        result.push(primary_access.into());
604
605        result
606    }
607
608    /// build index access if predicate (refers to primary table) is covered by index
609    fn build_index_access(
610        &self,
611        index: Arc<TableIndex>,
612        predicate: Condition,
613        ctx: OptimizerContextRef,
614        as_of: Option<AsOf>,
615    ) -> Option<PlanRef> {
616        let mut rewriter = IndexPredicateRewriter::new(
617            index.primary_to_secondary_mapping(),
618            index.function_mapping(),
619            0,
620        );
621        let new_predicate = predicate.rewrite_expr(&mut rewriter);
622
623        // check condition is covered by index.
624        if !rewriter.covered_by_index() {
625            return None;
626        }
627
628        Some(
629            generic::TableScan::new(
630                index
631                    .primary_table_pk_ref_to_index_table()
632                    .iter()
633                    .map(|x| x.column_index)
634                    .collect_vec(),
635                index.index_table.clone(),
636                vec![],
637                vec![],
638                ctx,
639                new_predicate,
640                as_of,
641            )
642            .into(),
643        )
644    }
645
646    fn merge(&self, paths: Vec<PlanRef>) -> Option<PlanRef> {
647        if paths.is_empty() {
648            return None;
649        }
650
651        let new_paths = paths
652            .iter()
653            .flat_map(|path| {
654                if let Some(union) = path.as_logical_union() {
655                    union.inputs().to_vec()
656                } else if let Some(_scan) = path.as_logical_scan() {
657                    vec![path.clone()]
658                } else {
659                    unreachable!();
660                }
661            })
662            .sorted_by(|a, b| {
663                // sort inputs to make plan deterministic
664                a.as_logical_scan()
665                    .expect("expect to be a logical scan")
666                    .table_name()
667                    .cmp(
668                        b.as_logical_scan()
669                            .expect("expect to be a logical scan")
670                            .table_name(),
671                    )
672            })
673            .collect_vec();
674
675        Some(LogicalUnion::create(false, new_paths))
676    }
677
678    fn choose_min_cost_path(
679        &self,
680        paths: &[PlanRef],
681        primary_table_row_size: usize,
682    ) -> Option<(PlanRef, IndexCost)> {
683        paths
684            .iter()
685            .map(|path| {
686                if let Some(scan) = path.as_logical_scan() {
687                    let cost = self.estimate_table_scan_cost(scan, primary_table_row_size);
688                    (scan.clone().into(), cost)
689                } else if let Some(union) = path.as_logical_union() {
690                    let cost = union
691                        .inputs()
692                        .iter()
693                        .map(|input| {
694                            self.estimate_table_scan_cost(
695                                input.as_logical_scan().expect("expect to be a scan"),
696                                primary_table_row_size,
697                            )
698                        })
699                        .reduce(|a, b| a.add(&b))
700                        .unwrap();
701                    (union.clone().into(), cost)
702                } else {
703                    unreachable!()
704                }
705            })
706            .min_by(|(_, cost1), (_, cost2)| Ord::cmp(cost1, cost2))
707    }
708
709    pub(crate) fn estimate_table_scan_cost(
710        &self,
711        scan: &LogicalScan,
712        row_size: usize,
713    ) -> IndexCost {
714        let mut table_scan_io_estimator = TableScanIoEstimator::new(scan, row_size);
715        table_scan_io_estimator.estimate(scan.predicate())
716    }
717
718    pub(crate) fn estimate_full_table_scan_cost(
719        &self,
720        scan: &LogicalScan,
721        row_size: usize,
722    ) -> IndexCost {
723        let mut table_scan_io_estimator = TableScanIoEstimator::new(scan, row_size);
724        table_scan_io_estimator.estimate(&Condition::true_cond())
725    }
726
727    pub fn create_null_safe_equal_expr(
728        left: usize,
729        left_data_type: DataType,
730        right: usize,
731        right_data_type: DataType,
732    ) -> ExprImpl {
733        ExprImpl::FunctionCall(Box::new(FunctionCall::new_unchecked(
734            ExprType::IsNotDistinctFrom,
735            vec![
736                ExprImpl::InputRef(Box::new(InputRef::new(left, left_data_type))),
737                ExprImpl::InputRef(Box::new(InputRef::new(right, right_data_type))),
738            ],
739            DataType::Boolean,
740        )))
741    }
742}
743
744pub(crate) struct TableScanIoEstimator<'a> {
745    table_scan: &'a LogicalScan,
746    row_size: usize,
747    cost: Option<IndexCost>,
748}
749
750impl<'a> TableScanIoEstimator<'a> {
751    pub fn new(table_scan: &'a LogicalScan, row_size: usize) -> Self {
752        Self {
753            table_scan,
754            row_size,
755            cost: None,
756        }
757    }
758
759    pub fn estimate_row_size(table_scan: &LogicalScan) -> usize {
760        // 5 for table_id + 1 for vnode + 8 for epoch
761        let row_meta_field_estimate_size = 14_usize;
762        let table = table_scan.table();
763        row_meta_field_estimate_size
764            + table
765                .columns
766                .iter()
767                // add order key twice for its appearance both in key and value
768                .chain(table.pk.iter().map(|x| &table.columns[x.column_index]))
769                .map(|x| TableScanIoEstimator::estimate_data_type_size(&x.data_type))
770                .sum::<usize>()
771    }
772
773    fn estimate_data_type_size(data_type: &DataType) -> usize {
774        use std::mem::size_of;
775
776        match data_type {
777            DataType::Boolean => size_of::<bool>(),
778            DataType::Int16 => size_of::<i16>(),
779            DataType::Int32 => size_of::<i32>(),
780            DataType::Int64 => size_of::<i64>(),
781            DataType::Serial => size_of::<Serial>(),
782            DataType::Float32 => size_of::<f32>(),
783            DataType::Float64 => size_of::<f64>(),
784            DataType::Decimal => size_of::<Decimal>(),
785            DataType::Date => size_of::<Date>(),
786            DataType::Time => size_of::<Time>(),
787            DataType::Timestamp => size_of::<Timestamp>(),
788            DataType::Timestamptz => size_of::<Timestamptz>(),
789            DataType::Interval => size_of::<Interval>(),
790            DataType::Int256 => Int256::size(),
791            DataType::Varchar => 20,
792            DataType::Bytea => 20,
793            DataType::Jsonb => 20,
794            DataType::Struct { .. } => 20,
795            DataType::List { .. } => 20,
796            DataType::Map(_) => 20,
797            DataType::Vector(d) => d * size_of::<VectorDistanceType>(),
798        }
799    }
800
801    pub fn estimate(&mut self, predicate: &Condition) -> IndexCost {
802        // try to deal with OR condition
803        if predicate.conjunctions.len() == 1 {
804            self.visit_expr(&predicate.conjunctions[0]);
805            self.cost.take().unwrap_or_default()
806        } else {
807            self.estimate_conjunctions(&predicate.conjunctions)
808        }
809    }
810
811    fn estimate_conjunctions(&mut self, conjunctions: &[ExprImpl]) -> IndexCost {
812        let mut new_conjunctions = conjunctions.to_owned();
813
814        let mut match_item_vec = vec![];
815
816        for column_idx in self.table_scan.table().order_column_indices() {
817            let match_item = self.match_index_column(column_idx, &mut new_conjunctions);
818            // seeing range, we don't need to match anymore.
819            let should_break = match match_item {
820                MatchItem::Equal | MatchItem::In(_) => false,
821                MatchItem::RangeOneSideBound | MatchItem::RangeTwoSideBound | MatchItem::All => {
822                    true
823                }
824            };
825            match_item_vec.push(match_item);
826            if should_break {
827                break;
828            }
829        }
830
831        let index_cost = match_item_vec
832            .iter()
833            .enumerate()
834            .take(INDEX_MAX_LEN)
835            .map(|(i, match_item)| match match_item {
836                MatchItem::Equal => INDEX_COST_MATRIX[0][i],
837                MatchItem::In(num) => min(INDEX_COST_MATRIX[1][i], *num),
838                MatchItem::RangeTwoSideBound => INDEX_COST_MATRIX[2][i],
839                MatchItem::RangeOneSideBound => INDEX_COST_MATRIX[3][i],
840                MatchItem::All => INDEX_COST_MATRIX[4][i],
841            })
842            .reduce(|x, y| x * y)
843            .unwrap();
844
845        // If `index_cost` equals 1, it is a primary lookup
846        let primary_lookup = index_cost == 1;
847
848        IndexCost::new(index_cost, primary_lookup)
849            .mul(&IndexCost::new(self.row_size, primary_lookup))
850    }
851
852    fn match_index_column(
853        &mut self,
854        column_idx: usize,
855        conjunctions: &mut Vec<ExprImpl>,
856    ) -> MatchItem {
857        // Equal
858        for (i, expr) in conjunctions.iter().enumerate() {
859            if let Some((input_ref, _const_expr)) = expr.as_eq_const()
860                && input_ref.index == column_idx
861            {
862                conjunctions.remove(i);
863                return MatchItem::Equal;
864            }
865        }
866
867        // In
868        for (i, expr) in conjunctions.iter().enumerate() {
869            if let Some((input_ref, in_const_list)) = expr.as_in_const_list()
870                && input_ref.index == column_idx
871            {
872                conjunctions.remove(i);
873                return MatchItem::In(in_const_list.len());
874            }
875        }
876
877        // Range
878        let mut left_side_bound = false;
879        let mut right_side_bound = false;
880        let mut i = 0;
881        while i < conjunctions.len() {
882            let expr = &conjunctions[i];
883            if let Some((input_ref, op, _const_expr)) = expr.as_comparison_const()
884                && input_ref.index == column_idx
885            {
886                conjunctions.remove(i);
887                match op {
888                    ExprType::LessThan | ExprType::LessThanOrEqual => right_side_bound = true,
889                    ExprType::GreaterThan | ExprType::GreaterThanOrEqual => left_side_bound = true,
890                    _ => unreachable!(),
891                };
892            } else {
893                i += 1;
894            }
895        }
896
897        if left_side_bound && right_side_bound {
898            MatchItem::RangeTwoSideBound
899        } else if left_side_bound || right_side_bound {
900            MatchItem::RangeOneSideBound
901        } else {
902            MatchItem::All
903        }
904    }
905}
906
907enum MatchItem {
908    Equal,
909    In(usize),
910    RangeTwoSideBound,
911    RangeOneSideBound,
912    All,
913}
914
915#[derive(PartialEq, Eq, Hash, Clone, Debug, PartialOrd, Ord)]
916pub(crate) struct IndexCost {
917    cost: usize,
918    pub(crate) primary_lookup: bool,
919}
920
921impl Default for IndexCost {
922    fn default() -> Self {
923        Self {
924            cost: IndexCost::maximum(),
925            primary_lookup: false,
926        }
927    }
928}
929
930impl IndexCost {
931    fn new(cost: usize, primary_lookup: bool) -> IndexCost {
932        Self {
933            cost: min(cost, IndexCost::maximum()),
934            primary_lookup,
935        }
936    }
937
938    fn maximum() -> usize {
939        10000000
940    }
941
942    fn add(&self, other: &IndexCost) -> IndexCost {
943        IndexCost::new(
944            self.cost
945                .checked_add(other.cost)
946                .unwrap_or_else(IndexCost::maximum),
947            self.primary_lookup && other.primary_lookup,
948        )
949    }
950
951    fn mul(&self, other: &IndexCost) -> IndexCost {
952        IndexCost::new(
953            self.cost
954                .checked_mul(other.cost)
955                .unwrap_or_else(IndexCost::maximum),
956            self.primary_lookup && other.primary_lookup,
957        )
958    }
959
960    pub(crate) fn le(&self, other: &IndexCost) -> bool {
961        self.cost < other.cost
962    }
963}
964
965impl ExprVisitor for TableScanIoEstimator<'_> {
966    fn visit_function_call(&mut self, func_call: &FunctionCall) {
967        let cost = match func_call.func_type() {
968            ExprType::Or => func_call
969                .inputs()
970                .iter()
971                .map(|x| {
972                    let mut estimator = TableScanIoEstimator::new(self.table_scan, self.row_size);
973                    estimator.visit_expr(x);
974                    estimator.cost.take().unwrap_or_default()
975                })
976                .reduce(|x, y| x.add(&y))
977                .unwrap(),
978            ExprType::And => self.estimate_conjunctions(func_call.inputs()),
979            _ => {
980                let single = vec![ExprImpl::FunctionCall(func_call.clone().into())];
981                self.estimate_conjunctions(&single)
982            }
983        };
984        self.cost = Some(cost);
985    }
986}
987
988struct ShiftInputRefRewriter {
989    offset: usize,
990}
991impl ExprRewriter for ShiftInputRefRewriter {
992    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
993        InputRef::new(input_ref.index() + self.offset, input_ref.return_type()).into()
994    }
995}
996
997impl IndexSelectionRule {
998    pub fn create() -> BoxedRule {
999        Box::new(IndexSelectionRule {})
1000    }
1001}