use itertools::Itertools;
use risingwave_pb::catalog::Table;
use risingwave_pb::stream_plan::stream_fragment_graph::StreamFragment;
use risingwave_pb::stream_plan::stream_node::NodeBody;
use risingwave_pb::stream_plan::{agg_call_state, StreamNode};
pub fn visit_stream_node<F>(stream_node: &mut StreamNode, mut f: F)
where
F: FnMut(&mut NodeBody),
{
fn visit_inner<F>(stream_node: &mut StreamNode, f: &mut F)
where
F: FnMut(&mut NodeBody),
{
f(stream_node.node_body.as_mut().unwrap());
for input in &mut stream_node.input {
visit_inner(input, f);
}
}
visit_inner(stream_node, &mut f)
}
pub fn visit_stream_node_cont_mut<F>(stream_node: &mut StreamNode, mut f: F)
where
F: FnMut(&mut StreamNode) -> bool,
{
fn visit_inner<F>(stream_node: &mut StreamNode, f: &mut F)
where
F: FnMut(&mut StreamNode) -> bool,
{
if !f(stream_node) {
return;
}
for input in &mut stream_node.input {
visit_inner(input, f);
}
}
visit_inner(stream_node, &mut f)
}
pub fn visit_stream_node_cont<F>(stream_node: &StreamNode, mut f: F)
where
F: FnMut(&StreamNode) -> bool,
{
fn visit_inner<F>(stream_node: &StreamNode, f: &mut F)
where
F: FnMut(&StreamNode) -> bool,
{
if !f(stream_node) {
return;
}
for input in &stream_node.input {
visit_inner(input, f);
}
}
visit_inner(stream_node, &mut f)
}
pub fn visit_fragment<F>(fragment: &mut StreamFragment, f: F)
where
F: FnMut(&mut NodeBody),
{
visit_stream_node(fragment.node.as_mut().unwrap(), f)
}
pub fn visit_stream_node_tables_inner<F>(
stream_node: &mut StreamNode,
internal_tables_only: bool,
visit_child_recursively: bool,
mut f: F,
) where
F: FnMut(&mut Table, &str),
{
macro_rules! always {
($table:expr, $name:expr) => {{
let table = $table
.as_mut()
.unwrap_or_else(|| panic!("internal table {} should always exist", $name));
f(table, $name);
}};
}
macro_rules! optional {
($table:expr, $name:expr) => {
if let Some(table) = &mut $table {
f(table, $name);
}
};
}
macro_rules! repeated {
($tables:expr, $name:expr) => {
for table in &mut $tables {
f(table, $name);
}
};
}
let mut visit_body = |body: &mut NodeBody| {
match body {
NodeBody::HashJoin(node) => {
always!(node.left_table, "HashJoinLeft");
always!(node.left_degree_table, "HashJoinDegreeLeft");
always!(node.right_table, "HashJoinRight");
always!(node.right_degree_table, "HashJoinDegreeRight");
}
NodeBody::TemporalJoin(node) => {
optional!(node.memo_table, "TemporalJoinMemo");
}
NodeBody::DynamicFilter(node) => {
always!(node.left_table, "DynamicFilterLeft");
always!(node.right_table, "DynamicFilterRight");
}
NodeBody::HashAgg(node) => {
assert_eq!(node.agg_call_states.len(), node.agg_calls.len());
always!(node.intermediate_state_table, "HashAggState");
for (call_idx, state) in node.agg_call_states.iter_mut().enumerate() {
match state.inner.as_mut().unwrap() {
agg_call_state::Inner::ValueState(_) => {}
agg_call_state::Inner::MaterializedInputState(s) => {
always!(s.table, &format!("HashAggCall{}", call_idx));
}
}
}
for (distinct_col, dedup_table) in node
.distinct_dedup_tables
.iter_mut()
.sorted_by_key(|(i, _)| *i)
{
f(dedup_table, &format!("HashAggDedupForCol{}", distinct_col));
}
}
NodeBody::SimpleAgg(node) => {
assert_eq!(node.agg_call_states.len(), node.agg_calls.len());
always!(node.intermediate_state_table, "SimpleAggState");
for (call_idx, state) in node.agg_call_states.iter_mut().enumerate() {
match state.inner.as_mut().unwrap() {
agg_call_state::Inner::ValueState(_) => {}
agg_call_state::Inner::MaterializedInputState(s) => {
always!(s.table, &format!("SimpleAggCall{}", call_idx));
}
}
}
for (distinct_col, dedup_table) in node
.distinct_dedup_tables
.iter_mut()
.sorted_by_key(|(i, _)| *i)
{
f(
dedup_table,
&format!("SimpleAggDedupForCol{}", distinct_col),
);
}
}
NodeBody::AppendOnlyTopN(node) => {
always!(node.table, "AppendOnlyTopN");
}
NodeBody::TopN(node) => {
always!(node.table, "TopN");
}
NodeBody::AppendOnlyGroupTopN(node) => {
always!(node.table, "AppendOnlyGroupTopN");
}
NodeBody::GroupTopN(node) => {
always!(node.table, "GroupTopN");
}
NodeBody::Source(node) => {
if let Some(source) = &mut node.source_inner {
always!(source.state_table, "Source");
}
}
NodeBody::StreamFsFetch(node) => {
if let Some(source) = &mut node.node_inner {
always!(source.state_table, "FsFetch");
}
}
NodeBody::SourceBackfill(node) => {
always!(node.state_table, "SourceBackfill")
}
NodeBody::Sink(node) => {
optional!(node.table, "Sink")
}
NodeBody::Now(node) => {
always!(node.state_table, "Now");
}
NodeBody::WatermarkFilter(node) => {
assert!(!node.tables.is_empty());
repeated!(node.tables, "WatermarkFilter");
}
NodeBody::Arrange(node) => {
always!(node.table, "Arrange");
}
NodeBody::AppendOnlyDedup(node) => {
always!(node.state_table, "AppendOnlyDedup");
}
NodeBody::EowcOverWindow(node) => {
always!(node.state_table, "EowcOverWindow");
}
NodeBody::OverWindow(node) => {
always!(node.state_table, "OverWindow");
}
NodeBody::Sort(node) => {
always!(node.state_table, "Sort");
}
NodeBody::StreamScan(node) => {
optional!(node.state_table, "StreamScan")
}
NodeBody::StreamCdcScan(node) => {
always!(node.state_table, "StreamCdcScan")
}
NodeBody::Materialize(node) if !internal_tables_only => {
always!(node.table, "Materialize")
}
NodeBody::GlobalApproxPercentile(node) => {
always!(node.bucket_state_table, "GlobalApproxPercentileBucketState");
always!(node.count_state_table, "GlobalApproxPercentileCountState");
}
NodeBody::AsOfJoin(node) => {
always!(node.left_table, "AsOfJoinLeft");
always!(node.right_table, "AsOfJoinRight");
}
_ => {}
}
};
if visit_child_recursively {
visit_stream_node(stream_node, visit_body)
} else {
visit_body(stream_node.node_body.as_mut().unwrap())
}
}
pub fn visit_stream_node_internal_tables<F>(stream_node: &mut StreamNode, f: F)
where
F: FnMut(&mut Table, &str),
{
visit_stream_node_tables_inner(stream_node, true, true, f)
}
pub fn visit_stream_node_tables<F>(stream_node: &mut StreamNode, f: F)
where
F: FnMut(&mut Table, &str),
{
visit_stream_node_tables_inner(stream_node, false, true, f)
}
pub fn visit_internal_tables<F>(fragment: &mut StreamFragment, f: F)
where
F: FnMut(&mut Table, &str),
{
visit_stream_node_internal_tables(fragment.node.as_mut().unwrap(), f)
}
pub fn visit_tables<F>(fragment: &mut StreamFragment, f: F)
where
F: FnMut(&mut Table, &str),
{
visit_stream_node_tables(fragment.node.as_mut().unwrap(), f)
}