risingwave_frontend/optimizer/plan_node/generic/
agg.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, BTreeSet, HashMap};
16use std::{fmt, vec};
17
18use itertools::{Either, Itertools};
19use pretty_xmlish::{Pretty, StrAssocArr};
20use risingwave_common::catalog::{Field, FieldDisplay, Schema};
21use risingwave_common::types::DataType;
22use risingwave_common::util::iter_util::ZipEqFast;
23use risingwave_common::util::sort_util::{ColumnOrder, ColumnOrderDisplay, OrderType};
24use risingwave_common::util::value_encoding::DatumToProtoExt;
25use risingwave_expr::aggregate::{AggType, PbAggKind, agg_types};
26use risingwave_expr::sig::{FUNCTION_REGISTRY, FuncBuilder};
27use risingwave_pb::expr::{PbAggCall, PbConstant};
28use risingwave_pb::stream_plan::{AggCallState as PbAggCallState, agg_call_state};
29
30use super::super::utils::TableCatalogBuilder;
31use super::{GenericPlanNode, GenericPlanRef, impl_distill_unit_from_fields, stream};
32use crate::TableCatalog;
33use crate::error::{ErrorCode, Result};
34use crate::expr::{Expr, ExprRewriter, ExprVisitor, InputRef, InputRefDisplay, Literal};
35use crate::optimizer::optimizer_context::OptimizerContextRef;
36use crate::optimizer::plan_node::batch::BatchPlanRef;
37use crate::optimizer::property::{
38    Distribution, FunctionalDependencySet, RequiredDist, WatermarkColumns,
39};
40use crate::stream_fragmenter::BuildFragmentGraphState;
41use crate::utils::{
42    ColIndexMapping, ColIndexMappingRewriteExt, Condition, ConditionDisplay, IndexRewriter,
43    IndexSet,
44};
45
46/// [`Agg`] groups input data by their group key and computes aggregation functions.
47///
48/// It corresponds to the `GROUP BY` operator in a SQL query statement together with the aggregate
49/// functions in the `SELECT` clause.
50///
51/// The output schema will first include the group key and then the aggregation calls.
52#[derive(Debug, Clone, PartialEq, Eq, Hash)]
53pub struct Agg<PlanRef> {
54    pub agg_calls: Vec<PlanAggCall>,
55    pub group_key: IndexSet,
56    pub grouping_sets: Vec<IndexSet>,
57    pub input: PlanRef,
58    pub enable_two_phase: bool,
59}
60
61impl<PlanRef: GenericPlanRef> Agg<PlanRef> {
62    pub(crate) fn rewrite_exprs(&mut self, r: &mut dyn ExprRewriter) {
63        self.agg_calls.iter_mut().for_each(|call| {
64            call.filter = call.filter.clone().rewrite_expr(r);
65        });
66    }
67
68    pub(crate) fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
69        self.agg_calls.iter().for_each(|call| {
70            call.filter.visit_expr(v);
71        });
72    }
73
74    pub(crate) fn output_len(&self) -> usize {
75        self.group_key.len() + self.agg_calls.len()
76    }
77
78    /// get the Mapping of columnIndex from input column index to output column index,if a input
79    /// column corresponds more than one out columns, mapping to any one
80    pub fn o2i_col_mapping(&self) -> ColIndexMapping {
81        let mut map = vec![None; self.output_len()];
82        for (i, key) in self.group_key.indices().enumerate() {
83            map[i] = Some(key);
84        }
85        ColIndexMapping::new(map, self.input.schema().len())
86    }
87
88    /// get the Mapping of columnIndex from input column index to out column index
89    pub fn i2o_col_mapping(&self) -> ColIndexMapping {
90        let mut map = vec![None; self.input.schema().len()];
91        for (i, key) in self.group_key.indices().enumerate() {
92            map[key] = Some(i);
93        }
94        ColIndexMapping::new(map, self.output_len())
95    }
96
97    fn two_phase_agg_forced(&self) -> bool {
98        self.ctx().session_ctx().config().force_two_phase_agg()
99    }
100
101    pub fn two_phase_agg_enabled(&self) -> bool {
102        self.enable_two_phase
103    }
104
105    pub(crate) fn can_two_phase_agg(&self) -> bool {
106        self.two_phase_agg_enabled()
107            && !self.agg_calls.is_empty()
108            && self.agg_calls.iter().all(|call| {
109                let agg_type_ok = !matches!(call.agg_type, agg_types::simply_cannot_two_phase!());
110                let order_ok = matches!(
111                    call.agg_type,
112                    agg_types::result_unaffected_by_order_by!()
113                        | AggType::Builtin(PbAggKind::ApproxPercentile)
114                ) || call.order_by.is_empty();
115                let distinct_ok =
116                    matches!(call.agg_type, agg_types::result_unaffected_by_distinct!())
117                        || !call.distinct;
118                agg_type_ok && order_ok && distinct_ok
119            })
120    }
121
122    /// Must try two phase agg iff we are forced to, and we satisfy the constraints.
123    pub(crate) fn must_try_two_phase_agg(&self) -> bool {
124        self.two_phase_agg_forced() && self.can_two_phase_agg()
125    }
126
127    /// Generally used by two phase hash agg.
128    /// If input dist already satisfies hash agg distribution,
129    /// it will be more expensive to do two phase agg, should just do shuffle agg.
130    pub(crate) fn hash_agg_dist_satisfied_by_input_dist(&self, input_dist: &Distribution) -> bool {
131        let required_dist =
132            RequiredDist::shard_by_key(self.input.schema().len(), &self.group_key.to_vec());
133        input_dist.satisfies(&required_dist)
134    }
135
136    /// See if all stream aggregation calls have a stateless local agg counterpart.
137    pub(crate) fn all_local_aggs_are_stateless(&self, stream_input_append_only: bool) -> bool {
138        self.agg_calls.iter().all(|c| {
139            matches!(c.agg_type, agg_types::single_value_state!())
140                || (matches!(c.agg_type, agg_types::single_value_state_iff_in_append_only!() if stream_input_append_only))
141        })
142    }
143
144    pub(crate) fn eowc_window_column(
145        &self,
146        input_watermark_columns: &WatermarkColumns,
147    ) -> Result<usize> {
148        let group_key_with_wtmk = self
149            .group_key
150            .indices()
151            .filter_map(|idx| {
152                input_watermark_columns
153                    .get_group(idx)
154                    .map(|group| (idx, group))
155            })
156            .collect::<Vec<_>>();
157
158        if group_key_with_wtmk.is_empty() {
159            return Err(ErrorCode::NotSupported(
160                "Emit-On-Window-Close mode requires a watermark column in GROUP BY.".to_owned(),
161                "Please try to GROUP BY a watermark column".to_owned(),
162            )
163            .into());
164        }
165        if group_key_with_wtmk.len() == 1
166            || group_key_with_wtmk
167                .iter()
168                .map(|(_, group)| group)
169                .all_equal()
170        {
171            // 1. only one watermark column, should be the window column
172            // 2. all watermark columns belong to the same group, choose the first one as the window column
173            return Ok(group_key_with_wtmk[0].0);
174        }
175        Err(ErrorCode::NotSupported(
176            "Emit-On-Window-Close mode requires that watermark columns in GROUP BY are derived from the same upstream column.".to_owned(),
177            "Please try to remove undesired columns from GROUP BY".to_owned(),
178        )
179        .into())
180    }
181
182    pub fn new(agg_calls: Vec<PlanAggCall>, group_key: IndexSet, input: PlanRef) -> Self {
183        let enable_two_phase = input.ctx().session_ctx().config().enable_two_phase_agg();
184        Self {
185            agg_calls,
186            group_key,
187            input,
188            grouping_sets: vec![],
189            enable_two_phase,
190        }
191    }
192
193    pub fn with_grouping_sets(mut self, grouping_sets: Vec<IndexSet>) -> Self {
194        self.grouping_sets = grouping_sets;
195        self
196    }
197
198    pub fn with_enable_two_phase(mut self, enable_two_phase: bool) -> Self {
199        self.enable_two_phase = enable_two_phase;
200        self
201    }
202}
203
204impl<PlanRef: BatchPlanRef> Agg<PlanRef> {
205    // Check if the input is already sorted on group keys.
206    pub(crate) fn input_provides_order_on_group_keys(&self) -> bool {
207        let mut input_order_prefix = IndexSet::empty();
208        for input_order_col in &self.input.order().column_orders {
209            if !self.group_key.contains(input_order_col.column_index) {
210                break;
211            }
212            input_order_prefix.insert(input_order_col.column_index);
213        }
214        self.group_key == input_order_prefix
215    }
216}
217
218impl<PlanRef: GenericPlanRef> GenericPlanNode for Agg<PlanRef> {
219    fn schema(&self) -> Schema {
220        let fields = self
221            .group_key
222            .indices()
223            .map(|i| self.input.schema().fields()[i].clone())
224            .chain(self.agg_calls.iter().map(|agg_call| {
225                let plan_agg_call_display = PlanAggCallDisplay {
226                    plan_agg_call: agg_call,
227                    input_schema: self.input.schema(),
228                };
229                let name = format!("{:?}", plan_agg_call_display);
230                Field::with_name(agg_call.return_type.clone(), name)
231            }))
232            .collect();
233        Schema { fields }
234    }
235
236    fn stream_key(&self) -> Option<Vec<usize>> {
237        Some((0..self.group_key.len()).collect())
238    }
239
240    fn ctx(&self) -> OptimizerContextRef {
241        self.input.ctx()
242    }
243
244    fn functional_dependency(&self) -> FunctionalDependencySet {
245        let output_len = self.output_len();
246        let _input_len = self.input.schema().len();
247        let mut fd_set =
248            FunctionalDependencySet::with_key(output_len, &(0..self.group_key.len()).collect_vec());
249        // take group keys from input_columns, then grow the target size to column_cnt
250        let i2o = self.i2o_col_mapping();
251        for fd in self.input.functional_dependency().as_dependencies() {
252            if let Some(fd) = i2o.rewrite_functional_dependency(fd) {
253                fd_set.add_functional_dependency(fd);
254            }
255        }
256        fd_set
257    }
258}
259
260pub enum AggCallState {
261    Value,
262    MaterializedInput(Box<MaterializedInputState>),
263}
264
265impl AggCallState {
266    pub fn into_prost(self, state: &mut BuildFragmentGraphState) -> PbAggCallState {
267        PbAggCallState {
268            inner: Some(match self {
269                AggCallState::Value => {
270                    agg_call_state::Inner::ValueState(agg_call_state::ValueState {})
271                }
272                AggCallState::MaterializedInput(s) => {
273                    agg_call_state::Inner::MaterializedInputState(
274                        agg_call_state::MaterializedInputState {
275                            table: Some(
276                                s.table
277                                    .with_id(state.gen_table_id_wrapped())
278                                    .to_internal_table_prost(),
279                            ),
280                            included_upstream_indices: s
281                                .included_upstream_indices
282                                .into_iter()
283                                .map(|x| x as _)
284                                .collect(),
285                            table_value_indices: s
286                                .table_value_indices
287                                .into_iter()
288                                .map(|x| x as _)
289                                .collect(),
290                            order_columns: s
291                                .order_columns
292                                .into_iter()
293                                .map(|x| x.to_protobuf())
294                                .collect(),
295                        },
296                    )
297                }
298            }),
299        }
300    }
301}
302
303pub struct MaterializedInputState {
304    pub table: TableCatalog,
305    pub included_upstream_indices: Vec<usize>,
306    pub table_value_indices: Vec<usize>,
307    pub order_columns: Vec<ColumnOrder>,
308}
309
310impl<PlanRef: stream::StreamPlanRef> Agg<PlanRef> {
311    pub fn infer_tables(
312        &self,
313        me: impl stream::StreamPlanRef,
314        vnode_col_idx: Option<usize>,
315        window_col_idx: Option<usize>,
316    ) -> (
317        TableCatalog,
318        Vec<AggCallState>,
319        HashMap<usize, TableCatalog>,
320    ) {
321        (
322            self.infer_intermediate_state_table(&me, vnode_col_idx, window_col_idx),
323            self.infer_stream_agg_state(&me, vnode_col_idx, window_col_idx),
324            self.infer_distinct_dedup_tables(&me, vnode_col_idx, window_col_idx),
325        )
326    }
327
328    fn get_ordered_group_key(&self, window_col_idx: Option<usize>) -> Vec<usize> {
329        if let Some(window_col_idx) = window_col_idx {
330            assert!(self.group_key.contains(window_col_idx));
331            Either::Left(
332                std::iter::once(window_col_idx).chain(
333                    self.group_key
334                        .indices()
335                        .filter(move |&i| i != window_col_idx),
336                ),
337            )
338        } else {
339            Either::Right(self.group_key.indices())
340        }
341        .collect()
342    }
343
344    /// Create a new table builder with group key columns added.
345    ///
346    /// # Returns
347    ///
348    /// - table builder with group key columns added
349    /// - included upstream indices
350    /// - column mapping from upstream to table
351    fn create_table_builder(
352        &self,
353        _ctx: OptimizerContextRef,
354        window_col_idx: Option<usize>,
355    ) -> (TableCatalogBuilder, Vec<usize>, BTreeMap<usize, usize>) {
356        // NOTE: this function should be called to get a table builder, so that all state tables
357        // created for Agg node have the same group key columns and pk ordering.
358        let mut table_builder = TableCatalogBuilder::default();
359
360        assert!(table_builder.columns().is_empty());
361        assert_eq!(table_builder.get_current_pk_len(), 0);
362
363        // add group key column to table builder
364        let mut included_upstream_indices = vec![];
365        let mut column_mapping = BTreeMap::new();
366        let in_fields = self.input.schema().fields();
367        for idx in self.group_key.indices() {
368            let tbl_col_idx = table_builder.add_column(&in_fields[idx]);
369            included_upstream_indices.push(idx);
370            column_mapping.insert(idx, tbl_col_idx);
371        }
372
373        // configure state table primary key (ordering)
374        let ordered_group_key = self.get_ordered_group_key(window_col_idx);
375        for idx in ordered_group_key {
376            table_builder.add_order_column(column_mapping[&idx], OrderType::ascending());
377        }
378
379        (table_builder, included_upstream_indices, column_mapping)
380    }
381
382    /// Infer `AggCallState`s for streaming agg.
383    pub fn infer_stream_agg_state(
384        &self,
385        me: impl stream::StreamPlanRef,
386        vnode_col_idx: Option<usize>,
387        window_col_idx: Option<usize>,
388    ) -> Vec<AggCallState> {
389        let in_fields = self.input.schema().fields().to_vec();
390        let in_pks = self.input.stream_key().unwrap().to_vec();
391        let in_append_only = self.input.append_only();
392        let in_dist_key = self.input.distribution().dist_column_indices().to_vec();
393
394        let gen_materialized_input_state = |sort_keys: Vec<(OrderType, usize)>,
395                                            extra_keys: Vec<usize>,
396                                            include_keys: Vec<usize>|
397         -> MaterializedInputState {
398            let (mut table_builder, mut included_upstream_indices, mut column_mapping) =
399                self.create_table_builder(me.ctx(), window_col_idx);
400            let read_prefix_len_hint = table_builder.get_current_pk_len();
401
402            let mut order_columns = vec![];
403            let mut table_value_indices = BTreeSet::new(); // table column indices of value columns
404            let mut add_column =
405                |upstream_idx, order_type, table_builder: &mut TableCatalogBuilder| {
406                    column_mapping.entry(upstream_idx).or_insert_with(|| {
407                        let table_col_idx = table_builder.add_column(&in_fields[upstream_idx]);
408                        if let Some(order_type) = order_type {
409                            table_builder.add_order_column(table_col_idx, order_type);
410                            order_columns.push(ColumnOrder::new(upstream_idx, order_type));
411                        }
412                        included_upstream_indices.push(upstream_idx);
413                        table_col_idx
414                    });
415                    table_value_indices.insert(column_mapping[&upstream_idx]);
416                };
417
418            for (order_type, idx) in sort_keys {
419                add_column(idx, Some(order_type), &mut table_builder);
420            }
421            for idx in extra_keys {
422                add_column(idx, Some(OrderType::ascending()), &mut table_builder);
423            }
424            for idx in include_keys {
425                add_column(idx, None, &mut table_builder);
426            }
427
428            let mapping =
429                ColIndexMapping::with_included_columns(&included_upstream_indices, in_fields.len());
430            let tb_dist = mapping.rewrite_dist_key(&in_dist_key);
431            if let Some(tb_vnode_idx) = vnode_col_idx.and_then(|idx| mapping.try_map(idx)) {
432                table_builder.set_vnode_col_idx(tb_vnode_idx);
433            }
434
435            // set value indices to reduce ser/de overhead
436            let table_value_indices = table_value_indices.into_iter().collect_vec();
437            table_builder.set_value_indices(table_value_indices.clone());
438
439            MaterializedInputState {
440                table: table_builder.build(tb_dist.unwrap_or_default(), read_prefix_len_hint),
441                included_upstream_indices,
442                table_value_indices,
443                order_columns,
444            }
445        };
446
447        self.agg_calls
448            .iter()
449            .map(|agg_call| match agg_call.agg_type {
450                agg_types::single_value_state_iff_in_append_only!() if in_append_only => {
451                    AggCallState::Value
452                }
453                agg_types::single_value_state!() => AggCallState::Value,
454                agg_types::materialized_input_state!() => {
455                    // columns with order requirement in state table
456                    let sort_keys = {
457                        match agg_call.agg_type {
458                            AggType::Builtin(PbAggKind::Min) => {
459                                vec![(OrderType::ascending(), agg_call.inputs[0].index)]
460                            }
461                            AggType::Builtin(PbAggKind::Max) => {
462                                vec![(OrderType::descending(), agg_call.inputs[0].index)]
463                            }
464                            AggType::Builtin(
465                                PbAggKind::FirstValue
466                                | PbAggKind::LastValue
467                                | PbAggKind::StringAgg
468                                | PbAggKind::ArrayAgg
469                                | PbAggKind::JsonbAgg,
470                            )
471                            | AggType::WrapScalar(_) => {
472                                if agg_call.order_by.is_empty() {
473                                    me.ctx().warn_to_user(format!(
474                                        "{} without ORDER BY may produce non-deterministic result",
475                                        agg_call.agg_type,
476                                    ));
477                                }
478                                agg_call
479                                    .order_by
480                                    .iter()
481                                    .map(|o| {
482                                        (
483                                            if matches!(
484                                                agg_call.agg_type,
485                                                AggType::Builtin(PbAggKind::LastValue)
486                                            ) {
487                                                o.order_type.reverse()
488                                            } else {
489                                                o.order_type
490                                            },
491                                            o.column_index,
492                                        )
493                                    })
494                                    .collect()
495                            }
496                            AggType::Builtin(PbAggKind::JsonbObjectAgg) => agg_call
497                                .order_by
498                                .iter()
499                                .map(|o| (o.order_type, o.column_index))
500                                .collect(),
501                            _ => unreachable!(),
502                        }
503                    };
504
505                    // columns to ensure each row unique
506                    let extra_keys = if agg_call.distinct {
507                        // if distinct, use distinct keys as extra keys
508                        let distinct_key = agg_call.inputs[0].index;
509                        vec![distinct_key]
510                    } else {
511                        // if not distinct, use primary keys as extra keys
512                        in_pks.clone()
513                    };
514
515                    // other columns that should be contained in state table
516                    let include_keys = match agg_call.agg_type {
517                        // `agg_types::materialized_input_state` except for `min`/`max`
518                        AggType::Builtin(
519                            PbAggKind::FirstValue
520                            | PbAggKind::LastValue
521                            | PbAggKind::StringAgg
522                            | PbAggKind::ArrayAgg
523                            | PbAggKind::JsonbAgg
524                            | PbAggKind::JsonbObjectAgg,
525                        )
526                        | AggType::WrapScalar(_) => {
527                            agg_call.inputs.iter().map(|i| i.index).collect()
528                        }
529                        _ => vec![],
530                    };
531
532                    let state = gen_materialized_input_state(sort_keys, extra_keys, include_keys);
533                    AggCallState::MaterializedInput(Box::new(state))
534                }
535                agg_types::rewritten!() => {
536                    unreachable!("should have been rewritten")
537                }
538                agg_types::unimplemented_in_stream!() => {
539                    unreachable!("should have been banned")
540                }
541                AggType::Builtin(
542                    PbAggKind::Unspecified | PbAggKind::UserDefined | PbAggKind::WrapScalar,
543                ) => {
544                    unreachable!("invalid agg kind")
545                }
546            })
547            .collect()
548    }
549
550    /// table schema:
551    /// group key | state for AGG1 | state for AGG2 | ...
552    pub fn infer_intermediate_state_table(
553        &self,
554        me: impl GenericPlanRef,
555        vnode_col_idx: Option<usize>,
556        window_col_idx: Option<usize>,
557    ) -> TableCatalog {
558        let mut out_fields = me.schema().fields().to_vec();
559
560        // rewrite data types in fields
561        let in_append_only = self.input.append_only();
562        for (agg_call, field) in self
563            .agg_calls
564            .iter()
565            .zip_eq_fast(&mut out_fields[self.group_key.len()..])
566        {
567            let agg_kind = match agg_call.agg_type {
568                AggType::UserDefined(_) => {
569                    // for user defined aggregate, the state type is always BYTEA
570                    field.data_type = DataType::Bytea;
571                    continue;
572                }
573                AggType::WrapScalar(_) => {
574                    // for wrapped scalar function, the state is always NULL
575                    continue;
576                }
577                AggType::Builtin(kind) => kind,
578            };
579            let sig = FUNCTION_REGISTRY
580                .get(
581                    agg_kind,
582                    &agg_call
583                        .inputs
584                        .iter()
585                        .map(|input| input.data_type.clone())
586                        .collect_vec(),
587                    &agg_call.return_type,
588                )
589                .expect("agg not found");
590            // in_append_only: whether the input is append-only
591            // sig.is_append_only(): whether the agg function has append-only version
592            match (in_append_only, sig.is_append_only()) {
593                (false, true) => {
594                    // we use materialized input state for non-retractable aggregate function.
595                    // for backward compatibility, the state type is same as the return type.
596                    // its values in the intermediate state table are always null.
597                }
598                (true, true) => {
599                    // use append-only version
600                    if let FuncBuilder::Aggregate {
601                        append_only_state_type: Some(state_type),
602                        ..
603                    } = &sig.build
604                    {
605                        field.data_type = state_type.clone();
606                    }
607                }
608                (_, false) => {
609                    // there is only retractable version, use it
610                    if let FuncBuilder::Aggregate {
611                        retractable_state_type: Some(state_type),
612                        ..
613                    } = &sig.build
614                    {
615                        field.data_type = state_type.clone();
616                    }
617                }
618            }
619        }
620        let in_dist_key = self.input.distribution().dist_column_indices().to_vec();
621        let n_group_key_cols = self.group_key.len();
622
623        let (mut table_builder, _, _) = self.create_table_builder(me.ctx(), window_col_idx);
624        let read_prefix_len_hint = table_builder.get_current_pk_len();
625
626        for field in out_fields.iter().skip(n_group_key_cols) {
627            table_builder.add_column(field);
628        }
629
630        let mapping = self.i2o_col_mapping();
631        let tb_dist = mapping.rewrite_dist_key(&in_dist_key).unwrap_or_default();
632        if let Some(tb_vnode_idx) = vnode_col_idx.and_then(|idx| mapping.try_map(idx)) {
633            table_builder.set_vnode_col_idx(tb_vnode_idx);
634        }
635
636        // the result_table is composed of group_key and all agg_call's values, so the value_indices
637        // of this table should skip group_key.len().
638        table_builder.set_value_indices((n_group_key_cols..out_fields.len()).collect());
639        table_builder.build(tb_dist, read_prefix_len_hint)
640    }
641
642    /// Infer dedup tables for distinct agg calls, partitioned by distinct columns.
643    /// Since distinct agg calls only dedup on the first argument, the key of the result map is
644    /// `usize`, i.e. the distinct column index.
645    ///
646    /// Dedup table schema:
647    /// group key | distinct key | count for AGG1(distinct x) | count for AGG2(distinct x) | ...
648    pub fn infer_distinct_dedup_tables(
649        &self,
650        me: impl GenericPlanRef,
651        vnode_col_idx: Option<usize>,
652        window_col_idx: Option<usize>,
653    ) -> HashMap<usize, TableCatalog> {
654        let in_dist_key = self.input.distribution().dist_column_indices().to_vec();
655        let in_fields = self.input.schema().fields();
656
657        self.agg_calls
658            .iter()
659            .enumerate()
660            .filter(|(_, call)| call.distinct) // only distinct agg calls need dedup table
661            .into_group_map_by(|(_, call)| call.inputs[0].index) // one table per distinct column
662            .into_iter()
663            .map(|(distinct_col, indices_and_calls)| {
664                let (mut table_builder, mut key_cols, _) =
665                    self.create_table_builder(me.ctx(), window_col_idx);
666                let table_col_idx = table_builder.add_column(&in_fields[distinct_col]);
667                table_builder.add_order_column(table_col_idx, OrderType::ascending());
668                key_cols.push(distinct_col);
669
670                let read_prefix_len_hint = table_builder.get_current_pk_len();
671
672                // Agg calls with same distinct column share the same dedup table, but they may have
673                // different filter conditions, so the count of occurrence of one distinct key may
674                // differ among different calls. We add one column for each call in the dedup table.
675                for (call_index, _) in indices_and_calls {
676                    table_builder.add_column(&Field {
677                        data_type: DataType::Int64,
678                        name: format!("count_for_agg_call_{}", call_index),
679                    });
680                }
681                table_builder
682                    .set_value_indices((key_cols.len()..table_builder.columns().len()).collect());
683
684                let mapping = ColIndexMapping::with_included_columns(&key_cols, in_fields.len());
685                if let Some(idx) = vnode_col_idx.and_then(|idx| mapping.try_map(idx)) {
686                    table_builder.set_vnode_col_idx(idx);
687                }
688                let dist_key = mapping.rewrite_dist_key(&in_dist_key).unwrap_or_default();
689                let table = table_builder.build(dist_key, read_prefix_len_hint);
690                (distinct_col, table)
691            })
692            .collect()
693    }
694
695    pub fn decompose(self) -> (Vec<PlanAggCall>, IndexSet, Vec<IndexSet>, PlanRef, bool) {
696        (
697            self.agg_calls,
698            self.group_key,
699            self.grouping_sets,
700            self.input,
701            self.enable_two_phase,
702        )
703    }
704
705    pub fn fields_pretty<'a>(&self) -> StrAssocArr<'a> {
706        let last = ("aggs", self.agg_calls_pretty());
707        if !self.group_key.is_empty() {
708            let first = ("group_key", self.group_key_pretty());
709            vec![first, last]
710        } else {
711            vec![last]
712        }
713    }
714
715    fn agg_calls_pretty<'a>(&self) -> Pretty<'a> {
716        let f = |plan_agg_call| {
717            Pretty::debug(&PlanAggCallDisplay {
718                plan_agg_call,
719                input_schema: self.input.schema(),
720            })
721        };
722        Pretty::Array(self.agg_calls.iter().map(f).collect())
723    }
724
725    fn group_key_pretty<'a>(&self) -> Pretty<'a> {
726        let f = |i| Pretty::display(&FieldDisplay(self.input.schema().fields.get(i).unwrap()));
727        Pretty::Array(self.group_key.indices().map(f).collect())
728    }
729}
730
731impl_distill_unit_from_fields!(Agg, stream::StreamPlanRef);
732
733/// Rewritten version of [`crate::expr::AggCall`] which uses `InputRef` instead of `ExprImpl`.
734/// Refer to [`crate::optimizer::plan_node::logical_agg::LogicalAggBuilder::try_rewrite_agg_call`]
735/// for more details.
736#[derive(Clone, PartialEq, Eq, Hash)]
737pub struct PlanAggCall {
738    /// Type of aggregation function
739    pub agg_type: AggType,
740
741    /// Data type of the returned column
742    pub return_type: DataType,
743
744    /// Column indexes of input columns.
745    ///
746    /// Its length can be:
747    /// - 0 (`RowCount`)
748    /// - 1 (`Max`, `Min`)
749    /// - 2 (`StringAgg`).
750    ///
751    /// Usually, we mark the first column as the aggregated column.
752    pub inputs: Vec<InputRef>,
753
754    pub distinct: bool,
755    pub order_by: Vec<ColumnOrder>,
756    /// Selective aggregation: only the input rows for which
757    /// `filter` evaluates to `true` will be fed to the aggregate function.
758    pub filter: Condition,
759    pub direct_args: Vec<Literal>,
760}
761
762impl fmt::Debug for PlanAggCall {
763    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
764        write!(f, "{}", self.agg_type)?;
765        if !self.inputs.is_empty() {
766            write!(f, "(")?;
767            for (idx, input) in self.inputs.iter().enumerate() {
768                if idx == 0 && self.distinct {
769                    write!(f, "distinct ")?;
770                }
771                write!(f, "{:?}", input)?;
772                if idx != (self.inputs.len() - 1) {
773                    write!(f, ",")?;
774                }
775            }
776            if !self.order_by.is_empty() {
777                let clause_text = self.order_by.iter().map(|e| format!("{:?}", e)).join(", ");
778                write!(f, " order_by({})", clause_text)?;
779            }
780            write!(f, ")")?;
781        }
782        if !self.filter.always_true() {
783            write!(
784                f,
785                " filter({:?})",
786                self.filter.as_expr_unless_true().unwrap()
787            )?;
788        }
789        Ok(())
790    }
791}
792
793impl PlanAggCall {
794    pub fn rewrite_input_index(&mut self, mapping: ColIndexMapping) {
795        // modify input
796        self.inputs.iter_mut().for_each(|x| {
797            x.index = mapping.map(x.index);
798        });
799
800        // modify order_by exprs
801        self.order_by.iter_mut().for_each(|x| {
802            x.column_index = mapping.map(x.column_index);
803        });
804
805        // modify filter
806        let mut rewriter = IndexRewriter::new(mapping);
807        self.filter.conjunctions.iter_mut().for_each(|x| {
808            *x = rewriter.rewrite_expr(x.clone());
809        });
810    }
811
812    pub fn to_protobuf(&self) -> PbAggCall {
813        PbAggCall {
814            kind: match &self.agg_type {
815                AggType::Builtin(kind) => *kind,
816                AggType::UserDefined(_) => PbAggKind::UserDefined,
817                AggType::WrapScalar(_) => PbAggKind::WrapScalar,
818            }
819            .into(),
820            return_type: Some(self.return_type.to_protobuf()),
821            args: self.inputs.iter().map(InputRef::to_proto).collect(),
822            distinct: self.distinct,
823            order_by: self.order_by.iter().map(ColumnOrder::to_protobuf).collect(),
824            filter: self.filter.as_expr_unless_true().map(|x| x.to_expr_proto()),
825            direct_args: self
826                .direct_args
827                .iter()
828                .map(|x| PbConstant {
829                    datum: Some(x.get_data().to_protobuf()),
830                    r#type: Some(x.return_type().to_protobuf()),
831                })
832                .collect(),
833            udf: match &self.agg_type {
834                AggType::UserDefined(udf) => Some(udf.clone()),
835                _ => None,
836            },
837            scalar: match &self.agg_type {
838                AggType::WrapScalar(expr) => Some(expr.clone()),
839                _ => None,
840            },
841        }
842    }
843
844    pub fn partial_to_total_agg_call(&self, partial_output_idx: usize) -> PlanAggCall {
845        let total_agg_type = self
846            .agg_type
847            .partial_to_total()
848            .expect("unsupported kinds shouldn't get here");
849        PlanAggCall {
850            agg_type: total_agg_type,
851            inputs: vec![InputRef::new(partial_output_idx, self.return_type.clone())],
852            order_by: vec![], // order must make no difference when we use 2-phase agg
853            filter: Condition::true_cond(),
854            ..self.clone()
855        }
856    }
857
858    pub fn count_star() -> Self {
859        PlanAggCall {
860            agg_type: PbAggKind::Count.into(),
861            return_type: DataType::Int64,
862            inputs: vec![],
863            distinct: false,
864            order_by: vec![],
865            filter: Condition::true_cond(),
866            direct_args: vec![],
867        }
868    }
869
870    pub fn with_condition(mut self, filter: Condition) -> Self {
871        self.filter = filter;
872        self
873    }
874
875    pub fn input_indices(&self) -> Vec<usize> {
876        self.inputs.iter().map(|input| input.index()).collect()
877    }
878}
879
880pub struct PlanAggCallDisplay<'a> {
881    pub plan_agg_call: &'a PlanAggCall,
882    pub input_schema: &'a Schema,
883}
884
885impl fmt::Debug for PlanAggCallDisplay<'_> {
886    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
887        let that = self.plan_agg_call;
888        write!(f, "{}", that.agg_type)?;
889        if !that.inputs.is_empty() {
890            write!(f, "(")?;
891            for (idx, input) in that.inputs.iter().enumerate() {
892                if idx == 0 && that.distinct {
893                    write!(f, "distinct ")?;
894                }
895                write!(
896                    f,
897                    "{}",
898                    InputRefDisplay {
899                        input_ref: input,
900                        input_schema: self.input_schema
901                    }
902                )?;
903                if idx != (that.inputs.len() - 1) {
904                    write!(f, ", ")?;
905                }
906            }
907            if !that.order_by.is_empty() {
908                write!(
909                    f,
910                    " order_by({})",
911                    that.order_by.iter().format_with(", ", |o, f| {
912                        f(&ColumnOrderDisplay {
913                            column_order: o,
914                            input_schema: self.input_schema,
915                        })
916                    })
917                )?;
918            }
919            write!(f, ")")?;
920        }
921
922        if !that.filter.always_true() {
923            write!(
924                f,
925                " filter({:?})",
926                ConditionDisplay {
927                    condition: &that.filter,
928                    input_schema: self.input_schema,
929                }
930            )?;
931        }
932        Ok(())
933    }
934}