risingwave_frontend/optimizer/plan_node/
stream_simple_agg.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16use pretty_xmlish::XmlNode;
17use risingwave_pb::stream_plan::stream_node::PbNodeBody;
18
19use super::generic::{self, PlanAggCall};
20use super::stream::prelude::*;
21use super::utils::{Distill, childless_record, plan_node_name};
22use super::{ExprRewritable, PlanBase, PlanTreeNodeUnary, StreamNode, StreamPlanRef as PlanRef};
23use crate::expr::{ExprRewriter, ExprVisitor};
24use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
25use crate::optimizer::property::{Distribution, MonotonicityMap, WatermarkColumns};
26use crate::stream_fragmenter::BuildFragmentGraphState;
27
28#[derive(Debug, Clone, PartialEq, Eq, Hash)]
29pub struct StreamSimpleAgg {
30    pub base: PlanBase<Stream>,
31    core: generic::Agg<PlanRef>,
32
33    /// The index of `count(*)` in `agg_calls`.
34    row_count_idx: usize,
35
36    // Required by the downstream `RowMerge`,
37    // currently only used by the `approx_percentile`'s two phase plan
38    must_output_per_barrier: bool,
39}
40
41impl StreamSimpleAgg {
42    pub fn new(
43        core: generic::Agg<PlanRef>,
44        row_count_idx: usize,
45        must_output_per_barrier: bool,
46    ) -> Result<Self> {
47        assert_eq!(core.agg_calls[row_count_idx], PlanAggCall::count_star());
48        reject_upsert_input!(core.input);
49
50        let input = core.input.clone();
51        let input_dist = input.distribution();
52        let dist = match input_dist {
53            Distribution::Single => Distribution::Single,
54            _ => panic!(),
55        };
56
57        // Empty because watermark column(s) must be in group key and simple agg have no group key.
58        let watermark_columns = WatermarkColumns::new();
59
60        // Simple agg executor might change the append-only behavior of the stream.
61        let base = PlanBase::new_stream_with_core(
62            &core,
63            dist,
64            StreamKind::Retract,
65            false,
66            watermark_columns,
67            MonotonicityMap::new(),
68        );
69
70        Ok(StreamSimpleAgg {
71            base,
72            core,
73            row_count_idx,
74            must_output_per_barrier,
75        })
76    }
77
78    pub fn agg_calls(&self) -> &[PlanAggCall] {
79        &self.core.agg_calls
80    }
81}
82
83impl Distill for StreamSimpleAgg {
84    fn distill<'a>(&self) -> XmlNode<'a> {
85        let name = plan_node_name!("StreamSimpleAgg",
86            { "append_only", self.input().append_only() },
87        );
88        let mut vec = self.core.fields_pretty();
89        if self.must_output_per_barrier {
90            vec.push(("must_output_per_barrier", "true".into()));
91        }
92        childless_record(name, vec)
93    }
94}
95
96impl PlanTreeNodeUnary<Stream> for StreamSimpleAgg {
97    fn input(&self) -> PlanRef {
98        self.core.input.clone()
99    }
100
101    fn clone_with_input(&self, input: PlanRef) -> Self {
102        let logical = generic::Agg {
103            input,
104            ..self.core.clone()
105        };
106        Self::new(logical, self.row_count_idx, self.must_output_per_barrier).unwrap()
107    }
108}
109impl_plan_tree_node_for_unary! { Stream, StreamSimpleAgg }
110
111impl StreamNode for StreamSimpleAgg {
112    fn to_stream_prost_body(&self, state: &mut BuildFragmentGraphState) -> PbNodeBody {
113        use risingwave_pb::stream_plan::*;
114        let (intermediate_state_table, agg_states, distinct_dedup_tables) =
115            self.core.infer_tables(&self.base, None, None);
116
117        PbNodeBody::SimpleAgg(Box::new(SimpleAggNode {
118            agg_calls: self
119                .agg_calls()
120                .iter()
121                .map(PlanAggCall::to_protobuf)
122                .collect(),
123            is_append_only: self.input().append_only(),
124            agg_call_states: agg_states
125                .into_iter()
126                .map(|s| s.into_prost(state))
127                .collect(),
128            intermediate_state_table: Some(
129                intermediate_state_table
130                    .with_id(state.gen_table_id_wrapped())
131                    .to_internal_table_prost(),
132            ),
133            distinct_dedup_tables: distinct_dedup_tables
134                .into_iter()
135                .sorted_by_key(|(i, _)| *i)
136                .map(|(key_idx, table)| {
137                    (
138                        key_idx as u32,
139                        table
140                            .with_id(state.gen_table_id_wrapped())
141                            .to_internal_table_prost(),
142                    )
143                })
144                .collect(),
145            row_count_index: self.row_count_idx as u32,
146            version: PbAggNodeVersion::LATEST as _,
147            must_output_per_barrier: self.must_output_per_barrier,
148        }))
149    }
150}
151
152impl ExprRewritable<Stream> for StreamSimpleAgg {
153    fn has_rewritable_expr(&self) -> bool {
154        true
155    }
156
157    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
158        let mut core = self.core.clone();
159        core.rewrite_exprs(r);
160        Self::new(core, self.row_count_idx, self.must_output_per_barrier)
161            .unwrap()
162            .into()
163    }
164}
165
166impl ExprVisitable for StreamSimpleAgg {
167    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
168        self.core.visit_exprs(v);
169    }
170}