risingwave_frontend/optimizer/plan_node/
merge_eq_nodes.rs1use std::collections::HashMap;
16use std::hash::Hash;
17
18use super::generic::GenericPlanRef;
19use super::{
20 EndoPlan, LogicalPlanRef as PlanRef, LogicalShare, PlanNodeId, PlanTreeNodeUnary, VisitPlan,
21};
22use crate::optimizer::plan_visitor;
23use crate::utils::{Endo, Visit};
24
25pub trait Semantics<V: Hash + Eq> {
26 fn semantics(&self) -> V;
27}
28
29impl Semantics<PlanRef> for PlanRef {
30 fn semantics(&self) -> PlanRef {
31 self.clone()
32 }
33}
34
35impl PlanRef {
36 pub fn common_subplan_sharing<V: Hash + Eq>(self) -> PlanRef
37 where
38 PlanRef: Semantics<V>,
39 {
40 Merger::default().apply(self)
41 }
42}
43
44struct Merger<V: Hash + Eq> {
45 cache: HashMap<V, LogicalShare>,
46}
47
48impl<V: Hash + Eq> Default for Merger<V> {
49 fn default() -> Self {
50 Merger {
51 cache: Default::default(),
52 }
53 }
54}
55
56impl<V: Hash + Eq> Endo<PlanRef> for Merger<V>
57where
58 PlanRef: Semantics<V>,
59{
60 fn apply(&mut self, t: PlanRef) -> PlanRef {
61 let semantics = t.semantics();
62 let share = self.cache.get(&semantics).cloned().unwrap_or_else(|| {
63 let share = LogicalShare::new(self.tree_apply(t));
64 self.cache.entry(semantics).or_insert(share).clone()
65 });
66 share.into()
67 }
68}
69
70impl PlanRef {
71 pub fn prune_share(&self) -> PlanRef {
72 let mut counter = Counter::default();
73 counter.visit(self);
74 counter.to_pruner().apply(self.clone())
75 }
76}
77
78#[derive(Default)]
79struct Counter {
80 counts: HashMap<PlanNodeId, u64>,
81}
82
83impl Counter {
84 fn to_pruner(&self) -> Pruner<'_> {
85 Pruner {
86 counts: &self.counts,
87 cache: HashMap::new(),
88 }
89 }
90}
91
92impl VisitPlan for Counter {
93 fn visited<F>(&mut self, plan: &PlanRef, mut f: F)
94 where
95 F: FnMut(&mut Self),
96 {
97 if self.counts.get(&plan.id()).is_none_or(|c| *c <= 1) {
98 f(self);
99 }
100 }
101}
102
103impl Visit<PlanRef> for Counter {
104 fn visit(&mut self, t: &PlanRef) {
105 if let Some(s) = t.as_logical_share() {
106 self.counts
107 .entry(s.id())
108 .and_modify(|c| *c += 1)
109 .or_insert(1);
110 }
111 self.dag_visit(t);
112 }
113}
114
115struct Pruner<'a> {
116 counts: &'a HashMap<PlanNodeId, u64>,
117 cache: HashMap<PlanNodeId, PlanRef>,
118}
119
120impl EndoPlan for Pruner<'_> {
121 fn cached<F>(&mut self, plan: PlanRef, mut f: F) -> PlanRef
122 where
123 F: FnMut(&mut Self) -> PlanRef,
124 {
125 self.cache.get(&plan.id()).cloned().unwrap_or_else(|| {
126 let res = f(self);
127 self.cache.entry(plan.id()).or_insert(res).clone()
128 })
129 }
130}
131
132impl Endo<PlanRef> for Pruner<'_> {
133 fn pre(&mut self, t: PlanRef) -> PlanRef {
134 let prunable = |s: &&LogicalShare| {
135 *self.counts.get(&s.id()).expect("Unprocessed shared node.") == 1
139 || s.input().as_logical_scan().is_some()
140 || !(plan_visitor::has_logical_scan(s.input())
141 || plan_visitor::has_logical_source(s.input()))
142 };
143 t.as_logical_share()
144 .filter(prunable)
145 .map_or(t.clone(), |s| self.pre(s.input()))
146 }
147
148 fn apply(&mut self, t: PlanRef) -> PlanRef {
149 self.dag_apply(t)
150 }
151}