risingwave_frontend/optimizer/
heuristic_optimizer.rs1use std::collections::HashMap;
16use std::collections::hash_map::Entry;
17use std::fmt;
18
19use itertools::Itertools;
20
21use super::ApplyResult;
22#[cfg(debug_assertions)]
23use crate::Explain;
24use crate::error::Result;
25use crate::optimizer::plan_node::ConventionMarker;
26use crate::optimizer::rule::BoxedRule;
27use crate::optimizer::{PlanRef, PlanTreeNode};
28
29pub enum ApplyOrder {
31 TopDown,
32 BottomUp,
33}
34
35pub struct HeuristicOptimizer<'a, C: ConventionMarker> {
39 apply_order: &'a ApplyOrder,
40 rules: &'a [BoxedRule<C>],
41 stats: Stats,
42}
43
44impl<'a, C: ConventionMarker> HeuristicOptimizer<'a, C> {
45 pub fn new(apply_order: &'a ApplyOrder, rules: &'a [BoxedRule<C>]) -> Self {
46 Self {
47 apply_order,
48 rules,
49 stats: Stats::new(),
50 }
51 }
52
53 fn optimize_node(&mut self, mut plan: PlanRef<C>) -> Result<PlanRef<C>> {
54 for rule in self.rules {
55 match rule.apply(plan.clone()) {
56 ApplyResult::Ok(applied) => {
57 #[cfg(debug_assertions)]
58 Self::check_equivalent_plan(rule.description(), &plan, &applied);
59
60 plan = applied;
61 self.stats.count_rule(rule);
62 }
63 ApplyResult::NotApplicable => {}
64 ApplyResult::Err(error) => return Err(error),
65 }
66 }
67 Ok(plan)
68 }
69
70 fn optimize_inputs(&mut self, plan: PlanRef<C>) -> Result<PlanRef<C>> {
71 let pre_applied = self.stats.total_applied();
72 let inputs: Vec<_> = plan
73 .inputs()
74 .into_iter()
75 .map(|sub_tree| self.optimize(sub_tree))
76 .try_collect()?;
77
78 Ok(if pre_applied != self.stats.total_applied() {
79 plan.clone_root_with_inputs(&inputs)
80 } else {
81 plan
82 })
83 }
84
85 pub fn optimize(&mut self, mut plan: PlanRef<C>) -> Result<PlanRef<C>> {
86 match self.apply_order {
87 ApplyOrder::TopDown => {
88 plan = self.optimize_node(plan)?;
89 self.optimize_inputs(plan)
90 }
91 ApplyOrder::BottomUp => {
92 plan = self.optimize_inputs(plan)?;
93 self.optimize_node(plan)
94 }
95 }
96 }
97
98 pub fn get_stats(&self) -> &Stats {
99 &self.stats
100 }
101
102 #[cfg(debug_assertions)]
103 pub fn check_equivalent_plan(
104 rule_desc: &str,
105 input_plan: &PlanRef<C>,
106 output_plan: &PlanRef<C>,
107 ) {
108 use crate::optimizer::plan_node::generic::GenericPlanRef;
109 if !input_plan.schema().type_eq(output_plan.schema()) {
110 panic!(
111 "{} fails to generate equivalent plan.\nInput schema: {:?}\nInput plan: \n{}\nOutput schema: {:?}\nOutput plan: \n{}\nSQL: {}",
112 rule_desc,
113 input_plan.schema(),
114 input_plan.explain_to_string(),
115 output_plan.schema(),
116 output_plan.explain_to_string(),
117 output_plan.ctx().sql()
118 );
119 }
120 }
121}
122
123pub struct Stats {
124 total_applied: usize,
125 rule_counter: HashMap<String, u32>,
126}
127
128impl Stats {
129 pub fn new() -> Self {
130 Self {
131 rule_counter: HashMap::new(),
132 total_applied: 0,
133 }
134 }
135
136 pub fn count_rule(&mut self, rule: &BoxedRule<impl ConventionMarker>) {
137 self.total_applied += 1;
138 match self.rule_counter.entry(rule.description().to_owned()) {
139 Entry::Occupied(mut entry) => {
140 *entry.get_mut() += 1;
141 }
142 Entry::Vacant(entry) => {
143 entry.insert(1);
144 }
145 }
146 }
147
148 pub fn has_applied_rule(&self) -> bool {
149 self.total_applied != 0
150 }
151
152 pub fn total_applied(&self) -> usize {
153 self.total_applied
154 }
155}
156
157impl fmt::Display for Stats {
158 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
159 for (rule, count) in &self.rule_counter {
160 writeln!(f, "apply {} {} time(s)", rule, count)?;
161 }
162 Ok(())
163 }
164}