risingwave_common/util/
stream_graph_visitor.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16use risingwave_pb::catalog::Table;
17use risingwave_pb::stream_plan::stream_fragment_graph::StreamFragment;
18use risingwave_pb::stream_plan::stream_node::NodeBody;
19use risingwave_pb::stream_plan::{SourceBackfillNode, StreamNode, StreamScanNode, agg_call_state};
20
21#[macro_export]
22macro_rules! dispatch_stream_node_body {
23    ($body:expr, $node_body:ident, $node:ident => $call:expr) => {
24        ::risingwave_pb::__dispatch_stream_node_body!($body, $node_body, $node => $call)
25    };
26}
27
28/// A utility for visiting and mutating the [`NodeBody`] of the [`StreamNode`]s recursively.
29pub fn visit_stream_node_mut(stream_node: &mut StreamNode, mut f: impl FnMut(&mut NodeBody)) {
30    visit_stream_node_cont_mut(stream_node, |stream_node| {
31        f(stream_node.node_body.as_mut().unwrap());
32        true
33    })
34}
35
36/// A utility for to accessing the [`StreamNode`] mutably. The returned bool is used to determine whether the access needs to continue.
37pub fn visit_stream_node_cont_mut<F>(stream_node: &mut StreamNode, mut f: F)
38where
39    F: FnMut(&mut StreamNode) -> bool,
40{
41    fn visit_inner<F>(stream_node: &mut StreamNode, f: &mut F)
42    where
43        F: FnMut(&mut StreamNode) -> bool,
44    {
45        if !f(stream_node) {
46            return;
47        }
48        for input in &mut stream_node.input {
49            visit_inner(input, f);
50        }
51    }
52
53    visit_inner(stream_node, &mut f)
54}
55
56/// A utility for visiting the [`NodeBody`] of the [`StreamNode`]s recursively.
57pub fn visit_stream_node_body(stream_node: &StreamNode, mut f: impl FnMut(&NodeBody)) {
58    visit_stream_node_cont(stream_node, |stream_node| {
59        f(stream_node.node_body.as_ref().unwrap());
60        true
61    })
62}
63
64pub fn visit_stream_node(stream_node: &StreamNode, mut f: impl FnMut(&StreamNode)) {
65    visit_stream_node_cont(stream_node, |stream_node| {
66        f(stream_node);
67        true
68    })
69}
70
71/// A utility for to accessing the [`StreamNode`] immutably. The returned bool is used to determine whether the access needs to continue.
72pub fn visit_stream_node_cont<'a, F>(stream_node: &'a StreamNode, mut f: F)
73where
74    F: FnMut(&'a StreamNode) -> bool,
75{
76    fn visit_inner<'a, F>(stream_node: &'a StreamNode, f: &mut F)
77    where
78        F: FnMut(&'a StreamNode) -> bool,
79    {
80        if !f(stream_node) {
81            return;
82        }
83        for input in &stream_node.input {
84            visit_inner(input, f);
85        }
86    }
87
88    visit_inner(stream_node, &mut f)
89}
90
91/// A utility for visiting and mutating the [`NodeBody`] of the [`StreamNode`]s in a
92/// [`StreamFragment`] recursively.
93pub fn visit_fragment_mut(fragment: &mut StreamFragment, f: impl FnMut(&mut NodeBody)) {
94    visit_stream_node_mut(fragment.node.as_mut().unwrap(), f)
95}
96
97/// A utility for visiting the [`NodeBody`] of the [`StreamNode`]s in a
98/// [`StreamFragment`] recursively.
99pub fn visit_fragment(fragment: &StreamFragment, f: impl FnMut(&NodeBody)) {
100    visit_stream_node_body(fragment.node.as_ref().unwrap(), f)
101}
102
103/// Visit the tables of a [`StreamNode`].
104pub fn visit_stream_node_tables_inner<F>(
105    stream_node: &mut StreamNode,
106    internal_tables_only: bool,
107    visit_child_recursively: bool,
108    mut f: F,
109) where
110    F: FnMut(&mut Table, &str),
111{
112    macro_rules! always {
113        ($table:expr, $name:expr) => {{
114            let table = $table
115                .as_mut()
116                .unwrap_or_else(|| panic!("internal table {} should always exist", $name));
117            f(table, $name);
118        }};
119    }
120
121    macro_rules! optional {
122        ($table:expr, $name:expr) => {
123            if let Some(table) = &mut $table {
124                f(table, $name);
125            }
126        };
127    }
128
129    macro_rules! repeated {
130        ($tables:expr, $name:expr) => {
131            for table in &mut $tables {
132                f(table, $name);
133            }
134        };
135    }
136
137    let mut visit_body = |body: &mut NodeBody| {
138        match body {
139            // Join
140            NodeBody::HashJoin(node) => {
141                // TODO: make the degree table optional
142                always!(node.left_table, "HashJoinLeft");
143                always!(node.left_degree_table, "HashJoinDegreeLeft");
144                always!(node.right_table, "HashJoinRight");
145                always!(node.right_degree_table, "HashJoinDegreeRight");
146            }
147            NodeBody::TemporalJoin(node) => {
148                optional!(node.memo_table, "TemporalJoinMemo");
149            }
150            NodeBody::DynamicFilter(node) => {
151                always!(node.left_table, "DynamicFilterLeft");
152                always!(node.right_table, "DynamicFilterRight");
153            }
154
155            // Aggregation
156            NodeBody::HashAgg(node) => {
157                assert_eq!(node.agg_call_states.len(), node.agg_calls.len());
158                always!(node.intermediate_state_table, "HashAggState");
159                for (call_idx, state) in node.agg_call_states.iter_mut().enumerate() {
160                    match state.inner.as_mut().unwrap() {
161                        agg_call_state::Inner::ValueState(_) => {}
162                        agg_call_state::Inner::MaterializedInputState(s) => {
163                            always!(s.table, &format!("HashAggCall{}", call_idx));
164                        }
165                    }
166                }
167                for (distinct_col, dedup_table) in node
168                    .distinct_dedup_tables
169                    .iter_mut()
170                    .sorted_by_key(|(i, _)| *i)
171                {
172                    f(dedup_table, &format!("HashAggDedupForCol{}", distinct_col));
173                }
174            }
175            NodeBody::SimpleAgg(node) => {
176                assert_eq!(node.agg_call_states.len(), node.agg_calls.len());
177                always!(node.intermediate_state_table, "SimpleAggState");
178                for (call_idx, state) in node.agg_call_states.iter_mut().enumerate() {
179                    match state.inner.as_mut().unwrap() {
180                        agg_call_state::Inner::ValueState(_) => {}
181                        agg_call_state::Inner::MaterializedInputState(s) => {
182                            always!(s.table, &format!("SimpleAggCall{}", call_idx));
183                        }
184                    }
185                }
186                for (distinct_col, dedup_table) in node
187                    .distinct_dedup_tables
188                    .iter_mut()
189                    .sorted_by_key(|(i, _)| *i)
190                {
191                    f(
192                        dedup_table,
193                        &format!("SimpleAggDedupForCol{}", distinct_col),
194                    );
195                }
196            }
197
198            // Top-N
199            NodeBody::AppendOnlyTopN(node) => {
200                always!(node.table, "AppendOnlyTopN");
201            }
202            NodeBody::TopN(node) => {
203                always!(node.table, "TopN");
204            }
205            NodeBody::AppendOnlyGroupTopN(node) => {
206                always!(node.table, "AppendOnlyGroupTopN");
207            }
208            NodeBody::GroupTopN(node) => {
209                always!(node.table, "GroupTopN");
210            }
211
212            // Source
213            NodeBody::Source(node) => {
214                if let Some(source) = &mut node.source_inner {
215                    always!(source.state_table, "Source");
216                }
217            }
218            NodeBody::StreamFsFetch(node) => {
219                if let Some(source) = &mut node.node_inner {
220                    always!(source.state_table, "FsFetch");
221                }
222            }
223            NodeBody::SourceBackfill(node) => {
224                always!(node.state_table, "SourceBackfill")
225            }
226
227            // Sink
228            NodeBody::Sink(node) => {
229                // A sink with a kv log store should have a state table.
230                optional!(node.table, "Sink")
231            }
232
233            // Now
234            NodeBody::Now(node) => {
235                always!(node.state_table, "Now");
236            }
237
238            // Watermark filter
239            NodeBody::WatermarkFilter(node) => {
240                assert!(!node.tables.is_empty());
241                repeated!(node.tables, "WatermarkFilter");
242            }
243
244            // Shared arrangement
245            NodeBody::Arrange(node) => {
246                always!(node.table, "Arrange");
247            }
248
249            // Dedup
250            NodeBody::AppendOnlyDedup(node) => {
251                always!(node.state_table, "AppendOnlyDedup");
252            }
253
254            // EOWC over window
255            NodeBody::EowcOverWindow(node) => {
256                always!(node.state_table, "EowcOverWindow");
257                optional!(
258                    node.intermediate_state_table,
259                    "EowcOverWindowIntermediateState"
260                );
261            }
262
263            NodeBody::OverWindow(node) => {
264                always!(node.state_table, "OverWindow");
265            }
266
267            // Sort
268            NodeBody::Sort(node) => {
269                always!(node.state_table, "Sort");
270            }
271
272            // Stream Scan
273            NodeBody::StreamScan(node) => {
274                optional!(node.state_table, "StreamScan")
275            }
276
277            // Stream Cdc Scan
278            NodeBody::StreamCdcScan(node) => {
279                always!(node.state_table, "StreamCdcScan")
280            }
281
282            NodeBody::EowcGapFill(node) => {
283                always!(node.buffer_table, "EowcGapFillBufferTable");
284                always!(node.prev_row_table, "EowcGapFillPrevRowTable");
285            }
286
287            NodeBody::GapFill(node) => {
288                always!(node.state_table, "GapFillStateTable");
289            }
290
291            // Note: add internal tables for new nodes here.
292            NodeBody::Materialize(node) => {
293                if !internal_tables_only {
294                    // Note: Plan directly generated by frontend may not have the table filled,
295                    // as it is filled by the meta service during building fragments.
296                    // However, if caller requests to visit this table via `!internal_tables_only`,
297                    // we still assert on its existence to avoid undefined behavior.
298                    always!(node.table, "Materialize");
299                }
300                // Also visit the staging table if it exists (for refreshable tables)
301                optional!(node.staging_table, "MaterializeStaging");
302                optional!(node.refresh_progress_table, "MaterializeRefreshProgress");
303            }
304
305            // Global Approx Percentile
306            NodeBody::GlobalApproxPercentile(node) => {
307                always!(node.bucket_state_table, "GlobalApproxPercentileBucketState");
308                always!(node.count_state_table, "GlobalApproxPercentileCountState");
309            }
310
311            // AsOf join
312            NodeBody::AsOfJoin(node) => {
313                always!(node.left_table, "AsOfJoinLeft");
314                always!(node.right_table, "AsOfJoinRight");
315            }
316
317            // Synced Log Store
318            NodeBody::SyncLogStore(node) => {
319                always!(node.log_store_table, "StreamSyncLogStore");
320            }
321
322            // MaterializedExprs
323            NodeBody::MaterializedExprs(node) => {
324                always!(node.state_table, "MaterializedExprs");
325            }
326
327            // Vector Index Write
328            NodeBody::VectorIndexWrite(node) if !internal_tables_only => {
329                always!(node.table, "StreamVectorIndexWrite");
330            }
331
332            NodeBody::LocalityProvider(node) => {
333                always!(node.state_table, "LocalityProviderState");
334                always!(node.progress_table, "LocalityProviderProgress");
335            }
336
337            NodeBody::IcebergWithPkIndexWriter(node) => {
338                always!(node.pk_index_table, "IcebergWithPkIndexWriter");
339            }
340
341            _ => {}
342        }
343    };
344    if visit_child_recursively {
345        visit_stream_node_mut(stream_node, visit_body)
346    } else {
347        visit_body(stream_node.node_body.as_mut().unwrap())
348    }
349}
350
351pub fn visit_stream_node_stream_scan(stream_node: &StreamNode, mut f: impl FnMut(&StreamScanNode)) {
352    visit_stream_node_body(stream_node, |body| {
353        if let NodeBody::StreamScan(node) = body {
354            f(node)
355        }
356    })
357}
358
359pub fn visit_stream_node_source_backfill(
360    stream_node: &StreamNode,
361    mut f: impl FnMut(&SourceBackfillNode),
362) {
363    visit_stream_node_body(stream_node, |body| {
364        if let NodeBody::SourceBackfill(node) = body {
365            f(node)
366        }
367    })
368}
369
370pub fn visit_stream_node_internal_tables<F>(stream_node: &mut StreamNode, f: F)
371where
372    F: FnMut(&mut Table, &str),
373{
374    visit_stream_node_tables_inner(stream_node, true, true, f)
375}
376
377pub fn visit_stream_node_tables<F>(stream_node: &mut StreamNode, f: F)
378where
379    F: FnMut(&mut Table, &str),
380{
381    visit_stream_node_tables_inner(stream_node, false, true, f)
382}
383
384/// Visit the internal tables of a [`StreamFragment`].
385pub fn visit_internal_tables<F>(fragment: &mut StreamFragment, f: F)
386where
387    F: FnMut(&mut Table, &str),
388{
389    visit_stream_node_internal_tables(fragment.node.as_mut().unwrap(), f)
390}
391
392/// Visit the tables of a [`StreamFragment`].
393///
394/// Compared to [`visit_internal_tables`], this function also visits the table of `Materialize` node.
395pub fn visit_tables<F>(fragment: &mut StreamFragment, f: F)
396where
397    F: FnMut(&mut Table, &str),
398{
399    visit_stream_node_tables(fragment.node.as_mut().unwrap(), f)
400}