risingwave_frontend/optimizer/plan_node/
logical_agg.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use fixedbitset::FixedBitSet;
16use itertools::Itertools;
17use risingwave_common::types::{DataType, ScalarImpl};
18use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
19use risingwave_common::{bail, bail_not_implemented, not_implemented};
20use risingwave_expr::aggregate::{AggType, PbAggKind, agg_types};
21
22use super::generic::{self, Agg, GenericPlanRef, PlanAggCall, ProjectBuilder};
23use super::utils::impl_distill_by_unit;
24use super::{
25    BatchHashAgg, BatchSimpleAgg, ColPrunable, ExprRewritable, Logical, LogicalPlanRef as PlanRef,
26    PlanBase, PlanTreeNodeUnary, PredicatePushdown, StreamHashAgg, StreamPlanRef, StreamProject,
27    StreamShare, StreamSimpleAgg, StreamStatelessSimpleAgg, ToBatch, ToStream,
28    try_enforce_locality_requirement,
29};
30use crate::error::{ErrorCode, Result, RwError};
31use crate::expr::{
32    AggCall, Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, FunctionCall, InputRef, Literal,
33    OrderBy, OrderByExpr, WindowFunction,
34};
35use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
36use crate::optimizer::plan_node::generic::GenericPlanNode;
37use crate::optimizer::plan_node::stream_global_approx_percentile::StreamGlobalApproxPercentile;
38use crate::optimizer::plan_node::stream_local_approx_percentile::StreamLocalApproxPercentile;
39use crate::optimizer::plan_node::stream_row_merge::StreamRowMerge;
40use crate::optimizer::plan_node::{
41    BatchSortAgg, ColumnPruningContext, LogicalDedup, LogicalProject, PredicatePushdownContext,
42    RewriteStreamContext, ToStreamContext, gen_filter_and_pushdown,
43};
44use crate::optimizer::property::{Distribution, Order, RequiredDist};
45use crate::utils::{
46    ColIndexMapping, ColIndexMappingRewriteExt, Condition, GroupBy, IndexSet, Substitute,
47};
48
49pub struct AggInfo {
50    pub calls: Vec<PlanAggCall>,
51    pub col_mapping: ColIndexMapping,
52}
53
54/// `SeparatedAggInfo` is used to separate normal and approx percentile aggs.
55pub struct SeparatedAggInfo {
56    normal: AggInfo,
57    approx: AggInfo,
58}
59
60/// `LogicalAgg` groups input data by their group key and computes aggregation functions.
61///
62/// It corresponds to the `GROUP BY` operator in a SQL query statement together with the aggregate
63/// functions in the `SELECT` clause.
64///
65/// The output schema will first include the group key and then the aggregation calls.
66#[derive(Clone, Debug, PartialEq, Eq, Hash)]
67pub struct LogicalAgg {
68    pub base: PlanBase<Logical>,
69    core: Agg<PlanRef>,
70}
71
72impl LogicalAgg {
73    /// Generate plan for stateless 2-phase streaming agg.
74    /// Should only be used iff input is distributed. Input must be converted to stream form.
75    fn gen_stateless_two_phase_streaming_agg_plan(
76        &self,
77        stream_input: StreamPlanRef,
78    ) -> Result<StreamPlanRef> {
79        debug_assert!(self.group_key().is_empty());
80
81        // ====== Handle approx percentile aggs
82        let (
83            non_approx_percentile_col_mapping,
84            approx_percentile_col_mapping,
85            approx_percentile,
86            core,
87        ) = self.prepare_approx_percentile(stream_input)?;
88
89        if core.agg_calls.is_empty() {
90            if let Some(approx_percentile) = approx_percentile {
91                return Ok(approx_percentile);
92            };
93            bail!("expected at least one agg call");
94        }
95
96        let need_row_merge: bool = Self::need_row_merge(&approx_percentile);
97
98        // ====== Handle normal aggs
99        let total_agg_calls = core
100            .agg_calls
101            .iter()
102            .enumerate()
103            .map(|(partial_output_idx, agg_call)| {
104                agg_call.partial_to_total_agg_call(partial_output_idx)
105            })
106            .collect_vec();
107        let local_agg = StreamStatelessSimpleAgg::new(core)?;
108        let exchange =
109            RequiredDist::single().streaming_enforce_if_not_satisfies(local_agg.into())?;
110
111        let must_output_per_barrier = need_row_merge;
112        let global_agg = new_stream_simple_agg(
113            Agg::new(total_agg_calls, IndexSet::empty(), exchange),
114            must_output_per_barrier,
115        )?;
116
117        // ====== Merge approx percentile and normal aggs
118        Self::add_row_merge_if_needed(
119            approx_percentile,
120            global_agg.into(),
121            approx_percentile_col_mapping,
122            non_approx_percentile_col_mapping,
123        )
124    }
125
126    /// Generate plan for stateless/stateful 2-phase streaming agg.
127    /// Should only be used iff input is distributed.
128    /// Input must be converted to stream form.
129    fn gen_vnode_two_phase_streaming_agg_plan(
130        &self,
131        stream_input: StreamPlanRef,
132        dist_key: &[usize],
133    ) -> Result<StreamPlanRef> {
134        let (
135            non_approx_percentile_col_mapping,
136            approx_percentile_col_mapping,
137            approx_percentile,
138            core,
139        ) = self.prepare_approx_percentile(stream_input.clone())?;
140
141        if core.agg_calls.is_empty() {
142            if let Some(approx_percentile) = approx_percentile {
143                return Ok(approx_percentile);
144            };
145            bail!("expected at least one agg call");
146        }
147        let need_row_merge = Self::need_row_merge(&approx_percentile);
148
149        // Generate vnode via project
150        // TODO(kwannoel): We should apply Project optimization rules here.
151        let input_col_num = stream_input.schema().len(); // get schema len before moving `stream_input`.
152        let project = StreamProject::new(generic::Project::with_vnode_col(stream_input, dist_key));
153        let vnode_col_idx = project.base.schema().len() - 1;
154
155        // Generate local agg step
156        let mut local_group_key = self.group_key().clone();
157        local_group_key.insert(vnode_col_idx);
158        let n_local_group_key = local_group_key.len();
159        let local_agg = new_stream_hash_agg(
160            Agg::new(core.agg_calls.clone(), local_group_key, project.into()),
161            Some(vnode_col_idx),
162        )?;
163        // Global group key excludes vnode.
164        let local_agg_group_key_cardinality = local_agg.group_key().len();
165        let local_group_key_without_vnode =
166            &local_agg.group_key().to_vec()[..local_agg_group_key_cardinality - 1];
167        let global_group_key = local_agg
168            .i2o_col_mapping()
169            .rewrite_dist_key(local_group_key_without_vnode)
170            .expect("some input group key could not be mapped");
171
172        // Generate global agg step
173        let global_agg = if self.group_key().is_empty() {
174            let exchange =
175                RequiredDist::single().streaming_enforce_if_not_satisfies(local_agg.into())?;
176            let must_output_per_barrier = need_row_merge;
177            let global_agg = new_stream_simple_agg(
178                Agg::new(
179                    core.agg_calls
180                        .iter()
181                        .enumerate()
182                        .map(|(partial_output_idx, agg_call)| {
183                            agg_call
184                                .partial_to_total_agg_call(n_local_group_key + partial_output_idx)
185                        })
186                        .collect(),
187                    global_group_key.into_iter().collect(),
188                    exchange,
189                ),
190                must_output_per_barrier,
191            )?;
192            global_agg.into()
193        } else {
194            // the `RowMergeExec` has not supported keyed merge
195            assert!(!need_row_merge);
196            let exchange = RequiredDist::shard_by_key(input_col_num, &global_group_key)
197                .streaming_enforce_if_not_satisfies(local_agg.into())?;
198            // Local phase should have reordered the group keys into their required order.
199            // we can just follow it.
200            let global_agg = new_stream_hash_agg(
201                Agg::new(
202                    core.agg_calls
203                        .iter()
204                        .enumerate()
205                        .map(|(partial_output_idx, agg_call)| {
206                            agg_call
207                                .partial_to_total_agg_call(n_local_group_key + partial_output_idx)
208                        })
209                        .collect(),
210                    global_group_key.into_iter().collect(),
211                    exchange,
212                ),
213                None,
214            )?;
215            global_agg.into()
216        };
217        Self::add_row_merge_if_needed(
218            approx_percentile,
219            global_agg,
220            approx_percentile_col_mapping,
221            non_approx_percentile_col_mapping,
222        )
223    }
224
225    fn gen_single_plan(&self, stream_input: StreamPlanRef) -> Result<StreamPlanRef> {
226        let input = RequiredDist::single().streaming_enforce_if_not_satisfies(stream_input)?;
227        let core = self.core.clone_with_input(input);
228        Ok(new_stream_simple_agg(core, false)?.into())
229    }
230
231    fn gen_shuffle_plan(&self, stream_input: StreamPlanRef) -> Result<StreamPlanRef> {
232        let input =
233            RequiredDist::shard_by_key(stream_input.schema().len(), &self.group_key().to_vec())
234                .streaming_enforce_if_not_satisfies(stream_input)?;
235        let core = self.core.clone_with_input(input);
236        Ok(new_stream_hash_agg(core, None)?.into())
237    }
238
239    /// Generates distributed stream plan.
240    fn gen_dist_stream_agg_plan(&self, stream_input: StreamPlanRef) -> Result<StreamPlanRef> {
241        use super::stream::prelude::*;
242
243        let input_dist = stream_input.distribution();
244        debug_assert!(*input_dist != Distribution::Broadcast);
245
246        // Shuffle agg
247        // If we have group key, and we won't try two phase agg optimization at all,
248        // we will always choose shuffle agg over single agg.
249        if !self.group_key().is_empty() && !self.core.must_try_two_phase_agg() {
250            return self.gen_shuffle_plan(stream_input);
251        }
252
253        // Standalone agg
254        // If no group key, and cannot two phase agg, we have to use single plan.
255        if self.group_key().is_empty() && !self.core.can_two_phase_agg() {
256            return self.gen_single_plan(stream_input);
257        }
258
259        debug_assert!(if !self.group_key().is_empty() {
260            self.core.must_try_two_phase_agg()
261        } else {
262            self.core.can_two_phase_agg()
263        });
264
265        // Stateless 2-phase simple agg
266        // can be applied on stateless simple agg calls,
267        // with input distributed by [`Distribution::AnyShard`]
268        if self.group_key().is_empty()
269            && self
270                .core
271                .all_local_aggs_are_stateless(stream_input.append_only())
272            && input_dist.satisfies(&RequiredDist::AnyShard)
273        {
274            return self.gen_stateless_two_phase_streaming_agg_plan(stream_input);
275        }
276
277        // If input is [`Distribution::SomeShard`] and we must try to use two phase agg,
278        // The only remaining strategy is Vnode-based 2-phase agg.
279        // We shall first distribute it by PK,
280        // so it obeys consistent hash strategy via [`Distribution::HashShard`].
281        let stream_input =
282            if *input_dist == Distribution::SomeShard && self.core.must_try_two_phase_agg() {
283                RequiredDist::shard_by_key(
284                    stream_input.schema().len(),
285                    stream_input.expect_stream_key(),
286                )
287                .streaming_enforce_if_not_satisfies(stream_input)?
288            } else {
289                stream_input
290            };
291        let input_dist = stream_input.distribution();
292
293        // Vnode-based 2-phase agg
294        // can be applied on agg calls not affected by order,
295        // with input distributed by dist_key.
296        match input_dist {
297            Distribution::HashShard(dist_key) | Distribution::UpstreamHashShard(dist_key, _)
298                if (self.group_key().is_empty()
299                    || !self.core.hash_agg_dist_satisfied_by_input_dist(input_dist)) =>
300            {
301                let dist_key = dist_key.clone();
302                return self.gen_vnode_two_phase_streaming_agg_plan(stream_input, &dist_key);
303            }
304            _ => {}
305        }
306
307        // Fallback to shuffle or single, if we can't generate any 2-phase plans.
308        if !self.group_key().is_empty() {
309            self.gen_shuffle_plan(stream_input)
310        } else {
311            self.gen_single_plan(stream_input)
312        }
313    }
314
315    /// Prepares metadata and the `approx_percentile` plan, if there's one present.
316    /// It may modify `core.agg_calls` to separate normal agg and approx percentile agg,
317    /// and `core.input` to share the input via `StreamShare`,
318    /// to both approx percentile agg and normal agg.
319    fn prepare_approx_percentile(
320        &self,
321        stream_input: StreamPlanRef,
322    ) -> Result<(
323        ColIndexMapping,
324        ColIndexMapping,
325        Option<StreamPlanRef>,
326        Agg<StreamPlanRef>,
327    )> {
328        let SeparatedAggInfo { normal, approx } = self.separate_normal_and_special_agg();
329
330        let AggInfo {
331            calls: non_approx_percentile_agg_calls,
332            col_mapping: non_approx_percentile_col_mapping,
333        } = normal;
334        let AggInfo {
335            calls: approx_percentile_agg_calls,
336            col_mapping: approx_percentile_col_mapping,
337        } = approx;
338        if !self.group_key().is_empty() && !approx_percentile_agg_calls.is_empty() {
339            bail_not_implemented!(
340                "two-phase streaming approx percentile aggregation with group key, \
341             please use single phase aggregation instead"
342            );
343        }
344
345        // Either we have approx percentile aggs and non_approx percentile aggs,
346        // or we have at least 2 approx percentile aggs.
347        let needs_row_merge = (!non_approx_percentile_agg_calls.is_empty()
348            && !approx_percentile_agg_calls.is_empty())
349            || approx_percentile_agg_calls.len() >= 2;
350        let input = if needs_row_merge {
351            // If there's row merge, we need to share the input.
352            StreamShare::new_from_input(stream_input).into()
353        } else {
354            stream_input
355        };
356        let mut core = self.core.clone_with_input(input);
357        core.agg_calls = non_approx_percentile_agg_calls;
358
359        let approx_percentile =
360            self.build_approx_percentile_aggs(core.input.clone(), &approx_percentile_agg_calls)?;
361        Ok((
362            non_approx_percentile_col_mapping,
363            approx_percentile_col_mapping,
364            approx_percentile,
365            core,
366        ))
367    }
368
369    fn need_row_merge(approx_percentile: &Option<StreamPlanRef>) -> bool {
370        approx_percentile.is_some()
371    }
372
373    /// Add `RowMerge` if needed
374    fn add_row_merge_if_needed(
375        approx_percentile: Option<StreamPlanRef>,
376        global_agg: StreamPlanRef,
377        approx_percentile_col_mapping: ColIndexMapping,
378        non_approx_percentile_col_mapping: ColIndexMapping,
379    ) -> Result<StreamPlanRef> {
380        // just for assert
381        let need_row_merge = Self::need_row_merge(&approx_percentile);
382
383        if let Some(approx_percentile) = approx_percentile {
384            assert!(need_row_merge);
385            let row_merge = StreamRowMerge::new(
386                approx_percentile,
387                global_agg,
388                approx_percentile_col_mapping,
389                non_approx_percentile_col_mapping,
390            )?;
391            Ok(row_merge.into())
392        } else {
393            assert!(!need_row_merge);
394            Ok(global_agg)
395        }
396    }
397
398    fn separate_normal_and_special_agg(&self) -> SeparatedAggInfo {
399        let estimated_len = self.agg_calls().len() - 1;
400        let mut approx_percentile_agg_calls = Vec::with_capacity(estimated_len);
401        let mut non_approx_percentile_agg_calls = Vec::with_capacity(estimated_len);
402        let mut approx_percentile_col_mapping = Vec::with_capacity(estimated_len);
403        let mut non_approx_percentile_col_mapping = Vec::with_capacity(estimated_len);
404        for (output_idx, agg_call) in self.agg_calls().iter().enumerate() {
405            if agg_call.agg_type == AggType::Builtin(PbAggKind::ApproxPercentile) {
406                approx_percentile_agg_calls.push(agg_call.clone());
407                approx_percentile_col_mapping.push(Some(output_idx));
408            } else {
409                non_approx_percentile_agg_calls.push(agg_call.clone());
410                non_approx_percentile_col_mapping.push(Some(output_idx));
411            }
412        }
413        let normal = AggInfo {
414            calls: non_approx_percentile_agg_calls,
415            col_mapping: ColIndexMapping::new(
416                non_approx_percentile_col_mapping,
417                self.agg_calls().len(),
418            ),
419        };
420        let approx = AggInfo {
421            calls: approx_percentile_agg_calls,
422            col_mapping: ColIndexMapping::new(
423                approx_percentile_col_mapping,
424                self.agg_calls().len(),
425            ),
426        };
427        SeparatedAggInfo { normal, approx }
428    }
429
430    fn build_approx_percentile_agg(
431        &self,
432        input: StreamPlanRef,
433        approx_percentile_agg_call: &PlanAggCall,
434    ) -> Result<StreamPlanRef> {
435        let local_approx_percentile =
436            StreamLocalApproxPercentile::new(input, approx_percentile_agg_call)?;
437        let exchange = RequiredDist::single()
438            .streaming_enforce_if_not_satisfies(local_approx_percentile.into())?;
439        let global_approx_percentile =
440            StreamGlobalApproxPercentile::new(exchange, approx_percentile_agg_call);
441        Ok(global_approx_percentile.into())
442    }
443
444    /// If only 1 approx percentile, just return it.
445    /// Otherwise build a tree of approx percentile with `MergeProject`.
446    /// e.g.
447    /// ApproxPercentile(col1, 0.5) as x,
448    /// ApproxPercentile(col2, 0.5) as y,
449    /// ApproxPercentile(col3, 0.5) as z
450    /// will be built as
451    ///        `MergeProject`
452    ///       /          \
453    ///  `MergeProject`       z
454    ///  /        \
455    /// x          y
456    fn build_approx_percentile_aggs(
457        &self,
458        input: StreamPlanRef,
459        approx_percentile_agg_call: &[PlanAggCall],
460    ) -> Result<Option<StreamPlanRef>> {
461        if approx_percentile_agg_call.is_empty() {
462            return Ok(None);
463        }
464        let approx_percentile_plans: Vec<_> = approx_percentile_agg_call
465            .iter()
466            .map(|agg_call| self.build_approx_percentile_agg(input.clone(), agg_call))
467            .try_collect()?;
468        assert!(!approx_percentile_plans.is_empty());
469        let mut iter = approx_percentile_plans.into_iter();
470        let mut acc = iter.next().unwrap();
471        for (current_size, plan) in iter.enumerate().map(|(i, p)| (i + 1, p)) {
472            let new_size = current_size + 1;
473            let row_merge = StreamRowMerge::new(
474                acc,
475                plan,
476                ColIndexMapping::identity_or_none(current_size, new_size),
477                ColIndexMapping::new(vec![Some(current_size)], new_size),
478            )?;
479            acc = row_merge.into();
480        }
481        Ok(Some(acc))
482    }
483
484    pub fn core(&self) -> &Agg<PlanRef> {
485        &self.core
486    }
487}
488
489/// `LogicalAggBuilder` extracts agg calls and references to group columns from select list and
490/// build the plan like `LogicalAgg - LogicalProject`.
491/// it is constructed by `group_exprs` and collect and rewrite the expression in selection and
492/// having clause.
493pub struct LogicalAggBuilder {
494    /// the builder of the input Project
495    input_proj_builder: ProjectBuilder,
496    /// the group key column indices in the project's output
497    group_key: IndexSet,
498    /// the grouping sets
499    grouping_sets: Vec<IndexSet>,
500    /// the agg calls
501    agg_calls: Vec<PlanAggCall>,
502    /// the error during the expression rewriting
503    error: Option<RwError>,
504    /// If `is_in_filter_clause` is true, it means that
505    /// we are processing filter clause.
506    /// This field is needed because input refs in these clauses
507    /// are allowed to refer to any columns, while those not in filter
508    /// clause are only allowed to refer to group keys.
509    is_in_filter_clause: bool,
510}
511
512impl LogicalAggBuilder {
513    fn new(group_by: GroupBy, input_schema_len: usize) -> Result<Self> {
514        let mut input_proj_builder = ProjectBuilder::default();
515
516        let mut gen_group_key_and_grouping_sets =
517            |grouping_sets: Vec<Vec<ExprImpl>>| -> Result<(IndexSet, Vec<IndexSet>)> {
518                let grouping_sets: Vec<IndexSet> = grouping_sets
519                    .into_iter()
520                    .map(|set| {
521                        set.into_iter()
522                            .map(|expr| input_proj_builder.add_expr(&expr))
523                            .try_collect()
524                            .map_err(|err| not_implemented!("{err} inside GROUP BY"))
525                    })
526                    .try_collect()?;
527
528                // Construct group key based on grouping sets.
529                let group_key = grouping_sets
530                    .iter()
531                    .fold(FixedBitSet::with_capacity(input_schema_len), |acc, x| {
532                        acc.union(&x.to_bitset()).collect()
533                    });
534
535                Ok((IndexSet::from_iter(group_key.ones()), grouping_sets))
536            };
537
538        let (group_key, grouping_sets) = match group_by {
539            GroupBy::GroupKey(group_key) => {
540                let group_key = group_key
541                    .into_iter()
542                    .map(|expr| input_proj_builder.add_expr(&expr))
543                    .try_collect()
544                    .map_err(|err| not_implemented!("{err} inside GROUP BY"))?;
545                (group_key, vec![])
546            }
547            GroupBy::GroupingSets(grouping_sets) => gen_group_key_and_grouping_sets(grouping_sets)?,
548            GroupBy::Rollup(rollup) => {
549                // Convert rollup to grouping sets.
550                let grouping_sets = (0..=rollup.len())
551                    .map(|n| {
552                        rollup
553                            .iter()
554                            .take(n)
555                            .flat_map(|x| x.iter().cloned())
556                            .collect_vec()
557                    })
558                    .collect_vec();
559                gen_group_key_and_grouping_sets(grouping_sets)?
560            }
561            GroupBy::Cube(cube) => {
562                // Convert cube to grouping sets.
563                let grouping_sets = cube
564                    .into_iter()
565                    .powerset()
566                    .map(|x| x.into_iter().flatten().collect_vec())
567                    .collect_vec();
568                gen_group_key_and_grouping_sets(grouping_sets)?
569            }
570        };
571
572        Ok(LogicalAggBuilder {
573            group_key,
574            grouping_sets,
575            agg_calls: vec![],
576            error: None,
577            input_proj_builder,
578            is_in_filter_clause: false,
579        })
580    }
581
582    pub fn build(self, input: PlanRef) -> LogicalAgg {
583        // This LogicalProject focuses on the exprs in aggregates and GROUP BY clause.
584        let logical_project = LogicalProject::with_core(self.input_proj_builder.build(input));
585
586        // This LogicalAgg focuses on calculating the aggregates and grouping.
587        Agg::new(self.agg_calls, self.group_key, logical_project.into())
588            .with_grouping_sets(self.grouping_sets)
589            .into()
590    }
591
592    fn rewrite_with_error(&mut self, expr: ExprImpl) -> Result<ExprImpl> {
593        let rewritten_expr = self.rewrite_expr(expr);
594        if let Some(error) = self.error.take() {
595            return Err(error);
596        }
597        Ok(rewritten_expr)
598    }
599
600    /// check if the expression is a group by key, and try to return the group key
601    pub fn try_as_group_expr(&self, expr: &ExprImpl) -> Option<usize> {
602        if let Some(input_index) = self.input_proj_builder.expr_index(expr)
603            && let Some(index) = self
604                .group_key
605                .indices()
606                .position(|group_key| group_key == input_index)
607        {
608            return Some(index);
609        }
610        None
611    }
612
613    fn schema_agg_start_offset(&self) -> usize {
614        self.group_key.len()
615    }
616
617    /// Rewrite [`AggCall`] if needed, and push it into the builder using `push_agg_call`.
618    /// This is shared by [`LogicalAggBuilder`] and `LogicalOverWindowBuilder`.
619    pub(crate) fn general_rewrite_agg_call(
620        agg_call: AggCall,
621        mut push_agg_call: impl FnMut(AggCall) -> Result<InputRef>,
622    ) -> Result<ExprImpl> {
623        match agg_call.agg_type {
624            // Rewrite avg to cast(sum as avg_return_type) / count.
625            AggType::Builtin(PbAggKind::Avg) => {
626                assert_eq!(agg_call.args.len(), 1);
627                let return_type = agg_call.return_type();
628
629                let sum = ExprImpl::from(push_agg_call(AggCall::new(
630                    PbAggKind::Sum.into(),
631                    agg_call.args.clone(),
632                    agg_call.distinct,
633                    agg_call.order_by.clone(),
634                    agg_call.filter.clone(),
635                    agg_call.direct_args.clone(),
636                )?)?)
637                .cast_explicit(&agg_call.return_type())?;
638
639                let count = ExprImpl::from(push_agg_call(AggCall::new(
640                    PbAggKind::Count.into(),
641                    agg_call.args.clone(),
642                    agg_call.distinct,
643                    agg_call.order_by.clone(),
644                    agg_call.filter.clone(),
645                    agg_call.direct_args,
646                )?)?);
647
648                let target = ExprImpl::from(FunctionCall::new(
649                    ExprType::Divide,
650                    Vec::from([sum, count.clone()]),
651                )?);
652                let null = ExprImpl::from(Literal::new(None, return_type));
653                let zero = ExprImpl::literal_int(0);
654                let case_cond =
655                    ExprImpl::from(FunctionCall::new(ExprType::Equal, vec![count, zero])?);
656                Ok(ExprImpl::from(FunctionCall::new(
657                    ExprType::Case,
658                    vec![case_cond, null, target],
659                )?))
660            }
661            // We compute `var_samp` as
662            // (sum(sq) - sum * sum / count) / (count - 1)
663            // and `var_pop` as
664            // (sum(sq) - sum * sum / count) / count
665            // Since we don't have the square function, we use the plain Multiply for squaring,
666            // which is in a sense more general than the pow function, especially when calculating
667            // covariances in the future. Also we don't have the sqrt function for rooting, so we
668            // use pow(x, 0.5) to simulate
669            AggType::Builtin(
670                kind @ (PbAggKind::StddevPop
671                | PbAggKind::StddevSamp
672                | PbAggKind::VarPop
673                | PbAggKind::VarSamp),
674            ) => {
675                let arg = agg_call.args().iter().exactly_one().unwrap();
676                let squared_arg = ExprImpl::from(FunctionCall::new(
677                    ExprType::Multiply,
678                    vec![arg.clone(), arg.clone()],
679                )?);
680
681                let sum_of_sq = ExprImpl::from(push_agg_call(AggCall::new(
682                    PbAggKind::Sum.into(),
683                    vec![squared_arg],
684                    agg_call.distinct,
685                    agg_call.order_by.clone(),
686                    agg_call.filter.clone(),
687                    agg_call.direct_args.clone(),
688                )?)?)
689                .cast_explicit(&agg_call.return_type())?;
690
691                let sum = ExprImpl::from(push_agg_call(AggCall::new(
692                    PbAggKind::Sum.into(),
693                    agg_call.args.clone(),
694                    agg_call.distinct,
695                    agg_call.order_by.clone(),
696                    agg_call.filter.clone(),
697                    agg_call.direct_args.clone(),
698                )?)?)
699                .cast_explicit(&agg_call.return_type())?;
700
701                let count = ExprImpl::from(push_agg_call(AggCall::new(
702                    PbAggKind::Count.into(),
703                    agg_call.args.clone(),
704                    agg_call.distinct,
705                    agg_call.order_by.clone(),
706                    agg_call.filter.clone(),
707                    agg_call.direct_args.clone(),
708                )?)?);
709
710                let zero = ExprImpl::literal_int(0);
711                let one = ExprImpl::literal_int(1);
712
713                let squared_sum = ExprImpl::from(FunctionCall::new(
714                    ExprType::Multiply,
715                    vec![sum.clone(), sum],
716                )?);
717
718                let raw_numerator = ExprImpl::from(FunctionCall::new(
719                    ExprType::Subtract,
720                    vec![
721                        sum_of_sq,
722                        ExprImpl::from(FunctionCall::new(
723                            ExprType::Divide,
724                            vec![squared_sum, count.clone()],
725                        )?),
726                    ],
727                )?);
728
729                // We need to check for potential accuracy issues that may occasionally lead to results less than 0.
730                let numerator_type = raw_numerator.return_type();
731                let numerator = ExprImpl::from(FunctionCall::new(
732                    ExprType::Greatest,
733                    vec![raw_numerator, zero.clone().cast_explicit(&numerator_type)?],
734                )?);
735
736                let denominator = match kind {
737                    PbAggKind::VarPop | PbAggKind::StddevPop => count.clone(),
738                    PbAggKind::VarSamp | PbAggKind::StddevSamp => ExprImpl::from(
739                        FunctionCall::new(ExprType::Subtract, vec![count.clone(), one.clone()])?,
740                    ),
741                    _ => unreachable!(),
742                };
743
744                let mut target = ExprImpl::from(FunctionCall::new(
745                    ExprType::Divide,
746                    vec![numerator, denominator],
747                )?);
748
749                if matches!(kind, PbAggKind::StddevPop | PbAggKind::StddevSamp) {
750                    target = ExprImpl::from(FunctionCall::new(ExprType::Sqrt, vec![target])?);
751                }
752
753                let null = ExprImpl::from(Literal::new(None, agg_call.return_type()));
754                let case_cond = match kind {
755                    PbAggKind::VarPop | PbAggKind::StddevPop => {
756                        ExprImpl::from(FunctionCall::new(ExprType::Equal, vec![count, zero])?)
757                    }
758                    PbAggKind::VarSamp | PbAggKind::StddevSamp => ExprImpl::from(
759                        FunctionCall::new(ExprType::LessThanOrEqual, vec![count, one])?,
760                    ),
761                    _ => unreachable!(),
762                };
763
764                Ok(ExprImpl::from(FunctionCall::new(
765                    ExprType::Case,
766                    vec![case_cond, null, target],
767                )?))
768            }
769            AggType::Builtin(PbAggKind::ApproxPercentile) => {
770                if agg_call.order_by.sort_exprs[0].order_type == OrderType::descending() {
771                    // Rewrite DESC into 1.0-percentile for approx_percentile.
772                    let prev_percentile = agg_call.direct_args[0].clone();
773                    let new_percentile = 1.0
774                        - prev_percentile
775                            .get_data()
776                            .as_ref()
777                            .unwrap()
778                            .as_float64()
779                            .into_inner();
780                    let new_percentile = Some(ScalarImpl::Float64(new_percentile.into()));
781                    let new_percentile = Literal::new(new_percentile, DataType::Float64);
782                    let new_direct_args = vec![new_percentile, agg_call.direct_args[1].clone()];
783
784                    let new_agg_call = AggCall {
785                        order_by: OrderBy::any(),
786                        direct_args: new_direct_args,
787                        ..agg_call
788                    };
789                    Ok(push_agg_call(new_agg_call)?.into())
790                } else {
791                    let new_agg_call = AggCall {
792                        order_by: OrderBy::any(),
793                        ..agg_call
794                    };
795                    Ok(push_agg_call(new_agg_call)?.into())
796                }
797            }
798            AggType::Builtin(PbAggKind::ArgMin | PbAggKind::ArgMax) => {
799                let mut agg_call = agg_call;
800
801                let comparison_arg_type = agg_call.args[1].return_type();
802                match comparison_arg_type {
803                    DataType::Struct(_)
804                    | DataType::List(_)
805                    | DataType::Map(_)
806                    | DataType::Vector(_)
807                    | DataType::Jsonb => {
808                        bail!(format!(
809                            "{} does not support struct, array, map, vector, jsonb for comparison argument, got {}",
810                            agg_call.agg_type.to_string(),
811                            comparison_arg_type
812                        ));
813                    }
814                    _ => {}
815                }
816
817                let not_null_exprs: Vec<ExprImpl> = agg_call
818                    .args
819                    .iter()
820                    .map(|arg| -> Result<ExprImpl> {
821                        Ok(FunctionCall::new(ExprType::IsNotNull, vec![arg.clone()])?.into())
822                    })
823                    .try_collect()?;
824
825                let comparison_expr = agg_call.args[1].clone();
826                let mut order_exprs = vec![OrderByExpr {
827                    expr: comparison_expr,
828                    order_type: if agg_call.agg_type == AggType::Builtin(PbAggKind::ArgMin) {
829                        OrderType::ascending()
830                    } else {
831                        OrderType::descending()
832                    },
833                }];
834
835                order_exprs.extend(agg_call.order_by.sort_exprs);
836
837                let order_by = OrderBy::new(order_exprs);
838
839                let filter = agg_call.filter.clone().and(Condition {
840                    conjunctions: not_null_exprs,
841                });
842
843                agg_call.args.truncate(1);
844
845                let new_agg_call = AggCall {
846                    agg_type: AggType::Builtin(PbAggKind::FirstValue),
847                    order_by,
848                    filter,
849                    ..agg_call
850                };
851                Ok(push_agg_call(new_agg_call)?.into())
852            }
853            _ => Ok(push_agg_call(agg_call)?.into()),
854        }
855    }
856
857    /// Push a new agg call into the builder.
858    /// Return an `InputRef` to that agg call.
859    /// For existing agg calls, return an `InputRef` to the existing one.
860    fn push_agg_call(&mut self, agg_call: AggCall) -> Result<InputRef> {
861        let AggCall {
862            agg_type,
863            return_type,
864            args,
865            distinct,
866            order_by,
867            filter,
868            direct_args,
869        } = agg_call;
870
871        self.is_in_filter_clause = true;
872        // filter expr is not added to `input_proj_builder` as a whole. Special exprs incl
873        // subquery/agg/table are rejected in `bind_agg`.
874        let filter = filter.rewrite_expr(self);
875        self.is_in_filter_clause = false;
876
877        let args: Vec<_> = args
878            .iter()
879            .map(|expr| {
880                let index = self.input_proj_builder.add_expr(expr)?;
881                Ok(InputRef::new(index, expr.return_type()))
882            })
883            .try_collect()
884            .map_err(|err: &'static str| not_implemented!("{err} inside aggregation calls"))?;
885
886        let order_by: Vec<_> = order_by
887            .sort_exprs
888            .iter()
889            .map(|e| {
890                let index = self.input_proj_builder.add_expr(&e.expr)?;
891                Ok(ColumnOrder::new(index, e.order_type))
892            })
893            .try_collect()
894            .map_err(|err: &'static str| {
895                not_implemented!("{err} inside aggregation calls order by")
896            })?;
897
898        let plan_agg_call = PlanAggCall {
899            agg_type,
900            return_type: return_type.clone(),
901            inputs: args,
902            distinct,
903            order_by,
904            filter,
905            direct_args,
906        };
907
908        if let Some((pos, existing)) = self
909            .agg_calls
910            .iter()
911            .find_position(|&c| c == &plan_agg_call)
912        {
913            return Ok(InputRef::new(
914                self.schema_agg_start_offset() + pos,
915                existing.return_type.clone(),
916            ));
917        }
918        let index = self.schema_agg_start_offset() + self.agg_calls.len();
919        self.agg_calls.push(plan_agg_call);
920        Ok(InputRef::new(index, return_type))
921    }
922
923    /// When there is an agg call, there are 3 things to do:
924    /// 1. Rewrite `avg`, `var_samp`, etc. into a combination of `sum`, `count`, etc.;
925    /// 2. Add exprs in arguments to input `Project`;
926    /// 2. Add the agg call to current `Agg`, and return an `InputRef` to it.
927    ///
928    /// Note that the rewriter does not traverse into inputs of agg calls.
929    fn try_rewrite_agg_call(&mut self, mut agg_call: AggCall) -> Result<ExprImpl> {
930        if matches!(agg_call.agg_type, agg_types::must_have_order_by!())
931            && agg_call.order_by.sort_exprs.is_empty()
932        {
933            return Err(ErrorCode::InvalidInputSyntax(format!(
934                "Aggregation function {} requires ORDER BY clause",
935                agg_call.agg_type
936            ))
937            .into());
938        }
939
940        // try ignore ORDER BY if it doesn't affect the result
941        if matches!(
942            agg_call.agg_type,
943            agg_types::result_unaffected_by_order_by!()
944        ) {
945            agg_call.order_by = OrderBy::any();
946        }
947        // try ignore DISTINCT if it doesn't affect the result
948        if matches!(
949            agg_call.agg_type,
950            agg_types::result_unaffected_by_distinct!()
951        ) {
952            agg_call.distinct = false;
953        }
954
955        if matches!(agg_call.agg_type, AggType::Builtin(PbAggKind::Grouping)) {
956            if self.grouping_sets.is_empty() {
957                return Err(ErrorCode::NotSupported(
958                    "GROUPING must be used in a query with grouping sets".into(),
959                    "try to use grouping sets instead".into(),
960                )
961                .into());
962            }
963            if agg_call.args.len() >= 32 {
964                return Err(ErrorCode::InvalidInputSyntax(
965                    "GROUPING must have fewer than 32 arguments".into(),
966                )
967                .into());
968            }
969            if agg_call
970                .args
971                .iter()
972                .any(|x| self.try_as_group_expr(x).is_none())
973            {
974                return Err(ErrorCode::InvalidInputSyntax(
975                    "arguments to GROUPING must be grouping expressions of the associated query level"
976                        .into(),
977                ).into());
978            }
979        }
980
981        Self::general_rewrite_agg_call(agg_call, |agg_call| self.push_agg_call(agg_call))
982    }
983}
984
985impl ExprRewriter for LogicalAggBuilder {
986    fn rewrite_agg_call(&mut self, agg_call: AggCall) -> ExprImpl {
987        let dummy = Literal::new(None, agg_call.return_type()).into();
988        match self.try_rewrite_agg_call(agg_call) {
989            Ok(expr) => expr,
990            Err(err) => {
991                self.error = Some(err);
992                dummy
993            }
994        }
995    }
996
997    /// When there is an `FunctionCall` (outside of agg call), it must refers to a group column.
998    /// Or all `InputRef`s appears in it must refer to a group column.
999    fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
1000        let expr = func_call.into();
1001        if let Some(group_key) = self.try_as_group_expr(&expr) {
1002            InputRef::new(group_key, expr.return_type()).into()
1003        } else {
1004            let (func_type, inputs, ret) = expr.into_function_call().unwrap().decompose();
1005            let inputs = inputs
1006                .into_iter()
1007                .map(|expr| self.rewrite_expr(expr))
1008                .collect();
1009            FunctionCall::new_unchecked(func_type, inputs, ret).into()
1010        }
1011    }
1012
1013    /// When there is an `WindowFunction` (outside of agg call), it must refers to a group column.
1014    /// Or all `InputRef`s appears in it must refer to a group column.
1015    fn rewrite_window_function(&mut self, window_func: WindowFunction) -> ExprImpl {
1016        let WindowFunction {
1017            args,
1018            partition_by,
1019            order_by,
1020            ..
1021        } = window_func;
1022        let args = args
1023            .into_iter()
1024            .map(|expr| self.rewrite_expr(expr))
1025            .collect();
1026        let partition_by = partition_by
1027            .into_iter()
1028            .map(|expr| self.rewrite_expr(expr))
1029            .collect();
1030        let order_by = order_by.rewrite_expr(self);
1031        WindowFunction {
1032            args,
1033            partition_by,
1034            order_by,
1035            ..window_func
1036        }
1037        .into()
1038    }
1039
1040    /// When there is an `InputRef` (outside of agg call), it must refers to a group column.
1041    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
1042        let expr = input_ref.into();
1043        if let Some(group_key) = self.try_as_group_expr(&expr) {
1044            InputRef::new(group_key, expr.return_type()).into()
1045        } else if self.is_in_filter_clause {
1046            InputRef::new(
1047                self.input_proj_builder.add_expr(&expr).unwrap(),
1048                expr.return_type(),
1049            )
1050            .into()
1051        } else {
1052            self.error = Some(
1053                ErrorCode::InvalidInputSyntax(
1054                    "column must appear in the GROUP BY clause or be used in an aggregate function"
1055                        .into(),
1056                )
1057                .into(),
1058            );
1059            expr
1060        }
1061    }
1062
1063    fn rewrite_subquery(&mut self, subquery: crate::expr::Subquery) -> ExprImpl {
1064        if subquery.is_correlated_by_depth(0) {
1065            self.error = Some(
1066                not_implemented!(
1067                    issue = 2275,
1068                    "correlated subquery in HAVING or SELECT with agg",
1069                )
1070                .into(),
1071            );
1072        }
1073        subquery.into()
1074    }
1075}
1076
1077impl From<Agg<PlanRef>> for LogicalAgg {
1078    fn from(core: Agg<PlanRef>) -> Self {
1079        let base = PlanBase::new_logical_with_core(&core);
1080        Self { base, core }
1081    }
1082}
1083
1084/// Because `From`/`Into` are not transitive
1085impl From<Agg<PlanRef>> for PlanRef {
1086    fn from(core: Agg<PlanRef>) -> Self {
1087        LogicalAgg::from(core).into()
1088    }
1089}
1090
1091impl LogicalAgg {
1092    /// get the Mapping of columnIndex from input column index to out column index
1093    pub fn i2o_col_mapping(&self) -> ColIndexMapping {
1094        self.core.i2o_col_mapping()
1095    }
1096
1097    /// `create` will analyze select exprs, group exprs and having, and construct a plan like
1098    ///
1099    /// ```text
1100    /// LogicalAgg -> LogicalProject -> input
1101    /// ```
1102    ///
1103    /// It also returns the rewritten select exprs and having that reference into the aggregated
1104    /// results.
1105    pub fn create(
1106        select_exprs: Vec<ExprImpl>,
1107        group_by: GroupBy,
1108        having: Option<ExprImpl>,
1109        input: PlanRef,
1110    ) -> Result<(PlanRef, Vec<ExprImpl>, Option<ExprImpl>)> {
1111        let mut agg_builder = LogicalAggBuilder::new(group_by, input.schema().len())?;
1112
1113        let rewritten_select_exprs = select_exprs
1114            .into_iter()
1115            .map(|expr| agg_builder.rewrite_with_error(expr))
1116            .collect::<Result<_>>()?;
1117        let rewritten_having = having
1118            .map(|expr| agg_builder.rewrite_with_error(expr))
1119            .transpose()?;
1120
1121        Ok((
1122            agg_builder.build(input).into(),
1123            rewritten_select_exprs,
1124            rewritten_having,
1125        ))
1126    }
1127
1128    /// Get a reference to the logical agg's agg calls.
1129    pub fn agg_calls(&self) -> &Vec<PlanAggCall> {
1130        &self.core.agg_calls
1131    }
1132
1133    /// Get a reference to the logical agg's group key.
1134    pub fn group_key(&self) -> &IndexSet {
1135        &self.core.group_key
1136    }
1137
1138    pub fn grouping_sets(&self) -> &Vec<IndexSet> {
1139        &self.core.grouping_sets
1140    }
1141
1142    pub fn decompose(self) -> (Vec<PlanAggCall>, IndexSet, Vec<IndexSet>, PlanRef, bool) {
1143        self.core.decompose()
1144    }
1145
1146    #[must_use]
1147    pub fn rewrite_with_input_agg(
1148        &self,
1149        input: PlanRef,
1150        agg_calls: &[PlanAggCall],
1151        mut input_col_change: ColIndexMapping,
1152    ) -> (Self, ColIndexMapping) {
1153        let agg_calls = agg_calls
1154            .iter()
1155            .cloned()
1156            .map(|mut agg_call| {
1157                agg_call.inputs.iter_mut().for_each(|i| {
1158                    *i = InputRef::new(input_col_change.map(i.index()), i.return_type())
1159                });
1160                agg_call.order_by.iter_mut().for_each(|o| {
1161                    o.column_index = input_col_change.map(o.column_index);
1162                });
1163                agg_call.filter = agg_call.filter.rewrite_expr(&mut input_col_change);
1164                agg_call
1165            })
1166            .collect();
1167        // This is the group key order should be after rewriting.
1168        let group_key_in_vec: Vec<usize> = self
1169            .group_key()
1170            .indices()
1171            .map(|key| input_col_change.map(key))
1172            .collect();
1173        // This is the group key order we get after rewriting.
1174        let group_key: IndexSet = group_key_in_vec.iter().cloned().collect();
1175        let grouping_sets = self
1176            .grouping_sets()
1177            .iter()
1178            .map(|set| set.indices().map(|key| input_col_change.map(key)).collect())
1179            .collect();
1180
1181        let new_agg = Agg::new(agg_calls, group_key.clone(), input)
1182            .with_grouping_sets(grouping_sets)
1183            .with_enable_two_phase(self.core().enable_two_phase);
1184
1185        // group_key remapping might cause an output column change, since group key actually is a
1186        // `FixedBitSet`.
1187        let mut out_col_change = vec![];
1188        for idx in group_key_in_vec {
1189            let pos = group_key.indices().position(|x| x == idx).unwrap();
1190            out_col_change.push(pos);
1191        }
1192        for i in (group_key.len())..new_agg.schema().len() {
1193            out_col_change.push(i);
1194        }
1195        let out_col_change =
1196            ColIndexMapping::with_remaining_columns(&out_col_change, new_agg.schema().len());
1197
1198        (new_agg.into(), out_col_change)
1199    }
1200}
1201
1202impl PlanTreeNodeUnary<Logical> for LogicalAgg {
1203    fn input(&self) -> PlanRef {
1204        self.core.input.clone()
1205    }
1206
1207    fn clone_with_input(&self, input: PlanRef) -> Self {
1208        Agg::new(self.agg_calls().clone(), self.group_key().clone(), input)
1209            .with_grouping_sets(self.grouping_sets().clone())
1210            .with_enable_two_phase(self.core().enable_two_phase)
1211            .into()
1212    }
1213
1214    fn rewrite_with_input(
1215        &self,
1216        input: PlanRef,
1217        input_col_change: ColIndexMapping,
1218    ) -> (Self, ColIndexMapping) {
1219        self.rewrite_with_input_agg(input, self.agg_calls(), input_col_change)
1220    }
1221}
1222
1223impl_plan_tree_node_for_unary! { Logical, LogicalAgg }
1224impl_distill_by_unit!(LogicalAgg, core, "LogicalAgg");
1225
1226impl ExprRewritable<Logical> for LogicalAgg {
1227    fn has_rewritable_expr(&self) -> bool {
1228        true
1229    }
1230
1231    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
1232        let mut core = self.core.clone();
1233        core.rewrite_exprs(r);
1234        Self {
1235            base: self.base.clone_with_new_plan_id(),
1236            core,
1237        }
1238        .into()
1239    }
1240}
1241
1242impl ExprVisitable for LogicalAgg {
1243    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
1244        self.core.visit_exprs(v);
1245    }
1246}
1247
1248impl ColPrunable for LogicalAgg {
1249    fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
1250        let group_key_required_cols = self.group_key().to_bitset();
1251
1252        let (agg_call_required_cols, agg_calls) = {
1253            let input_cnt = self.input().schema().len();
1254            let mut tmp = FixedBitSet::with_capacity(input_cnt);
1255            let group_key_cardinality = self.group_key().len();
1256            let new_agg_calls = required_cols
1257                .iter()
1258                .filter(|&&index| index >= group_key_cardinality)
1259                .map(|&index| {
1260                    let index = index - group_key_cardinality;
1261                    let agg_call = self.agg_calls()[index].clone();
1262                    tmp.extend(agg_call.inputs.iter().map(|x| x.index()));
1263                    tmp.extend(agg_call.order_by.iter().map(|x| x.column_index));
1264                    // collect columns used in aggregate filter expressions
1265                    for i in &agg_call.filter.conjunctions {
1266                        tmp.union_with(&i.collect_input_refs(input_cnt));
1267                    }
1268                    agg_call
1269                })
1270                .collect_vec();
1271            (tmp, new_agg_calls)
1272        };
1273
1274        let input_required_cols = {
1275            let mut tmp = FixedBitSet::with_capacity(self.input().schema().len());
1276            tmp.union_with(&group_key_required_cols);
1277            tmp.union_with(&agg_call_required_cols);
1278            tmp.ones().collect_vec()
1279        };
1280        let input_col_change = ColIndexMapping::with_remaining_columns(
1281            &input_required_cols,
1282            self.input().schema().len(),
1283        );
1284        let agg = {
1285            let input = self.input().prune_col(&input_required_cols, ctx);
1286            let (agg, output_col_change) =
1287                self.rewrite_with_input_agg(input, &agg_calls, input_col_change);
1288            assert!(output_col_change.is_identity());
1289            agg
1290        };
1291        let new_output_cols = {
1292            // group key were never pruned or even re-ordered in current impl
1293            let group_key_cardinality = agg.group_key().len();
1294            let mut tmp = (0..group_key_cardinality).collect_vec();
1295            tmp.extend(
1296                required_cols
1297                    .iter()
1298                    .filter(|&&index| index >= group_key_cardinality),
1299            );
1300            tmp
1301        };
1302        if new_output_cols == required_cols {
1303            // current schema perfectly fit the required columns
1304            agg.into()
1305        } else {
1306            // some columns are not needed, or the order need to be adjusted.
1307            // so we did a projection to remove/reorder the columns.
1308            let mapping =
1309                &ColIndexMapping::with_remaining_columns(&new_output_cols, self.schema().len());
1310            let output_required_cols = required_cols
1311                .iter()
1312                .map(|&idx| mapping.map(idx))
1313                .collect_vec();
1314            let src_size = agg.schema().len();
1315            LogicalProject::with_mapping(
1316                agg.into(),
1317                ColIndexMapping::with_remaining_columns(&output_required_cols, src_size),
1318            )
1319            .into()
1320        }
1321    }
1322}
1323
1324impl PredicatePushdown for LogicalAgg {
1325    fn predicate_pushdown(
1326        &self,
1327        predicate: Condition,
1328        ctx: &mut PredicatePushdownContext,
1329    ) -> PlanRef {
1330        let num_group_key = self.group_key().len();
1331        let num_agg_calls = self.agg_calls().len();
1332        assert!(num_group_key + num_agg_calls == self.schema().len());
1333
1334        // SimpleAgg should be skipped because the predicate either references agg_calls
1335        // or is const.
1336        // If the filter references agg_calls, we can not push it.
1337        // When it is constantly true, pushing is useless and may actually cause more evaluation
1338        // cost of the predicate.
1339        // When it is constantly false, pushing is wrong - the old plan returns 0 rows but new one
1340        // returns 1 row.
1341        if num_group_key == 0 {
1342            return gen_filter_and_pushdown(self, predicate, Condition::true_cond(), ctx);
1343        }
1344
1345        // If the filter references agg_calls, we can not push it.
1346        let mut agg_call_columns = FixedBitSet::with_capacity(num_group_key + num_agg_calls);
1347        agg_call_columns.insert_range(num_group_key..num_group_key + num_agg_calls);
1348        let (agg_call_pred, pushed_predicate) = predicate.split_disjoint(&agg_call_columns);
1349
1350        // convert the predicate to one that references the child of the agg
1351        let mut subst = Substitute {
1352            mapping: self
1353                .group_key()
1354                .indices()
1355                .enumerate()
1356                .map(|(i, group_key)| {
1357                    InputRef::new(group_key, self.schema().fields()[i].data_type()).into()
1358                })
1359                .collect(),
1360        };
1361        let pushed_predicate = pushed_predicate.rewrite_expr(&mut subst);
1362
1363        gen_filter_and_pushdown(self, agg_call_pred, pushed_predicate, ctx)
1364    }
1365}
1366
1367impl ToBatch for LogicalAgg {
1368    fn to_batch(&self) -> Result<crate::optimizer::plan_node::BatchPlanRef> {
1369        self.to_batch_with_order_required(&Order::any())
1370    }
1371
1372    // TODO(rc): `to_batch_with_order_required` seems to be useless after we decide to use
1373    // `BatchSortAgg` only when input is already sorted
1374    fn to_batch_with_order_required(
1375        &self,
1376        required_order: &Order,
1377    ) -> Result<crate::optimizer::plan_node::BatchPlanRef> {
1378        let input = self.input().to_batch()?;
1379        let new_logical = self.core.clone_with_input(input);
1380        let agg_plan = if self.group_key().is_empty() {
1381            BatchSimpleAgg::new(new_logical).into()
1382        } else if self.ctx().session_ctx().config().batch_enable_sort_agg()
1383            && new_logical.input_provides_order_on_group_keys()
1384        {
1385            BatchSortAgg::new(new_logical).into()
1386        } else {
1387            BatchHashAgg::new(new_logical).into()
1388        };
1389        required_order.enforce_if_not_satisfies(agg_plan)
1390    }
1391}
1392
1393fn find_or_append_row_count(mut logical: Agg<StreamPlanRef>) -> (Agg<StreamPlanRef>, usize) {
1394    // `HashAgg`/`SimpleAgg` executors require a `count(*)` to correctly build changes, so
1395    // append a `count(*)` if not exists.
1396    let count_star = PlanAggCall::count_star();
1397    let row_count_idx = if let Some((idx, _)) = logical
1398        .agg_calls
1399        .iter()
1400        .find_position(|&c| c == &count_star)
1401    {
1402        idx
1403    } else {
1404        let idx = logical.agg_calls.len();
1405        logical.agg_calls.push(count_star);
1406        idx
1407    };
1408    (logical, row_count_idx)
1409}
1410
1411fn new_stream_simple_agg(
1412    core: Agg<StreamPlanRef>,
1413    must_output_per_barrier: bool,
1414) -> Result<StreamSimpleAgg> {
1415    let (logical, row_count_idx) = find_or_append_row_count(core);
1416    StreamSimpleAgg::new(logical, row_count_idx, must_output_per_barrier)
1417}
1418
1419fn new_stream_hash_agg(
1420    core: Agg<StreamPlanRef>,
1421    vnode_col_idx: Option<usize>,
1422) -> Result<StreamHashAgg> {
1423    let (logical, row_count_idx) = find_or_append_row_count(core);
1424    StreamHashAgg::new(logical, vnode_col_idx, row_count_idx)
1425}
1426
1427impl ToStream for LogicalAgg {
1428    fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<StreamPlanRef> {
1429        use super::stream::prelude::*;
1430
1431        let eowc = ctx.emit_on_window_close();
1432        let input = self.input();
1433
1434        let stream_input = input.to_stream(ctx)?;
1435
1436        // Use Dedup operator, if possible.
1437        if stream_input.append_only() && self.agg_calls().is_empty() && !self.group_key().is_empty()
1438        {
1439            let group_key = self.group_key().to_vec();
1440            let input_schema_len = input.schema().len();
1441            let dedup: PlanRef = LogicalDedup::new(input, group_key.clone()).into();
1442            let project = LogicalProject::with_mapping(
1443                dedup,
1444                ColIndexMapping::with_remaining_columns(&group_key, input_schema_len),
1445            );
1446            return project.to_stream(ctx);
1447        }
1448
1449        if self.agg_calls().iter().any(|call| {
1450            matches!(
1451                call.agg_type,
1452                AggType::Builtin(PbAggKind::ApproxCountDistinct)
1453            )
1454        }) {
1455            if stream_input.append_only() {
1456                self.core.ctx().session_ctx().notice_to_user(
1457                    "Streaming `APPROX_COUNT_DISTINCT` is still a preview feature and subject to change. Please do not use it in production environment.",
1458                );
1459            } else {
1460                bail_not_implemented!(
1461                    "Streaming `APPROX_COUNT_DISTINCT` is only supported in append-only stream"
1462                );
1463            }
1464        }
1465
1466        let plan = self.gen_dist_stream_agg_plan(stream_input)?;
1467
1468        let (plan, n_final_agg_calls) = if let Some(final_agg) = plan.as_stream_simple_agg() {
1469            if eowc {
1470                return Err(ErrorCode::InvalidInputSyntax(
1471                    "`EMIT ON WINDOW CLOSE` cannot be used for aggregation without `GROUP BY`"
1472                        .to_owned(),
1473                )
1474                .into());
1475            }
1476            (plan.clone(), final_agg.agg_calls().len())
1477        } else if let Some(final_agg) = plan.as_stream_hash_agg() {
1478            (
1479                if eowc {
1480                    final_agg.to_eowc_version()?
1481                } else {
1482                    plan.clone()
1483                },
1484                final_agg.agg_calls().len(),
1485            )
1486        } else if let Some(_approx_percentile_agg) = plan.as_stream_global_approx_percentile() {
1487            if eowc {
1488                return Err(ErrorCode::InvalidInputSyntax(
1489                    "`EMIT ON WINDOW CLOSE` cannot be used for aggregation without `GROUP BY`"
1490                        .to_owned(),
1491                )
1492                .into());
1493            }
1494            (plan.clone(), 1)
1495        } else if let Some(stream_row_merge) = plan.as_stream_row_merge() {
1496            if eowc {
1497                return Err(ErrorCode::InvalidInputSyntax(
1498                    "`EMIT ON WINDOW CLOSE` cannot be used for aggregation without `GROUP BY`"
1499                        .to_owned(),
1500                )
1501                .into());
1502            }
1503            (plan.clone(), stream_row_merge.base.schema().len())
1504        } else {
1505            panic!(
1506                "the root PlanNode must be StreamHashAgg, StreamSimpleAgg, StreamGlobalApproxPercentile, or StreamRowMerge"
1507            );
1508        };
1509
1510        if self.agg_calls().len() == n_final_agg_calls {
1511            // an existing `count(*)` is used as row count column in `StreamXxxAgg`
1512            Ok(plan)
1513        } else {
1514            // a `count(*)` is appended, should project the output
1515            assert_eq!(self.agg_calls().len() + 1, n_final_agg_calls);
1516
1517            let mut project = StreamProject::new(generic::Project::with_out_col_idx(
1518                plan,
1519                0..self.schema().len(),
1520            ));
1521            // If there's no agg call, then `count(*)` will be the only column in the output besides keys.
1522            // Since it'll be pruned immediately in `StreamProject`, the update records are likely to be
1523            // no-op. So we set the hint to instruct the executor to eliminate them.
1524            // See https://github.com/risingwavelabs/risingwave/issues/17030.
1525            if self.agg_calls().is_empty() {
1526                project = project.with_noop_update_hint(true);
1527            }
1528            Ok(project.into())
1529        }
1530    }
1531
1532    fn try_better_locality(&self, columns: &[usize]) -> Option<PlanRef> {
1533        if columns.is_empty() {
1534            return None;
1535        }
1536
1537        // Check if the given columns are a prefix of group keys.
1538        let group_key = self.group_key().to_vec();
1539        if columns.len() > group_key.len() || columns != &group_key[..columns.len()] {
1540            return None;
1541        }
1542
1543        // Return the same plan directly without calling `try_better_locality` on input.
1544        // Because in `logical_rewrite_for_stream`, we will enforce the locality requirement on the group keys anyway.
1545        // If we call `try_better_locality` on input, it would miss the chance to utilize the locality of the current agg,
1546        // since the agg's input doesn't have the locality yet at that moment.
1547        Some(self.clone_with_input(self.input()).into())
1548    }
1549
1550    fn logical_rewrite_for_stream(
1551        &self,
1552        ctx: &mut RewriteStreamContext,
1553    ) -> Result<(PlanRef, ColIndexMapping)> {
1554        let logical_input = if self.group_key().is_empty() {
1555            self.input()
1556        } else {
1557            try_enforce_locality_requirement(self.input(), &self.group_key().to_vec())
1558        };
1559        let (input, input_col_change) = logical_input.logical_rewrite_for_stream(ctx)?;
1560        let (agg, out_col_change) = self.rewrite_with_input(input, input_col_change);
1561        let (map, _) = out_col_change.into_parts();
1562        let out_col_change = ColIndexMapping::new(map, agg.schema().len());
1563        Ok((agg.into(), out_col_change))
1564    }
1565}
1566
1567#[cfg(test)]
1568mod tests {
1569    use risingwave_common::catalog::{Field, Schema};
1570
1571    use super::*;
1572    use crate::expr::{assert_eq_input_ref, input_ref_to_column_indices};
1573    use crate::optimizer::optimizer_context::OptimizerContext;
1574    use crate::optimizer::plan_node::LogicalValues;
1575
1576    #[tokio::test]
1577    async fn test_create() {
1578        let ty = DataType::Int32;
1579        let ctx = OptimizerContext::mock();
1580        let fields: Vec<Field> = vec![
1581            Field::with_name(ty.clone(), "v1"),
1582            Field::with_name(ty.clone(), "v2"),
1583            Field::with_name(ty.clone(), "v3"),
1584        ];
1585        let values = LogicalValues::new(vec![], Schema { fields }, ctx);
1586        let input = PlanRef::from(values);
1587        let input_ref_1 = InputRef::new(0, ty.clone());
1588        let input_ref_2 = InputRef::new(1, ty.clone());
1589        let input_ref_3 = InputRef::new(2, ty);
1590
1591        let gen_internal_value = |select_exprs: Vec<ExprImpl>,
1592                                  group_exprs|
1593         -> (Vec<ExprImpl>, Vec<PlanAggCall>, IndexSet) {
1594            let (plan, exprs, _) = LogicalAgg::create(
1595                select_exprs,
1596                GroupBy::GroupKey(group_exprs),
1597                None,
1598                input.clone(),
1599            )
1600            .unwrap();
1601
1602            let logical_agg = plan.as_logical_agg().unwrap();
1603            let agg_calls = logical_agg.agg_calls().clone();
1604            let group_key = logical_agg.group_key().clone();
1605
1606            (exprs, agg_calls, group_key)
1607        };
1608
1609        // Test case: select v1 from test group by v1;
1610        {
1611            let select_exprs = vec![input_ref_1.clone().into()];
1612            let group_exprs = vec![input_ref_1.clone().into()];
1613
1614            let (exprs, agg_calls, group_key) = gen_internal_value(select_exprs, group_exprs);
1615
1616            assert_eq!(exprs.len(), 1);
1617            assert_eq_input_ref!(&exprs[0], 0);
1618
1619            assert_eq!(agg_calls.len(), 0);
1620            assert_eq!(group_key, vec![0].into());
1621        }
1622
1623        // Test case: select v1, min(v2) from test group by v1;
1624        {
1625            let min_v2 = AggCall::new(
1626                PbAggKind::Min.into(),
1627                vec![input_ref_2.clone().into()],
1628                false,
1629                OrderBy::any(),
1630                Condition::true_cond(),
1631                vec![],
1632            )
1633            .unwrap();
1634            let select_exprs = vec![input_ref_1.clone().into(), min_v2.into()];
1635            let group_exprs = vec![input_ref_1.clone().into()];
1636
1637            let (exprs, agg_calls, group_key) = gen_internal_value(select_exprs, group_exprs);
1638
1639            assert_eq!(exprs.len(), 2);
1640            assert_eq_input_ref!(&exprs[0], 0);
1641            assert_eq_input_ref!(&exprs[1], 1);
1642
1643            assert_eq!(agg_calls.len(), 1);
1644            assert_eq!(agg_calls[0].agg_type, PbAggKind::Min.into());
1645            assert_eq!(input_ref_to_column_indices(&agg_calls[0].inputs), vec![1]);
1646            assert_eq!(group_key, vec![0].into());
1647        }
1648
1649        // Test case: select v1, min(v2) + max(v3) from t group by v1;
1650        {
1651            let min_v2 = AggCall::new(
1652                PbAggKind::Min.into(),
1653                vec![input_ref_2.clone().into()],
1654                false,
1655                OrderBy::any(),
1656                Condition::true_cond(),
1657                vec![],
1658            )
1659            .unwrap();
1660            let max_v3 = AggCall::new(
1661                PbAggKind::Max.into(),
1662                vec![input_ref_3.clone().into()],
1663                false,
1664                OrderBy::any(),
1665                Condition::true_cond(),
1666                vec![],
1667            )
1668            .unwrap();
1669            let func_call =
1670                FunctionCall::new(ExprType::Add, vec![min_v2.into(), max_v3.into()]).unwrap();
1671            let select_exprs = vec![input_ref_1.clone().into(), ExprImpl::from(func_call)];
1672            let group_exprs = vec![input_ref_1.clone().into()];
1673
1674            let (exprs, agg_calls, group_key) = gen_internal_value(select_exprs, group_exprs);
1675
1676            assert_eq_input_ref!(&exprs[0], 0);
1677            if let ExprImpl::FunctionCall(func_call) = &exprs[1] {
1678                assert_eq!(func_call.func_type(), ExprType::Add);
1679                let inputs = func_call.inputs();
1680                assert_eq_input_ref!(&inputs[0], 1);
1681                assert_eq_input_ref!(&inputs[1], 2);
1682            } else {
1683                panic!("Wrong expression type!");
1684            }
1685
1686            assert_eq!(agg_calls.len(), 2);
1687            assert_eq!(agg_calls[0].agg_type, PbAggKind::Min.into());
1688            assert_eq!(input_ref_to_column_indices(&agg_calls[0].inputs), vec![1]);
1689            assert_eq!(agg_calls[1].agg_type, PbAggKind::Max.into());
1690            assert_eq!(input_ref_to_column_indices(&agg_calls[1].inputs), vec![2]);
1691            assert_eq!(group_key, vec![0].into());
1692        }
1693
1694        // Test case: select v2, min(v1 * v3) from test group by v2;
1695        {
1696            let v1_mult_v3 = FunctionCall::new(
1697                ExprType::Multiply,
1698                vec![input_ref_1.into(), input_ref_3.into()],
1699            )
1700            .unwrap();
1701            let agg_call = AggCall::new(
1702                PbAggKind::Min.into(),
1703                vec![v1_mult_v3.into()],
1704                false,
1705                OrderBy::any(),
1706                Condition::true_cond(),
1707                vec![],
1708            )
1709            .unwrap();
1710            let select_exprs = vec![input_ref_2.clone().into(), agg_call.into()];
1711            let group_exprs = vec![input_ref_2.into()];
1712
1713            let (exprs, agg_calls, group_key) = gen_internal_value(select_exprs, group_exprs);
1714
1715            assert_eq_input_ref!(&exprs[0], 0);
1716            assert_eq_input_ref!(&exprs[1], 1);
1717
1718            assert_eq!(agg_calls.len(), 1);
1719            assert_eq!(agg_calls[0].agg_type, PbAggKind::Min.into());
1720            assert_eq!(input_ref_to_column_indices(&agg_calls[0].inputs), vec![1]);
1721            assert_eq!(group_key, vec![0].into());
1722        }
1723    }
1724
1725    /// Generate a agg call node with given [`DataType`] and fields.
1726    /// For example, `generate_agg_call(Int32, [v1, v2, v3])` will result in:
1727    /// ```text
1728    /// Agg(min(input_ref(2))) group by (input_ref(1))
1729    ///   TableScan(v1, v2, v3)
1730    /// ```
1731    fn generate_agg_call(ty: DataType, fields: Vec<Field>) -> LogicalAgg {
1732        let ctx = OptimizerContext::mock();
1733
1734        let values = LogicalValues::new(vec![], Schema { fields }, ctx);
1735        let agg_call = PlanAggCall {
1736            agg_type: PbAggKind::Min.into(),
1737            return_type: ty.clone(),
1738            inputs: vec![InputRef::new(2, ty)],
1739            distinct: false,
1740            order_by: vec![],
1741            filter: Condition::true_cond(),
1742            direct_args: vec![],
1743        };
1744        Agg::new(vec![agg_call], vec![1].into(), values.into()).into()
1745    }
1746
1747    #[tokio::test]
1748    /// Pruning
1749    /// ```text
1750    /// Agg(min(input_ref(2))) group by (input_ref(1))
1751    ///   TableScan(v1, v2, v3)
1752    /// ```
1753    /// with required columns [0,1] (all columns) will result in
1754    /// ```text
1755    /// Agg(min(input_ref(1))) group by (input_ref(0))
1756    ///  TableScan(v2, v3)
1757    /// ```
1758    async fn test_prune_all() {
1759        let ty = DataType::Int32;
1760        let fields: Vec<Field> = vec![
1761            Field::with_name(ty.clone(), "v1"),
1762            Field::with_name(ty.clone(), "v2"),
1763            Field::with_name(ty.clone(), "v3"),
1764        ];
1765        let agg: PlanRef = generate_agg_call(ty.clone(), fields.clone()).into();
1766        // Perform the prune
1767        let required_cols = vec![0, 1];
1768        let plan = agg.prune_col(&required_cols, &mut ColumnPruningContext::new(agg.clone()));
1769
1770        // Check the result
1771        let agg_new = plan.as_logical_agg().unwrap();
1772        assert_eq!(agg_new.group_key(), &vec![0].into());
1773
1774        assert_eq!(agg_new.agg_calls().len(), 1);
1775        let agg_call_new = agg_new.agg_calls()[0].clone();
1776        assert_eq!(agg_call_new.agg_type, PbAggKind::Min.into());
1777        assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![1]);
1778        assert_eq!(agg_call_new.return_type, ty);
1779
1780        let values = agg_new.input();
1781        let values = values.as_logical_values().unwrap();
1782        assert_eq!(values.schema().fields(), &fields[1..]);
1783    }
1784
1785    #[tokio::test]
1786    /// Pruning
1787    /// ```text
1788    /// Agg(min(input_ref(2))) group by (input_ref(1))
1789    ///   TableScan(v1, v2, v3)
1790    /// ```
1791    /// with required columns [1,0] (all columns, with reversed order) will result in
1792    /// ```text
1793    /// Project [input_ref(1), input_ref(0)]
1794    ///   Agg(min(input_ref(1))) group by (input_ref(0))
1795    ///     TableScan(v2, v3)
1796    /// ```
1797    async fn test_prune_all_with_order_required() {
1798        let ty = DataType::Int32;
1799        let fields: Vec<Field> = vec![
1800            Field::with_name(ty.clone(), "v1"),
1801            Field::with_name(ty.clone(), "v2"),
1802            Field::with_name(ty.clone(), "v3"),
1803        ];
1804        let agg: PlanRef = generate_agg_call(ty.clone(), fields.clone()).into();
1805        // Perform the prune
1806        let required_cols = vec![1, 0];
1807        let plan = agg.prune_col(&required_cols, &mut ColumnPruningContext::new(agg.clone()));
1808        // Check the result
1809        let proj = plan.as_logical_project().unwrap();
1810        assert_eq!(proj.exprs().len(), 2);
1811        assert_eq!(proj.exprs()[0].as_input_ref().unwrap().index(), 1);
1812        assert_eq!(proj.exprs()[1].as_input_ref().unwrap().index(), 0);
1813        let proj_input = proj.input();
1814        let agg_new = proj_input.as_logical_agg().unwrap();
1815        assert_eq!(agg_new.group_key(), &vec![0].into());
1816
1817        assert_eq!(agg_new.agg_calls().len(), 1);
1818        let agg_call_new = agg_new.agg_calls()[0].clone();
1819        assert_eq!(agg_call_new.agg_type, PbAggKind::Min.into());
1820        assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![1]);
1821        assert_eq!(agg_call_new.return_type, ty);
1822
1823        let values = agg_new.input();
1824        let values = values.as_logical_values().unwrap();
1825        assert_eq!(values.schema().fields(), &fields[1..]);
1826    }
1827
1828    #[tokio::test]
1829    /// Pruning
1830    /// ```text
1831    /// Agg(min(input_ref(2))) group by (input_ref(1))
1832    ///   TableScan(v1, v2, v3)
1833    /// ```
1834    /// with required columns [1] (group key removed) will result in
1835    /// ```text
1836    /// Project(input_ref(1))
1837    ///   Agg(min(input_ref(1))) group by (input_ref(0))
1838    ///     TableScan(v2, v3)
1839    /// ```
1840    async fn test_prune_group_key() {
1841        let ctx = OptimizerContext::mock();
1842        let ty = DataType::Int32;
1843        let fields: Vec<Field> = vec![
1844            Field::with_name(ty.clone(), "v1"),
1845            Field::with_name(ty.clone(), "v2"),
1846            Field::with_name(ty.clone(), "v3"),
1847        ];
1848        let values: LogicalValues = LogicalValues::new(
1849            vec![],
1850            Schema {
1851                fields: fields.clone(),
1852            },
1853            ctx,
1854        );
1855        let agg_call = PlanAggCall {
1856            agg_type: PbAggKind::Min.into(),
1857            return_type: ty.clone(),
1858            inputs: vec![InputRef::new(2, ty.clone())],
1859            distinct: false,
1860            order_by: vec![],
1861            filter: Condition::true_cond(),
1862            direct_args: vec![],
1863        };
1864        let agg: PlanRef = Agg::new(vec![agg_call], vec![1].into(), values.into()).into();
1865
1866        // Perform the prune
1867        let required_cols = vec![1];
1868        let plan = agg.prune_col(&required_cols, &mut ColumnPruningContext::new(agg.clone()));
1869
1870        // Check the result
1871        let project = plan.as_logical_project().unwrap();
1872        assert_eq!(project.exprs().len(), 1);
1873        assert_eq_input_ref!(&project.exprs()[0], 1);
1874
1875        let agg_new = project.input();
1876        let agg_new = agg_new.as_logical_agg().unwrap();
1877        assert_eq!(agg_new.group_key(), &vec![0].into());
1878
1879        assert_eq!(agg_new.agg_calls().len(), 1);
1880        let agg_call_new = agg_new.agg_calls()[0].clone();
1881        assert_eq!(agg_call_new.agg_type, PbAggKind::Min.into());
1882        assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![1]);
1883        assert_eq!(agg_call_new.return_type, ty);
1884
1885        let values = agg_new.input();
1886        let values = values.as_logical_values().unwrap();
1887        assert_eq!(values.schema().fields(), &fields[1..]);
1888    }
1889
1890    #[tokio::test]
1891    /// Pruning
1892    /// ```text
1893    /// Agg(min(input_ref(2)), max(input_ref(1))) group by (input_ref(1), input_ref(2))
1894    ///   TableScan(v1, v2, v3)
1895    /// ```
1896    /// with required columns [0,3] will result in
1897    /// ```text
1898    /// Project(input_ref(0), input_ref(2))
1899    ///   Agg(max(input_ref(0))) group by (input_ref(0), input_ref(1))
1900    ///     TableScan(v2, v3)
1901    /// ```
1902    async fn test_prune_agg() {
1903        let ty = DataType::Int32;
1904        let ctx = OptimizerContext::mock();
1905        let fields: Vec<Field> = vec![
1906            Field::with_name(ty.clone(), "v1"),
1907            Field::with_name(ty.clone(), "v2"),
1908            Field::with_name(ty.clone(), "v3"),
1909        ];
1910        let values = LogicalValues::new(
1911            vec![],
1912            Schema {
1913                fields: fields.clone(),
1914            },
1915            ctx,
1916        );
1917
1918        let agg_calls = vec![
1919            PlanAggCall {
1920                agg_type: PbAggKind::Min.into(),
1921                return_type: ty.clone(),
1922                inputs: vec![InputRef::new(2, ty.clone())],
1923                distinct: false,
1924                order_by: vec![],
1925                filter: Condition::true_cond(),
1926                direct_args: vec![],
1927            },
1928            PlanAggCall {
1929                agg_type: PbAggKind::Max.into(),
1930                return_type: ty.clone(),
1931                inputs: vec![InputRef::new(1, ty.clone())],
1932                distinct: false,
1933                order_by: vec![],
1934                filter: Condition::true_cond(),
1935                direct_args: vec![],
1936            },
1937        ];
1938        let agg: PlanRef = Agg::new(agg_calls, vec![1, 2].into(), values.into()).into();
1939
1940        // Perform the prune
1941        let required_cols = vec![0, 3];
1942        let plan = agg.prune_col(&required_cols, &mut ColumnPruningContext::new(agg.clone()));
1943        // Check the result
1944        let project = plan.as_logical_project().unwrap();
1945        assert_eq!(project.exprs().len(), 2);
1946        assert_eq_input_ref!(&project.exprs()[0], 0);
1947        assert_eq_input_ref!(&project.exprs()[1], 2);
1948
1949        let agg_new = project.input();
1950        let agg_new = agg_new.as_logical_agg().unwrap();
1951        assert_eq!(agg_new.group_key(), &vec![0, 1].into());
1952
1953        assert_eq!(agg_new.agg_calls().len(), 1);
1954        let agg_call_new = agg_new.agg_calls()[0].clone();
1955        assert_eq!(agg_call_new.agg_type, PbAggKind::Max.into());
1956        assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![0]);
1957        assert_eq!(agg_call_new.return_type, ty);
1958
1959        let values = agg_new.input();
1960        let values = values.as_logical_values().unwrap();
1961        assert_eq!(values.schema().fields(), &fields[1..]);
1962    }
1963}