risingwave_common/util/
stream_graph_visitor.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16use risingwave_pb::catalog::Table;
17use risingwave_pb::stream_plan::stream_fragment_graph::StreamFragment;
18use risingwave_pb::stream_plan::stream_node::NodeBody;
19use risingwave_pb::stream_plan::{StreamNode, agg_call_state};
20
21/// A utility for visiting and mutating the [`NodeBody`] of the [`StreamNode`]s recursively.
22pub fn visit_stream_node_mut(stream_node: &mut StreamNode, mut f: impl FnMut(&mut NodeBody)) {
23    visit_stream_node_cont_mut(stream_node, |stream_node| {
24        f(stream_node.node_body.as_mut().unwrap());
25        true
26    })
27}
28
29/// A utility for to accessing the [`StreamNode`] mutably. The returned bool is used to determine whether the access needs to continue.
30pub fn visit_stream_node_cont_mut<F>(stream_node: &mut StreamNode, mut f: F)
31where
32    F: FnMut(&mut StreamNode) -> bool,
33{
34    fn visit_inner<F>(stream_node: &mut StreamNode, f: &mut F)
35    where
36        F: FnMut(&mut StreamNode) -> bool,
37    {
38        if !f(stream_node) {
39            return;
40        }
41        for input in &mut stream_node.input {
42            visit_inner(input, f);
43        }
44    }
45
46    visit_inner(stream_node, &mut f)
47}
48
49/// A utility for visiting the [`NodeBody`] of the [`StreamNode`]s recursively.
50pub fn visit_stream_node(stream_node: &StreamNode, mut f: impl FnMut(&NodeBody)) {
51    visit_stream_node_cont(stream_node, |stream_node| {
52        f(stream_node.node_body.as_ref().unwrap());
53        true
54    })
55}
56
57/// A utility for to accessing the [`StreamNode`] immutably. The returned bool is used to determine whether the access needs to continue.
58pub fn visit_stream_node_cont<F>(stream_node: &StreamNode, mut f: F)
59where
60    F: FnMut(&StreamNode) -> bool,
61{
62    fn visit_inner<F>(stream_node: &StreamNode, f: &mut F)
63    where
64        F: FnMut(&StreamNode) -> bool,
65    {
66        if !f(stream_node) {
67            return;
68        }
69        for input in &stream_node.input {
70            visit_inner(input, f);
71        }
72    }
73
74    visit_inner(stream_node, &mut f)
75}
76
77/// A utility for visiting and mutating the [`NodeBody`] of the [`StreamNode`]s in a
78/// [`StreamFragment`] recursively.
79pub fn visit_fragment_mut(fragment: &mut StreamFragment, f: impl FnMut(&mut NodeBody)) {
80    visit_stream_node_mut(fragment.node.as_mut().unwrap(), f)
81}
82
83/// A utility for visiting the [`NodeBody`] of the [`StreamNode`]s in a
84/// [`StreamFragment`] recursively.
85pub fn visit_fragment(fragment: &StreamFragment, f: impl FnMut(&NodeBody)) {
86    visit_stream_node(fragment.node.as_ref().unwrap(), f)
87}
88
89/// Visit the tables of a [`StreamNode`].
90pub fn visit_stream_node_tables_inner<F>(
91    stream_node: &mut StreamNode,
92    internal_tables_only: bool,
93    visit_child_recursively: bool,
94    mut f: F,
95) where
96    F: FnMut(&mut Table, &str),
97{
98    macro_rules! always {
99        ($table:expr, $name:expr) => {{
100            let table = $table
101                .as_mut()
102                .unwrap_or_else(|| panic!("internal table {} should always exist", $name));
103            f(table, $name);
104        }};
105    }
106
107    macro_rules! optional {
108        ($table:expr, $name:expr) => {
109            if let Some(table) = &mut $table {
110                f(table, $name);
111            }
112        };
113    }
114
115    macro_rules! repeated {
116        ($tables:expr, $name:expr) => {
117            for table in &mut $tables {
118                f(table, $name);
119            }
120        };
121    }
122
123    let mut visit_body = |body: &mut NodeBody| {
124        match body {
125            // Join
126            NodeBody::HashJoin(node) => {
127                // TODO: make the degree table optional
128                always!(node.left_table, "HashJoinLeft");
129                always!(node.left_degree_table, "HashJoinDegreeLeft");
130                always!(node.right_table, "HashJoinRight");
131                always!(node.right_degree_table, "HashJoinDegreeRight");
132            }
133            NodeBody::TemporalJoin(node) => {
134                optional!(node.memo_table, "TemporalJoinMemo");
135            }
136            NodeBody::DynamicFilter(node) => {
137                always!(node.left_table, "DynamicFilterLeft");
138                always!(node.right_table, "DynamicFilterRight");
139            }
140
141            // Aggregation
142            NodeBody::HashAgg(node) => {
143                assert_eq!(node.agg_call_states.len(), node.agg_calls.len());
144                always!(node.intermediate_state_table, "HashAggState");
145                for (call_idx, state) in node.agg_call_states.iter_mut().enumerate() {
146                    match state.inner.as_mut().unwrap() {
147                        agg_call_state::Inner::ValueState(_) => {}
148                        agg_call_state::Inner::MaterializedInputState(s) => {
149                            always!(s.table, &format!("HashAggCall{}", call_idx));
150                        }
151                    }
152                }
153                for (distinct_col, dedup_table) in node
154                    .distinct_dedup_tables
155                    .iter_mut()
156                    .sorted_by_key(|(i, _)| *i)
157                {
158                    f(dedup_table, &format!("HashAggDedupForCol{}", distinct_col));
159                }
160            }
161            NodeBody::SimpleAgg(node) => {
162                assert_eq!(node.agg_call_states.len(), node.agg_calls.len());
163                always!(node.intermediate_state_table, "SimpleAggState");
164                for (call_idx, state) in node.agg_call_states.iter_mut().enumerate() {
165                    match state.inner.as_mut().unwrap() {
166                        agg_call_state::Inner::ValueState(_) => {}
167                        agg_call_state::Inner::MaterializedInputState(s) => {
168                            always!(s.table, &format!("SimpleAggCall{}", call_idx));
169                        }
170                    }
171                }
172                for (distinct_col, dedup_table) in node
173                    .distinct_dedup_tables
174                    .iter_mut()
175                    .sorted_by_key(|(i, _)| *i)
176                {
177                    f(
178                        dedup_table,
179                        &format!("SimpleAggDedupForCol{}", distinct_col),
180                    );
181                }
182            }
183
184            // Top-N
185            NodeBody::AppendOnlyTopN(node) => {
186                always!(node.table, "AppendOnlyTopN");
187            }
188            NodeBody::TopN(node) => {
189                always!(node.table, "TopN");
190            }
191            NodeBody::AppendOnlyGroupTopN(node) => {
192                always!(node.table, "AppendOnlyGroupTopN");
193            }
194            NodeBody::GroupTopN(node) => {
195                always!(node.table, "GroupTopN");
196            }
197
198            // Source
199            NodeBody::Source(node) => {
200                if let Some(source) = &mut node.source_inner {
201                    always!(source.state_table, "Source");
202                }
203            }
204            NodeBody::StreamFsFetch(node) => {
205                if let Some(source) = &mut node.node_inner {
206                    always!(source.state_table, "FsFetch");
207                }
208            }
209            NodeBody::SourceBackfill(node) => {
210                always!(node.state_table, "SourceBackfill")
211            }
212
213            // Sink
214            NodeBody::Sink(node) => {
215                // A sink with a kv log store should have a state table.
216                optional!(node.table, "Sink")
217            }
218
219            // Now
220            NodeBody::Now(node) => {
221                always!(node.state_table, "Now");
222            }
223
224            // Watermark filter
225            NodeBody::WatermarkFilter(node) => {
226                assert!(!node.tables.is_empty());
227                repeated!(node.tables, "WatermarkFilter");
228            }
229
230            // Shared arrangement
231            NodeBody::Arrange(node) => {
232                always!(node.table, "Arrange");
233            }
234
235            // Dedup
236            NodeBody::AppendOnlyDedup(node) => {
237                always!(node.state_table, "AppendOnlyDedup");
238            }
239
240            // EOWC over window
241            NodeBody::EowcOverWindow(node) => {
242                always!(node.state_table, "EowcOverWindow");
243            }
244
245            NodeBody::OverWindow(node) => {
246                always!(node.state_table, "OverWindow");
247            }
248
249            // Sort
250            NodeBody::Sort(node) => {
251                always!(node.state_table, "Sort");
252            }
253
254            // Stream Scan
255            NodeBody::StreamScan(node) => {
256                optional!(node.state_table, "StreamScan")
257            }
258
259            // Stream Cdc Scan
260            NodeBody::StreamCdcScan(node) => {
261                always!(node.state_table, "StreamCdcScan")
262            }
263
264            // Note: add internal tables for new nodes here.
265            NodeBody::Materialize(node) if !internal_tables_only => {
266                always!(node.table, "Materialize")
267            }
268
269            // Global Approx Percentile
270            NodeBody::GlobalApproxPercentile(node) => {
271                always!(node.bucket_state_table, "GlobalApproxPercentileBucketState");
272                always!(node.count_state_table, "GlobalApproxPercentileCountState");
273            }
274
275            // AsOf join
276            NodeBody::AsOfJoin(node) => {
277                always!(node.left_table, "AsOfJoinLeft");
278                always!(node.right_table, "AsOfJoinRight");
279            }
280
281            // Synced Log Store
282            NodeBody::SyncLogStore(node) => {
283                always!(node.log_store_table, "StreamSyncLogStore");
284            }
285            _ => {}
286        }
287    };
288    if visit_child_recursively {
289        visit_stream_node_mut(stream_node, visit_body)
290    } else {
291        visit_body(stream_node.node_body.as_mut().unwrap())
292    }
293}
294
295pub fn visit_stream_node_internal_tables<F>(stream_node: &mut StreamNode, f: F)
296where
297    F: FnMut(&mut Table, &str),
298{
299    visit_stream_node_tables_inner(stream_node, true, true, f)
300}
301
302pub fn visit_stream_node_tables<F>(stream_node: &mut StreamNode, f: F)
303where
304    F: FnMut(&mut Table, &str),
305{
306    visit_stream_node_tables_inner(stream_node, false, true, f)
307}
308
309/// Visit the internal tables of a [`StreamFragment`].
310pub fn visit_internal_tables<F>(fragment: &mut StreamFragment, f: F)
311where
312    F: FnMut(&mut Table, &str),
313{
314    visit_stream_node_internal_tables(fragment.node.as_mut().unwrap(), f)
315}
316
317/// Visit the tables of a [`StreamFragment`].
318///
319/// Compared to [`visit_internal_tables`], this function also visits the table of `Materialize` node.
320pub fn visit_tables<F>(fragment: &mut StreamFragment, f: F)
321where
322    F: FnMut(&mut Table, &str),
323{
324    visit_stream_node_tables(fragment.node.as_mut().unwrap(), f)
325}