risingwave_frontend/optimizer/plan_node/generic/
agg.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, BTreeSet, HashMap};
16use std::{fmt, vec};
17
18use itertools::{Either, Itertools};
19use pretty_xmlish::{Pretty, StrAssocArr};
20use risingwave_common::catalog::{Field, FieldDisplay, Schema};
21use risingwave_common::types::DataType;
22use risingwave_common::util::iter_util::ZipEqFast;
23use risingwave_common::util::sort_util::{ColumnOrder, ColumnOrderDisplay, OrderType};
24use risingwave_common::util::value_encoding::DatumToProtoExt;
25use risingwave_expr::aggregate::{AggType, PbAggKind, agg_types};
26use risingwave_expr::sig::{FUNCTION_REGISTRY, FuncBuilder};
27use risingwave_pb::expr::{PbAggCall, PbConstant};
28use risingwave_pb::stream_plan::{AggCallState as PbAggCallState, agg_call_state};
29
30use super::super::utils::TableCatalogBuilder;
31use super::{GenericPlanNode, GenericPlanRef, PhysicalPlanRef, impl_distill_unit_from_fields};
32use crate::TableCatalog;
33use crate::error::{ErrorCode, Result};
34use crate::expr::{Expr, ExprRewriter, ExprVisitor, InputRef, InputRefDisplay, Literal};
35use crate::optimizer::optimizer_context::OptimizerContextRef;
36use crate::optimizer::plan_node::batch::BatchPlanNodeMetadata;
37use crate::optimizer::plan_node::{BatchPlanRef, StreamPlanNodeMetadata, StreamPlanRef};
38use crate::optimizer::property::{
39    Distribution, FunctionalDependencySet, RequiredDist, WatermarkColumns,
40};
41use crate::stream_fragmenter::BuildFragmentGraphState;
42use crate::utils::{
43    ColIndexMapping, ColIndexMappingRewriteExt, Condition, ConditionDisplay, IndexRewriter,
44    IndexSet,
45};
46
47/// [`Agg`] groups input data by their group key and computes aggregation functions.
48///
49/// It corresponds to the `GROUP BY` operator in a SQL query statement together with the aggregate
50/// functions in the `SELECT` clause.
51///
52/// The output schema will first include the group key and then the aggregation calls.
53#[derive(Debug, Clone, PartialEq, Eq, Hash)]
54pub struct Agg<PlanRef> {
55    pub agg_calls: Vec<PlanAggCall>,
56    pub group_key: IndexSet,
57    pub grouping_sets: Vec<IndexSet>,
58    pub input: PlanRef,
59    pub enable_two_phase: bool,
60}
61
62impl<PlanRef: GenericPlanRef> Agg<PlanRef> {
63    pub(crate) fn clone_with_input<OtherPlanRef>(&self, input: OtherPlanRef) -> Agg<OtherPlanRef> {
64        Agg {
65            agg_calls: self.agg_calls.clone(),
66            group_key: self.group_key.clone(),
67            grouping_sets: self.grouping_sets.clone(),
68            input,
69            enable_two_phase: self.enable_two_phase,
70        }
71    }
72
73    pub(crate) fn rewrite_exprs(&mut self, r: &mut dyn ExprRewriter) {
74        self.agg_calls.iter_mut().for_each(|call| {
75            call.filter = call.filter.clone().rewrite_expr(r);
76        });
77    }
78
79    pub(crate) fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
80        self.agg_calls.iter().for_each(|call| {
81            call.filter.visit_expr(v);
82        });
83    }
84
85    pub(crate) fn output_len(&self) -> usize {
86        self.group_key.len() + self.agg_calls.len()
87    }
88
89    /// get the Mapping of columnIndex from input column index to output column index,if a input
90    /// column corresponds more than one out columns, mapping to any one
91    pub fn o2i_col_mapping(&self) -> ColIndexMapping {
92        let mut map = vec![None; self.output_len()];
93        for (i, key) in self.group_key.indices().enumerate() {
94            map[i] = Some(key);
95        }
96        ColIndexMapping::new(map, self.input.schema().len())
97    }
98
99    /// get the Mapping of columnIndex from input column index to out column index
100    pub fn i2o_col_mapping(&self) -> ColIndexMapping {
101        let mut map = vec![None; self.input.schema().len()];
102        for (i, key) in self.group_key.indices().enumerate() {
103            map[key] = Some(i);
104        }
105        ColIndexMapping::new(map, self.output_len())
106    }
107
108    fn two_phase_agg_forced(&self) -> bool {
109        self.ctx().session_ctx().config().force_two_phase_agg()
110    }
111
112    pub fn two_phase_agg_enabled(&self) -> bool {
113        self.enable_two_phase
114    }
115
116    pub(crate) fn can_two_phase_agg(&self) -> bool {
117        self.two_phase_agg_enabled()
118            && !self.agg_calls.is_empty()
119            && self.agg_calls.iter().all(|call| {
120                let agg_type_ok = !matches!(call.agg_type, agg_types::simply_cannot_two_phase!());
121                let order_ok = matches!(
122                    call.agg_type,
123                    agg_types::result_unaffected_by_order_by!()
124                        | AggType::Builtin(PbAggKind::ApproxPercentile)
125                ) || call.order_by.is_empty();
126                let distinct_ok =
127                    matches!(call.agg_type, agg_types::result_unaffected_by_distinct!())
128                        || !call.distinct;
129                agg_type_ok && order_ok && distinct_ok
130            })
131    }
132
133    /// Must try two phase agg iff we are forced to, and we satisfy the constraints.
134    pub(crate) fn must_try_two_phase_agg(&self) -> bool {
135        self.two_phase_agg_forced() && self.can_two_phase_agg()
136    }
137
138    /// Generally used by two phase hash agg.
139    /// If input dist already satisfies hash agg distribution,
140    /// it will be more expensive to do two phase agg, should just do shuffle agg.
141    pub(crate) fn hash_agg_dist_satisfied_by_input_dist(&self, input_dist: &Distribution) -> bool {
142        let required_dist =
143            RequiredDist::shard_by_key(self.input.schema().len(), &self.group_key.to_vec());
144        input_dist.satisfies(&required_dist)
145    }
146
147    /// See if all stream aggregation calls have a stateless local agg counterpart.
148    pub(crate) fn all_local_aggs_are_stateless(&self, stream_input_append_only: bool) -> bool {
149        self.agg_calls.iter().all(|c| {
150            matches!(c.agg_type, agg_types::single_value_state!())
151                || (matches!(c.agg_type, agg_types::single_value_state_iff_in_append_only!() if stream_input_append_only))
152        })
153    }
154
155    pub(crate) fn eowc_window_column(
156        &self,
157        input_watermark_columns: &WatermarkColumns,
158    ) -> Result<usize> {
159        let group_key_with_wtmk = self
160            .group_key
161            .indices()
162            .filter_map(|idx| {
163                input_watermark_columns
164                    .get_group(idx)
165                    .map(|group| (idx, group))
166            })
167            .collect::<Vec<_>>();
168
169        if group_key_with_wtmk.is_empty() {
170            return Err(ErrorCode::NotSupported(
171                "Emit-On-Window-Close mode requires a watermark column in GROUP BY.".to_owned(),
172                "Please try to GROUP BY a watermark column".to_owned(),
173            )
174            .into());
175        }
176        if group_key_with_wtmk.len() == 1
177            || group_key_with_wtmk
178                .iter()
179                .map(|(_, group)| group)
180                .all_equal()
181        {
182            // 1. only one watermark column, should be the window column
183            // 2. all watermark columns belong to the same group, choose the first one as the window column
184            return Ok(group_key_with_wtmk[0].0);
185        }
186        Err(ErrorCode::NotSupported(
187            "Emit-On-Window-Close mode requires that watermark columns in GROUP BY are derived from the same upstream column.".to_owned(),
188            "Please try to remove undesired columns from GROUP BY".to_owned(),
189        )
190        .into())
191    }
192
193    pub fn new(agg_calls: Vec<PlanAggCall>, group_key: IndexSet, input: PlanRef) -> Self {
194        let enable_two_phase = input.ctx().session_ctx().config().enable_two_phase_agg();
195        Self {
196            agg_calls,
197            group_key,
198            input,
199            grouping_sets: vec![],
200            enable_two_phase,
201        }
202    }
203
204    pub fn with_grouping_sets(mut self, grouping_sets: Vec<IndexSet>) -> Self {
205        self.grouping_sets = grouping_sets;
206        self
207    }
208
209    pub fn with_enable_two_phase(mut self, enable_two_phase: bool) -> Self {
210        self.enable_two_phase = enable_two_phase;
211        self
212    }
213}
214
215impl Agg<BatchPlanRef> {
216    // Check if the input is already sorted on group keys.
217    pub(crate) fn input_provides_order_on_group_keys(&self) -> bool {
218        let mut input_order_prefix = IndexSet::empty();
219        for input_order_col in &self.input.order().column_orders {
220            if !self.group_key.contains(input_order_col.column_index) {
221                break;
222            }
223            input_order_prefix.insert(input_order_col.column_index);
224        }
225        self.group_key == input_order_prefix
226    }
227}
228
229impl<PlanRef: GenericPlanRef> GenericPlanNode for Agg<PlanRef> {
230    fn schema(&self) -> Schema {
231        let fields = self
232            .group_key
233            .indices()
234            .map(|i| self.input.schema().fields()[i].clone())
235            .chain(self.agg_calls.iter().map(|agg_call| {
236                let plan_agg_call_display = PlanAggCallDisplay {
237                    plan_agg_call: agg_call,
238                    input_schema: self.input.schema(),
239                };
240                let name = format!("{:?}", plan_agg_call_display);
241                Field::with_name(agg_call.return_type.clone(), name)
242            }))
243            .collect();
244        Schema { fields }
245    }
246
247    fn stream_key(&self) -> Option<Vec<usize>> {
248        Some((0..self.group_key.len()).collect())
249    }
250
251    fn ctx(&self) -> OptimizerContextRef {
252        self.input.ctx()
253    }
254
255    fn functional_dependency(&self) -> FunctionalDependencySet {
256        let output_len = self.output_len();
257        let _input_len = self.input.schema().len();
258        let mut fd_set =
259            FunctionalDependencySet::with_key(output_len, &(0..self.group_key.len()).collect_vec());
260        // take group keys from input_columns, then grow the target size to column_cnt
261        let i2o = self.i2o_col_mapping();
262        for fd in self.input.functional_dependency().as_dependencies() {
263            if let Some(fd) = i2o.rewrite_functional_dependency(fd) {
264                fd_set.add_functional_dependency(fd);
265            }
266        }
267        fd_set
268    }
269}
270
271pub enum AggCallState {
272    Value,
273    MaterializedInput(Box<MaterializedInputState>),
274}
275
276impl AggCallState {
277    pub fn into_prost(self, state: &mut BuildFragmentGraphState) -> PbAggCallState {
278        PbAggCallState {
279            inner: Some(match self {
280                AggCallState::Value => {
281                    agg_call_state::Inner::ValueState(agg_call_state::ValueState {})
282                }
283                AggCallState::MaterializedInput(s) => {
284                    agg_call_state::Inner::MaterializedInputState(
285                        agg_call_state::MaterializedInputState {
286                            table: Some(
287                                s.table
288                                    .with_id(state.gen_table_id_wrapped())
289                                    .to_internal_table_prost(),
290                            ),
291                            included_upstream_indices: s
292                                .included_upstream_indices
293                                .into_iter()
294                                .map(|x| x as _)
295                                .collect(),
296                            table_value_indices: s
297                                .table_value_indices
298                                .into_iter()
299                                .map(|x| x as _)
300                                .collect(),
301                            order_columns: s
302                                .order_columns
303                                .into_iter()
304                                .map(|x| x.to_protobuf())
305                                .collect(),
306                        },
307                    )
308                }
309            }),
310        }
311    }
312}
313
314pub struct MaterializedInputState {
315    pub table: TableCatalog,
316    pub included_upstream_indices: Vec<usize>,
317    pub table_value_indices: Vec<usize>,
318    pub order_columns: Vec<ColumnOrder>,
319}
320
321impl Agg<StreamPlanRef> {
322    pub fn infer_tables(
323        &self,
324        me: impl StreamPlanNodeMetadata,
325        vnode_col_idx: Option<usize>,
326        window_col_idx: Option<usize>,
327    ) -> (
328        TableCatalog,
329        Vec<AggCallState>,
330        HashMap<usize, TableCatalog>,
331    ) {
332        (
333            self.infer_intermediate_state_table(&me, vnode_col_idx, window_col_idx),
334            self.infer_stream_agg_state(&me, vnode_col_idx, window_col_idx),
335            self.infer_distinct_dedup_tables(&me, vnode_col_idx, window_col_idx),
336        )
337    }
338
339    fn get_ordered_group_key(&self, window_col_idx: Option<usize>) -> Vec<usize> {
340        if let Some(window_col_idx) = window_col_idx {
341            assert!(self.group_key.contains(window_col_idx));
342            Either::Left(
343                std::iter::once(window_col_idx).chain(
344                    self.group_key
345                        .indices()
346                        .filter(move |&i| i != window_col_idx),
347                ),
348            )
349        } else {
350            Either::Right(self.group_key.indices())
351        }
352        .collect()
353    }
354
355    /// Create a new table builder with group key columns added.
356    ///
357    /// # Returns
358    ///
359    /// - table builder with group key columns added
360    /// - included upstream indices
361    /// - column mapping from upstream to table
362    fn create_table_builder(
363        &self,
364        _ctx: OptimizerContextRef,
365        window_col_idx: Option<usize>,
366    ) -> (TableCatalogBuilder, Vec<usize>, BTreeMap<usize, usize>) {
367        // NOTE: this function should be called to get a table builder, so that all state tables
368        // created for Agg node have the same group key columns and pk ordering.
369        let mut table_builder = TableCatalogBuilder::default();
370
371        assert!(table_builder.columns().is_empty());
372        assert_eq!(table_builder.get_current_pk_len(), 0);
373
374        // add group key column to table builder
375        let mut included_upstream_indices = vec![];
376        let mut column_mapping = BTreeMap::new();
377        let in_fields = self.input.schema().fields();
378        for idx in self.group_key.indices() {
379            let tbl_col_idx = table_builder.add_column(&in_fields[idx]);
380            included_upstream_indices.push(idx);
381            column_mapping.insert(idx, tbl_col_idx);
382        }
383
384        // configure state table primary key (ordering)
385        let ordered_group_key = self.get_ordered_group_key(window_col_idx);
386        for idx in ordered_group_key {
387            table_builder.add_order_column(column_mapping[&idx], OrderType::ascending());
388        }
389
390        (table_builder, included_upstream_indices, column_mapping)
391    }
392
393    /// Infer `AggCallState`s for streaming agg.
394    pub fn infer_stream_agg_state(
395        &self,
396        me: impl StreamPlanNodeMetadata,
397        vnode_col_idx: Option<usize>,
398        window_col_idx: Option<usize>,
399    ) -> Vec<AggCallState> {
400        let in_fields = self.input.schema().fields().to_vec();
401        let in_pks = self.input.expect_stream_key().to_vec();
402        let in_append_only = self.input.append_only();
403        let in_dist_key = self.input.distribution().dist_column_indices().to_vec();
404
405        let gen_materialized_input_state = |sort_keys: Vec<(OrderType, usize)>,
406                                            extra_keys: Vec<usize>,
407                                            include_keys: Vec<usize>|
408         -> MaterializedInputState {
409            let (mut table_builder, mut included_upstream_indices, mut column_mapping) =
410                self.create_table_builder(me.ctx(), window_col_idx);
411            let read_prefix_len_hint = table_builder.get_current_pk_len();
412
413            let mut order_columns = vec![];
414            let mut table_value_indices = BTreeSet::new(); // table column indices of value columns
415            let mut add_column =
416                |upstream_idx, order_type, table_builder: &mut TableCatalogBuilder| {
417                    column_mapping.entry(upstream_idx).or_insert_with(|| {
418                        let table_col_idx = table_builder.add_column(&in_fields[upstream_idx]);
419                        if let Some(order_type) = order_type {
420                            table_builder.add_order_column(table_col_idx, order_type);
421                            order_columns.push(ColumnOrder::new(upstream_idx, order_type));
422                        }
423                        included_upstream_indices.push(upstream_idx);
424                        table_col_idx
425                    });
426                    table_value_indices.insert(column_mapping[&upstream_idx]);
427                };
428
429            for (order_type, idx) in sort_keys {
430                add_column(idx, Some(order_type), &mut table_builder);
431            }
432            for idx in extra_keys {
433                add_column(idx, Some(OrderType::ascending()), &mut table_builder);
434            }
435            for idx in include_keys {
436                add_column(idx, None, &mut table_builder);
437            }
438
439            let mapping =
440                ColIndexMapping::with_included_columns(&included_upstream_indices, in_fields.len());
441            let tb_dist = mapping.rewrite_dist_key(&in_dist_key);
442            if let Some(tb_vnode_idx) = vnode_col_idx.and_then(|idx| mapping.try_map(idx)) {
443                table_builder.set_vnode_col_idx(tb_vnode_idx);
444            }
445
446            // set value indices to reduce ser/de overhead
447            let table_value_indices = table_value_indices.into_iter().collect_vec();
448            table_builder.set_value_indices(table_value_indices.clone());
449
450            MaterializedInputState {
451                table: table_builder.build(tb_dist.unwrap_or_default(), read_prefix_len_hint),
452                included_upstream_indices,
453                table_value_indices,
454                order_columns,
455            }
456        };
457
458        self.agg_calls
459            .iter()
460            .map(|agg_call| match agg_call.agg_type {
461                agg_types::single_value_state_iff_in_append_only!() if in_append_only => {
462                    AggCallState::Value
463                }
464                agg_types::single_value_state!() => AggCallState::Value,
465                agg_types::materialized_input_state!() => {
466                    // columns with order requirement in state table
467                    let sort_keys = {
468                        match agg_call.agg_type {
469                            AggType::Builtin(PbAggKind::Min) => {
470                                vec![(OrderType::ascending(), agg_call.inputs[0].index)]
471                            }
472                            AggType::Builtin(PbAggKind::Max) => {
473                                vec![(OrderType::descending(), agg_call.inputs[0].index)]
474                            }
475                            AggType::Builtin(
476                                PbAggKind::FirstValue
477                                | PbAggKind::LastValue
478                                | PbAggKind::StringAgg
479                                | PbAggKind::ArrayAgg
480                                | PbAggKind::JsonbAgg,
481                            )
482                            | AggType::WrapScalar(_) => {
483                                if agg_call.order_by.is_empty() {
484                                    me.ctx().warn_to_user(format!(
485                                        "{} without ORDER BY may produce non-deterministic result",
486                                        agg_call.agg_type,
487                                    ));
488                                }
489                                agg_call
490                                    .order_by
491                                    .iter()
492                                    .map(|o| {
493                                        (
494                                            if matches!(
495                                                agg_call.agg_type,
496                                                AggType::Builtin(PbAggKind::LastValue)
497                                            ) {
498                                                o.order_type.reverse()
499                                            } else {
500                                                o.order_type
501                                            },
502                                            o.column_index,
503                                        )
504                                    })
505                                    .collect()
506                            }
507                            AggType::Builtin(PbAggKind::JsonbObjectAgg) => agg_call
508                                .order_by
509                                .iter()
510                                .map(|o| (o.order_type, o.column_index))
511                                .collect(),
512                            _ => unreachable!(),
513                        }
514                    };
515
516                    // columns to ensure each row unique
517                    let extra_keys = if agg_call.distinct {
518                        // if distinct, use distinct keys as extra keys
519                        let distinct_key = agg_call.inputs[0].index;
520                        vec![distinct_key]
521                    } else {
522                        // if not distinct, use primary keys as extra keys
523                        in_pks.clone()
524                    };
525
526                    // other columns that should be contained in state table
527                    let include_keys = match agg_call.agg_type {
528                        // `agg_types::materialized_input_state` except for `min`/`max`
529                        AggType::Builtin(
530                            PbAggKind::FirstValue
531                            | PbAggKind::LastValue
532                            | PbAggKind::StringAgg
533                            | PbAggKind::ArrayAgg
534                            | PbAggKind::JsonbAgg
535                            | PbAggKind::JsonbObjectAgg,
536                        )
537                        | AggType::WrapScalar(_) => {
538                            agg_call.inputs.iter().map(|i| i.index).collect()
539                        }
540                        _ => vec![],
541                    };
542
543                    let state = gen_materialized_input_state(sort_keys, extra_keys, include_keys);
544                    AggCallState::MaterializedInput(Box::new(state))
545                }
546                agg_types::rewritten!() => {
547                    unreachable!("should have been rewritten")
548                }
549                agg_types::unimplemented_in_stream!() => {
550                    unreachable!("should have been banned")
551                }
552                AggType::Builtin(
553                    PbAggKind::Unspecified | PbAggKind::UserDefined | PbAggKind::WrapScalar,
554                ) => {
555                    unreachable!("invalid agg kind")
556                }
557            })
558            .collect()
559    }
560
561    /// table schema:
562    /// group key | state for AGG1 | state for AGG2 | ...
563    pub fn infer_intermediate_state_table(
564        &self,
565        me: impl GenericPlanRef,
566        vnode_col_idx: Option<usize>,
567        window_col_idx: Option<usize>,
568    ) -> TableCatalog {
569        let mut out_fields = me.schema().fields().to_vec();
570
571        // rewrite data types in fields
572        let in_append_only = self.input.append_only();
573        for (agg_call, field) in self
574            .agg_calls
575            .iter()
576            .zip_eq_fast(&mut out_fields[self.group_key.len()..])
577        {
578            let agg_kind = match agg_call.agg_type {
579                AggType::UserDefined(_) => {
580                    // for user defined aggregate, the state type is always BYTEA
581                    field.data_type = DataType::Bytea;
582                    continue;
583                }
584                AggType::WrapScalar(_) => {
585                    // for wrapped scalar function, the state is always NULL
586                    continue;
587                }
588                AggType::Builtin(kind) => kind,
589            };
590            let sig = FUNCTION_REGISTRY
591                .get(
592                    agg_kind,
593                    &agg_call
594                        .inputs
595                        .iter()
596                        .map(|input| input.data_type.clone())
597                        .collect_vec(),
598                    &agg_call.return_type,
599                )
600                .expect("agg not found");
601            // in_append_only: whether the input is append-only
602            // sig.is_append_only(): whether the agg function has append-only version
603            match (in_append_only, sig.is_append_only()) {
604                (false, true) => {
605                    // we use materialized input state for non-retractable aggregate function.
606                    // for backward compatibility, the state type is same as the return type.
607                    // its values in the intermediate state table are always null.
608                }
609                (true, true) => {
610                    // use append-only version
611                    if let FuncBuilder::Aggregate {
612                        append_only_state_type: Some(state_type),
613                        ..
614                    } = &sig.build
615                    {
616                        field.data_type = state_type.clone();
617                    }
618                }
619                (_, false) => {
620                    // there is only retractable version, use it
621                    if let FuncBuilder::Aggregate {
622                        retractable_state_type: Some(state_type),
623                        ..
624                    } = &sig.build
625                    {
626                        field.data_type = state_type.clone();
627                    }
628                }
629            }
630        }
631        let in_dist_key = self.input.distribution().dist_column_indices().to_vec();
632        let n_group_key_cols = self.group_key.len();
633
634        let (mut table_builder, _, _) = self.create_table_builder(me.ctx(), window_col_idx);
635        let read_prefix_len_hint = table_builder.get_current_pk_len();
636
637        for field in out_fields.iter().skip(n_group_key_cols) {
638            table_builder.add_column(field);
639        }
640
641        let mapping = self.i2o_col_mapping();
642        let tb_dist = mapping.rewrite_dist_key(&in_dist_key).unwrap_or_default();
643        if let Some(tb_vnode_idx) = vnode_col_idx.and_then(|idx| mapping.try_map(idx)) {
644            table_builder.set_vnode_col_idx(tb_vnode_idx);
645        }
646
647        // the result_table is composed of group_key and all agg_call's values, so the value_indices
648        // of this table should skip group_key.len().
649        table_builder.set_value_indices((n_group_key_cols..out_fields.len()).collect());
650        table_builder.build(tb_dist, read_prefix_len_hint)
651    }
652
653    /// Infer dedup tables for distinct agg calls, partitioned by distinct columns.
654    /// Since distinct agg calls only dedup on the first argument, the key of the result map is
655    /// `usize`, i.e. the distinct column index.
656    ///
657    /// Dedup table schema:
658    /// group key | distinct key | count for AGG1(distinct x) | count for AGG2(distinct x) | ...
659    pub fn infer_distinct_dedup_tables(
660        &self,
661        me: impl GenericPlanRef,
662        vnode_col_idx: Option<usize>,
663        window_col_idx: Option<usize>,
664    ) -> HashMap<usize, TableCatalog> {
665        let in_dist_key = self.input.distribution().dist_column_indices().to_vec();
666        let in_fields = self.input.schema().fields();
667
668        self.agg_calls
669            .iter()
670            .enumerate()
671            .filter(|(_, call)| call.distinct) // only distinct agg calls need dedup table
672            .into_group_map_by(|(_, call)| call.inputs[0].index) // one table per distinct column
673            .into_iter()
674            .map(|(distinct_col, indices_and_calls)| {
675                let (mut table_builder, mut key_cols, _) =
676                    self.create_table_builder(me.ctx(), window_col_idx);
677                let table_col_idx = table_builder.add_column(&in_fields[distinct_col]);
678                table_builder.add_order_column(table_col_idx, OrderType::ascending());
679                key_cols.push(distinct_col);
680
681                let read_prefix_len_hint = table_builder.get_current_pk_len();
682
683                // Agg calls with same distinct column share the same dedup table, but they may have
684                // different filter conditions, so the count of occurrence of one distinct key may
685                // differ among different calls. We add one column for each call in the dedup table.
686                for (call_index, _) in indices_and_calls {
687                    table_builder.add_column(&Field {
688                        data_type: DataType::Int64,
689                        name: format!("count_for_agg_call_{}", call_index),
690                    });
691                }
692                table_builder
693                    .set_value_indices((key_cols.len()..table_builder.columns().len()).collect());
694
695                let mapping = ColIndexMapping::with_included_columns(&key_cols, in_fields.len());
696                if let Some(idx) = vnode_col_idx.and_then(|idx| mapping.try_map(idx)) {
697                    table_builder.set_vnode_col_idx(idx);
698                }
699                let dist_key = mapping.rewrite_dist_key(&in_dist_key).unwrap_or_default();
700                let table = table_builder.build(dist_key, read_prefix_len_hint);
701                (distinct_col, table)
702            })
703            .collect()
704    }
705}
706
707impl<PlanRef: GenericPlanRef> Agg<PlanRef> {
708    pub fn decompose(self) -> (Vec<PlanAggCall>, IndexSet, Vec<IndexSet>, PlanRef, bool) {
709        (
710            self.agg_calls,
711            self.group_key,
712            self.grouping_sets,
713            self.input,
714            self.enable_two_phase,
715        )
716    }
717
718    pub fn fields_pretty<'a>(&self) -> StrAssocArr<'a> {
719        let last = ("aggs", self.agg_calls_pretty());
720        if !self.group_key.is_empty() {
721            let first = ("group_key", self.group_key_pretty());
722            vec![first, last]
723        } else {
724            vec![last]
725        }
726    }
727
728    fn agg_calls_pretty<'a>(&self) -> Pretty<'a> {
729        let f = |plan_agg_call| {
730            Pretty::debug(&PlanAggCallDisplay {
731                plan_agg_call,
732                input_schema: self.input.schema(),
733            })
734        };
735        Pretty::Array(self.agg_calls.iter().map(f).collect())
736    }
737
738    fn group_key_pretty<'a>(&self) -> Pretty<'a> {
739        let f = |i| Pretty::display(&FieldDisplay(self.input.schema().fields.get(i).unwrap()));
740        Pretty::Array(self.group_key.indices().map(f).collect())
741    }
742}
743
744impl_distill_unit_from_fields!(Agg, GenericPlanRef);
745
746/// Rewritten version of [`crate::expr::AggCall`] which uses `InputRef` instead of `ExprImpl`.
747/// Refer to [`crate::optimizer::plan_node::logical_agg::LogicalAggBuilder::try_rewrite_agg_call`]
748/// for more details.
749#[derive(Clone, PartialEq, Eq, Hash)]
750pub struct PlanAggCall {
751    /// Type of aggregation function
752    pub agg_type: AggType,
753
754    /// Data type of the returned column
755    pub return_type: DataType,
756
757    /// Column indexes of input columns.
758    ///
759    /// Its length can be:
760    /// - 0 (`RowCount`)
761    /// - 1 (`Max`, `Min`)
762    /// - 2 (`StringAgg`).
763    ///
764    /// Usually, we mark the first column as the aggregated column.
765    pub inputs: Vec<InputRef>,
766
767    pub distinct: bool,
768    pub order_by: Vec<ColumnOrder>,
769    /// Selective aggregation: only the input rows for which
770    /// `filter` evaluates to `true` will be fed to the aggregate function.
771    pub filter: Condition,
772    pub direct_args: Vec<Literal>,
773}
774
775impl fmt::Debug for PlanAggCall {
776    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
777        write!(f, "{}", self.agg_type)?;
778        if !self.inputs.is_empty() {
779            write!(f, "(")?;
780            for (idx, input) in self.inputs.iter().enumerate() {
781                if idx == 0 && self.distinct {
782                    write!(f, "distinct ")?;
783                }
784                write!(f, "{:?}", input)?;
785                if idx != (self.inputs.len() - 1) {
786                    write!(f, ",")?;
787                }
788            }
789            if !self.order_by.is_empty() {
790                let clause_text = self.order_by.iter().map(|e| format!("{:?}", e)).join(", ");
791                write!(f, " order_by({})", clause_text)?;
792            }
793            write!(f, ")")?;
794        }
795        if !self.filter.always_true() {
796            write!(
797                f,
798                " filter({:?})",
799                self.filter.as_expr_unless_true().unwrap()
800            )?;
801        }
802        Ok(())
803    }
804}
805
806impl PlanAggCall {
807    pub fn rewrite_input_index(&mut self, mapping: ColIndexMapping) {
808        // modify input
809        self.inputs.iter_mut().for_each(|x| {
810            x.index = mapping.map(x.index);
811        });
812
813        // modify order_by exprs
814        self.order_by.iter_mut().for_each(|x| {
815            x.column_index = mapping.map(x.column_index);
816        });
817
818        // modify filter
819        let mut rewriter = IndexRewriter::new(mapping);
820        self.filter.conjunctions.iter_mut().for_each(|x| {
821            *x = rewriter.rewrite_expr(x.clone());
822        });
823    }
824
825    pub fn to_protobuf(&self) -> PbAggCall {
826        PbAggCall {
827            kind: match &self.agg_type {
828                AggType::Builtin(kind) => *kind,
829                AggType::UserDefined(_) => PbAggKind::UserDefined,
830                AggType::WrapScalar(_) => PbAggKind::WrapScalar,
831            }
832            .into(),
833            return_type: Some(self.return_type.to_protobuf()),
834            args: self.inputs.iter().map(InputRef::to_proto).collect(),
835            distinct: self.distinct,
836            order_by: self
837                .order_by
838                .iter()
839                .copied()
840                .map(ColumnOrder::to_protobuf)
841                .collect(),
842            filter: self.filter.as_expr_unless_true().map(|x| x.to_expr_proto()),
843            direct_args: self
844                .direct_args
845                .iter()
846                .map(|x| PbConstant {
847                    datum: Some(x.get_data().to_protobuf()),
848                    r#type: Some(x.return_type().to_protobuf()),
849                })
850                .collect(),
851            udf: match &self.agg_type {
852                AggType::UserDefined(udf) => Some(udf.clone()),
853                _ => None,
854            },
855            scalar: match &self.agg_type {
856                AggType::WrapScalar(expr) => Some(expr.clone()),
857                _ => None,
858            },
859        }
860    }
861
862    pub fn partial_to_total_agg_call(&self, partial_output_idx: usize) -> PlanAggCall {
863        let total_agg_type = self
864            .agg_type
865            .partial_to_total()
866            .expect("unsupported kinds shouldn't get here");
867        PlanAggCall {
868            agg_type: total_agg_type,
869            inputs: vec![InputRef::new(partial_output_idx, self.return_type.clone())],
870            order_by: vec![], // order must make no difference when we use 2-phase agg
871            filter: Condition::true_cond(),
872            ..self.clone()
873        }
874    }
875
876    pub fn count_star() -> Self {
877        PlanAggCall {
878            agg_type: PbAggKind::Count.into(),
879            return_type: DataType::Int64,
880            inputs: vec![],
881            distinct: false,
882            order_by: vec![],
883            filter: Condition::true_cond(),
884            direct_args: vec![],
885        }
886    }
887
888    pub fn with_condition(mut self, filter: Condition) -> Self {
889        self.filter = filter;
890        self
891    }
892
893    pub fn input_indices(&self) -> Vec<usize> {
894        self.inputs.iter().map(|input| input.index()).collect()
895    }
896}
897
898pub struct PlanAggCallDisplay<'a> {
899    pub plan_agg_call: &'a PlanAggCall,
900    pub input_schema: &'a Schema,
901}
902
903impl fmt::Debug for PlanAggCallDisplay<'_> {
904    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
905        let that = self.plan_agg_call;
906        write!(f, "{}", that.agg_type)?;
907        if !that.inputs.is_empty() {
908            write!(f, "(")?;
909            for (idx, input) in that.inputs.iter().enumerate() {
910                if idx == 0 && that.distinct {
911                    write!(f, "distinct ")?;
912                }
913                write!(
914                    f,
915                    "{}",
916                    InputRefDisplay {
917                        input_ref: input,
918                        input_schema: self.input_schema
919                    }
920                )?;
921                if idx != (that.inputs.len() - 1) {
922                    write!(f, ", ")?;
923                }
924            }
925            if !that.order_by.is_empty() {
926                write!(
927                    f,
928                    " order_by({})",
929                    that.order_by.iter().format_with(", ", |o, f| {
930                        f(&ColumnOrderDisplay {
931                            column_order: o,
932                            input_schema: self.input_schema,
933                        })
934                    })
935                )?;
936            }
937            write!(f, ")")?;
938        }
939
940        if !that.filter.always_true() {
941            write!(
942                f,
943                " filter({:?})",
944                ConditionDisplay {
945                    condition: &that.filter,
946                    input_schema: self.input_schema,
947                }
948            )?;
949        }
950        Ok(())
951    }
952}