risingwave_frontend/optimizer/plan_node/
batch_simple_agg.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use risingwave_expr::aggregate::{AggType, PbAggKind};
16use risingwave_pb::batch_plan::SortAggNode;
17use risingwave_pb::batch_plan::plan_node::NodeBody;
18
19use super::batch::prelude::*;
20use super::generic::{self, PlanAggCall};
21use super::utils::impl_distill_by_unit;
22use super::{ExprRewritable, PlanBase, PlanRef, PlanTreeNodeUnary, ToBatchPb, ToDistributedBatch};
23use crate::error::Result;
24use crate::expr::{ExprRewriter, ExprVisitor};
25use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
26use crate::optimizer::plan_node::{BatchExchange, ToLocalBatch};
27use crate::optimizer::property::{Distribution, Order, RequiredDist};
28
29#[derive(Debug, Clone, PartialEq, Eq, Hash)]
30pub struct BatchSimpleAgg {
31    pub base: PlanBase<Batch>,
32    pub core: generic::Agg<PlanRef>,
33}
34
35impl BatchSimpleAgg {
36    pub fn new(core: generic::Agg<PlanRef>) -> Self {
37        let input_dist = core.input.distribution().clone();
38        let base = PlanBase::new_batch_with_core(&core, input_dist, Order::any());
39        BatchSimpleAgg { base, core }
40    }
41
42    pub fn agg_calls(&self) -> &[PlanAggCall] {
43        &self.core.agg_calls
44    }
45
46    fn two_phase_agg_enabled(&self) -> bool {
47        self.base
48            .ctx()
49            .session_ctx()
50            .config()
51            .enable_two_phase_agg()
52    }
53
54    pub(crate) fn can_two_phase_agg(&self) -> bool {
55        self.core.can_two_phase_agg()
56            && self
57                .core
58                // Ban two phase approx percentile.
59                .agg_calls
60                .iter()
61                .map(|agg_call| &agg_call.agg_type)
62                .all(|agg_type| !matches!(agg_type, AggType::Builtin(PbAggKind::ApproxPercentile)))
63            && self.two_phase_agg_enabled()
64    }
65}
66
67impl PlanTreeNodeUnary for BatchSimpleAgg {
68    fn input(&self) -> PlanRef {
69        self.core.input.clone()
70    }
71
72    fn clone_with_input(&self, input: PlanRef) -> Self {
73        Self::new(generic::Agg {
74            input,
75            ..self.core.clone()
76        })
77    }
78}
79impl_plan_tree_node_for_unary! { BatchSimpleAgg }
80impl_distill_by_unit!(BatchSimpleAgg, core, "BatchSimpleAgg");
81
82impl ToDistributedBatch for BatchSimpleAgg {
83    fn to_distributed(&self) -> Result<PlanRef> {
84        // Ensure input is distributed, batch phase might not distribute it
85        // (e.g. see distribution of BatchSeqScan::new vs BatchSeqScan::to_distributed)
86        let dist_input = self.input().to_distributed()?;
87
88        // TODO: distinct agg cannot use 2-phase agg yet.
89        if dist_input.distribution().satisfies(&RequiredDist::AnyShard) && self.can_two_phase_agg()
90        {
91            // partial agg
92            let partial_agg = self.clone_with_input(dist_input).into();
93
94            // insert exchange
95            let exchange =
96                BatchExchange::new(partial_agg, Order::any(), Distribution::Single).into();
97
98            // insert total agg
99            let total_agg_types = self
100                .core
101                .agg_calls
102                .iter()
103                .enumerate()
104                .map(|(partial_output_idx, agg_call)| {
105                    agg_call.partial_to_total_agg_call(partial_output_idx)
106                })
107                .collect();
108            let total_agg_logical =
109                generic::Agg::new(total_agg_types, self.core.group_key.clone(), exchange);
110            Ok(BatchSimpleAgg::new(total_agg_logical).into())
111        } else {
112            let new_input = self
113                .input()
114                .to_distributed_with_required(&Order::any(), &RequiredDist::single())?;
115            Ok(self.clone_with_input(new_input).into())
116        }
117    }
118}
119
120impl ToBatchPb for BatchSimpleAgg {
121    fn to_batch_prost_body(&self) -> NodeBody {
122        NodeBody::SortAgg(SortAggNode {
123            agg_calls: self
124                .agg_calls()
125                .iter()
126                .map(PlanAggCall::to_protobuf)
127                .collect(),
128            // We treat simple agg as a special sort agg without group key.
129            group_key: vec![],
130        })
131    }
132}
133
134impl ToLocalBatch for BatchSimpleAgg {
135    fn to_local(&self) -> Result<PlanRef> {
136        let new_input = self.input().to_local()?;
137
138        let new_input =
139            RequiredDist::single().enforce_if_not_satisfies(new_input, &Order::any())?;
140
141        Ok(self.clone_with_input(new_input).into())
142    }
143}
144
145impl ExprRewritable for BatchSimpleAgg {
146    fn has_rewritable_expr(&self) -> bool {
147        true
148    }
149
150    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
151        let mut core = self.core.clone();
152        core.rewrite_exprs(r);
153        Self::new(core).into()
154    }
155}
156
157impl ExprVisitable for BatchSimpleAgg {
158    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
159        self.core.visit_exprs(v);
160    }
161}