risingwave_frontend/optimizer/plan_rewriter/
share_source_rewriter.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{HashMap, HashSet};
16
17use crate::catalog::SourceId;
18use crate::optimizer::PlanVisitor;
19use crate::optimizer::plan_node::{
20    Logical, LogicalPlanRef as PlanRef, LogicalShare, LogicalSource,
21};
22use crate::optimizer::plan_rewriter::PlanRewriter;
23use crate::optimizer::plan_visitor::{DefaultBehavior, DefaultValue, LogicalPlanVisitor};
24
25#[derive(Debug, Clone, Default)]
26pub struct ShareSourceRewriter {
27    /// Source id need to be shared.
28    share_ids: HashSet<SourceId>,
29    /// Source id to share node.
30    share_source: HashMap<SourceId, PlanRef>,
31}
32
33#[derive(Debug, Clone, Default)]
34struct SourceCounter {
35    /// Source id to count.
36    source_counter: HashMap<SourceId, usize>,
37}
38
39impl ShareSourceRewriter {
40    pub fn share_source(plan: PlanRef) -> PlanRef {
41        // Find which sources occurred more than once.
42        let mut source_counter = SourceCounter::default();
43        source_counter.visit(plan.clone());
44
45        let mut share_source_rewriter = ShareSourceRewriter {
46            share_ids: source_counter
47                .source_counter
48                .into_iter()
49                .filter(|(_, v)| *v > 1)
50                .map(|(k, _)| k)
51                .collect(),
52            share_source: Default::default(),
53        };
54        // Rewrite source to share source
55        plan.rewrite_with(&mut share_source_rewriter)
56    }
57
58    fn rewrite_logical_source(&mut self, source: &LogicalSource) -> PlanRef {
59        let source_id = match &source.core.catalog {
60            Some(s) => s.id,
61            None => {
62                return source.clone().into();
63            }
64        };
65        if !self.share_ids.contains(&source_id) {
66            let source_ref = source.clone().into();
67            return source_ref;
68        }
69        match self.share_source.get(&source_id) {
70            None => {
71                let source_ref = source.clone().into();
72                let share_source = LogicalShare::create(source_ref);
73                self.share_source.insert(source_id, share_source.clone());
74                share_source
75            }
76            Some(share_source) => share_source.clone(),
77        }
78    }
79}
80
81impl PlanRewriter<Logical> for ShareSourceRewriter {
82    fn rewrite_with_inputs(&mut self, plan: &PlanRef, inputs: Vec<PlanRef>) -> PlanRef {
83        if let Some(source) = plan.as_logical_source() {
84            self.rewrite_logical_source(source)
85        } else {
86            plan.clone_root_with_inputs(&inputs)
87        }
88    }
89}
90
91impl LogicalPlanVisitor for SourceCounter {
92    type Result = ();
93
94    type DefaultBehavior = impl DefaultBehavior<Self::Result>;
95
96    fn default_behavior() -> Self::DefaultBehavior {
97        DefaultValue
98    }
99
100    fn visit_logical_source(&mut self, source: &LogicalSource) {
101        if let Some(source) = &source.core.catalog {
102            self.source_counter
103                .entry(source.id)
104                .and_modify(|count| *count += 1)
105                .or_insert(1);
106        }
107    }
108}