risingwave_frontend/optimizer/plan_node/generic/
agg.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, BTreeSet, HashMap};
16use std::{fmt, vec};
17
18use itertools::{Either, Itertools};
19use pretty_xmlish::{Pretty, StrAssocArr};
20use risingwave_common::catalog::{Field, FieldDisplay, Schema};
21use risingwave_common::types::DataType;
22use risingwave_common::util::iter_util::ZipEqFast;
23use risingwave_common::util::sort_util::{ColumnOrder, ColumnOrderDisplay, OrderType};
24use risingwave_common::util::value_encoding::DatumToProtoExt;
25use risingwave_expr::aggregate::{AggType, PbAggKind, agg_types};
26use risingwave_expr::sig::{FUNCTION_REGISTRY, FuncBuilder};
27use risingwave_pb::expr::{PbAggCall, PbConstant};
28use risingwave_pb::stream_plan::{AggCallState as PbAggCallState, agg_call_state};
29
30use super::super::utils::TableCatalogBuilder;
31use super::{GenericPlanNode, GenericPlanRef, PhysicalPlanRef, impl_distill_unit_from_fields};
32use crate::TableCatalog;
33use crate::error::{ErrorCode, Result};
34use crate::expr::{Expr, ExprRewriter, ExprVisitor, InputRef, InputRefDisplay, Literal};
35use crate::optimizer::optimizer_context::OptimizerContextRef;
36use crate::optimizer::plan_node::batch::BatchPlanNodeMetadata;
37use crate::optimizer::plan_node::{BatchPlanRef, StreamPlanNodeMetadata, StreamPlanRef};
38use crate::optimizer::property::{
39    Distribution, FunctionalDependencySet, RequiredDist, WatermarkColumns,
40};
41use crate::stream_fragmenter::BuildFragmentGraphState;
42use crate::utils::{
43    ColIndexMapping, ColIndexMappingRewriteExt, Condition, ConditionDisplay, IndexRewriter,
44    IndexSet,
45};
46
47/// [`Agg`] groups input data by their group key and computes aggregation functions.
48///
49/// It corresponds to the `GROUP BY` operator in a SQL query statement together with the aggregate
50/// functions in the `SELECT` clause.
51///
52/// The output schema will first include the group key and then the aggregation calls.
53#[derive(Debug, Clone, PartialEq, Eq, Hash)]
54pub struct Agg<PlanRef> {
55    pub agg_calls: Vec<PlanAggCall>,
56    pub group_key: IndexSet,
57    pub grouping_sets: Vec<IndexSet>,
58    pub input: PlanRef,
59    pub enable_two_phase: bool,
60}
61
62impl<PlanRef: GenericPlanRef> Agg<PlanRef> {
63    pub(crate) fn clone_with_input<OtherPlanRef>(&self, input: OtherPlanRef) -> Agg<OtherPlanRef> {
64        Agg {
65            agg_calls: self.agg_calls.clone(),
66            group_key: self.group_key.clone(),
67            grouping_sets: self.grouping_sets.clone(),
68            input,
69            enable_two_phase: self.enable_two_phase,
70        }
71    }
72
73    pub(crate) fn rewrite_exprs(&mut self, r: &mut dyn ExprRewriter) {
74        self.agg_calls.iter_mut().for_each(|call| {
75            call.filter = call.filter.clone().rewrite_expr(r);
76        });
77    }
78
79    pub(crate) fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
80        self.agg_calls.iter().for_each(|call| {
81            call.filter.visit_expr(v);
82        });
83    }
84
85    pub(crate) fn output_len(&self) -> usize {
86        self.group_key.len() + self.agg_calls.len()
87    }
88
89    /// get the Mapping of columnIndex from input column index to output column index,if a input
90    /// column corresponds more than one out columns, mapping to any one
91    pub fn o2i_col_mapping(&self) -> ColIndexMapping {
92        let mut map = vec![None; self.output_len()];
93        for (i, key) in self.group_key.indices().enumerate() {
94            map[i] = Some(key);
95        }
96        ColIndexMapping::new(map, self.input.schema().len())
97    }
98
99    /// get the Mapping of columnIndex from input column index to out column index
100    pub fn i2o_col_mapping(&self) -> ColIndexMapping {
101        let mut map = vec![None; self.input.schema().len()];
102        for (i, key) in self.group_key.indices().enumerate() {
103            map[key] = Some(i);
104        }
105        ColIndexMapping::new(map, self.output_len())
106    }
107
108    fn two_phase_agg_forced(&self) -> bool {
109        self.ctx().session_ctx().config().force_two_phase_agg()
110    }
111
112    pub fn two_phase_agg_enabled(&self) -> bool {
113        self.enable_two_phase
114    }
115
116    pub(crate) fn can_two_phase_agg(&self) -> bool {
117        self.two_phase_agg_enabled()
118            && !self.agg_calls.is_empty()
119            && self.agg_calls.iter().all(|call| {
120                let agg_type_ok = !matches!(call.agg_type, agg_types::simply_cannot_two_phase!());
121                let order_ok = matches!(
122                    call.agg_type,
123                    agg_types::result_unaffected_by_order_by!()
124                        | AggType::Builtin(PbAggKind::ApproxPercentile)
125                ) || call.order_by.is_empty();
126                let distinct_ok =
127                    matches!(call.agg_type, agg_types::result_unaffected_by_distinct!())
128                        || !call.distinct;
129                agg_type_ok && order_ok && distinct_ok
130            })
131    }
132
133    /// Must try two phase agg iff we are forced to, and we satisfy the constraints.
134    pub(crate) fn must_try_two_phase_agg(&self) -> bool {
135        self.two_phase_agg_forced() && self.can_two_phase_agg()
136    }
137
138    /// Generally used by two phase hash agg.
139    /// If input dist already satisfies hash agg distribution,
140    /// it will be more expensive to do two phase agg, should just do shuffle agg.
141    pub(crate) fn hash_agg_dist_satisfied_by_input_dist(&self, input_dist: &Distribution) -> bool {
142        let required_dist =
143            RequiredDist::shard_by_key(self.input.schema().len(), &self.group_key.to_vec());
144        input_dist.satisfies(&required_dist)
145    }
146
147    /// See if all stream aggregation calls have a stateless local agg counterpart.
148    pub(crate) fn all_local_aggs_are_stateless(&self, stream_input_append_only: bool) -> bool {
149        self.agg_calls.iter().all(|c| {
150            matches!(c.agg_type, agg_types::single_value_state!())
151                || (matches!(c.agg_type, agg_types::single_value_state_iff_in_append_only!() if stream_input_append_only))
152        })
153    }
154
155    pub(crate) fn eowc_window_column(
156        &self,
157        input_watermark_columns: &WatermarkColumns,
158    ) -> Result<usize> {
159        let group_key_with_wtmk = self
160            .group_key
161            .indices()
162            .filter_map(|idx| {
163                input_watermark_columns
164                    .get_group(idx)
165                    .map(|group| (idx, group))
166            })
167            .collect::<Vec<_>>();
168
169        if group_key_with_wtmk.is_empty() {
170            return Err(ErrorCode::NotSupported(
171                "Emit-On-Window-Close mode requires a watermark column in GROUP BY.".to_owned(),
172                "Please try to GROUP BY a watermark column".to_owned(),
173            )
174            .into());
175        }
176        if group_key_with_wtmk.len() == 1
177            || group_key_with_wtmk
178                .iter()
179                .map(|(_, group)| group)
180                .all_equal()
181        {
182            // 1. only one watermark column, should be the window column
183            // 2. all watermark columns belong to the same group, choose the first one as the window column
184            return Ok(group_key_with_wtmk[0].0);
185        }
186        Err(ErrorCode::NotSupported(
187            "Emit-On-Window-Close mode requires that watermark columns in GROUP BY are derived from the same upstream column.".to_owned(),
188            "Please try to remove undesired columns from GROUP BY".to_owned(),
189        )
190        .into())
191    }
192
193    pub fn new(agg_calls: Vec<PlanAggCall>, group_key: IndexSet, input: PlanRef) -> Self {
194        let enable_two_phase = input.ctx().session_ctx().config().enable_two_phase_agg();
195        Self {
196            agg_calls,
197            group_key,
198            input,
199            grouping_sets: vec![],
200            enable_two_phase,
201        }
202    }
203
204    pub fn with_grouping_sets(mut self, grouping_sets: Vec<IndexSet>) -> Self {
205        self.grouping_sets = grouping_sets;
206        self
207    }
208
209    pub fn with_enable_two_phase(mut self, enable_two_phase: bool) -> Self {
210        self.enable_two_phase = enable_two_phase;
211        self
212    }
213}
214
215impl Agg<BatchPlanRef> {
216    // Check if the input is already sorted on group keys.
217    pub(crate) fn input_provides_order_on_group_keys(&self) -> bool {
218        let mut input_order_prefix = IndexSet::empty();
219        for input_order_col in &self.input.order().column_orders {
220            if !self.group_key.contains(input_order_col.column_index) {
221                break;
222            }
223            input_order_prefix.insert(input_order_col.column_index);
224        }
225        self.group_key == input_order_prefix
226    }
227}
228
229impl<PlanRef: GenericPlanRef> GenericPlanNode for Agg<PlanRef> {
230    fn schema(&self) -> Schema {
231        let fields = self
232            .group_key
233            .indices()
234            .map(|i| self.input.schema().fields()[i].clone())
235            .chain(self.agg_calls.iter().map(|agg_call| {
236                let plan_agg_call_display = PlanAggCallDisplay {
237                    plan_agg_call: agg_call,
238                    input_schema: self.input.schema(),
239                };
240                let name = format!("{:?}", plan_agg_call_display);
241                Field::with_name(agg_call.return_type.clone(), name)
242            }))
243            .collect();
244        Schema { fields }
245    }
246
247    fn stream_key(&self) -> Option<Vec<usize>> {
248        Some((0..self.group_key.len()).collect())
249    }
250
251    fn ctx(&self) -> OptimizerContextRef {
252        self.input.ctx()
253    }
254
255    fn functional_dependency(&self) -> FunctionalDependencySet {
256        let output_len = self.output_len();
257        let _input_len = self.input.schema().len();
258        let mut fd_set =
259            FunctionalDependencySet::with_key(output_len, &(0..self.group_key.len()).collect_vec());
260        // take group keys from input_columns, then grow the target size to column_cnt
261        let i2o = self.i2o_col_mapping();
262        for fd in self.input.functional_dependency().as_dependencies() {
263            if let Some(fd) = i2o.rewrite_functional_dependency(fd) {
264                fd_set.add_functional_dependency(fd);
265            }
266        }
267        fd_set
268    }
269}
270
271pub enum AggCallState {
272    Value,
273    MaterializedInput(Box<MaterializedInputState>),
274}
275
276impl AggCallState {
277    pub fn into_prost(self, state: &mut BuildFragmentGraphState) -> PbAggCallState {
278        PbAggCallState {
279            inner: Some(match self {
280                AggCallState::Value => {
281                    agg_call_state::Inner::ValueState(agg_call_state::ValueState {})
282                }
283                AggCallState::MaterializedInput(s) => {
284                    agg_call_state::Inner::MaterializedInputState(
285                        agg_call_state::MaterializedInputState {
286                            table: Some(
287                                s.table
288                                    .with_id(state.gen_table_id_wrapped())
289                                    .to_internal_table_prost(),
290                            ),
291                            included_upstream_indices: s
292                                .included_upstream_indices
293                                .into_iter()
294                                .map(|x| x as _)
295                                .collect(),
296                            table_value_indices: s
297                                .table_value_indices
298                                .into_iter()
299                                .map(|x| x as _)
300                                .collect(),
301                            order_columns: s
302                                .order_columns
303                                .into_iter()
304                                .map(|x| x.to_protobuf())
305                                .collect(),
306                        },
307                    )
308                }
309            }),
310        }
311    }
312}
313
314pub struct MaterializedInputState {
315    pub table: TableCatalog,
316    pub included_upstream_indices: Vec<usize>,
317    pub table_value_indices: Vec<usize>,
318    pub order_columns: Vec<ColumnOrder>,
319}
320
321impl Agg<StreamPlanRef> {
322    pub fn infer_tables(
323        &self,
324        me: impl StreamPlanNodeMetadata,
325        vnode_col_idx: Option<usize>,
326        window_col_idx: Option<usize>,
327    ) -> (
328        TableCatalog,
329        Vec<AggCallState>,
330        HashMap<usize, TableCatalog>,
331    ) {
332        (
333            self.infer_intermediate_state_table(&me, vnode_col_idx, window_col_idx),
334            self.infer_stream_agg_state(&me, vnode_col_idx, window_col_idx),
335            self.infer_distinct_dedup_tables(&me, vnode_col_idx, window_col_idx),
336        )
337    }
338
339    fn get_ordered_group_key(&self, window_col_idx: Option<usize>) -> Vec<usize> {
340        if let Some(window_col_idx) = window_col_idx {
341            assert!(self.group_key.contains(window_col_idx));
342            Either::Left(
343                std::iter::once(window_col_idx).chain(
344                    self.group_key
345                        .indices()
346                        .filter(move |&i| i != window_col_idx),
347                ),
348            )
349        } else {
350            Either::Right(self.group_key.indices())
351        }
352        .collect()
353    }
354
355    /// Create a new table builder with group key columns added.
356    ///
357    /// # Returns
358    ///
359    /// - table builder with group key columns added
360    /// - included upstream indices
361    /// - column mapping from upstream to table
362    fn create_table_builder(
363        &self,
364        _ctx: OptimizerContextRef,
365        window_col_idx: Option<usize>,
366    ) -> (TableCatalogBuilder, Vec<usize>, BTreeMap<usize, usize>) {
367        // NOTE: this function should be called to get a table builder, so that all state tables
368        // created for Agg node have the same group key columns and pk ordering.
369        let mut table_builder = TableCatalogBuilder::default();
370
371        assert!(table_builder.columns().is_empty());
372        assert_eq!(table_builder.get_current_pk_len(), 0);
373
374        // add group key column to table builder
375        let mut included_upstream_indices = vec![];
376        let mut column_mapping = BTreeMap::new();
377        let in_fields = self.input.schema().fields();
378        for idx in self.group_key.indices() {
379            let tbl_col_idx = table_builder.add_column(&in_fields[idx]);
380            included_upstream_indices.push(idx);
381            column_mapping.insert(idx, tbl_col_idx);
382        }
383
384        // configure state table primary key (ordering)
385        let ordered_group_key = self.get_ordered_group_key(window_col_idx);
386        for idx in ordered_group_key {
387            table_builder.add_order_column(column_mapping[&idx], OrderType::ascending());
388        }
389
390        (table_builder, included_upstream_indices, column_mapping)
391    }
392
393    /// Infer `AggCallState`s for streaming agg.
394    pub fn infer_stream_agg_state(
395        &self,
396        me: impl StreamPlanNodeMetadata,
397        vnode_col_idx: Option<usize>,
398        window_col_idx: Option<usize>,
399    ) -> Vec<AggCallState> {
400        let in_fields = self.input.schema().fields().to_vec();
401        let in_pks = self.input.expect_stream_key().to_vec();
402        let in_append_only = self.input.append_only();
403        let in_dist_key = self.input.distribution().dist_column_indices().to_vec();
404
405        let gen_materialized_input_state = |sort_keys: Vec<(OrderType, usize)>,
406                                            extra_keys: Vec<usize>,
407                                            include_keys: Vec<usize>|
408         -> MaterializedInputState {
409            let (mut table_builder, mut included_upstream_indices, mut column_mapping) =
410                self.create_table_builder(me.ctx(), window_col_idx);
411            let read_prefix_len_hint = table_builder.get_current_pk_len();
412
413            let mut order_columns = vec![];
414            let mut table_value_indices = BTreeSet::new(); // table column indices of value columns
415            let mut add_column =
416                |upstream_idx, order_type, table_builder: &mut TableCatalogBuilder| {
417                    column_mapping.entry(upstream_idx).or_insert_with(|| {
418                        let table_col_idx = table_builder.add_column(&in_fields[upstream_idx]);
419                        if let Some(order_type) = order_type {
420                            table_builder.add_order_column(table_col_idx, order_type);
421                            order_columns.push(ColumnOrder::new(upstream_idx, order_type));
422                        }
423                        included_upstream_indices.push(upstream_idx);
424                        table_col_idx
425                    });
426                    table_value_indices.insert(column_mapping[&upstream_idx]);
427                };
428
429            for (order_type, idx) in sort_keys {
430                add_column(idx, Some(order_type), &mut table_builder);
431            }
432            for idx in extra_keys {
433                add_column(idx, Some(OrderType::ascending()), &mut table_builder);
434            }
435            for idx in include_keys {
436                add_column(idx, None, &mut table_builder);
437            }
438
439            let mapping =
440                ColIndexMapping::with_included_columns(&included_upstream_indices, in_fields.len());
441            let tb_dist = mapping.rewrite_dist_key(&in_dist_key);
442            if let Some(tb_vnode_idx) = vnode_col_idx.and_then(|idx| mapping.try_map(idx)) {
443                table_builder.set_vnode_col_idx(tb_vnode_idx);
444            }
445
446            // set value indices to reduce ser/de overhead
447            let table_value_indices = table_value_indices.into_iter().collect_vec();
448            table_builder.set_value_indices(table_value_indices.clone());
449
450            MaterializedInputState {
451                table: table_builder.build(tb_dist.unwrap_or_default(), read_prefix_len_hint),
452                included_upstream_indices,
453                table_value_indices,
454                order_columns,
455            }
456        };
457
458        self.agg_calls
459            .iter()
460            .map(|agg_call| match agg_call.agg_type {
461                agg_types::single_value_state_iff_in_append_only!() if in_append_only => {
462                    AggCallState::Value
463                }
464                agg_types::single_value_state!() => AggCallState::Value,
465                agg_types::materialized_input_state!() => {
466                    // columns with order requirement in state table
467                    let sort_keys = {
468                        match agg_call.agg_type {
469                            AggType::Builtin(PbAggKind::Min) => {
470                                vec![(OrderType::ascending(), agg_call.inputs[0].index)]
471                            }
472                            AggType::Builtin(PbAggKind::Max) => {
473                                vec![(OrderType::descending(), agg_call.inputs[0].index)]
474                            }
475                            AggType::Builtin(
476                                PbAggKind::FirstValue
477                                | PbAggKind::LastValue
478                                | PbAggKind::StringAgg
479                                | PbAggKind::ArrayAgg
480                                | PbAggKind::JsonbAgg
481                                | PbAggKind::PercentileCont
482                                | PbAggKind::PercentileDisc
483                                | PbAggKind::Mode,
484                            )
485                            | AggType::WrapScalar(_) => {
486                                if agg_call.order_by.is_empty() {
487                                    me.ctx().warn_to_user(format!(
488                                        "{} without ORDER BY may produce non-deterministic result",
489                                        agg_call.agg_type,
490                                    ));
491                                }
492                                agg_call
493                                    .order_by
494                                    .iter()
495                                    .map(|o| {
496                                        (
497                                            if matches!(
498                                                agg_call.agg_type,
499                                                AggType::Builtin(PbAggKind::LastValue)
500                                            ) {
501                                                o.order_type.reverse()
502                                            } else {
503                                                o.order_type
504                                            },
505                                            o.column_index,
506                                        )
507                                    })
508                                    .collect()
509                            }
510                            AggType::Builtin(PbAggKind::JsonbObjectAgg) => agg_call
511                                .order_by
512                                .iter()
513                                .map(|o| (o.order_type, o.column_index))
514                                .collect(),
515                            _ => unreachable!(),
516                        }
517                    };
518
519                    // columns to ensure each row unique
520                    let extra_keys = if agg_call.distinct {
521                        // if distinct, use distinct keys as extra keys
522                        let distinct_key = agg_call.inputs[0].index;
523                        vec![distinct_key]
524                    } else {
525                        // if not distinct, use primary keys as extra keys
526                        in_pks.clone()
527                    };
528
529                    // other columns that should be contained in state table
530                    let include_keys = match agg_call.agg_type {
531                        // `agg_types::materialized_input_state` except for `min`/`max`
532                        AggType::Builtin(
533                            PbAggKind::FirstValue
534                            | PbAggKind::LastValue
535                            | PbAggKind::StringAgg
536                            | PbAggKind::ArrayAgg
537                            | PbAggKind::JsonbAgg
538                            | PbAggKind::JsonbObjectAgg
539                            | PbAggKind::PercentileCont
540                            | PbAggKind::PercentileDisc
541                            | PbAggKind::Mode,
542                        )
543                        | AggType::WrapScalar(_) => {
544                            agg_call.inputs.iter().map(|i| i.index).collect()
545                        }
546                        _ => vec![],
547                    };
548
549                    let state = gen_materialized_input_state(sort_keys, extra_keys, include_keys);
550                    AggCallState::MaterializedInput(Box::new(state))
551                }
552                agg_types::rewritten!() => {
553                    unreachable!("should have been rewritten")
554                }
555                AggType::Builtin(
556                    PbAggKind::Unspecified | PbAggKind::UserDefined | PbAggKind::WrapScalar,
557                ) => {
558                    unreachable!("invalid agg kind")
559                }
560            })
561            .collect()
562    }
563
564    /// table schema:
565    /// group key | state for AGG1 | state for AGG2 | ...
566    pub fn infer_intermediate_state_table(
567        &self,
568        me: impl GenericPlanRef,
569        vnode_col_idx: Option<usize>,
570        window_col_idx: Option<usize>,
571    ) -> TableCatalog {
572        let mut out_fields = me.schema().fields().to_vec();
573
574        // rewrite data types in fields
575        let in_append_only = self.input.append_only();
576        for (agg_call, field) in self
577            .agg_calls
578            .iter()
579            .zip_eq_fast(&mut out_fields[self.group_key.len()..])
580        {
581            let agg_kind = match agg_call.agg_type {
582                AggType::UserDefined(_) => {
583                    // for user defined aggregate, the state type is always BYTEA
584                    field.data_type = DataType::Bytea;
585                    continue;
586                }
587                AggType::WrapScalar(_) => {
588                    // for wrapped scalar function, the state is always NULL
589                    continue;
590                }
591                AggType::Builtin(kind) => kind,
592            };
593            let sig = FUNCTION_REGISTRY
594                .get(
595                    agg_kind,
596                    &agg_call
597                        .inputs
598                        .iter()
599                        .map(|input| input.data_type.clone())
600                        .collect_vec(),
601                    &agg_call.return_type,
602                )
603                .expect("agg not found");
604            // in_append_only: whether the input is append-only
605            // sig.is_append_only(): whether the agg function has append-only version
606            match (in_append_only, sig.is_append_only()) {
607                (false, true) => {
608                    // we use materialized input state for non-retractable aggregate function.
609                    // for backward compatibility, the state type is same as the return type.
610                    // its values in the intermediate state table are always null.
611                }
612                (true, true) => {
613                    // use append-only version
614                    if let FuncBuilder::Aggregate {
615                        append_only_state_type: Some(state_type),
616                        ..
617                    } = &sig.build
618                    {
619                        field.data_type = state_type.clone();
620                    }
621                }
622                (_, false) => {
623                    // there is only retractable version, use it
624                    if let FuncBuilder::Aggregate {
625                        retractable_state_type: Some(state_type),
626                        ..
627                    } = &sig.build
628                    {
629                        field.data_type = state_type.clone();
630                    }
631                }
632            }
633        }
634        let in_dist_key = self.input.distribution().dist_column_indices().to_vec();
635        let n_group_key_cols = self.group_key.len();
636
637        let (mut table_builder, _, _) = self.create_table_builder(me.ctx(), window_col_idx);
638        let read_prefix_len_hint = table_builder.get_current_pk_len();
639
640        for field in out_fields.iter().skip(n_group_key_cols) {
641            table_builder.add_column(field);
642        }
643
644        let mapping = self.i2o_col_mapping();
645        let tb_dist = mapping.rewrite_dist_key(&in_dist_key).unwrap_or_default();
646        if let Some(tb_vnode_idx) = vnode_col_idx.and_then(|idx| mapping.try_map(idx)) {
647            table_builder.set_vnode_col_idx(tb_vnode_idx);
648        }
649
650        // the result_table is composed of group_key and all agg_call's values, so the value_indices
651        // of this table should skip group_key.len().
652        table_builder.set_value_indices((n_group_key_cols..out_fields.len()).collect());
653        table_builder.build(tb_dist, read_prefix_len_hint)
654    }
655
656    /// Infer dedup tables for distinct agg calls, partitioned by distinct columns.
657    /// Since distinct agg calls only dedup on the first argument, the key of the result map is
658    /// `usize`, i.e. the distinct column index.
659    ///
660    /// Dedup table schema:
661    /// group key | distinct key | count for AGG1(distinct x) | count for AGG2(distinct x) | ...
662    pub fn infer_distinct_dedup_tables(
663        &self,
664        me: impl GenericPlanRef,
665        vnode_col_idx: Option<usize>,
666        window_col_idx: Option<usize>,
667    ) -> HashMap<usize, TableCatalog> {
668        let in_dist_key = self.input.distribution().dist_column_indices().to_vec();
669        let in_fields = self.input.schema().fields();
670
671        self.agg_calls
672            .iter()
673            .enumerate()
674            .filter(|(_, call)| call.distinct) // only distinct agg calls need dedup table
675            .into_group_map_by(|(_, call)| call.inputs[0].index) // one table per distinct column
676            .into_iter()
677            .map(|(distinct_col, indices_and_calls)| {
678                let (mut table_builder, mut key_cols, _) =
679                    self.create_table_builder(me.ctx(), window_col_idx);
680                let table_col_idx = table_builder.add_column(&in_fields[distinct_col]);
681                table_builder.add_order_column(table_col_idx, OrderType::ascending());
682                key_cols.push(distinct_col);
683
684                let read_prefix_len_hint = table_builder.get_current_pk_len();
685
686                // Agg calls with same distinct column share the same dedup table, but they may have
687                // different filter conditions, so the count of occurrence of one distinct key may
688                // differ among different calls. We add one column for each call in the dedup table.
689                for (call_index, _) in indices_and_calls {
690                    table_builder.add_column(&Field {
691                        data_type: DataType::Int64,
692                        name: format!("count_for_agg_call_{}", call_index),
693                    });
694                }
695                table_builder
696                    .set_value_indices((key_cols.len()..table_builder.columns().len()).collect());
697
698                let mapping = ColIndexMapping::with_included_columns(&key_cols, in_fields.len());
699                if let Some(idx) = vnode_col_idx.and_then(|idx| mapping.try_map(idx)) {
700                    table_builder.set_vnode_col_idx(idx);
701                }
702                let dist_key = mapping.rewrite_dist_key(&in_dist_key).unwrap_or_default();
703                let table = table_builder.build(dist_key, read_prefix_len_hint);
704                (distinct_col, table)
705            })
706            .collect()
707    }
708}
709
710impl<PlanRef: GenericPlanRef> Agg<PlanRef> {
711    pub fn decompose(self) -> (Vec<PlanAggCall>, IndexSet, Vec<IndexSet>, PlanRef, bool) {
712        (
713            self.agg_calls,
714            self.group_key,
715            self.grouping_sets,
716            self.input,
717            self.enable_two_phase,
718        )
719    }
720
721    pub fn fields_pretty<'a>(&self) -> StrAssocArr<'a> {
722        let last = ("aggs", self.agg_calls_pretty());
723        if !self.group_key.is_empty() {
724            let first = ("group_key", self.group_key_pretty());
725            vec![first, last]
726        } else {
727            vec![last]
728        }
729    }
730
731    fn agg_calls_pretty<'a>(&self) -> Pretty<'a> {
732        let f = |plan_agg_call| {
733            Pretty::debug(&PlanAggCallDisplay {
734                plan_agg_call,
735                input_schema: self.input.schema(),
736            })
737        };
738        Pretty::Array(self.agg_calls.iter().map(f).collect())
739    }
740
741    fn group_key_pretty<'a>(&self) -> Pretty<'a> {
742        let f = |i| Pretty::display(&FieldDisplay(self.input.schema().fields.get(i).unwrap()));
743        Pretty::Array(self.group_key.indices().map(f).collect())
744    }
745}
746
747impl_distill_unit_from_fields!(Agg, GenericPlanRef);
748
749/// Rewritten version of [`crate::expr::AggCall`] which uses `InputRef` instead of `ExprImpl`.
750/// Refer to [`crate::optimizer::plan_node::logical_agg::LogicalAggBuilder::try_rewrite_agg_call`]
751/// for more details.
752#[derive(Clone, PartialEq, Eq, Hash)]
753pub struct PlanAggCall {
754    /// Type of aggregation function
755    pub agg_type: AggType,
756
757    /// Data type of the returned column
758    pub return_type: DataType,
759
760    /// Column indexes of input columns.
761    ///
762    /// Its length can be:
763    /// - 0 (`RowCount`)
764    /// - 1 (`Max`, `Min`)
765    /// - 2 (`StringAgg`).
766    ///
767    /// Usually, we mark the first column as the aggregated column.
768    pub inputs: Vec<InputRef>,
769
770    pub distinct: bool,
771    pub order_by: Vec<ColumnOrder>,
772    /// Selective aggregation: only the input rows for which
773    /// `filter` evaluates to `true` will be fed to the aggregate function.
774    pub filter: Condition,
775    pub direct_args: Vec<Literal>,
776}
777
778impl fmt::Debug for PlanAggCall {
779    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
780        write!(f, "{}", self.agg_type)?;
781        if !self.inputs.is_empty() {
782            write!(f, "(")?;
783            for (idx, input) in self.inputs.iter().enumerate() {
784                if idx == 0 && self.distinct {
785                    write!(f, "distinct ")?;
786                }
787                write!(f, "{:?}", input)?;
788                if idx != (self.inputs.len() - 1) {
789                    write!(f, ",")?;
790                }
791            }
792            if !self.order_by.is_empty() {
793                let clause_text = self.order_by.iter().map(|e| format!("{:?}", e)).join(", ");
794                write!(f, " order_by({})", clause_text)?;
795            }
796            write!(f, ")")?;
797        }
798        if !self.filter.always_true() {
799            write!(
800                f,
801                " filter({:?})",
802                self.filter.as_expr_unless_true().unwrap()
803            )?;
804        }
805        Ok(())
806    }
807}
808
809impl PlanAggCall {
810    pub fn rewrite_input_index(&mut self, mapping: ColIndexMapping) {
811        // modify input
812        self.inputs.iter_mut().for_each(|x| {
813            x.index = mapping.map(x.index);
814        });
815
816        // modify order_by exprs
817        self.order_by.iter_mut().for_each(|x| {
818            x.column_index = mapping.map(x.column_index);
819        });
820
821        // modify filter
822        let mut rewriter = IndexRewriter::new(mapping);
823        self.filter.conjunctions.iter_mut().for_each(|x| {
824            *x = rewriter.rewrite_expr(x.clone());
825        });
826    }
827
828    pub fn to_protobuf(&self) -> PbAggCall {
829        PbAggCall {
830            kind: match &self.agg_type {
831                AggType::Builtin(kind) => *kind,
832                AggType::UserDefined(_) => PbAggKind::UserDefined,
833                AggType::WrapScalar(_) => PbAggKind::WrapScalar,
834            }
835            .into(),
836            return_type: Some(self.return_type.to_protobuf()),
837            args: self.inputs.iter().map(InputRef::to_proto).collect(),
838            distinct: self.distinct,
839            order_by: self
840                .order_by
841                .iter()
842                .copied()
843                .map(ColumnOrder::to_protobuf)
844                .collect(),
845            filter: self.filter.as_expr_unless_true().map(|x| x.to_expr_proto()),
846            direct_args: self
847                .direct_args
848                .iter()
849                .map(|x| PbConstant {
850                    datum: Some(x.get_data().to_protobuf()),
851                    r#type: Some(x.return_type().to_protobuf()),
852                })
853                .collect(),
854            udf: match &self.agg_type {
855                AggType::UserDefined(udf) => Some(udf.clone()),
856                _ => None,
857            },
858            scalar: match &self.agg_type {
859                AggType::WrapScalar(expr) => Some(expr.clone()),
860                _ => None,
861            },
862        }
863    }
864
865    pub fn partial_to_total_agg_call(&self, partial_output_idx: usize) -> PlanAggCall {
866        let total_agg_type = self
867            .agg_type
868            .partial_to_total()
869            .expect("unsupported kinds shouldn't get here");
870        PlanAggCall {
871            agg_type: total_agg_type,
872            inputs: vec![InputRef::new(partial_output_idx, self.return_type.clone())],
873            order_by: vec![], // order must make no difference when we use 2-phase agg
874            filter: Condition::true_cond(),
875            ..self.clone()
876        }
877    }
878
879    pub fn count_star() -> Self {
880        PlanAggCall {
881            agg_type: PbAggKind::Count.into(),
882            return_type: DataType::Int64,
883            inputs: vec![],
884            distinct: false,
885            order_by: vec![],
886            filter: Condition::true_cond(),
887            direct_args: vec![],
888        }
889    }
890
891    pub fn with_condition(mut self, filter: Condition) -> Self {
892        self.filter = filter;
893        self
894    }
895
896    pub fn input_indices(&self) -> Vec<usize> {
897        self.inputs.iter().map(|input| input.index()).collect()
898    }
899}
900
901pub struct PlanAggCallDisplay<'a> {
902    pub plan_agg_call: &'a PlanAggCall,
903    pub input_schema: &'a Schema,
904}
905
906impl fmt::Debug for PlanAggCallDisplay<'_> {
907    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
908        let that = self.plan_agg_call;
909        write!(f, "{}", that.agg_type)?;
910        if !that.inputs.is_empty() {
911            write!(f, "(")?;
912            for (idx, input) in that.inputs.iter().enumerate() {
913                if idx == 0 && that.distinct {
914                    write!(f, "distinct ")?;
915                }
916                write!(
917                    f,
918                    "{}",
919                    InputRefDisplay {
920                        input_ref: input,
921                        input_schema: self.input_schema
922                    }
923                )?;
924                if idx != (that.inputs.len() - 1) {
925                    write!(f, ", ")?;
926                }
927            }
928            if !that.order_by.is_empty() {
929                write!(
930                    f,
931                    " order_by({})",
932                    that.order_by.iter().format_with(", ", |o, f| {
933                        f(&ColumnOrderDisplay {
934                            column_order: o,
935                            input_schema: self.input_schema,
936                        })
937                    })
938                )?;
939            }
940            write!(f, ")")?;
941        }
942
943        if !that.filter.always_true() {
944            write!(
945                f,
946                " filter({:?})",
947                ConditionDisplay {
948                    condition: &that.filter,
949                    input_schema: self.input_schema,
950                }
951            )?;
952        }
953        Ok(())
954    }
955}