risingwave_frontend/optimizer/
heuristic_optimizer.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::collections::hash_map::Entry;
17use std::fmt;
18
19use itertools::Itertools;
20
21use super::ApplyResult;
22#[cfg(debug_assertions)]
23use crate::Explain;
24use crate::error::Result;
25use crate::optimizer::plan_node::ConventionMarker;
26use crate::optimizer::rule::BoxedRule;
27use crate::optimizer::{PlanRef, PlanTreeNode};
28
29/// Traverse order of [`HeuristicOptimizer`]
30pub enum ApplyOrder {
31    TopDown,
32    BottomUp,
33}
34
35// TODO: we should have a builder of HeuristicOptimizer here
36/// A rule-based heuristic optimizer, which traverses every plan nodes and tries to
37/// apply each rule on them.
38pub struct HeuristicOptimizer<'a, C: ConventionMarker> {
39    apply_order: &'a ApplyOrder,
40    rules: &'a [BoxedRule<C>],
41    stats: Stats,
42}
43
44impl<'a, C: ConventionMarker> HeuristicOptimizer<'a, C> {
45    pub fn new(apply_order: &'a ApplyOrder, rules: &'a [BoxedRule<C>]) -> Self {
46        Self {
47            apply_order,
48            rules,
49            stats: Stats::new(),
50        }
51    }
52
53    fn optimize_node(&mut self, mut plan: PlanRef<C>) -> Result<PlanRef<C>> {
54        for rule in self.rules {
55            match rule.apply(plan.clone()) {
56                ApplyResult::Ok(applied) => {
57                    #[cfg(debug_assertions)]
58                    Self::check_equivalent_plan(rule.description(), &plan, &applied);
59
60                    plan = applied;
61                    self.stats.count_rule(rule);
62                }
63                ApplyResult::NotApplicable => {}
64                ApplyResult::Err(error) => return Err(error),
65            }
66        }
67        Ok(plan)
68    }
69
70    fn optimize_inputs(&mut self, plan: PlanRef<C>) -> Result<PlanRef<C>> {
71        let pre_applied = self.stats.total_applied();
72        let inputs: Vec<_> = plan
73            .inputs()
74            .into_iter()
75            .map(|sub_tree| self.optimize(sub_tree))
76            .try_collect()?;
77
78        Ok(if pre_applied != self.stats.total_applied() {
79            plan.clone_root_with_inputs(&inputs)
80        } else {
81            plan
82        })
83    }
84
85    pub fn optimize(&mut self, mut plan: PlanRef<C>) -> Result<PlanRef<C>> {
86        match self.apply_order {
87            ApplyOrder::TopDown => {
88                plan = self.optimize_node(plan)?;
89                self.optimize_inputs(plan)
90            }
91            ApplyOrder::BottomUp => {
92                plan = self.optimize_inputs(plan)?;
93                self.optimize_node(plan)
94            }
95        }
96    }
97
98    pub fn get_stats(&self) -> &Stats {
99        &self.stats
100    }
101
102    #[cfg(debug_assertions)]
103    pub fn check_equivalent_plan(
104        rule_desc: &str,
105        input_plan: &PlanRef<C>,
106        output_plan: &PlanRef<C>,
107    ) {
108        use crate::optimizer::plan_node::generic::GenericPlanRef;
109        if !input_plan.schema().type_eq(output_plan.schema()) {
110            panic!(
111                "{} fails to generate equivalent plan.\nInput schema: {:?}\nInput plan: \n{}\nOutput schema: {:?}\nOutput plan: \n{}\nSQL: {}",
112                rule_desc,
113                input_plan.schema(),
114                input_plan.explain_to_string(),
115                output_plan.schema(),
116                output_plan.explain_to_string(),
117                output_plan.ctx().sql()
118            );
119        }
120    }
121}
122
123pub struct Stats {
124    total_applied: usize,
125    rule_counter: HashMap<String, u32>,
126}
127
128impl Stats {
129    pub fn new() -> Self {
130        Self {
131            rule_counter: HashMap::new(),
132            total_applied: 0,
133        }
134    }
135
136    pub fn count_rule(&mut self, rule: &BoxedRule<impl ConventionMarker>) {
137        self.total_applied += 1;
138        match self.rule_counter.entry(rule.description().to_owned()) {
139            Entry::Occupied(mut entry) => {
140                *entry.get_mut() += 1;
141            }
142            Entry::Vacant(entry) => {
143                entry.insert(1);
144            }
145        }
146    }
147
148    pub fn has_applied_rule(&self) -> bool {
149        self.total_applied != 0
150    }
151
152    pub fn total_applied(&self) -> usize {
153        self.total_applied
154    }
155}
156
157impl fmt::Display for Stats {
158    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
159        for (rule, count) in &self.rule_counter {
160            writeln!(f, "apply {} {} time(s)", rule, count)?;
161        }
162        Ok(())
163    }
164}