risingwave_frontend/optimizer/
heuristic_optimizer.rs1use std::collections::HashMap;
16use std::collections::hash_map::Entry;
17use std::fmt;
18
19use itertools::Itertools;
20
21use super::ApplyResult;
22#[cfg(debug_assertions)]
23use crate::Explain;
24use crate::error::Result;
25use crate::optimizer::PlanRef;
26use crate::optimizer::plan_node::PlanTreeNode;
27use crate::optimizer::rule::BoxedRule;
28
29pub enum ApplyOrder {
31 TopDown,
32 BottomUp,
33}
34
35pub struct HeuristicOptimizer<'a> {
39 apply_order: &'a ApplyOrder,
40 rules: &'a [BoxedRule],
41 stats: Stats,
42}
43
44impl<'a> HeuristicOptimizer<'a> {
45 pub fn new(apply_order: &'a ApplyOrder, rules: &'a [BoxedRule]) -> Self {
46 Self {
47 apply_order,
48 rules,
49 stats: Stats::new(),
50 }
51 }
52
53 fn optimize_node(&mut self, mut plan: PlanRef) -> Result<PlanRef> {
54 for rule in self.rules {
55 match rule.apply(plan.clone()) {
56 ApplyResult::Ok(applied) => {
57 #[cfg(debug_assertions)]
58 Self::check_equivalent_plan(rule.description(), &plan, &applied);
59
60 plan = applied;
61 self.stats.count_rule(rule);
62 }
63 ApplyResult::NotApplicable => {}
64 ApplyResult::Err(error) => return Err(error),
65 }
66 }
67 Ok(plan)
68 }
69
70 fn optimize_inputs(&mut self, plan: PlanRef) -> Result<PlanRef> {
71 let pre_applied = self.stats.total_applied();
72 let inputs: Vec<_> = plan
73 .inputs()
74 .into_iter()
75 .map(|sub_tree| self.optimize(sub_tree))
76 .try_collect()?;
77
78 Ok(if pre_applied != self.stats.total_applied() {
79 plan.clone_with_inputs(&inputs)
80 } else {
81 plan
82 })
83 }
84
85 pub fn optimize(&mut self, mut plan: PlanRef) -> Result<PlanRef> {
86 match self.apply_order {
87 ApplyOrder::TopDown => {
88 plan = self.optimize_node(plan)?;
89 self.optimize_inputs(plan)
90 }
91 ApplyOrder::BottomUp => {
92 plan = self.optimize_inputs(plan)?;
93 self.optimize_node(plan)
94 }
95 }
96 }
97
98 pub fn get_stats(&self) -> &Stats {
99 &self.stats
100 }
101
102 #[cfg(debug_assertions)]
103 pub fn check_equivalent_plan(rule_desc: &str, input_plan: &PlanRef, output_plan: &PlanRef) {
104 if !input_plan.schema().type_eq(output_plan.schema()) {
105 panic!(
106 "{} fails to generate equivalent plan.\nInput schema: {:?}\nInput plan: \n{}\nOutput schema: {:?}\nOutput plan: \n{}\nSQL: {}",
107 rule_desc,
108 input_plan.schema(),
109 input_plan.explain_to_string(),
110 output_plan.schema(),
111 output_plan.explain_to_string(),
112 output_plan.ctx().sql()
113 );
114 }
115 }
116}
117
118pub struct Stats {
119 total_applied: usize,
120 rule_counter: HashMap<String, u32>,
121}
122
123impl Stats {
124 pub fn new() -> Self {
125 Self {
126 rule_counter: HashMap::new(),
127 total_applied: 0,
128 }
129 }
130
131 pub fn count_rule(&mut self, rule: &BoxedRule) {
132 self.total_applied += 1;
133 match self.rule_counter.entry(rule.description().to_owned()) {
134 Entry::Occupied(mut entry) => {
135 *entry.get_mut() += 1;
136 }
137 Entry::Vacant(entry) => {
138 entry.insert(1);
139 }
140 }
141 }
142
143 pub fn has_applied_rule(&self) -> bool {
144 self.total_applied != 0
145 }
146
147 pub fn total_applied(&self) -> usize {
148 self.total_applied
149 }
150}
151
152impl fmt::Display for Stats {
153 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
154 for (rule, count) in &self.rule_counter {
155 writeln!(f, "apply {} {} time(s)", rule, count)?;
156 }
157 Ok(())
158 }
159}