1use itertools::Itertools;
16use risingwave_pb::catalog::Table;
17use risingwave_pb::stream_plan::stream_fragment_graph::StreamFragment;
18use risingwave_pb::stream_plan::stream_node::NodeBody;
19use risingwave_pb::stream_plan::{SourceBackfillNode, StreamNode, StreamScanNode, agg_call_state};
20
21pub fn visit_stream_node_mut(stream_node: &mut StreamNode, mut f: impl FnMut(&mut NodeBody)) {
23 visit_stream_node_cont_mut(stream_node, |stream_node| {
24 f(stream_node.node_body.as_mut().unwrap());
25 true
26 })
27}
28
29pub fn visit_stream_node_cont_mut<F>(stream_node: &mut StreamNode, mut f: F)
31where
32 F: FnMut(&mut StreamNode) -> bool,
33{
34 fn visit_inner<F>(stream_node: &mut StreamNode, f: &mut F)
35 where
36 F: FnMut(&mut StreamNode) -> bool,
37 {
38 if !f(stream_node) {
39 return;
40 }
41 for input in &mut stream_node.input {
42 visit_inner(input, f);
43 }
44 }
45
46 visit_inner(stream_node, &mut f)
47}
48
49pub fn visit_stream_node_body(stream_node: &StreamNode, mut f: impl FnMut(&NodeBody)) {
51 visit_stream_node_cont(stream_node, |stream_node| {
52 f(stream_node.node_body.as_ref().unwrap());
53 true
54 })
55}
56
57pub fn visit_stream_node(stream_node: &StreamNode, mut f: impl FnMut(&StreamNode)) {
58 visit_stream_node_cont(stream_node, |stream_node| {
59 f(stream_node);
60 true
61 })
62}
63
64pub fn visit_stream_node_cont<F>(stream_node: &StreamNode, mut f: F)
66where
67 F: FnMut(&StreamNode) -> bool,
68{
69 fn visit_inner<F>(stream_node: &StreamNode, f: &mut F)
70 where
71 F: FnMut(&StreamNode) -> bool,
72 {
73 if !f(stream_node) {
74 return;
75 }
76 for input in &stream_node.input {
77 visit_inner(input, f);
78 }
79 }
80
81 visit_inner(stream_node, &mut f)
82}
83
84pub fn visit_fragment_mut(fragment: &mut StreamFragment, f: impl FnMut(&mut NodeBody)) {
87 visit_stream_node_mut(fragment.node.as_mut().unwrap(), f)
88}
89
90pub fn visit_fragment(fragment: &StreamFragment, f: impl FnMut(&NodeBody)) {
93 visit_stream_node_body(fragment.node.as_ref().unwrap(), f)
94}
95
96pub fn visit_stream_node_tables_inner<F>(
98 stream_node: &mut StreamNode,
99 internal_tables_only: bool,
100 visit_child_recursively: bool,
101 mut f: F,
102) where
103 F: FnMut(&mut Table, &str),
104{
105 macro_rules! always {
106 ($table:expr, $name:expr) => {{
107 let table = $table
108 .as_mut()
109 .unwrap_or_else(|| panic!("internal table {} should always exist", $name));
110 f(table, $name);
111 }};
112 }
113
114 macro_rules! optional {
115 ($table:expr, $name:expr) => {
116 if let Some(table) = &mut $table {
117 f(table, $name);
118 }
119 };
120 }
121
122 macro_rules! repeated {
123 ($tables:expr, $name:expr) => {
124 for table in &mut $tables {
125 f(table, $name);
126 }
127 };
128 }
129
130 let mut visit_body = |body: &mut NodeBody| {
131 match body {
132 NodeBody::HashJoin(node) => {
134 always!(node.left_table, "HashJoinLeft");
136 always!(node.left_degree_table, "HashJoinDegreeLeft");
137 always!(node.right_table, "HashJoinRight");
138 always!(node.right_degree_table, "HashJoinDegreeRight");
139 }
140 NodeBody::TemporalJoin(node) => {
141 optional!(node.memo_table, "TemporalJoinMemo");
142 }
143 NodeBody::DynamicFilter(node) => {
144 always!(node.left_table, "DynamicFilterLeft");
145 always!(node.right_table, "DynamicFilterRight");
146 }
147
148 NodeBody::HashAgg(node) => {
150 assert_eq!(node.agg_call_states.len(), node.agg_calls.len());
151 always!(node.intermediate_state_table, "HashAggState");
152 for (call_idx, state) in node.agg_call_states.iter_mut().enumerate() {
153 match state.inner.as_mut().unwrap() {
154 agg_call_state::Inner::ValueState(_) => {}
155 agg_call_state::Inner::MaterializedInputState(s) => {
156 always!(s.table, &format!("HashAggCall{}", call_idx));
157 }
158 }
159 }
160 for (distinct_col, dedup_table) in node
161 .distinct_dedup_tables
162 .iter_mut()
163 .sorted_by_key(|(i, _)| *i)
164 {
165 f(dedup_table, &format!("HashAggDedupForCol{}", distinct_col));
166 }
167 }
168 NodeBody::SimpleAgg(node) => {
169 assert_eq!(node.agg_call_states.len(), node.agg_calls.len());
170 always!(node.intermediate_state_table, "SimpleAggState");
171 for (call_idx, state) in node.agg_call_states.iter_mut().enumerate() {
172 match state.inner.as_mut().unwrap() {
173 agg_call_state::Inner::ValueState(_) => {}
174 agg_call_state::Inner::MaterializedInputState(s) => {
175 always!(s.table, &format!("SimpleAggCall{}", call_idx));
176 }
177 }
178 }
179 for (distinct_col, dedup_table) in node
180 .distinct_dedup_tables
181 .iter_mut()
182 .sorted_by_key(|(i, _)| *i)
183 {
184 f(
185 dedup_table,
186 &format!("SimpleAggDedupForCol{}", distinct_col),
187 );
188 }
189 }
190
191 NodeBody::AppendOnlyTopN(node) => {
193 always!(node.table, "AppendOnlyTopN");
194 }
195 NodeBody::TopN(node) => {
196 always!(node.table, "TopN");
197 }
198 NodeBody::AppendOnlyGroupTopN(node) => {
199 always!(node.table, "AppendOnlyGroupTopN");
200 }
201 NodeBody::GroupTopN(node) => {
202 always!(node.table, "GroupTopN");
203 }
204
205 NodeBody::Source(node) => {
207 if let Some(source) = &mut node.source_inner {
208 always!(source.state_table, "Source");
209 }
210 }
211 NodeBody::StreamFsFetch(node) => {
212 if let Some(source) = &mut node.node_inner {
213 always!(source.state_table, "FsFetch");
214 }
215 }
216 NodeBody::SourceBackfill(node) => {
217 always!(node.state_table, "SourceBackfill")
218 }
219
220 NodeBody::Sink(node) => {
222 optional!(node.table, "Sink")
224 }
225
226 NodeBody::Now(node) => {
228 always!(node.state_table, "Now");
229 }
230
231 NodeBody::WatermarkFilter(node) => {
233 assert!(!node.tables.is_empty());
234 repeated!(node.tables, "WatermarkFilter");
235 }
236
237 NodeBody::Arrange(node) => {
239 always!(node.table, "Arrange");
240 }
241
242 NodeBody::AppendOnlyDedup(node) => {
244 always!(node.state_table, "AppendOnlyDedup");
245 }
246
247 NodeBody::EowcOverWindow(node) => {
249 always!(node.state_table, "EowcOverWindow");
250 }
251
252 NodeBody::OverWindow(node) => {
253 always!(node.state_table, "OverWindow");
254 }
255
256 NodeBody::Sort(node) => {
258 always!(node.state_table, "Sort");
259 }
260
261 NodeBody::StreamScan(node) => {
263 optional!(node.state_table, "StreamScan")
264 }
265
266 NodeBody::StreamCdcScan(node) => {
268 always!(node.state_table, "StreamCdcScan")
269 }
270
271 NodeBody::Materialize(node) if !internal_tables_only => {
273 always!(node.table, "Materialize")
278 }
279
280 NodeBody::GlobalApproxPercentile(node) => {
282 always!(node.bucket_state_table, "GlobalApproxPercentileBucketState");
283 always!(node.count_state_table, "GlobalApproxPercentileCountState");
284 }
285
286 NodeBody::AsOfJoin(node) => {
288 always!(node.left_table, "AsOfJoinLeft");
289 always!(node.right_table, "AsOfJoinRight");
290 }
291
292 NodeBody::SyncLogStore(node) => {
294 always!(node.log_store_table, "StreamSyncLogStore");
295 }
296
297 NodeBody::MaterializedExprs(node) => {
299 always!(node.state_table, "MaterializedExprs");
300 }
301
302 _ => {}
303 }
304 };
305 if visit_child_recursively {
306 visit_stream_node_mut(stream_node, visit_body)
307 } else {
308 visit_body(stream_node.node_body.as_mut().unwrap())
309 }
310}
311
312pub fn visit_stream_node_stream_scan(stream_node: &StreamNode, mut f: impl FnMut(&StreamScanNode)) {
313 visit_stream_node_body(stream_node, |body| {
314 if let NodeBody::StreamScan(node) = body {
315 f(node)
316 }
317 })
318}
319
320pub fn visit_stream_node_source_backfill(
321 stream_node: &StreamNode,
322 mut f: impl FnMut(&SourceBackfillNode),
323) {
324 visit_stream_node_body(stream_node, |body| {
325 if let NodeBody::SourceBackfill(node) = body {
326 f(node)
327 }
328 })
329}
330
331pub fn visit_stream_node_internal_tables<F>(stream_node: &mut StreamNode, f: F)
332where
333 F: FnMut(&mut Table, &str),
334{
335 visit_stream_node_tables_inner(stream_node, true, true, f)
336}
337
338pub fn visit_stream_node_tables<F>(stream_node: &mut StreamNode, f: F)
339where
340 F: FnMut(&mut Table, &str),
341{
342 visit_stream_node_tables_inner(stream_node, false, true, f)
343}
344
345pub fn visit_internal_tables<F>(fragment: &mut StreamFragment, f: F)
347where
348 F: FnMut(&mut Table, &str),
349{
350 visit_stream_node_internal_tables(fragment.node.as_mut().unwrap(), f)
351}
352
353pub fn visit_tables<F>(fragment: &mut StreamFragment, f: F)
357where
358 F: FnMut(&mut Table, &str),
359{
360 visit_stream_node_tables(fragment.node.as_mut().unwrap(), f)
361}