1use fixedbitset::FixedBitSet;
16use itertools::Itertools;
17use risingwave_common::types::{DataType, ScalarImpl};
18use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
19use risingwave_common::{bail, bail_not_implemented, not_implemented};
20use risingwave_expr::aggregate::{AggType, PbAggKind, agg_types};
21
22use super::generic::{self, Agg, GenericPlanRef, PlanAggCall, ProjectBuilder};
23use super::utils::impl_distill_by_unit;
24use super::{
25 BatchHashAgg, BatchSimpleAgg, ColPrunable, ExprRewritable, Logical, LogicalPlanRef as PlanRef,
26 PlanBase, PlanTreeNodeUnary, PredicatePushdown, StreamHashAgg, StreamPlanRef, StreamProject,
27 StreamShare, StreamSimpleAgg, StreamStatelessSimpleAgg, ToBatch, ToStream,
28 try_enforce_locality_requirement,
29};
30use crate::error::{ErrorCode, Result, RwError};
31use crate::expr::{
32 AggCall, Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, FunctionCall, InputRef, Literal,
33 OrderBy, OrderByExpr, WindowFunction,
34};
35use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
36use crate::optimizer::plan_node::generic::GenericPlanNode;
37use crate::optimizer::plan_node::stream_global_approx_percentile::StreamGlobalApproxPercentile;
38use crate::optimizer::plan_node::stream_local_approx_percentile::StreamLocalApproxPercentile;
39use crate::optimizer::plan_node::stream_row_merge::StreamRowMerge;
40use crate::optimizer::plan_node::{
41 BatchSortAgg, ColumnPruningContext, LogicalDedup, LogicalProject, PredicatePushdownContext,
42 RewriteStreamContext, ToStreamContext, gen_filter_and_pushdown,
43};
44use crate::optimizer::property::{Distribution, Order, RequiredDist};
45use crate::utils::{
46 ColIndexMapping, ColIndexMappingRewriteExt, Condition, GroupBy, IndexSet, Substitute,
47};
48
49pub struct AggInfo {
50 pub calls: Vec<PlanAggCall>,
51 pub col_mapping: ColIndexMapping,
52}
53
54pub struct SeparatedAggInfo {
56 normal: AggInfo,
57 approx: AggInfo,
58}
59
60#[derive(Clone, Debug, PartialEq, Eq, Hash)]
67pub struct LogicalAgg {
68 pub base: PlanBase<Logical>,
69 core: Agg<PlanRef>,
70}
71
72impl LogicalAgg {
73 fn gen_stateless_two_phase_streaming_agg_plan(
76 &self,
77 stream_input: StreamPlanRef,
78 ) -> Result<StreamPlanRef> {
79 debug_assert!(self.group_key().is_empty());
80
81 let (
83 non_approx_percentile_col_mapping,
84 approx_percentile_col_mapping,
85 approx_percentile,
86 core,
87 ) = self.prepare_approx_percentile(stream_input)?;
88
89 if core.agg_calls.is_empty() {
90 if let Some(approx_percentile) = approx_percentile {
91 return Ok(approx_percentile);
92 };
93 bail!("expected at least one agg call");
94 }
95
96 let need_row_merge: bool = Self::need_row_merge(&approx_percentile);
97
98 let total_agg_calls = core
100 .agg_calls
101 .iter()
102 .enumerate()
103 .map(|(partial_output_idx, agg_call)| {
104 agg_call.partial_to_total_agg_call(partial_output_idx)
105 })
106 .collect_vec();
107 let local_agg = StreamStatelessSimpleAgg::new(core)?;
108 let exchange =
109 RequiredDist::single().streaming_enforce_if_not_satisfies(local_agg.into())?;
110
111 let must_output_per_barrier = need_row_merge;
112 let global_agg = new_stream_simple_agg(
113 Agg::new(total_agg_calls, IndexSet::empty(), exchange),
114 must_output_per_barrier,
115 )?;
116
117 Self::add_row_merge_if_needed(
119 approx_percentile,
120 global_agg.into(),
121 approx_percentile_col_mapping,
122 non_approx_percentile_col_mapping,
123 )
124 }
125
126 fn gen_vnode_two_phase_streaming_agg_plan(
130 &self,
131 stream_input: StreamPlanRef,
132 dist_key: &[usize],
133 ) -> Result<StreamPlanRef> {
134 let (
135 non_approx_percentile_col_mapping,
136 approx_percentile_col_mapping,
137 approx_percentile,
138 core,
139 ) = self.prepare_approx_percentile(stream_input.clone())?;
140
141 if core.agg_calls.is_empty() {
142 if let Some(approx_percentile) = approx_percentile {
143 return Ok(approx_percentile);
144 };
145 bail!("expected at least one agg call");
146 }
147 let need_row_merge = Self::need_row_merge(&approx_percentile);
148
149 let input_col_num = stream_input.schema().len(); let project = StreamProject::new(generic::Project::with_vnode_col(stream_input, dist_key));
153 let vnode_col_idx = project.base.schema().len() - 1;
154
155 let mut local_group_key = self.group_key().clone();
157 local_group_key.insert(vnode_col_idx);
158 let n_local_group_key = local_group_key.len();
159 let local_agg = new_stream_hash_agg(
160 Agg::new(core.agg_calls.clone(), local_group_key, project.into()),
161 Some(vnode_col_idx),
162 )?;
163 let local_agg_group_key_cardinality = local_agg.group_key().len();
165 let local_group_key_without_vnode =
166 &local_agg.group_key().to_vec()[..local_agg_group_key_cardinality - 1];
167 let global_group_key = local_agg
168 .i2o_col_mapping()
169 .rewrite_dist_key(local_group_key_without_vnode)
170 .expect("some input group key could not be mapped");
171
172 let global_agg = if self.group_key().is_empty() {
174 let exchange =
175 RequiredDist::single().streaming_enforce_if_not_satisfies(local_agg.into())?;
176 let must_output_per_barrier = need_row_merge;
177 let global_agg = new_stream_simple_agg(
178 Agg::new(
179 core.agg_calls
180 .iter()
181 .enumerate()
182 .map(|(partial_output_idx, agg_call)| {
183 agg_call
184 .partial_to_total_agg_call(n_local_group_key + partial_output_idx)
185 })
186 .collect(),
187 global_group_key.into_iter().collect(),
188 exchange,
189 ),
190 must_output_per_barrier,
191 )?;
192 global_agg.into()
193 } else {
194 assert!(!need_row_merge);
196 let exchange = RequiredDist::shard_by_key(input_col_num, &global_group_key)
197 .streaming_enforce_if_not_satisfies(local_agg.into())?;
198 let global_agg = new_stream_hash_agg(
201 Agg::new(
202 core.agg_calls
203 .iter()
204 .enumerate()
205 .map(|(partial_output_idx, agg_call)| {
206 agg_call
207 .partial_to_total_agg_call(n_local_group_key + partial_output_idx)
208 })
209 .collect(),
210 global_group_key.into_iter().collect(),
211 exchange,
212 ),
213 None,
214 )?;
215 global_agg.into()
216 };
217 Self::add_row_merge_if_needed(
218 approx_percentile,
219 global_agg,
220 approx_percentile_col_mapping,
221 non_approx_percentile_col_mapping,
222 )
223 }
224
225 fn gen_single_plan(&self, stream_input: StreamPlanRef) -> Result<StreamPlanRef> {
226 let input = RequiredDist::single().streaming_enforce_if_not_satisfies(stream_input)?;
227 let core = self.core.clone_with_input(input);
228 Ok(new_stream_simple_agg(core, false)?.into())
229 }
230
231 fn gen_shuffle_plan(&self, stream_input: StreamPlanRef) -> Result<StreamPlanRef> {
232 let input =
233 RequiredDist::shard_by_key(stream_input.schema().len(), &self.group_key().to_vec())
234 .streaming_enforce_if_not_satisfies(stream_input)?;
235 let core = self.core.clone_with_input(input);
236 Ok(new_stream_hash_agg(core, None)?.into())
237 }
238
239 fn gen_dist_stream_agg_plan(&self, stream_input: StreamPlanRef) -> Result<StreamPlanRef> {
241 use super::stream::prelude::*;
242
243 let input_dist = stream_input.distribution();
244 debug_assert!(*input_dist != Distribution::Broadcast);
245
246 if !self.group_key().is_empty() && !self.core.must_try_two_phase_agg() {
250 return self.gen_shuffle_plan(stream_input);
251 }
252
253 if self.group_key().is_empty() && !self.core.can_two_phase_agg() {
256 return self.gen_single_plan(stream_input);
257 }
258
259 debug_assert!(if !self.group_key().is_empty() {
260 self.core.must_try_two_phase_agg()
261 } else {
262 self.core.can_two_phase_agg()
263 });
264
265 if self.group_key().is_empty()
269 && self
270 .core
271 .all_local_aggs_are_stateless(stream_input.append_only())
272 && input_dist.satisfies(&RequiredDist::AnyShard)
273 {
274 return self.gen_stateless_two_phase_streaming_agg_plan(stream_input);
275 }
276
277 let stream_input =
282 if *input_dist == Distribution::SomeShard && self.core.must_try_two_phase_agg() {
283 RequiredDist::shard_by_key(
284 stream_input.schema().len(),
285 stream_input.expect_stream_key(),
286 )
287 .streaming_enforce_if_not_satisfies(stream_input)?
288 } else {
289 stream_input
290 };
291 let input_dist = stream_input.distribution();
292
293 match input_dist {
297 Distribution::HashShard(dist_key) | Distribution::UpstreamHashShard(dist_key, _)
298 if (self.group_key().is_empty()
299 || !self.core.hash_agg_dist_satisfied_by_input_dist(input_dist)) =>
300 {
301 let dist_key = dist_key.clone();
302 return self.gen_vnode_two_phase_streaming_agg_plan(stream_input, &dist_key);
303 }
304 _ => {}
305 }
306
307 if !self.group_key().is_empty() {
309 self.gen_shuffle_plan(stream_input)
310 } else {
311 self.gen_single_plan(stream_input)
312 }
313 }
314
315 fn prepare_approx_percentile(
320 &self,
321 stream_input: StreamPlanRef,
322 ) -> Result<(
323 ColIndexMapping,
324 ColIndexMapping,
325 Option<StreamPlanRef>,
326 Agg<StreamPlanRef>,
327 )> {
328 let SeparatedAggInfo { normal, approx } = self.separate_normal_and_special_agg();
329
330 let AggInfo {
331 calls: non_approx_percentile_agg_calls,
332 col_mapping: non_approx_percentile_col_mapping,
333 } = normal;
334 let AggInfo {
335 calls: approx_percentile_agg_calls,
336 col_mapping: approx_percentile_col_mapping,
337 } = approx;
338 if !self.group_key().is_empty() && !approx_percentile_agg_calls.is_empty() {
339 bail_not_implemented!(
340 "two-phase streaming approx percentile aggregation with group key, \
341 please use single phase aggregation instead"
342 );
343 }
344
345 let needs_row_merge = (!non_approx_percentile_agg_calls.is_empty()
348 && !approx_percentile_agg_calls.is_empty())
349 || approx_percentile_agg_calls.len() >= 2;
350 let input = if needs_row_merge {
351 StreamShare::new_from_input(stream_input).into()
353 } else {
354 stream_input
355 };
356 let mut core = self.core.clone_with_input(input);
357 core.agg_calls = non_approx_percentile_agg_calls;
358
359 let approx_percentile =
360 self.build_approx_percentile_aggs(core.input.clone(), &approx_percentile_agg_calls)?;
361 Ok((
362 non_approx_percentile_col_mapping,
363 approx_percentile_col_mapping,
364 approx_percentile,
365 core,
366 ))
367 }
368
369 fn need_row_merge(approx_percentile: &Option<StreamPlanRef>) -> bool {
370 approx_percentile.is_some()
371 }
372
373 fn add_row_merge_if_needed(
375 approx_percentile: Option<StreamPlanRef>,
376 global_agg: StreamPlanRef,
377 approx_percentile_col_mapping: ColIndexMapping,
378 non_approx_percentile_col_mapping: ColIndexMapping,
379 ) -> Result<StreamPlanRef> {
380 let need_row_merge = Self::need_row_merge(&approx_percentile);
382
383 if let Some(approx_percentile) = approx_percentile {
384 assert!(need_row_merge);
385 let row_merge = StreamRowMerge::new(
386 approx_percentile,
387 global_agg,
388 approx_percentile_col_mapping,
389 non_approx_percentile_col_mapping,
390 )?;
391 Ok(row_merge.into())
392 } else {
393 assert!(!need_row_merge);
394 Ok(global_agg)
395 }
396 }
397
398 fn separate_normal_and_special_agg(&self) -> SeparatedAggInfo {
399 let estimated_len = self.agg_calls().len() - 1;
400 let mut approx_percentile_agg_calls = Vec::with_capacity(estimated_len);
401 let mut non_approx_percentile_agg_calls = Vec::with_capacity(estimated_len);
402 let mut approx_percentile_col_mapping = Vec::with_capacity(estimated_len);
403 let mut non_approx_percentile_col_mapping = Vec::with_capacity(estimated_len);
404 for (output_idx, agg_call) in self.agg_calls().iter().enumerate() {
405 if agg_call.agg_type == AggType::Builtin(PbAggKind::ApproxPercentile) {
406 approx_percentile_agg_calls.push(agg_call.clone());
407 approx_percentile_col_mapping.push(Some(output_idx));
408 } else {
409 non_approx_percentile_agg_calls.push(agg_call.clone());
410 non_approx_percentile_col_mapping.push(Some(output_idx));
411 }
412 }
413 let normal = AggInfo {
414 calls: non_approx_percentile_agg_calls,
415 col_mapping: ColIndexMapping::new(
416 non_approx_percentile_col_mapping,
417 self.agg_calls().len(),
418 ),
419 };
420 let approx = AggInfo {
421 calls: approx_percentile_agg_calls,
422 col_mapping: ColIndexMapping::new(
423 approx_percentile_col_mapping,
424 self.agg_calls().len(),
425 ),
426 };
427 SeparatedAggInfo { normal, approx }
428 }
429
430 fn build_approx_percentile_agg(
431 &self,
432 input: StreamPlanRef,
433 approx_percentile_agg_call: &PlanAggCall,
434 ) -> Result<StreamPlanRef> {
435 let local_approx_percentile =
436 StreamLocalApproxPercentile::new(input, approx_percentile_agg_call)?;
437 let exchange = RequiredDist::single()
438 .streaming_enforce_if_not_satisfies(local_approx_percentile.into())?;
439 let global_approx_percentile =
440 StreamGlobalApproxPercentile::new(exchange, approx_percentile_agg_call);
441 Ok(global_approx_percentile.into())
442 }
443
444 fn build_approx_percentile_aggs(
457 &self,
458 input: StreamPlanRef,
459 approx_percentile_agg_call: &[PlanAggCall],
460 ) -> Result<Option<StreamPlanRef>> {
461 if approx_percentile_agg_call.is_empty() {
462 return Ok(None);
463 }
464 let approx_percentile_plans: Vec<_> = approx_percentile_agg_call
465 .iter()
466 .map(|agg_call| self.build_approx_percentile_agg(input.clone(), agg_call))
467 .try_collect()?;
468 assert!(!approx_percentile_plans.is_empty());
469 let mut iter = approx_percentile_plans.into_iter();
470 let mut acc = iter.next().unwrap();
471 for (current_size, plan) in iter.enumerate().map(|(i, p)| (i + 1, p)) {
472 let new_size = current_size + 1;
473 let row_merge = StreamRowMerge::new(
474 acc,
475 plan,
476 ColIndexMapping::identity_or_none(current_size, new_size),
477 ColIndexMapping::new(vec![Some(current_size)], new_size),
478 )?;
479 acc = row_merge.into();
480 }
481 Ok(Some(acc))
482 }
483
484 pub fn core(&self) -> &Agg<PlanRef> {
485 &self.core
486 }
487}
488
489pub struct LogicalAggBuilder {
494 input_proj_builder: ProjectBuilder,
496 group_key: IndexSet,
498 grouping_sets: Vec<IndexSet>,
500 agg_calls: Vec<PlanAggCall>,
502 error: Option<RwError>,
504 is_in_filter_clause: bool,
510}
511
512impl LogicalAggBuilder {
513 fn new(group_by: GroupBy, input_schema_len: usize) -> Result<Self> {
514 let mut input_proj_builder = ProjectBuilder::default();
515
516 let mut gen_group_key_and_grouping_sets =
517 |grouping_sets: Vec<Vec<ExprImpl>>| -> Result<(IndexSet, Vec<IndexSet>)> {
518 let grouping_sets: Vec<IndexSet> = grouping_sets
519 .into_iter()
520 .map(|set| {
521 set.into_iter()
522 .map(|expr| input_proj_builder.add_expr(&expr))
523 .try_collect()
524 .map_err(|err| not_implemented!("{err} inside GROUP BY"))
525 })
526 .try_collect()?;
527
528 let group_key = grouping_sets
530 .iter()
531 .fold(FixedBitSet::with_capacity(input_schema_len), |acc, x| {
532 acc.union(&x.to_bitset()).collect()
533 });
534
535 Ok((IndexSet::from_iter(group_key.ones()), grouping_sets))
536 };
537
538 let (group_key, grouping_sets) = match group_by {
539 GroupBy::GroupKey(group_key) => {
540 let group_key = group_key
541 .into_iter()
542 .map(|expr| input_proj_builder.add_expr(&expr))
543 .try_collect()
544 .map_err(|err| not_implemented!("{err} inside GROUP BY"))?;
545 (group_key, vec![])
546 }
547 GroupBy::GroupingSets(grouping_sets) => gen_group_key_and_grouping_sets(grouping_sets)?,
548 GroupBy::Rollup(rollup) => {
549 let grouping_sets = (0..=rollup.len())
551 .map(|n| {
552 rollup
553 .iter()
554 .take(n)
555 .flat_map(|x| x.iter().cloned())
556 .collect_vec()
557 })
558 .collect_vec();
559 gen_group_key_and_grouping_sets(grouping_sets)?
560 }
561 GroupBy::Cube(cube) => {
562 let grouping_sets = cube
564 .into_iter()
565 .powerset()
566 .map(|x| x.into_iter().flatten().collect_vec())
567 .collect_vec();
568 gen_group_key_and_grouping_sets(grouping_sets)?
569 }
570 };
571
572 Ok(LogicalAggBuilder {
573 group_key,
574 grouping_sets,
575 agg_calls: vec![],
576 error: None,
577 input_proj_builder,
578 is_in_filter_clause: false,
579 })
580 }
581
582 pub fn build(self, input: PlanRef) -> LogicalAgg {
583 let logical_project = LogicalProject::with_core(self.input_proj_builder.build(input));
585
586 Agg::new(self.agg_calls, self.group_key, logical_project.into())
588 .with_grouping_sets(self.grouping_sets)
589 .into()
590 }
591
592 fn rewrite_with_error(&mut self, expr: ExprImpl) -> Result<ExprImpl> {
593 let rewritten_expr = self.rewrite_expr(expr);
594 if let Some(error) = self.error.take() {
595 return Err(error);
596 }
597 Ok(rewritten_expr)
598 }
599
600 pub fn try_as_group_expr(&self, expr: &ExprImpl) -> Option<usize> {
602 if let Some(input_index) = self.input_proj_builder.expr_index(expr)
603 && let Some(index) = self
604 .group_key
605 .indices()
606 .position(|group_key| group_key == input_index)
607 {
608 return Some(index);
609 }
610 None
611 }
612
613 fn schema_agg_start_offset(&self) -> usize {
614 self.group_key.len()
615 }
616
617 pub(crate) fn general_rewrite_agg_call(
620 agg_call: AggCall,
621 mut push_agg_call: impl FnMut(AggCall) -> Result<InputRef>,
622 ) -> Result<ExprImpl> {
623 match agg_call.agg_type {
624 AggType::Builtin(PbAggKind::Avg) => {
626 assert_eq!(agg_call.args.len(), 1);
627
628 let sum = ExprImpl::from(push_agg_call(AggCall::new(
629 PbAggKind::Sum.into(),
630 agg_call.args.clone(),
631 agg_call.distinct,
632 agg_call.order_by.clone(),
633 agg_call.filter.clone(),
634 agg_call.direct_args.clone(),
635 )?)?)
636 .cast_explicit(&agg_call.return_type())?;
637
638 let count = ExprImpl::from(push_agg_call(AggCall::new(
639 PbAggKind::Count.into(),
640 agg_call.args.clone(),
641 agg_call.distinct,
642 agg_call.order_by.clone(),
643 agg_call.filter.clone(),
644 agg_call.direct_args,
645 )?)?);
646
647 Ok(FunctionCall::new(ExprType::Divide, Vec::from([sum, count]))?.into())
648 }
649 AggType::Builtin(
658 kind @ (PbAggKind::StddevPop
659 | PbAggKind::StddevSamp
660 | PbAggKind::VarPop
661 | PbAggKind::VarSamp),
662 ) => {
663 let arg = agg_call.args().iter().exactly_one().unwrap();
664 let squared_arg = ExprImpl::from(FunctionCall::new(
665 ExprType::Multiply,
666 vec![arg.clone(), arg.clone()],
667 )?);
668
669 let sum_of_sq = ExprImpl::from(push_agg_call(AggCall::new(
670 PbAggKind::Sum.into(),
671 vec![squared_arg],
672 agg_call.distinct,
673 agg_call.order_by.clone(),
674 agg_call.filter.clone(),
675 agg_call.direct_args.clone(),
676 )?)?)
677 .cast_explicit(&agg_call.return_type())?;
678
679 let sum = ExprImpl::from(push_agg_call(AggCall::new(
680 PbAggKind::Sum.into(),
681 agg_call.args.clone(),
682 agg_call.distinct,
683 agg_call.order_by.clone(),
684 agg_call.filter.clone(),
685 agg_call.direct_args.clone(),
686 )?)?)
687 .cast_explicit(&agg_call.return_type())?;
688
689 let count = ExprImpl::from(push_agg_call(AggCall::new(
690 PbAggKind::Count.into(),
691 agg_call.args.clone(),
692 agg_call.distinct,
693 agg_call.order_by.clone(),
694 agg_call.filter.clone(),
695 agg_call.direct_args.clone(),
696 )?)?);
697
698 let zero = ExprImpl::literal_int(0);
699 let one = ExprImpl::literal_int(1);
700
701 let squared_sum = ExprImpl::from(FunctionCall::new(
702 ExprType::Multiply,
703 vec![sum.clone(), sum],
704 )?);
705
706 let raw_numerator = ExprImpl::from(FunctionCall::new(
707 ExprType::Subtract,
708 vec![
709 sum_of_sq,
710 ExprImpl::from(FunctionCall::new(
711 ExprType::Divide,
712 vec![squared_sum, count.clone()],
713 )?),
714 ],
715 )?);
716
717 let numerator_type = raw_numerator.return_type();
719 let numerator = ExprImpl::from(FunctionCall::new(
720 ExprType::Greatest,
721 vec![raw_numerator, zero.clone().cast_explicit(&numerator_type)?],
722 )?);
723
724 let denominator = match kind {
725 PbAggKind::VarPop | PbAggKind::StddevPop => count.clone(),
726 PbAggKind::VarSamp | PbAggKind::StddevSamp => ExprImpl::from(
727 FunctionCall::new(ExprType::Subtract, vec![count.clone(), one.clone()])?,
728 ),
729 _ => unreachable!(),
730 };
731
732 let mut target = ExprImpl::from(FunctionCall::new(
733 ExprType::Divide,
734 vec![numerator, denominator],
735 )?);
736
737 if matches!(kind, PbAggKind::StddevPop | PbAggKind::StddevSamp) {
738 target = ExprImpl::from(FunctionCall::new(ExprType::Sqrt, vec![target])?);
739 }
740
741 let null = ExprImpl::from(Literal::new(None, agg_call.return_type()));
742 let case_cond = match kind {
743 PbAggKind::VarPop | PbAggKind::StddevPop => {
744 ExprImpl::from(FunctionCall::new(ExprType::Equal, vec![count, zero])?)
745 }
746 PbAggKind::VarSamp | PbAggKind::StddevSamp => ExprImpl::from(
747 FunctionCall::new(ExprType::LessThanOrEqual, vec![count, one])?,
748 ),
749 _ => unreachable!(),
750 };
751
752 Ok(ExprImpl::from(FunctionCall::new(
753 ExprType::Case,
754 vec![case_cond, null, target],
755 )?))
756 }
757 AggType::Builtin(PbAggKind::ApproxPercentile) => {
758 if agg_call.order_by.sort_exprs[0].order_type == OrderType::descending() {
759 let prev_percentile = agg_call.direct_args[0].clone();
761 let new_percentile = 1.0
762 - prev_percentile
763 .get_data()
764 .as_ref()
765 .unwrap()
766 .as_float64()
767 .into_inner();
768 let new_percentile = Some(ScalarImpl::Float64(new_percentile.into()));
769 let new_percentile = Literal::new(new_percentile, DataType::Float64);
770 let new_direct_args = vec![new_percentile, agg_call.direct_args[1].clone()];
771
772 let new_agg_call = AggCall {
773 order_by: OrderBy::any(),
774 direct_args: new_direct_args,
775 ..agg_call
776 };
777 Ok(push_agg_call(new_agg_call)?.into())
778 } else {
779 let new_agg_call = AggCall {
780 order_by: OrderBy::any(),
781 ..agg_call
782 };
783 Ok(push_agg_call(new_agg_call)?.into())
784 }
785 }
786 AggType::Builtin(PbAggKind::ArgMin | PbAggKind::ArgMax) => {
787 let mut agg_call = agg_call;
788
789 let comparison_arg_type = agg_call.args[1].return_type();
790 match comparison_arg_type {
791 DataType::Struct(_)
792 | DataType::List(_)
793 | DataType::Map(_)
794 | DataType::Vector(_)
795 | DataType::Jsonb => {
796 bail!(format!(
797 "{} does not support struct, array, map, vector, jsonb for comparison argument, got {}",
798 agg_call.agg_type.to_string(),
799 comparison_arg_type
800 ));
801 }
802 _ => {}
803 }
804
805 let not_null_exprs: Vec<ExprImpl> = agg_call
806 .args
807 .iter()
808 .map(|arg| -> Result<ExprImpl> {
809 Ok(FunctionCall::new(ExprType::IsNotNull, vec![arg.clone()])?.into())
810 })
811 .try_collect()?;
812
813 let comparison_expr = agg_call.args[1].clone();
814 let mut order_exprs = vec![OrderByExpr {
815 expr: comparison_expr,
816 order_type: if agg_call.agg_type == AggType::Builtin(PbAggKind::ArgMin) {
817 OrderType::ascending()
818 } else {
819 OrderType::descending()
820 },
821 }];
822
823 order_exprs.extend(agg_call.order_by.sort_exprs);
824
825 let order_by = OrderBy::new(order_exprs);
826
827 let filter = agg_call.filter.clone().and(Condition {
828 conjunctions: not_null_exprs,
829 });
830
831 agg_call.args.truncate(1);
832
833 let new_agg_call = AggCall {
834 agg_type: AggType::Builtin(PbAggKind::FirstValue),
835 order_by,
836 filter,
837 ..agg_call
838 };
839 Ok(push_agg_call(new_agg_call)?.into())
840 }
841 _ => Ok(push_agg_call(agg_call)?.into()),
842 }
843 }
844
845 fn push_agg_call(&mut self, agg_call: AggCall) -> Result<InputRef> {
849 let AggCall {
850 agg_type,
851 return_type,
852 args,
853 distinct,
854 order_by,
855 filter,
856 direct_args,
857 } = agg_call;
858
859 self.is_in_filter_clause = true;
860 let filter = filter.rewrite_expr(self);
863 self.is_in_filter_clause = false;
864
865 let args: Vec<_> = args
866 .iter()
867 .map(|expr| {
868 let index = self.input_proj_builder.add_expr(expr)?;
869 Ok(InputRef::new(index, expr.return_type()))
870 })
871 .try_collect()
872 .map_err(|err: &'static str| not_implemented!("{err} inside aggregation calls"))?;
873
874 let order_by: Vec<_> = order_by
875 .sort_exprs
876 .iter()
877 .map(|e| {
878 let index = self.input_proj_builder.add_expr(&e.expr)?;
879 Ok(ColumnOrder::new(index, e.order_type))
880 })
881 .try_collect()
882 .map_err(|err: &'static str| {
883 not_implemented!("{err} inside aggregation calls order by")
884 })?;
885
886 let plan_agg_call = PlanAggCall {
887 agg_type,
888 return_type: return_type.clone(),
889 inputs: args,
890 distinct,
891 order_by,
892 filter,
893 direct_args,
894 };
895
896 if let Some((pos, existing)) = self
897 .agg_calls
898 .iter()
899 .find_position(|&c| c == &plan_agg_call)
900 {
901 return Ok(InputRef::new(
902 self.schema_agg_start_offset() + pos,
903 existing.return_type.clone(),
904 ));
905 }
906 let index = self.schema_agg_start_offset() + self.agg_calls.len();
907 self.agg_calls.push(plan_agg_call);
908 Ok(InputRef::new(index, return_type))
909 }
910
911 fn try_rewrite_agg_call(&mut self, mut agg_call: AggCall) -> Result<ExprImpl> {
918 if matches!(agg_call.agg_type, agg_types::must_have_order_by!())
919 && agg_call.order_by.sort_exprs.is_empty()
920 {
921 return Err(ErrorCode::InvalidInputSyntax(format!(
922 "Aggregation function {} requires ORDER BY clause",
923 agg_call.agg_type
924 ))
925 .into());
926 }
927
928 if matches!(
930 agg_call.agg_type,
931 agg_types::result_unaffected_by_order_by!()
932 ) {
933 agg_call.order_by = OrderBy::any();
934 }
935 if matches!(
937 agg_call.agg_type,
938 agg_types::result_unaffected_by_distinct!()
939 ) {
940 agg_call.distinct = false;
941 }
942
943 if matches!(agg_call.agg_type, AggType::Builtin(PbAggKind::Grouping)) {
944 if self.grouping_sets.is_empty() {
945 return Err(ErrorCode::NotSupported(
946 "GROUPING must be used in a query with grouping sets".into(),
947 "try to use grouping sets instead".into(),
948 )
949 .into());
950 }
951 if agg_call.args.len() >= 32 {
952 return Err(ErrorCode::InvalidInputSyntax(
953 "GROUPING must have fewer than 32 arguments".into(),
954 )
955 .into());
956 }
957 if agg_call
958 .args
959 .iter()
960 .any(|x| self.try_as_group_expr(x).is_none())
961 {
962 return Err(ErrorCode::InvalidInputSyntax(
963 "arguments to GROUPING must be grouping expressions of the associated query level"
964 .into(),
965 ).into());
966 }
967 }
968
969 Self::general_rewrite_agg_call(agg_call, |agg_call| self.push_agg_call(agg_call))
970 }
971}
972
973impl ExprRewriter for LogicalAggBuilder {
974 fn rewrite_agg_call(&mut self, agg_call: AggCall) -> ExprImpl {
975 let dummy = Literal::new(None, agg_call.return_type()).into();
976 match self.try_rewrite_agg_call(agg_call) {
977 Ok(expr) => expr,
978 Err(err) => {
979 self.error = Some(err);
980 dummy
981 }
982 }
983 }
984
985 fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
988 let expr = func_call.into();
989 if let Some(group_key) = self.try_as_group_expr(&expr) {
990 InputRef::new(group_key, expr.return_type()).into()
991 } else {
992 let (func_type, inputs, ret) = expr.into_function_call().unwrap().decompose();
993 let inputs = inputs
994 .into_iter()
995 .map(|expr| self.rewrite_expr(expr))
996 .collect();
997 FunctionCall::new_unchecked(func_type, inputs, ret).into()
998 }
999 }
1000
1001 fn rewrite_window_function(&mut self, window_func: WindowFunction) -> ExprImpl {
1004 let WindowFunction {
1005 args,
1006 partition_by,
1007 order_by,
1008 ..
1009 } = window_func;
1010 let args = args
1011 .into_iter()
1012 .map(|expr| self.rewrite_expr(expr))
1013 .collect();
1014 let partition_by = partition_by
1015 .into_iter()
1016 .map(|expr| self.rewrite_expr(expr))
1017 .collect();
1018 let order_by = order_by.rewrite_expr(self);
1019 WindowFunction {
1020 args,
1021 partition_by,
1022 order_by,
1023 ..window_func
1024 }
1025 .into()
1026 }
1027
1028 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
1030 let expr = input_ref.into();
1031 if let Some(group_key) = self.try_as_group_expr(&expr) {
1032 InputRef::new(group_key, expr.return_type()).into()
1033 } else if self.is_in_filter_clause {
1034 InputRef::new(
1035 self.input_proj_builder.add_expr(&expr).unwrap(),
1036 expr.return_type(),
1037 )
1038 .into()
1039 } else {
1040 self.error = Some(
1041 ErrorCode::InvalidInputSyntax(
1042 "column must appear in the GROUP BY clause or be used in an aggregate function"
1043 .into(),
1044 )
1045 .into(),
1046 );
1047 expr
1048 }
1049 }
1050
1051 fn rewrite_subquery(&mut self, subquery: crate::expr::Subquery) -> ExprImpl {
1052 if subquery.is_correlated_by_depth(0) {
1053 self.error = Some(
1054 not_implemented!(
1055 issue = 2275,
1056 "correlated subquery in HAVING or SELECT with agg",
1057 )
1058 .into(),
1059 );
1060 }
1061 subquery.into()
1062 }
1063}
1064
1065impl From<Agg<PlanRef>> for LogicalAgg {
1066 fn from(core: Agg<PlanRef>) -> Self {
1067 let base = PlanBase::new_logical_with_core(&core);
1068 Self { base, core }
1069 }
1070}
1071
1072impl From<Agg<PlanRef>> for PlanRef {
1074 fn from(core: Agg<PlanRef>) -> Self {
1075 LogicalAgg::from(core).into()
1076 }
1077}
1078
1079impl LogicalAgg {
1080 pub fn i2o_col_mapping(&self) -> ColIndexMapping {
1082 self.core.i2o_col_mapping()
1083 }
1084
1085 pub fn create(
1094 select_exprs: Vec<ExprImpl>,
1095 group_by: GroupBy,
1096 having: Option<ExprImpl>,
1097 input: PlanRef,
1098 ) -> Result<(PlanRef, Vec<ExprImpl>, Option<ExprImpl>)> {
1099 let mut agg_builder = LogicalAggBuilder::new(group_by, input.schema().len())?;
1100
1101 let rewritten_select_exprs = select_exprs
1102 .into_iter()
1103 .map(|expr| agg_builder.rewrite_with_error(expr))
1104 .collect::<Result<_>>()?;
1105 let rewritten_having = having
1106 .map(|expr| agg_builder.rewrite_with_error(expr))
1107 .transpose()?;
1108
1109 Ok((
1110 agg_builder.build(input).into(),
1111 rewritten_select_exprs,
1112 rewritten_having,
1113 ))
1114 }
1115
1116 pub fn agg_calls(&self) -> &Vec<PlanAggCall> {
1118 &self.core.agg_calls
1119 }
1120
1121 pub fn group_key(&self) -> &IndexSet {
1123 &self.core.group_key
1124 }
1125
1126 pub fn grouping_sets(&self) -> &Vec<IndexSet> {
1127 &self.core.grouping_sets
1128 }
1129
1130 pub fn decompose(self) -> (Vec<PlanAggCall>, IndexSet, Vec<IndexSet>, PlanRef, bool) {
1131 self.core.decompose()
1132 }
1133
1134 #[must_use]
1135 pub fn rewrite_with_input_agg(
1136 &self,
1137 input: PlanRef,
1138 agg_calls: &[PlanAggCall],
1139 mut input_col_change: ColIndexMapping,
1140 ) -> (Self, ColIndexMapping) {
1141 let agg_calls = agg_calls
1142 .iter()
1143 .cloned()
1144 .map(|mut agg_call| {
1145 agg_call.inputs.iter_mut().for_each(|i| {
1146 *i = InputRef::new(input_col_change.map(i.index()), i.return_type())
1147 });
1148 agg_call.order_by.iter_mut().for_each(|o| {
1149 o.column_index = input_col_change.map(o.column_index);
1150 });
1151 agg_call.filter = agg_call.filter.rewrite_expr(&mut input_col_change);
1152 agg_call
1153 })
1154 .collect();
1155 let group_key_in_vec: Vec<usize> = self
1157 .group_key()
1158 .indices()
1159 .map(|key| input_col_change.map(key))
1160 .collect();
1161 let group_key: IndexSet = group_key_in_vec.iter().cloned().collect();
1163 let grouping_sets = self
1164 .grouping_sets()
1165 .iter()
1166 .map(|set| set.indices().map(|key| input_col_change.map(key)).collect())
1167 .collect();
1168
1169 let new_agg = Agg::new(agg_calls, group_key.clone(), input)
1170 .with_grouping_sets(grouping_sets)
1171 .with_enable_two_phase(self.core().enable_two_phase);
1172
1173 let mut out_col_change = vec![];
1176 for idx in group_key_in_vec {
1177 let pos = group_key.indices().position(|x| x == idx).unwrap();
1178 out_col_change.push(pos);
1179 }
1180 for i in (group_key.len())..new_agg.schema().len() {
1181 out_col_change.push(i);
1182 }
1183 let out_col_change =
1184 ColIndexMapping::with_remaining_columns(&out_col_change, new_agg.schema().len());
1185
1186 (new_agg.into(), out_col_change)
1187 }
1188}
1189
1190impl PlanTreeNodeUnary<Logical> for LogicalAgg {
1191 fn input(&self) -> PlanRef {
1192 self.core.input.clone()
1193 }
1194
1195 fn clone_with_input(&self, input: PlanRef) -> Self {
1196 Agg::new(self.agg_calls().clone(), self.group_key().clone(), input)
1197 .with_grouping_sets(self.grouping_sets().clone())
1198 .with_enable_two_phase(self.core().enable_two_phase)
1199 .into()
1200 }
1201
1202 fn rewrite_with_input(
1203 &self,
1204 input: PlanRef,
1205 input_col_change: ColIndexMapping,
1206 ) -> (Self, ColIndexMapping) {
1207 self.rewrite_with_input_agg(input, self.agg_calls(), input_col_change)
1208 }
1209}
1210
1211impl_plan_tree_node_for_unary! { Logical, LogicalAgg }
1212impl_distill_by_unit!(LogicalAgg, core, "LogicalAgg");
1213
1214impl ExprRewritable<Logical> for LogicalAgg {
1215 fn has_rewritable_expr(&self) -> bool {
1216 true
1217 }
1218
1219 fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
1220 let mut core = self.core.clone();
1221 core.rewrite_exprs(r);
1222 Self {
1223 base: self.base.clone_with_new_plan_id(),
1224 core,
1225 }
1226 .into()
1227 }
1228}
1229
1230impl ExprVisitable for LogicalAgg {
1231 fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
1232 self.core.visit_exprs(v);
1233 }
1234}
1235
1236impl ColPrunable for LogicalAgg {
1237 fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
1238 let group_key_required_cols = self.group_key().to_bitset();
1239
1240 let (agg_call_required_cols, agg_calls) = {
1241 let input_cnt = self.input().schema().len();
1242 let mut tmp = FixedBitSet::with_capacity(input_cnt);
1243 let group_key_cardinality = self.group_key().len();
1244 let new_agg_calls = required_cols
1245 .iter()
1246 .filter(|&&index| index >= group_key_cardinality)
1247 .map(|&index| {
1248 let index = index - group_key_cardinality;
1249 let agg_call = self.agg_calls()[index].clone();
1250 tmp.extend(agg_call.inputs.iter().map(|x| x.index()));
1251 tmp.extend(agg_call.order_by.iter().map(|x| x.column_index));
1252 for i in &agg_call.filter.conjunctions {
1254 tmp.union_with(&i.collect_input_refs(input_cnt));
1255 }
1256 agg_call
1257 })
1258 .collect_vec();
1259 (tmp, new_agg_calls)
1260 };
1261
1262 let input_required_cols = {
1263 let mut tmp = FixedBitSet::with_capacity(self.input().schema().len());
1264 tmp.union_with(&group_key_required_cols);
1265 tmp.union_with(&agg_call_required_cols);
1266 tmp.ones().collect_vec()
1267 };
1268 let input_col_change = ColIndexMapping::with_remaining_columns(
1269 &input_required_cols,
1270 self.input().schema().len(),
1271 );
1272 let agg = {
1273 let input = self.input().prune_col(&input_required_cols, ctx);
1274 let (agg, output_col_change) =
1275 self.rewrite_with_input_agg(input, &agg_calls, input_col_change);
1276 assert!(output_col_change.is_identity());
1277 agg
1278 };
1279 let new_output_cols = {
1280 let group_key_cardinality = agg.group_key().len();
1282 let mut tmp = (0..group_key_cardinality).collect_vec();
1283 tmp.extend(
1284 required_cols
1285 .iter()
1286 .filter(|&&index| index >= group_key_cardinality),
1287 );
1288 tmp
1289 };
1290 if new_output_cols == required_cols {
1291 agg.into()
1293 } else {
1294 let mapping =
1297 &ColIndexMapping::with_remaining_columns(&new_output_cols, self.schema().len());
1298 let output_required_cols = required_cols
1299 .iter()
1300 .map(|&idx| mapping.map(idx))
1301 .collect_vec();
1302 let src_size = agg.schema().len();
1303 LogicalProject::with_mapping(
1304 agg.into(),
1305 ColIndexMapping::with_remaining_columns(&output_required_cols, src_size),
1306 )
1307 .into()
1308 }
1309 }
1310}
1311
1312impl PredicatePushdown for LogicalAgg {
1313 fn predicate_pushdown(
1314 &self,
1315 predicate: Condition,
1316 ctx: &mut PredicatePushdownContext,
1317 ) -> PlanRef {
1318 let num_group_key = self.group_key().len();
1319 let num_agg_calls = self.agg_calls().len();
1320 assert!(num_group_key + num_agg_calls == self.schema().len());
1321
1322 if num_group_key == 0 {
1330 return gen_filter_and_pushdown(self, predicate, Condition::true_cond(), ctx);
1331 }
1332
1333 let mut agg_call_columns = FixedBitSet::with_capacity(num_group_key + num_agg_calls);
1335 agg_call_columns.insert_range(num_group_key..num_group_key + num_agg_calls);
1336 let (agg_call_pred, pushed_predicate) = predicate.split_disjoint(&agg_call_columns);
1337
1338 let mut subst = Substitute {
1340 mapping: self
1341 .group_key()
1342 .indices()
1343 .enumerate()
1344 .map(|(i, group_key)| {
1345 InputRef::new(group_key, self.schema().fields()[i].data_type()).into()
1346 })
1347 .collect(),
1348 };
1349 let pushed_predicate = pushed_predicate.rewrite_expr(&mut subst);
1350
1351 gen_filter_and_pushdown(self, agg_call_pred, pushed_predicate, ctx)
1352 }
1353}
1354
1355impl ToBatch for LogicalAgg {
1356 fn to_batch(&self) -> Result<crate::optimizer::plan_node::BatchPlanRef> {
1357 self.to_batch_with_order_required(&Order::any())
1358 }
1359
1360 fn to_batch_with_order_required(
1363 &self,
1364 required_order: &Order,
1365 ) -> Result<crate::optimizer::plan_node::BatchPlanRef> {
1366 let input = self.input().to_batch()?;
1367 let new_logical = self.core.clone_with_input(input);
1368 let agg_plan = if self.group_key().is_empty() {
1369 BatchSimpleAgg::new(new_logical).into()
1370 } else if self.ctx().session_ctx().config().batch_enable_sort_agg()
1371 && new_logical.input_provides_order_on_group_keys()
1372 {
1373 BatchSortAgg::new(new_logical).into()
1374 } else {
1375 BatchHashAgg::new(new_logical).into()
1376 };
1377 required_order.enforce_if_not_satisfies(agg_plan)
1378 }
1379}
1380
1381fn find_or_append_row_count(mut logical: Agg<StreamPlanRef>) -> (Agg<StreamPlanRef>, usize) {
1382 let count_star = PlanAggCall::count_star();
1385 let row_count_idx = if let Some((idx, _)) = logical
1386 .agg_calls
1387 .iter()
1388 .find_position(|&c| c == &count_star)
1389 {
1390 idx
1391 } else {
1392 let idx = logical.agg_calls.len();
1393 logical.agg_calls.push(count_star);
1394 idx
1395 };
1396 (logical, row_count_idx)
1397}
1398
1399fn new_stream_simple_agg(
1400 core: Agg<StreamPlanRef>,
1401 must_output_per_barrier: bool,
1402) -> Result<StreamSimpleAgg> {
1403 let (logical, row_count_idx) = find_or_append_row_count(core);
1404 StreamSimpleAgg::new(logical, row_count_idx, must_output_per_barrier)
1405}
1406
1407fn new_stream_hash_agg(
1408 core: Agg<StreamPlanRef>,
1409 vnode_col_idx: Option<usize>,
1410) -> Result<StreamHashAgg> {
1411 let (logical, row_count_idx) = find_or_append_row_count(core);
1412 StreamHashAgg::new(logical, vnode_col_idx, row_count_idx)
1413}
1414
1415impl ToStream for LogicalAgg {
1416 fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<StreamPlanRef> {
1417 use super::stream::prelude::*;
1418
1419 let eowc = ctx.emit_on_window_close();
1420 let input = if self.group_key().is_empty() {
1421 self.input()
1422 } else {
1423 try_enforce_locality_requirement(self.input(), &self.group_key().to_vec())
1424 };
1425
1426 let stream_input = input.to_stream(ctx)?;
1427
1428 if stream_input.append_only() && self.agg_calls().is_empty() && !self.group_key().is_empty()
1430 {
1431 let input = if self.group_key().len() != self.input().schema().len() {
1432 let cols = &self.group_key().to_vec();
1433 LogicalProject::with_mapping(
1434 self.input(),
1435 ColIndexMapping::with_remaining_columns(cols, self.input().schema().len()),
1436 )
1437 .into()
1438 } else {
1439 self.input()
1440 };
1441 let input_schema_len = input.schema().len();
1442 let logical_dedup = LogicalDedup::new(input, (0..input_schema_len).collect());
1443 return logical_dedup.to_stream(ctx);
1444 }
1445
1446 if self.agg_calls().iter().any(|call| {
1447 matches!(
1448 call.agg_type,
1449 AggType::Builtin(PbAggKind::ApproxCountDistinct)
1450 )
1451 }) {
1452 if stream_input.append_only() {
1453 self.core.ctx().session_ctx().notice_to_user(
1454 "Streaming `APPROX_COUNT_DISTINCT` is still a preview feature and subject to change. Please do not use it in production environment.",
1455 );
1456 } else {
1457 bail_not_implemented!(
1458 "Streaming `APPROX_COUNT_DISTINCT` is only supported in append-only stream"
1459 );
1460 }
1461 }
1462
1463 let plan = self.gen_dist_stream_agg_plan(stream_input)?;
1464
1465 let (plan, n_final_agg_calls) = if let Some(final_agg) = plan.as_stream_simple_agg() {
1466 if eowc {
1467 return Err(ErrorCode::InvalidInputSyntax(
1468 "`EMIT ON WINDOW CLOSE` cannot be used for aggregation without `GROUP BY`"
1469 .to_owned(),
1470 )
1471 .into());
1472 }
1473 (plan.clone(), final_agg.agg_calls().len())
1474 } else if let Some(final_agg) = plan.as_stream_hash_agg() {
1475 (
1476 if eowc {
1477 final_agg.to_eowc_version()?
1478 } else {
1479 plan.clone()
1480 },
1481 final_agg.agg_calls().len(),
1482 )
1483 } else if let Some(_approx_percentile_agg) = plan.as_stream_global_approx_percentile() {
1484 if eowc {
1485 return Err(ErrorCode::InvalidInputSyntax(
1486 "`EMIT ON WINDOW CLOSE` cannot be used for aggregation without `GROUP BY`"
1487 .to_owned(),
1488 )
1489 .into());
1490 }
1491 (plan.clone(), 1)
1492 } else if let Some(stream_row_merge) = plan.as_stream_row_merge() {
1493 if eowc {
1494 return Err(ErrorCode::InvalidInputSyntax(
1495 "`EMIT ON WINDOW CLOSE` cannot be used for aggregation without `GROUP BY`"
1496 .to_owned(),
1497 )
1498 .into());
1499 }
1500 (plan.clone(), stream_row_merge.base.schema().len())
1501 } else {
1502 panic!(
1503 "the root PlanNode must be StreamHashAgg, StreamSimpleAgg, StreamGlobalApproxPercentile, or StreamRowMerge"
1504 );
1505 };
1506
1507 if self.agg_calls().len() == n_final_agg_calls {
1508 Ok(plan)
1510 } else {
1511 assert_eq!(self.agg_calls().len() + 1, n_final_agg_calls);
1513
1514 let mut project = StreamProject::new(generic::Project::with_out_col_idx(
1515 plan,
1516 0..self.schema().len(),
1517 ));
1518 if self.agg_calls().is_empty() {
1523 project = project.with_noop_update_hint(true);
1524 }
1525 Ok(project.into())
1526 }
1527 }
1528
1529 fn logical_rewrite_for_stream(
1530 &self,
1531 ctx: &mut RewriteStreamContext,
1532 ) -> Result<(PlanRef, ColIndexMapping)> {
1533 let (input, input_col_change) = self.input().logical_rewrite_for_stream(ctx)?;
1534 let (agg, out_col_change) = self.rewrite_with_input(input, input_col_change);
1535 let (map, _) = out_col_change.into_parts();
1536 let out_col_change = ColIndexMapping::new(map, agg.schema().len());
1537 Ok((agg.into(), out_col_change))
1538 }
1539}
1540
1541#[cfg(test)]
1542mod tests {
1543 use risingwave_common::catalog::{Field, Schema};
1544
1545 use super::*;
1546 use crate::expr::{assert_eq_input_ref, input_ref_to_column_indices};
1547 use crate::optimizer::optimizer_context::OptimizerContext;
1548 use crate::optimizer::plan_node::LogicalValues;
1549
1550 #[tokio::test]
1551 async fn test_create() {
1552 let ty = DataType::Int32;
1553 let ctx = OptimizerContext::mock().await;
1554 let fields: Vec<Field> = vec![
1555 Field::with_name(ty.clone(), "v1"),
1556 Field::with_name(ty.clone(), "v2"),
1557 Field::with_name(ty.clone(), "v3"),
1558 ];
1559 let values = LogicalValues::new(vec![], Schema { fields }, ctx);
1560 let input = PlanRef::from(values);
1561 let input_ref_1 = InputRef::new(0, ty.clone());
1562 let input_ref_2 = InputRef::new(1, ty.clone());
1563 let input_ref_3 = InputRef::new(2, ty.clone());
1564
1565 let gen_internal_value = |select_exprs: Vec<ExprImpl>,
1566 group_exprs|
1567 -> (Vec<ExprImpl>, Vec<PlanAggCall>, IndexSet) {
1568 let (plan, exprs, _) = LogicalAgg::create(
1569 select_exprs,
1570 GroupBy::GroupKey(group_exprs),
1571 None,
1572 input.clone(),
1573 )
1574 .unwrap();
1575
1576 let logical_agg = plan.as_logical_agg().unwrap();
1577 let agg_calls = logical_agg.agg_calls().clone();
1578 let group_key = logical_agg.group_key().clone();
1579
1580 (exprs, agg_calls, group_key)
1581 };
1582
1583 {
1585 let select_exprs = vec![input_ref_1.clone().into()];
1586 let group_exprs = vec![input_ref_1.clone().into()];
1587
1588 let (exprs, agg_calls, group_key) = gen_internal_value(select_exprs, group_exprs);
1589
1590 assert_eq!(exprs.len(), 1);
1591 assert_eq_input_ref!(&exprs[0], 0);
1592
1593 assert_eq!(agg_calls.len(), 0);
1594 assert_eq!(group_key, vec![0].into());
1595 }
1596
1597 {
1599 let min_v2 = AggCall::new(
1600 PbAggKind::Min.into(),
1601 vec![input_ref_2.clone().into()],
1602 false,
1603 OrderBy::any(),
1604 Condition::true_cond(),
1605 vec![],
1606 )
1607 .unwrap();
1608 let select_exprs = vec![input_ref_1.clone().into(), min_v2.into()];
1609 let group_exprs = vec![input_ref_1.clone().into()];
1610
1611 let (exprs, agg_calls, group_key) = gen_internal_value(select_exprs, group_exprs);
1612
1613 assert_eq!(exprs.len(), 2);
1614 assert_eq_input_ref!(&exprs[0], 0);
1615 assert_eq_input_ref!(&exprs[1], 1);
1616
1617 assert_eq!(agg_calls.len(), 1);
1618 assert_eq!(agg_calls[0].agg_type, PbAggKind::Min.into());
1619 assert_eq!(input_ref_to_column_indices(&agg_calls[0].inputs), vec![1]);
1620 assert_eq!(group_key, vec![0].into());
1621 }
1622
1623 {
1625 let min_v2 = AggCall::new(
1626 PbAggKind::Min.into(),
1627 vec![input_ref_2.clone().into()],
1628 false,
1629 OrderBy::any(),
1630 Condition::true_cond(),
1631 vec![],
1632 )
1633 .unwrap();
1634 let max_v3 = AggCall::new(
1635 PbAggKind::Max.into(),
1636 vec![input_ref_3.clone().into()],
1637 false,
1638 OrderBy::any(),
1639 Condition::true_cond(),
1640 vec![],
1641 )
1642 .unwrap();
1643 let func_call =
1644 FunctionCall::new(ExprType::Add, vec![min_v2.into(), max_v3.into()]).unwrap();
1645 let select_exprs = vec![input_ref_1.clone().into(), ExprImpl::from(func_call)];
1646 let group_exprs = vec![input_ref_1.clone().into()];
1647
1648 let (exprs, agg_calls, group_key) = gen_internal_value(select_exprs, group_exprs);
1649
1650 assert_eq_input_ref!(&exprs[0], 0);
1651 if let ExprImpl::FunctionCall(func_call) = &exprs[1] {
1652 assert_eq!(func_call.func_type(), ExprType::Add);
1653 let inputs = func_call.inputs();
1654 assert_eq_input_ref!(&inputs[0], 1);
1655 assert_eq_input_ref!(&inputs[1], 2);
1656 } else {
1657 panic!("Wrong expression type!");
1658 }
1659
1660 assert_eq!(agg_calls.len(), 2);
1661 assert_eq!(agg_calls[0].agg_type, PbAggKind::Min.into());
1662 assert_eq!(input_ref_to_column_indices(&agg_calls[0].inputs), vec![1]);
1663 assert_eq!(agg_calls[1].agg_type, PbAggKind::Max.into());
1664 assert_eq!(input_ref_to_column_indices(&agg_calls[1].inputs), vec![2]);
1665 assert_eq!(group_key, vec![0].into());
1666 }
1667
1668 {
1670 let v1_mult_v3 = FunctionCall::new(
1671 ExprType::Multiply,
1672 vec![input_ref_1.into(), input_ref_3.into()],
1673 )
1674 .unwrap();
1675 let agg_call = AggCall::new(
1676 PbAggKind::Min.into(),
1677 vec![v1_mult_v3.into()],
1678 false,
1679 OrderBy::any(),
1680 Condition::true_cond(),
1681 vec![],
1682 )
1683 .unwrap();
1684 let select_exprs = vec![input_ref_2.clone().into(), agg_call.into()];
1685 let group_exprs = vec![input_ref_2.into()];
1686
1687 let (exprs, agg_calls, group_key) = gen_internal_value(select_exprs, group_exprs);
1688
1689 assert_eq_input_ref!(&exprs[0], 0);
1690 assert_eq_input_ref!(&exprs[1], 1);
1691
1692 assert_eq!(agg_calls.len(), 1);
1693 assert_eq!(agg_calls[0].agg_type, PbAggKind::Min.into());
1694 assert_eq!(input_ref_to_column_indices(&agg_calls[0].inputs), vec![1]);
1695 assert_eq!(group_key, vec![0].into());
1696 }
1697 }
1698
1699 async fn generate_agg_call(ty: DataType, fields: Vec<Field>) -> LogicalAgg {
1706 let ctx = OptimizerContext::mock().await;
1707
1708 let values = LogicalValues::new(vec![], Schema { fields }, ctx);
1709 let agg_call = PlanAggCall {
1710 agg_type: PbAggKind::Min.into(),
1711 return_type: ty.clone(),
1712 inputs: vec![InputRef::new(2, ty.clone())],
1713 distinct: false,
1714 order_by: vec![],
1715 filter: Condition::true_cond(),
1716 direct_args: vec![],
1717 };
1718 Agg::new(vec![agg_call], vec![1].into(), values.into()).into()
1719 }
1720
1721 #[tokio::test]
1722 async fn test_prune_all() {
1733 let ty = DataType::Int32;
1734 let fields: Vec<Field> = vec![
1735 Field::with_name(ty.clone(), "v1"),
1736 Field::with_name(ty.clone(), "v2"),
1737 Field::with_name(ty.clone(), "v3"),
1738 ];
1739 let agg: PlanRef = generate_agg_call(ty.clone(), fields.clone()).await.into();
1740 let required_cols = vec![0, 1];
1742 let plan = agg.prune_col(&required_cols, &mut ColumnPruningContext::new(agg.clone()));
1743
1744 let agg_new = plan.as_logical_agg().unwrap();
1746 assert_eq!(agg_new.group_key(), &vec![0].into());
1747
1748 assert_eq!(agg_new.agg_calls().len(), 1);
1749 let agg_call_new = agg_new.agg_calls()[0].clone();
1750 assert_eq!(agg_call_new.agg_type, PbAggKind::Min.into());
1751 assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![1]);
1752 assert_eq!(agg_call_new.return_type, ty);
1753
1754 let values = agg_new.input();
1755 let values = values.as_logical_values().unwrap();
1756 assert_eq!(values.schema().fields(), &fields[1..]);
1757 }
1758
1759 #[tokio::test]
1760 async fn test_prune_all_with_order_required() {
1772 let ty = DataType::Int32;
1773 let fields: Vec<Field> = vec![
1774 Field::with_name(ty.clone(), "v1"),
1775 Field::with_name(ty.clone(), "v2"),
1776 Field::with_name(ty.clone(), "v3"),
1777 ];
1778 let agg: PlanRef = generate_agg_call(ty.clone(), fields.clone()).await.into();
1779 let required_cols = vec![1, 0];
1781 let plan = agg.prune_col(&required_cols, &mut ColumnPruningContext::new(agg.clone()));
1782 let proj = plan.as_logical_project().unwrap();
1784 assert_eq!(proj.exprs().len(), 2);
1785 assert_eq!(proj.exprs()[0].as_input_ref().unwrap().index(), 1);
1786 assert_eq!(proj.exprs()[1].as_input_ref().unwrap().index(), 0);
1787 let proj_input = proj.input();
1788 let agg_new = proj_input.as_logical_agg().unwrap();
1789 assert_eq!(agg_new.group_key(), &vec![0].into());
1790
1791 assert_eq!(agg_new.agg_calls().len(), 1);
1792 let agg_call_new = agg_new.agg_calls()[0].clone();
1793 assert_eq!(agg_call_new.agg_type, PbAggKind::Min.into());
1794 assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![1]);
1795 assert_eq!(agg_call_new.return_type, ty);
1796
1797 let values = agg_new.input();
1798 let values = values.as_logical_values().unwrap();
1799 assert_eq!(values.schema().fields(), &fields[1..]);
1800 }
1801
1802 #[tokio::test]
1803 async fn test_prune_group_key() {
1815 let ctx = OptimizerContext::mock().await;
1816 let ty = DataType::Int32;
1817 let fields: Vec<Field> = vec![
1818 Field::with_name(ty.clone(), "v1"),
1819 Field::with_name(ty.clone(), "v2"),
1820 Field::with_name(ty.clone(), "v3"),
1821 ];
1822 let values: LogicalValues = LogicalValues::new(
1823 vec![],
1824 Schema {
1825 fields: fields.clone(),
1826 },
1827 ctx,
1828 );
1829 let agg_call = PlanAggCall {
1830 agg_type: PbAggKind::Min.into(),
1831 return_type: ty.clone(),
1832 inputs: vec![InputRef::new(2, ty.clone())],
1833 distinct: false,
1834 order_by: vec![],
1835 filter: Condition::true_cond(),
1836 direct_args: vec![],
1837 };
1838 let agg: PlanRef = Agg::new(vec![agg_call], vec![1].into(), values.into()).into();
1839
1840 let required_cols = vec![1];
1842 let plan = agg.prune_col(&required_cols, &mut ColumnPruningContext::new(agg.clone()));
1843
1844 let project = plan.as_logical_project().unwrap();
1846 assert_eq!(project.exprs().len(), 1);
1847 assert_eq_input_ref!(&project.exprs()[0], 1);
1848
1849 let agg_new = project.input();
1850 let agg_new = agg_new.as_logical_agg().unwrap();
1851 assert_eq!(agg_new.group_key(), &vec![0].into());
1852
1853 assert_eq!(agg_new.agg_calls().len(), 1);
1854 let agg_call_new = agg_new.agg_calls()[0].clone();
1855 assert_eq!(agg_call_new.agg_type, PbAggKind::Min.into());
1856 assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![1]);
1857 assert_eq!(agg_call_new.return_type, ty);
1858
1859 let values = agg_new.input();
1860 let values = values.as_logical_values().unwrap();
1861 assert_eq!(values.schema().fields(), &fields[1..]);
1862 }
1863
1864 #[tokio::test]
1865 async fn test_prune_agg() {
1877 let ty = DataType::Int32;
1878 let ctx = OptimizerContext::mock().await;
1879 let fields: Vec<Field> = vec![
1880 Field::with_name(ty.clone(), "v1"),
1881 Field::with_name(ty.clone(), "v2"),
1882 Field::with_name(ty.clone(), "v3"),
1883 ];
1884 let values = LogicalValues::new(
1885 vec![],
1886 Schema {
1887 fields: fields.clone(),
1888 },
1889 ctx,
1890 );
1891
1892 let agg_calls = vec![
1893 PlanAggCall {
1894 agg_type: PbAggKind::Min.into(),
1895 return_type: ty.clone(),
1896 inputs: vec![InputRef::new(2, ty.clone())],
1897 distinct: false,
1898 order_by: vec![],
1899 filter: Condition::true_cond(),
1900 direct_args: vec![],
1901 },
1902 PlanAggCall {
1903 agg_type: PbAggKind::Max.into(),
1904 return_type: ty.clone(),
1905 inputs: vec![InputRef::new(1, ty.clone())],
1906 distinct: false,
1907 order_by: vec![],
1908 filter: Condition::true_cond(),
1909 direct_args: vec![],
1910 },
1911 ];
1912 let agg: PlanRef = Agg::new(agg_calls, vec![1, 2].into(), values.into()).into();
1913
1914 let required_cols = vec![0, 3];
1916 let plan = agg.prune_col(&required_cols, &mut ColumnPruningContext::new(agg.clone()));
1917 let project = plan.as_logical_project().unwrap();
1919 assert_eq!(project.exprs().len(), 2);
1920 assert_eq_input_ref!(&project.exprs()[0], 0);
1921 assert_eq_input_ref!(&project.exprs()[1], 2);
1922
1923 let agg_new = project.input();
1924 let agg_new = agg_new.as_logical_agg().unwrap();
1925 assert_eq!(agg_new.group_key(), &vec![0, 1].into());
1926
1927 assert_eq!(agg_new.agg_calls().len(), 1);
1928 let agg_call_new = agg_new.agg_calls()[0].clone();
1929 assert_eq!(agg_call_new.agg_type, PbAggKind::Max.into());
1930 assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![0]);
1931 assert_eq!(agg_call_new.return_type, ty);
1932
1933 let values = agg_new.input();
1934 let values = values.as_logical_values().unwrap();
1935 assert_eq!(values.schema().fields(), &fields[1..]);
1936 }
1937}