risingwave_frontend/optimizer/plan_node/
logical_agg.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use fixedbitset::FixedBitSet;
16use itertools::Itertools;
17use risingwave_common::types::{DataType, ScalarImpl};
18use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
19use risingwave_common::{bail, bail_not_implemented, not_implemented};
20use risingwave_expr::aggregate::{AggType, PbAggKind, agg_types};
21
22use super::generic::{self, Agg, GenericPlanRef, PlanAggCall, ProjectBuilder};
23use super::utils::impl_distill_by_unit;
24use super::{
25    BatchHashAgg, BatchSimpleAgg, ColPrunable, ExprRewritable, Logical, LogicalPlanRef as PlanRef,
26    PlanBase, PlanTreeNodeUnary, PredicatePushdown, StreamHashAgg, StreamPlanRef, StreamProject,
27    StreamShare, StreamSimpleAgg, StreamStatelessSimpleAgg, ToBatch, ToStream,
28};
29use crate::error::{ErrorCode, Result, RwError};
30use crate::expr::{
31    AggCall, Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, FunctionCall, InputRef, Literal,
32    OrderBy, OrderByExpr, WindowFunction,
33};
34use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
35use crate::optimizer::plan_node::generic::GenericPlanNode;
36use crate::optimizer::plan_node::stream_global_approx_percentile::StreamGlobalApproxPercentile;
37use crate::optimizer::plan_node::stream_local_approx_percentile::StreamLocalApproxPercentile;
38use crate::optimizer::plan_node::stream_row_merge::StreamRowMerge;
39use crate::optimizer::plan_node::{
40    BatchSortAgg, ColumnPruningContext, LogicalDedup, LogicalProject, PredicatePushdownContext,
41    RewriteStreamContext, ToStreamContext, gen_filter_and_pushdown,
42};
43use crate::optimizer::property::{Distribution, Order, RequiredDist};
44use crate::utils::{
45    ColIndexMapping, ColIndexMappingRewriteExt, Condition, GroupBy, IndexSet, Substitute,
46};
47
48pub struct AggInfo {
49    pub calls: Vec<PlanAggCall>,
50    pub col_mapping: ColIndexMapping,
51}
52
53/// `SeparatedAggInfo` is used to separate normal and approx percentile aggs.
54pub struct SeparatedAggInfo {
55    normal: AggInfo,
56    approx: AggInfo,
57}
58
59/// `LogicalAgg` groups input data by their group key and computes aggregation functions.
60///
61/// It corresponds to the `GROUP BY` operator in a SQL query statement together with the aggregate
62/// functions in the `SELECT` clause.
63///
64/// The output schema will first include the group key and then the aggregation calls.
65#[derive(Clone, Debug, PartialEq, Eq, Hash)]
66pub struct LogicalAgg {
67    pub base: PlanBase<Logical>,
68    core: Agg<PlanRef>,
69}
70
71impl LogicalAgg {
72    /// Generate plan for stateless 2-phase streaming agg.
73    /// Should only be used iff input is distributed. Input must be converted to stream form.
74    fn gen_stateless_two_phase_streaming_agg_plan(
75        &self,
76        stream_input: StreamPlanRef,
77    ) -> Result<StreamPlanRef> {
78        debug_assert!(self.group_key().is_empty());
79
80        // ====== Handle approx percentile aggs
81        let (
82            non_approx_percentile_col_mapping,
83            approx_percentile_col_mapping,
84            approx_percentile,
85            core,
86        ) = self.prepare_approx_percentile(stream_input.clone())?;
87
88        if core.agg_calls.is_empty() {
89            if let Some(approx_percentile) = approx_percentile {
90                return Ok(approx_percentile);
91            };
92            bail!("expected at least one agg call");
93        }
94
95        let need_row_merge: bool = Self::need_row_merge(&approx_percentile);
96
97        // ====== Handle normal aggs
98        let total_agg_calls = core
99            .agg_calls
100            .iter()
101            .enumerate()
102            .map(|(partial_output_idx, agg_call)| {
103                agg_call.partial_to_total_agg_call(partial_output_idx)
104            })
105            .collect_vec();
106        let local_agg = StreamStatelessSimpleAgg::new(core)?;
107        let exchange =
108            RequiredDist::single().streaming_enforce_if_not_satisfies(local_agg.into())?;
109
110        let must_output_per_barrier = need_row_merge;
111        let global_agg = new_stream_simple_agg(
112            Agg::new(total_agg_calls, IndexSet::empty(), exchange),
113            must_output_per_barrier,
114        )?;
115
116        // ====== Merge approx percentile and normal aggs
117        Self::add_row_merge_if_needed(
118            approx_percentile,
119            global_agg.into(),
120            approx_percentile_col_mapping,
121            non_approx_percentile_col_mapping,
122        )
123    }
124
125    /// Generate plan for stateless/stateful 2-phase streaming agg.
126    /// Should only be used iff input is distributed.
127    /// Input must be converted to stream form.
128    fn gen_vnode_two_phase_streaming_agg_plan(
129        &self,
130        stream_input: StreamPlanRef,
131        dist_key: &[usize],
132    ) -> Result<StreamPlanRef> {
133        let (
134            non_approx_percentile_col_mapping,
135            approx_percentile_col_mapping,
136            approx_percentile,
137            core,
138        ) = self.prepare_approx_percentile(stream_input.clone())?;
139
140        if core.agg_calls.is_empty() {
141            if let Some(approx_percentile) = approx_percentile {
142                return Ok(approx_percentile);
143            };
144            bail!("expected at least one agg call");
145        }
146        let need_row_merge = Self::need_row_merge(&approx_percentile);
147
148        // Generate vnode via project
149        // TODO(kwannoel): We should apply Project optimization rules here.
150        let input_col_num = stream_input.schema().len(); // get schema len before moving `stream_input`.
151        let project = StreamProject::new(generic::Project::with_vnode_col(stream_input, dist_key));
152        let vnode_col_idx = project.base.schema().len() - 1;
153
154        // Generate local agg step
155        let mut local_group_key = self.group_key().clone();
156        local_group_key.insert(vnode_col_idx);
157        let n_local_group_key = local_group_key.len();
158        let local_agg = new_stream_hash_agg(
159            Agg::new(core.agg_calls.to_vec(), local_group_key, project.into()),
160            Some(vnode_col_idx),
161        )?;
162        // Global group key excludes vnode.
163        let local_agg_group_key_cardinality = local_agg.group_key().len();
164        let local_group_key_without_vnode =
165            &local_agg.group_key().to_vec()[..local_agg_group_key_cardinality - 1];
166        let global_group_key = local_agg
167            .i2o_col_mapping()
168            .rewrite_dist_key(local_group_key_without_vnode)
169            .expect("some input group key could not be mapped");
170
171        // Generate global agg step
172        let global_agg = if self.group_key().is_empty() {
173            let exchange =
174                RequiredDist::single().streaming_enforce_if_not_satisfies(local_agg.into())?;
175            let must_output_per_barrier = need_row_merge;
176            let global_agg = new_stream_simple_agg(
177                Agg::new(
178                    core.agg_calls
179                        .iter()
180                        .enumerate()
181                        .map(|(partial_output_idx, agg_call)| {
182                            agg_call
183                                .partial_to_total_agg_call(n_local_group_key + partial_output_idx)
184                        })
185                        .collect(),
186                    global_group_key.into_iter().collect(),
187                    exchange,
188                ),
189                must_output_per_barrier,
190            )?;
191            global_agg.into()
192        } else {
193            // the `RowMergeExec` has not supported keyed merge
194            assert!(!need_row_merge);
195            let exchange = RequiredDist::shard_by_key(input_col_num, &global_group_key)
196                .streaming_enforce_if_not_satisfies(local_agg.into())?;
197            // Local phase should have reordered the group keys into their required order.
198            // we can just follow it.
199            let global_agg = new_stream_hash_agg(
200                Agg::new(
201                    core.agg_calls
202                        .iter()
203                        .enumerate()
204                        .map(|(partial_output_idx, agg_call)| {
205                            agg_call
206                                .partial_to_total_agg_call(n_local_group_key + partial_output_idx)
207                        })
208                        .collect(),
209                    global_group_key.into_iter().collect(),
210                    exchange,
211                ),
212                None,
213            )?;
214            global_agg.into()
215        };
216        Self::add_row_merge_if_needed(
217            approx_percentile,
218            global_agg,
219            approx_percentile_col_mapping,
220            non_approx_percentile_col_mapping,
221        )
222    }
223
224    fn gen_single_plan(&self, stream_input: StreamPlanRef) -> Result<StreamPlanRef> {
225        let input = RequiredDist::single().streaming_enforce_if_not_satisfies(stream_input)?;
226        let core = self.core.clone_with_input(input);
227        Ok(new_stream_simple_agg(core, false)?.into())
228    }
229
230    fn gen_shuffle_plan(&self, stream_input: StreamPlanRef) -> Result<StreamPlanRef> {
231        let input =
232            RequiredDist::shard_by_key(stream_input.schema().len(), &self.group_key().to_vec())
233                .streaming_enforce_if_not_satisfies(stream_input)?;
234        let core = self.core.clone_with_input(input);
235        Ok(new_stream_hash_agg(core, None)?.into())
236    }
237
238    /// Generates distributed stream plan.
239    fn gen_dist_stream_agg_plan(&self, stream_input: StreamPlanRef) -> Result<StreamPlanRef> {
240        use super::stream::prelude::*;
241
242        let input_dist = stream_input.distribution();
243        debug_assert!(*input_dist != Distribution::Broadcast);
244
245        // Shuffle agg
246        // If we have group key, and we won't try two phase agg optimization at all,
247        // we will always choose shuffle agg over single agg.
248        if !self.group_key().is_empty() && !self.core.must_try_two_phase_agg() {
249            return self.gen_shuffle_plan(stream_input);
250        }
251
252        // Standalone agg
253        // If no group key, and cannot two phase agg, we have to use single plan.
254        if self.group_key().is_empty() && !self.core.can_two_phase_agg() {
255            return self.gen_single_plan(stream_input);
256        }
257
258        debug_assert!(if !self.group_key().is_empty() {
259            self.core.must_try_two_phase_agg()
260        } else {
261            self.core.can_two_phase_agg()
262        });
263
264        // Stateless 2-phase simple agg
265        // can be applied on stateless simple agg calls,
266        // with input distributed by [`Distribution::AnyShard`]
267        if self.group_key().is_empty()
268            && self
269                .core
270                .all_local_aggs_are_stateless(stream_input.append_only())
271            && input_dist.satisfies(&RequiredDist::AnyShard)
272        {
273            return self.gen_stateless_two_phase_streaming_agg_plan(stream_input);
274        }
275
276        // If input is [`Distribution::SomeShard`] and we must try to use two phase agg,
277        // The only remaining strategy is Vnode-based 2-phase agg.
278        // We shall first distribute it by PK,
279        // so it obeys consistent hash strategy via [`Distribution::HashShard`].
280        let stream_input =
281            if *input_dist == Distribution::SomeShard && self.core.must_try_two_phase_agg() {
282                RequiredDist::shard_by_key(
283                    stream_input.schema().len(),
284                    stream_input.expect_stream_key(),
285                )
286                .streaming_enforce_if_not_satisfies(stream_input)?
287            } else {
288                stream_input
289            };
290        let input_dist = stream_input.distribution();
291
292        // Vnode-based 2-phase agg
293        // can be applied on agg calls not affected by order,
294        // with input distributed by dist_key.
295        match input_dist {
296            Distribution::HashShard(dist_key) | Distribution::UpstreamHashShard(dist_key, _)
297                if (self.group_key().is_empty()
298                    || !self.core.hash_agg_dist_satisfied_by_input_dist(input_dist)) =>
299            {
300                let dist_key = dist_key.clone();
301                return self.gen_vnode_two_phase_streaming_agg_plan(stream_input, &dist_key);
302            }
303            _ => {}
304        }
305
306        // Fallback to shuffle or single, if we can't generate any 2-phase plans.
307        if !self.group_key().is_empty() {
308            self.gen_shuffle_plan(stream_input)
309        } else {
310            self.gen_single_plan(stream_input)
311        }
312    }
313
314    /// Prepares metadata and the `approx_percentile` plan, if there's one present.
315    /// It may modify `core.agg_calls` to separate normal agg and approx percentile agg,
316    /// and `core.input` to share the input via `StreamShare`,
317    /// to both approx percentile agg and normal agg.
318    fn prepare_approx_percentile(
319        &self,
320        stream_input: StreamPlanRef,
321    ) -> Result<(
322        ColIndexMapping,
323        ColIndexMapping,
324        Option<StreamPlanRef>,
325        Agg<StreamPlanRef>,
326    )> {
327        let SeparatedAggInfo { normal, approx } = self.separate_normal_and_special_agg();
328
329        let AggInfo {
330            calls: non_approx_percentile_agg_calls,
331            col_mapping: non_approx_percentile_col_mapping,
332        } = normal;
333        let AggInfo {
334            calls: approx_percentile_agg_calls,
335            col_mapping: approx_percentile_col_mapping,
336        } = approx;
337        if !self.group_key().is_empty() && !approx_percentile_agg_calls.is_empty() {
338            bail_not_implemented!(
339                "two-phase streaming approx percentile aggregation with group key, \
340             please use single phase aggregation instead"
341            );
342        }
343
344        // Either we have approx percentile aggs and non_approx percentile aggs,
345        // or we have at least 2 approx percentile aggs.
346        let needs_row_merge = (!non_approx_percentile_agg_calls.is_empty()
347            && !approx_percentile_agg_calls.is_empty())
348            || approx_percentile_agg_calls.len() >= 2;
349        let input = if needs_row_merge {
350            // If there's row merge, we need to share the input.
351            StreamShare::new_from_input(stream_input.clone()).into()
352        } else {
353            stream_input
354        };
355        let mut core = self.core.clone_with_input(input);
356        core.agg_calls = non_approx_percentile_agg_calls;
357
358        let approx_percentile =
359            self.build_approx_percentile_aggs(core.input.clone(), &approx_percentile_agg_calls)?;
360        Ok((
361            non_approx_percentile_col_mapping,
362            approx_percentile_col_mapping,
363            approx_percentile,
364            core,
365        ))
366    }
367
368    fn need_row_merge(approx_percentile: &Option<StreamPlanRef>) -> bool {
369        approx_percentile.is_some()
370    }
371
372    /// Add `RowMerge` if needed
373    fn add_row_merge_if_needed(
374        approx_percentile: Option<StreamPlanRef>,
375        global_agg: StreamPlanRef,
376        approx_percentile_col_mapping: ColIndexMapping,
377        non_approx_percentile_col_mapping: ColIndexMapping,
378    ) -> Result<StreamPlanRef> {
379        // just for assert
380        let need_row_merge = Self::need_row_merge(&approx_percentile);
381
382        if let Some(approx_percentile) = approx_percentile {
383            assert!(need_row_merge);
384            let row_merge = StreamRowMerge::new(
385                approx_percentile,
386                global_agg,
387                approx_percentile_col_mapping,
388                non_approx_percentile_col_mapping,
389            )?;
390            Ok(row_merge.into())
391        } else {
392            assert!(!need_row_merge);
393            Ok(global_agg)
394        }
395    }
396
397    fn separate_normal_and_special_agg(&self) -> SeparatedAggInfo {
398        let estimated_len = self.agg_calls().len() - 1;
399        let mut approx_percentile_agg_calls = Vec::with_capacity(estimated_len);
400        let mut non_approx_percentile_agg_calls = Vec::with_capacity(estimated_len);
401        let mut approx_percentile_col_mapping = Vec::with_capacity(estimated_len);
402        let mut non_approx_percentile_col_mapping = Vec::with_capacity(estimated_len);
403        for (output_idx, agg_call) in self.agg_calls().iter().enumerate() {
404            if agg_call.agg_type == AggType::Builtin(PbAggKind::ApproxPercentile) {
405                approx_percentile_agg_calls.push(agg_call.clone());
406                approx_percentile_col_mapping.push(Some(output_idx));
407            } else {
408                non_approx_percentile_agg_calls.push(agg_call.clone());
409                non_approx_percentile_col_mapping.push(Some(output_idx));
410            }
411        }
412        let normal = AggInfo {
413            calls: non_approx_percentile_agg_calls,
414            col_mapping: ColIndexMapping::new(
415                non_approx_percentile_col_mapping,
416                self.agg_calls().len(),
417            ),
418        };
419        let approx = AggInfo {
420            calls: approx_percentile_agg_calls,
421            col_mapping: ColIndexMapping::new(
422                approx_percentile_col_mapping,
423                self.agg_calls().len(),
424            ),
425        };
426        SeparatedAggInfo { normal, approx }
427    }
428
429    fn build_approx_percentile_agg(
430        &self,
431        input: StreamPlanRef,
432        approx_percentile_agg_call: &PlanAggCall,
433    ) -> Result<StreamPlanRef> {
434        let local_approx_percentile =
435            StreamLocalApproxPercentile::new(input, approx_percentile_agg_call)?;
436        let exchange = RequiredDist::single()
437            .streaming_enforce_if_not_satisfies(local_approx_percentile.into())?;
438        let global_approx_percentile =
439            StreamGlobalApproxPercentile::new(exchange, approx_percentile_agg_call);
440        Ok(global_approx_percentile.into())
441    }
442
443    /// If only 1 approx percentile, just return it.
444    /// Otherwise build a tree of approx percentile with `MergeProject`.
445    /// e.g.
446    /// ApproxPercentile(col1, 0.5) as x,
447    /// ApproxPercentile(col2, 0.5) as y,
448    /// ApproxPercentile(col3, 0.5) as z
449    /// will be built as
450    ///        `MergeProject`
451    ///       /          \
452    ///  `MergeProject`       z
453    ///  /        \
454    /// x          y
455    fn build_approx_percentile_aggs(
456        &self,
457        input: StreamPlanRef,
458        approx_percentile_agg_call: &[PlanAggCall],
459    ) -> Result<Option<StreamPlanRef>> {
460        if approx_percentile_agg_call.is_empty() {
461            return Ok(None);
462        }
463        let approx_percentile_plans: Vec<_> = approx_percentile_agg_call
464            .iter()
465            .map(|agg_call| self.build_approx_percentile_agg(input.clone(), agg_call))
466            .try_collect()?;
467        assert!(!approx_percentile_plans.is_empty());
468        let mut iter = approx_percentile_plans.into_iter();
469        let mut acc = iter.next().unwrap();
470        for (current_size, plan) in iter.enumerate().map(|(i, p)| (i + 1, p)) {
471            let new_size = current_size + 1;
472            let row_merge = StreamRowMerge::new(
473                acc,
474                plan,
475                ColIndexMapping::identity_or_none(current_size, new_size),
476                ColIndexMapping::new(vec![Some(current_size)], new_size),
477            )?;
478            acc = row_merge.into();
479        }
480        Ok(Some(acc))
481    }
482
483    pub fn core(&self) -> &Agg<PlanRef> {
484        &self.core
485    }
486}
487
488/// `LogicalAggBuilder` extracts agg calls and references to group columns from select list and
489/// build the plan like `LogicalAgg - LogicalProject`.
490/// it is constructed by `group_exprs` and collect and rewrite the expression in selection and
491/// having clause.
492pub struct LogicalAggBuilder {
493    /// the builder of the input Project
494    input_proj_builder: ProjectBuilder,
495    /// the group key column indices in the project's output
496    group_key: IndexSet,
497    /// the grouping sets
498    grouping_sets: Vec<IndexSet>,
499    /// the agg calls
500    agg_calls: Vec<PlanAggCall>,
501    /// the error during the expression rewriting
502    error: Option<RwError>,
503    /// If `is_in_filter_clause` is true, it means that
504    /// we are processing filter clause.
505    /// This field is needed because input refs in these clauses
506    /// are allowed to refer to any columns, while those not in filter
507    /// clause are only allowed to refer to group keys.
508    is_in_filter_clause: bool,
509}
510
511impl LogicalAggBuilder {
512    fn new(group_by: GroupBy, input_schema_len: usize) -> Result<Self> {
513        let mut input_proj_builder = ProjectBuilder::default();
514
515        let mut gen_group_key_and_grouping_sets =
516            |grouping_sets: Vec<Vec<ExprImpl>>| -> Result<(IndexSet, Vec<IndexSet>)> {
517                let grouping_sets: Vec<IndexSet> = grouping_sets
518                    .into_iter()
519                    .map(|set| {
520                        set.into_iter()
521                            .map(|expr| input_proj_builder.add_expr(&expr))
522                            .try_collect()
523                            .map_err(|err| not_implemented!("{err} inside GROUP BY"))
524                    })
525                    .try_collect()?;
526
527                // Construct group key based on grouping sets.
528                let group_key = grouping_sets
529                    .iter()
530                    .fold(FixedBitSet::with_capacity(input_schema_len), |acc, x| {
531                        acc.union(&x.to_bitset()).collect()
532                    });
533
534                Ok((IndexSet::from_iter(group_key.ones()), grouping_sets))
535            };
536
537        let (group_key, grouping_sets) = match group_by {
538            GroupBy::GroupKey(group_key) => {
539                let group_key = group_key
540                    .into_iter()
541                    .map(|expr| input_proj_builder.add_expr(&expr))
542                    .try_collect()
543                    .map_err(|err| not_implemented!("{err} inside GROUP BY"))?;
544                (group_key, vec![])
545            }
546            GroupBy::GroupingSets(grouping_sets) => gen_group_key_and_grouping_sets(grouping_sets)?,
547            GroupBy::Rollup(rollup) => {
548                // Convert rollup to grouping sets.
549                let grouping_sets = (0..=rollup.len())
550                    .map(|n| {
551                        rollup
552                            .iter()
553                            .take(n)
554                            .flat_map(|x| x.iter().cloned())
555                            .collect_vec()
556                    })
557                    .collect_vec();
558                gen_group_key_and_grouping_sets(grouping_sets)?
559            }
560            GroupBy::Cube(cube) => {
561                // Convert cube to grouping sets.
562                let grouping_sets = cube
563                    .into_iter()
564                    .powerset()
565                    .map(|x| x.into_iter().flatten().collect_vec())
566                    .collect_vec();
567                gen_group_key_and_grouping_sets(grouping_sets)?
568            }
569        };
570
571        Ok(LogicalAggBuilder {
572            group_key,
573            grouping_sets,
574            agg_calls: vec![],
575            error: None,
576            input_proj_builder,
577            is_in_filter_clause: false,
578        })
579    }
580
581    pub fn build(self, input: PlanRef) -> LogicalAgg {
582        // This LogicalProject focuses on the exprs in aggregates and GROUP BY clause.
583        let logical_project = LogicalProject::with_core(self.input_proj_builder.build(input));
584
585        // This LogicalAgg focuses on calculating the aggregates and grouping.
586        Agg::new(self.agg_calls, self.group_key, logical_project.into())
587            .with_grouping_sets(self.grouping_sets)
588            .into()
589    }
590
591    fn rewrite_with_error(&mut self, expr: ExprImpl) -> Result<ExprImpl> {
592        let rewritten_expr = self.rewrite_expr(expr);
593        if let Some(error) = self.error.take() {
594            return Err(error);
595        }
596        Ok(rewritten_expr)
597    }
598
599    /// check if the expression is a group by key, and try to return the group key
600    pub fn try_as_group_expr(&self, expr: &ExprImpl) -> Option<usize> {
601        if let Some(input_index) = self.input_proj_builder.expr_index(expr)
602            && let Some(index) = self
603                .group_key
604                .indices()
605                .position(|group_key| group_key == input_index)
606        {
607            return Some(index);
608        }
609        None
610    }
611
612    fn schema_agg_start_offset(&self) -> usize {
613        self.group_key.len()
614    }
615
616    /// Rewrite [`AggCall`] if needed, and push it into the builder using `push_agg_call`.
617    /// This is shared by [`LogicalAggBuilder`] and `LogicalOverWindowBuilder`.
618    pub(crate) fn general_rewrite_agg_call(
619        agg_call: AggCall,
620        mut push_agg_call: impl FnMut(AggCall) -> Result<InputRef>,
621    ) -> Result<ExprImpl> {
622        match agg_call.agg_type {
623            // Rewrite avg to cast(sum as avg_return_type) / count.
624            AggType::Builtin(PbAggKind::Avg) => {
625                assert_eq!(agg_call.args.len(), 1);
626
627                let sum = ExprImpl::from(push_agg_call(AggCall::new(
628                    PbAggKind::Sum.into(),
629                    agg_call.args.clone(),
630                    agg_call.distinct,
631                    agg_call.order_by.clone(),
632                    agg_call.filter.clone(),
633                    agg_call.direct_args.clone(),
634                )?)?)
635                .cast_explicit(&agg_call.return_type())?;
636
637                let count = ExprImpl::from(push_agg_call(AggCall::new(
638                    PbAggKind::Count.into(),
639                    agg_call.args.clone(),
640                    agg_call.distinct,
641                    agg_call.order_by.clone(),
642                    agg_call.filter.clone(),
643                    agg_call.direct_args.clone(),
644                )?)?);
645
646                Ok(FunctionCall::new(ExprType::Divide, Vec::from([sum, count]))?.into())
647            }
648            // We compute `var_samp` as
649            // (sum(sq) - sum * sum / count) / (count - 1)
650            // and `var_pop` as
651            // (sum(sq) - sum * sum / count) / count
652            // Since we don't have the square function, we use the plain Multiply for squaring,
653            // which is in a sense more general than the pow function, especially when calculating
654            // covariances in the future. Also we don't have the sqrt function for rooting, so we
655            // use pow(x, 0.5) to simulate
656            AggType::Builtin(
657                kind @ (PbAggKind::StddevPop
658                | PbAggKind::StddevSamp
659                | PbAggKind::VarPop
660                | PbAggKind::VarSamp),
661            ) => {
662                let arg = agg_call.args().iter().exactly_one().unwrap();
663                let squared_arg = ExprImpl::from(FunctionCall::new(
664                    ExprType::Multiply,
665                    vec![arg.clone(), arg.clone()],
666                )?);
667
668                let sum_of_sq = ExprImpl::from(push_agg_call(AggCall::new(
669                    PbAggKind::Sum.into(),
670                    vec![squared_arg],
671                    agg_call.distinct,
672                    agg_call.order_by.clone(),
673                    agg_call.filter.clone(),
674                    agg_call.direct_args.clone(),
675                )?)?)
676                .cast_explicit(&agg_call.return_type())?;
677
678                let sum = ExprImpl::from(push_agg_call(AggCall::new(
679                    PbAggKind::Sum.into(),
680                    agg_call.args.clone(),
681                    agg_call.distinct,
682                    agg_call.order_by.clone(),
683                    agg_call.filter.clone(),
684                    agg_call.direct_args.clone(),
685                )?)?)
686                .cast_explicit(&agg_call.return_type())?;
687
688                let count = ExprImpl::from(push_agg_call(AggCall::new(
689                    PbAggKind::Count.into(),
690                    agg_call.args.clone(),
691                    agg_call.distinct,
692                    agg_call.order_by.clone(),
693                    agg_call.filter.clone(),
694                    agg_call.direct_args.clone(),
695                )?)?);
696
697                let zero = ExprImpl::literal_int(0);
698                let one = ExprImpl::literal_int(1);
699
700                let squared_sum = ExprImpl::from(FunctionCall::new(
701                    ExprType::Multiply,
702                    vec![sum.clone(), sum],
703                )?);
704
705                let raw_numerator = ExprImpl::from(FunctionCall::new(
706                    ExprType::Subtract,
707                    vec![
708                        sum_of_sq,
709                        ExprImpl::from(FunctionCall::new(
710                            ExprType::Divide,
711                            vec![squared_sum, count.clone()],
712                        )?),
713                    ],
714                )?);
715
716                // We need to check for potential accuracy issues that may occasionally lead to results less than 0.
717                let numerator_type = raw_numerator.return_type();
718                let numerator = ExprImpl::from(FunctionCall::new(
719                    ExprType::Greatest,
720                    vec![raw_numerator, zero.clone().cast_explicit(&numerator_type)?],
721                )?);
722
723                let denominator = match kind {
724                    PbAggKind::VarPop | PbAggKind::StddevPop => count.clone(),
725                    PbAggKind::VarSamp | PbAggKind::StddevSamp => ExprImpl::from(
726                        FunctionCall::new(ExprType::Subtract, vec![count.clone(), one.clone()])?,
727                    ),
728                    _ => unreachable!(),
729                };
730
731                let mut target = ExprImpl::from(FunctionCall::new(
732                    ExprType::Divide,
733                    vec![numerator, denominator],
734                )?);
735
736                if matches!(kind, PbAggKind::StddevPop | PbAggKind::StddevSamp) {
737                    target = ExprImpl::from(FunctionCall::new(ExprType::Sqrt, vec![target])?);
738                }
739
740                let null = ExprImpl::from(Literal::new(None, agg_call.return_type()));
741                let case_cond = match kind {
742                    PbAggKind::VarPop | PbAggKind::StddevPop => {
743                        ExprImpl::from(FunctionCall::new(ExprType::Equal, vec![count, zero])?)
744                    }
745                    PbAggKind::VarSamp | PbAggKind::StddevSamp => ExprImpl::from(
746                        FunctionCall::new(ExprType::LessThanOrEqual, vec![count, one])?,
747                    ),
748                    _ => unreachable!(),
749                };
750
751                Ok(ExprImpl::from(FunctionCall::new(
752                    ExprType::Case,
753                    vec![case_cond, null, target],
754                )?))
755            }
756            AggType::Builtin(PbAggKind::ApproxPercentile) => {
757                if agg_call.order_by.sort_exprs[0].order_type == OrderType::descending() {
758                    // Rewrite DESC into 1.0-percentile for approx_percentile.
759                    let prev_percentile = agg_call.direct_args[0].clone();
760                    let new_percentile = 1.0
761                        - prev_percentile
762                            .get_data()
763                            .as_ref()
764                            .unwrap()
765                            .as_float64()
766                            .into_inner();
767                    let new_percentile = Some(ScalarImpl::Float64(new_percentile.into()));
768                    let new_percentile = Literal::new(new_percentile, DataType::Float64);
769                    let new_direct_args = vec![new_percentile, agg_call.direct_args[1].clone()];
770
771                    let new_agg_call = AggCall {
772                        order_by: OrderBy::any(),
773                        direct_args: new_direct_args,
774                        ..agg_call
775                    };
776                    Ok(push_agg_call(new_agg_call)?.into())
777                } else {
778                    let new_agg_call = AggCall {
779                        order_by: OrderBy::any(),
780                        ..agg_call
781                    };
782                    Ok(push_agg_call(new_agg_call)?.into())
783                }
784            }
785            AggType::Builtin(PbAggKind::ArgMin | PbAggKind::ArgMax) => {
786                let mut agg_call = agg_call;
787
788                let comparison_arg_type = agg_call.args[1].return_type();
789                match comparison_arg_type {
790                    DataType::Struct(_)
791                    | DataType::List(_)
792                    | DataType::Map(_)
793                    | DataType::Vector(_)
794                    | DataType::Jsonb => {
795                        bail!(format!(
796                            "{} does not support struct, array, map, vector, jsonb for comparison argument, got {}",
797                            agg_call.agg_type.to_string(),
798                            comparison_arg_type
799                        ));
800                    }
801                    _ => {}
802                }
803
804                let not_null_exprs: Vec<ExprImpl> = agg_call
805                    .args
806                    .iter()
807                    .map(|arg| -> Result<ExprImpl> {
808                        Ok(FunctionCall::new(ExprType::IsNotNull, vec![arg.clone()])?.into())
809                    })
810                    .try_collect()?;
811
812                let comparison_expr = agg_call.args[1].clone();
813                let mut order_exprs = vec![OrderByExpr {
814                    expr: comparison_expr,
815                    order_type: if agg_call.agg_type == AggType::Builtin(PbAggKind::ArgMin) {
816                        OrderType::ascending()
817                    } else {
818                        OrderType::descending()
819                    },
820                }];
821
822                order_exprs.extend(agg_call.order_by.sort_exprs);
823
824                let order_by = OrderBy::new(order_exprs);
825
826                let filter = agg_call.filter.clone().and(Condition {
827                    conjunctions: not_null_exprs,
828                });
829
830                agg_call.args.truncate(1);
831
832                let new_agg_call = AggCall {
833                    agg_type: AggType::Builtin(PbAggKind::FirstValue),
834                    order_by,
835                    filter,
836                    ..agg_call
837                };
838                Ok(push_agg_call(new_agg_call)?.into())
839            }
840            _ => Ok(push_agg_call(agg_call)?.into()),
841        }
842    }
843
844    /// Push a new agg call into the builder.
845    /// Return an `InputRef` to that agg call.
846    /// For existing agg calls, return an `InputRef` to the existing one.
847    fn push_agg_call(&mut self, agg_call: AggCall) -> Result<InputRef> {
848        let AggCall {
849            agg_type,
850            return_type,
851            args,
852            distinct,
853            order_by,
854            filter,
855            direct_args,
856        } = agg_call;
857
858        self.is_in_filter_clause = true;
859        // filter expr is not added to `input_proj_builder` as a whole. Special exprs incl
860        // subquery/agg/table are rejected in `bind_agg`.
861        let filter = filter.rewrite_expr(self);
862        self.is_in_filter_clause = false;
863
864        let args: Vec<_> = args
865            .iter()
866            .map(|expr| {
867                let index = self.input_proj_builder.add_expr(expr)?;
868                Ok(InputRef::new(index, expr.return_type()))
869            })
870            .try_collect()
871            .map_err(|err: &'static str| not_implemented!("{err} inside aggregation calls"))?;
872
873        let order_by: Vec<_> = order_by
874            .sort_exprs
875            .iter()
876            .map(|e| {
877                let index = self.input_proj_builder.add_expr(&e.expr)?;
878                Ok(ColumnOrder::new(index, e.order_type))
879            })
880            .try_collect()
881            .map_err(|err: &'static str| {
882                not_implemented!("{err} inside aggregation calls order by")
883            })?;
884
885        let plan_agg_call = PlanAggCall {
886            agg_type,
887            return_type: return_type.clone(),
888            inputs: args,
889            distinct,
890            order_by,
891            filter,
892            direct_args,
893        };
894
895        if let Some((pos, existing)) = self
896            .agg_calls
897            .iter()
898            .find_position(|&c| c == &plan_agg_call)
899        {
900            return Ok(InputRef::new(
901                self.schema_agg_start_offset() + pos,
902                existing.return_type.clone(),
903            ));
904        }
905        let index = self.schema_agg_start_offset() + self.agg_calls.len();
906        self.agg_calls.push(plan_agg_call);
907        Ok(InputRef::new(index, return_type))
908    }
909
910    /// When there is an agg call, there are 3 things to do:
911    /// 1. Rewrite `avg`, `var_samp`, etc. into a combination of `sum`, `count`, etc.;
912    /// 2. Add exprs in arguments to input `Project`;
913    /// 2. Add the agg call to current `Agg`, and return an `InputRef` to it.
914    ///
915    /// Note that the rewriter does not traverse into inputs of agg calls.
916    fn try_rewrite_agg_call(&mut self, mut agg_call: AggCall) -> Result<ExprImpl> {
917        if matches!(agg_call.agg_type, agg_types::must_have_order_by!())
918            && agg_call.order_by.sort_exprs.is_empty()
919        {
920            return Err(ErrorCode::InvalidInputSyntax(format!(
921                "Aggregation function {} requires ORDER BY clause",
922                agg_call.agg_type
923            ))
924            .into());
925        }
926
927        // try ignore ORDER BY if it doesn't affect the result
928        if matches!(
929            agg_call.agg_type,
930            agg_types::result_unaffected_by_order_by!()
931        ) {
932            agg_call.order_by = OrderBy::any();
933        }
934        // try ignore DISTINCT if it doesn't affect the result
935        if matches!(
936            agg_call.agg_type,
937            agg_types::result_unaffected_by_distinct!()
938        ) {
939            agg_call.distinct = false;
940        }
941
942        if matches!(agg_call.agg_type, AggType::Builtin(PbAggKind::Grouping)) {
943            if self.grouping_sets.is_empty() {
944                return Err(ErrorCode::NotSupported(
945                    "GROUPING must be used in a query with grouping sets".into(),
946                    "try to use grouping sets instead".into(),
947                )
948                .into());
949            }
950            if agg_call.args.len() >= 32 {
951                return Err(ErrorCode::InvalidInputSyntax(
952                    "GROUPING must have fewer than 32 arguments".into(),
953                )
954                .into());
955            }
956            if agg_call
957                .args
958                .iter()
959                .any(|x| self.try_as_group_expr(x).is_none())
960            {
961                return Err(ErrorCode::InvalidInputSyntax(
962                    "arguments to GROUPING must be grouping expressions of the associated query level"
963                        .into(),
964                ).into());
965            }
966        }
967
968        Self::general_rewrite_agg_call(agg_call, |agg_call| self.push_agg_call(agg_call))
969    }
970}
971
972impl ExprRewriter for LogicalAggBuilder {
973    fn rewrite_agg_call(&mut self, agg_call: AggCall) -> ExprImpl {
974        let dummy = Literal::new(None, agg_call.return_type()).into();
975        match self.try_rewrite_agg_call(agg_call) {
976            Ok(expr) => expr,
977            Err(err) => {
978                self.error = Some(err);
979                dummy
980            }
981        }
982    }
983
984    /// When there is an `FunctionCall` (outside of agg call), it must refers to a group column.
985    /// Or all `InputRef`s appears in it must refer to a group column.
986    fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
987        let expr = func_call.into();
988        if let Some(group_key) = self.try_as_group_expr(&expr) {
989            InputRef::new(group_key, expr.return_type()).into()
990        } else {
991            let (func_type, inputs, ret) = expr.into_function_call().unwrap().decompose();
992            let inputs = inputs
993                .into_iter()
994                .map(|expr| self.rewrite_expr(expr))
995                .collect();
996            FunctionCall::new_unchecked(func_type, inputs, ret).into()
997        }
998    }
999
1000    /// When there is an `WindowFunction` (outside of agg call), it must refers to a group column.
1001    /// Or all `InputRef`s appears in it must refer to a group column.
1002    fn rewrite_window_function(&mut self, window_func: WindowFunction) -> ExprImpl {
1003        let WindowFunction {
1004            args,
1005            partition_by,
1006            order_by,
1007            ..
1008        } = window_func;
1009        let args = args
1010            .into_iter()
1011            .map(|expr| self.rewrite_expr(expr))
1012            .collect();
1013        let partition_by = partition_by
1014            .into_iter()
1015            .map(|expr| self.rewrite_expr(expr))
1016            .collect();
1017        let order_by = order_by.rewrite_expr(self);
1018        WindowFunction {
1019            args,
1020            partition_by,
1021            order_by,
1022            ..window_func
1023        }
1024        .into()
1025    }
1026
1027    /// When there is an `InputRef` (outside of agg call), it must refers to a group column.
1028    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
1029        let expr = input_ref.into();
1030        if let Some(group_key) = self.try_as_group_expr(&expr) {
1031            InputRef::new(group_key, expr.return_type()).into()
1032        } else if self.is_in_filter_clause {
1033            InputRef::new(
1034                self.input_proj_builder.add_expr(&expr).unwrap(),
1035                expr.return_type(),
1036            )
1037            .into()
1038        } else {
1039            self.error = Some(
1040                ErrorCode::InvalidInputSyntax(
1041                    "column must appear in the GROUP BY clause or be used in an aggregate function"
1042                        .into(),
1043                )
1044                .into(),
1045            );
1046            expr
1047        }
1048    }
1049
1050    fn rewrite_subquery(&mut self, subquery: crate::expr::Subquery) -> ExprImpl {
1051        if subquery.is_correlated_by_depth(0) {
1052            self.error = Some(
1053                not_implemented!(
1054                    issue = 2275,
1055                    "correlated subquery in HAVING or SELECT with agg",
1056                )
1057                .into(),
1058            );
1059        }
1060        subquery.into()
1061    }
1062}
1063
1064impl From<Agg<PlanRef>> for LogicalAgg {
1065    fn from(core: Agg<PlanRef>) -> Self {
1066        let base = PlanBase::new_logical_with_core(&core);
1067        Self { base, core }
1068    }
1069}
1070
1071/// Because `From`/`Into` are not transitive
1072impl From<Agg<PlanRef>> for PlanRef {
1073    fn from(core: Agg<PlanRef>) -> Self {
1074        LogicalAgg::from(core).into()
1075    }
1076}
1077
1078impl LogicalAgg {
1079    /// get the Mapping of columnIndex from input column index to out column index
1080    pub fn i2o_col_mapping(&self) -> ColIndexMapping {
1081        self.core.i2o_col_mapping()
1082    }
1083
1084    /// `create` will analyze select exprs, group exprs and having, and construct a plan like
1085    ///
1086    /// ```text
1087    /// LogicalAgg -> LogicalProject -> input
1088    /// ```
1089    ///
1090    /// It also returns the rewritten select exprs and having that reference into the aggregated
1091    /// results.
1092    pub fn create(
1093        select_exprs: Vec<ExprImpl>,
1094        group_by: GroupBy,
1095        having: Option<ExprImpl>,
1096        input: PlanRef,
1097    ) -> Result<(PlanRef, Vec<ExprImpl>, Option<ExprImpl>)> {
1098        let mut agg_builder = LogicalAggBuilder::new(group_by, input.schema().len())?;
1099
1100        let rewritten_select_exprs = select_exprs
1101            .into_iter()
1102            .map(|expr| agg_builder.rewrite_with_error(expr))
1103            .collect::<Result<_>>()?;
1104        let rewritten_having = having
1105            .map(|expr| agg_builder.rewrite_with_error(expr))
1106            .transpose()?;
1107
1108        Ok((
1109            agg_builder.build(input).into(),
1110            rewritten_select_exprs,
1111            rewritten_having,
1112        ))
1113    }
1114
1115    /// Get a reference to the logical agg's agg calls.
1116    pub fn agg_calls(&self) -> &Vec<PlanAggCall> {
1117        &self.core.agg_calls
1118    }
1119
1120    /// Get a reference to the logical agg's group key.
1121    pub fn group_key(&self) -> &IndexSet {
1122        &self.core.group_key
1123    }
1124
1125    pub fn grouping_sets(&self) -> &Vec<IndexSet> {
1126        &self.core.grouping_sets
1127    }
1128
1129    pub fn decompose(self) -> (Vec<PlanAggCall>, IndexSet, Vec<IndexSet>, PlanRef, bool) {
1130        self.core.decompose()
1131    }
1132
1133    #[must_use]
1134    pub fn rewrite_with_input_agg(
1135        &self,
1136        input: PlanRef,
1137        agg_calls: &[PlanAggCall],
1138        mut input_col_change: ColIndexMapping,
1139    ) -> (Self, ColIndexMapping) {
1140        let agg_calls = agg_calls
1141            .iter()
1142            .cloned()
1143            .map(|mut agg_call| {
1144                agg_call.inputs.iter_mut().for_each(|i| {
1145                    *i = InputRef::new(input_col_change.map(i.index()), i.return_type())
1146                });
1147                agg_call.order_by.iter_mut().for_each(|o| {
1148                    o.column_index = input_col_change.map(o.column_index);
1149                });
1150                agg_call.filter = agg_call.filter.rewrite_expr(&mut input_col_change);
1151                agg_call
1152            })
1153            .collect();
1154        // This is the group key order should be after rewriting.
1155        let group_key_in_vec: Vec<usize> = self
1156            .group_key()
1157            .indices()
1158            .map(|key| input_col_change.map(key))
1159            .collect();
1160        // This is the group key order we get after rewriting.
1161        let group_key: IndexSet = group_key_in_vec.iter().cloned().collect();
1162        let grouping_sets = self
1163            .grouping_sets()
1164            .iter()
1165            .map(|set| set.indices().map(|key| input_col_change.map(key)).collect())
1166            .collect();
1167
1168        let new_agg = Agg::new(agg_calls, group_key.clone(), input)
1169            .with_grouping_sets(grouping_sets)
1170            .with_enable_two_phase(self.core().enable_two_phase);
1171
1172        // group_key remapping might cause an output column change, since group key actually is a
1173        // `FixedBitSet`.
1174        let mut out_col_change = vec![];
1175        for idx in group_key_in_vec {
1176            let pos = group_key.indices().position(|x| x == idx).unwrap();
1177            out_col_change.push(pos);
1178        }
1179        for i in (group_key.len())..new_agg.schema().len() {
1180            out_col_change.push(i);
1181        }
1182        let out_col_change =
1183            ColIndexMapping::with_remaining_columns(&out_col_change, new_agg.schema().len());
1184
1185        (new_agg.into(), out_col_change)
1186    }
1187}
1188
1189impl PlanTreeNodeUnary<Logical> for LogicalAgg {
1190    fn input(&self) -> PlanRef {
1191        self.core.input.clone()
1192    }
1193
1194    fn clone_with_input(&self, input: PlanRef) -> Self {
1195        Agg::new(self.agg_calls().to_vec(), self.group_key().clone(), input)
1196            .with_grouping_sets(self.grouping_sets().clone())
1197            .with_enable_two_phase(self.core().enable_two_phase)
1198            .into()
1199    }
1200
1201    fn rewrite_with_input(
1202        &self,
1203        input: PlanRef,
1204        input_col_change: ColIndexMapping,
1205    ) -> (Self, ColIndexMapping) {
1206        self.rewrite_with_input_agg(input, self.agg_calls(), input_col_change)
1207    }
1208}
1209
1210impl_plan_tree_node_for_unary! { Logical, LogicalAgg }
1211impl_distill_by_unit!(LogicalAgg, core, "LogicalAgg");
1212
1213impl ExprRewritable<Logical> for LogicalAgg {
1214    fn has_rewritable_expr(&self) -> bool {
1215        true
1216    }
1217
1218    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
1219        let mut core = self.core.clone();
1220        core.rewrite_exprs(r);
1221        Self {
1222            base: self.base.clone_with_new_plan_id(),
1223            core,
1224        }
1225        .into()
1226    }
1227}
1228
1229impl ExprVisitable for LogicalAgg {
1230    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
1231        self.core.visit_exprs(v);
1232    }
1233}
1234
1235impl ColPrunable for LogicalAgg {
1236    fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
1237        let group_key_required_cols = self.group_key().to_bitset();
1238
1239        let (agg_call_required_cols, agg_calls) = {
1240            let input_cnt = self.input().schema().len();
1241            let mut tmp = FixedBitSet::with_capacity(input_cnt);
1242            let group_key_cardinality = self.group_key().len();
1243            let new_agg_calls = required_cols
1244                .iter()
1245                .filter(|&&index| index >= group_key_cardinality)
1246                .map(|&index| {
1247                    let index = index - group_key_cardinality;
1248                    let agg_call = self.agg_calls()[index].clone();
1249                    tmp.extend(agg_call.inputs.iter().map(|x| x.index()));
1250                    tmp.extend(agg_call.order_by.iter().map(|x| x.column_index));
1251                    // collect columns used in aggregate filter expressions
1252                    for i in &agg_call.filter.conjunctions {
1253                        tmp.union_with(&i.collect_input_refs(input_cnt));
1254                    }
1255                    agg_call
1256                })
1257                .collect_vec();
1258            (tmp, new_agg_calls)
1259        };
1260
1261        let input_required_cols = {
1262            let mut tmp = FixedBitSet::with_capacity(self.input().schema().len());
1263            tmp.union_with(&group_key_required_cols);
1264            tmp.union_with(&agg_call_required_cols);
1265            tmp.ones().collect_vec()
1266        };
1267        let input_col_change = ColIndexMapping::with_remaining_columns(
1268            &input_required_cols,
1269            self.input().schema().len(),
1270        );
1271        let agg = {
1272            let input = self.input().prune_col(&input_required_cols, ctx);
1273            let (agg, output_col_change) =
1274                self.rewrite_with_input_agg(input, &agg_calls, input_col_change);
1275            assert!(output_col_change.is_identity());
1276            agg
1277        };
1278        let new_output_cols = {
1279            // group key were never pruned or even re-ordered in current impl
1280            let group_key_cardinality = agg.group_key().len();
1281            let mut tmp = (0..group_key_cardinality).collect_vec();
1282            tmp.extend(
1283                required_cols
1284                    .iter()
1285                    .filter(|&&index| index >= group_key_cardinality),
1286            );
1287            tmp
1288        };
1289        if new_output_cols == required_cols {
1290            // current schema perfectly fit the required columns
1291            agg.into()
1292        } else {
1293            // some columns are not needed, or the order need to be adjusted.
1294            // so we did a projection to remove/reorder the columns.
1295            let mapping =
1296                &ColIndexMapping::with_remaining_columns(&new_output_cols, self.schema().len());
1297            let output_required_cols = required_cols
1298                .iter()
1299                .map(|&idx| mapping.map(idx))
1300                .collect_vec();
1301            let src_size = agg.schema().len();
1302            LogicalProject::with_mapping(
1303                agg.into(),
1304                ColIndexMapping::with_remaining_columns(&output_required_cols, src_size),
1305            )
1306            .into()
1307        }
1308    }
1309}
1310
1311impl PredicatePushdown for LogicalAgg {
1312    fn predicate_pushdown(
1313        &self,
1314        predicate: Condition,
1315        ctx: &mut PredicatePushdownContext,
1316    ) -> PlanRef {
1317        let num_group_key = self.group_key().len();
1318        let num_agg_calls = self.agg_calls().len();
1319        assert!(num_group_key + num_agg_calls == self.schema().len());
1320
1321        // SimpleAgg should be skipped because the predicate either references agg_calls
1322        // or is const.
1323        // If the filter references agg_calls, we can not push it.
1324        // When it is constantly true, pushing is useless and may actually cause more evaluation
1325        // cost of the predicate.
1326        // When it is constantly false, pushing is wrong - the old plan returns 0 rows but new one
1327        // returns 1 row.
1328        if num_group_key == 0 {
1329            return gen_filter_and_pushdown(self, predicate, Condition::true_cond(), ctx);
1330        }
1331
1332        // If the filter references agg_calls, we can not push it.
1333        let mut agg_call_columns = FixedBitSet::with_capacity(num_group_key + num_agg_calls);
1334        agg_call_columns.insert_range(num_group_key..num_group_key + num_agg_calls);
1335        let (agg_call_pred, pushed_predicate) = predicate.split_disjoint(&agg_call_columns);
1336
1337        // convert the predicate to one that references the child of the agg
1338        let mut subst = Substitute {
1339            mapping: self
1340                .group_key()
1341                .indices()
1342                .enumerate()
1343                .map(|(i, group_key)| {
1344                    InputRef::new(group_key, self.schema().fields()[i].data_type()).into()
1345                })
1346                .collect(),
1347        };
1348        let pushed_predicate = pushed_predicate.rewrite_expr(&mut subst);
1349
1350        gen_filter_and_pushdown(self, agg_call_pred, pushed_predicate, ctx)
1351    }
1352}
1353
1354impl ToBatch for LogicalAgg {
1355    fn to_batch(&self) -> Result<crate::optimizer::plan_node::BatchPlanRef> {
1356        self.to_batch_with_order_required(&Order::any())
1357    }
1358
1359    // TODO(rc): `to_batch_with_order_required` seems to be useless after we decide to use
1360    // `BatchSortAgg` only when input is already sorted
1361    fn to_batch_with_order_required(
1362        &self,
1363        required_order: &Order,
1364    ) -> Result<crate::optimizer::plan_node::BatchPlanRef> {
1365        let input = self.input().to_batch()?;
1366        let new_logical = self.core.clone_with_input(input);
1367        let agg_plan = if self.group_key().is_empty() {
1368            BatchSimpleAgg::new(new_logical).into()
1369        } else if self.ctx().session_ctx().config().batch_enable_sort_agg()
1370            && new_logical.input_provides_order_on_group_keys()
1371        {
1372            BatchSortAgg::new(new_logical).into()
1373        } else {
1374            BatchHashAgg::new(new_logical).into()
1375        };
1376        required_order.enforce_if_not_satisfies(agg_plan)
1377    }
1378}
1379
1380fn find_or_append_row_count(mut logical: Agg<StreamPlanRef>) -> (Agg<StreamPlanRef>, usize) {
1381    // `HashAgg`/`SimpleAgg` executors require a `count(*)` to correctly build changes, so
1382    // append a `count(*)` if not exists.
1383    let count_star = PlanAggCall::count_star();
1384    let row_count_idx = if let Some((idx, _)) = logical
1385        .agg_calls
1386        .iter()
1387        .find_position(|&c| c == &count_star)
1388    {
1389        idx
1390    } else {
1391        let idx = logical.agg_calls.len();
1392        logical.agg_calls.push(count_star);
1393        idx
1394    };
1395    (logical, row_count_idx)
1396}
1397
1398fn new_stream_simple_agg(
1399    core: Agg<StreamPlanRef>,
1400    must_output_per_barrier: bool,
1401) -> Result<StreamSimpleAgg> {
1402    let (logical, row_count_idx) = find_or_append_row_count(core);
1403    StreamSimpleAgg::new(logical, row_count_idx, must_output_per_barrier)
1404}
1405
1406fn new_stream_hash_agg(
1407    core: Agg<StreamPlanRef>,
1408    vnode_col_idx: Option<usize>,
1409) -> Result<StreamHashAgg> {
1410    let (logical, row_count_idx) = find_or_append_row_count(core);
1411    StreamHashAgg::new(logical, vnode_col_idx, row_count_idx)
1412}
1413
1414impl ToStream for LogicalAgg {
1415    fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<StreamPlanRef> {
1416        use super::stream::prelude::*;
1417
1418        for agg_call in self.agg_calls() {
1419            if matches!(agg_call.agg_type, agg_types::unimplemented_in_stream!()) {
1420                bail_not_implemented!("{} aggregation in materialized view", agg_call.agg_type);
1421            }
1422        }
1423        let eowc = ctx.emit_on_window_close();
1424        let input = self
1425            .input()
1426            .try_better_locality(&self.group_key().to_vec())
1427            .unwrap_or_else(|| self.input());
1428        let stream_input = input.to_stream(ctx)?;
1429
1430        // Use Dedup operator, if possible.
1431        if stream_input.append_only() && self.agg_calls().is_empty() && !self.group_key().is_empty()
1432        {
1433            let input = if self.group_key().len() != self.input().schema().len() {
1434                let cols = &self.group_key().to_vec();
1435                LogicalProject::with_mapping(
1436                    self.input(),
1437                    ColIndexMapping::with_remaining_columns(cols, self.input().schema().len()),
1438                )
1439                .into()
1440            } else {
1441                self.input()
1442            };
1443            let input_schema_len = input.schema().len();
1444            let logical_dedup = LogicalDedup::new(input, (0..input_schema_len).collect());
1445            return logical_dedup.to_stream(ctx);
1446        }
1447
1448        if self.agg_calls().iter().any(|call| {
1449            matches!(
1450                call.agg_type,
1451                AggType::Builtin(PbAggKind::ApproxCountDistinct)
1452            )
1453        }) {
1454            if stream_input.append_only() {
1455                self.core.ctx().session_ctx().notice_to_user(
1456                    "Streaming `APPROX_COUNT_DISTINCT` is still a preview feature and subject to change. Please do not use it in production environment.",
1457                );
1458            } else {
1459                bail_not_implemented!(
1460                    "Streaming `APPROX_COUNT_DISTINCT` is only supported in append-only stream"
1461                );
1462            }
1463        }
1464
1465        let plan = self.gen_dist_stream_agg_plan(stream_input)?;
1466
1467        let (plan, n_final_agg_calls) = if let Some(final_agg) = plan.as_stream_simple_agg() {
1468            if eowc {
1469                return Err(ErrorCode::InvalidInputSyntax(
1470                    "`EMIT ON WINDOW CLOSE` cannot be used for aggregation without `GROUP BY`"
1471                        .to_owned(),
1472                )
1473                .into());
1474            }
1475            (plan.clone(), final_agg.agg_calls().len())
1476        } else if let Some(final_agg) = plan.as_stream_hash_agg() {
1477            (
1478                if eowc {
1479                    final_agg.to_eowc_version()?
1480                } else {
1481                    plan.clone()
1482                },
1483                final_agg.agg_calls().len(),
1484            )
1485        } else if let Some(_approx_percentile_agg) = plan.as_stream_global_approx_percentile() {
1486            if eowc {
1487                return Err(ErrorCode::InvalidInputSyntax(
1488                    "`EMIT ON WINDOW CLOSE` cannot be used for aggregation without `GROUP BY`"
1489                        .to_owned(),
1490                )
1491                .into());
1492            }
1493            (plan.clone(), 1)
1494        } else if let Some(stream_row_merge) = plan.as_stream_row_merge() {
1495            if eowc {
1496                return Err(ErrorCode::InvalidInputSyntax(
1497                    "`EMIT ON WINDOW CLOSE` cannot be used for aggregation without `GROUP BY`"
1498                        .to_owned(),
1499                )
1500                .into());
1501            }
1502            (plan.clone(), stream_row_merge.base.schema().len())
1503        } else {
1504            panic!(
1505                "the root PlanNode must be StreamHashAgg, StreamSimpleAgg, StreamGlobalApproxPercentile, or StreamRowMerge"
1506            );
1507        };
1508
1509        if self.agg_calls().len() == n_final_agg_calls {
1510            // an existing `count(*)` is used as row count column in `StreamXxxAgg`
1511            Ok(plan)
1512        } else {
1513            // a `count(*)` is appended, should project the output
1514            assert_eq!(self.agg_calls().len() + 1, n_final_agg_calls);
1515
1516            let mut project = StreamProject::new(generic::Project::with_out_col_idx(
1517                plan,
1518                0..self.schema().len(),
1519            ));
1520            // If there's no agg call, then `count(*)` will be the only column in the output besides keys.
1521            // Since it'll be pruned immediately in `StreamProject`, the update records are likely to be
1522            // no-op. So we set the hint to instruct the executor to eliminate them.
1523            // See https://github.com/risingwavelabs/risingwave/issues/17030.
1524            if self.agg_calls().is_empty() {
1525                project = project.with_noop_update_hint(true);
1526            }
1527            Ok(project.into())
1528        }
1529    }
1530
1531    fn logical_rewrite_for_stream(
1532        &self,
1533        ctx: &mut RewriteStreamContext,
1534    ) -> Result<(PlanRef, ColIndexMapping)> {
1535        let (input, input_col_change) = self.input().logical_rewrite_for_stream(ctx)?;
1536        let (agg, out_col_change) = self.rewrite_with_input(input, input_col_change);
1537        let (map, _) = out_col_change.into_parts();
1538        let out_col_change = ColIndexMapping::new(map, agg.schema().len());
1539        Ok((agg.into(), out_col_change))
1540    }
1541}
1542
1543#[cfg(test)]
1544mod tests {
1545    use risingwave_common::catalog::{Field, Schema};
1546
1547    use super::*;
1548    use crate::expr::{assert_eq_input_ref, input_ref_to_column_indices};
1549    use crate::optimizer::optimizer_context::OptimizerContext;
1550    use crate::optimizer::plan_node::LogicalValues;
1551
1552    #[tokio::test]
1553    async fn test_create() {
1554        let ty = DataType::Int32;
1555        let ctx = OptimizerContext::mock().await;
1556        let fields: Vec<Field> = vec![
1557            Field::with_name(ty.clone(), "v1"),
1558            Field::with_name(ty.clone(), "v2"),
1559            Field::with_name(ty.clone(), "v3"),
1560        ];
1561        let values = LogicalValues::new(vec![], Schema { fields }, ctx);
1562        let input = PlanRef::from(values);
1563        let input_ref_1 = InputRef::new(0, ty.clone());
1564        let input_ref_2 = InputRef::new(1, ty.clone());
1565        let input_ref_3 = InputRef::new(2, ty.clone());
1566
1567        let gen_internal_value = |select_exprs: Vec<ExprImpl>,
1568                                  group_exprs|
1569         -> (Vec<ExprImpl>, Vec<PlanAggCall>, IndexSet) {
1570            let (plan, exprs, _) = LogicalAgg::create(
1571                select_exprs,
1572                GroupBy::GroupKey(group_exprs),
1573                None,
1574                input.clone(),
1575            )
1576            .unwrap();
1577
1578            let logical_agg = plan.as_logical_agg().unwrap();
1579            let agg_calls = logical_agg.agg_calls().to_vec();
1580            let group_key = logical_agg.group_key().clone();
1581
1582            (exprs, agg_calls, group_key)
1583        };
1584
1585        // Test case: select v1 from test group by v1;
1586        {
1587            let select_exprs = vec![input_ref_1.clone().into()];
1588            let group_exprs = vec![input_ref_1.clone().into()];
1589
1590            let (exprs, agg_calls, group_key) = gen_internal_value(select_exprs, group_exprs);
1591
1592            assert_eq!(exprs.len(), 1);
1593            assert_eq_input_ref!(&exprs[0], 0);
1594
1595            assert_eq!(agg_calls.len(), 0);
1596            assert_eq!(group_key, vec![0].into());
1597        }
1598
1599        // Test case: select v1, min(v2) from test group by v1;
1600        {
1601            let min_v2 = AggCall::new(
1602                PbAggKind::Min.into(),
1603                vec![input_ref_2.clone().into()],
1604                false,
1605                OrderBy::any(),
1606                Condition::true_cond(),
1607                vec![],
1608            )
1609            .unwrap();
1610            let select_exprs = vec![input_ref_1.clone().into(), min_v2.into()];
1611            let group_exprs = vec![input_ref_1.clone().into()];
1612
1613            let (exprs, agg_calls, group_key) = gen_internal_value(select_exprs, group_exprs);
1614
1615            assert_eq!(exprs.len(), 2);
1616            assert_eq_input_ref!(&exprs[0], 0);
1617            assert_eq_input_ref!(&exprs[1], 1);
1618
1619            assert_eq!(agg_calls.len(), 1);
1620            assert_eq!(agg_calls[0].agg_type, PbAggKind::Min.into());
1621            assert_eq!(input_ref_to_column_indices(&agg_calls[0].inputs), vec![1]);
1622            assert_eq!(group_key, vec![0].into());
1623        }
1624
1625        // Test case: select v1, min(v2) + max(v3) from t group by v1;
1626        {
1627            let min_v2 = AggCall::new(
1628                PbAggKind::Min.into(),
1629                vec![input_ref_2.clone().into()],
1630                false,
1631                OrderBy::any(),
1632                Condition::true_cond(),
1633                vec![],
1634            )
1635            .unwrap();
1636            let max_v3 = AggCall::new(
1637                PbAggKind::Max.into(),
1638                vec![input_ref_3.clone().into()],
1639                false,
1640                OrderBy::any(),
1641                Condition::true_cond(),
1642                vec![],
1643            )
1644            .unwrap();
1645            let func_call =
1646                FunctionCall::new(ExprType::Add, vec![min_v2.into(), max_v3.into()]).unwrap();
1647            let select_exprs = vec![input_ref_1.clone().into(), ExprImpl::from(func_call)];
1648            let group_exprs = vec![input_ref_1.clone().into()];
1649
1650            let (exprs, agg_calls, group_key) = gen_internal_value(select_exprs, group_exprs);
1651
1652            assert_eq_input_ref!(&exprs[0], 0);
1653            if let ExprImpl::FunctionCall(func_call) = &exprs[1] {
1654                assert_eq!(func_call.func_type(), ExprType::Add);
1655                let inputs = func_call.inputs();
1656                assert_eq_input_ref!(&inputs[0], 1);
1657                assert_eq_input_ref!(&inputs[1], 2);
1658            } else {
1659                panic!("Wrong expression type!");
1660            }
1661
1662            assert_eq!(agg_calls.len(), 2);
1663            assert_eq!(agg_calls[0].agg_type, PbAggKind::Min.into());
1664            assert_eq!(input_ref_to_column_indices(&agg_calls[0].inputs), vec![1]);
1665            assert_eq!(agg_calls[1].agg_type, PbAggKind::Max.into());
1666            assert_eq!(input_ref_to_column_indices(&agg_calls[1].inputs), vec![2]);
1667            assert_eq!(group_key, vec![0].into());
1668        }
1669
1670        // Test case: select v2, min(v1 * v3) from test group by v2;
1671        {
1672            let v1_mult_v3 = FunctionCall::new(
1673                ExprType::Multiply,
1674                vec![input_ref_1.into(), input_ref_3.into()],
1675            )
1676            .unwrap();
1677            let agg_call = AggCall::new(
1678                PbAggKind::Min.into(),
1679                vec![v1_mult_v3.into()],
1680                false,
1681                OrderBy::any(),
1682                Condition::true_cond(),
1683                vec![],
1684            )
1685            .unwrap();
1686            let select_exprs = vec![input_ref_2.clone().into(), agg_call.into()];
1687            let group_exprs = vec![input_ref_2.into()];
1688
1689            let (exprs, agg_calls, group_key) = gen_internal_value(select_exprs, group_exprs);
1690
1691            assert_eq_input_ref!(&exprs[0], 0);
1692            assert_eq_input_ref!(&exprs[1], 1);
1693
1694            assert_eq!(agg_calls.len(), 1);
1695            assert_eq!(agg_calls[0].agg_type, PbAggKind::Min.into());
1696            assert_eq!(input_ref_to_column_indices(&agg_calls[0].inputs), vec![1]);
1697            assert_eq!(group_key, vec![0].into());
1698        }
1699    }
1700
1701    /// Generate a agg call node with given [`DataType`] and fields.
1702    /// For example, `generate_agg_call(Int32, [v1, v2, v3])` will result in:
1703    /// ```text
1704    /// Agg(min(input_ref(2))) group by (input_ref(1))
1705    ///   TableScan(v1, v2, v3)
1706    /// ```
1707    async fn generate_agg_call(ty: DataType, fields: Vec<Field>) -> LogicalAgg {
1708        let ctx = OptimizerContext::mock().await;
1709
1710        let values = LogicalValues::new(vec![], Schema { fields }, ctx);
1711        let agg_call = PlanAggCall {
1712            agg_type: PbAggKind::Min.into(),
1713            return_type: ty.clone(),
1714            inputs: vec![InputRef::new(2, ty.clone())],
1715            distinct: false,
1716            order_by: vec![],
1717            filter: Condition::true_cond(),
1718            direct_args: vec![],
1719        };
1720        Agg::new(vec![agg_call], vec![1].into(), values.into()).into()
1721    }
1722
1723    #[tokio::test]
1724    /// Pruning
1725    /// ```text
1726    /// Agg(min(input_ref(2))) group by (input_ref(1))
1727    ///   TableScan(v1, v2, v3)
1728    /// ```
1729    /// with required columns [0,1] (all columns) will result in
1730    /// ```text
1731    /// Agg(min(input_ref(1))) group by (input_ref(0))
1732    ///  TableScan(v2, v3)
1733    /// ```
1734    async fn test_prune_all() {
1735        let ty = DataType::Int32;
1736        let fields: Vec<Field> = vec![
1737            Field::with_name(ty.clone(), "v1"),
1738            Field::with_name(ty.clone(), "v2"),
1739            Field::with_name(ty.clone(), "v3"),
1740        ];
1741        let agg: PlanRef = generate_agg_call(ty.clone(), fields.clone()).await.into();
1742        // Perform the prune
1743        let required_cols = vec![0, 1];
1744        let plan = agg.prune_col(&required_cols, &mut ColumnPruningContext::new(agg.clone()));
1745
1746        // Check the result
1747        let agg_new = plan.as_logical_agg().unwrap();
1748        assert_eq!(agg_new.group_key(), &vec![0].into());
1749
1750        assert_eq!(agg_new.agg_calls().len(), 1);
1751        let agg_call_new = agg_new.agg_calls()[0].clone();
1752        assert_eq!(agg_call_new.agg_type, PbAggKind::Min.into());
1753        assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![1]);
1754        assert_eq!(agg_call_new.return_type, ty);
1755
1756        let values = agg_new.input();
1757        let values = values.as_logical_values().unwrap();
1758        assert_eq!(values.schema().fields(), &fields[1..]);
1759    }
1760
1761    #[tokio::test]
1762    /// Pruning
1763    /// ```text
1764    /// Agg(min(input_ref(2))) group by (input_ref(1))
1765    ///   TableScan(v1, v2, v3)
1766    /// ```
1767    /// with required columns [1,0] (all columns, with reversed order) will result in
1768    /// ```text
1769    /// Project [input_ref(1), input_ref(0)]
1770    ///   Agg(min(input_ref(1))) group by (input_ref(0))
1771    ///     TableScan(v2, v3)
1772    /// ```
1773    async fn test_prune_all_with_order_required() {
1774        let ty = DataType::Int32;
1775        let fields: Vec<Field> = vec![
1776            Field::with_name(ty.clone(), "v1"),
1777            Field::with_name(ty.clone(), "v2"),
1778            Field::with_name(ty.clone(), "v3"),
1779        ];
1780        let agg: PlanRef = generate_agg_call(ty.clone(), fields.clone()).await.into();
1781        // Perform the prune
1782        let required_cols = vec![1, 0];
1783        let plan = agg.prune_col(&required_cols, &mut ColumnPruningContext::new(agg.clone()));
1784        // Check the result
1785        let proj = plan.as_logical_project().unwrap();
1786        assert_eq!(proj.exprs().len(), 2);
1787        assert_eq!(proj.exprs()[0].as_input_ref().unwrap().index(), 1);
1788        assert_eq!(proj.exprs()[1].as_input_ref().unwrap().index(), 0);
1789        let proj_input = proj.input();
1790        let agg_new = proj_input.as_logical_agg().unwrap();
1791        assert_eq!(agg_new.group_key(), &vec![0].into());
1792
1793        assert_eq!(agg_new.agg_calls().len(), 1);
1794        let agg_call_new = agg_new.agg_calls()[0].clone();
1795        assert_eq!(agg_call_new.agg_type, PbAggKind::Min.into());
1796        assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![1]);
1797        assert_eq!(agg_call_new.return_type, ty);
1798
1799        let values = agg_new.input();
1800        let values = values.as_logical_values().unwrap();
1801        assert_eq!(values.schema().fields(), &fields[1..]);
1802    }
1803
1804    #[tokio::test]
1805    /// Pruning
1806    /// ```text
1807    /// Agg(min(input_ref(2))) group by (input_ref(1))
1808    ///   TableScan(v1, v2, v3)
1809    /// ```
1810    /// with required columns [1] (group key removed) will result in
1811    /// ```text
1812    /// Project(input_ref(1))
1813    ///   Agg(min(input_ref(1))) group by (input_ref(0))
1814    ///     TableScan(v2, v3)
1815    /// ```
1816    async fn test_prune_group_key() {
1817        let ctx = OptimizerContext::mock().await;
1818        let ty = DataType::Int32;
1819        let fields: Vec<Field> = vec![
1820            Field::with_name(ty.clone(), "v1"),
1821            Field::with_name(ty.clone(), "v2"),
1822            Field::with_name(ty.clone(), "v3"),
1823        ];
1824        let values: LogicalValues = LogicalValues::new(
1825            vec![],
1826            Schema {
1827                fields: fields.clone(),
1828            },
1829            ctx,
1830        );
1831        let agg_call = PlanAggCall {
1832            agg_type: PbAggKind::Min.into(),
1833            return_type: ty.clone(),
1834            inputs: vec![InputRef::new(2, ty.clone())],
1835            distinct: false,
1836            order_by: vec![],
1837            filter: Condition::true_cond(),
1838            direct_args: vec![],
1839        };
1840        let agg: PlanRef = Agg::new(vec![agg_call], vec![1].into(), values.into()).into();
1841
1842        // Perform the prune
1843        let required_cols = vec![1];
1844        let plan = agg.prune_col(&required_cols, &mut ColumnPruningContext::new(agg.clone()));
1845
1846        // Check the result
1847        let project = plan.as_logical_project().unwrap();
1848        assert_eq!(project.exprs().len(), 1);
1849        assert_eq_input_ref!(&project.exprs()[0], 1);
1850
1851        let agg_new = project.input();
1852        let agg_new = agg_new.as_logical_agg().unwrap();
1853        assert_eq!(agg_new.group_key(), &vec![0].into());
1854
1855        assert_eq!(agg_new.agg_calls().len(), 1);
1856        let agg_call_new = agg_new.agg_calls()[0].clone();
1857        assert_eq!(agg_call_new.agg_type, PbAggKind::Min.into());
1858        assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![1]);
1859        assert_eq!(agg_call_new.return_type, ty);
1860
1861        let values = agg_new.input();
1862        let values = values.as_logical_values().unwrap();
1863        assert_eq!(values.schema().fields(), &fields[1..]);
1864    }
1865
1866    #[tokio::test]
1867    /// Pruning
1868    /// ```text
1869    /// Agg(min(input_ref(2)), max(input_ref(1))) group by (input_ref(1), input_ref(2))
1870    ///   TableScan(v1, v2, v3)
1871    /// ```
1872    /// with required columns [0,3] will result in
1873    /// ```text
1874    /// Project(input_ref(0), input_ref(2))
1875    ///   Agg(max(input_ref(0))) group by (input_ref(0), input_ref(1))
1876    ///     TableScan(v2, v3)
1877    /// ```
1878    async fn test_prune_agg() {
1879        let ty = DataType::Int32;
1880        let ctx = OptimizerContext::mock().await;
1881        let fields: Vec<Field> = vec![
1882            Field::with_name(ty.clone(), "v1"),
1883            Field::with_name(ty.clone(), "v2"),
1884            Field::with_name(ty.clone(), "v3"),
1885        ];
1886        let values = LogicalValues::new(
1887            vec![],
1888            Schema {
1889                fields: fields.clone(),
1890            },
1891            ctx,
1892        );
1893
1894        let agg_calls = vec![
1895            PlanAggCall {
1896                agg_type: PbAggKind::Min.into(),
1897                return_type: ty.clone(),
1898                inputs: vec![InputRef::new(2, ty.clone())],
1899                distinct: false,
1900                order_by: vec![],
1901                filter: Condition::true_cond(),
1902                direct_args: vec![],
1903            },
1904            PlanAggCall {
1905                agg_type: PbAggKind::Max.into(),
1906                return_type: ty.clone(),
1907                inputs: vec![InputRef::new(1, ty.clone())],
1908                distinct: false,
1909                order_by: vec![],
1910                filter: Condition::true_cond(),
1911                direct_args: vec![],
1912            },
1913        ];
1914        let agg: PlanRef = Agg::new(agg_calls, vec![1, 2].into(), values.into()).into();
1915
1916        // Perform the prune
1917        let required_cols = vec![0, 3];
1918        let plan = agg.prune_col(&required_cols, &mut ColumnPruningContext::new(agg.clone()));
1919        // Check the result
1920        let project = plan.as_logical_project().unwrap();
1921        assert_eq!(project.exprs().len(), 2);
1922        assert_eq_input_ref!(&project.exprs()[0], 0);
1923        assert_eq_input_ref!(&project.exprs()[1], 2);
1924
1925        let agg_new = project.input();
1926        let agg_new = agg_new.as_logical_agg().unwrap();
1927        assert_eq!(agg_new.group_key(), &vec![0, 1].into());
1928
1929        assert_eq!(agg_new.agg_calls().len(), 1);
1930        let agg_call_new = agg_new.agg_calls()[0].clone();
1931        assert_eq!(agg_call_new.agg_type, PbAggKind::Max.into());
1932        assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![0]);
1933        assert_eq!(agg_call_new.return_type, ty);
1934
1935        let values = agg_new.input();
1936        let values = values.as_logical_values().unwrap();
1937        assert_eq!(values.schema().fields(), &fields[1..]);
1938    }
1939}