risingwave_common/util/
stream_graph_visitor.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16use risingwave_pb::catalog::Table;
17use risingwave_pb::stream_plan::stream_fragment_graph::StreamFragment;
18use risingwave_pb::stream_plan::stream_node::NodeBody;
19use risingwave_pb::stream_plan::{SourceBackfillNode, StreamNode, StreamScanNode, agg_call_state};
20
21/// A utility for visiting and mutating the [`NodeBody`] of the [`StreamNode`]s recursively.
22pub fn visit_stream_node_mut(stream_node: &mut StreamNode, mut f: impl FnMut(&mut NodeBody)) {
23    visit_stream_node_cont_mut(stream_node, |stream_node| {
24        f(stream_node.node_body.as_mut().unwrap());
25        true
26    })
27}
28
29/// A utility for to accessing the [`StreamNode`] mutably. The returned bool is used to determine whether the access needs to continue.
30pub fn visit_stream_node_cont_mut<F>(stream_node: &mut StreamNode, mut f: F)
31where
32    F: FnMut(&mut StreamNode) -> bool,
33{
34    fn visit_inner<F>(stream_node: &mut StreamNode, f: &mut F)
35    where
36        F: FnMut(&mut StreamNode) -> bool,
37    {
38        if !f(stream_node) {
39            return;
40        }
41        for input in &mut stream_node.input {
42            visit_inner(input, f);
43        }
44    }
45
46    visit_inner(stream_node, &mut f)
47}
48
49/// A utility for visiting the [`NodeBody`] of the [`StreamNode`]s recursively.
50pub fn visit_stream_node_body(stream_node: &StreamNode, mut f: impl FnMut(&NodeBody)) {
51    visit_stream_node_cont(stream_node, |stream_node| {
52        f(stream_node.node_body.as_ref().unwrap());
53        true
54    })
55}
56
57pub fn visit_stream_node(stream_node: &StreamNode, mut f: impl FnMut(&StreamNode)) {
58    visit_stream_node_cont(stream_node, |stream_node| {
59        f(stream_node);
60        true
61    })
62}
63
64/// A utility for to accessing the [`StreamNode`] immutably. The returned bool is used to determine whether the access needs to continue.
65pub fn visit_stream_node_cont<F>(stream_node: &StreamNode, mut f: F)
66where
67    F: FnMut(&StreamNode) -> bool,
68{
69    fn visit_inner<F>(stream_node: &StreamNode, f: &mut F)
70    where
71        F: FnMut(&StreamNode) -> bool,
72    {
73        if !f(stream_node) {
74            return;
75        }
76        for input in &stream_node.input {
77            visit_inner(input, f);
78        }
79    }
80
81    visit_inner(stream_node, &mut f)
82}
83
84/// A utility for visiting and mutating the [`NodeBody`] of the [`StreamNode`]s in a
85/// [`StreamFragment`] recursively.
86pub fn visit_fragment_mut(fragment: &mut StreamFragment, f: impl FnMut(&mut NodeBody)) {
87    visit_stream_node_mut(fragment.node.as_mut().unwrap(), f)
88}
89
90/// A utility for visiting the [`NodeBody`] of the [`StreamNode`]s in a
91/// [`StreamFragment`] recursively.
92pub fn visit_fragment(fragment: &StreamFragment, f: impl FnMut(&NodeBody)) {
93    visit_stream_node_body(fragment.node.as_ref().unwrap(), f)
94}
95
96/// Visit the tables of a [`StreamNode`].
97pub fn visit_stream_node_tables_inner<F>(
98    stream_node: &mut StreamNode,
99    internal_tables_only: bool,
100    visit_child_recursively: bool,
101    mut f: F,
102) where
103    F: FnMut(&mut Table, &str),
104{
105    macro_rules! always {
106        ($table:expr, $name:expr) => {{
107            let table = $table
108                .as_mut()
109                .unwrap_or_else(|| panic!("internal table {} should always exist", $name));
110            f(table, $name);
111        }};
112    }
113
114    macro_rules! optional {
115        ($table:expr, $name:expr) => {
116            if let Some(table) = &mut $table {
117                f(table, $name);
118            }
119        };
120    }
121
122    macro_rules! repeated {
123        ($tables:expr, $name:expr) => {
124            for table in &mut $tables {
125                f(table, $name);
126            }
127        };
128    }
129
130    let mut visit_body = |body: &mut NodeBody| {
131        match body {
132            // Join
133            NodeBody::HashJoin(node) => {
134                // TODO: make the degree table optional
135                always!(node.left_table, "HashJoinLeft");
136                always!(node.left_degree_table, "HashJoinDegreeLeft");
137                always!(node.right_table, "HashJoinRight");
138                always!(node.right_degree_table, "HashJoinDegreeRight");
139            }
140            NodeBody::TemporalJoin(node) => {
141                optional!(node.memo_table, "TemporalJoinMemo");
142            }
143            NodeBody::DynamicFilter(node) => {
144                always!(node.left_table, "DynamicFilterLeft");
145                always!(node.right_table, "DynamicFilterRight");
146            }
147
148            // Aggregation
149            NodeBody::HashAgg(node) => {
150                assert_eq!(node.agg_call_states.len(), node.agg_calls.len());
151                always!(node.intermediate_state_table, "HashAggState");
152                for (call_idx, state) in node.agg_call_states.iter_mut().enumerate() {
153                    match state.inner.as_mut().unwrap() {
154                        agg_call_state::Inner::ValueState(_) => {}
155                        agg_call_state::Inner::MaterializedInputState(s) => {
156                            always!(s.table, &format!("HashAggCall{}", call_idx));
157                        }
158                    }
159                }
160                for (distinct_col, dedup_table) in node
161                    .distinct_dedup_tables
162                    .iter_mut()
163                    .sorted_by_key(|(i, _)| *i)
164                {
165                    f(dedup_table, &format!("HashAggDedupForCol{}", distinct_col));
166                }
167            }
168            NodeBody::SimpleAgg(node) => {
169                assert_eq!(node.agg_call_states.len(), node.agg_calls.len());
170                always!(node.intermediate_state_table, "SimpleAggState");
171                for (call_idx, state) in node.agg_call_states.iter_mut().enumerate() {
172                    match state.inner.as_mut().unwrap() {
173                        agg_call_state::Inner::ValueState(_) => {}
174                        agg_call_state::Inner::MaterializedInputState(s) => {
175                            always!(s.table, &format!("SimpleAggCall{}", call_idx));
176                        }
177                    }
178                }
179                for (distinct_col, dedup_table) in node
180                    .distinct_dedup_tables
181                    .iter_mut()
182                    .sorted_by_key(|(i, _)| *i)
183                {
184                    f(
185                        dedup_table,
186                        &format!("SimpleAggDedupForCol{}", distinct_col),
187                    );
188                }
189            }
190
191            // Top-N
192            NodeBody::AppendOnlyTopN(node) => {
193                always!(node.table, "AppendOnlyTopN");
194            }
195            NodeBody::TopN(node) => {
196                always!(node.table, "TopN");
197            }
198            NodeBody::AppendOnlyGroupTopN(node) => {
199                always!(node.table, "AppendOnlyGroupTopN");
200            }
201            NodeBody::GroupTopN(node) => {
202                always!(node.table, "GroupTopN");
203            }
204
205            // Source
206            NodeBody::Source(node) => {
207                if let Some(source) = &mut node.source_inner {
208                    always!(source.state_table, "Source");
209                }
210            }
211            NodeBody::StreamFsFetch(node) => {
212                if let Some(source) = &mut node.node_inner {
213                    always!(source.state_table, "FsFetch");
214                }
215            }
216            NodeBody::SourceBackfill(node) => {
217                always!(node.state_table, "SourceBackfill")
218            }
219
220            // Sink
221            NodeBody::Sink(node) => {
222                // A sink with a kv log store should have a state table.
223                optional!(node.table, "Sink")
224            }
225
226            // Now
227            NodeBody::Now(node) => {
228                always!(node.state_table, "Now");
229            }
230
231            // Watermark filter
232            NodeBody::WatermarkFilter(node) => {
233                assert!(!node.tables.is_empty());
234                repeated!(node.tables, "WatermarkFilter");
235            }
236
237            // Shared arrangement
238            NodeBody::Arrange(node) => {
239                always!(node.table, "Arrange");
240            }
241
242            // Dedup
243            NodeBody::AppendOnlyDedup(node) => {
244                always!(node.state_table, "AppendOnlyDedup");
245            }
246
247            // EOWC over window
248            NodeBody::EowcOverWindow(node) => {
249                always!(node.state_table, "EowcOverWindow");
250            }
251
252            NodeBody::OverWindow(node) => {
253                always!(node.state_table, "OverWindow");
254            }
255
256            // Sort
257            NodeBody::Sort(node) => {
258                always!(node.state_table, "Sort");
259            }
260
261            // Stream Scan
262            NodeBody::StreamScan(node) => {
263                optional!(node.state_table, "StreamScan")
264            }
265
266            // Stream Cdc Scan
267            NodeBody::StreamCdcScan(node) => {
268                always!(node.state_table, "StreamCdcScan")
269            }
270
271            // Note: add internal tables for new nodes here.
272            NodeBody::Materialize(node) if !internal_tables_only => {
273                // Note: Plan directly generated by frontend may not have the table filled,
274                // as it is filled by the meta service during building fragments.
275                // However, if caller requests to visit this table via `!internal_tables_only`,
276                // we still assert on its existence to avoid undefined behavior.
277                always!(node.table, "Materialize")
278            }
279
280            // Global Approx Percentile
281            NodeBody::GlobalApproxPercentile(node) => {
282                always!(node.bucket_state_table, "GlobalApproxPercentileBucketState");
283                always!(node.count_state_table, "GlobalApproxPercentileCountState");
284            }
285
286            // AsOf join
287            NodeBody::AsOfJoin(node) => {
288                always!(node.left_table, "AsOfJoinLeft");
289                always!(node.right_table, "AsOfJoinRight");
290            }
291
292            // Synced Log Store
293            NodeBody::SyncLogStore(node) => {
294                always!(node.log_store_table, "StreamSyncLogStore");
295            }
296
297            // MaterializedExprs
298            NodeBody::MaterializedExprs(node) => {
299                always!(node.state_table, "MaterializedExprs");
300            }
301
302            _ => {}
303        }
304    };
305    if visit_child_recursively {
306        visit_stream_node_mut(stream_node, visit_body)
307    } else {
308        visit_body(stream_node.node_body.as_mut().unwrap())
309    }
310}
311
312pub fn visit_stream_node_stream_scan(stream_node: &StreamNode, mut f: impl FnMut(&StreamScanNode)) {
313    visit_stream_node_body(stream_node, |body| {
314        if let NodeBody::StreamScan(node) = body {
315            f(node)
316        }
317    })
318}
319
320pub fn visit_stream_node_source_backfill(
321    stream_node: &StreamNode,
322    mut f: impl FnMut(&SourceBackfillNode),
323) {
324    visit_stream_node_body(stream_node, |body| {
325        if let NodeBody::SourceBackfill(node) = body {
326            f(node)
327        }
328    })
329}
330
331pub fn visit_stream_node_internal_tables<F>(stream_node: &mut StreamNode, f: F)
332where
333    F: FnMut(&mut Table, &str),
334{
335    visit_stream_node_tables_inner(stream_node, true, true, f)
336}
337
338pub fn visit_stream_node_tables<F>(stream_node: &mut StreamNode, f: F)
339where
340    F: FnMut(&mut Table, &str),
341{
342    visit_stream_node_tables_inner(stream_node, false, true, f)
343}
344
345/// Visit the internal tables of a [`StreamFragment`].
346pub fn visit_internal_tables<F>(fragment: &mut StreamFragment, f: F)
347where
348    F: FnMut(&mut Table, &str),
349{
350    visit_stream_node_internal_tables(fragment.node.as_mut().unwrap(), f)
351}
352
353/// Visit the tables of a [`StreamFragment`].
354///
355/// Compared to [`visit_internal_tables`], this function also visits the table of `Materialize` node.
356pub fn visit_tables<F>(fragment: &mut StreamFragment, f: F)
357where
358    F: FnMut(&mut Table, &str),
359{
360    visit_stream_node_tables(fragment.node.as_mut().unwrap(), f)
361}