risingwave_frontend/optimizer/plan_node/
batch_simple_agg.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use risingwave_expr::aggregate::{AggType, PbAggKind};
16use risingwave_pb::batch_plan::SortAggNode;
17use risingwave_pb::batch_plan::plan_node::NodeBody;
18
19use super::batch::prelude::*;
20use super::generic::{self, PlanAggCall};
21use super::utils::impl_distill_by_unit;
22use super::{
23    BatchPlanRef as PlanRef, ExprRewritable, PlanBase, PlanTreeNodeUnary, ToBatchPb,
24    ToDistributedBatch,
25};
26use crate::error::Result;
27use crate::expr::{ExprRewriter, ExprVisitor};
28use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
29use crate::optimizer::plan_node::{BatchExchange, ToLocalBatch};
30use crate::optimizer::property::{Distribution, Order, RequiredDist};
31
32#[derive(Debug, Clone, PartialEq, Eq, Hash)]
33pub struct BatchSimpleAgg {
34    pub base: PlanBase<Batch>,
35    pub core: generic::Agg<PlanRef>,
36}
37
38impl BatchSimpleAgg {
39    pub fn new(core: generic::Agg<PlanRef>) -> Self {
40        let input_dist = core.input.distribution().clone();
41        let base = PlanBase::new_batch_with_core(&core, input_dist, Order::any());
42        BatchSimpleAgg { base, core }
43    }
44
45    pub fn agg_calls(&self) -> &[PlanAggCall] {
46        &self.core.agg_calls
47    }
48
49    fn two_phase_agg_enabled(&self) -> bool {
50        self.base
51            .ctx()
52            .session_ctx()
53            .config()
54            .enable_two_phase_agg()
55    }
56
57    pub(crate) fn can_two_phase_agg(&self) -> bool {
58        self.core.can_two_phase_agg()
59            && self
60                .core
61                // Ban two phase approx percentile.
62                .agg_calls
63                .iter()
64                .map(|agg_call| &agg_call.agg_type)
65                .all(|agg_type| !matches!(agg_type, AggType::Builtin(PbAggKind::ApproxPercentile)))
66            && self.two_phase_agg_enabled()
67    }
68}
69
70impl PlanTreeNodeUnary<Batch> for BatchSimpleAgg {
71    fn input(&self) -> PlanRef {
72        self.core.input.clone()
73    }
74
75    fn clone_with_input(&self, input: PlanRef) -> Self {
76        Self::new(generic::Agg {
77            input,
78            ..self.core.clone()
79        })
80    }
81}
82impl_plan_tree_node_for_unary! { Batch, BatchSimpleAgg }
83impl_distill_by_unit!(BatchSimpleAgg, core, "BatchSimpleAgg");
84
85impl ToDistributedBatch for BatchSimpleAgg {
86    fn to_distributed(&self) -> Result<PlanRef> {
87        // Ensure input is distributed, batch phase might not distribute it
88        // (e.g. see distribution of BatchSeqScan::new vs BatchSeqScan::to_distributed)
89        let dist_input = self.input().to_distributed()?;
90
91        // TODO: distinct agg cannot use 2-phase agg yet.
92        if dist_input.distribution().satisfies(&RequiredDist::AnyShard) && self.can_two_phase_agg()
93        {
94            // partial agg
95            let partial_agg = self.clone_with_input(dist_input).into();
96
97            // insert exchange
98            let exchange =
99                BatchExchange::new(partial_agg, Order::any(), Distribution::Single).into();
100
101            // insert total agg
102            let total_agg_types = self
103                .core
104                .agg_calls
105                .iter()
106                .enumerate()
107                .map(|(partial_output_idx, agg_call)| {
108                    agg_call.partial_to_total_agg_call(partial_output_idx)
109                })
110                .collect();
111            let total_agg_logical =
112                generic::Agg::new(total_agg_types, self.core.group_key.clone(), exchange);
113            Ok(BatchSimpleAgg::new(total_agg_logical).into())
114        } else {
115            let new_input = self
116                .input()
117                .to_distributed_with_required(&Order::any(), &RequiredDist::single())?;
118            Ok(self.clone_with_input(new_input).into())
119        }
120    }
121}
122
123impl ToBatchPb for BatchSimpleAgg {
124    fn to_batch_prost_body(&self) -> NodeBody {
125        NodeBody::SortAgg(SortAggNode {
126            agg_calls: self
127                .agg_calls()
128                .iter()
129                .map(PlanAggCall::to_protobuf)
130                .collect(),
131            // We treat simple agg as a special sort agg without group key.
132            group_key: vec![],
133        })
134    }
135}
136
137impl ToLocalBatch for BatchSimpleAgg {
138    fn to_local(&self) -> Result<PlanRef> {
139        let new_input = self.input().to_local()?;
140
141        let new_input =
142            RequiredDist::single().batch_enforce_if_not_satisfies(new_input, &Order::any())?;
143
144        Ok(self.clone_with_input(new_input).into())
145    }
146}
147
148impl ExprRewritable<Batch> for BatchSimpleAgg {
149    fn has_rewritable_expr(&self) -> bool {
150        true
151    }
152
153    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
154        let mut core = self.core.clone();
155        core.rewrite_exprs(r);
156        Self::new(core).into()
157    }
158}
159
160impl ExprVisitable for BatchSimpleAgg {
161    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
162        self.core.visit_exprs(v);
163    }
164}