risingwave_frontend/optimizer/plan_node/
merge_eq_nodes.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::hash::Hash;
17
18use super::generic::GenericPlanRef;
19use super::{EndoPlan, LogicalShare, PlanNodeId, PlanRef, PlanTreeNodeUnary, VisitPlan};
20use crate::optimizer::plan_visitor;
21use crate::utils::{Endo, Visit};
22
23pub trait Semantics<V: Hash + Eq> {
24    fn semantics(&self) -> V;
25}
26
27impl Semantics<PlanRef> for PlanRef {
28    fn semantics(&self) -> PlanRef {
29        self.clone()
30    }
31}
32
33impl PlanRef {
34    pub fn common_subplan_sharing<V: Hash + Eq>(self) -> PlanRef
35    where
36        PlanRef: Semantics<V>,
37    {
38        Merger::default().apply(self)
39    }
40}
41
42struct Merger<V: Hash + Eq> {
43    cache: HashMap<V, LogicalShare>,
44}
45
46impl<V: Hash + Eq> Default for Merger<V> {
47    fn default() -> Self {
48        Merger {
49            cache: Default::default(),
50        }
51    }
52}
53
54impl<V: Hash + Eq> Endo<PlanRef> for Merger<V>
55where
56    PlanRef: Semantics<V>,
57{
58    fn apply(&mut self, t: PlanRef) -> PlanRef {
59        let semantics = t.semantics();
60        let share = self.cache.get(&semantics).cloned().unwrap_or_else(|| {
61            let share = LogicalShare::new(self.tree_apply(t));
62            self.cache.entry(semantics).or_insert(share).clone()
63        });
64        share.into()
65    }
66}
67
68impl PlanRef {
69    pub fn prune_share(&self) -> PlanRef {
70        let mut counter = Counter::default();
71        counter.visit(self);
72        counter.to_pruner().apply(self.clone())
73    }
74}
75
76#[derive(Default)]
77struct Counter {
78    counts: HashMap<PlanNodeId, u64>,
79}
80
81impl Counter {
82    fn to_pruner(&self) -> Pruner<'_> {
83        Pruner {
84            counts: &self.counts,
85            cache: HashMap::new(),
86        }
87    }
88}
89
90impl VisitPlan for Counter {
91    fn visited<F>(&mut self, plan: &PlanRef, mut f: F)
92    where
93        F: FnMut(&mut Self),
94    {
95        if self.counts.get(&plan.id()).is_none_or(|c| *c <= 1) {
96            f(self);
97        }
98    }
99}
100
101impl Visit<PlanRef> for Counter {
102    fn visit(&mut self, t: &PlanRef) {
103        if let Some(s) = t.as_logical_share() {
104            self.counts
105                .entry(s.id())
106                .and_modify(|c| *c += 1)
107                .or_insert(1);
108        }
109        self.dag_visit(t);
110    }
111}
112
113struct Pruner<'a> {
114    counts: &'a HashMap<PlanNodeId, u64>,
115    cache: HashMap<PlanNodeId, PlanRef>,
116}
117
118impl EndoPlan for Pruner<'_> {
119    fn cached<F>(&mut self, plan: PlanRef, mut f: F) -> PlanRef
120    where
121        F: FnMut(&mut Self) -> PlanRef,
122    {
123        self.cache.get(&plan.id()).cloned().unwrap_or_else(|| {
124            let res = f(self);
125            self.cache.entry(plan.id()).or_insert(res).clone()
126        })
127    }
128}
129
130impl Endo<PlanRef> for Pruner<'_> {
131    fn pre(&mut self, t: PlanRef) -> PlanRef {
132        let prunable = |s: &&LogicalShare| {
133            // Prune if share node has only one parent
134            // or it just shares a scan
135            // or it doesn't share any scan or source.
136            *self.counts.get(&s.id()).expect("Unprocessed shared node.") == 1
137                || s.input().as_logical_scan().is_some()
138                || !(plan_visitor::has_logical_scan(s.input())
139                    || plan_visitor::has_logical_source(s.input()))
140        };
141        t.as_logical_share()
142            .filter(prunable)
143            .map_or(t.clone(), |s| self.pre(s.input()))
144    }
145
146    fn apply(&mut self, t: PlanRef) -> PlanRef {
147        self.dag_apply(t)
148    }
149}