risingwave_frontend/optimizer/plan_node/
batch_hash_agg.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16use risingwave_pb::batch_plan::HashAggNode;
17use risingwave_pb::batch_plan::plan_node::NodeBody;
18
19use super::batch::prelude::*;
20use super::generic::{self, PlanAggCall};
21use super::utils::impl_distill_by_unit;
22use super::{
23    ExprRewritable, PlanBase, PlanNodeType, PlanRef, PlanTreeNodeUnary, ToBatchPb,
24    ToDistributedBatch,
25};
26use crate::error::Result;
27use crate::expr::{ExprRewriter, ExprVisitor};
28use crate::optimizer::plan_node::ToLocalBatch;
29use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
30use crate::optimizer::property::{Distribution, Order, RequiredDist};
31use crate::utils::{ColIndexMappingRewriteExt, IndexSet};
32
33#[derive(Debug, Clone, PartialEq, Eq, Hash)]
34pub struct BatchHashAgg {
35    pub base: PlanBase<Batch>,
36    core: generic::Agg<PlanRef>,
37}
38
39impl BatchHashAgg {
40    pub fn new(core: generic::Agg<PlanRef>) -> Self {
41        assert!(!core.group_key.is_empty());
42        let input = core.input.clone();
43        let input_dist = input.distribution();
44        let dist = core
45            .i2o_col_mapping()
46            .rewrite_provided_distribution(input_dist);
47        let base = PlanBase::new_batch_with_core(&core, dist, Order::any());
48        BatchHashAgg { base, core }
49    }
50
51    pub fn agg_calls(&self) -> &[PlanAggCall] {
52        &self.core.agg_calls
53    }
54
55    pub fn group_key(&self) -> &IndexSet {
56        &self.core.group_key
57    }
58
59    fn to_two_phase_agg(&self, dist_input: PlanRef) -> Result<PlanRef> {
60        // partial agg - follows input distribution
61        let partial_agg: PlanRef = self.clone_with_input(dist_input).into();
62        debug_assert!(partial_agg.node_type() == PlanNodeType::BatchHashAgg);
63
64        // insert exchange
65        let exchange = RequiredDist::shard_by_key(
66            partial_agg.schema().len(),
67            &(0..self.group_key().len()).collect_vec(),
68        )
69        .enforce_if_not_satisfies(partial_agg, &Order::any())?;
70
71        // insert total agg
72        let total_agg_types = self
73            .core
74            .agg_calls
75            .iter()
76            .enumerate()
77            .map(|(partial_output_idx, agg_call)| {
78                agg_call.partial_to_total_agg_call(partial_output_idx + self.group_key().len())
79            })
80            .collect();
81        let total_agg_logical = generic::Agg::new(
82            total_agg_types,
83            (0..self.group_key().len()).collect(),
84            exchange,
85        );
86        Ok(BatchHashAgg::new(total_agg_logical).into())
87    }
88
89    fn to_shuffle_agg(&self) -> Result<PlanRef> {
90        let input = self.input();
91        let required_dist = RequiredDist::shard_by_key(
92            input.schema().len(),
93            &self.group_key().indices().collect_vec(),
94        );
95        let new_input = input.to_distributed_with_required(&Order::any(), &required_dist)?;
96        Ok(self.clone_with_input(new_input).into())
97    }
98}
99
100impl_distill_by_unit!(BatchHashAgg, core, "BatchHashAgg");
101
102impl PlanTreeNodeUnary for BatchHashAgg {
103    fn input(&self) -> PlanRef {
104        self.core.input.clone()
105    }
106
107    fn clone_with_input(&self, input: PlanRef) -> Self {
108        let mut core = self.core.clone();
109        core.input = input;
110        Self::new(core)
111    }
112}
113
114impl_plan_tree_node_for_unary! { BatchHashAgg }
115impl ToDistributedBatch for BatchHashAgg {
116    fn to_distributed(&self) -> Result<PlanRef> {
117        if self.core.must_try_two_phase_agg() {
118            let input = self.input().to_distributed()?;
119            let input_dist = input.distribution();
120            if !self.core.hash_agg_dist_satisfied_by_input_dist(input_dist)
121                && matches!(
122                    input_dist,
123                    Distribution::HashShard(_)
124                        | Distribution::UpstreamHashShard(_, _)
125                        | Distribution::SomeShard
126                )
127            {
128                return self.to_two_phase_agg(input);
129            }
130        }
131        self.to_shuffle_agg()
132    }
133}
134
135impl ToBatchPb for BatchHashAgg {
136    fn to_batch_prost_body(&self) -> NodeBody {
137        NodeBody::HashAgg(HashAggNode {
138            agg_calls: self
139                .agg_calls()
140                .iter()
141                .map(PlanAggCall::to_protobuf)
142                .collect(),
143            group_key: self.group_key().to_vec_as_u32(),
144        })
145    }
146}
147
148impl ToLocalBatch for BatchHashAgg {
149    fn to_local(&self) -> Result<PlanRef> {
150        let new_input = self.input().to_local()?;
151
152        let new_input =
153            RequiredDist::single().enforce_if_not_satisfies(new_input, &Order::any())?;
154
155        Ok(self.clone_with_input(new_input).into())
156    }
157}
158
159impl ExprRewritable for BatchHashAgg {
160    fn has_rewritable_expr(&self) -> bool {
161        true
162    }
163
164    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
165        let mut core = self.core.clone();
166        core.rewrite_exprs(r);
167        Self::new(core).into()
168    }
169}
170
171impl ExprVisitable for BatchHashAgg {
172    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
173        self.core.visit_exprs(v);
174    }
175}