risingwave_frontend/optimizer/plan_node/
merge_eq_nodes.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::hash::Hash;
17
18use super::generic::GenericPlanRef;
19use super::{
20    EndoPlan, LogicalPlanRef as PlanRef, LogicalShare, PlanNodeId, PlanTreeNodeUnary, VisitPlan,
21};
22use crate::optimizer::plan_visitor;
23use crate::utils::{Endo, Visit};
24
25pub trait Semantics<V: Hash + Eq> {
26    fn semantics(&self) -> V;
27}
28
29impl Semantics<PlanRef> for PlanRef {
30    fn semantics(&self) -> PlanRef {
31        self.clone()
32    }
33}
34
35impl PlanRef {
36    pub fn common_subplan_sharing<V: Hash + Eq>(self) -> PlanRef
37    where
38        PlanRef: Semantics<V>,
39    {
40        Merger::default().apply(self)
41    }
42}
43
44struct Merger<V: Hash + Eq> {
45    cache: HashMap<V, LogicalShare>,
46}
47
48impl<V: Hash + Eq> Default for Merger<V> {
49    fn default() -> Self {
50        Merger {
51            cache: Default::default(),
52        }
53    }
54}
55
56impl<V: Hash + Eq> Endo<PlanRef> for Merger<V>
57where
58    PlanRef: Semantics<V>,
59{
60    fn apply(&mut self, t: PlanRef) -> PlanRef {
61        let semantics = t.semantics();
62        let share = self.cache.get(&semantics).cloned().unwrap_or_else(|| {
63            let share = LogicalShare::new(self.tree_apply(t));
64            self.cache.entry(semantics).or_insert(share).clone()
65        });
66        share.into()
67    }
68}
69
70impl PlanRef {
71    pub fn prune_share(&self) -> PlanRef {
72        let mut counter = Counter::default();
73        counter.visit(self);
74        counter.to_pruner().apply(self.clone())
75    }
76}
77
78#[derive(Default)]
79struct Counter {
80    counts: HashMap<PlanNodeId, u64>,
81}
82
83impl Counter {
84    fn to_pruner(&self) -> Pruner<'_> {
85        Pruner {
86            counts: &self.counts,
87            cache: HashMap::new(),
88        }
89    }
90}
91
92impl VisitPlan for Counter {
93    fn visited<F>(&mut self, plan: &PlanRef, mut f: F)
94    where
95        F: FnMut(&mut Self),
96    {
97        if self.counts.get(&plan.id()).is_none_or(|c| *c <= 1) {
98            f(self);
99        }
100    }
101}
102
103impl Visit<PlanRef> for Counter {
104    fn visit(&mut self, t: &PlanRef) {
105        if let Some(s) = t.as_logical_share() {
106            self.counts
107                .entry(s.id())
108                .and_modify(|c| *c += 1)
109                .or_insert(1);
110        }
111        self.dag_visit(t);
112    }
113}
114
115struct Pruner<'a> {
116    counts: &'a HashMap<PlanNodeId, u64>,
117    cache: HashMap<PlanNodeId, PlanRef>,
118}
119
120impl EndoPlan for Pruner<'_> {
121    fn cached<F>(&mut self, plan: PlanRef, mut f: F) -> PlanRef
122    where
123        F: FnMut(&mut Self) -> PlanRef,
124    {
125        self.cache.get(&plan.id()).cloned().unwrap_or_else(|| {
126            let res = f(self);
127            self.cache.entry(plan.id()).or_insert(res).clone()
128        })
129    }
130}
131
132impl Endo<PlanRef> for Pruner<'_> {
133    fn pre(&mut self, t: PlanRef) -> PlanRef {
134        let prunable = |s: &&LogicalShare| {
135            // Prune if share node has only one parent
136            // or it just shares a scan
137            // or it doesn't share any scan or source.
138            *self.counts.get(&s.id()).expect("Unprocessed shared node.") == 1
139                || s.input().as_logical_scan().is_some()
140                || !(plan_visitor::has_logical_scan(s.input())
141                    || plan_visitor::has_logical_source(s.input()))
142        };
143        t.as_logical_share()
144            .filter(prunable)
145            .map_or(t.clone(), |s| self.pre(s.input()))
146    }
147
148    fn apply(&mut self, t: PlanRef) -> PlanRef {
149        self.dag_apply(t)
150    }
151}