risingwave_frontend/optimizer/plan_node/
stream_hash_agg.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16use pretty_xmlish::XmlNode;
17use risingwave_pb::stream_plan::stream_node::PbNodeBody;
18
19use super::generic::{self, PlanAggCall};
20use super::stream::prelude::*;
21use super::utils::{Distill, childless_record, plan_node_name, watermark_pretty};
22use super::{ExprRewritable, PlanBase, PlanTreeNodeUnary, StreamNode, StreamPlanRef as PlanRef};
23use crate::error::Result;
24use crate::expr::{ExprRewriter, ExprVisitor};
25use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
26use crate::optimizer::property::{MonotonicityMap, WatermarkColumns};
27use crate::stream_fragmenter::BuildFragmentGraphState;
28use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, IndexSet};
29
30#[derive(Debug, Clone, PartialEq, Eq, Hash)]
31pub struct StreamHashAgg {
32    pub base: PlanBase<Stream>,
33    core: generic::Agg<PlanRef>,
34
35    /// An optional column index which is the vnode of each row computed by the input's consistent
36    /// hash distribution.
37    vnode_col_idx: Option<usize>,
38
39    /// The index of `count(*)` in `agg_calls`.
40    row_count_idx: usize,
41
42    /// Whether to emit output only when the window is closed by watermark.
43    emit_on_window_close: bool,
44
45    /// The watermark column that Emit-On-Window-Close behavior is based on.
46    window_col_idx: Option<usize>,
47}
48
49impl StreamHashAgg {
50    pub fn new(
51        core: generic::Agg<PlanRef>,
52        vnode_col_idx: Option<usize>,
53        row_count_idx: usize,
54    ) -> Result<Self> {
55        Self::new_with_eowc(core, vnode_col_idx, row_count_idx, false)
56    }
57
58    pub fn new_with_eowc(
59        core: generic::Agg<PlanRef>,
60        vnode_col_idx: Option<usize>,
61        row_count_idx: usize,
62        emit_on_window_close: bool,
63    ) -> Result<Self> {
64        assert_eq!(core.agg_calls[row_count_idx], PlanAggCall::count_star());
65        reject_upsert_input!(core.input);
66
67        let input = core.input.clone();
68        let input_dist = input.distribution();
69        let dist = core
70            .i2o_col_mapping()
71            .rewrite_provided_distribution(input_dist);
72
73        let mut watermark_columns = WatermarkColumns::new();
74        let mut window_col_idx = None;
75        let mapping = core.i2o_col_mapping();
76        if emit_on_window_close {
77            let window_col = core
78                .eowc_window_column(input.watermark_columns())
79                .expect("checked in `to_eowc_version`");
80            // EOWC HashAgg only propagate one watermark column, the window column.
81            watermark_columns.insert(
82                mapping.map(window_col),
83                input.watermark_columns().get_group(window_col).unwrap(),
84            );
85            window_col_idx = Some(window_col);
86        } else {
87            for idx in core.group_key.indices() {
88                if let Some(wtmk_group) = input.watermark_columns().get_group(idx) {
89                    // Non-EOWC `StreamHashAgg` simply forwards the watermark messages from the input.
90                    watermark_columns.insert(mapping.map(idx), wtmk_group);
91                }
92            }
93        }
94
95        // Hash agg executor might change the append-only behavior of the stream.
96        let base = PlanBase::new_stream_with_core(
97            &core,
98            dist,
99            if emit_on_window_close {
100                // in EOWC mode, we produce append only output
101                StreamKind::AppendOnly
102            } else {
103                StreamKind::Retract
104            },
105            emit_on_window_close,
106            watermark_columns,
107            MonotonicityMap::new(), // TODO: derive monotonicity
108        );
109
110        Ok(StreamHashAgg {
111            base,
112            core,
113            vnode_col_idx,
114            row_count_idx,
115            emit_on_window_close,
116            window_col_idx,
117        })
118    }
119
120    pub fn agg_calls(&self) -> &[PlanAggCall] {
121        &self.core.agg_calls
122    }
123
124    pub fn group_key(&self) -> &IndexSet {
125        &self.core.group_key
126    }
127
128    pub(crate) fn i2o_col_mapping(&self) -> ColIndexMapping {
129        self.core.i2o_col_mapping()
130    }
131
132    // TODO(rc): It'll be better to force creation of EOWC version through `new`, especially when we
133    // optimize for 2-phase EOWC aggregation later.
134    pub fn to_eowc_version(&self) -> Result<PlanRef> {
135        let input = self.input();
136
137        // check whether the group by columns are valid
138        let _ = self.core.eowc_window_column(input.watermark_columns())?;
139
140        Ok(Self::new_with_eowc(
141            self.core.clone(),
142            self.vnode_col_idx,
143            self.row_count_idx,
144            true,
145        )?
146        .into())
147    }
148}
149
150impl Distill for StreamHashAgg {
151    fn distill<'a>(&self) -> XmlNode<'a> {
152        let mut vec = self.core.fields_pretty();
153        if let Some(ow) = watermark_pretty(self.base.watermark_columns(), self.schema()) {
154            vec.push(("output_watermarks", ow));
155        }
156        childless_record(
157            plan_node_name!(
158                "StreamHashAgg",
159                { "append_only", self.input().append_only() },
160                { "eowc", self.emit_on_window_close },
161            ),
162            vec,
163        )
164    }
165}
166
167impl PlanTreeNodeUnary<Stream> for StreamHashAgg {
168    fn input(&self) -> PlanRef {
169        self.core.input.clone()
170    }
171
172    fn clone_with_input(&self, input: PlanRef) -> Self {
173        let logical = generic::Agg {
174            input,
175            ..self.core.clone()
176        };
177
178        Self::new_with_eowc(
179            logical,
180            self.vnode_col_idx,
181            self.row_count_idx,
182            self.emit_on_window_close,
183        )
184        .unwrap()
185    }
186}
187impl_plan_tree_node_for_unary! { Stream, StreamHashAgg }
188
189impl StreamNode for StreamHashAgg {
190    fn to_stream_prost_body(&self, state: &mut BuildFragmentGraphState) -> PbNodeBody {
191        use risingwave_pb::stream_plan::*;
192        let (intermediate_state_table, agg_states, distinct_dedup_tables) =
193            self.core
194                .infer_tables(&self.base, self.vnode_col_idx, self.window_col_idx);
195
196        PbNodeBody::HashAgg(Box::new(HashAggNode {
197            group_key: self.group_key().to_vec_as_u32(),
198            agg_calls: self
199                .agg_calls()
200                .iter()
201                .map(PlanAggCall::to_protobuf)
202                .collect(),
203
204            is_append_only: self.input().append_only(),
205            agg_call_states: agg_states
206                .into_iter()
207                .map(|s| s.into_prost(state))
208                .collect(),
209            intermediate_state_table: Some(
210                intermediate_state_table
211                    .with_id(state.gen_table_id_wrapped())
212                    .to_internal_table_prost(),
213            ),
214            distinct_dedup_tables: distinct_dedup_tables
215                .into_iter()
216                .sorted_by_key(|(i, _)| *i)
217                .map(|(key_idx, table)| {
218                    (
219                        key_idx as u32,
220                        table
221                            .with_id(state.gen_table_id_wrapped())
222                            .to_internal_table_prost(),
223                    )
224                })
225                .collect(),
226            row_count_index: self.row_count_idx as u32,
227            emit_on_window_close: self.base.emit_on_window_close(),
228            version: PbAggNodeVersion::LATEST as _,
229        }))
230    }
231}
232
233impl ExprRewritable<Stream> for StreamHashAgg {
234    fn has_rewritable_expr(&self) -> bool {
235        true
236    }
237
238    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
239        let mut core = self.core.clone();
240        core.rewrite_exprs(r);
241        Self::new_with_eowc(
242            core,
243            self.vnode_col_idx,
244            self.row_count_idx,
245            self.emit_on_window_close,
246        )
247        .unwrap()
248        .into()
249    }
250}
251
252impl ExprVisitable for StreamHashAgg {
253    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
254        self.core.visit_exprs(v);
255    }
256}