risingwave_frontend/optimizer/
heuristic_optimizer.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::collections::hash_map::Entry;
17use std::fmt;
18
19use itertools::Itertools;
20
21use super::ApplyResult;
22#[cfg(debug_assertions)]
23use crate::Explain;
24use crate::error::Result;
25use crate::optimizer::PlanRef;
26use crate::optimizer::plan_node::PlanTreeNode;
27use crate::optimizer::rule::BoxedRule;
28
29/// Traverse order of [`HeuristicOptimizer`]
30pub enum ApplyOrder {
31    TopDown,
32    BottomUp,
33}
34
35// TODO: we should have a builder of HeuristicOptimizer here
36/// A rule-based heuristic optimizer, which traverses every plan nodes and tries to
37/// apply each rule on them.
38pub struct HeuristicOptimizer<'a> {
39    apply_order: &'a ApplyOrder,
40    rules: &'a [BoxedRule],
41    stats: Stats,
42}
43
44impl<'a> HeuristicOptimizer<'a> {
45    pub fn new(apply_order: &'a ApplyOrder, rules: &'a [BoxedRule]) -> Self {
46        Self {
47            apply_order,
48            rules,
49            stats: Stats::new(),
50        }
51    }
52
53    fn optimize_node(&mut self, mut plan: PlanRef) -> Result<PlanRef> {
54        for rule in self.rules {
55            match rule.apply(plan.clone()) {
56                ApplyResult::Ok(applied) => {
57                    #[cfg(debug_assertions)]
58                    Self::check_equivalent_plan(rule.description(), &plan, &applied);
59
60                    plan = applied;
61                    self.stats.count_rule(rule);
62                }
63                ApplyResult::NotApplicable => {}
64                ApplyResult::Err(error) => return Err(error),
65            }
66        }
67        Ok(plan)
68    }
69
70    fn optimize_inputs(&mut self, plan: PlanRef) -> Result<PlanRef> {
71        let pre_applied = self.stats.total_applied();
72        let inputs: Vec<_> = plan
73            .inputs()
74            .into_iter()
75            .map(|sub_tree| self.optimize(sub_tree))
76            .try_collect()?;
77
78        Ok(if pre_applied != self.stats.total_applied() {
79            plan.clone_with_inputs(&inputs)
80        } else {
81            plan
82        })
83    }
84
85    pub fn optimize(&mut self, mut plan: PlanRef) -> Result<PlanRef> {
86        match self.apply_order {
87            ApplyOrder::TopDown => {
88                plan = self.optimize_node(plan)?;
89                self.optimize_inputs(plan)
90            }
91            ApplyOrder::BottomUp => {
92                plan = self.optimize_inputs(plan)?;
93                self.optimize_node(plan)
94            }
95        }
96    }
97
98    pub fn get_stats(&self) -> &Stats {
99        &self.stats
100    }
101
102    #[cfg(debug_assertions)]
103    pub fn check_equivalent_plan(rule_desc: &str, input_plan: &PlanRef, output_plan: &PlanRef) {
104        if !input_plan.schema().type_eq(output_plan.schema()) {
105            panic!(
106                "{} fails to generate equivalent plan.\nInput schema: {:?}\nInput plan: \n{}\nOutput schema: {:?}\nOutput plan: \n{}\nSQL: {}",
107                rule_desc,
108                input_plan.schema(),
109                input_plan.explain_to_string(),
110                output_plan.schema(),
111                output_plan.explain_to_string(),
112                output_plan.ctx().sql()
113            );
114        }
115    }
116}
117
118pub struct Stats {
119    total_applied: usize,
120    rule_counter: HashMap<String, u32>,
121}
122
123impl Stats {
124    pub fn new() -> Self {
125        Self {
126            rule_counter: HashMap::new(),
127            total_applied: 0,
128        }
129    }
130
131    pub fn count_rule(&mut self, rule: &BoxedRule) {
132        self.total_applied += 1;
133        match self.rule_counter.entry(rule.description().to_owned()) {
134            Entry::Occupied(mut entry) => {
135                *entry.get_mut() += 1;
136            }
137            Entry::Vacant(entry) => {
138                entry.insert(1);
139            }
140        }
141    }
142
143    pub fn has_applied_rule(&self) -> bool {
144        self.total_applied != 0
145    }
146
147    pub fn total_applied(&self) -> usize {
148        self.total_applied
149    }
150}
151
152impl fmt::Display for Stats {
153    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
154        for (rule, count) in &self.rule_counter {
155            writeln!(f, "apply {} {} time(s)", rule, count)?;
156        }
157        Ok(())
158    }
159}