risingwave_frontend/optimizer/plan_node/
logical_agg.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use fixedbitset::FixedBitSet;
16use itertools::Itertools;
17use risingwave_common::types::{DataType, ScalarImpl};
18use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
19use risingwave_common::{bail, bail_not_implemented, not_implemented};
20use risingwave_expr::aggregate::{AggType, PbAggKind, agg_types};
21
22use super::generic::{self, Agg, GenericPlanRef, PlanAggCall, ProjectBuilder};
23use super::utils::impl_distill_by_unit;
24use super::{
25    BatchHashAgg, BatchSimpleAgg, ColPrunable, ExprRewritable, Logical, LogicalPlanRef as PlanRef,
26    PlanBase, PlanTreeNodeUnary, PredicatePushdown, StreamHashAgg, StreamPlanRef, StreamProject,
27    StreamShare, StreamSimpleAgg, StreamStatelessSimpleAgg, ToBatch, ToStream,
28    try_enforce_locality_requirement,
29};
30use crate::error::{ErrorCode, Result, RwError};
31use crate::expr::{
32    AggCall, Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, FunctionCall, InputRef, Literal,
33    OrderBy, OrderByExpr, WindowFunction,
34};
35use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
36use crate::optimizer::plan_node::generic::GenericPlanNode;
37use crate::optimizer::plan_node::stream_global_approx_percentile::StreamGlobalApproxPercentile;
38use crate::optimizer::plan_node::stream_local_approx_percentile::StreamLocalApproxPercentile;
39use crate::optimizer::plan_node::stream_row_merge::StreamRowMerge;
40use crate::optimizer::plan_node::{
41    BatchSortAgg, ColumnPruningContext, LogicalDedup, LogicalProject, PredicatePushdownContext,
42    RewriteStreamContext, ToStreamContext, gen_filter_and_pushdown,
43};
44use crate::optimizer::property::{Distribution, Order, RequiredDist};
45use crate::utils::{
46    ColIndexMapping, ColIndexMappingRewriteExt, Condition, GroupBy, IndexSet, Substitute,
47};
48
49pub struct AggInfo {
50    pub calls: Vec<PlanAggCall>,
51    pub col_mapping: ColIndexMapping,
52}
53
54/// `SeparatedAggInfo` is used to separate normal and approx percentile aggs.
55pub struct SeparatedAggInfo {
56    normal: AggInfo,
57    approx: AggInfo,
58}
59
60/// `LogicalAgg` groups input data by their group key and computes aggregation functions.
61///
62/// It corresponds to the `GROUP BY` operator in a SQL query statement together with the aggregate
63/// functions in the `SELECT` clause.
64///
65/// The output schema will first include the group key and then the aggregation calls.
66#[derive(Clone, Debug, PartialEq, Eq, Hash)]
67pub struct LogicalAgg {
68    pub base: PlanBase<Logical>,
69    core: Agg<PlanRef>,
70}
71
72impl LogicalAgg {
73    /// Generate plan for stateless 2-phase streaming agg.
74    /// Should only be used iff input is distributed. Input must be converted to stream form.
75    fn gen_stateless_two_phase_streaming_agg_plan(
76        &self,
77        stream_input: StreamPlanRef,
78    ) -> Result<StreamPlanRef> {
79        debug_assert!(self.group_key().is_empty());
80
81        // ====== Handle approx percentile aggs
82        let (
83            non_approx_percentile_col_mapping,
84            approx_percentile_col_mapping,
85            approx_percentile,
86            core,
87        ) = self.prepare_approx_percentile(stream_input)?;
88
89        if core.agg_calls.is_empty() {
90            if let Some(approx_percentile) = approx_percentile {
91                return Ok(approx_percentile);
92            };
93            bail!("expected at least one agg call");
94        }
95
96        let need_row_merge: bool = Self::need_row_merge(&approx_percentile);
97
98        // ====== Handle normal aggs
99        let total_agg_calls = core
100            .agg_calls
101            .iter()
102            .enumerate()
103            .map(|(partial_output_idx, agg_call)| {
104                agg_call.partial_to_total_agg_call(partial_output_idx)
105            })
106            .collect_vec();
107        let local_agg = StreamStatelessSimpleAgg::new(core)?;
108        let exchange =
109            RequiredDist::single().streaming_enforce_if_not_satisfies(local_agg.into())?;
110
111        let must_output_per_barrier = need_row_merge;
112        let global_agg = new_stream_simple_agg(
113            Agg::new(total_agg_calls, IndexSet::empty(), exchange),
114            must_output_per_barrier,
115        )?;
116
117        // ====== Merge approx percentile and normal aggs
118        Self::add_row_merge_if_needed(
119            approx_percentile,
120            global_agg.into(),
121            approx_percentile_col_mapping,
122            non_approx_percentile_col_mapping,
123        )
124    }
125
126    /// Generate plan for stateless/stateful 2-phase streaming agg.
127    /// Should only be used iff input is distributed.
128    /// Input must be converted to stream form.
129    fn gen_vnode_two_phase_streaming_agg_plan(
130        &self,
131        stream_input: StreamPlanRef,
132        dist_key: &[usize],
133    ) -> Result<StreamPlanRef> {
134        let (
135            non_approx_percentile_col_mapping,
136            approx_percentile_col_mapping,
137            approx_percentile,
138            core,
139        ) = self.prepare_approx_percentile(stream_input.clone())?;
140
141        if core.agg_calls.is_empty() {
142            if let Some(approx_percentile) = approx_percentile {
143                return Ok(approx_percentile);
144            };
145            bail!("expected at least one agg call");
146        }
147        let need_row_merge = Self::need_row_merge(&approx_percentile);
148
149        // Generate vnode via project
150        // TODO(kwannoel): We should apply Project optimization rules here.
151        let input_col_num = stream_input.schema().len(); // get schema len before moving `stream_input`.
152        let project = StreamProject::new(generic::Project::with_vnode_col(stream_input, dist_key));
153        let vnode_col_idx = project.base.schema().len() - 1;
154
155        // Generate local agg step
156        let mut local_group_key = self.group_key().clone();
157        local_group_key.insert(vnode_col_idx);
158        let n_local_group_key = local_group_key.len();
159        let local_agg = new_stream_hash_agg(
160            Agg::new(core.agg_calls.clone(), local_group_key, project.into()),
161            Some(vnode_col_idx),
162        )?;
163        // Global group key excludes vnode.
164        let local_agg_group_key_cardinality = local_agg.group_key().len();
165        let local_group_key_without_vnode =
166            &local_agg.group_key().to_vec()[..local_agg_group_key_cardinality - 1];
167        let global_group_key = local_agg
168            .i2o_col_mapping()
169            .rewrite_dist_key(local_group_key_without_vnode)
170            .expect("some input group key could not be mapped");
171
172        // Generate global agg step
173        let global_agg = if self.group_key().is_empty() {
174            let exchange =
175                RequiredDist::single().streaming_enforce_if_not_satisfies(local_agg.into())?;
176            let must_output_per_barrier = need_row_merge;
177            let global_agg = new_stream_simple_agg(
178                Agg::new(
179                    core.agg_calls
180                        .iter()
181                        .enumerate()
182                        .map(|(partial_output_idx, agg_call)| {
183                            agg_call
184                                .partial_to_total_agg_call(n_local_group_key + partial_output_idx)
185                        })
186                        .collect(),
187                    global_group_key.into_iter().collect(),
188                    exchange,
189                ),
190                must_output_per_barrier,
191            )?;
192            global_agg.into()
193        } else {
194            // the `RowMergeExec` has not supported keyed merge
195            assert!(!need_row_merge);
196            let exchange = RequiredDist::shard_by_key(input_col_num, &global_group_key)
197                .streaming_enforce_if_not_satisfies(local_agg.into())?;
198            // Local phase should have reordered the group keys into their required order.
199            // we can just follow it.
200            let global_agg = new_stream_hash_agg(
201                Agg::new(
202                    core.agg_calls
203                        .iter()
204                        .enumerate()
205                        .map(|(partial_output_idx, agg_call)| {
206                            agg_call
207                                .partial_to_total_agg_call(n_local_group_key + partial_output_idx)
208                        })
209                        .collect(),
210                    global_group_key.into_iter().collect(),
211                    exchange,
212                ),
213                None,
214            )?;
215            global_agg.into()
216        };
217        Self::add_row_merge_if_needed(
218            approx_percentile,
219            global_agg,
220            approx_percentile_col_mapping,
221            non_approx_percentile_col_mapping,
222        )
223    }
224
225    fn gen_single_plan(&self, stream_input: StreamPlanRef) -> Result<StreamPlanRef> {
226        let input = RequiredDist::single().streaming_enforce_if_not_satisfies(stream_input)?;
227        let core = self.core.clone_with_input(input);
228        Ok(new_stream_simple_agg(core, false)?.into())
229    }
230
231    fn gen_shuffle_plan(&self, stream_input: StreamPlanRef) -> Result<StreamPlanRef> {
232        let input =
233            RequiredDist::shard_by_key(stream_input.schema().len(), &self.group_key().to_vec())
234                .streaming_enforce_if_not_satisfies(stream_input)?;
235        let core = self.core.clone_with_input(input);
236        Ok(new_stream_hash_agg(core, None)?.into())
237    }
238
239    /// Generates distributed stream plan.
240    fn gen_dist_stream_agg_plan(&self, stream_input: StreamPlanRef) -> Result<StreamPlanRef> {
241        use super::stream::prelude::*;
242
243        let input_dist = stream_input.distribution();
244        debug_assert!(*input_dist != Distribution::Broadcast);
245
246        // Shuffle agg
247        // If we have group key, and we won't try two phase agg optimization at all,
248        // we will always choose shuffle agg over single agg.
249        if !self.group_key().is_empty() && !self.core.must_try_two_phase_agg() {
250            return self.gen_shuffle_plan(stream_input);
251        }
252
253        // Standalone agg
254        // If no group key, and cannot two phase agg, we have to use single plan.
255        if self.group_key().is_empty() && !self.core.can_two_phase_agg() {
256            return self.gen_single_plan(stream_input);
257        }
258
259        debug_assert!(if !self.group_key().is_empty() {
260            self.core.must_try_two_phase_agg()
261        } else {
262            self.core.can_two_phase_agg()
263        });
264
265        // Stateless 2-phase simple agg
266        // can be applied on stateless simple agg calls,
267        // with input distributed by [`Distribution::AnyShard`]
268        if self.group_key().is_empty()
269            && self
270                .core
271                .all_local_aggs_are_stateless(stream_input.append_only())
272            && input_dist.satisfies(&RequiredDist::AnyShard)
273        {
274            return self.gen_stateless_two_phase_streaming_agg_plan(stream_input);
275        }
276
277        // If input is [`Distribution::SomeShard`] and we must try to use two phase agg,
278        // The only remaining strategy is Vnode-based 2-phase agg.
279        // We shall first distribute it by PK,
280        // so it obeys consistent hash strategy via [`Distribution::HashShard`].
281        let stream_input =
282            if *input_dist == Distribution::SomeShard && self.core.must_try_two_phase_agg() {
283                RequiredDist::shard_by_key(
284                    stream_input.schema().len(),
285                    stream_input.expect_stream_key(),
286                )
287                .streaming_enforce_if_not_satisfies(stream_input)?
288            } else {
289                stream_input
290            };
291        let input_dist = stream_input.distribution();
292
293        // Vnode-based 2-phase agg
294        // can be applied on agg calls not affected by order,
295        // with input distributed by dist_key.
296        match input_dist {
297            Distribution::HashShard(dist_key) | Distribution::UpstreamHashShard(dist_key, _)
298                if (self.group_key().is_empty()
299                    || !self.core.hash_agg_dist_satisfied_by_input_dist(input_dist)) =>
300            {
301                let dist_key = dist_key.clone();
302                return self.gen_vnode_two_phase_streaming_agg_plan(stream_input, &dist_key);
303            }
304            _ => {}
305        }
306
307        // Fallback to shuffle or single, if we can't generate any 2-phase plans.
308        if !self.group_key().is_empty() {
309            self.gen_shuffle_plan(stream_input)
310        } else {
311            self.gen_single_plan(stream_input)
312        }
313    }
314
315    /// Prepares metadata and the `approx_percentile` plan, if there's one present.
316    /// It may modify `core.agg_calls` to separate normal agg and approx percentile agg,
317    /// and `core.input` to share the input via `StreamShare`,
318    /// to both approx percentile agg and normal agg.
319    fn prepare_approx_percentile(
320        &self,
321        stream_input: StreamPlanRef,
322    ) -> Result<(
323        ColIndexMapping,
324        ColIndexMapping,
325        Option<StreamPlanRef>,
326        Agg<StreamPlanRef>,
327    )> {
328        let SeparatedAggInfo { normal, approx } = self.separate_normal_and_special_agg();
329
330        let AggInfo {
331            calls: non_approx_percentile_agg_calls,
332            col_mapping: non_approx_percentile_col_mapping,
333        } = normal;
334        let AggInfo {
335            calls: approx_percentile_agg_calls,
336            col_mapping: approx_percentile_col_mapping,
337        } = approx;
338        if !self.group_key().is_empty() && !approx_percentile_agg_calls.is_empty() {
339            bail_not_implemented!(
340                "two-phase streaming approx percentile aggregation with group key, \
341             please use single phase aggregation instead"
342            );
343        }
344
345        // Either we have approx percentile aggs and non_approx percentile aggs,
346        // or we have at least 2 approx percentile aggs.
347        let needs_row_merge = (!non_approx_percentile_agg_calls.is_empty()
348            && !approx_percentile_agg_calls.is_empty())
349            || approx_percentile_agg_calls.len() >= 2;
350        let input = if needs_row_merge {
351            // If there's row merge, we need to share the input.
352            StreamShare::new_from_input(stream_input).into()
353        } else {
354            stream_input
355        };
356        let mut core = self.core.clone_with_input(input);
357        core.agg_calls = non_approx_percentile_agg_calls;
358
359        let approx_percentile =
360            self.build_approx_percentile_aggs(core.input.clone(), &approx_percentile_agg_calls)?;
361        Ok((
362            non_approx_percentile_col_mapping,
363            approx_percentile_col_mapping,
364            approx_percentile,
365            core,
366        ))
367    }
368
369    fn need_row_merge(approx_percentile: &Option<StreamPlanRef>) -> bool {
370        approx_percentile.is_some()
371    }
372
373    /// Add `RowMerge` if needed
374    fn add_row_merge_if_needed(
375        approx_percentile: Option<StreamPlanRef>,
376        global_agg: StreamPlanRef,
377        approx_percentile_col_mapping: ColIndexMapping,
378        non_approx_percentile_col_mapping: ColIndexMapping,
379    ) -> Result<StreamPlanRef> {
380        // just for assert
381        let need_row_merge = Self::need_row_merge(&approx_percentile);
382
383        if let Some(approx_percentile) = approx_percentile {
384            assert!(need_row_merge);
385            let row_merge = StreamRowMerge::new(
386                approx_percentile,
387                global_agg,
388                approx_percentile_col_mapping,
389                non_approx_percentile_col_mapping,
390            )?;
391            Ok(row_merge.into())
392        } else {
393            assert!(!need_row_merge);
394            Ok(global_agg)
395        }
396    }
397
398    fn separate_normal_and_special_agg(&self) -> SeparatedAggInfo {
399        let estimated_len = self.agg_calls().len() - 1;
400        let mut approx_percentile_agg_calls = Vec::with_capacity(estimated_len);
401        let mut non_approx_percentile_agg_calls = Vec::with_capacity(estimated_len);
402        let mut approx_percentile_col_mapping = Vec::with_capacity(estimated_len);
403        let mut non_approx_percentile_col_mapping = Vec::with_capacity(estimated_len);
404        for (output_idx, agg_call) in self.agg_calls().iter().enumerate() {
405            if agg_call.agg_type == AggType::Builtin(PbAggKind::ApproxPercentile) {
406                approx_percentile_agg_calls.push(agg_call.clone());
407                approx_percentile_col_mapping.push(Some(output_idx));
408            } else {
409                non_approx_percentile_agg_calls.push(agg_call.clone());
410                non_approx_percentile_col_mapping.push(Some(output_idx));
411            }
412        }
413        let normal = AggInfo {
414            calls: non_approx_percentile_agg_calls,
415            col_mapping: ColIndexMapping::new(
416                non_approx_percentile_col_mapping,
417                self.agg_calls().len(),
418            ),
419        };
420        let approx = AggInfo {
421            calls: approx_percentile_agg_calls,
422            col_mapping: ColIndexMapping::new(
423                approx_percentile_col_mapping,
424                self.agg_calls().len(),
425            ),
426        };
427        SeparatedAggInfo { normal, approx }
428    }
429
430    fn build_approx_percentile_agg(
431        &self,
432        input: StreamPlanRef,
433        approx_percentile_agg_call: &PlanAggCall,
434    ) -> Result<StreamPlanRef> {
435        let local_approx_percentile =
436            StreamLocalApproxPercentile::new(input, approx_percentile_agg_call)?;
437        let exchange = RequiredDist::single()
438            .streaming_enforce_if_not_satisfies(local_approx_percentile.into())?;
439        let global_approx_percentile =
440            StreamGlobalApproxPercentile::new(exchange, approx_percentile_agg_call);
441        Ok(global_approx_percentile.into())
442    }
443
444    /// If only 1 approx percentile, just return it.
445    /// Otherwise build a tree of approx percentile with `MergeProject`.
446    /// e.g.
447    /// ApproxPercentile(col1, 0.5) as x,
448    /// ApproxPercentile(col2, 0.5) as y,
449    /// ApproxPercentile(col3, 0.5) as z
450    /// will be built as
451    ///        `MergeProject`
452    ///       /          \
453    ///  `MergeProject`       z
454    ///  /        \
455    /// x          y
456    fn build_approx_percentile_aggs(
457        &self,
458        input: StreamPlanRef,
459        approx_percentile_agg_call: &[PlanAggCall],
460    ) -> Result<Option<StreamPlanRef>> {
461        if approx_percentile_agg_call.is_empty() {
462            return Ok(None);
463        }
464        let approx_percentile_plans: Vec<_> = approx_percentile_agg_call
465            .iter()
466            .map(|agg_call| self.build_approx_percentile_agg(input.clone(), agg_call))
467            .try_collect()?;
468        assert!(!approx_percentile_plans.is_empty());
469        let mut iter = approx_percentile_plans.into_iter();
470        let mut acc = iter.next().unwrap();
471        for (current_size, plan) in iter.enumerate().map(|(i, p)| (i + 1, p)) {
472            let new_size = current_size + 1;
473            let row_merge = StreamRowMerge::new(
474                acc,
475                plan,
476                ColIndexMapping::identity_or_none(current_size, new_size),
477                ColIndexMapping::new(vec![Some(current_size)], new_size),
478            )?;
479            acc = row_merge.into();
480        }
481        Ok(Some(acc))
482    }
483
484    pub fn core(&self) -> &Agg<PlanRef> {
485        &self.core
486    }
487}
488
489/// `LogicalAggBuilder` extracts agg calls and references to group columns from select list and
490/// build the plan like `LogicalAgg - LogicalProject`.
491/// it is constructed by `group_exprs` and collect and rewrite the expression in selection and
492/// having clause.
493pub struct LogicalAggBuilder {
494    /// the builder of the input Project
495    input_proj_builder: ProjectBuilder,
496    /// the group key column indices in the project's output
497    group_key: IndexSet,
498    /// the grouping sets
499    grouping_sets: Vec<IndexSet>,
500    /// the agg calls
501    agg_calls: Vec<PlanAggCall>,
502    /// the error during the expression rewriting
503    error: Option<RwError>,
504    /// If `is_in_filter_clause` is true, it means that
505    /// we are processing filter clause.
506    /// This field is needed because input refs in these clauses
507    /// are allowed to refer to any columns, while those not in filter
508    /// clause are only allowed to refer to group keys.
509    is_in_filter_clause: bool,
510}
511
512impl LogicalAggBuilder {
513    fn new(group_by: GroupBy, input_schema_len: usize) -> Result<Self> {
514        let mut input_proj_builder = ProjectBuilder::default();
515
516        let mut gen_group_key_and_grouping_sets =
517            |grouping_sets: Vec<Vec<ExprImpl>>| -> Result<(IndexSet, Vec<IndexSet>)> {
518                let grouping_sets: Vec<IndexSet> = grouping_sets
519                    .into_iter()
520                    .map(|set| {
521                        set.into_iter()
522                            .map(|expr| input_proj_builder.add_expr(&expr))
523                            .try_collect()
524                            .map_err(|err| not_implemented!("{err} inside GROUP BY"))
525                    })
526                    .try_collect()?;
527
528                // Construct group key based on grouping sets.
529                let group_key = grouping_sets
530                    .iter()
531                    .fold(FixedBitSet::with_capacity(input_schema_len), |acc, x| {
532                        acc.union(&x.to_bitset()).collect()
533                    });
534
535                Ok((IndexSet::from_iter(group_key.ones()), grouping_sets))
536            };
537
538        let (group_key, grouping_sets) = match group_by {
539            GroupBy::GroupKey(group_key) => {
540                let group_key = group_key
541                    .into_iter()
542                    .map(|expr| input_proj_builder.add_expr(&expr))
543                    .try_collect()
544                    .map_err(|err| not_implemented!("{err} inside GROUP BY"))?;
545                (group_key, vec![])
546            }
547            GroupBy::GroupingSets(grouping_sets) => gen_group_key_and_grouping_sets(grouping_sets)?,
548            GroupBy::Rollup(rollup) => {
549                // Convert rollup to grouping sets.
550                let grouping_sets = (0..=rollup.len())
551                    .map(|n| {
552                        rollup
553                            .iter()
554                            .take(n)
555                            .flat_map(|x| x.iter().cloned())
556                            .collect_vec()
557                    })
558                    .collect_vec();
559                gen_group_key_and_grouping_sets(grouping_sets)?
560            }
561            GroupBy::Cube(cube) => {
562                // Convert cube to grouping sets.
563                let grouping_sets = cube
564                    .into_iter()
565                    .powerset()
566                    .map(|x| x.into_iter().flatten().collect_vec())
567                    .collect_vec();
568                gen_group_key_and_grouping_sets(grouping_sets)?
569            }
570        };
571
572        Ok(LogicalAggBuilder {
573            group_key,
574            grouping_sets,
575            agg_calls: vec![],
576            error: None,
577            input_proj_builder,
578            is_in_filter_clause: false,
579        })
580    }
581
582    pub fn build(self, input: PlanRef) -> LogicalAgg {
583        // This LogicalProject focuses on the exprs in aggregates and GROUP BY clause.
584        let logical_project = LogicalProject::with_core(self.input_proj_builder.build(input));
585
586        // This LogicalAgg focuses on calculating the aggregates and grouping.
587        Agg::new(self.agg_calls, self.group_key, logical_project.into())
588            .with_grouping_sets(self.grouping_sets)
589            .into()
590    }
591
592    fn rewrite_with_error(&mut self, expr: ExprImpl) -> Result<ExprImpl> {
593        let rewritten_expr = self.rewrite_expr(expr);
594        if let Some(error) = self.error.take() {
595            return Err(error);
596        }
597        Ok(rewritten_expr)
598    }
599
600    /// check if the expression is a group by key, and try to return the group key
601    pub fn try_as_group_expr(&self, expr: &ExprImpl) -> Option<usize> {
602        if let Some(input_index) = self.input_proj_builder.expr_index(expr)
603            && let Some(index) = self
604                .group_key
605                .indices()
606                .position(|group_key| group_key == input_index)
607        {
608            return Some(index);
609        }
610        None
611    }
612
613    fn schema_agg_start_offset(&self) -> usize {
614        self.group_key.len()
615    }
616
617    /// Rewrite [`AggCall`] if needed, and push it into the builder using `push_agg_call`.
618    /// This is shared by [`LogicalAggBuilder`] and `LogicalOverWindowBuilder`.
619    pub(crate) fn general_rewrite_agg_call(
620        agg_call: AggCall,
621        mut push_agg_call: impl FnMut(AggCall) -> Result<InputRef>,
622    ) -> Result<ExprImpl> {
623        match agg_call.agg_type {
624            // Rewrite avg to cast(sum as avg_return_type) / count.
625            AggType::Builtin(PbAggKind::Avg) => {
626                assert_eq!(agg_call.args.len(), 1);
627
628                let sum = ExprImpl::from(push_agg_call(AggCall::new(
629                    PbAggKind::Sum.into(),
630                    agg_call.args.clone(),
631                    agg_call.distinct,
632                    agg_call.order_by.clone(),
633                    agg_call.filter.clone(),
634                    agg_call.direct_args.clone(),
635                )?)?)
636                .cast_explicit(&agg_call.return_type())?;
637
638                let count = ExprImpl::from(push_agg_call(AggCall::new(
639                    PbAggKind::Count.into(),
640                    agg_call.args.clone(),
641                    agg_call.distinct,
642                    agg_call.order_by.clone(),
643                    agg_call.filter.clone(),
644                    agg_call.direct_args,
645                )?)?);
646
647                Ok(FunctionCall::new(ExprType::Divide, Vec::from([sum, count]))?.into())
648            }
649            // We compute `var_samp` as
650            // (sum(sq) - sum * sum / count) / (count - 1)
651            // and `var_pop` as
652            // (sum(sq) - sum * sum / count) / count
653            // Since we don't have the square function, we use the plain Multiply for squaring,
654            // which is in a sense more general than the pow function, especially when calculating
655            // covariances in the future. Also we don't have the sqrt function for rooting, so we
656            // use pow(x, 0.5) to simulate
657            AggType::Builtin(
658                kind @ (PbAggKind::StddevPop
659                | PbAggKind::StddevSamp
660                | PbAggKind::VarPop
661                | PbAggKind::VarSamp),
662            ) => {
663                let arg = agg_call.args().iter().exactly_one().unwrap();
664                let squared_arg = ExprImpl::from(FunctionCall::new(
665                    ExprType::Multiply,
666                    vec![arg.clone(), arg.clone()],
667                )?);
668
669                let sum_of_sq = ExprImpl::from(push_agg_call(AggCall::new(
670                    PbAggKind::Sum.into(),
671                    vec![squared_arg],
672                    agg_call.distinct,
673                    agg_call.order_by.clone(),
674                    agg_call.filter.clone(),
675                    agg_call.direct_args.clone(),
676                )?)?)
677                .cast_explicit(&agg_call.return_type())?;
678
679                let sum = ExprImpl::from(push_agg_call(AggCall::new(
680                    PbAggKind::Sum.into(),
681                    agg_call.args.clone(),
682                    agg_call.distinct,
683                    agg_call.order_by.clone(),
684                    agg_call.filter.clone(),
685                    agg_call.direct_args.clone(),
686                )?)?)
687                .cast_explicit(&agg_call.return_type())?;
688
689                let count = ExprImpl::from(push_agg_call(AggCall::new(
690                    PbAggKind::Count.into(),
691                    agg_call.args.clone(),
692                    agg_call.distinct,
693                    agg_call.order_by.clone(),
694                    agg_call.filter.clone(),
695                    agg_call.direct_args.clone(),
696                )?)?);
697
698                let zero = ExprImpl::literal_int(0);
699                let one = ExprImpl::literal_int(1);
700
701                let squared_sum = ExprImpl::from(FunctionCall::new(
702                    ExprType::Multiply,
703                    vec![sum.clone(), sum],
704                )?);
705
706                let raw_numerator = ExprImpl::from(FunctionCall::new(
707                    ExprType::Subtract,
708                    vec![
709                        sum_of_sq,
710                        ExprImpl::from(FunctionCall::new(
711                            ExprType::Divide,
712                            vec![squared_sum, count.clone()],
713                        )?),
714                    ],
715                )?);
716
717                // We need to check for potential accuracy issues that may occasionally lead to results less than 0.
718                let numerator_type = raw_numerator.return_type();
719                let numerator = ExprImpl::from(FunctionCall::new(
720                    ExprType::Greatest,
721                    vec![raw_numerator, zero.clone().cast_explicit(&numerator_type)?],
722                )?);
723
724                let denominator = match kind {
725                    PbAggKind::VarPop | PbAggKind::StddevPop => count.clone(),
726                    PbAggKind::VarSamp | PbAggKind::StddevSamp => ExprImpl::from(
727                        FunctionCall::new(ExprType::Subtract, vec![count.clone(), one.clone()])?,
728                    ),
729                    _ => unreachable!(),
730                };
731
732                let mut target = ExprImpl::from(FunctionCall::new(
733                    ExprType::Divide,
734                    vec![numerator, denominator],
735                )?);
736
737                if matches!(kind, PbAggKind::StddevPop | PbAggKind::StddevSamp) {
738                    target = ExprImpl::from(FunctionCall::new(ExprType::Sqrt, vec![target])?);
739                }
740
741                let null = ExprImpl::from(Literal::new(None, agg_call.return_type()));
742                let case_cond = match kind {
743                    PbAggKind::VarPop | PbAggKind::StddevPop => {
744                        ExprImpl::from(FunctionCall::new(ExprType::Equal, vec![count, zero])?)
745                    }
746                    PbAggKind::VarSamp | PbAggKind::StddevSamp => ExprImpl::from(
747                        FunctionCall::new(ExprType::LessThanOrEqual, vec![count, one])?,
748                    ),
749                    _ => unreachable!(),
750                };
751
752                Ok(ExprImpl::from(FunctionCall::new(
753                    ExprType::Case,
754                    vec![case_cond, null, target],
755                )?))
756            }
757            AggType::Builtin(PbAggKind::ApproxPercentile) => {
758                if agg_call.order_by.sort_exprs[0].order_type == OrderType::descending() {
759                    // Rewrite DESC into 1.0-percentile for approx_percentile.
760                    let prev_percentile = agg_call.direct_args[0].clone();
761                    let new_percentile = 1.0
762                        - prev_percentile
763                            .get_data()
764                            .as_ref()
765                            .unwrap()
766                            .as_float64()
767                            .into_inner();
768                    let new_percentile = Some(ScalarImpl::Float64(new_percentile.into()));
769                    let new_percentile = Literal::new(new_percentile, DataType::Float64);
770                    let new_direct_args = vec![new_percentile, agg_call.direct_args[1].clone()];
771
772                    let new_agg_call = AggCall {
773                        order_by: OrderBy::any(),
774                        direct_args: new_direct_args,
775                        ..agg_call
776                    };
777                    Ok(push_agg_call(new_agg_call)?.into())
778                } else {
779                    let new_agg_call = AggCall {
780                        order_by: OrderBy::any(),
781                        ..agg_call
782                    };
783                    Ok(push_agg_call(new_agg_call)?.into())
784                }
785            }
786            AggType::Builtin(PbAggKind::ArgMin | PbAggKind::ArgMax) => {
787                let mut agg_call = agg_call;
788
789                let comparison_arg_type = agg_call.args[1].return_type();
790                match comparison_arg_type {
791                    DataType::Struct(_)
792                    | DataType::List(_)
793                    | DataType::Map(_)
794                    | DataType::Vector(_)
795                    | DataType::Jsonb => {
796                        bail!(format!(
797                            "{} does not support struct, array, map, vector, jsonb for comparison argument, got {}",
798                            agg_call.agg_type.to_string(),
799                            comparison_arg_type
800                        ));
801                    }
802                    _ => {}
803                }
804
805                let not_null_exprs: Vec<ExprImpl> = agg_call
806                    .args
807                    .iter()
808                    .map(|arg| -> Result<ExprImpl> {
809                        Ok(FunctionCall::new(ExprType::IsNotNull, vec![arg.clone()])?.into())
810                    })
811                    .try_collect()?;
812
813                let comparison_expr = agg_call.args[1].clone();
814                let mut order_exprs = vec![OrderByExpr {
815                    expr: comparison_expr,
816                    order_type: if agg_call.agg_type == AggType::Builtin(PbAggKind::ArgMin) {
817                        OrderType::ascending()
818                    } else {
819                        OrderType::descending()
820                    },
821                }];
822
823                order_exprs.extend(agg_call.order_by.sort_exprs);
824
825                let order_by = OrderBy::new(order_exprs);
826
827                let filter = agg_call.filter.clone().and(Condition {
828                    conjunctions: not_null_exprs,
829                });
830
831                agg_call.args.truncate(1);
832
833                let new_agg_call = AggCall {
834                    agg_type: AggType::Builtin(PbAggKind::FirstValue),
835                    order_by,
836                    filter,
837                    ..agg_call
838                };
839                Ok(push_agg_call(new_agg_call)?.into())
840            }
841            _ => Ok(push_agg_call(agg_call)?.into()),
842        }
843    }
844
845    /// Push a new agg call into the builder.
846    /// Return an `InputRef` to that agg call.
847    /// For existing agg calls, return an `InputRef` to the existing one.
848    fn push_agg_call(&mut self, agg_call: AggCall) -> Result<InputRef> {
849        let AggCall {
850            agg_type,
851            return_type,
852            args,
853            distinct,
854            order_by,
855            filter,
856            direct_args,
857        } = agg_call;
858
859        self.is_in_filter_clause = true;
860        // filter expr is not added to `input_proj_builder` as a whole. Special exprs incl
861        // subquery/agg/table are rejected in `bind_agg`.
862        let filter = filter.rewrite_expr(self);
863        self.is_in_filter_clause = false;
864
865        let args: Vec<_> = args
866            .iter()
867            .map(|expr| {
868                let index = self.input_proj_builder.add_expr(expr)?;
869                Ok(InputRef::new(index, expr.return_type()))
870            })
871            .try_collect()
872            .map_err(|err: &'static str| not_implemented!("{err} inside aggregation calls"))?;
873
874        let order_by: Vec<_> = order_by
875            .sort_exprs
876            .iter()
877            .map(|e| {
878                let index = self.input_proj_builder.add_expr(&e.expr)?;
879                Ok(ColumnOrder::new(index, e.order_type))
880            })
881            .try_collect()
882            .map_err(|err: &'static str| {
883                not_implemented!("{err} inside aggregation calls order by")
884            })?;
885
886        let plan_agg_call = PlanAggCall {
887            agg_type,
888            return_type: return_type.clone(),
889            inputs: args,
890            distinct,
891            order_by,
892            filter,
893            direct_args,
894        };
895
896        if let Some((pos, existing)) = self
897            .agg_calls
898            .iter()
899            .find_position(|&c| c == &plan_agg_call)
900        {
901            return Ok(InputRef::new(
902                self.schema_agg_start_offset() + pos,
903                existing.return_type.clone(),
904            ));
905        }
906        let index = self.schema_agg_start_offset() + self.agg_calls.len();
907        self.agg_calls.push(plan_agg_call);
908        Ok(InputRef::new(index, return_type))
909    }
910
911    /// When there is an agg call, there are 3 things to do:
912    /// 1. Rewrite `avg`, `var_samp`, etc. into a combination of `sum`, `count`, etc.;
913    /// 2. Add exprs in arguments to input `Project`;
914    /// 2. Add the agg call to current `Agg`, and return an `InputRef` to it.
915    ///
916    /// Note that the rewriter does not traverse into inputs of agg calls.
917    fn try_rewrite_agg_call(&mut self, mut agg_call: AggCall) -> Result<ExprImpl> {
918        if matches!(agg_call.agg_type, agg_types::must_have_order_by!())
919            && agg_call.order_by.sort_exprs.is_empty()
920        {
921            return Err(ErrorCode::InvalidInputSyntax(format!(
922                "Aggregation function {} requires ORDER BY clause",
923                agg_call.agg_type
924            ))
925            .into());
926        }
927
928        // try ignore ORDER BY if it doesn't affect the result
929        if matches!(
930            agg_call.agg_type,
931            agg_types::result_unaffected_by_order_by!()
932        ) {
933            agg_call.order_by = OrderBy::any();
934        }
935        // try ignore DISTINCT if it doesn't affect the result
936        if matches!(
937            agg_call.agg_type,
938            agg_types::result_unaffected_by_distinct!()
939        ) {
940            agg_call.distinct = false;
941        }
942
943        if matches!(agg_call.agg_type, AggType::Builtin(PbAggKind::Grouping)) {
944            if self.grouping_sets.is_empty() {
945                return Err(ErrorCode::NotSupported(
946                    "GROUPING must be used in a query with grouping sets".into(),
947                    "try to use grouping sets instead".into(),
948                )
949                .into());
950            }
951            if agg_call.args.len() >= 32 {
952                return Err(ErrorCode::InvalidInputSyntax(
953                    "GROUPING must have fewer than 32 arguments".into(),
954                )
955                .into());
956            }
957            if agg_call
958                .args
959                .iter()
960                .any(|x| self.try_as_group_expr(x).is_none())
961            {
962                return Err(ErrorCode::InvalidInputSyntax(
963                    "arguments to GROUPING must be grouping expressions of the associated query level"
964                        .into(),
965                ).into());
966            }
967        }
968
969        Self::general_rewrite_agg_call(agg_call, |agg_call| self.push_agg_call(agg_call))
970    }
971}
972
973impl ExprRewriter for LogicalAggBuilder {
974    fn rewrite_agg_call(&mut self, agg_call: AggCall) -> ExprImpl {
975        let dummy = Literal::new(None, agg_call.return_type()).into();
976        match self.try_rewrite_agg_call(agg_call) {
977            Ok(expr) => expr,
978            Err(err) => {
979                self.error = Some(err);
980                dummy
981            }
982        }
983    }
984
985    /// When there is an `FunctionCall` (outside of agg call), it must refers to a group column.
986    /// Or all `InputRef`s appears in it must refer to a group column.
987    fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
988        let expr = func_call.into();
989        if let Some(group_key) = self.try_as_group_expr(&expr) {
990            InputRef::new(group_key, expr.return_type()).into()
991        } else {
992            let (func_type, inputs, ret) = expr.into_function_call().unwrap().decompose();
993            let inputs = inputs
994                .into_iter()
995                .map(|expr| self.rewrite_expr(expr))
996                .collect();
997            FunctionCall::new_unchecked(func_type, inputs, ret).into()
998        }
999    }
1000
1001    /// When there is an `WindowFunction` (outside of agg call), it must refers to a group column.
1002    /// Or all `InputRef`s appears in it must refer to a group column.
1003    fn rewrite_window_function(&mut self, window_func: WindowFunction) -> ExprImpl {
1004        let WindowFunction {
1005            args,
1006            partition_by,
1007            order_by,
1008            ..
1009        } = window_func;
1010        let args = args
1011            .into_iter()
1012            .map(|expr| self.rewrite_expr(expr))
1013            .collect();
1014        let partition_by = partition_by
1015            .into_iter()
1016            .map(|expr| self.rewrite_expr(expr))
1017            .collect();
1018        let order_by = order_by.rewrite_expr(self);
1019        WindowFunction {
1020            args,
1021            partition_by,
1022            order_by,
1023            ..window_func
1024        }
1025        .into()
1026    }
1027
1028    /// When there is an `InputRef` (outside of agg call), it must refers to a group column.
1029    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
1030        let expr = input_ref.into();
1031        if let Some(group_key) = self.try_as_group_expr(&expr) {
1032            InputRef::new(group_key, expr.return_type()).into()
1033        } else if self.is_in_filter_clause {
1034            InputRef::new(
1035                self.input_proj_builder.add_expr(&expr).unwrap(),
1036                expr.return_type(),
1037            )
1038            .into()
1039        } else {
1040            self.error = Some(
1041                ErrorCode::InvalidInputSyntax(
1042                    "column must appear in the GROUP BY clause or be used in an aggregate function"
1043                        .into(),
1044                )
1045                .into(),
1046            );
1047            expr
1048        }
1049    }
1050
1051    fn rewrite_subquery(&mut self, subquery: crate::expr::Subquery) -> ExprImpl {
1052        if subquery.is_correlated_by_depth(0) {
1053            self.error = Some(
1054                not_implemented!(
1055                    issue = 2275,
1056                    "correlated subquery in HAVING or SELECT with agg",
1057                )
1058                .into(),
1059            );
1060        }
1061        subquery.into()
1062    }
1063}
1064
1065impl From<Agg<PlanRef>> for LogicalAgg {
1066    fn from(core: Agg<PlanRef>) -> Self {
1067        let base = PlanBase::new_logical_with_core(&core);
1068        Self { base, core }
1069    }
1070}
1071
1072/// Because `From`/`Into` are not transitive
1073impl From<Agg<PlanRef>> for PlanRef {
1074    fn from(core: Agg<PlanRef>) -> Self {
1075        LogicalAgg::from(core).into()
1076    }
1077}
1078
1079impl LogicalAgg {
1080    /// get the Mapping of columnIndex from input column index to out column index
1081    pub fn i2o_col_mapping(&self) -> ColIndexMapping {
1082        self.core.i2o_col_mapping()
1083    }
1084
1085    /// `create` will analyze select exprs, group exprs and having, and construct a plan like
1086    ///
1087    /// ```text
1088    /// LogicalAgg -> LogicalProject -> input
1089    /// ```
1090    ///
1091    /// It also returns the rewritten select exprs and having that reference into the aggregated
1092    /// results.
1093    pub fn create(
1094        select_exprs: Vec<ExprImpl>,
1095        group_by: GroupBy,
1096        having: Option<ExprImpl>,
1097        input: PlanRef,
1098    ) -> Result<(PlanRef, Vec<ExprImpl>, Option<ExprImpl>)> {
1099        let mut agg_builder = LogicalAggBuilder::new(group_by, input.schema().len())?;
1100
1101        let rewritten_select_exprs = select_exprs
1102            .into_iter()
1103            .map(|expr| agg_builder.rewrite_with_error(expr))
1104            .collect::<Result<_>>()?;
1105        let rewritten_having = having
1106            .map(|expr| agg_builder.rewrite_with_error(expr))
1107            .transpose()?;
1108
1109        Ok((
1110            agg_builder.build(input).into(),
1111            rewritten_select_exprs,
1112            rewritten_having,
1113        ))
1114    }
1115
1116    /// Get a reference to the logical agg's agg calls.
1117    pub fn agg_calls(&self) -> &Vec<PlanAggCall> {
1118        &self.core.agg_calls
1119    }
1120
1121    /// Get a reference to the logical agg's group key.
1122    pub fn group_key(&self) -> &IndexSet {
1123        &self.core.group_key
1124    }
1125
1126    pub fn grouping_sets(&self) -> &Vec<IndexSet> {
1127        &self.core.grouping_sets
1128    }
1129
1130    pub fn decompose(self) -> (Vec<PlanAggCall>, IndexSet, Vec<IndexSet>, PlanRef, bool) {
1131        self.core.decompose()
1132    }
1133
1134    #[must_use]
1135    pub fn rewrite_with_input_agg(
1136        &self,
1137        input: PlanRef,
1138        agg_calls: &[PlanAggCall],
1139        mut input_col_change: ColIndexMapping,
1140    ) -> (Self, ColIndexMapping) {
1141        let agg_calls = agg_calls
1142            .iter()
1143            .cloned()
1144            .map(|mut agg_call| {
1145                agg_call.inputs.iter_mut().for_each(|i| {
1146                    *i = InputRef::new(input_col_change.map(i.index()), i.return_type())
1147                });
1148                agg_call.order_by.iter_mut().for_each(|o| {
1149                    o.column_index = input_col_change.map(o.column_index);
1150                });
1151                agg_call.filter = agg_call.filter.rewrite_expr(&mut input_col_change);
1152                agg_call
1153            })
1154            .collect();
1155        // This is the group key order should be after rewriting.
1156        let group_key_in_vec: Vec<usize> = self
1157            .group_key()
1158            .indices()
1159            .map(|key| input_col_change.map(key))
1160            .collect();
1161        // This is the group key order we get after rewriting.
1162        let group_key: IndexSet = group_key_in_vec.iter().cloned().collect();
1163        let grouping_sets = self
1164            .grouping_sets()
1165            .iter()
1166            .map(|set| set.indices().map(|key| input_col_change.map(key)).collect())
1167            .collect();
1168
1169        let new_agg = Agg::new(agg_calls, group_key.clone(), input)
1170            .with_grouping_sets(grouping_sets)
1171            .with_enable_two_phase(self.core().enable_two_phase);
1172
1173        // group_key remapping might cause an output column change, since group key actually is a
1174        // `FixedBitSet`.
1175        let mut out_col_change = vec![];
1176        for idx in group_key_in_vec {
1177            let pos = group_key.indices().position(|x| x == idx).unwrap();
1178            out_col_change.push(pos);
1179        }
1180        for i in (group_key.len())..new_agg.schema().len() {
1181            out_col_change.push(i);
1182        }
1183        let out_col_change =
1184            ColIndexMapping::with_remaining_columns(&out_col_change, new_agg.schema().len());
1185
1186        (new_agg.into(), out_col_change)
1187    }
1188}
1189
1190impl PlanTreeNodeUnary<Logical> for LogicalAgg {
1191    fn input(&self) -> PlanRef {
1192        self.core.input.clone()
1193    }
1194
1195    fn clone_with_input(&self, input: PlanRef) -> Self {
1196        Agg::new(self.agg_calls().clone(), self.group_key().clone(), input)
1197            .with_grouping_sets(self.grouping_sets().clone())
1198            .with_enable_two_phase(self.core().enable_two_phase)
1199            .into()
1200    }
1201
1202    fn rewrite_with_input(
1203        &self,
1204        input: PlanRef,
1205        input_col_change: ColIndexMapping,
1206    ) -> (Self, ColIndexMapping) {
1207        self.rewrite_with_input_agg(input, self.agg_calls(), input_col_change)
1208    }
1209}
1210
1211impl_plan_tree_node_for_unary! { Logical, LogicalAgg }
1212impl_distill_by_unit!(LogicalAgg, core, "LogicalAgg");
1213
1214impl ExprRewritable<Logical> for LogicalAgg {
1215    fn has_rewritable_expr(&self) -> bool {
1216        true
1217    }
1218
1219    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
1220        let mut core = self.core.clone();
1221        core.rewrite_exprs(r);
1222        Self {
1223            base: self.base.clone_with_new_plan_id(),
1224            core,
1225        }
1226        .into()
1227    }
1228}
1229
1230impl ExprVisitable for LogicalAgg {
1231    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
1232        self.core.visit_exprs(v);
1233    }
1234}
1235
1236impl ColPrunable for LogicalAgg {
1237    fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
1238        let group_key_required_cols = self.group_key().to_bitset();
1239
1240        let (agg_call_required_cols, agg_calls) = {
1241            let input_cnt = self.input().schema().len();
1242            let mut tmp = FixedBitSet::with_capacity(input_cnt);
1243            let group_key_cardinality = self.group_key().len();
1244            let new_agg_calls = required_cols
1245                .iter()
1246                .filter(|&&index| index >= group_key_cardinality)
1247                .map(|&index| {
1248                    let index = index - group_key_cardinality;
1249                    let agg_call = self.agg_calls()[index].clone();
1250                    tmp.extend(agg_call.inputs.iter().map(|x| x.index()));
1251                    tmp.extend(agg_call.order_by.iter().map(|x| x.column_index));
1252                    // collect columns used in aggregate filter expressions
1253                    for i in &agg_call.filter.conjunctions {
1254                        tmp.union_with(&i.collect_input_refs(input_cnt));
1255                    }
1256                    agg_call
1257                })
1258                .collect_vec();
1259            (tmp, new_agg_calls)
1260        };
1261
1262        let input_required_cols = {
1263            let mut tmp = FixedBitSet::with_capacity(self.input().schema().len());
1264            tmp.union_with(&group_key_required_cols);
1265            tmp.union_with(&agg_call_required_cols);
1266            tmp.ones().collect_vec()
1267        };
1268        let input_col_change = ColIndexMapping::with_remaining_columns(
1269            &input_required_cols,
1270            self.input().schema().len(),
1271        );
1272        let agg = {
1273            let input = self.input().prune_col(&input_required_cols, ctx);
1274            let (agg, output_col_change) =
1275                self.rewrite_with_input_agg(input, &agg_calls, input_col_change);
1276            assert!(output_col_change.is_identity());
1277            agg
1278        };
1279        let new_output_cols = {
1280            // group key were never pruned or even re-ordered in current impl
1281            let group_key_cardinality = agg.group_key().len();
1282            let mut tmp = (0..group_key_cardinality).collect_vec();
1283            tmp.extend(
1284                required_cols
1285                    .iter()
1286                    .filter(|&&index| index >= group_key_cardinality),
1287            );
1288            tmp
1289        };
1290        if new_output_cols == required_cols {
1291            // current schema perfectly fit the required columns
1292            agg.into()
1293        } else {
1294            // some columns are not needed, or the order need to be adjusted.
1295            // so we did a projection to remove/reorder the columns.
1296            let mapping =
1297                &ColIndexMapping::with_remaining_columns(&new_output_cols, self.schema().len());
1298            let output_required_cols = required_cols
1299                .iter()
1300                .map(|&idx| mapping.map(idx))
1301                .collect_vec();
1302            let src_size = agg.schema().len();
1303            LogicalProject::with_mapping(
1304                agg.into(),
1305                ColIndexMapping::with_remaining_columns(&output_required_cols, src_size),
1306            )
1307            .into()
1308        }
1309    }
1310}
1311
1312impl PredicatePushdown for LogicalAgg {
1313    fn predicate_pushdown(
1314        &self,
1315        predicate: Condition,
1316        ctx: &mut PredicatePushdownContext,
1317    ) -> PlanRef {
1318        let num_group_key = self.group_key().len();
1319        let num_agg_calls = self.agg_calls().len();
1320        assert!(num_group_key + num_agg_calls == self.schema().len());
1321
1322        // SimpleAgg should be skipped because the predicate either references agg_calls
1323        // or is const.
1324        // If the filter references agg_calls, we can not push it.
1325        // When it is constantly true, pushing is useless and may actually cause more evaluation
1326        // cost of the predicate.
1327        // When it is constantly false, pushing is wrong - the old plan returns 0 rows but new one
1328        // returns 1 row.
1329        if num_group_key == 0 {
1330            return gen_filter_and_pushdown(self, predicate, Condition::true_cond(), ctx);
1331        }
1332
1333        // If the filter references agg_calls, we can not push it.
1334        let mut agg_call_columns = FixedBitSet::with_capacity(num_group_key + num_agg_calls);
1335        agg_call_columns.insert_range(num_group_key..num_group_key + num_agg_calls);
1336        let (agg_call_pred, pushed_predicate) = predicate.split_disjoint(&agg_call_columns);
1337
1338        // convert the predicate to one that references the child of the agg
1339        let mut subst = Substitute {
1340            mapping: self
1341                .group_key()
1342                .indices()
1343                .enumerate()
1344                .map(|(i, group_key)| {
1345                    InputRef::new(group_key, self.schema().fields()[i].data_type()).into()
1346                })
1347                .collect(),
1348        };
1349        let pushed_predicate = pushed_predicate.rewrite_expr(&mut subst);
1350
1351        gen_filter_and_pushdown(self, agg_call_pred, pushed_predicate, ctx)
1352    }
1353}
1354
1355impl ToBatch for LogicalAgg {
1356    fn to_batch(&self) -> Result<crate::optimizer::plan_node::BatchPlanRef> {
1357        self.to_batch_with_order_required(&Order::any())
1358    }
1359
1360    // TODO(rc): `to_batch_with_order_required` seems to be useless after we decide to use
1361    // `BatchSortAgg` only when input is already sorted
1362    fn to_batch_with_order_required(
1363        &self,
1364        required_order: &Order,
1365    ) -> Result<crate::optimizer::plan_node::BatchPlanRef> {
1366        let input = self.input().to_batch()?;
1367        let new_logical = self.core.clone_with_input(input);
1368        let agg_plan = if self.group_key().is_empty() {
1369            BatchSimpleAgg::new(new_logical).into()
1370        } else if self.ctx().session_ctx().config().batch_enable_sort_agg()
1371            && new_logical.input_provides_order_on_group_keys()
1372        {
1373            BatchSortAgg::new(new_logical).into()
1374        } else {
1375            BatchHashAgg::new(new_logical).into()
1376        };
1377        required_order.enforce_if_not_satisfies(agg_plan)
1378    }
1379}
1380
1381fn find_or_append_row_count(mut logical: Agg<StreamPlanRef>) -> (Agg<StreamPlanRef>, usize) {
1382    // `HashAgg`/`SimpleAgg` executors require a `count(*)` to correctly build changes, so
1383    // append a `count(*)` if not exists.
1384    let count_star = PlanAggCall::count_star();
1385    let row_count_idx = if let Some((idx, _)) = logical
1386        .agg_calls
1387        .iter()
1388        .find_position(|&c| c == &count_star)
1389    {
1390        idx
1391    } else {
1392        let idx = logical.agg_calls.len();
1393        logical.agg_calls.push(count_star);
1394        idx
1395    };
1396    (logical, row_count_idx)
1397}
1398
1399fn new_stream_simple_agg(
1400    core: Agg<StreamPlanRef>,
1401    must_output_per_barrier: bool,
1402) -> Result<StreamSimpleAgg> {
1403    let (logical, row_count_idx) = find_or_append_row_count(core);
1404    StreamSimpleAgg::new(logical, row_count_idx, must_output_per_barrier)
1405}
1406
1407fn new_stream_hash_agg(
1408    core: Agg<StreamPlanRef>,
1409    vnode_col_idx: Option<usize>,
1410) -> Result<StreamHashAgg> {
1411    let (logical, row_count_idx) = find_or_append_row_count(core);
1412    StreamHashAgg::new(logical, vnode_col_idx, row_count_idx)
1413}
1414
1415impl ToStream for LogicalAgg {
1416    fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<StreamPlanRef> {
1417        use super::stream::prelude::*;
1418
1419        let eowc = ctx.emit_on_window_close();
1420        let input = if self.group_key().is_empty() {
1421            self.input()
1422        } else {
1423            try_enforce_locality_requirement(self.input(), &self.group_key().to_vec())
1424        };
1425
1426        let stream_input = input.to_stream(ctx)?;
1427
1428        // Use Dedup operator, if possible.
1429        if stream_input.append_only() && self.agg_calls().is_empty() && !self.group_key().is_empty()
1430        {
1431            let input = if self.group_key().len() != self.input().schema().len() {
1432                let cols = &self.group_key().to_vec();
1433                LogicalProject::with_mapping(
1434                    self.input(),
1435                    ColIndexMapping::with_remaining_columns(cols, self.input().schema().len()),
1436                )
1437                .into()
1438            } else {
1439                self.input()
1440            };
1441            let input_schema_len = input.schema().len();
1442            let logical_dedup = LogicalDedup::new(input, (0..input_schema_len).collect());
1443            return logical_dedup.to_stream(ctx);
1444        }
1445
1446        if self.agg_calls().iter().any(|call| {
1447            matches!(
1448                call.agg_type,
1449                AggType::Builtin(PbAggKind::ApproxCountDistinct)
1450            )
1451        }) {
1452            if stream_input.append_only() {
1453                self.core.ctx().session_ctx().notice_to_user(
1454                    "Streaming `APPROX_COUNT_DISTINCT` is still a preview feature and subject to change. Please do not use it in production environment.",
1455                );
1456            } else {
1457                bail_not_implemented!(
1458                    "Streaming `APPROX_COUNT_DISTINCT` is only supported in append-only stream"
1459                );
1460            }
1461        }
1462
1463        let plan = self.gen_dist_stream_agg_plan(stream_input)?;
1464
1465        let (plan, n_final_agg_calls) = if let Some(final_agg) = plan.as_stream_simple_agg() {
1466            if eowc {
1467                return Err(ErrorCode::InvalidInputSyntax(
1468                    "`EMIT ON WINDOW CLOSE` cannot be used for aggregation without `GROUP BY`"
1469                        .to_owned(),
1470                )
1471                .into());
1472            }
1473            (plan.clone(), final_agg.agg_calls().len())
1474        } else if let Some(final_agg) = plan.as_stream_hash_agg() {
1475            (
1476                if eowc {
1477                    final_agg.to_eowc_version()?
1478                } else {
1479                    plan.clone()
1480                },
1481                final_agg.agg_calls().len(),
1482            )
1483        } else if let Some(_approx_percentile_agg) = plan.as_stream_global_approx_percentile() {
1484            if eowc {
1485                return Err(ErrorCode::InvalidInputSyntax(
1486                    "`EMIT ON WINDOW CLOSE` cannot be used for aggregation without `GROUP BY`"
1487                        .to_owned(),
1488                )
1489                .into());
1490            }
1491            (plan.clone(), 1)
1492        } else if let Some(stream_row_merge) = plan.as_stream_row_merge() {
1493            if eowc {
1494                return Err(ErrorCode::InvalidInputSyntax(
1495                    "`EMIT ON WINDOW CLOSE` cannot be used for aggregation without `GROUP BY`"
1496                        .to_owned(),
1497                )
1498                .into());
1499            }
1500            (plan.clone(), stream_row_merge.base.schema().len())
1501        } else {
1502            panic!(
1503                "the root PlanNode must be StreamHashAgg, StreamSimpleAgg, StreamGlobalApproxPercentile, or StreamRowMerge"
1504            );
1505        };
1506
1507        if self.agg_calls().len() == n_final_agg_calls {
1508            // an existing `count(*)` is used as row count column in `StreamXxxAgg`
1509            Ok(plan)
1510        } else {
1511            // a `count(*)` is appended, should project the output
1512            assert_eq!(self.agg_calls().len() + 1, n_final_agg_calls);
1513
1514            let mut project = StreamProject::new(generic::Project::with_out_col_idx(
1515                plan,
1516                0..self.schema().len(),
1517            ));
1518            // If there's no agg call, then `count(*)` will be the only column in the output besides keys.
1519            // Since it'll be pruned immediately in `StreamProject`, the update records are likely to be
1520            // no-op. So we set the hint to instruct the executor to eliminate them.
1521            // See https://github.com/risingwavelabs/risingwave/issues/17030.
1522            if self.agg_calls().is_empty() {
1523                project = project.with_noop_update_hint(true);
1524            }
1525            Ok(project.into())
1526        }
1527    }
1528
1529    fn logical_rewrite_for_stream(
1530        &self,
1531        ctx: &mut RewriteStreamContext,
1532    ) -> Result<(PlanRef, ColIndexMapping)> {
1533        let (input, input_col_change) = self.input().logical_rewrite_for_stream(ctx)?;
1534        let (agg, out_col_change) = self.rewrite_with_input(input, input_col_change);
1535        let (map, _) = out_col_change.into_parts();
1536        let out_col_change = ColIndexMapping::new(map, agg.schema().len());
1537        Ok((agg.into(), out_col_change))
1538    }
1539}
1540
1541#[cfg(test)]
1542mod tests {
1543    use risingwave_common::catalog::{Field, Schema};
1544
1545    use super::*;
1546    use crate::expr::{assert_eq_input_ref, input_ref_to_column_indices};
1547    use crate::optimizer::optimizer_context::OptimizerContext;
1548    use crate::optimizer::plan_node::LogicalValues;
1549
1550    #[tokio::test]
1551    async fn test_create() {
1552        let ty = DataType::Int32;
1553        let ctx = OptimizerContext::mock().await;
1554        let fields: Vec<Field> = vec![
1555            Field::with_name(ty.clone(), "v1"),
1556            Field::with_name(ty.clone(), "v2"),
1557            Field::with_name(ty.clone(), "v3"),
1558        ];
1559        let values = LogicalValues::new(vec![], Schema { fields }, ctx);
1560        let input = PlanRef::from(values);
1561        let input_ref_1 = InputRef::new(0, ty.clone());
1562        let input_ref_2 = InputRef::new(1, ty.clone());
1563        let input_ref_3 = InputRef::new(2, ty.clone());
1564
1565        let gen_internal_value = |select_exprs: Vec<ExprImpl>,
1566                                  group_exprs|
1567         -> (Vec<ExprImpl>, Vec<PlanAggCall>, IndexSet) {
1568            let (plan, exprs, _) = LogicalAgg::create(
1569                select_exprs,
1570                GroupBy::GroupKey(group_exprs),
1571                None,
1572                input.clone(),
1573            )
1574            .unwrap();
1575
1576            let logical_agg = plan.as_logical_agg().unwrap();
1577            let agg_calls = logical_agg.agg_calls().clone();
1578            let group_key = logical_agg.group_key().clone();
1579
1580            (exprs, agg_calls, group_key)
1581        };
1582
1583        // Test case: select v1 from test group by v1;
1584        {
1585            let select_exprs = vec![input_ref_1.clone().into()];
1586            let group_exprs = vec![input_ref_1.clone().into()];
1587
1588            let (exprs, agg_calls, group_key) = gen_internal_value(select_exprs, group_exprs);
1589
1590            assert_eq!(exprs.len(), 1);
1591            assert_eq_input_ref!(&exprs[0], 0);
1592
1593            assert_eq!(agg_calls.len(), 0);
1594            assert_eq!(group_key, vec![0].into());
1595        }
1596
1597        // Test case: select v1, min(v2) from test group by v1;
1598        {
1599            let min_v2 = AggCall::new(
1600                PbAggKind::Min.into(),
1601                vec![input_ref_2.clone().into()],
1602                false,
1603                OrderBy::any(),
1604                Condition::true_cond(),
1605                vec![],
1606            )
1607            .unwrap();
1608            let select_exprs = vec![input_ref_1.clone().into(), min_v2.into()];
1609            let group_exprs = vec![input_ref_1.clone().into()];
1610
1611            let (exprs, agg_calls, group_key) = gen_internal_value(select_exprs, group_exprs);
1612
1613            assert_eq!(exprs.len(), 2);
1614            assert_eq_input_ref!(&exprs[0], 0);
1615            assert_eq_input_ref!(&exprs[1], 1);
1616
1617            assert_eq!(agg_calls.len(), 1);
1618            assert_eq!(agg_calls[0].agg_type, PbAggKind::Min.into());
1619            assert_eq!(input_ref_to_column_indices(&agg_calls[0].inputs), vec![1]);
1620            assert_eq!(group_key, vec![0].into());
1621        }
1622
1623        // Test case: select v1, min(v2) + max(v3) from t group by v1;
1624        {
1625            let min_v2 = AggCall::new(
1626                PbAggKind::Min.into(),
1627                vec![input_ref_2.clone().into()],
1628                false,
1629                OrderBy::any(),
1630                Condition::true_cond(),
1631                vec![],
1632            )
1633            .unwrap();
1634            let max_v3 = AggCall::new(
1635                PbAggKind::Max.into(),
1636                vec![input_ref_3.clone().into()],
1637                false,
1638                OrderBy::any(),
1639                Condition::true_cond(),
1640                vec![],
1641            )
1642            .unwrap();
1643            let func_call =
1644                FunctionCall::new(ExprType::Add, vec![min_v2.into(), max_v3.into()]).unwrap();
1645            let select_exprs = vec![input_ref_1.clone().into(), ExprImpl::from(func_call)];
1646            let group_exprs = vec![input_ref_1.clone().into()];
1647
1648            let (exprs, agg_calls, group_key) = gen_internal_value(select_exprs, group_exprs);
1649
1650            assert_eq_input_ref!(&exprs[0], 0);
1651            if let ExprImpl::FunctionCall(func_call) = &exprs[1] {
1652                assert_eq!(func_call.func_type(), ExprType::Add);
1653                let inputs = func_call.inputs();
1654                assert_eq_input_ref!(&inputs[0], 1);
1655                assert_eq_input_ref!(&inputs[1], 2);
1656            } else {
1657                panic!("Wrong expression type!");
1658            }
1659
1660            assert_eq!(agg_calls.len(), 2);
1661            assert_eq!(agg_calls[0].agg_type, PbAggKind::Min.into());
1662            assert_eq!(input_ref_to_column_indices(&agg_calls[0].inputs), vec![1]);
1663            assert_eq!(agg_calls[1].agg_type, PbAggKind::Max.into());
1664            assert_eq!(input_ref_to_column_indices(&agg_calls[1].inputs), vec![2]);
1665            assert_eq!(group_key, vec![0].into());
1666        }
1667
1668        // Test case: select v2, min(v1 * v3) from test group by v2;
1669        {
1670            let v1_mult_v3 = FunctionCall::new(
1671                ExprType::Multiply,
1672                vec![input_ref_1.into(), input_ref_3.into()],
1673            )
1674            .unwrap();
1675            let agg_call = AggCall::new(
1676                PbAggKind::Min.into(),
1677                vec![v1_mult_v3.into()],
1678                false,
1679                OrderBy::any(),
1680                Condition::true_cond(),
1681                vec![],
1682            )
1683            .unwrap();
1684            let select_exprs = vec![input_ref_2.clone().into(), agg_call.into()];
1685            let group_exprs = vec![input_ref_2.into()];
1686
1687            let (exprs, agg_calls, group_key) = gen_internal_value(select_exprs, group_exprs);
1688
1689            assert_eq_input_ref!(&exprs[0], 0);
1690            assert_eq_input_ref!(&exprs[1], 1);
1691
1692            assert_eq!(agg_calls.len(), 1);
1693            assert_eq!(agg_calls[0].agg_type, PbAggKind::Min.into());
1694            assert_eq!(input_ref_to_column_indices(&agg_calls[0].inputs), vec![1]);
1695            assert_eq!(group_key, vec![0].into());
1696        }
1697    }
1698
1699    /// Generate a agg call node with given [`DataType`] and fields.
1700    /// For example, `generate_agg_call(Int32, [v1, v2, v3])` will result in:
1701    /// ```text
1702    /// Agg(min(input_ref(2))) group by (input_ref(1))
1703    ///   TableScan(v1, v2, v3)
1704    /// ```
1705    async fn generate_agg_call(ty: DataType, fields: Vec<Field>) -> LogicalAgg {
1706        let ctx = OptimizerContext::mock().await;
1707
1708        let values = LogicalValues::new(vec![], Schema { fields }, ctx);
1709        let agg_call = PlanAggCall {
1710            agg_type: PbAggKind::Min.into(),
1711            return_type: ty.clone(),
1712            inputs: vec![InputRef::new(2, ty.clone())],
1713            distinct: false,
1714            order_by: vec![],
1715            filter: Condition::true_cond(),
1716            direct_args: vec![],
1717        };
1718        Agg::new(vec![agg_call], vec![1].into(), values.into()).into()
1719    }
1720
1721    #[tokio::test]
1722    /// Pruning
1723    /// ```text
1724    /// Agg(min(input_ref(2))) group by (input_ref(1))
1725    ///   TableScan(v1, v2, v3)
1726    /// ```
1727    /// with required columns [0,1] (all columns) will result in
1728    /// ```text
1729    /// Agg(min(input_ref(1))) group by (input_ref(0))
1730    ///  TableScan(v2, v3)
1731    /// ```
1732    async fn test_prune_all() {
1733        let ty = DataType::Int32;
1734        let fields: Vec<Field> = vec![
1735            Field::with_name(ty.clone(), "v1"),
1736            Field::with_name(ty.clone(), "v2"),
1737            Field::with_name(ty.clone(), "v3"),
1738        ];
1739        let agg: PlanRef = generate_agg_call(ty.clone(), fields.clone()).await.into();
1740        // Perform the prune
1741        let required_cols = vec![0, 1];
1742        let plan = agg.prune_col(&required_cols, &mut ColumnPruningContext::new(agg.clone()));
1743
1744        // Check the result
1745        let agg_new = plan.as_logical_agg().unwrap();
1746        assert_eq!(agg_new.group_key(), &vec![0].into());
1747
1748        assert_eq!(agg_new.agg_calls().len(), 1);
1749        let agg_call_new = agg_new.agg_calls()[0].clone();
1750        assert_eq!(agg_call_new.agg_type, PbAggKind::Min.into());
1751        assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![1]);
1752        assert_eq!(agg_call_new.return_type, ty);
1753
1754        let values = agg_new.input();
1755        let values = values.as_logical_values().unwrap();
1756        assert_eq!(values.schema().fields(), &fields[1..]);
1757    }
1758
1759    #[tokio::test]
1760    /// Pruning
1761    /// ```text
1762    /// Agg(min(input_ref(2))) group by (input_ref(1))
1763    ///   TableScan(v1, v2, v3)
1764    /// ```
1765    /// with required columns [1,0] (all columns, with reversed order) will result in
1766    /// ```text
1767    /// Project [input_ref(1), input_ref(0)]
1768    ///   Agg(min(input_ref(1))) group by (input_ref(0))
1769    ///     TableScan(v2, v3)
1770    /// ```
1771    async fn test_prune_all_with_order_required() {
1772        let ty = DataType::Int32;
1773        let fields: Vec<Field> = vec![
1774            Field::with_name(ty.clone(), "v1"),
1775            Field::with_name(ty.clone(), "v2"),
1776            Field::with_name(ty.clone(), "v3"),
1777        ];
1778        let agg: PlanRef = generate_agg_call(ty.clone(), fields.clone()).await.into();
1779        // Perform the prune
1780        let required_cols = vec![1, 0];
1781        let plan = agg.prune_col(&required_cols, &mut ColumnPruningContext::new(agg.clone()));
1782        // Check the result
1783        let proj = plan.as_logical_project().unwrap();
1784        assert_eq!(proj.exprs().len(), 2);
1785        assert_eq!(proj.exprs()[0].as_input_ref().unwrap().index(), 1);
1786        assert_eq!(proj.exprs()[1].as_input_ref().unwrap().index(), 0);
1787        let proj_input = proj.input();
1788        let agg_new = proj_input.as_logical_agg().unwrap();
1789        assert_eq!(agg_new.group_key(), &vec![0].into());
1790
1791        assert_eq!(agg_new.agg_calls().len(), 1);
1792        let agg_call_new = agg_new.agg_calls()[0].clone();
1793        assert_eq!(agg_call_new.agg_type, PbAggKind::Min.into());
1794        assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![1]);
1795        assert_eq!(agg_call_new.return_type, ty);
1796
1797        let values = agg_new.input();
1798        let values = values.as_logical_values().unwrap();
1799        assert_eq!(values.schema().fields(), &fields[1..]);
1800    }
1801
1802    #[tokio::test]
1803    /// Pruning
1804    /// ```text
1805    /// Agg(min(input_ref(2))) group by (input_ref(1))
1806    ///   TableScan(v1, v2, v3)
1807    /// ```
1808    /// with required columns [1] (group key removed) will result in
1809    /// ```text
1810    /// Project(input_ref(1))
1811    ///   Agg(min(input_ref(1))) group by (input_ref(0))
1812    ///     TableScan(v2, v3)
1813    /// ```
1814    async fn test_prune_group_key() {
1815        let ctx = OptimizerContext::mock().await;
1816        let ty = DataType::Int32;
1817        let fields: Vec<Field> = vec![
1818            Field::with_name(ty.clone(), "v1"),
1819            Field::with_name(ty.clone(), "v2"),
1820            Field::with_name(ty.clone(), "v3"),
1821        ];
1822        let values: LogicalValues = LogicalValues::new(
1823            vec![],
1824            Schema {
1825                fields: fields.clone(),
1826            },
1827            ctx,
1828        );
1829        let agg_call = PlanAggCall {
1830            agg_type: PbAggKind::Min.into(),
1831            return_type: ty.clone(),
1832            inputs: vec![InputRef::new(2, ty.clone())],
1833            distinct: false,
1834            order_by: vec![],
1835            filter: Condition::true_cond(),
1836            direct_args: vec![],
1837        };
1838        let agg: PlanRef = Agg::new(vec![agg_call], vec![1].into(), values.into()).into();
1839
1840        // Perform the prune
1841        let required_cols = vec![1];
1842        let plan = agg.prune_col(&required_cols, &mut ColumnPruningContext::new(agg.clone()));
1843
1844        // Check the result
1845        let project = plan.as_logical_project().unwrap();
1846        assert_eq!(project.exprs().len(), 1);
1847        assert_eq_input_ref!(&project.exprs()[0], 1);
1848
1849        let agg_new = project.input();
1850        let agg_new = agg_new.as_logical_agg().unwrap();
1851        assert_eq!(agg_new.group_key(), &vec![0].into());
1852
1853        assert_eq!(agg_new.agg_calls().len(), 1);
1854        let agg_call_new = agg_new.agg_calls()[0].clone();
1855        assert_eq!(agg_call_new.agg_type, PbAggKind::Min.into());
1856        assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![1]);
1857        assert_eq!(agg_call_new.return_type, ty);
1858
1859        let values = agg_new.input();
1860        let values = values.as_logical_values().unwrap();
1861        assert_eq!(values.schema().fields(), &fields[1..]);
1862    }
1863
1864    #[tokio::test]
1865    /// Pruning
1866    /// ```text
1867    /// Agg(min(input_ref(2)), max(input_ref(1))) group by (input_ref(1), input_ref(2))
1868    ///   TableScan(v1, v2, v3)
1869    /// ```
1870    /// with required columns [0,3] will result in
1871    /// ```text
1872    /// Project(input_ref(0), input_ref(2))
1873    ///   Agg(max(input_ref(0))) group by (input_ref(0), input_ref(1))
1874    ///     TableScan(v2, v3)
1875    /// ```
1876    async fn test_prune_agg() {
1877        let ty = DataType::Int32;
1878        let ctx = OptimizerContext::mock().await;
1879        let fields: Vec<Field> = vec![
1880            Field::with_name(ty.clone(), "v1"),
1881            Field::with_name(ty.clone(), "v2"),
1882            Field::with_name(ty.clone(), "v3"),
1883        ];
1884        let values = LogicalValues::new(
1885            vec![],
1886            Schema {
1887                fields: fields.clone(),
1888            },
1889            ctx,
1890        );
1891
1892        let agg_calls = vec![
1893            PlanAggCall {
1894                agg_type: PbAggKind::Min.into(),
1895                return_type: ty.clone(),
1896                inputs: vec![InputRef::new(2, ty.clone())],
1897                distinct: false,
1898                order_by: vec![],
1899                filter: Condition::true_cond(),
1900                direct_args: vec![],
1901            },
1902            PlanAggCall {
1903                agg_type: PbAggKind::Max.into(),
1904                return_type: ty.clone(),
1905                inputs: vec![InputRef::new(1, ty.clone())],
1906                distinct: false,
1907                order_by: vec![],
1908                filter: Condition::true_cond(),
1909                direct_args: vec![],
1910            },
1911        ];
1912        let agg: PlanRef = Agg::new(agg_calls, vec![1, 2].into(), values.into()).into();
1913
1914        // Perform the prune
1915        let required_cols = vec![0, 3];
1916        let plan = agg.prune_col(&required_cols, &mut ColumnPruningContext::new(agg.clone()));
1917        // Check the result
1918        let project = plan.as_logical_project().unwrap();
1919        assert_eq!(project.exprs().len(), 2);
1920        assert_eq_input_ref!(&project.exprs()[0], 0);
1921        assert_eq_input_ref!(&project.exprs()[1], 2);
1922
1923        let agg_new = project.input();
1924        let agg_new = agg_new.as_logical_agg().unwrap();
1925        assert_eq!(agg_new.group_key(), &vec![0, 1].into());
1926
1927        assert_eq!(agg_new.agg_calls().len(), 1);
1928        let agg_call_new = agg_new.agg_calls()[0].clone();
1929        assert_eq!(agg_call_new.agg_type, PbAggKind::Max.into());
1930        assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![0]);
1931        assert_eq!(agg_call_new.return_type, ty);
1932
1933        let values = agg_new.input();
1934        let values = values.as_logical_values().unwrap();
1935        assert_eq!(values.schema().fields(), &fields[1..]);
1936    }
1937}