1use fixedbitset::FixedBitSet;
16use itertools::Itertools;
17use risingwave_common::types::{DataType, ScalarImpl};
18use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
19use risingwave_common::{bail, bail_not_implemented, not_implemented};
20use risingwave_expr::aggregate::{AggType, PbAggKind, agg_types};
21
22use super::generic::{self, Agg, GenericPlanRef, PlanAggCall, ProjectBuilder};
23use super::utils::impl_distill_by_unit;
24use super::{
25 BatchHashAgg, BatchSimpleAgg, ColPrunable, ExprRewritable, Logical, PlanBase, PlanRef,
26 PlanTreeNodeUnary, PredicatePushdown, StreamHashAgg, StreamProject, StreamShare,
27 StreamSimpleAgg, StreamStatelessSimpleAgg, ToBatch, ToStream,
28};
29use crate::error::{ErrorCode, Result, RwError};
30use crate::expr::{
31 AggCall, Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, FunctionCall, InputRef, Literal,
32 OrderBy, WindowFunction,
33};
34use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
35use crate::optimizer::plan_node::generic::GenericPlanNode;
36use crate::optimizer::plan_node::stream_global_approx_percentile::StreamGlobalApproxPercentile;
37use crate::optimizer::plan_node::stream_local_approx_percentile::StreamLocalApproxPercentile;
38use crate::optimizer::plan_node::stream_row_merge::StreamRowMerge;
39use crate::optimizer::plan_node::{
40 BatchSortAgg, ColumnPruningContext, LogicalDedup, LogicalProject, PredicatePushdownContext,
41 RewriteStreamContext, ToStreamContext, gen_filter_and_pushdown,
42};
43use crate::optimizer::property::{Distribution, Order, RequiredDist};
44use crate::utils::{
45 ColIndexMapping, ColIndexMappingRewriteExt, Condition, GroupBy, IndexSet, Substitute,
46};
47
48pub struct AggInfo {
49 pub calls: Vec<PlanAggCall>,
50 pub col_mapping: ColIndexMapping,
51}
52
53pub struct SeparatedAggInfo {
55 normal: AggInfo,
56 approx: AggInfo,
57}
58
59#[derive(Clone, Debug, PartialEq, Eq, Hash)]
66pub struct LogicalAgg {
67 pub base: PlanBase<Logical>,
68 core: Agg<PlanRef>,
69}
70
71impl LogicalAgg {
72 fn gen_stateless_two_phase_streaming_agg_plan(&self, stream_input: PlanRef) -> Result<PlanRef> {
75 debug_assert!(self.group_key().is_empty());
76 let mut core = self.core.clone();
77
78 let (non_approx_percentile_col_mapping, approx_percentile_col_mapping, approx_percentile) =
80 self.prepare_approx_percentile(&mut core, stream_input.clone())?;
81
82 if core.agg_calls.is_empty() {
83 if let Some(approx_percentile) = approx_percentile {
84 return Ok(approx_percentile);
85 };
86 bail!("expected at least one agg call");
87 }
88
89 let need_row_merge: bool = Self::need_row_merge(&approx_percentile);
90
91 let total_agg_calls = core
93 .agg_calls
94 .iter()
95 .enumerate()
96 .map(|(partial_output_idx, agg_call)| {
97 agg_call.partial_to_total_agg_call(partial_output_idx)
98 })
99 .collect_vec();
100 let local_agg = StreamStatelessSimpleAgg::new(core);
101 let exchange =
102 RequiredDist::single().enforce_if_not_satisfies(local_agg.into(), &Order::any())?;
103
104 let must_output_per_barrier = need_row_merge;
105 let global_agg = new_stream_simple_agg(
106 Agg::new(total_agg_calls, IndexSet::empty(), exchange),
107 must_output_per_barrier,
108 );
109
110 Self::add_row_merge_if_needed(
112 approx_percentile,
113 global_agg.into(),
114 approx_percentile_col_mapping,
115 non_approx_percentile_col_mapping,
116 )
117 }
118
119 fn gen_vnode_two_phase_streaming_agg_plan(
123 &self,
124 stream_input: PlanRef,
125 dist_key: &[usize],
126 ) -> Result<PlanRef> {
127 let mut core = self.core.clone();
128
129 let (non_approx_percentile_col_mapping, approx_percentile_col_mapping, approx_percentile) =
130 self.prepare_approx_percentile(&mut core, stream_input.clone())?;
131
132 if core.agg_calls.is_empty() {
133 if let Some(approx_percentile) = approx_percentile {
134 return Ok(approx_percentile);
135 };
136 bail!("expected at least one agg call");
137 }
138 let need_row_merge = Self::need_row_merge(&approx_percentile);
139
140 let input_col_num = stream_input.schema().len(); let project = StreamProject::new(generic::Project::with_vnode_col(stream_input, dist_key));
144 let vnode_col_idx = project.base.schema().len() - 1;
145
146 let mut local_group_key = self.group_key().clone();
148 local_group_key.insert(vnode_col_idx);
149 let n_local_group_key = local_group_key.len();
150 let local_agg = new_stream_hash_agg(
151 Agg::new(core.agg_calls.to_vec(), local_group_key, project.into()),
152 Some(vnode_col_idx),
153 );
154 let local_agg_group_key_cardinality = local_agg.group_key().len();
156 let local_group_key_without_vnode =
157 &local_agg.group_key().to_vec()[..local_agg_group_key_cardinality - 1];
158 let global_group_key = local_agg
159 .i2o_col_mapping()
160 .rewrite_dist_key(local_group_key_without_vnode)
161 .expect("some input group key could not be mapped");
162
163 let global_agg = if self.group_key().is_empty() {
165 let exchange =
166 RequiredDist::single().enforce_if_not_satisfies(local_agg.into(), &Order::any())?;
167 let must_output_per_barrier = need_row_merge;
168 let global_agg = new_stream_simple_agg(
169 Agg::new(
170 core.agg_calls
171 .iter()
172 .enumerate()
173 .map(|(partial_output_idx, agg_call)| {
174 agg_call
175 .partial_to_total_agg_call(n_local_group_key + partial_output_idx)
176 })
177 .collect(),
178 global_group_key.into_iter().collect(),
179 exchange,
180 ),
181 must_output_per_barrier,
182 );
183 global_agg.into()
184 } else {
185 assert!(!need_row_merge);
187 let exchange = RequiredDist::shard_by_key(input_col_num, &global_group_key)
188 .enforce_if_not_satisfies(local_agg.into(), &Order::any())?;
189 let global_agg = new_stream_hash_agg(
192 Agg::new(
193 core.agg_calls
194 .iter()
195 .enumerate()
196 .map(|(partial_output_idx, agg_call)| {
197 agg_call
198 .partial_to_total_agg_call(n_local_group_key + partial_output_idx)
199 })
200 .collect(),
201 global_group_key.into_iter().collect(),
202 exchange,
203 ),
204 None,
205 );
206 global_agg.into()
207 };
208 Self::add_row_merge_if_needed(
209 approx_percentile,
210 global_agg,
211 approx_percentile_col_mapping,
212 non_approx_percentile_col_mapping,
213 )
214 }
215
216 fn gen_single_plan(&self, stream_input: PlanRef) -> Result<PlanRef> {
217 let mut core = self.core.clone();
218 let input = RequiredDist::single().enforce_if_not_satisfies(stream_input, &Order::any())?;
219 core.input = input;
220 Ok(new_stream_simple_agg(core, false).into())
221 }
222
223 fn gen_shuffle_plan(&self, stream_input: PlanRef) -> Result<PlanRef> {
224 let input =
225 RequiredDist::shard_by_key(stream_input.schema().len(), &self.group_key().to_vec())
226 .enforce_if_not_satisfies(stream_input, &Order::any())?;
227 let mut core = self.core.clone();
228 core.input = input;
229 Ok(new_stream_hash_agg(core, None).into())
230 }
231
232 fn gen_dist_stream_agg_plan(&self, stream_input: PlanRef) -> Result<PlanRef> {
234 use super::stream::prelude::*;
235
236 let input_dist = stream_input.distribution();
237 debug_assert!(*input_dist != Distribution::Broadcast);
238
239 if !self.group_key().is_empty() && !self.core.must_try_two_phase_agg() {
243 return self.gen_shuffle_plan(stream_input);
244 }
245
246 if self.group_key().is_empty() && !self.core.can_two_phase_agg() {
249 return self.gen_single_plan(stream_input);
250 }
251
252 debug_assert!(if !self.group_key().is_empty() {
253 self.core.must_try_two_phase_agg()
254 } else {
255 self.core.can_two_phase_agg()
256 });
257
258 if self.group_key().is_empty()
262 && self
263 .core
264 .all_local_aggs_are_stateless(stream_input.append_only())
265 && input_dist.satisfies(&RequiredDist::AnyShard)
266 {
267 return self.gen_stateless_two_phase_streaming_agg_plan(stream_input);
268 }
269
270 let stream_input =
275 if *input_dist == Distribution::SomeShard && self.core.must_try_two_phase_agg() {
276 RequiredDist::shard_by_key(
277 stream_input.schema().len(),
278 stream_input.expect_stream_key(),
279 )
280 .enforce_if_not_satisfies(stream_input, &Order::any())?
281 } else {
282 stream_input
283 };
284 let input_dist = stream_input.distribution();
285
286 match input_dist {
290 Distribution::HashShard(dist_key) | Distribution::UpstreamHashShard(dist_key, _)
291 if (self.group_key().is_empty()
292 || !self.core.hash_agg_dist_satisfied_by_input_dist(input_dist)) =>
293 {
294 let dist_key = dist_key.clone();
295 return self.gen_vnode_two_phase_streaming_agg_plan(stream_input, &dist_key);
296 }
297 _ => {}
298 }
299
300 if !self.group_key().is_empty() {
302 self.gen_shuffle_plan(stream_input)
303 } else {
304 self.gen_single_plan(stream_input)
305 }
306 }
307
308 fn prepare_approx_percentile(
313 &self,
314 core: &mut Agg<PlanRef>,
315 stream_input: PlanRef,
316 ) -> Result<(ColIndexMapping, ColIndexMapping, Option<PlanRef>)> {
317 let SeparatedAggInfo { normal, approx } = self.separate_normal_and_special_agg();
318
319 let AggInfo {
320 calls: non_approx_percentile_agg_calls,
321 col_mapping: non_approx_percentile_col_mapping,
322 } = normal;
323 let AggInfo {
324 calls: approx_percentile_agg_calls,
325 col_mapping: approx_percentile_col_mapping,
326 } = approx;
327 if !self.group_key().is_empty() && !approx_percentile_agg_calls.is_empty() {
328 bail_not_implemented!(
329 "two-phase streaming approx percentile aggregation with group key, \
330 please use single phase aggregation instead"
331 );
332 }
333
334 let needs_row_merge = (!non_approx_percentile_agg_calls.is_empty()
337 && !approx_percentile_agg_calls.is_empty())
338 || approx_percentile_agg_calls.len() >= 2;
339 core.input = if needs_row_merge {
340 StreamShare::new_from_input(stream_input.clone()).into()
342 } else {
343 stream_input
344 };
345 core.agg_calls = non_approx_percentile_agg_calls;
346
347 let approx_percentile =
348 self.build_approx_percentile_aggs(core.input.clone(), &approx_percentile_agg_calls)?;
349 Ok((
350 non_approx_percentile_col_mapping,
351 approx_percentile_col_mapping,
352 approx_percentile,
353 ))
354 }
355
356 fn need_row_merge(approx_percentile: &Option<PlanRef>) -> bool {
357 approx_percentile.is_some()
358 }
359
360 fn add_row_merge_if_needed(
362 approx_percentile: Option<PlanRef>,
363 global_agg: PlanRef,
364 approx_percentile_col_mapping: ColIndexMapping,
365 non_approx_percentile_col_mapping: ColIndexMapping,
366 ) -> Result<PlanRef> {
367 let need_row_merge = Self::need_row_merge(&approx_percentile);
369
370 if let Some(approx_percentile) = approx_percentile {
371 assert!(need_row_merge);
372 let row_merge = StreamRowMerge::new(
373 approx_percentile,
374 global_agg,
375 approx_percentile_col_mapping,
376 non_approx_percentile_col_mapping,
377 )?;
378 Ok(row_merge.into())
379 } else {
380 assert!(!need_row_merge);
381 Ok(global_agg)
382 }
383 }
384
385 fn separate_normal_and_special_agg(&self) -> SeparatedAggInfo {
386 let estimated_len = self.agg_calls().len() - 1;
387 let mut approx_percentile_agg_calls = Vec::with_capacity(estimated_len);
388 let mut non_approx_percentile_agg_calls = Vec::with_capacity(estimated_len);
389 let mut approx_percentile_col_mapping = Vec::with_capacity(estimated_len);
390 let mut non_approx_percentile_col_mapping = Vec::with_capacity(estimated_len);
391 for (output_idx, agg_call) in self.agg_calls().iter().enumerate() {
392 if agg_call.agg_type == AggType::Builtin(PbAggKind::ApproxPercentile) {
393 approx_percentile_agg_calls.push(agg_call.clone());
394 approx_percentile_col_mapping.push(Some(output_idx));
395 } else {
396 non_approx_percentile_agg_calls.push(agg_call.clone());
397 non_approx_percentile_col_mapping.push(Some(output_idx));
398 }
399 }
400 let normal = AggInfo {
401 calls: non_approx_percentile_agg_calls,
402 col_mapping: ColIndexMapping::new(
403 non_approx_percentile_col_mapping,
404 self.agg_calls().len(),
405 ),
406 };
407 let approx = AggInfo {
408 calls: approx_percentile_agg_calls,
409 col_mapping: ColIndexMapping::new(
410 approx_percentile_col_mapping,
411 self.agg_calls().len(),
412 ),
413 };
414 SeparatedAggInfo { normal, approx }
415 }
416
417 fn build_approx_percentile_agg(
418 &self,
419 input: PlanRef,
420 approx_percentile_agg_call: &PlanAggCall,
421 ) -> Result<PlanRef> {
422 let local_approx_percentile =
423 StreamLocalApproxPercentile::new(input, approx_percentile_agg_call);
424 let exchange = RequiredDist::single()
425 .enforce_if_not_satisfies(local_approx_percentile.into(), &Order::any())?;
426 let global_approx_percentile =
427 StreamGlobalApproxPercentile::new(exchange, approx_percentile_agg_call);
428 Ok(global_approx_percentile.into())
429 }
430
431 fn build_approx_percentile_aggs(
444 &self,
445 input: PlanRef,
446 approx_percentile_agg_call: &[PlanAggCall],
447 ) -> Result<Option<PlanRef>> {
448 if approx_percentile_agg_call.is_empty() {
449 return Ok(None);
450 }
451 let approx_percentile_plans: Vec<PlanRef> = approx_percentile_agg_call
452 .iter()
453 .map(|agg_call| self.build_approx_percentile_agg(input.clone(), agg_call))
454 .try_collect()?;
455 assert!(!approx_percentile_plans.is_empty());
456 let mut iter = approx_percentile_plans.into_iter();
457 let mut acc = iter.next().unwrap();
458 for (current_size, plan) in iter.enumerate().map(|(i, p)| (i + 1, p)) {
459 let new_size = current_size + 1;
460 let row_merge = StreamRowMerge::new(
461 acc,
462 plan,
463 ColIndexMapping::identity_or_none(current_size, new_size),
464 ColIndexMapping::new(vec![Some(current_size)], new_size),
465 )
466 .expect("failed to build row merge");
467 acc = row_merge.into();
468 }
469 Ok(Some(acc))
470 }
471
472 pub fn core(&self) -> &Agg<PlanRef> {
473 &self.core
474 }
475}
476
477pub struct LogicalAggBuilder {
482 input_proj_builder: ProjectBuilder,
484 group_key: IndexSet,
486 grouping_sets: Vec<IndexSet>,
488 agg_calls: Vec<PlanAggCall>,
490 error: Option<RwError>,
492 is_in_filter_clause: bool,
498}
499
500impl LogicalAggBuilder {
501 fn new(group_by: GroupBy, input_schema_len: usize) -> Result<Self> {
502 let mut input_proj_builder = ProjectBuilder::default();
503
504 let mut gen_group_key_and_grouping_sets =
505 |grouping_sets: Vec<Vec<ExprImpl>>| -> Result<(IndexSet, Vec<IndexSet>)> {
506 let grouping_sets: Vec<IndexSet> = grouping_sets
507 .into_iter()
508 .map(|set| {
509 set.into_iter()
510 .map(|expr| input_proj_builder.add_expr(&expr))
511 .try_collect()
512 .map_err(|err| not_implemented!("{err} inside GROUP BY"))
513 })
514 .try_collect()?;
515
516 let group_key = grouping_sets
518 .iter()
519 .fold(FixedBitSet::with_capacity(input_schema_len), |acc, x| {
520 acc.union(&x.to_bitset()).collect()
521 });
522
523 Ok((IndexSet::from_iter(group_key.ones()), grouping_sets))
524 };
525
526 let (group_key, grouping_sets) = match group_by {
527 GroupBy::GroupKey(group_key) => {
528 let group_key = group_key
529 .into_iter()
530 .map(|expr| input_proj_builder.add_expr(&expr))
531 .try_collect()
532 .map_err(|err| not_implemented!("{err} inside GROUP BY"))?;
533 (group_key, vec![])
534 }
535 GroupBy::GroupingSets(grouping_sets) => gen_group_key_and_grouping_sets(grouping_sets)?,
536 GroupBy::Rollup(rollup) => {
537 let grouping_sets = (0..=rollup.len())
539 .map(|n| {
540 rollup
541 .iter()
542 .take(n)
543 .flat_map(|x| x.iter().cloned())
544 .collect_vec()
545 })
546 .collect_vec();
547 gen_group_key_and_grouping_sets(grouping_sets)?
548 }
549 GroupBy::Cube(cube) => {
550 let grouping_sets = cube
552 .into_iter()
553 .powerset()
554 .map(|x| x.into_iter().flatten().collect_vec())
555 .collect_vec();
556 gen_group_key_and_grouping_sets(grouping_sets)?
557 }
558 };
559
560 Ok(LogicalAggBuilder {
561 group_key,
562 grouping_sets,
563 agg_calls: vec![],
564 error: None,
565 input_proj_builder,
566 is_in_filter_clause: false,
567 })
568 }
569
570 pub fn build(self, input: PlanRef) -> LogicalAgg {
571 let logical_project = LogicalProject::with_core(self.input_proj_builder.build(input));
573
574 Agg::new(self.agg_calls, self.group_key, logical_project.into())
576 .with_grouping_sets(self.grouping_sets)
577 .into()
578 }
579
580 fn rewrite_with_error(&mut self, expr: ExprImpl) -> Result<ExprImpl> {
581 let rewritten_expr = self.rewrite_expr(expr);
582 if let Some(error) = self.error.take() {
583 return Err(error);
584 }
585 Ok(rewritten_expr)
586 }
587
588 pub fn try_as_group_expr(&self, expr: &ExprImpl) -> Option<usize> {
590 if let Some(input_index) = self.input_proj_builder.expr_index(expr) {
591 if let Some(index) = self
592 .group_key
593 .indices()
594 .position(|group_key| group_key == input_index)
595 {
596 return Some(index);
597 }
598 }
599 None
600 }
601
602 fn schema_agg_start_offset(&self) -> usize {
603 self.group_key.len()
604 }
605
606 pub(crate) fn general_rewrite_agg_call(
609 agg_call: AggCall,
610 mut push_agg_call: impl FnMut(AggCall) -> Result<InputRef>,
611 ) -> Result<ExprImpl> {
612 match agg_call.agg_type {
613 AggType::Builtin(PbAggKind::Avg) => {
615 assert_eq!(agg_call.args.len(), 1);
616
617 let sum = ExprImpl::from(push_agg_call(AggCall::new(
618 PbAggKind::Sum.into(),
619 agg_call.args.clone(),
620 agg_call.distinct,
621 agg_call.order_by.clone(),
622 agg_call.filter.clone(),
623 agg_call.direct_args.clone(),
624 )?)?)
625 .cast_explicit(agg_call.return_type())?;
626
627 let count = ExprImpl::from(push_agg_call(AggCall::new(
628 PbAggKind::Count.into(),
629 agg_call.args.clone(),
630 agg_call.distinct,
631 agg_call.order_by.clone(),
632 agg_call.filter.clone(),
633 agg_call.direct_args.clone(),
634 )?)?);
635
636 Ok(FunctionCall::new(ExprType::Divide, Vec::from([sum, count]))?.into())
637 }
638 AggType::Builtin(
647 kind @ (PbAggKind::StddevPop
648 | PbAggKind::StddevSamp
649 | PbAggKind::VarPop
650 | PbAggKind::VarSamp),
651 ) => {
652 let arg = agg_call.args().iter().exactly_one().unwrap();
653 let squared_arg = ExprImpl::from(FunctionCall::new(
654 ExprType::Multiply,
655 vec![arg.clone(), arg.clone()],
656 )?);
657
658 let sum_of_sq = ExprImpl::from(push_agg_call(AggCall::new(
659 PbAggKind::Sum.into(),
660 vec![squared_arg],
661 agg_call.distinct,
662 agg_call.order_by.clone(),
663 agg_call.filter.clone(),
664 agg_call.direct_args.clone(),
665 )?)?)
666 .cast_explicit(agg_call.return_type())?;
667
668 let sum = ExprImpl::from(push_agg_call(AggCall::new(
669 PbAggKind::Sum.into(),
670 agg_call.args.clone(),
671 agg_call.distinct,
672 agg_call.order_by.clone(),
673 agg_call.filter.clone(),
674 agg_call.direct_args.clone(),
675 )?)?)
676 .cast_explicit(agg_call.return_type())?;
677
678 let count = ExprImpl::from(push_agg_call(AggCall::new(
679 PbAggKind::Count.into(),
680 agg_call.args.clone(),
681 agg_call.distinct,
682 agg_call.order_by.clone(),
683 agg_call.filter.clone(),
684 agg_call.direct_args.clone(),
685 )?)?);
686
687 let zero = ExprImpl::literal_int(0);
688 let one = ExprImpl::literal_int(1);
689
690 let squared_sum = ExprImpl::from(FunctionCall::new(
691 ExprType::Multiply,
692 vec![sum.clone(), sum],
693 )?);
694
695 let raw_numerator = ExprImpl::from(FunctionCall::new(
696 ExprType::Subtract,
697 vec![
698 sum_of_sq,
699 ExprImpl::from(FunctionCall::new(
700 ExprType::Divide,
701 vec![squared_sum, count.clone()],
702 )?),
703 ],
704 )?);
705
706 let numerator_type = raw_numerator.return_type();
708 let numerator = ExprImpl::from(FunctionCall::new(
709 ExprType::Greatest,
710 vec![raw_numerator, zero.clone().cast_explicit(numerator_type)?],
711 )?);
712
713 let denominator = match kind {
714 PbAggKind::VarPop | PbAggKind::StddevPop => count.clone(),
715 PbAggKind::VarSamp | PbAggKind::StddevSamp => ExprImpl::from(
716 FunctionCall::new(ExprType::Subtract, vec![count.clone(), one.clone()])?,
717 ),
718 _ => unreachable!(),
719 };
720
721 let mut target = ExprImpl::from(FunctionCall::new(
722 ExprType::Divide,
723 vec![numerator, denominator],
724 )?);
725
726 if matches!(kind, PbAggKind::StddevPop | PbAggKind::StddevSamp) {
727 target = ExprImpl::from(FunctionCall::new(ExprType::Sqrt, vec![target])?);
728 }
729
730 let null = ExprImpl::from(Literal::new(None, agg_call.return_type()));
731 let case_cond = match kind {
732 PbAggKind::VarPop | PbAggKind::StddevPop => {
733 ExprImpl::from(FunctionCall::new(ExprType::Equal, vec![count, zero])?)
734 }
735 PbAggKind::VarSamp | PbAggKind::StddevSamp => ExprImpl::from(
736 FunctionCall::new(ExprType::LessThanOrEqual, vec![count, one])?,
737 ),
738 _ => unreachable!(),
739 };
740
741 Ok(ExprImpl::from(FunctionCall::new(
742 ExprType::Case,
743 vec![case_cond, null, target],
744 )?))
745 }
746 AggType::Builtin(PbAggKind::ApproxPercentile) => {
747 if agg_call.order_by.sort_exprs[0].order_type == OrderType::descending() {
748 let prev_percentile = agg_call.direct_args[0].clone();
750 let new_percentile = 1.0
751 - prev_percentile
752 .get_data()
753 .as_ref()
754 .unwrap()
755 .as_float64()
756 .into_inner();
757 let new_percentile = Some(ScalarImpl::Float64(new_percentile.into()));
758 let new_percentile = Literal::new(new_percentile, DataType::Float64);
759 let new_direct_args = vec![new_percentile, agg_call.direct_args[1].clone()];
760
761 let new_agg_call = AggCall {
762 order_by: OrderBy::any(),
763 direct_args: new_direct_args,
764 ..agg_call
765 };
766 Ok(push_agg_call(new_agg_call)?.into())
767 } else {
768 let new_agg_call = AggCall {
769 order_by: OrderBy::any(),
770 ..agg_call
771 };
772 Ok(push_agg_call(new_agg_call)?.into())
773 }
774 }
775 _ => Ok(push_agg_call(agg_call)?.into()),
776 }
777 }
778
779 fn push_agg_call(&mut self, agg_call: AggCall) -> Result<InputRef> {
783 let AggCall {
784 agg_type,
785 return_type,
786 args,
787 distinct,
788 order_by,
789 filter,
790 direct_args,
791 } = agg_call;
792
793 self.is_in_filter_clause = true;
794 let filter = filter.rewrite_expr(self);
797 self.is_in_filter_clause = false;
798
799 let args: Vec<_> = args
800 .iter()
801 .map(|expr| {
802 let index = self.input_proj_builder.add_expr(expr)?;
803 Ok(InputRef::new(index, expr.return_type()))
804 })
805 .try_collect()
806 .map_err(|err: &'static str| not_implemented!("{err} inside aggregation calls"))?;
807
808 let order_by: Vec<_> = order_by
809 .sort_exprs
810 .iter()
811 .map(|e| {
812 let index = self.input_proj_builder.add_expr(&e.expr)?;
813 Ok(ColumnOrder::new(index, e.order_type))
814 })
815 .try_collect()
816 .map_err(|err: &'static str| {
817 not_implemented!("{err} inside aggregation calls order by")
818 })?;
819
820 let plan_agg_call = PlanAggCall {
821 agg_type,
822 return_type: return_type.clone(),
823 inputs: args,
824 distinct,
825 order_by,
826 filter,
827 direct_args,
828 };
829
830 if let Some((pos, existing)) = self
831 .agg_calls
832 .iter()
833 .find_position(|&c| c == &plan_agg_call)
834 {
835 return Ok(InputRef::new(
836 self.schema_agg_start_offset() + pos,
837 existing.return_type.clone(),
838 ));
839 }
840 let index = self.schema_agg_start_offset() + self.agg_calls.len();
841 self.agg_calls.push(plan_agg_call);
842 Ok(InputRef::new(index, return_type))
843 }
844
845 fn try_rewrite_agg_call(&mut self, mut agg_call: AggCall) -> Result<ExprImpl> {
852 if matches!(agg_call.agg_type, agg_types::must_have_order_by!())
853 && agg_call.order_by.sort_exprs.is_empty()
854 {
855 return Err(ErrorCode::InvalidInputSyntax(format!(
856 "Aggregation function {} requires ORDER BY clause",
857 agg_call.agg_type
858 ))
859 .into());
860 }
861
862 if matches!(
864 agg_call.agg_type,
865 agg_types::result_unaffected_by_order_by!()
866 ) {
867 agg_call.order_by = OrderBy::any();
868 }
869 if matches!(
871 agg_call.agg_type,
872 agg_types::result_unaffected_by_distinct!()
873 ) {
874 agg_call.distinct = false;
875 }
876
877 if matches!(agg_call.agg_type, AggType::Builtin(PbAggKind::Grouping)) {
878 if self.grouping_sets.is_empty() {
879 return Err(ErrorCode::NotSupported(
880 "GROUPING must be used in a query with grouping sets".into(),
881 "try to use grouping sets instead".into(),
882 )
883 .into());
884 }
885 if agg_call.args.len() >= 32 {
886 return Err(ErrorCode::InvalidInputSyntax(
887 "GROUPING must have fewer than 32 arguments".into(),
888 )
889 .into());
890 }
891 if agg_call
892 .args
893 .iter()
894 .any(|x| self.try_as_group_expr(x).is_none())
895 {
896 return Err(ErrorCode::InvalidInputSyntax(
897 "arguments to GROUPING must be grouping expressions of the associated query level"
898 .into(),
899 ).into());
900 }
901 }
902
903 Self::general_rewrite_agg_call(agg_call, |agg_call| self.push_agg_call(agg_call))
904 }
905}
906
907impl ExprRewriter for LogicalAggBuilder {
908 fn rewrite_agg_call(&mut self, agg_call: AggCall) -> ExprImpl {
909 let dummy = Literal::new(None, agg_call.return_type()).into();
910 match self.try_rewrite_agg_call(agg_call) {
911 Ok(expr) => expr,
912 Err(err) => {
913 self.error = Some(err);
914 dummy
915 }
916 }
917 }
918
919 fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
922 let expr = func_call.into();
923 if let Some(group_key) = self.try_as_group_expr(&expr) {
924 InputRef::new(group_key, expr.return_type()).into()
925 } else {
926 let (func_type, inputs, ret) = expr.into_function_call().unwrap().decompose();
927 let inputs = inputs
928 .into_iter()
929 .map(|expr| self.rewrite_expr(expr))
930 .collect();
931 FunctionCall::new_unchecked(func_type, inputs, ret).into()
932 }
933 }
934
935 fn rewrite_window_function(&mut self, window_func: WindowFunction) -> ExprImpl {
938 let WindowFunction {
939 args,
940 partition_by,
941 order_by,
942 ..
943 } = window_func;
944 let args = args
945 .into_iter()
946 .map(|expr| self.rewrite_expr(expr))
947 .collect();
948 let partition_by = partition_by
949 .into_iter()
950 .map(|expr| self.rewrite_expr(expr))
951 .collect();
952 let order_by = order_by.rewrite_expr(self);
953 WindowFunction {
954 args,
955 partition_by,
956 order_by,
957 ..window_func
958 }
959 .into()
960 }
961
962 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
964 let expr = input_ref.into();
965 if let Some(group_key) = self.try_as_group_expr(&expr) {
966 InputRef::new(group_key, expr.return_type()).into()
967 } else if self.is_in_filter_clause {
968 InputRef::new(
969 self.input_proj_builder.add_expr(&expr).unwrap(),
970 expr.return_type(),
971 )
972 .into()
973 } else {
974 self.error = Some(
975 ErrorCode::InvalidInputSyntax(
976 "column must appear in the GROUP BY clause or be used in an aggregate function"
977 .into(),
978 )
979 .into(),
980 );
981 expr
982 }
983 }
984
985 fn rewrite_subquery(&mut self, subquery: crate::expr::Subquery) -> ExprImpl {
986 if subquery.is_correlated(0) {
987 self.error = Some(
988 not_implemented!(
989 issue = 2275,
990 "correlated subquery in HAVING or SELECT with agg",
991 )
992 .into(),
993 );
994 }
995 subquery.into()
996 }
997}
998
999impl From<Agg<PlanRef>> for LogicalAgg {
1000 fn from(core: Agg<PlanRef>) -> Self {
1001 let base = PlanBase::new_logical_with_core(&core);
1002 Self { base, core }
1003 }
1004}
1005
1006impl From<Agg<PlanRef>> for PlanRef {
1008 fn from(core: Agg<PlanRef>) -> Self {
1009 LogicalAgg::from(core).into()
1010 }
1011}
1012
1013impl LogicalAgg {
1014 pub fn i2o_col_mapping(&self) -> ColIndexMapping {
1016 self.core.i2o_col_mapping()
1017 }
1018
1019 pub fn create(
1028 select_exprs: Vec<ExprImpl>,
1029 group_by: GroupBy,
1030 having: Option<ExprImpl>,
1031 input: PlanRef,
1032 ) -> Result<(PlanRef, Vec<ExprImpl>, Option<ExprImpl>)> {
1033 let mut agg_builder = LogicalAggBuilder::new(group_by, input.schema().len())?;
1034
1035 let rewritten_select_exprs = select_exprs
1036 .into_iter()
1037 .map(|expr| agg_builder.rewrite_with_error(expr))
1038 .collect::<Result<_>>()?;
1039 let rewritten_having = having
1040 .map(|expr| agg_builder.rewrite_with_error(expr))
1041 .transpose()?;
1042
1043 Ok((
1044 agg_builder.build(input).into(),
1045 rewritten_select_exprs,
1046 rewritten_having,
1047 ))
1048 }
1049
1050 pub fn agg_calls(&self) -> &Vec<PlanAggCall> {
1052 &self.core.agg_calls
1053 }
1054
1055 pub fn group_key(&self) -> &IndexSet {
1057 &self.core.group_key
1058 }
1059
1060 pub fn grouping_sets(&self) -> &Vec<IndexSet> {
1061 &self.core.grouping_sets
1062 }
1063
1064 pub fn decompose(self) -> (Vec<PlanAggCall>, IndexSet, Vec<IndexSet>, PlanRef, bool) {
1065 self.core.decompose()
1066 }
1067
1068 #[must_use]
1069 pub fn rewrite_with_input_agg(
1070 &self,
1071 input: PlanRef,
1072 agg_calls: &[PlanAggCall],
1073 mut input_col_change: ColIndexMapping,
1074 ) -> (Self, ColIndexMapping) {
1075 let agg_calls = agg_calls
1076 .iter()
1077 .cloned()
1078 .map(|mut agg_call| {
1079 agg_call.inputs.iter_mut().for_each(|i| {
1080 *i = InputRef::new(input_col_change.map(i.index()), i.return_type())
1081 });
1082 agg_call.order_by.iter_mut().for_each(|o| {
1083 o.column_index = input_col_change.map(o.column_index);
1084 });
1085 agg_call.filter = agg_call.filter.rewrite_expr(&mut input_col_change);
1086 agg_call
1087 })
1088 .collect();
1089 let group_key_in_vec: Vec<usize> = self
1091 .group_key()
1092 .indices()
1093 .map(|key| input_col_change.map(key))
1094 .collect();
1095 let group_key: IndexSet = group_key_in_vec.iter().cloned().collect();
1097 let grouping_sets = self
1098 .grouping_sets()
1099 .iter()
1100 .map(|set| set.indices().map(|key| input_col_change.map(key)).collect())
1101 .collect();
1102
1103 let new_agg = Agg::new(agg_calls, group_key.clone(), input)
1104 .with_grouping_sets(grouping_sets)
1105 .with_enable_two_phase(self.core().enable_two_phase);
1106
1107 let mut out_col_change = vec![];
1110 for idx in group_key_in_vec {
1111 let pos = group_key.indices().position(|x| x == idx).unwrap();
1112 out_col_change.push(pos);
1113 }
1114 for i in (group_key.len())..new_agg.schema().len() {
1115 out_col_change.push(i);
1116 }
1117 let out_col_change =
1118 ColIndexMapping::with_remaining_columns(&out_col_change, new_agg.schema().len());
1119
1120 (new_agg.into(), out_col_change)
1121 }
1122}
1123
1124impl PlanTreeNodeUnary for LogicalAgg {
1125 fn input(&self) -> PlanRef {
1126 self.core.input.clone()
1127 }
1128
1129 fn clone_with_input(&self, input: PlanRef) -> Self {
1130 Agg::new(self.agg_calls().to_vec(), self.group_key().clone(), input)
1131 .with_grouping_sets(self.grouping_sets().clone())
1132 .with_enable_two_phase(self.core().enable_two_phase)
1133 .into()
1134 }
1135
1136 fn rewrite_with_input(
1137 &self,
1138 input: PlanRef,
1139 input_col_change: ColIndexMapping,
1140 ) -> (Self, ColIndexMapping) {
1141 self.rewrite_with_input_agg(input, self.agg_calls(), input_col_change)
1142 }
1143}
1144
1145impl_plan_tree_node_for_unary! {LogicalAgg}
1146impl_distill_by_unit!(LogicalAgg, core, "LogicalAgg");
1147
1148impl ExprRewritable for LogicalAgg {
1149 fn has_rewritable_expr(&self) -> bool {
1150 true
1151 }
1152
1153 fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
1154 let mut core = self.core.clone();
1155 core.rewrite_exprs(r);
1156 Self {
1157 base: self.base.clone_with_new_plan_id(),
1158 core,
1159 }
1160 .into()
1161 }
1162}
1163
1164impl ExprVisitable for LogicalAgg {
1165 fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
1166 self.core.visit_exprs(v);
1167 }
1168}
1169
1170impl ColPrunable for LogicalAgg {
1171 fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
1172 let group_key_required_cols = self.group_key().to_bitset();
1173
1174 let (agg_call_required_cols, agg_calls) = {
1175 let input_cnt = self.input().schema().len();
1176 let mut tmp = FixedBitSet::with_capacity(input_cnt);
1177 let group_key_cardinality = self.group_key().len();
1178 let new_agg_calls = required_cols
1179 .iter()
1180 .filter(|&&index| index >= group_key_cardinality)
1181 .map(|&index| {
1182 let index = index - group_key_cardinality;
1183 let agg_call = self.agg_calls()[index].clone();
1184 tmp.extend(agg_call.inputs.iter().map(|x| x.index()));
1185 tmp.extend(agg_call.order_by.iter().map(|x| x.column_index));
1186 for i in &agg_call.filter.conjunctions {
1188 tmp.union_with(&i.collect_input_refs(input_cnt));
1189 }
1190 agg_call
1191 })
1192 .collect_vec();
1193 (tmp, new_agg_calls)
1194 };
1195
1196 let input_required_cols = {
1197 let mut tmp = FixedBitSet::with_capacity(self.input().schema().len());
1198 tmp.union_with(&group_key_required_cols);
1199 tmp.union_with(&agg_call_required_cols);
1200 tmp.ones().collect_vec()
1201 };
1202 let input_col_change = ColIndexMapping::with_remaining_columns(
1203 &input_required_cols,
1204 self.input().schema().len(),
1205 );
1206 let agg = {
1207 let input = self.input().prune_col(&input_required_cols, ctx);
1208 let (agg, output_col_change) =
1209 self.rewrite_with_input_agg(input, &agg_calls, input_col_change);
1210 assert!(output_col_change.is_identity());
1211 agg
1212 };
1213 let new_output_cols = {
1214 let group_key_cardinality = agg.group_key().len();
1216 let mut tmp = (0..group_key_cardinality).collect_vec();
1217 tmp.extend(
1218 required_cols
1219 .iter()
1220 .filter(|&&index| index >= group_key_cardinality),
1221 );
1222 tmp
1223 };
1224 if new_output_cols == required_cols {
1225 agg.into()
1227 } else {
1228 let mapping =
1231 &ColIndexMapping::with_remaining_columns(&new_output_cols, self.schema().len());
1232 let output_required_cols = required_cols
1233 .iter()
1234 .map(|&idx| mapping.map(idx))
1235 .collect_vec();
1236 let src_size = agg.schema().len();
1237 LogicalProject::with_mapping(
1238 agg.into(),
1239 ColIndexMapping::with_remaining_columns(&output_required_cols, src_size),
1240 )
1241 .into()
1242 }
1243 }
1244}
1245
1246impl PredicatePushdown for LogicalAgg {
1247 fn predicate_pushdown(
1248 &self,
1249 predicate: Condition,
1250 ctx: &mut PredicatePushdownContext,
1251 ) -> PlanRef {
1252 let num_group_key = self.group_key().len();
1253 let num_agg_calls = self.agg_calls().len();
1254 assert!(num_group_key + num_agg_calls == self.schema().len());
1255
1256 if num_group_key == 0 {
1264 return gen_filter_and_pushdown(self, predicate, Condition::true_cond(), ctx);
1265 }
1266
1267 let mut agg_call_columns = FixedBitSet::with_capacity(num_group_key + num_agg_calls);
1269 agg_call_columns.insert_range(num_group_key..num_group_key + num_agg_calls);
1270 let (agg_call_pred, pushed_predicate) = predicate.split_disjoint(&agg_call_columns);
1271
1272 let mut subst = Substitute {
1274 mapping: self
1275 .group_key()
1276 .indices()
1277 .enumerate()
1278 .map(|(i, group_key)| {
1279 InputRef::new(group_key, self.schema().fields()[i].data_type()).into()
1280 })
1281 .collect(),
1282 };
1283 let pushed_predicate = pushed_predicate.rewrite_expr(&mut subst);
1284
1285 gen_filter_and_pushdown(self, agg_call_pred, pushed_predicate, ctx)
1286 }
1287}
1288
1289impl ToBatch for LogicalAgg {
1290 fn to_batch(&self) -> Result<PlanRef> {
1291 self.to_batch_with_order_required(&Order::any())
1292 }
1293
1294 fn to_batch_with_order_required(&self, required_order: &Order) -> Result<PlanRef> {
1297 let input = self.input().to_batch()?;
1298 let new_logical = Agg {
1299 input,
1300 ..self.core.clone()
1301 };
1302 let agg_plan = if self.group_key().is_empty() {
1303 BatchSimpleAgg::new(new_logical).into()
1304 } else if self.ctx().session_ctx().config().batch_enable_sort_agg()
1305 && new_logical.input_provides_order_on_group_keys()
1306 {
1307 BatchSortAgg::new(new_logical).into()
1308 } else {
1309 BatchHashAgg::new(new_logical).into()
1310 };
1311 required_order.enforce_if_not_satisfies(agg_plan)
1312 }
1313}
1314
1315fn find_or_append_row_count(mut logical: Agg<PlanRef>) -> (Agg<PlanRef>, usize) {
1316 let count_star = PlanAggCall::count_star();
1319 let row_count_idx = if let Some((idx, _)) = logical
1320 .agg_calls
1321 .iter()
1322 .find_position(|&c| c == &count_star)
1323 {
1324 idx
1325 } else {
1326 let idx = logical.agg_calls.len();
1327 logical.agg_calls.push(count_star);
1328 idx
1329 };
1330 (logical, row_count_idx)
1331}
1332
1333fn new_stream_simple_agg(core: Agg<PlanRef>, must_output_per_barrier: bool) -> StreamSimpleAgg {
1334 let (logical, row_count_idx) = find_or_append_row_count(core);
1335 StreamSimpleAgg::new(logical, row_count_idx, must_output_per_barrier)
1336}
1337
1338fn new_stream_hash_agg(core: Agg<PlanRef>, vnode_col_idx: Option<usize>) -> StreamHashAgg {
1339 let (logical, row_count_idx) = find_or_append_row_count(core);
1340 StreamHashAgg::new(logical, vnode_col_idx, row_count_idx)
1341}
1342
1343impl ToStream for LogicalAgg {
1344 fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<PlanRef> {
1345 use super::stream::prelude::*;
1346
1347 for agg_call in self.agg_calls() {
1348 if matches!(agg_call.agg_type, agg_types::unimplemented_in_stream!()) {
1349 bail_not_implemented!("{} aggregation in materialized view", agg_call.agg_type);
1350 }
1351 }
1352 let eowc = ctx.emit_on_window_close();
1353 let stream_input = self.input().to_stream(ctx)?;
1354
1355 if stream_input.append_only() && self.agg_calls().is_empty() && !self.group_key().is_empty()
1357 {
1358 let input = if self.group_key().len() != self.input().schema().len() {
1359 let cols = &self.group_key().to_vec();
1360 LogicalProject::with_mapping(
1361 self.input(),
1362 ColIndexMapping::with_remaining_columns(cols, self.input().schema().len()),
1363 )
1364 .into()
1365 } else {
1366 self.input()
1367 };
1368 let input_schema_len = input.schema().len();
1369 let logical_dedup = LogicalDedup::new(input, (0..input_schema_len).collect());
1370 return logical_dedup.to_stream(ctx);
1371 }
1372
1373 if self.agg_calls().iter().any(|call| {
1374 matches!(
1375 call.agg_type,
1376 AggType::Builtin(PbAggKind::ApproxCountDistinct)
1377 )
1378 }) {
1379 if stream_input.append_only() {
1380 self.core.ctx().session_ctx().notice_to_user(
1381 "Streaming `APPROX_COUNT_DISTINCT` is still a preview feature and subject to change. Please do not use it in production environment.",
1382 );
1383 } else {
1384 bail_not_implemented!(
1385 "Streaming `APPROX_COUNT_DISTINCT` is only supported in append-only stream"
1386 );
1387 }
1388 }
1389
1390 let plan = self.gen_dist_stream_agg_plan(stream_input)?;
1391
1392 let (plan, n_final_agg_calls) = if let Some(final_agg) = plan.as_stream_simple_agg() {
1393 if eowc {
1394 return Err(ErrorCode::InvalidInputSyntax(
1395 "`EMIT ON WINDOW CLOSE` cannot be used for aggregation without `GROUP BY`"
1396 .to_owned(),
1397 )
1398 .into());
1399 }
1400 (plan.clone(), final_agg.agg_calls().len())
1401 } else if let Some(final_agg) = plan.as_stream_hash_agg() {
1402 (
1403 if eowc {
1404 final_agg.to_eowc_version()?
1405 } else {
1406 plan.clone()
1407 },
1408 final_agg.agg_calls().len(),
1409 )
1410 } else if let Some(_approx_percentile_agg) = plan.as_stream_global_approx_percentile() {
1411 if eowc {
1412 return Err(ErrorCode::InvalidInputSyntax(
1413 "`EMIT ON WINDOW CLOSE` cannot be used for aggregation without `GROUP BY`"
1414 .to_owned(),
1415 )
1416 .into());
1417 }
1418 (plan.clone(), 1)
1419 } else if let Some(stream_row_merge) = plan.as_stream_row_merge() {
1420 if eowc {
1421 return Err(ErrorCode::InvalidInputSyntax(
1422 "`EMIT ON WINDOW CLOSE` cannot be used for aggregation without `GROUP BY`"
1423 .to_owned(),
1424 )
1425 .into());
1426 }
1427 (plan.clone(), stream_row_merge.base.schema().len())
1428 } else {
1429 panic!(
1430 "the root PlanNode must be StreamHashAgg, StreamSimpleAgg, StreamGlobalApproxPercentile, or StreamRowMerge"
1431 );
1432 };
1433
1434 if self.agg_calls().len() == n_final_agg_calls {
1435 Ok(plan)
1437 } else {
1438 assert_eq!(self.agg_calls().len() + 1, n_final_agg_calls);
1440 Ok(StreamProject::new(generic::Project::with_out_col_idx(
1441 plan,
1442 0..self.schema().len(),
1443 ))
1444 .with_noop_update_hint(self.agg_calls().is_empty())
1449 .into())
1450 }
1451 }
1452
1453 fn logical_rewrite_for_stream(
1454 &self,
1455 ctx: &mut RewriteStreamContext,
1456 ) -> Result<(PlanRef, ColIndexMapping)> {
1457 let (input, input_col_change) = self.input().logical_rewrite_for_stream(ctx)?;
1458 let (agg, out_col_change) = self.rewrite_with_input(input, input_col_change);
1459 let (map, _) = out_col_change.into_parts();
1460 let out_col_change = ColIndexMapping::new(map, agg.schema().len());
1461 Ok((agg.into(), out_col_change))
1462 }
1463}
1464
1465#[cfg(test)]
1466mod tests {
1467 use risingwave_common::catalog::{Field, Schema};
1468
1469 use super::*;
1470 use crate::expr::{assert_eq_input_ref, input_ref_to_column_indices};
1471 use crate::optimizer::optimizer_context::OptimizerContext;
1472 use crate::optimizer::plan_node::LogicalValues;
1473
1474 #[tokio::test]
1475 async fn test_create() {
1476 let ty = DataType::Int32;
1477 let ctx = OptimizerContext::mock().await;
1478 let fields: Vec<Field> = vec![
1479 Field::with_name(ty.clone(), "v1"),
1480 Field::with_name(ty.clone(), "v2"),
1481 Field::with_name(ty.clone(), "v3"),
1482 ];
1483 let values = LogicalValues::new(vec![], Schema { fields }, ctx);
1484 let input = PlanRef::from(values);
1485 let input_ref_1 = InputRef::new(0, ty.clone());
1486 let input_ref_2 = InputRef::new(1, ty.clone());
1487 let input_ref_3 = InputRef::new(2, ty.clone());
1488
1489 let gen_internal_value = |select_exprs: Vec<ExprImpl>,
1490 group_exprs|
1491 -> (Vec<ExprImpl>, Vec<PlanAggCall>, IndexSet) {
1492 let (plan, exprs, _) = LogicalAgg::create(
1493 select_exprs,
1494 GroupBy::GroupKey(group_exprs),
1495 None,
1496 input.clone(),
1497 )
1498 .unwrap();
1499
1500 let logical_agg = plan.as_logical_agg().unwrap();
1501 let agg_calls = logical_agg.agg_calls().to_vec();
1502 let group_key = logical_agg.group_key().clone();
1503
1504 (exprs, agg_calls, group_key)
1505 };
1506
1507 {
1509 let select_exprs = vec![input_ref_1.clone().into()];
1510 let group_exprs = vec![input_ref_1.clone().into()];
1511
1512 let (exprs, agg_calls, group_key) = gen_internal_value(select_exprs, group_exprs);
1513
1514 assert_eq!(exprs.len(), 1);
1515 assert_eq_input_ref!(&exprs[0], 0);
1516
1517 assert_eq!(agg_calls.len(), 0);
1518 assert_eq!(group_key, vec![0].into());
1519 }
1520
1521 {
1523 let min_v2 = AggCall::new(
1524 PbAggKind::Min.into(),
1525 vec![input_ref_2.clone().into()],
1526 false,
1527 OrderBy::any(),
1528 Condition::true_cond(),
1529 vec![],
1530 )
1531 .unwrap();
1532 let select_exprs = vec![input_ref_1.clone().into(), min_v2.into()];
1533 let group_exprs = vec![input_ref_1.clone().into()];
1534
1535 let (exprs, agg_calls, group_key) = gen_internal_value(select_exprs, group_exprs);
1536
1537 assert_eq!(exprs.len(), 2);
1538 assert_eq_input_ref!(&exprs[0], 0);
1539 assert_eq_input_ref!(&exprs[1], 1);
1540
1541 assert_eq!(agg_calls.len(), 1);
1542 assert_eq!(agg_calls[0].agg_type, PbAggKind::Min.into());
1543 assert_eq!(input_ref_to_column_indices(&agg_calls[0].inputs), vec![1]);
1544 assert_eq!(group_key, vec![0].into());
1545 }
1546
1547 {
1549 let min_v2 = AggCall::new(
1550 PbAggKind::Min.into(),
1551 vec![input_ref_2.clone().into()],
1552 false,
1553 OrderBy::any(),
1554 Condition::true_cond(),
1555 vec![],
1556 )
1557 .unwrap();
1558 let max_v3 = AggCall::new(
1559 PbAggKind::Max.into(),
1560 vec![input_ref_3.clone().into()],
1561 false,
1562 OrderBy::any(),
1563 Condition::true_cond(),
1564 vec![],
1565 )
1566 .unwrap();
1567 let func_call =
1568 FunctionCall::new(ExprType::Add, vec![min_v2.into(), max_v3.into()]).unwrap();
1569 let select_exprs = vec![input_ref_1.clone().into(), ExprImpl::from(func_call)];
1570 let group_exprs = vec![input_ref_1.clone().into()];
1571
1572 let (exprs, agg_calls, group_key) = gen_internal_value(select_exprs, group_exprs);
1573
1574 assert_eq_input_ref!(&exprs[0], 0);
1575 if let ExprImpl::FunctionCall(func_call) = &exprs[1] {
1576 assert_eq!(func_call.func_type(), ExprType::Add);
1577 let inputs = func_call.inputs();
1578 assert_eq_input_ref!(&inputs[0], 1);
1579 assert_eq_input_ref!(&inputs[1], 2);
1580 } else {
1581 panic!("Wrong expression type!");
1582 }
1583
1584 assert_eq!(agg_calls.len(), 2);
1585 assert_eq!(agg_calls[0].agg_type, PbAggKind::Min.into());
1586 assert_eq!(input_ref_to_column_indices(&agg_calls[0].inputs), vec![1]);
1587 assert_eq!(agg_calls[1].agg_type, PbAggKind::Max.into());
1588 assert_eq!(input_ref_to_column_indices(&agg_calls[1].inputs), vec![2]);
1589 assert_eq!(group_key, vec![0].into());
1590 }
1591
1592 {
1594 let v1_mult_v3 = FunctionCall::new(
1595 ExprType::Multiply,
1596 vec![input_ref_1.into(), input_ref_3.into()],
1597 )
1598 .unwrap();
1599 let agg_call = AggCall::new(
1600 PbAggKind::Min.into(),
1601 vec![v1_mult_v3.into()],
1602 false,
1603 OrderBy::any(),
1604 Condition::true_cond(),
1605 vec![],
1606 )
1607 .unwrap();
1608 let select_exprs = vec![input_ref_2.clone().into(), agg_call.into()];
1609 let group_exprs = vec![input_ref_2.into()];
1610
1611 let (exprs, agg_calls, group_key) = gen_internal_value(select_exprs, group_exprs);
1612
1613 assert_eq_input_ref!(&exprs[0], 0);
1614 assert_eq_input_ref!(&exprs[1], 1);
1615
1616 assert_eq!(agg_calls.len(), 1);
1617 assert_eq!(agg_calls[0].agg_type, PbAggKind::Min.into());
1618 assert_eq!(input_ref_to_column_indices(&agg_calls[0].inputs), vec![1]);
1619 assert_eq!(group_key, vec![0].into());
1620 }
1621 }
1622
1623 async fn generate_agg_call(ty: DataType, fields: Vec<Field>) -> LogicalAgg {
1630 let ctx = OptimizerContext::mock().await;
1631
1632 let values = LogicalValues::new(vec![], Schema { fields }, ctx);
1633 let agg_call = PlanAggCall {
1634 agg_type: PbAggKind::Min.into(),
1635 return_type: ty.clone(),
1636 inputs: vec![InputRef::new(2, ty.clone())],
1637 distinct: false,
1638 order_by: vec![],
1639 filter: Condition::true_cond(),
1640 direct_args: vec![],
1641 };
1642 Agg::new(vec![agg_call], vec![1].into(), values.into()).into()
1643 }
1644
1645 #[tokio::test]
1646 async fn test_prune_all() {
1657 let ty = DataType::Int32;
1658 let fields: Vec<Field> = vec![
1659 Field::with_name(ty.clone(), "v1"),
1660 Field::with_name(ty.clone(), "v2"),
1661 Field::with_name(ty.clone(), "v3"),
1662 ];
1663 let agg: PlanRef = generate_agg_call(ty.clone(), fields.clone()).await.into();
1664 let required_cols = vec![0, 1];
1666 let plan = agg.prune_col(&required_cols, &mut ColumnPruningContext::new(agg.clone()));
1667
1668 let agg_new = plan.as_logical_agg().unwrap();
1670 assert_eq!(agg_new.group_key(), &vec![0].into());
1671
1672 assert_eq!(agg_new.agg_calls().len(), 1);
1673 let agg_call_new = agg_new.agg_calls()[0].clone();
1674 assert_eq!(agg_call_new.agg_type, PbAggKind::Min.into());
1675 assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![1]);
1676 assert_eq!(agg_call_new.return_type, ty);
1677
1678 let values = agg_new.input();
1679 let values = values.as_logical_values().unwrap();
1680 assert_eq!(values.schema().fields(), &fields[1..]);
1681 }
1682
1683 #[tokio::test]
1684 async fn test_prune_all_with_order_required() {
1696 let ty = DataType::Int32;
1697 let fields: Vec<Field> = vec![
1698 Field::with_name(ty.clone(), "v1"),
1699 Field::with_name(ty.clone(), "v2"),
1700 Field::with_name(ty.clone(), "v3"),
1701 ];
1702 let agg: PlanRef = generate_agg_call(ty.clone(), fields.clone()).await.into();
1703 let required_cols = vec![1, 0];
1705 let plan = agg.prune_col(&required_cols, &mut ColumnPruningContext::new(agg.clone()));
1706 let proj = plan.as_logical_project().unwrap();
1708 assert_eq!(proj.exprs().len(), 2);
1709 assert_eq!(proj.exprs()[0].as_input_ref().unwrap().index(), 1);
1710 assert_eq!(proj.exprs()[1].as_input_ref().unwrap().index(), 0);
1711 let proj_input = proj.input();
1712 let agg_new = proj_input.as_logical_agg().unwrap();
1713 assert_eq!(agg_new.group_key(), &vec![0].into());
1714
1715 assert_eq!(agg_new.agg_calls().len(), 1);
1716 let agg_call_new = agg_new.agg_calls()[0].clone();
1717 assert_eq!(agg_call_new.agg_type, PbAggKind::Min.into());
1718 assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![1]);
1719 assert_eq!(agg_call_new.return_type, ty);
1720
1721 let values = agg_new.input();
1722 let values = values.as_logical_values().unwrap();
1723 assert_eq!(values.schema().fields(), &fields[1..]);
1724 }
1725
1726 #[tokio::test]
1727 async fn test_prune_group_key() {
1739 let ctx = OptimizerContext::mock().await;
1740 let ty = DataType::Int32;
1741 let fields: Vec<Field> = vec![
1742 Field::with_name(ty.clone(), "v1"),
1743 Field::with_name(ty.clone(), "v2"),
1744 Field::with_name(ty.clone(), "v3"),
1745 ];
1746 let values: LogicalValues = LogicalValues::new(
1747 vec![],
1748 Schema {
1749 fields: fields.clone(),
1750 },
1751 ctx,
1752 );
1753 let agg_call = PlanAggCall {
1754 agg_type: PbAggKind::Min.into(),
1755 return_type: ty.clone(),
1756 inputs: vec![InputRef::new(2, ty.clone())],
1757 distinct: false,
1758 order_by: vec![],
1759 filter: Condition::true_cond(),
1760 direct_args: vec![],
1761 };
1762 let agg: PlanRef = Agg::new(vec![agg_call], vec![1].into(), values.into()).into();
1763
1764 let required_cols = vec![1];
1766 let plan = agg.prune_col(&required_cols, &mut ColumnPruningContext::new(agg.clone()));
1767
1768 let project = plan.as_logical_project().unwrap();
1770 assert_eq!(project.exprs().len(), 1);
1771 assert_eq_input_ref!(&project.exprs()[0], 1);
1772
1773 let agg_new = project.input();
1774 let agg_new = agg_new.as_logical_agg().unwrap();
1775 assert_eq!(agg_new.group_key(), &vec![0].into());
1776
1777 assert_eq!(agg_new.agg_calls().len(), 1);
1778 let agg_call_new = agg_new.agg_calls()[0].clone();
1779 assert_eq!(agg_call_new.agg_type, PbAggKind::Min.into());
1780 assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![1]);
1781 assert_eq!(agg_call_new.return_type, ty);
1782
1783 let values = agg_new.input();
1784 let values = values.as_logical_values().unwrap();
1785 assert_eq!(values.schema().fields(), &fields[1..]);
1786 }
1787
1788 #[tokio::test]
1789 async fn test_prune_agg() {
1801 let ty = DataType::Int32;
1802 let ctx = OptimizerContext::mock().await;
1803 let fields: Vec<Field> = vec![
1804 Field::with_name(ty.clone(), "v1"),
1805 Field::with_name(ty.clone(), "v2"),
1806 Field::with_name(ty.clone(), "v3"),
1807 ];
1808 let values = LogicalValues::new(
1809 vec![],
1810 Schema {
1811 fields: fields.clone(),
1812 },
1813 ctx,
1814 );
1815
1816 let agg_calls = vec![
1817 PlanAggCall {
1818 agg_type: PbAggKind::Min.into(),
1819 return_type: ty.clone(),
1820 inputs: vec![InputRef::new(2, ty.clone())],
1821 distinct: false,
1822 order_by: vec![],
1823 filter: Condition::true_cond(),
1824 direct_args: vec![],
1825 },
1826 PlanAggCall {
1827 agg_type: PbAggKind::Max.into(),
1828 return_type: ty.clone(),
1829 inputs: vec![InputRef::new(1, ty.clone())],
1830 distinct: false,
1831 order_by: vec![],
1832 filter: Condition::true_cond(),
1833 direct_args: vec![],
1834 },
1835 ];
1836 let agg: PlanRef = Agg::new(agg_calls, vec![1, 2].into(), values.into()).into();
1837
1838 let required_cols = vec![0, 3];
1840 let plan = agg.prune_col(&required_cols, &mut ColumnPruningContext::new(agg.clone()));
1841 let project = plan.as_logical_project().unwrap();
1843 assert_eq!(project.exprs().len(), 2);
1844 assert_eq_input_ref!(&project.exprs()[0], 0);
1845 assert_eq_input_ref!(&project.exprs()[1], 2);
1846
1847 let agg_new = project.input();
1848 let agg_new = agg_new.as_logical_agg().unwrap();
1849 assert_eq!(agg_new.group_key(), &vec![0, 1].into());
1850
1851 assert_eq!(agg_new.agg_calls().len(), 1);
1852 let agg_call_new = agg_new.agg_calls()[0].clone();
1853 assert_eq!(agg_call_new.agg_type, PbAggKind::Max.into());
1854 assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![0]);
1855 assert_eq!(agg_call_new.return_type, ty);
1856
1857 let values = agg_new.input();
1858 let values = values.as_logical_values().unwrap();
1859 assert_eq!(values.schema().fields(), &fields[1..]);
1860 }
1861}