risingwave_frontend/optimizer/plan_node/
batch_hash_agg.rs1use itertools::Itertools;
16use risingwave_pb::batch_plan::HashAggNode;
17use risingwave_pb::batch_plan::plan_node::NodeBody;
18
19use super::batch::prelude::*;
20use super::generic::{self, PlanAggCall};
21use super::utils::impl_distill_by_unit;
22use super::{
23 ExprRewritable, PlanBase, PlanNodeType, PlanRef, PlanTreeNodeUnary, ToBatchPb,
24 ToDistributedBatch,
25};
26use crate::error::Result;
27use crate::expr::{ExprRewriter, ExprVisitor};
28use crate::optimizer::plan_node::ToLocalBatch;
29use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
30use crate::optimizer::property::{Distribution, Order, RequiredDist};
31use crate::utils::{ColIndexMappingRewriteExt, IndexSet};
32
33#[derive(Debug, Clone, PartialEq, Eq, Hash)]
34pub struct BatchHashAgg {
35 pub base: PlanBase<Batch>,
36 core: generic::Agg<PlanRef>,
37}
38
39impl BatchHashAgg {
40 pub fn new(core: generic::Agg<PlanRef>) -> Self {
41 assert!(!core.group_key.is_empty());
42 let input = core.input.clone();
43 let input_dist = input.distribution();
44 let dist = core
45 .i2o_col_mapping()
46 .rewrite_provided_distribution(input_dist);
47 let base = PlanBase::new_batch_with_core(&core, dist, Order::any());
48 BatchHashAgg { base, core }
49 }
50
51 pub fn agg_calls(&self) -> &[PlanAggCall] {
52 &self.core.agg_calls
53 }
54
55 pub fn group_key(&self) -> &IndexSet {
56 &self.core.group_key
57 }
58
59 fn to_two_phase_agg(&self, dist_input: PlanRef) -> Result<PlanRef> {
60 let partial_agg: PlanRef = self.clone_with_input(dist_input).into();
62 debug_assert!(partial_agg.node_type() == PlanNodeType::BatchHashAgg);
63
64 let exchange = RequiredDist::shard_by_key(
66 partial_agg.schema().len(),
67 &(0..self.group_key().len()).collect_vec(),
68 )
69 .enforce_if_not_satisfies(partial_agg, &Order::any())?;
70
71 let total_agg_types = self
73 .core
74 .agg_calls
75 .iter()
76 .enumerate()
77 .map(|(partial_output_idx, agg_call)| {
78 agg_call.partial_to_total_agg_call(partial_output_idx + self.group_key().len())
79 })
80 .collect();
81 let total_agg_logical = generic::Agg::new(
82 total_agg_types,
83 (0..self.group_key().len()).collect(),
84 exchange,
85 );
86 Ok(BatchHashAgg::new(total_agg_logical).into())
87 }
88
89 fn to_shuffle_agg(&self) -> Result<PlanRef> {
90 let input = self.input();
91 let required_dist = RequiredDist::shard_by_key(
92 input.schema().len(),
93 &self.group_key().indices().collect_vec(),
94 );
95 let new_input = input.to_distributed_with_required(&Order::any(), &required_dist)?;
96 Ok(self.clone_with_input(new_input).into())
97 }
98}
99
100impl_distill_by_unit!(BatchHashAgg, core, "BatchHashAgg");
101
102impl PlanTreeNodeUnary for BatchHashAgg {
103 fn input(&self) -> PlanRef {
104 self.core.input.clone()
105 }
106
107 fn clone_with_input(&self, input: PlanRef) -> Self {
108 let mut core = self.core.clone();
109 core.input = input;
110 Self::new(core)
111 }
112}
113
114impl_plan_tree_node_for_unary! { BatchHashAgg }
115impl ToDistributedBatch for BatchHashAgg {
116 fn to_distributed(&self) -> Result<PlanRef> {
117 if self.core.must_try_two_phase_agg() {
118 let input = self.input().to_distributed()?;
119 let input_dist = input.distribution();
120 if !self.core.hash_agg_dist_satisfied_by_input_dist(input_dist)
121 && matches!(
122 input_dist,
123 Distribution::HashShard(_)
124 | Distribution::UpstreamHashShard(_, _)
125 | Distribution::SomeShard
126 )
127 {
128 return self.to_two_phase_agg(input);
129 }
130 }
131 self.to_shuffle_agg()
132 }
133}
134
135impl ToBatchPb for BatchHashAgg {
136 fn to_batch_prost_body(&self) -> NodeBody {
137 NodeBody::HashAgg(HashAggNode {
138 agg_calls: self
139 .agg_calls()
140 .iter()
141 .map(PlanAggCall::to_protobuf)
142 .collect(),
143 group_key: self.group_key().to_vec_as_u32(),
144 })
145 }
146}
147
148impl ToLocalBatch for BatchHashAgg {
149 fn to_local(&self) -> Result<PlanRef> {
150 let new_input = self.input().to_local()?;
151
152 let new_input =
153 RequiredDist::single().enforce_if_not_satisfies(new_input, &Order::any())?;
154
155 Ok(self.clone_with_input(new_input).into())
156 }
157}
158
159impl ExprRewritable for BatchHashAgg {
160 fn has_rewritable_expr(&self) -> bool {
161 true
162 }
163
164 fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
165 let mut core = self.core.clone();
166 core.rewrite_exprs(r);
167 Self::new(core).into()
168 }
169}
170
171impl ExprVisitable for BatchHashAgg {
172 fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
173 self.core.visit_exprs(v);
174 }
175}