risingwave_frontend/optimizer/plan_node/
logical_agg.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use fixedbitset::FixedBitSet;
16use itertools::Itertools;
17use risingwave_common::types::{DataType, ScalarImpl};
18use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
19use risingwave_common::{bail, bail_not_implemented, not_implemented};
20use risingwave_expr::aggregate::{AggType, PbAggKind, agg_types};
21
22use super::generic::{self, Agg, GenericPlanRef, PlanAggCall, ProjectBuilder};
23use super::utils::impl_distill_by_unit;
24use super::{
25    BatchHashAgg, BatchSimpleAgg, ColPrunable, ExprRewritable, Logical, PlanBase, PlanRef,
26    PlanTreeNodeUnary, PredicatePushdown, StreamHashAgg, StreamProject, StreamShare,
27    StreamSimpleAgg, StreamStatelessSimpleAgg, ToBatch, ToStream,
28};
29use crate::error::{ErrorCode, Result, RwError};
30use crate::expr::{
31    AggCall, Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, FunctionCall, InputRef, Literal,
32    OrderBy, WindowFunction,
33};
34use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
35use crate::optimizer::plan_node::generic::GenericPlanNode;
36use crate::optimizer::plan_node::stream_global_approx_percentile::StreamGlobalApproxPercentile;
37use crate::optimizer::plan_node::stream_local_approx_percentile::StreamLocalApproxPercentile;
38use crate::optimizer::plan_node::stream_row_merge::StreamRowMerge;
39use crate::optimizer::plan_node::{
40    BatchSortAgg, ColumnPruningContext, LogicalDedup, LogicalProject, PredicatePushdownContext,
41    RewriteStreamContext, ToStreamContext, gen_filter_and_pushdown,
42};
43use crate::optimizer::property::{Distribution, Order, RequiredDist};
44use crate::utils::{
45    ColIndexMapping, ColIndexMappingRewriteExt, Condition, GroupBy, IndexSet, Substitute,
46};
47
48pub struct AggInfo {
49    pub calls: Vec<PlanAggCall>,
50    pub col_mapping: ColIndexMapping,
51}
52
53/// `SeparatedAggInfo` is used to separate normal and approx percentile aggs.
54pub struct SeparatedAggInfo {
55    normal: AggInfo,
56    approx: AggInfo,
57}
58
59/// `LogicalAgg` groups input data by their group key and computes aggregation functions.
60///
61/// It corresponds to the `GROUP BY` operator in a SQL query statement together with the aggregate
62/// functions in the `SELECT` clause.
63///
64/// The output schema will first include the group key and then the aggregation calls.
65#[derive(Clone, Debug, PartialEq, Eq, Hash)]
66pub struct LogicalAgg {
67    pub base: PlanBase<Logical>,
68    core: Agg<PlanRef>,
69}
70
71impl LogicalAgg {
72    /// Generate plan for stateless 2-phase streaming agg.
73    /// Should only be used iff input is distributed. Input must be converted to stream form.
74    fn gen_stateless_two_phase_streaming_agg_plan(&self, stream_input: PlanRef) -> Result<PlanRef> {
75        debug_assert!(self.group_key().is_empty());
76        let mut core = self.core.clone();
77
78        // ====== Handle approx percentile aggs
79        let (non_approx_percentile_col_mapping, approx_percentile_col_mapping, approx_percentile) =
80            self.prepare_approx_percentile(&mut core, stream_input.clone())?;
81
82        if core.agg_calls.is_empty() {
83            if let Some(approx_percentile) = approx_percentile {
84                return Ok(approx_percentile);
85            };
86            bail!("expected at least one agg call");
87        }
88
89        let need_row_merge: bool = Self::need_row_merge(&approx_percentile);
90
91        // ====== Handle normal aggs
92        let total_agg_calls = core
93            .agg_calls
94            .iter()
95            .enumerate()
96            .map(|(partial_output_idx, agg_call)| {
97                agg_call.partial_to_total_agg_call(partial_output_idx)
98            })
99            .collect_vec();
100        let local_agg = StreamStatelessSimpleAgg::new(core);
101        let exchange =
102            RequiredDist::single().enforce_if_not_satisfies(local_agg.into(), &Order::any())?;
103
104        let must_output_per_barrier = need_row_merge;
105        let global_agg = new_stream_simple_agg(
106            Agg::new(total_agg_calls, IndexSet::empty(), exchange),
107            must_output_per_barrier,
108        );
109
110        // ====== Merge approx percentile and normal aggs
111        Self::add_row_merge_if_needed(
112            approx_percentile,
113            global_agg.into(),
114            approx_percentile_col_mapping,
115            non_approx_percentile_col_mapping,
116        )
117    }
118
119    /// Generate plan for stateless/stateful 2-phase streaming agg.
120    /// Should only be used iff input is distributed.
121    /// Input must be converted to stream form.
122    fn gen_vnode_two_phase_streaming_agg_plan(
123        &self,
124        stream_input: PlanRef,
125        dist_key: &[usize],
126    ) -> Result<PlanRef> {
127        let mut core = self.core.clone();
128
129        let (non_approx_percentile_col_mapping, approx_percentile_col_mapping, approx_percentile) =
130            self.prepare_approx_percentile(&mut core, stream_input.clone())?;
131
132        if core.agg_calls.is_empty() {
133            if let Some(approx_percentile) = approx_percentile {
134                return Ok(approx_percentile);
135            };
136            bail!("expected at least one agg call");
137        }
138        let need_row_merge = Self::need_row_merge(&approx_percentile);
139
140        // Generate vnode via project
141        // TODO(kwannoel): We should apply Project optimization rules here.
142        let input_col_num = stream_input.schema().len(); // get schema len before moving `stream_input`.
143        let project = StreamProject::new(generic::Project::with_vnode_col(stream_input, dist_key));
144        let vnode_col_idx = project.base.schema().len() - 1;
145
146        // Generate local agg step
147        let mut local_group_key = self.group_key().clone();
148        local_group_key.insert(vnode_col_idx);
149        let n_local_group_key = local_group_key.len();
150        let local_agg = new_stream_hash_agg(
151            Agg::new(core.agg_calls.to_vec(), local_group_key, project.into()),
152            Some(vnode_col_idx),
153        );
154        // Global group key excludes vnode.
155        let local_agg_group_key_cardinality = local_agg.group_key().len();
156        let local_group_key_without_vnode =
157            &local_agg.group_key().to_vec()[..local_agg_group_key_cardinality - 1];
158        let global_group_key = local_agg
159            .i2o_col_mapping()
160            .rewrite_dist_key(local_group_key_without_vnode)
161            .expect("some input group key could not be mapped");
162
163        // Generate global agg step
164        let global_agg = if self.group_key().is_empty() {
165            let exchange =
166                RequiredDist::single().enforce_if_not_satisfies(local_agg.into(), &Order::any())?;
167            let must_output_per_barrier = need_row_merge;
168            let global_agg = new_stream_simple_agg(
169                Agg::new(
170                    core.agg_calls
171                        .iter()
172                        .enumerate()
173                        .map(|(partial_output_idx, agg_call)| {
174                            agg_call
175                                .partial_to_total_agg_call(n_local_group_key + partial_output_idx)
176                        })
177                        .collect(),
178                    global_group_key.into_iter().collect(),
179                    exchange,
180                ),
181                must_output_per_barrier,
182            );
183            global_agg.into()
184        } else {
185            // the `RowMergeExec` has not supported keyed merge
186            assert!(!need_row_merge);
187            let exchange = RequiredDist::shard_by_key(input_col_num, &global_group_key)
188                .enforce_if_not_satisfies(local_agg.into(), &Order::any())?;
189            // Local phase should have reordered the group keys into their required order.
190            // we can just follow it.
191            let global_agg = new_stream_hash_agg(
192                Agg::new(
193                    core.agg_calls
194                        .iter()
195                        .enumerate()
196                        .map(|(partial_output_idx, agg_call)| {
197                            agg_call
198                                .partial_to_total_agg_call(n_local_group_key + partial_output_idx)
199                        })
200                        .collect(),
201                    global_group_key.into_iter().collect(),
202                    exchange,
203                ),
204                None,
205            );
206            global_agg.into()
207        };
208        Self::add_row_merge_if_needed(
209            approx_percentile,
210            global_agg,
211            approx_percentile_col_mapping,
212            non_approx_percentile_col_mapping,
213        )
214    }
215
216    fn gen_single_plan(&self, stream_input: PlanRef) -> Result<PlanRef> {
217        let mut core = self.core.clone();
218        let input = RequiredDist::single().enforce_if_not_satisfies(stream_input, &Order::any())?;
219        core.input = input;
220        Ok(new_stream_simple_agg(core, false).into())
221    }
222
223    fn gen_shuffle_plan(&self, stream_input: PlanRef) -> Result<PlanRef> {
224        let input =
225            RequiredDist::shard_by_key(stream_input.schema().len(), &self.group_key().to_vec())
226                .enforce_if_not_satisfies(stream_input, &Order::any())?;
227        let mut core = self.core.clone();
228        core.input = input;
229        Ok(new_stream_hash_agg(core, None).into())
230    }
231
232    /// Generates distributed stream plan.
233    fn gen_dist_stream_agg_plan(&self, stream_input: PlanRef) -> Result<PlanRef> {
234        use super::stream::prelude::*;
235
236        let input_dist = stream_input.distribution();
237        debug_assert!(*input_dist != Distribution::Broadcast);
238
239        // Shuffle agg
240        // If we have group key, and we won't try two phase agg optimization at all,
241        // we will always choose shuffle agg over single agg.
242        if !self.group_key().is_empty() && !self.core.must_try_two_phase_agg() {
243            return self.gen_shuffle_plan(stream_input);
244        }
245
246        // Standalone agg
247        // If no group key, and cannot two phase agg, we have to use single plan.
248        if self.group_key().is_empty() && !self.core.can_two_phase_agg() {
249            return self.gen_single_plan(stream_input);
250        }
251
252        debug_assert!(if !self.group_key().is_empty() {
253            self.core.must_try_two_phase_agg()
254        } else {
255            self.core.can_two_phase_agg()
256        });
257
258        // Stateless 2-phase simple agg
259        // can be applied on stateless simple agg calls,
260        // with input distributed by [`Distribution::AnyShard`]
261        if self.group_key().is_empty()
262            && self
263                .core
264                .all_local_aggs_are_stateless(stream_input.append_only())
265            && input_dist.satisfies(&RequiredDist::AnyShard)
266        {
267            return self.gen_stateless_two_phase_streaming_agg_plan(stream_input);
268        }
269
270        // If input is [`Distribution::SomeShard`] and we must try to use two phase agg,
271        // The only remaining strategy is Vnode-based 2-phase agg.
272        // We shall first distribute it by PK,
273        // so it obeys consistent hash strategy via [`Distribution::HashShard`].
274        let stream_input =
275            if *input_dist == Distribution::SomeShard && self.core.must_try_two_phase_agg() {
276                RequiredDist::shard_by_key(
277                    stream_input.schema().len(),
278                    stream_input.expect_stream_key(),
279                )
280                .enforce_if_not_satisfies(stream_input, &Order::any())?
281            } else {
282                stream_input
283            };
284        let input_dist = stream_input.distribution();
285
286        // Vnode-based 2-phase agg
287        // can be applied on agg calls not affected by order,
288        // with input distributed by dist_key.
289        match input_dist {
290            Distribution::HashShard(dist_key) | Distribution::UpstreamHashShard(dist_key, _)
291                if (self.group_key().is_empty()
292                    || !self.core.hash_agg_dist_satisfied_by_input_dist(input_dist)) =>
293            {
294                let dist_key = dist_key.clone();
295                return self.gen_vnode_two_phase_streaming_agg_plan(stream_input, &dist_key);
296            }
297            _ => {}
298        }
299
300        // Fallback to shuffle or single, if we can't generate any 2-phase plans.
301        if !self.group_key().is_empty() {
302            self.gen_shuffle_plan(stream_input)
303        } else {
304            self.gen_single_plan(stream_input)
305        }
306    }
307
308    /// Prepares metadata and the `approx_percentile` plan, if there's one present.
309    /// It may modify `core.agg_calls` to separate normal agg and approx percentile agg,
310    /// and `core.input` to share the input via `StreamShare`,
311    /// to both approx percentile agg and normal agg.
312    fn prepare_approx_percentile(
313        &self,
314        core: &mut Agg<PlanRef>,
315        stream_input: PlanRef,
316    ) -> Result<(ColIndexMapping, ColIndexMapping, Option<PlanRef>)> {
317        let SeparatedAggInfo { normal, approx } = self.separate_normal_and_special_agg();
318
319        let AggInfo {
320            calls: non_approx_percentile_agg_calls,
321            col_mapping: non_approx_percentile_col_mapping,
322        } = normal;
323        let AggInfo {
324            calls: approx_percentile_agg_calls,
325            col_mapping: approx_percentile_col_mapping,
326        } = approx;
327        if !self.group_key().is_empty() && !approx_percentile_agg_calls.is_empty() {
328            bail_not_implemented!(
329                "two-phase streaming approx percentile aggregation with group key, \
330             please use single phase aggregation instead"
331            );
332        }
333
334        // Either we have approx percentile aggs and non_approx percentile aggs,
335        // or we have at least 2 approx percentile aggs.
336        let needs_row_merge = (!non_approx_percentile_agg_calls.is_empty()
337            && !approx_percentile_agg_calls.is_empty())
338            || approx_percentile_agg_calls.len() >= 2;
339        core.input = if needs_row_merge {
340            // If there's row merge, we need to share the input.
341            StreamShare::new_from_input(stream_input.clone()).into()
342        } else {
343            stream_input
344        };
345        core.agg_calls = non_approx_percentile_agg_calls;
346
347        let approx_percentile =
348            self.build_approx_percentile_aggs(core.input.clone(), &approx_percentile_agg_calls)?;
349        Ok((
350            non_approx_percentile_col_mapping,
351            approx_percentile_col_mapping,
352            approx_percentile,
353        ))
354    }
355
356    fn need_row_merge(approx_percentile: &Option<PlanRef>) -> bool {
357        approx_percentile.is_some()
358    }
359
360    /// Add `RowMerge` if needed
361    fn add_row_merge_if_needed(
362        approx_percentile: Option<PlanRef>,
363        global_agg: PlanRef,
364        approx_percentile_col_mapping: ColIndexMapping,
365        non_approx_percentile_col_mapping: ColIndexMapping,
366    ) -> Result<PlanRef> {
367        // just for assert
368        let need_row_merge = Self::need_row_merge(&approx_percentile);
369
370        if let Some(approx_percentile) = approx_percentile {
371            assert!(need_row_merge);
372            let row_merge = StreamRowMerge::new(
373                approx_percentile,
374                global_agg,
375                approx_percentile_col_mapping,
376                non_approx_percentile_col_mapping,
377            )?;
378            Ok(row_merge.into())
379        } else {
380            assert!(!need_row_merge);
381            Ok(global_agg)
382        }
383    }
384
385    fn separate_normal_and_special_agg(&self) -> SeparatedAggInfo {
386        let estimated_len = self.agg_calls().len() - 1;
387        let mut approx_percentile_agg_calls = Vec::with_capacity(estimated_len);
388        let mut non_approx_percentile_agg_calls = Vec::with_capacity(estimated_len);
389        let mut approx_percentile_col_mapping = Vec::with_capacity(estimated_len);
390        let mut non_approx_percentile_col_mapping = Vec::with_capacity(estimated_len);
391        for (output_idx, agg_call) in self.agg_calls().iter().enumerate() {
392            if agg_call.agg_type == AggType::Builtin(PbAggKind::ApproxPercentile) {
393                approx_percentile_agg_calls.push(agg_call.clone());
394                approx_percentile_col_mapping.push(Some(output_idx));
395            } else {
396                non_approx_percentile_agg_calls.push(agg_call.clone());
397                non_approx_percentile_col_mapping.push(Some(output_idx));
398            }
399        }
400        let normal = AggInfo {
401            calls: non_approx_percentile_agg_calls,
402            col_mapping: ColIndexMapping::new(
403                non_approx_percentile_col_mapping,
404                self.agg_calls().len(),
405            ),
406        };
407        let approx = AggInfo {
408            calls: approx_percentile_agg_calls,
409            col_mapping: ColIndexMapping::new(
410                approx_percentile_col_mapping,
411                self.agg_calls().len(),
412            ),
413        };
414        SeparatedAggInfo { normal, approx }
415    }
416
417    fn build_approx_percentile_agg(
418        &self,
419        input: PlanRef,
420        approx_percentile_agg_call: &PlanAggCall,
421    ) -> Result<PlanRef> {
422        let local_approx_percentile =
423            StreamLocalApproxPercentile::new(input, approx_percentile_agg_call);
424        let exchange = RequiredDist::single()
425            .enforce_if_not_satisfies(local_approx_percentile.into(), &Order::any())?;
426        let global_approx_percentile =
427            StreamGlobalApproxPercentile::new(exchange, approx_percentile_agg_call);
428        Ok(global_approx_percentile.into())
429    }
430
431    /// If only 1 approx percentile, just return it.
432    /// Otherwise build a tree of approx percentile with `MergeProject`.
433    /// e.g.
434    /// ApproxPercentile(col1, 0.5) as x,
435    /// ApproxPercentile(col2, 0.5) as y,
436    /// ApproxPercentile(col3, 0.5) as z
437    /// will be built as
438    ///        `MergeProject`
439    ///       /          \
440    ///  `MergeProject`       z
441    ///  /        \
442    /// x          y
443    fn build_approx_percentile_aggs(
444        &self,
445        input: PlanRef,
446        approx_percentile_agg_call: &[PlanAggCall],
447    ) -> Result<Option<PlanRef>> {
448        if approx_percentile_agg_call.is_empty() {
449            return Ok(None);
450        }
451        let approx_percentile_plans: Vec<PlanRef> = approx_percentile_agg_call
452            .iter()
453            .map(|agg_call| self.build_approx_percentile_agg(input.clone(), agg_call))
454            .try_collect()?;
455        assert!(!approx_percentile_plans.is_empty());
456        let mut iter = approx_percentile_plans.into_iter();
457        let mut acc = iter.next().unwrap();
458        for (current_size, plan) in iter.enumerate().map(|(i, p)| (i + 1, p)) {
459            let new_size = current_size + 1;
460            let row_merge = StreamRowMerge::new(
461                acc,
462                plan,
463                ColIndexMapping::identity_or_none(current_size, new_size),
464                ColIndexMapping::new(vec![Some(current_size)], new_size),
465            )
466            .expect("failed to build row merge");
467            acc = row_merge.into();
468        }
469        Ok(Some(acc))
470    }
471
472    pub fn core(&self) -> &Agg<PlanRef> {
473        &self.core
474    }
475}
476
477/// `LogicalAggBuilder` extracts agg calls and references to group columns from select list and
478/// build the plan like `LogicalAgg - LogicalProject`.
479/// it is constructed by `group_exprs` and collect and rewrite the expression in selection and
480/// having clause.
481pub struct LogicalAggBuilder {
482    /// the builder of the input Project
483    input_proj_builder: ProjectBuilder,
484    /// the group key column indices in the project's output
485    group_key: IndexSet,
486    /// the grouping sets
487    grouping_sets: Vec<IndexSet>,
488    /// the agg calls
489    agg_calls: Vec<PlanAggCall>,
490    /// the error during the expression rewriting
491    error: Option<RwError>,
492    /// If `is_in_filter_clause` is true, it means that
493    /// we are processing filter clause.
494    /// This field is needed because input refs in these clauses
495    /// are allowed to refer to any columns, while those not in filter
496    /// clause are only allowed to refer to group keys.
497    is_in_filter_clause: bool,
498}
499
500impl LogicalAggBuilder {
501    fn new(group_by: GroupBy, input_schema_len: usize) -> Result<Self> {
502        let mut input_proj_builder = ProjectBuilder::default();
503
504        let mut gen_group_key_and_grouping_sets =
505            |grouping_sets: Vec<Vec<ExprImpl>>| -> Result<(IndexSet, Vec<IndexSet>)> {
506                let grouping_sets: Vec<IndexSet> = grouping_sets
507                    .into_iter()
508                    .map(|set| {
509                        set.into_iter()
510                            .map(|expr| input_proj_builder.add_expr(&expr))
511                            .try_collect()
512                            .map_err(|err| not_implemented!("{err} inside GROUP BY"))
513                    })
514                    .try_collect()?;
515
516                // Construct group key based on grouping sets.
517                let group_key = grouping_sets
518                    .iter()
519                    .fold(FixedBitSet::with_capacity(input_schema_len), |acc, x| {
520                        acc.union(&x.to_bitset()).collect()
521                    });
522
523                Ok((IndexSet::from_iter(group_key.ones()), grouping_sets))
524            };
525
526        let (group_key, grouping_sets) = match group_by {
527            GroupBy::GroupKey(group_key) => {
528                let group_key = group_key
529                    .into_iter()
530                    .map(|expr| input_proj_builder.add_expr(&expr))
531                    .try_collect()
532                    .map_err(|err| not_implemented!("{err} inside GROUP BY"))?;
533                (group_key, vec![])
534            }
535            GroupBy::GroupingSets(grouping_sets) => gen_group_key_and_grouping_sets(grouping_sets)?,
536            GroupBy::Rollup(rollup) => {
537                // Convert rollup to grouping sets.
538                let grouping_sets = (0..=rollup.len())
539                    .map(|n| {
540                        rollup
541                            .iter()
542                            .take(n)
543                            .flat_map(|x| x.iter().cloned())
544                            .collect_vec()
545                    })
546                    .collect_vec();
547                gen_group_key_and_grouping_sets(grouping_sets)?
548            }
549            GroupBy::Cube(cube) => {
550                // Convert cube to grouping sets.
551                let grouping_sets = cube
552                    .into_iter()
553                    .powerset()
554                    .map(|x| x.into_iter().flatten().collect_vec())
555                    .collect_vec();
556                gen_group_key_and_grouping_sets(grouping_sets)?
557            }
558        };
559
560        Ok(LogicalAggBuilder {
561            group_key,
562            grouping_sets,
563            agg_calls: vec![],
564            error: None,
565            input_proj_builder,
566            is_in_filter_clause: false,
567        })
568    }
569
570    pub fn build(self, input: PlanRef) -> LogicalAgg {
571        // This LogicalProject focuses on the exprs in aggregates and GROUP BY clause.
572        let logical_project = LogicalProject::with_core(self.input_proj_builder.build(input));
573
574        // This LogicalAgg focuses on calculating the aggregates and grouping.
575        Agg::new(self.agg_calls, self.group_key, logical_project.into())
576            .with_grouping_sets(self.grouping_sets)
577            .into()
578    }
579
580    fn rewrite_with_error(&mut self, expr: ExprImpl) -> Result<ExprImpl> {
581        let rewritten_expr = self.rewrite_expr(expr);
582        if let Some(error) = self.error.take() {
583            return Err(error);
584        }
585        Ok(rewritten_expr)
586    }
587
588    /// check if the expression is a group by key, and try to return the group key
589    pub fn try_as_group_expr(&self, expr: &ExprImpl) -> Option<usize> {
590        if let Some(input_index) = self.input_proj_builder.expr_index(expr) {
591            if let Some(index) = self
592                .group_key
593                .indices()
594                .position(|group_key| group_key == input_index)
595            {
596                return Some(index);
597            }
598        }
599        None
600    }
601
602    fn schema_agg_start_offset(&self) -> usize {
603        self.group_key.len()
604    }
605
606    /// Rewrite [`AggCall`] if needed, and push it into the builder using `push_agg_call`.
607    /// This is shared by [`LogicalAggBuilder`] and `LogicalOverWindowBuilder`.
608    pub(crate) fn general_rewrite_agg_call(
609        agg_call: AggCall,
610        mut push_agg_call: impl FnMut(AggCall) -> Result<InputRef>,
611    ) -> Result<ExprImpl> {
612        match agg_call.agg_type {
613            // Rewrite avg to cast(sum as avg_return_type) / count.
614            AggType::Builtin(PbAggKind::Avg) => {
615                assert_eq!(agg_call.args.len(), 1);
616
617                let sum = ExprImpl::from(push_agg_call(AggCall::new(
618                    PbAggKind::Sum.into(),
619                    agg_call.args.clone(),
620                    agg_call.distinct,
621                    agg_call.order_by.clone(),
622                    agg_call.filter.clone(),
623                    agg_call.direct_args.clone(),
624                )?)?)
625                .cast_explicit(agg_call.return_type())?;
626
627                let count = ExprImpl::from(push_agg_call(AggCall::new(
628                    PbAggKind::Count.into(),
629                    agg_call.args.clone(),
630                    agg_call.distinct,
631                    agg_call.order_by.clone(),
632                    agg_call.filter.clone(),
633                    agg_call.direct_args.clone(),
634                )?)?);
635
636                Ok(FunctionCall::new(ExprType::Divide, Vec::from([sum, count]))?.into())
637            }
638            // We compute `var_samp` as
639            // (sum(sq) - sum * sum / count) / (count - 1)
640            // and `var_pop` as
641            // (sum(sq) - sum * sum / count) / count
642            // Since we don't have the square function, we use the plain Multiply for squaring,
643            // which is in a sense more general than the pow function, especially when calculating
644            // covariances in the future. Also we don't have the sqrt function for rooting, so we
645            // use pow(x, 0.5) to simulate
646            AggType::Builtin(
647                kind @ (PbAggKind::StddevPop
648                | PbAggKind::StddevSamp
649                | PbAggKind::VarPop
650                | PbAggKind::VarSamp),
651            ) => {
652                let arg = agg_call.args().iter().exactly_one().unwrap();
653                let squared_arg = ExprImpl::from(FunctionCall::new(
654                    ExprType::Multiply,
655                    vec![arg.clone(), arg.clone()],
656                )?);
657
658                let sum_of_sq = ExprImpl::from(push_agg_call(AggCall::new(
659                    PbAggKind::Sum.into(),
660                    vec![squared_arg],
661                    agg_call.distinct,
662                    agg_call.order_by.clone(),
663                    agg_call.filter.clone(),
664                    agg_call.direct_args.clone(),
665                )?)?)
666                .cast_explicit(agg_call.return_type())?;
667
668                let sum = ExprImpl::from(push_agg_call(AggCall::new(
669                    PbAggKind::Sum.into(),
670                    agg_call.args.clone(),
671                    agg_call.distinct,
672                    agg_call.order_by.clone(),
673                    agg_call.filter.clone(),
674                    agg_call.direct_args.clone(),
675                )?)?)
676                .cast_explicit(agg_call.return_type())?;
677
678                let count = ExprImpl::from(push_agg_call(AggCall::new(
679                    PbAggKind::Count.into(),
680                    agg_call.args.clone(),
681                    agg_call.distinct,
682                    agg_call.order_by.clone(),
683                    agg_call.filter.clone(),
684                    agg_call.direct_args.clone(),
685                )?)?);
686
687                let zero = ExprImpl::literal_int(0);
688                let one = ExprImpl::literal_int(1);
689
690                let squared_sum = ExprImpl::from(FunctionCall::new(
691                    ExprType::Multiply,
692                    vec![sum.clone(), sum],
693                )?);
694
695                let raw_numerator = ExprImpl::from(FunctionCall::new(
696                    ExprType::Subtract,
697                    vec![
698                        sum_of_sq,
699                        ExprImpl::from(FunctionCall::new(
700                            ExprType::Divide,
701                            vec![squared_sum, count.clone()],
702                        )?),
703                    ],
704                )?);
705
706                // We need to check for potential accuracy issues that may occasionally lead to results less than 0.
707                let numerator_type = raw_numerator.return_type();
708                let numerator = ExprImpl::from(FunctionCall::new(
709                    ExprType::Greatest,
710                    vec![raw_numerator, zero.clone().cast_explicit(numerator_type)?],
711                )?);
712
713                let denominator = match kind {
714                    PbAggKind::VarPop | PbAggKind::StddevPop => count.clone(),
715                    PbAggKind::VarSamp | PbAggKind::StddevSamp => ExprImpl::from(
716                        FunctionCall::new(ExprType::Subtract, vec![count.clone(), one.clone()])?,
717                    ),
718                    _ => unreachable!(),
719                };
720
721                let mut target = ExprImpl::from(FunctionCall::new(
722                    ExprType::Divide,
723                    vec![numerator, denominator],
724                )?);
725
726                if matches!(kind, PbAggKind::StddevPop | PbAggKind::StddevSamp) {
727                    target = ExprImpl::from(FunctionCall::new(ExprType::Sqrt, vec![target])?);
728                }
729
730                let null = ExprImpl::from(Literal::new(None, agg_call.return_type()));
731                let case_cond = match kind {
732                    PbAggKind::VarPop | PbAggKind::StddevPop => {
733                        ExprImpl::from(FunctionCall::new(ExprType::Equal, vec![count, zero])?)
734                    }
735                    PbAggKind::VarSamp | PbAggKind::StddevSamp => ExprImpl::from(
736                        FunctionCall::new(ExprType::LessThanOrEqual, vec![count, one])?,
737                    ),
738                    _ => unreachable!(),
739                };
740
741                Ok(ExprImpl::from(FunctionCall::new(
742                    ExprType::Case,
743                    vec![case_cond, null, target],
744                )?))
745            }
746            AggType::Builtin(PbAggKind::ApproxPercentile) => {
747                if agg_call.order_by.sort_exprs[0].order_type == OrderType::descending() {
748                    // Rewrite DESC into 1.0-percentile for approx_percentile.
749                    let prev_percentile = agg_call.direct_args[0].clone();
750                    let new_percentile = 1.0
751                        - prev_percentile
752                            .get_data()
753                            .as_ref()
754                            .unwrap()
755                            .as_float64()
756                            .into_inner();
757                    let new_percentile = Some(ScalarImpl::Float64(new_percentile.into()));
758                    let new_percentile = Literal::new(new_percentile, DataType::Float64);
759                    let new_direct_args = vec![new_percentile, agg_call.direct_args[1].clone()];
760
761                    let new_agg_call = AggCall {
762                        order_by: OrderBy::any(),
763                        direct_args: new_direct_args,
764                        ..agg_call
765                    };
766                    Ok(push_agg_call(new_agg_call)?.into())
767                } else {
768                    let new_agg_call = AggCall {
769                        order_by: OrderBy::any(),
770                        ..agg_call
771                    };
772                    Ok(push_agg_call(new_agg_call)?.into())
773                }
774            }
775            _ => Ok(push_agg_call(agg_call)?.into()),
776        }
777    }
778
779    /// Push a new agg call into the builder.
780    /// Return an `InputRef` to that agg call.
781    /// For existing agg calls, return an `InputRef` to the existing one.
782    fn push_agg_call(&mut self, agg_call: AggCall) -> Result<InputRef> {
783        let AggCall {
784            agg_type,
785            return_type,
786            args,
787            distinct,
788            order_by,
789            filter,
790            direct_args,
791        } = agg_call;
792
793        self.is_in_filter_clause = true;
794        // filter expr is not added to `input_proj_builder` as a whole. Special exprs incl
795        // subquery/agg/table are rejected in `bind_agg`.
796        let filter = filter.rewrite_expr(self);
797        self.is_in_filter_clause = false;
798
799        let args: Vec<_> = args
800            .iter()
801            .map(|expr| {
802                let index = self.input_proj_builder.add_expr(expr)?;
803                Ok(InputRef::new(index, expr.return_type()))
804            })
805            .try_collect()
806            .map_err(|err: &'static str| not_implemented!("{err} inside aggregation calls"))?;
807
808        let order_by: Vec<_> = order_by
809            .sort_exprs
810            .iter()
811            .map(|e| {
812                let index = self.input_proj_builder.add_expr(&e.expr)?;
813                Ok(ColumnOrder::new(index, e.order_type))
814            })
815            .try_collect()
816            .map_err(|err: &'static str| {
817                not_implemented!("{err} inside aggregation calls order by")
818            })?;
819
820        let plan_agg_call = PlanAggCall {
821            agg_type,
822            return_type: return_type.clone(),
823            inputs: args,
824            distinct,
825            order_by,
826            filter,
827            direct_args,
828        };
829
830        if let Some((pos, existing)) = self
831            .agg_calls
832            .iter()
833            .find_position(|&c| c == &plan_agg_call)
834        {
835            return Ok(InputRef::new(
836                self.schema_agg_start_offset() + pos,
837                existing.return_type.clone(),
838            ));
839        }
840        let index = self.schema_agg_start_offset() + self.agg_calls.len();
841        self.agg_calls.push(plan_agg_call);
842        Ok(InputRef::new(index, return_type))
843    }
844
845    /// When there is an agg call, there are 3 things to do:
846    /// 1. Rewrite `avg`, `var_samp`, etc. into a combination of `sum`, `count`, etc.;
847    /// 2. Add exprs in arguments to input `Project`;
848    /// 2. Add the agg call to current `Agg`, and return an `InputRef` to it.
849    ///
850    /// Note that the rewriter does not traverse into inputs of agg calls.
851    fn try_rewrite_agg_call(&mut self, mut agg_call: AggCall) -> Result<ExprImpl> {
852        if matches!(agg_call.agg_type, agg_types::must_have_order_by!())
853            && agg_call.order_by.sort_exprs.is_empty()
854        {
855            return Err(ErrorCode::InvalidInputSyntax(format!(
856                "Aggregation function {} requires ORDER BY clause",
857                agg_call.agg_type
858            ))
859            .into());
860        }
861
862        // try ignore ORDER BY if it doesn't affect the result
863        if matches!(
864            agg_call.agg_type,
865            agg_types::result_unaffected_by_order_by!()
866        ) {
867            agg_call.order_by = OrderBy::any();
868        }
869        // try ignore DISTINCT if it doesn't affect the result
870        if matches!(
871            agg_call.agg_type,
872            agg_types::result_unaffected_by_distinct!()
873        ) {
874            agg_call.distinct = false;
875        }
876
877        if matches!(agg_call.agg_type, AggType::Builtin(PbAggKind::Grouping)) {
878            if self.grouping_sets.is_empty() {
879                return Err(ErrorCode::NotSupported(
880                    "GROUPING must be used in a query with grouping sets".into(),
881                    "try to use grouping sets instead".into(),
882                )
883                .into());
884            }
885            if agg_call.args.len() >= 32 {
886                return Err(ErrorCode::InvalidInputSyntax(
887                    "GROUPING must have fewer than 32 arguments".into(),
888                )
889                .into());
890            }
891            if agg_call
892                .args
893                .iter()
894                .any(|x| self.try_as_group_expr(x).is_none())
895            {
896                return Err(ErrorCode::InvalidInputSyntax(
897                    "arguments to GROUPING must be grouping expressions of the associated query level"
898                        .into(),
899                ).into());
900            }
901        }
902
903        Self::general_rewrite_agg_call(agg_call, |agg_call| self.push_agg_call(agg_call))
904    }
905}
906
907impl ExprRewriter for LogicalAggBuilder {
908    fn rewrite_agg_call(&mut self, agg_call: AggCall) -> ExprImpl {
909        let dummy = Literal::new(None, agg_call.return_type()).into();
910        match self.try_rewrite_agg_call(agg_call) {
911            Ok(expr) => expr,
912            Err(err) => {
913                self.error = Some(err);
914                dummy
915            }
916        }
917    }
918
919    /// When there is an `FunctionCall` (outside of agg call), it must refers to a group column.
920    /// Or all `InputRef`s appears in it must refer to a group column.
921    fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
922        let expr = func_call.into();
923        if let Some(group_key) = self.try_as_group_expr(&expr) {
924            InputRef::new(group_key, expr.return_type()).into()
925        } else {
926            let (func_type, inputs, ret) = expr.into_function_call().unwrap().decompose();
927            let inputs = inputs
928                .into_iter()
929                .map(|expr| self.rewrite_expr(expr))
930                .collect();
931            FunctionCall::new_unchecked(func_type, inputs, ret).into()
932        }
933    }
934
935    /// When there is an `WindowFunction` (outside of agg call), it must refers to a group column.
936    /// Or all `InputRef`s appears in it must refer to a group column.
937    fn rewrite_window_function(&mut self, window_func: WindowFunction) -> ExprImpl {
938        let WindowFunction {
939            args,
940            partition_by,
941            order_by,
942            ..
943        } = window_func;
944        let args = args
945            .into_iter()
946            .map(|expr| self.rewrite_expr(expr))
947            .collect();
948        let partition_by = partition_by
949            .into_iter()
950            .map(|expr| self.rewrite_expr(expr))
951            .collect();
952        let order_by = order_by.rewrite_expr(self);
953        WindowFunction {
954            args,
955            partition_by,
956            order_by,
957            ..window_func
958        }
959        .into()
960    }
961
962    /// When there is an `InputRef` (outside of agg call), it must refers to a group column.
963    fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
964        let expr = input_ref.into();
965        if let Some(group_key) = self.try_as_group_expr(&expr) {
966            InputRef::new(group_key, expr.return_type()).into()
967        } else if self.is_in_filter_clause {
968            InputRef::new(
969                self.input_proj_builder.add_expr(&expr).unwrap(),
970                expr.return_type(),
971            )
972            .into()
973        } else {
974            self.error = Some(
975                ErrorCode::InvalidInputSyntax(
976                    "column must appear in the GROUP BY clause or be used in an aggregate function"
977                        .into(),
978                )
979                .into(),
980            );
981            expr
982        }
983    }
984
985    fn rewrite_subquery(&mut self, subquery: crate::expr::Subquery) -> ExprImpl {
986        if subquery.is_correlated(0) {
987            self.error = Some(
988                not_implemented!(
989                    issue = 2275,
990                    "correlated subquery in HAVING or SELECT with agg",
991                )
992                .into(),
993            );
994        }
995        subquery.into()
996    }
997}
998
999impl From<Agg<PlanRef>> for LogicalAgg {
1000    fn from(core: Agg<PlanRef>) -> Self {
1001        let base = PlanBase::new_logical_with_core(&core);
1002        Self { base, core }
1003    }
1004}
1005
1006/// Because `From`/`Into` are not transitive
1007impl From<Agg<PlanRef>> for PlanRef {
1008    fn from(core: Agg<PlanRef>) -> Self {
1009        LogicalAgg::from(core).into()
1010    }
1011}
1012
1013impl LogicalAgg {
1014    /// get the Mapping of columnIndex from input column index to out column index
1015    pub fn i2o_col_mapping(&self) -> ColIndexMapping {
1016        self.core.i2o_col_mapping()
1017    }
1018
1019    /// `create` will analyze select exprs, group exprs and having, and construct a plan like
1020    ///
1021    /// ```text
1022    /// LogicalAgg -> LogicalProject -> input
1023    /// ```
1024    ///
1025    /// It also returns the rewritten select exprs and having that reference into the aggregated
1026    /// results.
1027    pub fn create(
1028        select_exprs: Vec<ExprImpl>,
1029        group_by: GroupBy,
1030        having: Option<ExprImpl>,
1031        input: PlanRef,
1032    ) -> Result<(PlanRef, Vec<ExprImpl>, Option<ExprImpl>)> {
1033        let mut agg_builder = LogicalAggBuilder::new(group_by, input.schema().len())?;
1034
1035        let rewritten_select_exprs = select_exprs
1036            .into_iter()
1037            .map(|expr| agg_builder.rewrite_with_error(expr))
1038            .collect::<Result<_>>()?;
1039        let rewritten_having = having
1040            .map(|expr| agg_builder.rewrite_with_error(expr))
1041            .transpose()?;
1042
1043        Ok((
1044            agg_builder.build(input).into(),
1045            rewritten_select_exprs,
1046            rewritten_having,
1047        ))
1048    }
1049
1050    /// Get a reference to the logical agg's agg calls.
1051    pub fn agg_calls(&self) -> &Vec<PlanAggCall> {
1052        &self.core.agg_calls
1053    }
1054
1055    /// Get a reference to the logical agg's group key.
1056    pub fn group_key(&self) -> &IndexSet {
1057        &self.core.group_key
1058    }
1059
1060    pub fn grouping_sets(&self) -> &Vec<IndexSet> {
1061        &self.core.grouping_sets
1062    }
1063
1064    pub fn decompose(self) -> (Vec<PlanAggCall>, IndexSet, Vec<IndexSet>, PlanRef, bool) {
1065        self.core.decompose()
1066    }
1067
1068    #[must_use]
1069    pub fn rewrite_with_input_agg(
1070        &self,
1071        input: PlanRef,
1072        agg_calls: &[PlanAggCall],
1073        mut input_col_change: ColIndexMapping,
1074    ) -> (Self, ColIndexMapping) {
1075        let agg_calls = agg_calls
1076            .iter()
1077            .cloned()
1078            .map(|mut agg_call| {
1079                agg_call.inputs.iter_mut().for_each(|i| {
1080                    *i = InputRef::new(input_col_change.map(i.index()), i.return_type())
1081                });
1082                agg_call.order_by.iter_mut().for_each(|o| {
1083                    o.column_index = input_col_change.map(o.column_index);
1084                });
1085                agg_call.filter = agg_call.filter.rewrite_expr(&mut input_col_change);
1086                agg_call
1087            })
1088            .collect();
1089        // This is the group key order should be after rewriting.
1090        let group_key_in_vec: Vec<usize> = self
1091            .group_key()
1092            .indices()
1093            .map(|key| input_col_change.map(key))
1094            .collect();
1095        // This is the group key order we get after rewriting.
1096        let group_key: IndexSet = group_key_in_vec.iter().cloned().collect();
1097        let grouping_sets = self
1098            .grouping_sets()
1099            .iter()
1100            .map(|set| set.indices().map(|key| input_col_change.map(key)).collect())
1101            .collect();
1102
1103        let new_agg = Agg::new(agg_calls, group_key.clone(), input)
1104            .with_grouping_sets(grouping_sets)
1105            .with_enable_two_phase(self.core().enable_two_phase);
1106
1107        // group_key remapping might cause an output column change, since group key actually is a
1108        // `FixedBitSet`.
1109        let mut out_col_change = vec![];
1110        for idx in group_key_in_vec {
1111            let pos = group_key.indices().position(|x| x == idx).unwrap();
1112            out_col_change.push(pos);
1113        }
1114        for i in (group_key.len())..new_agg.schema().len() {
1115            out_col_change.push(i);
1116        }
1117        let out_col_change =
1118            ColIndexMapping::with_remaining_columns(&out_col_change, new_agg.schema().len());
1119
1120        (new_agg.into(), out_col_change)
1121    }
1122}
1123
1124impl PlanTreeNodeUnary for LogicalAgg {
1125    fn input(&self) -> PlanRef {
1126        self.core.input.clone()
1127    }
1128
1129    fn clone_with_input(&self, input: PlanRef) -> Self {
1130        Agg::new(self.agg_calls().to_vec(), self.group_key().clone(), input)
1131            .with_grouping_sets(self.grouping_sets().clone())
1132            .with_enable_two_phase(self.core().enable_two_phase)
1133            .into()
1134    }
1135
1136    fn rewrite_with_input(
1137        &self,
1138        input: PlanRef,
1139        input_col_change: ColIndexMapping,
1140    ) -> (Self, ColIndexMapping) {
1141        self.rewrite_with_input_agg(input, self.agg_calls(), input_col_change)
1142    }
1143}
1144
1145impl_plan_tree_node_for_unary! {LogicalAgg}
1146impl_distill_by_unit!(LogicalAgg, core, "LogicalAgg");
1147
1148impl ExprRewritable for LogicalAgg {
1149    fn has_rewritable_expr(&self) -> bool {
1150        true
1151    }
1152
1153    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
1154        let mut core = self.core.clone();
1155        core.rewrite_exprs(r);
1156        Self {
1157            base: self.base.clone_with_new_plan_id(),
1158            core,
1159        }
1160        .into()
1161    }
1162}
1163
1164impl ExprVisitable for LogicalAgg {
1165    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
1166        self.core.visit_exprs(v);
1167    }
1168}
1169
1170impl ColPrunable for LogicalAgg {
1171    fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
1172        let group_key_required_cols = self.group_key().to_bitset();
1173
1174        let (agg_call_required_cols, agg_calls) = {
1175            let input_cnt = self.input().schema().len();
1176            let mut tmp = FixedBitSet::with_capacity(input_cnt);
1177            let group_key_cardinality = self.group_key().len();
1178            let new_agg_calls = required_cols
1179                .iter()
1180                .filter(|&&index| index >= group_key_cardinality)
1181                .map(|&index| {
1182                    let index = index - group_key_cardinality;
1183                    let agg_call = self.agg_calls()[index].clone();
1184                    tmp.extend(agg_call.inputs.iter().map(|x| x.index()));
1185                    tmp.extend(agg_call.order_by.iter().map(|x| x.column_index));
1186                    // collect columns used in aggregate filter expressions
1187                    for i in &agg_call.filter.conjunctions {
1188                        tmp.union_with(&i.collect_input_refs(input_cnt));
1189                    }
1190                    agg_call
1191                })
1192                .collect_vec();
1193            (tmp, new_agg_calls)
1194        };
1195
1196        let input_required_cols = {
1197            let mut tmp = FixedBitSet::with_capacity(self.input().schema().len());
1198            tmp.union_with(&group_key_required_cols);
1199            tmp.union_with(&agg_call_required_cols);
1200            tmp.ones().collect_vec()
1201        };
1202        let input_col_change = ColIndexMapping::with_remaining_columns(
1203            &input_required_cols,
1204            self.input().schema().len(),
1205        );
1206        let agg = {
1207            let input = self.input().prune_col(&input_required_cols, ctx);
1208            let (agg, output_col_change) =
1209                self.rewrite_with_input_agg(input, &agg_calls, input_col_change);
1210            assert!(output_col_change.is_identity());
1211            agg
1212        };
1213        let new_output_cols = {
1214            // group key were never pruned or even re-ordered in current impl
1215            let group_key_cardinality = agg.group_key().len();
1216            let mut tmp = (0..group_key_cardinality).collect_vec();
1217            tmp.extend(
1218                required_cols
1219                    .iter()
1220                    .filter(|&&index| index >= group_key_cardinality),
1221            );
1222            tmp
1223        };
1224        if new_output_cols == required_cols {
1225            // current schema perfectly fit the required columns
1226            agg.into()
1227        } else {
1228            // some columns are not needed, or the order need to be adjusted.
1229            // so we did a projection to remove/reorder the columns.
1230            let mapping =
1231                &ColIndexMapping::with_remaining_columns(&new_output_cols, self.schema().len());
1232            let output_required_cols = required_cols
1233                .iter()
1234                .map(|&idx| mapping.map(idx))
1235                .collect_vec();
1236            let src_size = agg.schema().len();
1237            LogicalProject::with_mapping(
1238                agg.into(),
1239                ColIndexMapping::with_remaining_columns(&output_required_cols, src_size),
1240            )
1241            .into()
1242        }
1243    }
1244}
1245
1246impl PredicatePushdown for LogicalAgg {
1247    fn predicate_pushdown(
1248        &self,
1249        predicate: Condition,
1250        ctx: &mut PredicatePushdownContext,
1251    ) -> PlanRef {
1252        let num_group_key = self.group_key().len();
1253        let num_agg_calls = self.agg_calls().len();
1254        assert!(num_group_key + num_agg_calls == self.schema().len());
1255
1256        // SimpleAgg should be skipped because the predicate either references agg_calls
1257        // or is const.
1258        // If the filter references agg_calls, we can not push it.
1259        // When it is constantly true, pushing is useless and may actually cause more evaluation
1260        // cost of the predicate.
1261        // When it is constantly false, pushing is wrong - the old plan returns 0 rows but new one
1262        // returns 1 row.
1263        if num_group_key == 0 {
1264            return gen_filter_and_pushdown(self, predicate, Condition::true_cond(), ctx);
1265        }
1266
1267        // If the filter references agg_calls, we can not push it.
1268        let mut agg_call_columns = FixedBitSet::with_capacity(num_group_key + num_agg_calls);
1269        agg_call_columns.insert_range(num_group_key..num_group_key + num_agg_calls);
1270        let (agg_call_pred, pushed_predicate) = predicate.split_disjoint(&agg_call_columns);
1271
1272        // convert the predicate to one that references the child of the agg
1273        let mut subst = Substitute {
1274            mapping: self
1275                .group_key()
1276                .indices()
1277                .enumerate()
1278                .map(|(i, group_key)| {
1279                    InputRef::new(group_key, self.schema().fields()[i].data_type()).into()
1280                })
1281                .collect(),
1282        };
1283        let pushed_predicate = pushed_predicate.rewrite_expr(&mut subst);
1284
1285        gen_filter_and_pushdown(self, agg_call_pred, pushed_predicate, ctx)
1286    }
1287}
1288
1289impl ToBatch for LogicalAgg {
1290    fn to_batch(&self) -> Result<PlanRef> {
1291        self.to_batch_with_order_required(&Order::any())
1292    }
1293
1294    // TODO(rc): `to_batch_with_order_required` seems to be useless after we decide to use
1295    // `BatchSortAgg` only when input is already sorted
1296    fn to_batch_with_order_required(&self, required_order: &Order) -> Result<PlanRef> {
1297        let input = self.input().to_batch()?;
1298        let new_logical = Agg {
1299            input,
1300            ..self.core.clone()
1301        };
1302        let agg_plan = if self.group_key().is_empty() {
1303            BatchSimpleAgg::new(new_logical).into()
1304        } else if self.ctx().session_ctx().config().batch_enable_sort_agg()
1305            && new_logical.input_provides_order_on_group_keys()
1306        {
1307            BatchSortAgg::new(new_logical).into()
1308        } else {
1309            BatchHashAgg::new(new_logical).into()
1310        };
1311        required_order.enforce_if_not_satisfies(agg_plan)
1312    }
1313}
1314
1315fn find_or_append_row_count(mut logical: Agg<PlanRef>) -> (Agg<PlanRef>, usize) {
1316    // `HashAgg`/`SimpleAgg` executors require a `count(*)` to correctly build changes, so
1317    // append a `count(*)` if not exists.
1318    let count_star = PlanAggCall::count_star();
1319    let row_count_idx = if let Some((idx, _)) = logical
1320        .agg_calls
1321        .iter()
1322        .find_position(|&c| c == &count_star)
1323    {
1324        idx
1325    } else {
1326        let idx = logical.agg_calls.len();
1327        logical.agg_calls.push(count_star);
1328        idx
1329    };
1330    (logical, row_count_idx)
1331}
1332
1333fn new_stream_simple_agg(core: Agg<PlanRef>, must_output_per_barrier: bool) -> StreamSimpleAgg {
1334    let (logical, row_count_idx) = find_or_append_row_count(core);
1335    StreamSimpleAgg::new(logical, row_count_idx, must_output_per_barrier)
1336}
1337
1338fn new_stream_hash_agg(core: Agg<PlanRef>, vnode_col_idx: Option<usize>) -> StreamHashAgg {
1339    let (logical, row_count_idx) = find_or_append_row_count(core);
1340    StreamHashAgg::new(logical, vnode_col_idx, row_count_idx)
1341}
1342
1343impl ToStream for LogicalAgg {
1344    fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<PlanRef> {
1345        use super::stream::prelude::*;
1346
1347        for agg_call in self.agg_calls() {
1348            if matches!(agg_call.agg_type, agg_types::unimplemented_in_stream!()) {
1349                bail_not_implemented!("{} aggregation in materialized view", agg_call.agg_type);
1350            }
1351        }
1352        let eowc = ctx.emit_on_window_close();
1353        let stream_input = self.input().to_stream(ctx)?;
1354
1355        // Use Dedup operator, if possible.
1356        if stream_input.append_only() && self.agg_calls().is_empty() {
1357            let input = if self.group_key().len() != self.input().schema().len() {
1358                let cols = &self.group_key().to_vec();
1359                LogicalProject::with_mapping(
1360                    self.input(),
1361                    ColIndexMapping::with_remaining_columns(cols, self.input().schema().len()),
1362                )
1363                .into()
1364            } else {
1365                self.input()
1366            };
1367            let input_schema_len = input.schema().len();
1368            let logical_dedup = LogicalDedup::new(input, (0..input_schema_len).collect());
1369            return logical_dedup.to_stream(ctx);
1370        }
1371
1372        let plan = self.gen_dist_stream_agg_plan(stream_input)?;
1373
1374        let (plan, n_final_agg_calls) = if let Some(final_agg) = plan.as_stream_simple_agg() {
1375            if eowc {
1376                return Err(ErrorCode::InvalidInputSyntax(
1377                    "`EMIT ON WINDOW CLOSE` cannot be used for aggregation without `GROUP BY`"
1378                        .to_owned(),
1379                )
1380                .into());
1381            }
1382            (plan.clone(), final_agg.agg_calls().len())
1383        } else if let Some(final_agg) = plan.as_stream_hash_agg() {
1384            (
1385                if eowc {
1386                    final_agg.to_eowc_version()?
1387                } else {
1388                    plan.clone()
1389                },
1390                final_agg.agg_calls().len(),
1391            )
1392        } else if let Some(_approx_percentile_agg) = plan.as_stream_global_approx_percentile() {
1393            if eowc {
1394                return Err(ErrorCode::InvalidInputSyntax(
1395                    "`EMIT ON WINDOW CLOSE` cannot be used for aggregation without `GROUP BY`"
1396                        .to_owned(),
1397                )
1398                .into());
1399            }
1400            (plan.clone(), 1)
1401        } else if let Some(stream_row_merge) = plan.as_stream_row_merge() {
1402            if eowc {
1403                return Err(ErrorCode::InvalidInputSyntax(
1404                    "`EMIT ON WINDOW CLOSE` cannot be used for aggregation without `GROUP BY`"
1405                        .to_owned(),
1406                )
1407                .into());
1408            }
1409            (plan.clone(), stream_row_merge.base.schema().len())
1410        } else {
1411            panic!(
1412                "the root PlanNode must be StreamHashAgg, StreamSimpleAgg, StreamGlobalApproxPercentile, or StreamRowMerge"
1413            );
1414        };
1415
1416        if self.agg_calls().len() == n_final_agg_calls {
1417            // an existing `count(*)` is used as row count column in `StreamXxxAgg`
1418            Ok(plan)
1419        } else {
1420            // a `count(*)` is appended, should project the output
1421            assert_eq!(self.agg_calls().len() + 1, n_final_agg_calls);
1422            Ok(StreamProject::new(generic::Project::with_out_col_idx(
1423                plan,
1424                0..self.schema().len(),
1425            ))
1426            // If there's no agg call, then `count(*)` will be the only column in the output besides keys.
1427            // Since it'll be pruned immediately in `StreamProject`, the update records are likely to be
1428            // no-op. So we set the hint to instruct the executor to eliminate them.
1429            // See https://github.com/risingwavelabs/risingwave/issues/17030.
1430            .with_noop_update_hint(self.agg_calls().is_empty())
1431            .into())
1432        }
1433    }
1434
1435    fn logical_rewrite_for_stream(
1436        &self,
1437        ctx: &mut RewriteStreamContext,
1438    ) -> Result<(PlanRef, ColIndexMapping)> {
1439        let (input, input_col_change) = self.input().logical_rewrite_for_stream(ctx)?;
1440        let (agg, out_col_change) = self.rewrite_with_input(input, input_col_change);
1441        let (map, _) = out_col_change.into_parts();
1442        let out_col_change = ColIndexMapping::new(map, agg.schema().len());
1443        Ok((agg.into(), out_col_change))
1444    }
1445}
1446
1447#[cfg(test)]
1448mod tests {
1449    use risingwave_common::catalog::{Field, Schema};
1450
1451    use super::*;
1452    use crate::expr::{assert_eq_input_ref, input_ref_to_column_indices};
1453    use crate::optimizer::optimizer_context::OptimizerContext;
1454    use crate::optimizer::plan_node::LogicalValues;
1455
1456    #[tokio::test]
1457    async fn test_create() {
1458        let ty = DataType::Int32;
1459        let ctx = OptimizerContext::mock().await;
1460        let fields: Vec<Field> = vec![
1461            Field::with_name(ty.clone(), "v1"),
1462            Field::with_name(ty.clone(), "v2"),
1463            Field::with_name(ty.clone(), "v3"),
1464        ];
1465        let values = LogicalValues::new(vec![], Schema { fields }, ctx);
1466        let input = PlanRef::from(values);
1467        let input_ref_1 = InputRef::new(0, ty.clone());
1468        let input_ref_2 = InputRef::new(1, ty.clone());
1469        let input_ref_3 = InputRef::new(2, ty.clone());
1470
1471        let gen_internal_value = |select_exprs: Vec<ExprImpl>,
1472                                  group_exprs|
1473         -> (Vec<ExprImpl>, Vec<PlanAggCall>, IndexSet) {
1474            let (plan, exprs, _) = LogicalAgg::create(
1475                select_exprs,
1476                GroupBy::GroupKey(group_exprs),
1477                None,
1478                input.clone(),
1479            )
1480            .unwrap();
1481
1482            let logical_agg = plan.as_logical_agg().unwrap();
1483            let agg_calls = logical_agg.agg_calls().to_vec();
1484            let group_key = logical_agg.group_key().clone();
1485
1486            (exprs, agg_calls, group_key)
1487        };
1488
1489        // Test case: select v1 from test group by v1;
1490        {
1491            let select_exprs = vec![input_ref_1.clone().into()];
1492            let group_exprs = vec![input_ref_1.clone().into()];
1493
1494            let (exprs, agg_calls, group_key) = gen_internal_value(select_exprs, group_exprs);
1495
1496            assert_eq!(exprs.len(), 1);
1497            assert_eq_input_ref!(&exprs[0], 0);
1498
1499            assert_eq!(agg_calls.len(), 0);
1500            assert_eq!(group_key, vec![0].into());
1501        }
1502
1503        // Test case: select v1, min(v2) from test group by v1;
1504        {
1505            let min_v2 = AggCall::new(
1506                PbAggKind::Min.into(),
1507                vec![input_ref_2.clone().into()],
1508                false,
1509                OrderBy::any(),
1510                Condition::true_cond(),
1511                vec![],
1512            )
1513            .unwrap();
1514            let select_exprs = vec![input_ref_1.clone().into(), min_v2.into()];
1515            let group_exprs = vec![input_ref_1.clone().into()];
1516
1517            let (exprs, agg_calls, group_key) = gen_internal_value(select_exprs, group_exprs);
1518
1519            assert_eq!(exprs.len(), 2);
1520            assert_eq_input_ref!(&exprs[0], 0);
1521            assert_eq_input_ref!(&exprs[1], 1);
1522
1523            assert_eq!(agg_calls.len(), 1);
1524            assert_eq!(agg_calls[0].agg_type, PbAggKind::Min.into());
1525            assert_eq!(input_ref_to_column_indices(&agg_calls[0].inputs), vec![1]);
1526            assert_eq!(group_key, vec![0].into());
1527        }
1528
1529        // Test case: select v1, min(v2) + max(v3) from t group by v1;
1530        {
1531            let min_v2 = AggCall::new(
1532                PbAggKind::Min.into(),
1533                vec![input_ref_2.clone().into()],
1534                false,
1535                OrderBy::any(),
1536                Condition::true_cond(),
1537                vec![],
1538            )
1539            .unwrap();
1540            let max_v3 = AggCall::new(
1541                PbAggKind::Max.into(),
1542                vec![input_ref_3.clone().into()],
1543                false,
1544                OrderBy::any(),
1545                Condition::true_cond(),
1546                vec![],
1547            )
1548            .unwrap();
1549            let func_call =
1550                FunctionCall::new(ExprType::Add, vec![min_v2.into(), max_v3.into()]).unwrap();
1551            let select_exprs = vec![input_ref_1.clone().into(), ExprImpl::from(func_call)];
1552            let group_exprs = vec![input_ref_1.clone().into()];
1553
1554            let (exprs, agg_calls, group_key) = gen_internal_value(select_exprs, group_exprs);
1555
1556            assert_eq_input_ref!(&exprs[0], 0);
1557            if let ExprImpl::FunctionCall(func_call) = &exprs[1] {
1558                assert_eq!(func_call.func_type(), ExprType::Add);
1559                let inputs = func_call.inputs();
1560                assert_eq_input_ref!(&inputs[0], 1);
1561                assert_eq_input_ref!(&inputs[1], 2);
1562            } else {
1563                panic!("Wrong expression type!");
1564            }
1565
1566            assert_eq!(agg_calls.len(), 2);
1567            assert_eq!(agg_calls[0].agg_type, PbAggKind::Min.into());
1568            assert_eq!(input_ref_to_column_indices(&agg_calls[0].inputs), vec![1]);
1569            assert_eq!(agg_calls[1].agg_type, PbAggKind::Max.into());
1570            assert_eq!(input_ref_to_column_indices(&agg_calls[1].inputs), vec![2]);
1571            assert_eq!(group_key, vec![0].into());
1572        }
1573
1574        // Test case: select v2, min(v1 * v3) from test group by v2;
1575        {
1576            let v1_mult_v3 = FunctionCall::new(
1577                ExprType::Multiply,
1578                vec![input_ref_1.into(), input_ref_3.into()],
1579            )
1580            .unwrap();
1581            let agg_call = AggCall::new(
1582                PbAggKind::Min.into(),
1583                vec![v1_mult_v3.into()],
1584                false,
1585                OrderBy::any(),
1586                Condition::true_cond(),
1587                vec![],
1588            )
1589            .unwrap();
1590            let select_exprs = vec![input_ref_2.clone().into(), agg_call.into()];
1591            let group_exprs = vec![input_ref_2.into()];
1592
1593            let (exprs, agg_calls, group_key) = gen_internal_value(select_exprs, group_exprs);
1594
1595            assert_eq_input_ref!(&exprs[0], 0);
1596            assert_eq_input_ref!(&exprs[1], 1);
1597
1598            assert_eq!(agg_calls.len(), 1);
1599            assert_eq!(agg_calls[0].agg_type, PbAggKind::Min.into());
1600            assert_eq!(input_ref_to_column_indices(&agg_calls[0].inputs), vec![1]);
1601            assert_eq!(group_key, vec![0].into());
1602        }
1603    }
1604
1605    /// Generate a agg call node with given [`DataType`] and fields.
1606    /// For example, `generate_agg_call(Int32, [v1, v2, v3])` will result in:
1607    /// ```text
1608    /// Agg(min(input_ref(2))) group by (input_ref(1))
1609    ///   TableScan(v1, v2, v3)
1610    /// ```
1611    async fn generate_agg_call(ty: DataType, fields: Vec<Field>) -> LogicalAgg {
1612        let ctx = OptimizerContext::mock().await;
1613
1614        let values = LogicalValues::new(vec![], Schema { fields }, ctx);
1615        let agg_call = PlanAggCall {
1616            agg_type: PbAggKind::Min.into(),
1617            return_type: ty.clone(),
1618            inputs: vec![InputRef::new(2, ty.clone())],
1619            distinct: false,
1620            order_by: vec![],
1621            filter: Condition::true_cond(),
1622            direct_args: vec![],
1623        };
1624        Agg::new(vec![agg_call], vec![1].into(), values.into()).into()
1625    }
1626
1627    #[tokio::test]
1628    /// Pruning
1629    /// ```text
1630    /// Agg(min(input_ref(2))) group by (input_ref(1))
1631    ///   TableScan(v1, v2, v3)
1632    /// ```
1633    /// with required columns [0,1] (all columns) will result in
1634    /// ```text
1635    /// Agg(min(input_ref(1))) group by (input_ref(0))
1636    ///  TableScan(v2, v3)
1637    /// ```
1638    async fn test_prune_all() {
1639        let ty = DataType::Int32;
1640        let fields: Vec<Field> = vec![
1641            Field::with_name(ty.clone(), "v1"),
1642            Field::with_name(ty.clone(), "v2"),
1643            Field::with_name(ty.clone(), "v3"),
1644        ];
1645        let agg: PlanRef = generate_agg_call(ty.clone(), fields.clone()).await.into();
1646        // Perform the prune
1647        let required_cols = vec![0, 1];
1648        let plan = agg.prune_col(&required_cols, &mut ColumnPruningContext::new(agg.clone()));
1649
1650        // Check the result
1651        let agg_new = plan.as_logical_agg().unwrap();
1652        assert_eq!(agg_new.group_key(), &vec![0].into());
1653
1654        assert_eq!(agg_new.agg_calls().len(), 1);
1655        let agg_call_new = agg_new.agg_calls()[0].clone();
1656        assert_eq!(agg_call_new.agg_type, PbAggKind::Min.into());
1657        assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![1]);
1658        assert_eq!(agg_call_new.return_type, ty);
1659
1660        let values = agg_new.input();
1661        let values = values.as_logical_values().unwrap();
1662        assert_eq!(values.schema().fields(), &fields[1..]);
1663    }
1664
1665    #[tokio::test]
1666    /// Pruning
1667    /// ```text
1668    /// Agg(min(input_ref(2))) group by (input_ref(1))
1669    ///   TableScan(v1, v2, v3)
1670    /// ```
1671    /// with required columns [1,0] (all columns, with reversed order) will result in
1672    /// ```text
1673    /// Project [input_ref(1), input_ref(0)]
1674    ///   Agg(min(input_ref(1))) group by (input_ref(0))
1675    ///     TableScan(v2, v3)
1676    /// ```
1677    async fn test_prune_all_with_order_required() {
1678        let ty = DataType::Int32;
1679        let fields: Vec<Field> = vec![
1680            Field::with_name(ty.clone(), "v1"),
1681            Field::with_name(ty.clone(), "v2"),
1682            Field::with_name(ty.clone(), "v3"),
1683        ];
1684        let agg: PlanRef = generate_agg_call(ty.clone(), fields.clone()).await.into();
1685        // Perform the prune
1686        let required_cols = vec![1, 0];
1687        let plan = agg.prune_col(&required_cols, &mut ColumnPruningContext::new(agg.clone()));
1688        // Check the result
1689        let proj = plan.as_logical_project().unwrap();
1690        assert_eq!(proj.exprs().len(), 2);
1691        assert_eq!(proj.exprs()[0].as_input_ref().unwrap().index(), 1);
1692        assert_eq!(proj.exprs()[1].as_input_ref().unwrap().index(), 0);
1693        let proj_input = proj.input();
1694        let agg_new = proj_input.as_logical_agg().unwrap();
1695        assert_eq!(agg_new.group_key(), &vec![0].into());
1696
1697        assert_eq!(agg_new.agg_calls().len(), 1);
1698        let agg_call_new = agg_new.agg_calls()[0].clone();
1699        assert_eq!(agg_call_new.agg_type, PbAggKind::Min.into());
1700        assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![1]);
1701        assert_eq!(agg_call_new.return_type, ty);
1702
1703        let values = agg_new.input();
1704        let values = values.as_logical_values().unwrap();
1705        assert_eq!(values.schema().fields(), &fields[1..]);
1706    }
1707
1708    #[tokio::test]
1709    /// Pruning
1710    /// ```text
1711    /// Agg(min(input_ref(2))) group by (input_ref(1))
1712    ///   TableScan(v1, v2, v3)
1713    /// ```
1714    /// with required columns [1] (group key removed) will result in
1715    /// ```text
1716    /// Project(input_ref(1))
1717    ///   Agg(min(input_ref(1))) group by (input_ref(0))
1718    ///     TableScan(v2, v3)
1719    /// ```
1720    async fn test_prune_group_key() {
1721        let ctx = OptimizerContext::mock().await;
1722        let ty = DataType::Int32;
1723        let fields: Vec<Field> = vec![
1724            Field::with_name(ty.clone(), "v1"),
1725            Field::with_name(ty.clone(), "v2"),
1726            Field::with_name(ty.clone(), "v3"),
1727        ];
1728        let values: LogicalValues = LogicalValues::new(
1729            vec![],
1730            Schema {
1731                fields: fields.clone(),
1732            },
1733            ctx,
1734        );
1735        let agg_call = PlanAggCall {
1736            agg_type: PbAggKind::Min.into(),
1737            return_type: ty.clone(),
1738            inputs: vec![InputRef::new(2, ty.clone())],
1739            distinct: false,
1740            order_by: vec![],
1741            filter: Condition::true_cond(),
1742            direct_args: vec![],
1743        };
1744        let agg: PlanRef = Agg::new(vec![agg_call], vec![1].into(), values.into()).into();
1745
1746        // Perform the prune
1747        let required_cols = vec![1];
1748        let plan = agg.prune_col(&required_cols, &mut ColumnPruningContext::new(agg.clone()));
1749
1750        // Check the result
1751        let project = plan.as_logical_project().unwrap();
1752        assert_eq!(project.exprs().len(), 1);
1753        assert_eq_input_ref!(&project.exprs()[0], 1);
1754
1755        let agg_new = project.input();
1756        let agg_new = agg_new.as_logical_agg().unwrap();
1757        assert_eq!(agg_new.group_key(), &vec![0].into());
1758
1759        assert_eq!(agg_new.agg_calls().len(), 1);
1760        let agg_call_new = agg_new.agg_calls()[0].clone();
1761        assert_eq!(agg_call_new.agg_type, PbAggKind::Min.into());
1762        assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![1]);
1763        assert_eq!(agg_call_new.return_type, ty);
1764
1765        let values = agg_new.input();
1766        let values = values.as_logical_values().unwrap();
1767        assert_eq!(values.schema().fields(), &fields[1..]);
1768    }
1769
1770    #[tokio::test]
1771    /// Pruning
1772    /// ```text
1773    /// Agg(min(input_ref(2)), max(input_ref(1))) group by (input_ref(1), input_ref(2))
1774    ///   TableScan(v1, v2, v3)
1775    /// ```
1776    /// with required columns [0,3] will result in
1777    /// ```text
1778    /// Project(input_ref(0), input_ref(2))
1779    ///   Agg(max(input_ref(0))) group by (input_ref(0), input_ref(1))
1780    ///     TableScan(v2, v3)
1781    /// ```
1782    async fn test_prune_agg() {
1783        let ty = DataType::Int32;
1784        let ctx = OptimizerContext::mock().await;
1785        let fields: Vec<Field> = vec![
1786            Field::with_name(ty.clone(), "v1"),
1787            Field::with_name(ty.clone(), "v2"),
1788            Field::with_name(ty.clone(), "v3"),
1789        ];
1790        let values = LogicalValues::new(
1791            vec![],
1792            Schema {
1793                fields: fields.clone(),
1794            },
1795            ctx,
1796        );
1797
1798        let agg_calls = vec![
1799            PlanAggCall {
1800                agg_type: PbAggKind::Min.into(),
1801                return_type: ty.clone(),
1802                inputs: vec![InputRef::new(2, ty.clone())],
1803                distinct: false,
1804                order_by: vec![],
1805                filter: Condition::true_cond(),
1806                direct_args: vec![],
1807            },
1808            PlanAggCall {
1809                agg_type: PbAggKind::Max.into(),
1810                return_type: ty.clone(),
1811                inputs: vec![InputRef::new(1, ty.clone())],
1812                distinct: false,
1813                order_by: vec![],
1814                filter: Condition::true_cond(),
1815                direct_args: vec![],
1816            },
1817        ];
1818        let agg: PlanRef = Agg::new(agg_calls, vec![1, 2].into(), values.into()).into();
1819
1820        // Perform the prune
1821        let required_cols = vec![0, 3];
1822        let plan = agg.prune_col(&required_cols, &mut ColumnPruningContext::new(agg.clone()));
1823        // Check the result
1824        let project = plan.as_logical_project().unwrap();
1825        assert_eq!(project.exprs().len(), 2);
1826        assert_eq_input_ref!(&project.exprs()[0], 0);
1827        assert_eq_input_ref!(&project.exprs()[1], 2);
1828
1829        let agg_new = project.input();
1830        let agg_new = agg_new.as_logical_agg().unwrap();
1831        assert_eq!(agg_new.group_key(), &vec![0, 1].into());
1832
1833        assert_eq!(agg_new.agg_calls().len(), 1);
1834        let agg_call_new = agg_new.agg_calls()[0].clone();
1835        assert_eq!(agg_call_new.agg_type, PbAggKind::Max.into());
1836        assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![0]);
1837        assert_eq!(agg_call_new.return_type, ty);
1838
1839        let values = agg_new.input();
1840        let values = values.as_logical_values().unwrap();
1841        assert_eq!(values.schema().fields(), &fields[1..]);
1842    }
1843}