1use fixedbitset::FixedBitSet;
16use itertools::Itertools;
17use risingwave_common::types::{DataType, ScalarImpl};
18use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
19use risingwave_common::{bail, bail_not_implemented, not_implemented};
20use risingwave_expr::aggregate::{AggType, PbAggKind, agg_types};
21
22use super::generic::{self, Agg, GenericPlanRef, PlanAggCall, ProjectBuilder};
23use super::utils::impl_distill_by_unit;
24use super::{
25 BatchHashAgg, BatchSimpleAgg, ColPrunable, ExprRewritable, Logical, LogicalPlanRef as PlanRef,
26 PlanBase, PlanTreeNodeUnary, PredicatePushdown, StreamHashAgg, StreamPlanRef, StreamProject,
27 StreamShare, StreamSimpleAgg, StreamStatelessSimpleAgg, ToBatch, ToStream,
28 try_enforce_locality_requirement,
29};
30use crate::error::{ErrorCode, Result, RwError};
31use crate::expr::{
32 AggCall, Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, FunctionCall, InputRef, Literal,
33 OrderBy, OrderByExpr, WindowFunction,
34};
35use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
36use crate::optimizer::plan_node::generic::GenericPlanNode;
37use crate::optimizer::plan_node::stream_global_approx_percentile::StreamGlobalApproxPercentile;
38use crate::optimizer::plan_node::stream_local_approx_percentile::StreamLocalApproxPercentile;
39use crate::optimizer::plan_node::stream_row_merge::StreamRowMerge;
40use crate::optimizer::plan_node::{
41 BatchSortAgg, ColumnPruningContext, LogicalDedup, LogicalProject, PredicatePushdownContext,
42 RewriteStreamContext, ToStreamContext, gen_filter_and_pushdown,
43};
44use crate::optimizer::property::{Distribution, Order, RequiredDist};
45use crate::utils::{
46 ColIndexMapping, ColIndexMappingRewriteExt, Condition, GroupBy, IndexSet, Substitute,
47};
48
49pub struct AggInfo {
50 pub calls: Vec<PlanAggCall>,
51 pub col_mapping: ColIndexMapping,
52}
53
54pub struct SeparatedAggInfo {
56 normal: AggInfo,
57 approx: AggInfo,
58}
59
60#[derive(Clone, Debug, PartialEq, Eq, Hash)]
67pub struct LogicalAgg {
68 pub base: PlanBase<Logical>,
69 core: Agg<PlanRef>,
70}
71
72impl LogicalAgg {
73 fn gen_stateless_two_phase_streaming_agg_plan(
76 &self,
77 stream_input: StreamPlanRef,
78 ) -> Result<StreamPlanRef> {
79 debug_assert!(self.group_key().is_empty());
80
81 let (
83 non_approx_percentile_col_mapping,
84 approx_percentile_col_mapping,
85 approx_percentile,
86 core,
87 ) = self.prepare_approx_percentile(stream_input)?;
88
89 if core.agg_calls.is_empty() {
90 if let Some(approx_percentile) = approx_percentile {
91 return Ok(approx_percentile);
92 };
93 bail!("expected at least one agg call");
94 }
95
96 let need_row_merge: bool = Self::need_row_merge(&approx_percentile);
97
98 let total_agg_calls = core
100 .agg_calls
101 .iter()
102 .enumerate()
103 .map(|(partial_output_idx, agg_call)| {
104 agg_call.partial_to_total_agg_call(partial_output_idx)
105 })
106 .collect_vec();
107 let local_agg = StreamStatelessSimpleAgg::new(core)?;
108 let exchange =
109 RequiredDist::single().streaming_enforce_if_not_satisfies(local_agg.into())?;
110
111 let must_output_per_barrier = need_row_merge;
112 let global_agg = new_stream_simple_agg(
113 Agg::new(total_agg_calls, IndexSet::empty(), exchange),
114 must_output_per_barrier,
115 )?;
116
117 Self::add_row_merge_if_needed(
119 approx_percentile,
120 global_agg.into(),
121 approx_percentile_col_mapping,
122 non_approx_percentile_col_mapping,
123 )
124 }
125
126 fn gen_vnode_two_phase_streaming_agg_plan(
130 &self,
131 stream_input: StreamPlanRef,
132 dist_key: &[usize],
133 ) -> Result<StreamPlanRef> {
134 let (
135 non_approx_percentile_col_mapping,
136 approx_percentile_col_mapping,
137 approx_percentile,
138 core,
139 ) = self.prepare_approx_percentile(stream_input.clone())?;
140
141 if core.agg_calls.is_empty() {
142 if let Some(approx_percentile) = approx_percentile {
143 return Ok(approx_percentile);
144 };
145 bail!("expected at least one agg call");
146 }
147 let need_row_merge = Self::need_row_merge(&approx_percentile);
148
149 let input_col_num = stream_input.schema().len(); let project = StreamProject::new(generic::Project::with_vnode_col(stream_input, dist_key));
153 let vnode_col_idx = project.base.schema().len() - 1;
154
155 let mut local_group_key = self.group_key().clone();
157 local_group_key.insert(vnode_col_idx);
158 let n_local_group_key = local_group_key.len();
159 let local_agg = new_stream_hash_agg(
160 Agg::new(core.agg_calls.clone(), local_group_key, project.into()),
161 Some(vnode_col_idx),
162 )?;
163 let local_agg_group_key_cardinality = local_agg.group_key().len();
165 let local_group_key_without_vnode =
166 &local_agg.group_key().to_vec()[..local_agg_group_key_cardinality - 1];
167 let global_group_key = local_agg
168 .i2o_col_mapping()
169 .rewrite_dist_key(local_group_key_without_vnode)
170 .expect("some input group key could not be mapped");
171
172 let global_agg = if self.group_key().is_empty() {
174 let exchange =
175 RequiredDist::single().streaming_enforce_if_not_satisfies(local_agg.into())?;
176 let must_output_per_barrier = need_row_merge;
177 let global_agg = new_stream_simple_agg(
178 Agg::new(
179 core.agg_calls
180 .iter()
181 .enumerate()
182 .map(|(partial_output_idx, agg_call)| {
183 agg_call
184 .partial_to_total_agg_call(n_local_group_key + partial_output_idx)
185 })
186 .collect(),
187 global_group_key.into_iter().collect(),
188 exchange,
189 ),
190 must_output_per_barrier,
191 )?;
192 global_agg.into()
193 } else {
194 assert!(!need_row_merge);
196 let exchange = RequiredDist::shard_by_key(input_col_num, &global_group_key)
197 .streaming_enforce_if_not_satisfies(local_agg.into())?;
198 let global_agg = new_stream_hash_agg(
201 Agg::new(
202 core.agg_calls
203 .iter()
204 .enumerate()
205 .map(|(partial_output_idx, agg_call)| {
206 agg_call
207 .partial_to_total_agg_call(n_local_group_key + partial_output_idx)
208 })
209 .collect(),
210 global_group_key.into_iter().collect(),
211 exchange,
212 ),
213 None,
214 )?;
215 global_agg.into()
216 };
217 Self::add_row_merge_if_needed(
218 approx_percentile,
219 global_agg,
220 approx_percentile_col_mapping,
221 non_approx_percentile_col_mapping,
222 )
223 }
224
225 fn gen_single_plan(&self, stream_input: StreamPlanRef) -> Result<StreamPlanRef> {
226 let input = RequiredDist::single().streaming_enforce_if_not_satisfies(stream_input)?;
227 let core = self.core.clone_with_input(input);
228 Ok(new_stream_simple_agg(core, false)?.into())
229 }
230
231 fn gen_shuffle_plan(&self, stream_input: StreamPlanRef) -> Result<StreamPlanRef> {
232 let input =
233 RequiredDist::shard_by_key(stream_input.schema().len(), &self.group_key().to_vec())
234 .streaming_enforce_if_not_satisfies(stream_input)?;
235 let core = self.core.clone_with_input(input);
236 Ok(new_stream_hash_agg(core, None)?.into())
237 }
238
239 fn gen_dist_stream_agg_plan(&self, stream_input: StreamPlanRef) -> Result<StreamPlanRef> {
241 use super::stream::prelude::*;
242
243 let input_dist = stream_input.distribution();
244 debug_assert!(*input_dist != Distribution::Broadcast);
245
246 if !self.group_key().is_empty() && !self.core.must_try_two_phase_agg() {
250 return self.gen_shuffle_plan(stream_input);
251 }
252
253 if self.group_key().is_empty() && !self.core.can_two_phase_agg() {
256 return self.gen_single_plan(stream_input);
257 }
258
259 debug_assert!(if !self.group_key().is_empty() {
260 self.core.must_try_two_phase_agg()
261 } else {
262 self.core.can_two_phase_agg()
263 });
264
265 if self.group_key().is_empty()
269 && self
270 .core
271 .all_local_aggs_are_stateless(stream_input.append_only())
272 && input_dist.satisfies(&RequiredDist::AnyShard)
273 {
274 return self.gen_stateless_two_phase_streaming_agg_plan(stream_input);
275 }
276
277 let stream_input =
282 if *input_dist == Distribution::SomeShard && self.core.must_try_two_phase_agg() {
283 RequiredDist::shard_by_key(
284 stream_input.schema().len(),
285 stream_input.expect_stream_key(),
286 )
287 .streaming_enforce_if_not_satisfies(stream_input)?
288 } else {
289 stream_input
290 };
291 let input_dist = stream_input.distribution();
292
293 match input_dist {
297 Distribution::HashShard(dist_key) | Distribution::UpstreamHashShard(dist_key, _)
298 if (self.group_key().is_empty()
299 || !self.core.hash_agg_dist_satisfied_by_input_dist(input_dist)) =>
300 {
301 let dist_key = dist_key.clone();
302 return self.gen_vnode_two_phase_streaming_agg_plan(stream_input, &dist_key);
303 }
304 _ => {}
305 }
306
307 if !self.group_key().is_empty() {
309 self.gen_shuffle_plan(stream_input)
310 } else {
311 self.gen_single_plan(stream_input)
312 }
313 }
314
315 fn prepare_approx_percentile(
320 &self,
321 stream_input: StreamPlanRef,
322 ) -> Result<(
323 ColIndexMapping,
324 ColIndexMapping,
325 Option<StreamPlanRef>,
326 Agg<StreamPlanRef>,
327 )> {
328 let SeparatedAggInfo { normal, approx } = self.separate_normal_and_special_agg();
329
330 let AggInfo {
331 calls: non_approx_percentile_agg_calls,
332 col_mapping: non_approx_percentile_col_mapping,
333 } = normal;
334 let AggInfo {
335 calls: approx_percentile_agg_calls,
336 col_mapping: approx_percentile_col_mapping,
337 } = approx;
338 if !self.group_key().is_empty() && !approx_percentile_agg_calls.is_empty() {
339 bail_not_implemented!(
340 "two-phase streaming approx percentile aggregation with group key, \
341 please use single phase aggregation instead"
342 );
343 }
344
345 let needs_row_merge = (!non_approx_percentile_agg_calls.is_empty()
348 && !approx_percentile_agg_calls.is_empty())
349 || approx_percentile_agg_calls.len() >= 2;
350 let input = if needs_row_merge {
351 StreamShare::new_from_input(stream_input).into()
353 } else {
354 stream_input
355 };
356 let mut core = self.core.clone_with_input(input);
357 core.agg_calls = non_approx_percentile_agg_calls;
358
359 let approx_percentile =
360 self.build_approx_percentile_aggs(core.input.clone(), &approx_percentile_agg_calls)?;
361 Ok((
362 non_approx_percentile_col_mapping,
363 approx_percentile_col_mapping,
364 approx_percentile,
365 core,
366 ))
367 }
368
369 fn need_row_merge(approx_percentile: &Option<StreamPlanRef>) -> bool {
370 approx_percentile.is_some()
371 }
372
373 fn add_row_merge_if_needed(
375 approx_percentile: Option<StreamPlanRef>,
376 global_agg: StreamPlanRef,
377 approx_percentile_col_mapping: ColIndexMapping,
378 non_approx_percentile_col_mapping: ColIndexMapping,
379 ) -> Result<StreamPlanRef> {
380 let need_row_merge = Self::need_row_merge(&approx_percentile);
382
383 if let Some(approx_percentile) = approx_percentile {
384 assert!(need_row_merge);
385 let row_merge = StreamRowMerge::new(
386 approx_percentile,
387 global_agg,
388 approx_percentile_col_mapping,
389 non_approx_percentile_col_mapping,
390 )?;
391 Ok(row_merge.into())
392 } else {
393 assert!(!need_row_merge);
394 Ok(global_agg)
395 }
396 }
397
398 fn separate_normal_and_special_agg(&self) -> SeparatedAggInfo {
399 let estimated_len = self.agg_calls().len() - 1;
400 let mut approx_percentile_agg_calls = Vec::with_capacity(estimated_len);
401 let mut non_approx_percentile_agg_calls = Vec::with_capacity(estimated_len);
402 let mut approx_percentile_col_mapping = Vec::with_capacity(estimated_len);
403 let mut non_approx_percentile_col_mapping = Vec::with_capacity(estimated_len);
404 for (output_idx, agg_call) in self.agg_calls().iter().enumerate() {
405 if agg_call.agg_type == AggType::Builtin(PbAggKind::ApproxPercentile) {
406 approx_percentile_agg_calls.push(agg_call.clone());
407 approx_percentile_col_mapping.push(Some(output_idx));
408 } else {
409 non_approx_percentile_agg_calls.push(agg_call.clone());
410 non_approx_percentile_col_mapping.push(Some(output_idx));
411 }
412 }
413 let normal = AggInfo {
414 calls: non_approx_percentile_agg_calls,
415 col_mapping: ColIndexMapping::new(
416 non_approx_percentile_col_mapping,
417 self.agg_calls().len(),
418 ),
419 };
420 let approx = AggInfo {
421 calls: approx_percentile_agg_calls,
422 col_mapping: ColIndexMapping::new(
423 approx_percentile_col_mapping,
424 self.agg_calls().len(),
425 ),
426 };
427 SeparatedAggInfo { normal, approx }
428 }
429
430 fn build_approx_percentile_agg(
431 &self,
432 input: StreamPlanRef,
433 approx_percentile_agg_call: &PlanAggCall,
434 ) -> Result<StreamPlanRef> {
435 let local_approx_percentile =
436 StreamLocalApproxPercentile::new(input, approx_percentile_agg_call)?;
437 let exchange = RequiredDist::single()
438 .streaming_enforce_if_not_satisfies(local_approx_percentile.into())?;
439 let global_approx_percentile =
440 StreamGlobalApproxPercentile::new(exchange, approx_percentile_agg_call);
441 Ok(global_approx_percentile.into())
442 }
443
444 fn build_approx_percentile_aggs(
457 &self,
458 input: StreamPlanRef,
459 approx_percentile_agg_call: &[PlanAggCall],
460 ) -> Result<Option<StreamPlanRef>> {
461 if approx_percentile_agg_call.is_empty() {
462 return Ok(None);
463 }
464 let approx_percentile_plans: Vec<_> = approx_percentile_agg_call
465 .iter()
466 .map(|agg_call| self.build_approx_percentile_agg(input.clone(), agg_call))
467 .try_collect()?;
468 assert!(!approx_percentile_plans.is_empty());
469 let mut iter = approx_percentile_plans.into_iter();
470 let mut acc = iter.next().unwrap();
471 for (current_size, plan) in iter.enumerate().map(|(i, p)| (i + 1, p)) {
472 let new_size = current_size + 1;
473 let row_merge = StreamRowMerge::new(
474 acc,
475 plan,
476 ColIndexMapping::identity_or_none(current_size, new_size),
477 ColIndexMapping::new(vec![Some(current_size)], new_size),
478 )?;
479 acc = row_merge.into();
480 }
481 Ok(Some(acc))
482 }
483
484 pub fn core(&self) -> &Agg<PlanRef> {
485 &self.core
486 }
487}
488
489pub struct LogicalAggBuilder {
494 input_proj_builder: ProjectBuilder,
496 group_key: IndexSet,
498 grouping_sets: Vec<IndexSet>,
500 agg_calls: Vec<PlanAggCall>,
502 error: Option<RwError>,
504 is_in_filter_clause: bool,
510}
511
512impl LogicalAggBuilder {
513 fn new(group_by: GroupBy, input_schema_len: usize) -> Result<Self> {
514 let mut input_proj_builder = ProjectBuilder::default();
515
516 let mut gen_group_key_and_grouping_sets =
517 |grouping_sets: Vec<Vec<ExprImpl>>| -> Result<(IndexSet, Vec<IndexSet>)> {
518 let grouping_sets: Vec<IndexSet> = grouping_sets
519 .into_iter()
520 .map(|set| {
521 set.into_iter()
522 .map(|expr| input_proj_builder.add_expr(&expr))
523 .try_collect()
524 .map_err(|err| not_implemented!("{err} inside GROUP BY"))
525 })
526 .try_collect()?;
527
528 let group_key = grouping_sets
530 .iter()
531 .fold(FixedBitSet::with_capacity(input_schema_len), |acc, x| {
532 acc.union(&x.to_bitset()).collect()
533 });
534
535 Ok((IndexSet::from_iter(group_key.ones()), grouping_sets))
536 };
537
538 let (group_key, grouping_sets) = match group_by {
539 GroupBy::GroupKey(group_key) => {
540 let group_key = group_key
541 .into_iter()
542 .map(|expr| input_proj_builder.add_expr(&expr))
543 .try_collect()
544 .map_err(|err| not_implemented!("{err} inside GROUP BY"))?;
545 (group_key, vec![])
546 }
547 GroupBy::GroupingSets(grouping_sets) => gen_group_key_and_grouping_sets(grouping_sets)?,
548 GroupBy::Rollup(rollup) => {
549 let grouping_sets = (0..=rollup.len())
551 .map(|n| {
552 rollup
553 .iter()
554 .take(n)
555 .flat_map(|x| x.iter().cloned())
556 .collect_vec()
557 })
558 .collect_vec();
559 gen_group_key_and_grouping_sets(grouping_sets)?
560 }
561 GroupBy::Cube(cube) => {
562 let grouping_sets = cube
564 .into_iter()
565 .powerset()
566 .map(|x| x.into_iter().flatten().collect_vec())
567 .collect_vec();
568 gen_group_key_and_grouping_sets(grouping_sets)?
569 }
570 };
571
572 Ok(LogicalAggBuilder {
573 group_key,
574 grouping_sets,
575 agg_calls: vec![],
576 error: None,
577 input_proj_builder,
578 is_in_filter_clause: false,
579 })
580 }
581
582 pub fn build(self, input: PlanRef) -> LogicalAgg {
583 let logical_project = LogicalProject::with_core(self.input_proj_builder.build(input));
585
586 Agg::new(self.agg_calls, self.group_key, logical_project.into())
588 .with_grouping_sets(self.grouping_sets)
589 .into()
590 }
591
592 fn rewrite_with_error(&mut self, expr: ExprImpl) -> Result<ExprImpl> {
593 let rewritten_expr = self.rewrite_expr(expr);
594 if let Some(error) = self.error.take() {
595 return Err(error);
596 }
597 Ok(rewritten_expr)
598 }
599
600 pub fn try_as_group_expr(&self, expr: &ExprImpl) -> Option<usize> {
602 if let Some(input_index) = self.input_proj_builder.expr_index(expr)
603 && let Some(index) = self
604 .group_key
605 .indices()
606 .position(|group_key| group_key == input_index)
607 {
608 return Some(index);
609 }
610 None
611 }
612
613 fn schema_agg_start_offset(&self) -> usize {
614 self.group_key.len()
615 }
616
617 pub(crate) fn general_rewrite_agg_call(
620 agg_call: AggCall,
621 mut push_agg_call: impl FnMut(AggCall) -> Result<InputRef>,
622 ) -> Result<ExprImpl> {
623 match agg_call.agg_type {
624 AggType::Builtin(PbAggKind::Avg) => {
626 assert_eq!(agg_call.args.len(), 1);
627 let return_type = agg_call.return_type();
628
629 let sum = ExprImpl::from(push_agg_call(AggCall::new(
630 PbAggKind::Sum.into(),
631 agg_call.args.clone(),
632 agg_call.distinct,
633 agg_call.order_by.clone(),
634 agg_call.filter.clone(),
635 agg_call.direct_args.clone(),
636 )?)?)
637 .cast_explicit(&agg_call.return_type())?;
638
639 let count = ExprImpl::from(push_agg_call(AggCall::new(
640 PbAggKind::Count.into(),
641 agg_call.args.clone(),
642 agg_call.distinct,
643 agg_call.order_by.clone(),
644 agg_call.filter.clone(),
645 agg_call.direct_args,
646 )?)?);
647
648 let target = ExprImpl::from(FunctionCall::new(
649 ExprType::Divide,
650 Vec::from([sum, count.clone()]),
651 )?);
652 let null = ExprImpl::from(Literal::new(None, return_type));
653 let zero = ExprImpl::literal_int(0);
654 let case_cond =
655 ExprImpl::from(FunctionCall::new(ExprType::Equal, vec![count, zero])?);
656 Ok(ExprImpl::from(FunctionCall::new(
657 ExprType::Case,
658 vec![case_cond, null, target],
659 )?))
660 }
661 AggType::Builtin(
670 kind @ (PbAggKind::StddevPop
671 | PbAggKind::StddevSamp
672 | PbAggKind::VarPop
673 | PbAggKind::VarSamp),
674 ) => {
675 let arg = agg_call.args().iter().exactly_one().unwrap();
676 let squared_arg = ExprImpl::from(FunctionCall::new(
677 ExprType::Multiply,
678 vec![arg.clone(), arg.clone()],
679 )?);
680
681 let sum_of_sq = ExprImpl::from(push_agg_call(AggCall::new(
682 PbAggKind::Sum.into(),
683 vec![squared_arg],
684 agg_call.distinct,
685 agg_call.order_by.clone(),
686 agg_call.filter.clone(),
687 agg_call.direct_args.clone(),
688 )?)?)
689 .cast_explicit(&agg_call.return_type())?;
690
691 let sum = ExprImpl::from(push_agg_call(AggCall::new(
692 PbAggKind::Sum.into(),
693 agg_call.args.clone(),
694 agg_call.distinct,
695 agg_call.order_by.clone(),
696 agg_call.filter.clone(),
697 agg_call.direct_args.clone(),
698 )?)?)
699 .cast_explicit(&agg_call.return_type())?;
700
701 let count = ExprImpl::from(push_agg_call(AggCall::new(
702 PbAggKind::Count.into(),
703 agg_call.args.clone(),
704 agg_call.distinct,
705 agg_call.order_by.clone(),
706 agg_call.filter.clone(),
707 agg_call.direct_args.clone(),
708 )?)?);
709
710 let zero = ExprImpl::literal_int(0);
711 let one = ExprImpl::literal_int(1);
712
713 let squared_sum = ExprImpl::from(FunctionCall::new(
714 ExprType::Multiply,
715 vec![sum.clone(), sum],
716 )?);
717
718 let raw_numerator = ExprImpl::from(FunctionCall::new(
719 ExprType::Subtract,
720 vec![
721 sum_of_sq,
722 ExprImpl::from(FunctionCall::new(
723 ExprType::Divide,
724 vec![squared_sum, count.clone()],
725 )?),
726 ],
727 )?);
728
729 let numerator_type = raw_numerator.return_type();
731 let numerator = ExprImpl::from(FunctionCall::new(
732 ExprType::Greatest,
733 vec![raw_numerator, zero.clone().cast_explicit(&numerator_type)?],
734 )?);
735
736 let denominator = match kind {
737 PbAggKind::VarPop | PbAggKind::StddevPop => count.clone(),
738 PbAggKind::VarSamp | PbAggKind::StddevSamp => ExprImpl::from(
739 FunctionCall::new(ExprType::Subtract, vec![count.clone(), one.clone()])?,
740 ),
741 _ => unreachable!(),
742 };
743
744 let mut target = ExprImpl::from(FunctionCall::new(
745 ExprType::Divide,
746 vec![numerator, denominator],
747 )?);
748
749 if matches!(kind, PbAggKind::StddevPop | PbAggKind::StddevSamp) {
750 target = ExprImpl::from(FunctionCall::new(ExprType::Sqrt, vec![target])?);
751 }
752
753 let null = ExprImpl::from(Literal::new(None, agg_call.return_type()));
754 let case_cond = match kind {
755 PbAggKind::VarPop | PbAggKind::StddevPop => {
756 ExprImpl::from(FunctionCall::new(ExprType::Equal, vec![count, zero])?)
757 }
758 PbAggKind::VarSamp | PbAggKind::StddevSamp => ExprImpl::from(
759 FunctionCall::new(ExprType::LessThanOrEqual, vec![count, one])?,
760 ),
761 _ => unreachable!(),
762 };
763
764 Ok(ExprImpl::from(FunctionCall::new(
765 ExprType::Case,
766 vec![case_cond, null, target],
767 )?))
768 }
769 AggType::Builtin(PbAggKind::ApproxPercentile) => {
770 if agg_call.order_by.sort_exprs[0].order_type == OrderType::descending() {
771 let prev_percentile = agg_call.direct_args[0].clone();
773 let new_percentile = 1.0
774 - prev_percentile
775 .get_data()
776 .as_ref()
777 .unwrap()
778 .as_float64()
779 .into_inner();
780 let new_percentile = Some(ScalarImpl::Float64(new_percentile.into()));
781 let new_percentile = Literal::new(new_percentile, DataType::Float64);
782 let new_direct_args = vec![new_percentile, agg_call.direct_args[1].clone()];
783
784 let new_agg_call = AggCall {
785 order_by: OrderBy::any(),
786 direct_args: new_direct_args,
787 ..agg_call
788 };
789 Ok(push_agg_call(new_agg_call)?.into())
790 } else {
791 let new_agg_call = AggCall {
792 order_by: OrderBy::any(),
793 ..agg_call
794 };
795 Ok(push_agg_call(new_agg_call)?.into())
796 }
797 }
798 AggType::Builtin(PbAggKind::ArgMin | PbAggKind::ArgMax) => {
799 let mut agg_call = agg_call;
800
801 let comparison_arg_type = agg_call.args[1].return_type();
802 match comparison_arg_type {
803 DataType::Struct(_)
804 | DataType::List(_)
805 | DataType::Map(_)
806 | DataType::Vector(_)
807 | DataType::Jsonb => {
808 bail!(format!(
809 "{} does not support struct, array, map, vector, jsonb for comparison argument, got {}",
810 agg_call.agg_type.to_string(),
811 comparison_arg_type
812 ));
813 }
814 _ => {}
815 }
816
817 let not_null_exprs: Vec<ExprImpl> = agg_call
818 .args
819 .iter()
820 .map(|arg| -> Result<ExprImpl> {
821 Ok(FunctionCall::new(ExprType::IsNotNull, vec![arg.clone()])?.into())
822 })
823 .try_collect()?;
824
825 let comparison_expr = agg_call.args[1].clone();
826 let mut order_exprs = vec![OrderByExpr {
827 expr: comparison_expr,
828 order_type: if agg_call.agg_type == AggType::Builtin(PbAggKind::ArgMin) {
829 OrderType::ascending()
830 } else {
831 OrderType::descending()
832 },
833 }];
834
835 order_exprs.extend(agg_call.order_by.sort_exprs);
836
837 let order_by = OrderBy::new(order_exprs);
838
839 let filter = agg_call.filter.clone().and(Condition {
840 conjunctions: not_null_exprs,
841 });
842
843 agg_call.args.truncate(1);
844
845 let new_agg_call = AggCall {
846 agg_type: AggType::Builtin(PbAggKind::FirstValue),
847 order_by,
848 filter,
849 ..agg_call
850 };
851 Ok(push_agg_call(new_agg_call)?.into())
852 }
853 _ => Ok(push_agg_call(agg_call)?.into()),
854 }
855 }
856
857 fn push_agg_call(&mut self, agg_call: AggCall) -> Result<InputRef> {
861 let AggCall {
862 agg_type,
863 return_type,
864 args,
865 distinct,
866 order_by,
867 filter,
868 direct_args,
869 } = agg_call;
870
871 self.is_in_filter_clause = true;
872 let filter = filter.rewrite_expr(self);
875 self.is_in_filter_clause = false;
876
877 let args: Vec<_> = args
878 .iter()
879 .map(|expr| {
880 let index = self.input_proj_builder.add_expr(expr)?;
881 Ok(InputRef::new(index, expr.return_type()))
882 })
883 .try_collect()
884 .map_err(|err: &'static str| not_implemented!("{err} inside aggregation calls"))?;
885
886 let order_by: Vec<_> = order_by
887 .sort_exprs
888 .iter()
889 .map(|e| {
890 let index = self.input_proj_builder.add_expr(&e.expr)?;
891 Ok(ColumnOrder::new(index, e.order_type))
892 })
893 .try_collect()
894 .map_err(|err: &'static str| {
895 not_implemented!("{err} inside aggregation calls order by")
896 })?;
897
898 let plan_agg_call = PlanAggCall {
899 agg_type,
900 return_type: return_type.clone(),
901 inputs: args,
902 distinct,
903 order_by,
904 filter,
905 direct_args,
906 };
907
908 if let Some((pos, existing)) = self
909 .agg_calls
910 .iter()
911 .find_position(|&c| c == &plan_agg_call)
912 {
913 return Ok(InputRef::new(
914 self.schema_agg_start_offset() + pos,
915 existing.return_type.clone(),
916 ));
917 }
918 let index = self.schema_agg_start_offset() + self.agg_calls.len();
919 self.agg_calls.push(plan_agg_call);
920 Ok(InputRef::new(index, return_type))
921 }
922
923 fn try_rewrite_agg_call(&mut self, mut agg_call: AggCall) -> Result<ExprImpl> {
930 if matches!(agg_call.agg_type, agg_types::must_have_order_by!())
931 && agg_call.order_by.sort_exprs.is_empty()
932 {
933 return Err(ErrorCode::InvalidInputSyntax(format!(
934 "Aggregation function {} requires ORDER BY clause",
935 agg_call.agg_type
936 ))
937 .into());
938 }
939
940 if matches!(
942 agg_call.agg_type,
943 agg_types::result_unaffected_by_order_by!()
944 ) {
945 agg_call.order_by = OrderBy::any();
946 }
947 if matches!(
949 agg_call.agg_type,
950 agg_types::result_unaffected_by_distinct!()
951 ) {
952 agg_call.distinct = false;
953 }
954
955 if matches!(agg_call.agg_type, AggType::Builtin(PbAggKind::Grouping)) {
956 if self.grouping_sets.is_empty() {
957 return Err(ErrorCode::NotSupported(
958 "GROUPING must be used in a query with grouping sets".into(),
959 "try to use grouping sets instead".into(),
960 )
961 .into());
962 }
963 if agg_call.args.len() >= 32 {
964 return Err(ErrorCode::InvalidInputSyntax(
965 "GROUPING must have fewer than 32 arguments".into(),
966 )
967 .into());
968 }
969 if agg_call
970 .args
971 .iter()
972 .any(|x| self.try_as_group_expr(x).is_none())
973 {
974 return Err(ErrorCode::InvalidInputSyntax(
975 "arguments to GROUPING must be grouping expressions of the associated query level"
976 .into(),
977 ).into());
978 }
979 }
980
981 Self::general_rewrite_agg_call(agg_call, |agg_call| self.push_agg_call(agg_call))
982 }
983}
984
985impl ExprRewriter for LogicalAggBuilder {
986 fn rewrite_agg_call(&mut self, agg_call: AggCall) -> ExprImpl {
987 let dummy = Literal::new(None, agg_call.return_type()).into();
988 match self.try_rewrite_agg_call(agg_call) {
989 Ok(expr) => expr,
990 Err(err) => {
991 self.error = Some(err);
992 dummy
993 }
994 }
995 }
996
997 fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
1000 let expr = func_call.into();
1001 if let Some(group_key) = self.try_as_group_expr(&expr) {
1002 InputRef::new(group_key, expr.return_type()).into()
1003 } else {
1004 let (func_type, inputs, ret) = expr.into_function_call().unwrap().decompose();
1005 let inputs = inputs
1006 .into_iter()
1007 .map(|expr| self.rewrite_expr(expr))
1008 .collect();
1009 FunctionCall::new_unchecked(func_type, inputs, ret).into()
1010 }
1011 }
1012
1013 fn rewrite_window_function(&mut self, window_func: WindowFunction) -> ExprImpl {
1016 let WindowFunction {
1017 args,
1018 partition_by,
1019 order_by,
1020 ..
1021 } = window_func;
1022 let args = args
1023 .into_iter()
1024 .map(|expr| self.rewrite_expr(expr))
1025 .collect();
1026 let partition_by = partition_by
1027 .into_iter()
1028 .map(|expr| self.rewrite_expr(expr))
1029 .collect();
1030 let order_by = order_by.rewrite_expr(self);
1031 WindowFunction {
1032 args,
1033 partition_by,
1034 order_by,
1035 ..window_func
1036 }
1037 .into()
1038 }
1039
1040 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
1042 let expr = input_ref.into();
1043 if let Some(group_key) = self.try_as_group_expr(&expr) {
1044 InputRef::new(group_key, expr.return_type()).into()
1045 } else if self.is_in_filter_clause {
1046 InputRef::new(
1047 self.input_proj_builder.add_expr(&expr).unwrap(),
1048 expr.return_type(),
1049 )
1050 .into()
1051 } else {
1052 self.error = Some(
1053 ErrorCode::InvalidInputSyntax(
1054 "column must appear in the GROUP BY clause or be used in an aggregate function"
1055 .into(),
1056 )
1057 .into(),
1058 );
1059 expr
1060 }
1061 }
1062
1063 fn rewrite_subquery(&mut self, subquery: crate::expr::Subquery) -> ExprImpl {
1064 if subquery.is_correlated_by_depth(0) {
1065 self.error = Some(
1066 not_implemented!(
1067 issue = 2275,
1068 "correlated subquery in HAVING or SELECT with agg",
1069 )
1070 .into(),
1071 );
1072 }
1073 subquery.into()
1074 }
1075}
1076
1077impl From<Agg<PlanRef>> for LogicalAgg {
1078 fn from(core: Agg<PlanRef>) -> Self {
1079 let base = PlanBase::new_logical_with_core(&core);
1080 Self { base, core }
1081 }
1082}
1083
1084impl From<Agg<PlanRef>> for PlanRef {
1086 fn from(core: Agg<PlanRef>) -> Self {
1087 LogicalAgg::from(core).into()
1088 }
1089}
1090
1091impl LogicalAgg {
1092 pub fn i2o_col_mapping(&self) -> ColIndexMapping {
1094 self.core.i2o_col_mapping()
1095 }
1096
1097 pub fn create(
1106 select_exprs: Vec<ExprImpl>,
1107 group_by: GroupBy,
1108 having: Option<ExprImpl>,
1109 input: PlanRef,
1110 ) -> Result<(PlanRef, Vec<ExprImpl>, Option<ExprImpl>)> {
1111 let mut agg_builder = LogicalAggBuilder::new(group_by, input.schema().len())?;
1112
1113 let rewritten_select_exprs = select_exprs
1114 .into_iter()
1115 .map(|expr| agg_builder.rewrite_with_error(expr))
1116 .collect::<Result<_>>()?;
1117 let rewritten_having = having
1118 .map(|expr| agg_builder.rewrite_with_error(expr))
1119 .transpose()?;
1120
1121 Ok((
1122 agg_builder.build(input).into(),
1123 rewritten_select_exprs,
1124 rewritten_having,
1125 ))
1126 }
1127
1128 pub fn agg_calls(&self) -> &Vec<PlanAggCall> {
1130 &self.core.agg_calls
1131 }
1132
1133 pub fn group_key(&self) -> &IndexSet {
1135 &self.core.group_key
1136 }
1137
1138 pub fn grouping_sets(&self) -> &Vec<IndexSet> {
1139 &self.core.grouping_sets
1140 }
1141
1142 pub fn decompose(self) -> (Vec<PlanAggCall>, IndexSet, Vec<IndexSet>, PlanRef, bool) {
1143 self.core.decompose()
1144 }
1145
1146 #[must_use]
1147 pub fn rewrite_with_input_agg(
1148 &self,
1149 input: PlanRef,
1150 agg_calls: &[PlanAggCall],
1151 mut input_col_change: ColIndexMapping,
1152 ) -> (Self, ColIndexMapping) {
1153 let agg_calls = agg_calls
1154 .iter()
1155 .cloned()
1156 .map(|mut agg_call| {
1157 agg_call.inputs.iter_mut().for_each(|i| {
1158 *i = InputRef::new(input_col_change.map(i.index()), i.return_type())
1159 });
1160 agg_call.order_by.iter_mut().for_each(|o| {
1161 o.column_index = input_col_change.map(o.column_index);
1162 });
1163 agg_call.filter = agg_call.filter.rewrite_expr(&mut input_col_change);
1164 agg_call
1165 })
1166 .collect();
1167 let group_key_in_vec: Vec<usize> = self
1169 .group_key()
1170 .indices()
1171 .map(|key| input_col_change.map(key))
1172 .collect();
1173 let group_key: IndexSet = group_key_in_vec.iter().cloned().collect();
1175 let grouping_sets = self
1176 .grouping_sets()
1177 .iter()
1178 .map(|set| set.indices().map(|key| input_col_change.map(key)).collect())
1179 .collect();
1180
1181 let new_agg = Agg::new(agg_calls, group_key.clone(), input)
1182 .with_grouping_sets(grouping_sets)
1183 .with_enable_two_phase(self.core().enable_two_phase);
1184
1185 let mut out_col_change = vec![];
1188 for idx in group_key_in_vec {
1189 let pos = group_key.indices().position(|x| x == idx).unwrap();
1190 out_col_change.push(pos);
1191 }
1192 for i in (group_key.len())..new_agg.schema().len() {
1193 out_col_change.push(i);
1194 }
1195 let out_col_change =
1196 ColIndexMapping::with_remaining_columns(&out_col_change, new_agg.schema().len());
1197
1198 (new_agg.into(), out_col_change)
1199 }
1200}
1201
1202impl PlanTreeNodeUnary<Logical> for LogicalAgg {
1203 fn input(&self) -> PlanRef {
1204 self.core.input.clone()
1205 }
1206
1207 fn clone_with_input(&self, input: PlanRef) -> Self {
1208 Agg::new(self.agg_calls().clone(), self.group_key().clone(), input)
1209 .with_grouping_sets(self.grouping_sets().clone())
1210 .with_enable_two_phase(self.core().enable_two_phase)
1211 .into()
1212 }
1213
1214 fn rewrite_with_input(
1215 &self,
1216 input: PlanRef,
1217 input_col_change: ColIndexMapping,
1218 ) -> (Self, ColIndexMapping) {
1219 self.rewrite_with_input_agg(input, self.agg_calls(), input_col_change)
1220 }
1221}
1222
1223impl_plan_tree_node_for_unary! { Logical, LogicalAgg }
1224impl_distill_by_unit!(LogicalAgg, core, "LogicalAgg");
1225
1226impl ExprRewritable<Logical> for LogicalAgg {
1227 fn has_rewritable_expr(&self) -> bool {
1228 true
1229 }
1230
1231 fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
1232 let mut core = self.core.clone();
1233 core.rewrite_exprs(r);
1234 Self {
1235 base: self.base.clone_with_new_plan_id(),
1236 core,
1237 }
1238 .into()
1239 }
1240}
1241
1242impl ExprVisitable for LogicalAgg {
1243 fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
1244 self.core.visit_exprs(v);
1245 }
1246}
1247
1248impl ColPrunable for LogicalAgg {
1249 fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
1250 let group_key_required_cols = self.group_key().to_bitset();
1251
1252 let (agg_call_required_cols, agg_calls) = {
1253 let input_cnt = self.input().schema().len();
1254 let mut tmp = FixedBitSet::with_capacity(input_cnt);
1255 let group_key_cardinality = self.group_key().len();
1256 let new_agg_calls = required_cols
1257 .iter()
1258 .filter(|&&index| index >= group_key_cardinality)
1259 .map(|&index| {
1260 let index = index - group_key_cardinality;
1261 let agg_call = self.agg_calls()[index].clone();
1262 tmp.extend(agg_call.inputs.iter().map(|x| x.index()));
1263 tmp.extend(agg_call.order_by.iter().map(|x| x.column_index));
1264 for i in &agg_call.filter.conjunctions {
1266 tmp.union_with(&i.collect_input_refs(input_cnt));
1267 }
1268 agg_call
1269 })
1270 .collect_vec();
1271 (tmp, new_agg_calls)
1272 };
1273
1274 let input_required_cols = {
1275 let mut tmp = FixedBitSet::with_capacity(self.input().schema().len());
1276 tmp.union_with(&group_key_required_cols);
1277 tmp.union_with(&agg_call_required_cols);
1278 tmp.ones().collect_vec()
1279 };
1280 let input_col_change = ColIndexMapping::with_remaining_columns(
1281 &input_required_cols,
1282 self.input().schema().len(),
1283 );
1284 let agg = {
1285 let input = self.input().prune_col(&input_required_cols, ctx);
1286 let (agg, output_col_change) =
1287 self.rewrite_with_input_agg(input, &agg_calls, input_col_change);
1288 assert!(output_col_change.is_identity());
1289 agg
1290 };
1291 let new_output_cols = {
1292 let group_key_cardinality = agg.group_key().len();
1294 let mut tmp = (0..group_key_cardinality).collect_vec();
1295 tmp.extend(
1296 required_cols
1297 .iter()
1298 .filter(|&&index| index >= group_key_cardinality),
1299 );
1300 tmp
1301 };
1302 if new_output_cols == required_cols {
1303 agg.into()
1305 } else {
1306 let mapping =
1309 &ColIndexMapping::with_remaining_columns(&new_output_cols, self.schema().len());
1310 let output_required_cols = required_cols
1311 .iter()
1312 .map(|&idx| mapping.map(idx))
1313 .collect_vec();
1314 let src_size = agg.schema().len();
1315 LogicalProject::with_mapping(
1316 agg.into(),
1317 ColIndexMapping::with_remaining_columns(&output_required_cols, src_size),
1318 )
1319 .into()
1320 }
1321 }
1322}
1323
1324impl PredicatePushdown for LogicalAgg {
1325 fn predicate_pushdown(
1326 &self,
1327 predicate: Condition,
1328 ctx: &mut PredicatePushdownContext,
1329 ) -> PlanRef {
1330 let num_group_key = self.group_key().len();
1331 let num_agg_calls = self.agg_calls().len();
1332 assert!(num_group_key + num_agg_calls == self.schema().len());
1333
1334 if num_group_key == 0 {
1342 return gen_filter_and_pushdown(self, predicate, Condition::true_cond(), ctx);
1343 }
1344
1345 let mut agg_call_columns = FixedBitSet::with_capacity(num_group_key + num_agg_calls);
1347 agg_call_columns.insert_range(num_group_key..num_group_key + num_agg_calls);
1348 let (agg_call_pred, pushed_predicate) = predicate.split_disjoint(&agg_call_columns);
1349
1350 let mut subst = Substitute {
1352 mapping: self
1353 .group_key()
1354 .indices()
1355 .enumerate()
1356 .map(|(i, group_key)| {
1357 InputRef::new(group_key, self.schema().fields()[i].data_type()).into()
1358 })
1359 .collect(),
1360 };
1361 let pushed_predicate = pushed_predicate.rewrite_expr(&mut subst);
1362
1363 gen_filter_and_pushdown(self, agg_call_pred, pushed_predicate, ctx)
1364 }
1365}
1366
1367impl ToBatch for LogicalAgg {
1368 fn to_batch(&self) -> Result<crate::optimizer::plan_node::BatchPlanRef> {
1369 self.to_batch_with_order_required(&Order::any())
1370 }
1371
1372 fn to_batch_with_order_required(
1375 &self,
1376 required_order: &Order,
1377 ) -> Result<crate::optimizer::plan_node::BatchPlanRef> {
1378 let input = self.input().to_batch()?;
1379 let new_logical = self.core.clone_with_input(input);
1380 let agg_plan = if self.group_key().is_empty() {
1381 BatchSimpleAgg::new(new_logical).into()
1382 } else if self.ctx().session_ctx().config().batch_enable_sort_agg()
1383 && new_logical.input_provides_order_on_group_keys()
1384 {
1385 BatchSortAgg::new(new_logical).into()
1386 } else {
1387 BatchHashAgg::new(new_logical).into()
1388 };
1389 required_order.enforce_if_not_satisfies(agg_plan)
1390 }
1391}
1392
1393fn find_or_append_row_count(mut logical: Agg<StreamPlanRef>) -> (Agg<StreamPlanRef>, usize) {
1394 let count_star = PlanAggCall::count_star();
1397 let row_count_idx = if let Some((idx, _)) = logical
1398 .agg_calls
1399 .iter()
1400 .find_position(|&c| c == &count_star)
1401 {
1402 idx
1403 } else {
1404 let idx = logical.agg_calls.len();
1405 logical.agg_calls.push(count_star);
1406 idx
1407 };
1408 (logical, row_count_idx)
1409}
1410
1411fn new_stream_simple_agg(
1412 core: Agg<StreamPlanRef>,
1413 must_output_per_barrier: bool,
1414) -> Result<StreamSimpleAgg> {
1415 let (logical, row_count_idx) = find_or_append_row_count(core);
1416 StreamSimpleAgg::new(logical, row_count_idx, must_output_per_barrier)
1417}
1418
1419fn new_stream_hash_agg(
1420 core: Agg<StreamPlanRef>,
1421 vnode_col_idx: Option<usize>,
1422) -> Result<StreamHashAgg> {
1423 let (logical, row_count_idx) = find_or_append_row_count(core);
1424 StreamHashAgg::new(logical, vnode_col_idx, row_count_idx)
1425}
1426
1427impl ToStream for LogicalAgg {
1428 fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<StreamPlanRef> {
1429 use super::stream::prelude::*;
1430
1431 let eowc = ctx.emit_on_window_close();
1432 let input = self.input();
1433
1434 let stream_input = input.to_stream(ctx)?;
1435
1436 if stream_input.append_only() && self.agg_calls().is_empty() && !self.group_key().is_empty()
1438 {
1439 let group_key = self.group_key().to_vec();
1440 let input_schema_len = input.schema().len();
1441 let dedup: PlanRef = LogicalDedup::new(input, group_key.clone()).into();
1442 let project = LogicalProject::with_mapping(
1443 dedup,
1444 ColIndexMapping::with_remaining_columns(&group_key, input_schema_len),
1445 );
1446 return project.to_stream(ctx);
1447 }
1448
1449 if self.agg_calls().iter().any(|call| {
1450 matches!(
1451 call.agg_type,
1452 AggType::Builtin(PbAggKind::ApproxCountDistinct)
1453 )
1454 }) {
1455 if stream_input.append_only() {
1456 self.core.ctx().session_ctx().notice_to_user(
1457 "Streaming `APPROX_COUNT_DISTINCT` is still a preview feature and subject to change. Please do not use it in production environment.",
1458 );
1459 } else {
1460 bail_not_implemented!(
1461 "Streaming `APPROX_COUNT_DISTINCT` is only supported in append-only stream"
1462 );
1463 }
1464 }
1465
1466 let plan = self.gen_dist_stream_agg_plan(stream_input)?;
1467
1468 let (plan, n_final_agg_calls) = if let Some(final_agg) = plan.as_stream_simple_agg() {
1469 if eowc {
1470 return Err(ErrorCode::InvalidInputSyntax(
1471 "`EMIT ON WINDOW CLOSE` cannot be used for aggregation without `GROUP BY`"
1472 .to_owned(),
1473 )
1474 .into());
1475 }
1476 (plan.clone(), final_agg.agg_calls().len())
1477 } else if let Some(final_agg) = plan.as_stream_hash_agg() {
1478 (
1479 if eowc {
1480 final_agg.to_eowc_version()?
1481 } else {
1482 plan.clone()
1483 },
1484 final_agg.agg_calls().len(),
1485 )
1486 } else if let Some(_approx_percentile_agg) = plan.as_stream_global_approx_percentile() {
1487 if eowc {
1488 return Err(ErrorCode::InvalidInputSyntax(
1489 "`EMIT ON WINDOW CLOSE` cannot be used for aggregation without `GROUP BY`"
1490 .to_owned(),
1491 )
1492 .into());
1493 }
1494 (plan.clone(), 1)
1495 } else if let Some(stream_row_merge) = plan.as_stream_row_merge() {
1496 if eowc {
1497 return Err(ErrorCode::InvalidInputSyntax(
1498 "`EMIT ON WINDOW CLOSE` cannot be used for aggregation without `GROUP BY`"
1499 .to_owned(),
1500 )
1501 .into());
1502 }
1503 (plan.clone(), stream_row_merge.base.schema().len())
1504 } else {
1505 panic!(
1506 "the root PlanNode must be StreamHashAgg, StreamSimpleAgg, StreamGlobalApproxPercentile, or StreamRowMerge"
1507 );
1508 };
1509
1510 if self.agg_calls().len() == n_final_agg_calls {
1511 Ok(plan)
1513 } else {
1514 assert_eq!(self.agg_calls().len() + 1, n_final_agg_calls);
1516
1517 let mut project = StreamProject::new(generic::Project::with_out_col_idx(
1518 plan,
1519 0..self.schema().len(),
1520 ));
1521 if self.agg_calls().is_empty() {
1526 project = project.with_noop_update_hint(true);
1527 }
1528 Ok(project.into())
1529 }
1530 }
1531
1532 fn try_better_locality(&self, columns: &[usize]) -> Option<PlanRef> {
1533 if columns.is_empty() {
1534 return None;
1535 }
1536
1537 let group_key = self.group_key().to_vec();
1539 if columns.len() > group_key.len() || columns != &group_key[..columns.len()] {
1540 return None;
1541 }
1542
1543 Some(self.clone_with_input(self.input()).into())
1548 }
1549
1550 fn logical_rewrite_for_stream(
1551 &self,
1552 ctx: &mut RewriteStreamContext,
1553 ) -> Result<(PlanRef, ColIndexMapping)> {
1554 let logical_input = if self.group_key().is_empty() {
1555 self.input()
1556 } else {
1557 try_enforce_locality_requirement(self.input(), &self.group_key().to_vec())
1558 };
1559 let (input, input_col_change) = logical_input.logical_rewrite_for_stream(ctx)?;
1560 let (agg, out_col_change) = self.rewrite_with_input(input, input_col_change);
1561 let (map, _) = out_col_change.into_parts();
1562 let out_col_change = ColIndexMapping::new(map, agg.schema().len());
1563 Ok((agg.into(), out_col_change))
1564 }
1565}
1566
1567#[cfg(test)]
1568mod tests {
1569 use risingwave_common::catalog::{Field, Schema};
1570
1571 use super::*;
1572 use crate::expr::{assert_eq_input_ref, input_ref_to_column_indices};
1573 use crate::optimizer::optimizer_context::OptimizerContext;
1574 use crate::optimizer::plan_node::LogicalValues;
1575
1576 #[tokio::test]
1577 async fn test_create() {
1578 let ty = DataType::Int32;
1579 let ctx = OptimizerContext::mock();
1580 let fields: Vec<Field> = vec![
1581 Field::with_name(ty.clone(), "v1"),
1582 Field::with_name(ty.clone(), "v2"),
1583 Field::with_name(ty.clone(), "v3"),
1584 ];
1585 let values = LogicalValues::new(vec![], Schema { fields }, ctx);
1586 let input = PlanRef::from(values);
1587 let input_ref_1 = InputRef::new(0, ty.clone());
1588 let input_ref_2 = InputRef::new(1, ty.clone());
1589 let input_ref_3 = InputRef::new(2, ty);
1590
1591 let gen_internal_value = |select_exprs: Vec<ExprImpl>,
1592 group_exprs|
1593 -> (Vec<ExprImpl>, Vec<PlanAggCall>, IndexSet) {
1594 let (plan, exprs, _) = LogicalAgg::create(
1595 select_exprs,
1596 GroupBy::GroupKey(group_exprs),
1597 None,
1598 input.clone(),
1599 )
1600 .unwrap();
1601
1602 let logical_agg = plan.as_logical_agg().unwrap();
1603 let agg_calls = logical_agg.agg_calls().clone();
1604 let group_key = logical_agg.group_key().clone();
1605
1606 (exprs, agg_calls, group_key)
1607 };
1608
1609 {
1611 let select_exprs = vec![input_ref_1.clone().into()];
1612 let group_exprs = vec![input_ref_1.clone().into()];
1613
1614 let (exprs, agg_calls, group_key) = gen_internal_value(select_exprs, group_exprs);
1615
1616 assert_eq!(exprs.len(), 1);
1617 assert_eq_input_ref!(&exprs[0], 0);
1618
1619 assert_eq!(agg_calls.len(), 0);
1620 assert_eq!(group_key, vec![0].into());
1621 }
1622
1623 {
1625 let min_v2 = AggCall::new(
1626 PbAggKind::Min.into(),
1627 vec![input_ref_2.clone().into()],
1628 false,
1629 OrderBy::any(),
1630 Condition::true_cond(),
1631 vec![],
1632 )
1633 .unwrap();
1634 let select_exprs = vec![input_ref_1.clone().into(), min_v2.into()];
1635 let group_exprs = vec![input_ref_1.clone().into()];
1636
1637 let (exprs, agg_calls, group_key) = gen_internal_value(select_exprs, group_exprs);
1638
1639 assert_eq!(exprs.len(), 2);
1640 assert_eq_input_ref!(&exprs[0], 0);
1641 assert_eq_input_ref!(&exprs[1], 1);
1642
1643 assert_eq!(agg_calls.len(), 1);
1644 assert_eq!(agg_calls[0].agg_type, PbAggKind::Min.into());
1645 assert_eq!(input_ref_to_column_indices(&agg_calls[0].inputs), vec![1]);
1646 assert_eq!(group_key, vec![0].into());
1647 }
1648
1649 {
1651 let min_v2 = AggCall::new(
1652 PbAggKind::Min.into(),
1653 vec![input_ref_2.clone().into()],
1654 false,
1655 OrderBy::any(),
1656 Condition::true_cond(),
1657 vec![],
1658 )
1659 .unwrap();
1660 let max_v3 = AggCall::new(
1661 PbAggKind::Max.into(),
1662 vec![input_ref_3.clone().into()],
1663 false,
1664 OrderBy::any(),
1665 Condition::true_cond(),
1666 vec![],
1667 )
1668 .unwrap();
1669 let func_call =
1670 FunctionCall::new(ExprType::Add, vec![min_v2.into(), max_v3.into()]).unwrap();
1671 let select_exprs = vec![input_ref_1.clone().into(), ExprImpl::from(func_call)];
1672 let group_exprs = vec![input_ref_1.clone().into()];
1673
1674 let (exprs, agg_calls, group_key) = gen_internal_value(select_exprs, group_exprs);
1675
1676 assert_eq_input_ref!(&exprs[0], 0);
1677 if let ExprImpl::FunctionCall(func_call) = &exprs[1] {
1678 assert_eq!(func_call.func_type(), ExprType::Add);
1679 let inputs = func_call.inputs();
1680 assert_eq_input_ref!(&inputs[0], 1);
1681 assert_eq_input_ref!(&inputs[1], 2);
1682 } else {
1683 panic!("Wrong expression type!");
1684 }
1685
1686 assert_eq!(agg_calls.len(), 2);
1687 assert_eq!(agg_calls[0].agg_type, PbAggKind::Min.into());
1688 assert_eq!(input_ref_to_column_indices(&agg_calls[0].inputs), vec![1]);
1689 assert_eq!(agg_calls[1].agg_type, PbAggKind::Max.into());
1690 assert_eq!(input_ref_to_column_indices(&agg_calls[1].inputs), vec![2]);
1691 assert_eq!(group_key, vec![0].into());
1692 }
1693
1694 {
1696 let v1_mult_v3 = FunctionCall::new(
1697 ExprType::Multiply,
1698 vec![input_ref_1.into(), input_ref_3.into()],
1699 )
1700 .unwrap();
1701 let agg_call = AggCall::new(
1702 PbAggKind::Min.into(),
1703 vec![v1_mult_v3.into()],
1704 false,
1705 OrderBy::any(),
1706 Condition::true_cond(),
1707 vec![],
1708 )
1709 .unwrap();
1710 let select_exprs = vec![input_ref_2.clone().into(), agg_call.into()];
1711 let group_exprs = vec![input_ref_2.into()];
1712
1713 let (exprs, agg_calls, group_key) = gen_internal_value(select_exprs, group_exprs);
1714
1715 assert_eq_input_ref!(&exprs[0], 0);
1716 assert_eq_input_ref!(&exprs[1], 1);
1717
1718 assert_eq!(agg_calls.len(), 1);
1719 assert_eq!(agg_calls[0].agg_type, PbAggKind::Min.into());
1720 assert_eq!(input_ref_to_column_indices(&agg_calls[0].inputs), vec![1]);
1721 assert_eq!(group_key, vec![0].into());
1722 }
1723 }
1724
1725 fn generate_agg_call(ty: DataType, fields: Vec<Field>) -> LogicalAgg {
1732 let ctx = OptimizerContext::mock();
1733
1734 let values = LogicalValues::new(vec![], Schema { fields }, ctx);
1735 let agg_call = PlanAggCall {
1736 agg_type: PbAggKind::Min.into(),
1737 return_type: ty.clone(),
1738 inputs: vec![InputRef::new(2, ty)],
1739 distinct: false,
1740 order_by: vec![],
1741 filter: Condition::true_cond(),
1742 direct_args: vec![],
1743 };
1744 Agg::new(vec![agg_call], vec![1].into(), values.into()).into()
1745 }
1746
1747 #[tokio::test]
1748 async fn test_prune_all() {
1759 let ty = DataType::Int32;
1760 let fields: Vec<Field> = vec![
1761 Field::with_name(ty.clone(), "v1"),
1762 Field::with_name(ty.clone(), "v2"),
1763 Field::with_name(ty.clone(), "v3"),
1764 ];
1765 let agg: PlanRef = generate_agg_call(ty.clone(), fields.clone()).into();
1766 let required_cols = vec![0, 1];
1768 let plan = agg.prune_col(&required_cols, &mut ColumnPruningContext::new(agg.clone()));
1769
1770 let agg_new = plan.as_logical_agg().unwrap();
1772 assert_eq!(agg_new.group_key(), &vec![0].into());
1773
1774 assert_eq!(agg_new.agg_calls().len(), 1);
1775 let agg_call_new = agg_new.agg_calls()[0].clone();
1776 assert_eq!(agg_call_new.agg_type, PbAggKind::Min.into());
1777 assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![1]);
1778 assert_eq!(agg_call_new.return_type, ty);
1779
1780 let values = agg_new.input();
1781 let values = values.as_logical_values().unwrap();
1782 assert_eq!(values.schema().fields(), &fields[1..]);
1783 }
1784
1785 #[tokio::test]
1786 async fn test_prune_all_with_order_required() {
1798 let ty = DataType::Int32;
1799 let fields: Vec<Field> = vec![
1800 Field::with_name(ty.clone(), "v1"),
1801 Field::with_name(ty.clone(), "v2"),
1802 Field::with_name(ty.clone(), "v3"),
1803 ];
1804 let agg: PlanRef = generate_agg_call(ty.clone(), fields.clone()).into();
1805 let required_cols = vec![1, 0];
1807 let plan = agg.prune_col(&required_cols, &mut ColumnPruningContext::new(agg.clone()));
1808 let proj = plan.as_logical_project().unwrap();
1810 assert_eq!(proj.exprs().len(), 2);
1811 assert_eq!(proj.exprs()[0].as_input_ref().unwrap().index(), 1);
1812 assert_eq!(proj.exprs()[1].as_input_ref().unwrap().index(), 0);
1813 let proj_input = proj.input();
1814 let agg_new = proj_input.as_logical_agg().unwrap();
1815 assert_eq!(agg_new.group_key(), &vec![0].into());
1816
1817 assert_eq!(agg_new.agg_calls().len(), 1);
1818 let agg_call_new = agg_new.agg_calls()[0].clone();
1819 assert_eq!(agg_call_new.agg_type, PbAggKind::Min.into());
1820 assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![1]);
1821 assert_eq!(agg_call_new.return_type, ty);
1822
1823 let values = agg_new.input();
1824 let values = values.as_logical_values().unwrap();
1825 assert_eq!(values.schema().fields(), &fields[1..]);
1826 }
1827
1828 #[tokio::test]
1829 async fn test_prune_group_key() {
1841 let ctx = OptimizerContext::mock();
1842 let ty = DataType::Int32;
1843 let fields: Vec<Field> = vec![
1844 Field::with_name(ty.clone(), "v1"),
1845 Field::with_name(ty.clone(), "v2"),
1846 Field::with_name(ty.clone(), "v3"),
1847 ];
1848 let values: LogicalValues = LogicalValues::new(
1849 vec![],
1850 Schema {
1851 fields: fields.clone(),
1852 },
1853 ctx,
1854 );
1855 let agg_call = PlanAggCall {
1856 agg_type: PbAggKind::Min.into(),
1857 return_type: ty.clone(),
1858 inputs: vec![InputRef::new(2, ty.clone())],
1859 distinct: false,
1860 order_by: vec![],
1861 filter: Condition::true_cond(),
1862 direct_args: vec![],
1863 };
1864 let agg: PlanRef = Agg::new(vec![agg_call], vec![1].into(), values.into()).into();
1865
1866 let required_cols = vec![1];
1868 let plan = agg.prune_col(&required_cols, &mut ColumnPruningContext::new(agg.clone()));
1869
1870 let project = plan.as_logical_project().unwrap();
1872 assert_eq!(project.exprs().len(), 1);
1873 assert_eq_input_ref!(&project.exprs()[0], 1);
1874
1875 let agg_new = project.input();
1876 let agg_new = agg_new.as_logical_agg().unwrap();
1877 assert_eq!(agg_new.group_key(), &vec![0].into());
1878
1879 assert_eq!(agg_new.agg_calls().len(), 1);
1880 let agg_call_new = agg_new.agg_calls()[0].clone();
1881 assert_eq!(agg_call_new.agg_type, PbAggKind::Min.into());
1882 assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![1]);
1883 assert_eq!(agg_call_new.return_type, ty);
1884
1885 let values = agg_new.input();
1886 let values = values.as_logical_values().unwrap();
1887 assert_eq!(values.schema().fields(), &fields[1..]);
1888 }
1889
1890 #[tokio::test]
1891 async fn test_prune_agg() {
1903 let ty = DataType::Int32;
1904 let ctx = OptimizerContext::mock();
1905 let fields: Vec<Field> = vec![
1906 Field::with_name(ty.clone(), "v1"),
1907 Field::with_name(ty.clone(), "v2"),
1908 Field::with_name(ty.clone(), "v3"),
1909 ];
1910 let values = LogicalValues::new(
1911 vec![],
1912 Schema {
1913 fields: fields.clone(),
1914 },
1915 ctx,
1916 );
1917
1918 let agg_calls = vec![
1919 PlanAggCall {
1920 agg_type: PbAggKind::Min.into(),
1921 return_type: ty.clone(),
1922 inputs: vec![InputRef::new(2, ty.clone())],
1923 distinct: false,
1924 order_by: vec![],
1925 filter: Condition::true_cond(),
1926 direct_args: vec![],
1927 },
1928 PlanAggCall {
1929 agg_type: PbAggKind::Max.into(),
1930 return_type: ty.clone(),
1931 inputs: vec![InputRef::new(1, ty.clone())],
1932 distinct: false,
1933 order_by: vec![],
1934 filter: Condition::true_cond(),
1935 direct_args: vec![],
1936 },
1937 ];
1938 let agg: PlanRef = Agg::new(agg_calls, vec![1, 2].into(), values.into()).into();
1939
1940 let required_cols = vec![0, 3];
1942 let plan = agg.prune_col(&required_cols, &mut ColumnPruningContext::new(agg.clone()));
1943 let project = plan.as_logical_project().unwrap();
1945 assert_eq!(project.exprs().len(), 2);
1946 assert_eq_input_ref!(&project.exprs()[0], 0);
1947 assert_eq_input_ref!(&project.exprs()[1], 2);
1948
1949 let agg_new = project.input();
1950 let agg_new = agg_new.as_logical_agg().unwrap();
1951 assert_eq!(agg_new.group_key(), &vec![0, 1].into());
1952
1953 assert_eq!(agg_new.agg_calls().len(), 1);
1954 let agg_call_new = agg_new.agg_calls()[0].clone();
1955 assert_eq!(agg_call_new.agg_type, PbAggKind::Max.into());
1956 assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![0]);
1957 assert_eq!(agg_call_new.return_type, ty);
1958
1959 let values = agg_new.input();
1960 let values = values.as_logical_values().unwrap();
1961 assert_eq!(values.schema().fields(), &fields[1..]);
1962 }
1963}