risingwave_frontend/optimizer/plan_rewriter/
share_source_rewriter.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{HashMap, HashSet};
16
17use itertools::Itertools;
18
19use crate::PlanRef;
20use crate::catalog::SourceId;
21use crate::optimizer::plan_node::generic::GenericPlanRef;
22use crate::optimizer::plan_node::{
23    LogicalShare, LogicalSource, PlanNodeId, PlanTreeNode, StreamShare,
24};
25use crate::optimizer::plan_visitor::{DefaultBehavior, DefaultValue};
26use crate::optimizer::{PlanRewriter, PlanVisitor};
27
28#[derive(Debug, Clone, Default)]
29pub struct ShareSourceRewriter {
30    /// Source id need to be shared.
31    share_ids: HashSet<SourceId>,
32    /// Source id to share node.
33    share_source: HashMap<SourceId, PlanRef>,
34    /// Original share node plan id to new share node.
35    /// Rewriter will rewrite all nodes, but we need to keep the shape of the DAG.
36    share_map: HashMap<PlanNodeId, PlanRef>,
37}
38
39#[derive(Debug, Clone, Default)]
40struct SourceCounter {
41    /// Source id to count.
42    source_counter: HashMap<SourceId, usize>,
43}
44
45impl ShareSourceRewriter {
46    pub fn share_source(plan: PlanRef) -> PlanRef {
47        // Find which sources occurred more than once.
48        let mut source_counter = SourceCounter::default();
49        source_counter.visit(plan.clone());
50
51        let mut share_source_rewriter = ShareSourceRewriter {
52            share_ids: source_counter
53                .source_counter
54                .into_iter()
55                .filter(|(_, v)| *v > 1)
56                .map(|(k, _)| k)
57                .collect(),
58            share_source: Default::default(),
59            share_map: Default::default(),
60        };
61        // Rewrite source to share source
62        share_source_rewriter.rewrite(plan)
63    }
64}
65
66impl PlanRewriter for ShareSourceRewriter {
67    fn rewrite_logical_source(&mut self, source: &LogicalSource) -> PlanRef {
68        let source_id = match &source.core.catalog {
69            Some(s) => s.id,
70            None => {
71                return source.clone().into();
72            }
73        };
74        if !self.share_ids.contains(&source_id) {
75            let source_ref = source.clone().into();
76            return source_ref;
77        }
78        match self.share_source.get(&source_id) {
79            None => {
80                let source_ref = source.clone().into();
81                let share_source = LogicalShare::create(source_ref);
82                self.share_source.insert(source_id, share_source.clone());
83                share_source
84            }
85            Some(share_source) => share_source.clone(),
86        }
87    }
88
89    fn rewrite_logical_share(&mut self, share: &LogicalShare) -> PlanRef {
90        // When we use the plan rewriter, we need to take care of the share operator,
91        // because our plan is a DAG rather than a tree.
92        match self.share_map.get(&share.id()) {
93            None => {
94                let new_inputs = share
95                    .inputs()
96                    .into_iter()
97                    .map(|input| self.rewrite(input))
98                    .collect_vec();
99                let new_share = share.clone_with_inputs(&new_inputs);
100                self.share_map.insert(share.id(), new_share.clone());
101                new_share
102            }
103            Some(new_share) => new_share.clone(),
104        }
105    }
106
107    fn rewrite_stream_share(&mut self, _share: &StreamShare) -> PlanRef {
108        // We only access logical node here, so stream share is unreachable.
109        unreachable!()
110    }
111}
112
113impl PlanVisitor for SourceCounter {
114    type Result = ();
115
116    type DefaultBehavior = impl DefaultBehavior<Self::Result>;
117
118    fn default_behavior() -> Self::DefaultBehavior {
119        DefaultValue
120    }
121
122    fn visit_logical_source(&mut self, source: &LogicalSource) {
123        if let Some(source) = &source.core.catalog {
124            self.source_counter
125                .entry(source.id)
126                .and_modify(|count| *count += 1)
127                .or_insert(1);
128        }
129    }
130}