1use fixedbitset::FixedBitSet;
16use itertools::Itertools;
17use risingwave_common::types::{DataType, ScalarImpl};
18use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
19use risingwave_common::{bail, bail_not_implemented, not_implemented};
20use risingwave_expr::aggregate::{AggType, PbAggKind, agg_types};
21
22use super::generic::{self, Agg, GenericPlanRef, PlanAggCall, ProjectBuilder};
23use super::utils::impl_distill_by_unit;
24use super::{
25 BatchHashAgg, BatchSimpleAgg, ColPrunable, ExprRewritable, Logical, LogicalPlanRef as PlanRef,
26 PlanBase, PlanTreeNodeUnary, PredicatePushdown, StreamHashAgg, StreamPlanRef, StreamProject,
27 StreamShare, StreamSimpleAgg, StreamStatelessSimpleAgg, ToBatch, ToStream,
28};
29use crate::error::{ErrorCode, Result, RwError};
30use crate::expr::{
31 AggCall, Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, FunctionCall, InputRef, Literal,
32 OrderBy, OrderByExpr, WindowFunction,
33};
34use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
35use crate::optimizer::plan_node::generic::GenericPlanNode;
36use crate::optimizer::plan_node::stream_global_approx_percentile::StreamGlobalApproxPercentile;
37use crate::optimizer::plan_node::stream_local_approx_percentile::StreamLocalApproxPercentile;
38use crate::optimizer::plan_node::stream_row_merge::StreamRowMerge;
39use crate::optimizer::plan_node::{
40 BatchSortAgg, ColumnPruningContext, LogicalDedup, LogicalProject, PredicatePushdownContext,
41 RewriteStreamContext, ToStreamContext, gen_filter_and_pushdown,
42};
43use crate::optimizer::property::{Distribution, Order, RequiredDist};
44use crate::utils::{
45 ColIndexMapping, ColIndexMappingRewriteExt, Condition, GroupBy, IndexSet, Substitute,
46};
47
48pub struct AggInfo {
49 pub calls: Vec<PlanAggCall>,
50 pub col_mapping: ColIndexMapping,
51}
52
53pub struct SeparatedAggInfo {
55 normal: AggInfo,
56 approx: AggInfo,
57}
58
59#[derive(Clone, Debug, PartialEq, Eq, Hash)]
66pub struct LogicalAgg {
67 pub base: PlanBase<Logical>,
68 core: Agg<PlanRef>,
69}
70
71impl LogicalAgg {
72 fn gen_stateless_two_phase_streaming_agg_plan(
75 &self,
76 stream_input: StreamPlanRef,
77 ) -> Result<StreamPlanRef> {
78 debug_assert!(self.group_key().is_empty());
79
80 let (
82 non_approx_percentile_col_mapping,
83 approx_percentile_col_mapping,
84 approx_percentile,
85 core,
86 ) = self.prepare_approx_percentile(stream_input.clone())?;
87
88 if core.agg_calls.is_empty() {
89 if let Some(approx_percentile) = approx_percentile {
90 return Ok(approx_percentile);
91 };
92 bail!("expected at least one agg call");
93 }
94
95 let need_row_merge: bool = Self::need_row_merge(&approx_percentile);
96
97 let total_agg_calls = core
99 .agg_calls
100 .iter()
101 .enumerate()
102 .map(|(partial_output_idx, agg_call)| {
103 agg_call.partial_to_total_agg_call(partial_output_idx)
104 })
105 .collect_vec();
106 let local_agg = StreamStatelessSimpleAgg::new(core)?;
107 let exchange =
108 RequiredDist::single().streaming_enforce_if_not_satisfies(local_agg.into())?;
109
110 let must_output_per_barrier = need_row_merge;
111 let global_agg = new_stream_simple_agg(
112 Agg::new(total_agg_calls, IndexSet::empty(), exchange),
113 must_output_per_barrier,
114 )?;
115
116 Self::add_row_merge_if_needed(
118 approx_percentile,
119 global_agg.into(),
120 approx_percentile_col_mapping,
121 non_approx_percentile_col_mapping,
122 )
123 }
124
125 fn gen_vnode_two_phase_streaming_agg_plan(
129 &self,
130 stream_input: StreamPlanRef,
131 dist_key: &[usize],
132 ) -> Result<StreamPlanRef> {
133 let (
134 non_approx_percentile_col_mapping,
135 approx_percentile_col_mapping,
136 approx_percentile,
137 core,
138 ) = self.prepare_approx_percentile(stream_input.clone())?;
139
140 if core.agg_calls.is_empty() {
141 if let Some(approx_percentile) = approx_percentile {
142 return Ok(approx_percentile);
143 };
144 bail!("expected at least one agg call");
145 }
146 let need_row_merge = Self::need_row_merge(&approx_percentile);
147
148 let input_col_num = stream_input.schema().len(); let project = StreamProject::new(generic::Project::with_vnode_col(stream_input, dist_key));
152 let vnode_col_idx = project.base.schema().len() - 1;
153
154 let mut local_group_key = self.group_key().clone();
156 local_group_key.insert(vnode_col_idx);
157 let n_local_group_key = local_group_key.len();
158 let local_agg = new_stream_hash_agg(
159 Agg::new(core.agg_calls.to_vec(), local_group_key, project.into()),
160 Some(vnode_col_idx),
161 )?;
162 let local_agg_group_key_cardinality = local_agg.group_key().len();
164 let local_group_key_without_vnode =
165 &local_agg.group_key().to_vec()[..local_agg_group_key_cardinality - 1];
166 let global_group_key = local_agg
167 .i2o_col_mapping()
168 .rewrite_dist_key(local_group_key_without_vnode)
169 .expect("some input group key could not be mapped");
170
171 let global_agg = if self.group_key().is_empty() {
173 let exchange =
174 RequiredDist::single().streaming_enforce_if_not_satisfies(local_agg.into())?;
175 let must_output_per_barrier = need_row_merge;
176 let global_agg = new_stream_simple_agg(
177 Agg::new(
178 core.agg_calls
179 .iter()
180 .enumerate()
181 .map(|(partial_output_idx, agg_call)| {
182 agg_call
183 .partial_to_total_agg_call(n_local_group_key + partial_output_idx)
184 })
185 .collect(),
186 global_group_key.into_iter().collect(),
187 exchange,
188 ),
189 must_output_per_barrier,
190 )?;
191 global_agg.into()
192 } else {
193 assert!(!need_row_merge);
195 let exchange = RequiredDist::shard_by_key(input_col_num, &global_group_key)
196 .streaming_enforce_if_not_satisfies(local_agg.into())?;
197 let global_agg = new_stream_hash_agg(
200 Agg::new(
201 core.agg_calls
202 .iter()
203 .enumerate()
204 .map(|(partial_output_idx, agg_call)| {
205 agg_call
206 .partial_to_total_agg_call(n_local_group_key + partial_output_idx)
207 })
208 .collect(),
209 global_group_key.into_iter().collect(),
210 exchange,
211 ),
212 None,
213 )?;
214 global_agg.into()
215 };
216 Self::add_row_merge_if_needed(
217 approx_percentile,
218 global_agg,
219 approx_percentile_col_mapping,
220 non_approx_percentile_col_mapping,
221 )
222 }
223
224 fn gen_single_plan(&self, stream_input: StreamPlanRef) -> Result<StreamPlanRef> {
225 let input = RequiredDist::single().streaming_enforce_if_not_satisfies(stream_input)?;
226 let core = self.core.clone_with_input(input);
227 Ok(new_stream_simple_agg(core, false)?.into())
228 }
229
230 fn gen_shuffle_plan(&self, stream_input: StreamPlanRef) -> Result<StreamPlanRef> {
231 let input =
232 RequiredDist::shard_by_key(stream_input.schema().len(), &self.group_key().to_vec())
233 .streaming_enforce_if_not_satisfies(stream_input)?;
234 let core = self.core.clone_with_input(input);
235 Ok(new_stream_hash_agg(core, None)?.into())
236 }
237
238 fn gen_dist_stream_agg_plan(&self, stream_input: StreamPlanRef) -> Result<StreamPlanRef> {
240 use super::stream::prelude::*;
241
242 let input_dist = stream_input.distribution();
243 debug_assert!(*input_dist != Distribution::Broadcast);
244
245 if !self.group_key().is_empty() && !self.core.must_try_two_phase_agg() {
249 return self.gen_shuffle_plan(stream_input);
250 }
251
252 if self.group_key().is_empty() && !self.core.can_two_phase_agg() {
255 return self.gen_single_plan(stream_input);
256 }
257
258 debug_assert!(if !self.group_key().is_empty() {
259 self.core.must_try_two_phase_agg()
260 } else {
261 self.core.can_two_phase_agg()
262 });
263
264 if self.group_key().is_empty()
268 && self
269 .core
270 .all_local_aggs_are_stateless(stream_input.append_only())
271 && input_dist.satisfies(&RequiredDist::AnyShard)
272 {
273 return self.gen_stateless_two_phase_streaming_agg_plan(stream_input);
274 }
275
276 let stream_input =
281 if *input_dist == Distribution::SomeShard && self.core.must_try_two_phase_agg() {
282 RequiredDist::shard_by_key(
283 stream_input.schema().len(),
284 stream_input.expect_stream_key(),
285 )
286 .streaming_enforce_if_not_satisfies(stream_input)?
287 } else {
288 stream_input
289 };
290 let input_dist = stream_input.distribution();
291
292 match input_dist {
296 Distribution::HashShard(dist_key) | Distribution::UpstreamHashShard(dist_key, _)
297 if (self.group_key().is_empty()
298 || !self.core.hash_agg_dist_satisfied_by_input_dist(input_dist)) =>
299 {
300 let dist_key = dist_key.clone();
301 return self.gen_vnode_two_phase_streaming_agg_plan(stream_input, &dist_key);
302 }
303 _ => {}
304 }
305
306 if !self.group_key().is_empty() {
308 self.gen_shuffle_plan(stream_input)
309 } else {
310 self.gen_single_plan(stream_input)
311 }
312 }
313
314 fn prepare_approx_percentile(
319 &self,
320 stream_input: StreamPlanRef,
321 ) -> Result<(
322 ColIndexMapping,
323 ColIndexMapping,
324 Option<StreamPlanRef>,
325 Agg<StreamPlanRef>,
326 )> {
327 let SeparatedAggInfo { normal, approx } = self.separate_normal_and_special_agg();
328
329 let AggInfo {
330 calls: non_approx_percentile_agg_calls,
331 col_mapping: non_approx_percentile_col_mapping,
332 } = normal;
333 let AggInfo {
334 calls: approx_percentile_agg_calls,
335 col_mapping: approx_percentile_col_mapping,
336 } = approx;
337 if !self.group_key().is_empty() && !approx_percentile_agg_calls.is_empty() {
338 bail_not_implemented!(
339 "two-phase streaming approx percentile aggregation with group key, \
340 please use single phase aggregation instead"
341 );
342 }
343
344 let needs_row_merge = (!non_approx_percentile_agg_calls.is_empty()
347 && !approx_percentile_agg_calls.is_empty())
348 || approx_percentile_agg_calls.len() >= 2;
349 let input = if needs_row_merge {
350 StreamShare::new_from_input(stream_input.clone()).into()
352 } else {
353 stream_input
354 };
355 let mut core = self.core.clone_with_input(input);
356 core.agg_calls = non_approx_percentile_agg_calls;
357
358 let approx_percentile =
359 self.build_approx_percentile_aggs(core.input.clone(), &approx_percentile_agg_calls)?;
360 Ok((
361 non_approx_percentile_col_mapping,
362 approx_percentile_col_mapping,
363 approx_percentile,
364 core,
365 ))
366 }
367
368 fn need_row_merge(approx_percentile: &Option<StreamPlanRef>) -> bool {
369 approx_percentile.is_some()
370 }
371
372 fn add_row_merge_if_needed(
374 approx_percentile: Option<StreamPlanRef>,
375 global_agg: StreamPlanRef,
376 approx_percentile_col_mapping: ColIndexMapping,
377 non_approx_percentile_col_mapping: ColIndexMapping,
378 ) -> Result<StreamPlanRef> {
379 let need_row_merge = Self::need_row_merge(&approx_percentile);
381
382 if let Some(approx_percentile) = approx_percentile {
383 assert!(need_row_merge);
384 let row_merge = StreamRowMerge::new(
385 approx_percentile,
386 global_agg,
387 approx_percentile_col_mapping,
388 non_approx_percentile_col_mapping,
389 )?;
390 Ok(row_merge.into())
391 } else {
392 assert!(!need_row_merge);
393 Ok(global_agg)
394 }
395 }
396
397 fn separate_normal_and_special_agg(&self) -> SeparatedAggInfo {
398 let estimated_len = self.agg_calls().len() - 1;
399 let mut approx_percentile_agg_calls = Vec::with_capacity(estimated_len);
400 let mut non_approx_percentile_agg_calls = Vec::with_capacity(estimated_len);
401 let mut approx_percentile_col_mapping = Vec::with_capacity(estimated_len);
402 let mut non_approx_percentile_col_mapping = Vec::with_capacity(estimated_len);
403 for (output_idx, agg_call) in self.agg_calls().iter().enumerate() {
404 if agg_call.agg_type == AggType::Builtin(PbAggKind::ApproxPercentile) {
405 approx_percentile_agg_calls.push(agg_call.clone());
406 approx_percentile_col_mapping.push(Some(output_idx));
407 } else {
408 non_approx_percentile_agg_calls.push(agg_call.clone());
409 non_approx_percentile_col_mapping.push(Some(output_idx));
410 }
411 }
412 let normal = AggInfo {
413 calls: non_approx_percentile_agg_calls,
414 col_mapping: ColIndexMapping::new(
415 non_approx_percentile_col_mapping,
416 self.agg_calls().len(),
417 ),
418 };
419 let approx = AggInfo {
420 calls: approx_percentile_agg_calls,
421 col_mapping: ColIndexMapping::new(
422 approx_percentile_col_mapping,
423 self.agg_calls().len(),
424 ),
425 };
426 SeparatedAggInfo { normal, approx }
427 }
428
429 fn build_approx_percentile_agg(
430 &self,
431 input: StreamPlanRef,
432 approx_percentile_agg_call: &PlanAggCall,
433 ) -> Result<StreamPlanRef> {
434 let local_approx_percentile =
435 StreamLocalApproxPercentile::new(input, approx_percentile_agg_call)?;
436 let exchange = RequiredDist::single()
437 .streaming_enforce_if_not_satisfies(local_approx_percentile.into())?;
438 let global_approx_percentile =
439 StreamGlobalApproxPercentile::new(exchange, approx_percentile_agg_call);
440 Ok(global_approx_percentile.into())
441 }
442
443 fn build_approx_percentile_aggs(
456 &self,
457 input: StreamPlanRef,
458 approx_percentile_agg_call: &[PlanAggCall],
459 ) -> Result<Option<StreamPlanRef>> {
460 if approx_percentile_agg_call.is_empty() {
461 return Ok(None);
462 }
463 let approx_percentile_plans: Vec<_> = approx_percentile_agg_call
464 .iter()
465 .map(|agg_call| self.build_approx_percentile_agg(input.clone(), agg_call))
466 .try_collect()?;
467 assert!(!approx_percentile_plans.is_empty());
468 let mut iter = approx_percentile_plans.into_iter();
469 let mut acc = iter.next().unwrap();
470 for (current_size, plan) in iter.enumerate().map(|(i, p)| (i + 1, p)) {
471 let new_size = current_size + 1;
472 let row_merge = StreamRowMerge::new(
473 acc,
474 plan,
475 ColIndexMapping::identity_or_none(current_size, new_size),
476 ColIndexMapping::new(vec![Some(current_size)], new_size),
477 )?;
478 acc = row_merge.into();
479 }
480 Ok(Some(acc))
481 }
482
483 pub fn core(&self) -> &Agg<PlanRef> {
484 &self.core
485 }
486}
487
488pub struct LogicalAggBuilder {
493 input_proj_builder: ProjectBuilder,
495 group_key: IndexSet,
497 grouping_sets: Vec<IndexSet>,
499 agg_calls: Vec<PlanAggCall>,
501 error: Option<RwError>,
503 is_in_filter_clause: bool,
509}
510
511impl LogicalAggBuilder {
512 fn new(group_by: GroupBy, input_schema_len: usize) -> Result<Self> {
513 let mut input_proj_builder = ProjectBuilder::default();
514
515 let mut gen_group_key_and_grouping_sets =
516 |grouping_sets: Vec<Vec<ExprImpl>>| -> Result<(IndexSet, Vec<IndexSet>)> {
517 let grouping_sets: Vec<IndexSet> = grouping_sets
518 .into_iter()
519 .map(|set| {
520 set.into_iter()
521 .map(|expr| input_proj_builder.add_expr(&expr))
522 .try_collect()
523 .map_err(|err| not_implemented!("{err} inside GROUP BY"))
524 })
525 .try_collect()?;
526
527 let group_key = grouping_sets
529 .iter()
530 .fold(FixedBitSet::with_capacity(input_schema_len), |acc, x| {
531 acc.union(&x.to_bitset()).collect()
532 });
533
534 Ok((IndexSet::from_iter(group_key.ones()), grouping_sets))
535 };
536
537 let (group_key, grouping_sets) = match group_by {
538 GroupBy::GroupKey(group_key) => {
539 let group_key = group_key
540 .into_iter()
541 .map(|expr| input_proj_builder.add_expr(&expr))
542 .try_collect()
543 .map_err(|err| not_implemented!("{err} inside GROUP BY"))?;
544 (group_key, vec![])
545 }
546 GroupBy::GroupingSets(grouping_sets) => gen_group_key_and_grouping_sets(grouping_sets)?,
547 GroupBy::Rollup(rollup) => {
548 let grouping_sets = (0..=rollup.len())
550 .map(|n| {
551 rollup
552 .iter()
553 .take(n)
554 .flat_map(|x| x.iter().cloned())
555 .collect_vec()
556 })
557 .collect_vec();
558 gen_group_key_and_grouping_sets(grouping_sets)?
559 }
560 GroupBy::Cube(cube) => {
561 let grouping_sets = cube
563 .into_iter()
564 .powerset()
565 .map(|x| x.into_iter().flatten().collect_vec())
566 .collect_vec();
567 gen_group_key_and_grouping_sets(grouping_sets)?
568 }
569 };
570
571 Ok(LogicalAggBuilder {
572 group_key,
573 grouping_sets,
574 agg_calls: vec![],
575 error: None,
576 input_proj_builder,
577 is_in_filter_clause: false,
578 })
579 }
580
581 pub fn build(self, input: PlanRef) -> LogicalAgg {
582 let logical_project = LogicalProject::with_core(self.input_proj_builder.build(input));
584
585 Agg::new(self.agg_calls, self.group_key, logical_project.into())
587 .with_grouping_sets(self.grouping_sets)
588 .into()
589 }
590
591 fn rewrite_with_error(&mut self, expr: ExprImpl) -> Result<ExprImpl> {
592 let rewritten_expr = self.rewrite_expr(expr);
593 if let Some(error) = self.error.take() {
594 return Err(error);
595 }
596 Ok(rewritten_expr)
597 }
598
599 pub fn try_as_group_expr(&self, expr: &ExprImpl) -> Option<usize> {
601 if let Some(input_index) = self.input_proj_builder.expr_index(expr)
602 && let Some(index) = self
603 .group_key
604 .indices()
605 .position(|group_key| group_key == input_index)
606 {
607 return Some(index);
608 }
609 None
610 }
611
612 fn schema_agg_start_offset(&self) -> usize {
613 self.group_key.len()
614 }
615
616 pub(crate) fn general_rewrite_agg_call(
619 agg_call: AggCall,
620 mut push_agg_call: impl FnMut(AggCall) -> Result<InputRef>,
621 ) -> Result<ExprImpl> {
622 match agg_call.agg_type {
623 AggType::Builtin(PbAggKind::Avg) => {
625 assert_eq!(agg_call.args.len(), 1);
626
627 let sum = ExprImpl::from(push_agg_call(AggCall::new(
628 PbAggKind::Sum.into(),
629 agg_call.args.clone(),
630 agg_call.distinct,
631 agg_call.order_by.clone(),
632 agg_call.filter.clone(),
633 agg_call.direct_args.clone(),
634 )?)?)
635 .cast_explicit(&agg_call.return_type())?;
636
637 let count = ExprImpl::from(push_agg_call(AggCall::new(
638 PbAggKind::Count.into(),
639 agg_call.args.clone(),
640 agg_call.distinct,
641 agg_call.order_by.clone(),
642 agg_call.filter.clone(),
643 agg_call.direct_args.clone(),
644 )?)?);
645
646 Ok(FunctionCall::new(ExprType::Divide, Vec::from([sum, count]))?.into())
647 }
648 AggType::Builtin(
657 kind @ (PbAggKind::StddevPop
658 | PbAggKind::StddevSamp
659 | PbAggKind::VarPop
660 | PbAggKind::VarSamp),
661 ) => {
662 let arg = agg_call.args().iter().exactly_one().unwrap();
663 let squared_arg = ExprImpl::from(FunctionCall::new(
664 ExprType::Multiply,
665 vec![arg.clone(), arg.clone()],
666 )?);
667
668 let sum_of_sq = ExprImpl::from(push_agg_call(AggCall::new(
669 PbAggKind::Sum.into(),
670 vec![squared_arg],
671 agg_call.distinct,
672 agg_call.order_by.clone(),
673 agg_call.filter.clone(),
674 agg_call.direct_args.clone(),
675 )?)?)
676 .cast_explicit(&agg_call.return_type())?;
677
678 let sum = ExprImpl::from(push_agg_call(AggCall::new(
679 PbAggKind::Sum.into(),
680 agg_call.args.clone(),
681 agg_call.distinct,
682 agg_call.order_by.clone(),
683 agg_call.filter.clone(),
684 agg_call.direct_args.clone(),
685 )?)?)
686 .cast_explicit(&agg_call.return_type())?;
687
688 let count = ExprImpl::from(push_agg_call(AggCall::new(
689 PbAggKind::Count.into(),
690 agg_call.args.clone(),
691 agg_call.distinct,
692 agg_call.order_by.clone(),
693 agg_call.filter.clone(),
694 agg_call.direct_args.clone(),
695 )?)?);
696
697 let zero = ExprImpl::literal_int(0);
698 let one = ExprImpl::literal_int(1);
699
700 let squared_sum = ExprImpl::from(FunctionCall::new(
701 ExprType::Multiply,
702 vec![sum.clone(), sum],
703 )?);
704
705 let raw_numerator = ExprImpl::from(FunctionCall::new(
706 ExprType::Subtract,
707 vec![
708 sum_of_sq,
709 ExprImpl::from(FunctionCall::new(
710 ExprType::Divide,
711 vec![squared_sum, count.clone()],
712 )?),
713 ],
714 )?);
715
716 let numerator_type = raw_numerator.return_type();
718 let numerator = ExprImpl::from(FunctionCall::new(
719 ExprType::Greatest,
720 vec![raw_numerator, zero.clone().cast_explicit(&numerator_type)?],
721 )?);
722
723 let denominator = match kind {
724 PbAggKind::VarPop | PbAggKind::StddevPop => count.clone(),
725 PbAggKind::VarSamp | PbAggKind::StddevSamp => ExprImpl::from(
726 FunctionCall::new(ExprType::Subtract, vec![count.clone(), one.clone()])?,
727 ),
728 _ => unreachable!(),
729 };
730
731 let mut target = ExprImpl::from(FunctionCall::new(
732 ExprType::Divide,
733 vec![numerator, denominator],
734 )?);
735
736 if matches!(kind, PbAggKind::StddevPop | PbAggKind::StddevSamp) {
737 target = ExprImpl::from(FunctionCall::new(ExprType::Sqrt, vec![target])?);
738 }
739
740 let null = ExprImpl::from(Literal::new(None, agg_call.return_type()));
741 let case_cond = match kind {
742 PbAggKind::VarPop | PbAggKind::StddevPop => {
743 ExprImpl::from(FunctionCall::new(ExprType::Equal, vec![count, zero])?)
744 }
745 PbAggKind::VarSamp | PbAggKind::StddevSamp => ExprImpl::from(
746 FunctionCall::new(ExprType::LessThanOrEqual, vec![count, one])?,
747 ),
748 _ => unreachable!(),
749 };
750
751 Ok(ExprImpl::from(FunctionCall::new(
752 ExprType::Case,
753 vec![case_cond, null, target],
754 )?))
755 }
756 AggType::Builtin(PbAggKind::ApproxPercentile) => {
757 if agg_call.order_by.sort_exprs[0].order_type == OrderType::descending() {
758 let prev_percentile = agg_call.direct_args[0].clone();
760 let new_percentile = 1.0
761 - prev_percentile
762 .get_data()
763 .as_ref()
764 .unwrap()
765 .as_float64()
766 .into_inner();
767 let new_percentile = Some(ScalarImpl::Float64(new_percentile.into()));
768 let new_percentile = Literal::new(new_percentile, DataType::Float64);
769 let new_direct_args = vec![new_percentile, agg_call.direct_args[1].clone()];
770
771 let new_agg_call = AggCall {
772 order_by: OrderBy::any(),
773 direct_args: new_direct_args,
774 ..agg_call
775 };
776 Ok(push_agg_call(new_agg_call)?.into())
777 } else {
778 let new_agg_call = AggCall {
779 order_by: OrderBy::any(),
780 ..agg_call
781 };
782 Ok(push_agg_call(new_agg_call)?.into())
783 }
784 }
785 AggType::Builtin(PbAggKind::ArgMin | PbAggKind::ArgMax) => {
786 let mut agg_call = agg_call;
787
788 let comparison_arg_type = agg_call.args[1].return_type();
789 match comparison_arg_type {
790 DataType::Struct(_)
791 | DataType::List(_)
792 | DataType::Map(_)
793 | DataType::Vector(_)
794 | DataType::Jsonb => {
795 bail!(format!(
796 "{} does not support struct, array, map, vector, jsonb for comparison argument, got {}",
797 agg_call.agg_type.to_string(),
798 comparison_arg_type
799 ));
800 }
801 _ => {}
802 }
803
804 let not_null_exprs: Vec<ExprImpl> = agg_call
805 .args
806 .iter()
807 .map(|arg| -> Result<ExprImpl> {
808 Ok(FunctionCall::new(ExprType::IsNotNull, vec![arg.clone()])?.into())
809 })
810 .try_collect()?;
811
812 let comparison_expr = agg_call.args[1].clone();
813 let mut order_exprs = vec![OrderByExpr {
814 expr: comparison_expr,
815 order_type: if agg_call.agg_type == AggType::Builtin(PbAggKind::ArgMin) {
816 OrderType::ascending()
817 } else {
818 OrderType::descending()
819 },
820 }];
821
822 order_exprs.extend(agg_call.order_by.sort_exprs);
823
824 let order_by = OrderBy::new(order_exprs);
825
826 let filter = agg_call.filter.clone().and(Condition {
827 conjunctions: not_null_exprs,
828 });
829
830 agg_call.args.truncate(1);
831
832 let new_agg_call = AggCall {
833 agg_type: AggType::Builtin(PbAggKind::FirstValue),
834 order_by,
835 filter,
836 ..agg_call
837 };
838 Ok(push_agg_call(new_agg_call)?.into())
839 }
840 _ => Ok(push_agg_call(agg_call)?.into()),
841 }
842 }
843
844 fn push_agg_call(&mut self, agg_call: AggCall) -> Result<InputRef> {
848 let AggCall {
849 agg_type,
850 return_type,
851 args,
852 distinct,
853 order_by,
854 filter,
855 direct_args,
856 } = agg_call;
857
858 self.is_in_filter_clause = true;
859 let filter = filter.rewrite_expr(self);
862 self.is_in_filter_clause = false;
863
864 let args: Vec<_> = args
865 .iter()
866 .map(|expr| {
867 let index = self.input_proj_builder.add_expr(expr)?;
868 Ok(InputRef::new(index, expr.return_type()))
869 })
870 .try_collect()
871 .map_err(|err: &'static str| not_implemented!("{err} inside aggregation calls"))?;
872
873 let order_by: Vec<_> = order_by
874 .sort_exprs
875 .iter()
876 .map(|e| {
877 let index = self.input_proj_builder.add_expr(&e.expr)?;
878 Ok(ColumnOrder::new(index, e.order_type))
879 })
880 .try_collect()
881 .map_err(|err: &'static str| {
882 not_implemented!("{err} inside aggregation calls order by")
883 })?;
884
885 let plan_agg_call = PlanAggCall {
886 agg_type,
887 return_type: return_type.clone(),
888 inputs: args,
889 distinct,
890 order_by,
891 filter,
892 direct_args,
893 };
894
895 if let Some((pos, existing)) = self
896 .agg_calls
897 .iter()
898 .find_position(|&c| c == &plan_agg_call)
899 {
900 return Ok(InputRef::new(
901 self.schema_agg_start_offset() + pos,
902 existing.return_type.clone(),
903 ));
904 }
905 let index = self.schema_agg_start_offset() + self.agg_calls.len();
906 self.agg_calls.push(plan_agg_call);
907 Ok(InputRef::new(index, return_type))
908 }
909
910 fn try_rewrite_agg_call(&mut self, mut agg_call: AggCall) -> Result<ExprImpl> {
917 if matches!(agg_call.agg_type, agg_types::must_have_order_by!())
918 && agg_call.order_by.sort_exprs.is_empty()
919 {
920 return Err(ErrorCode::InvalidInputSyntax(format!(
921 "Aggregation function {} requires ORDER BY clause",
922 agg_call.agg_type
923 ))
924 .into());
925 }
926
927 if matches!(
929 agg_call.agg_type,
930 agg_types::result_unaffected_by_order_by!()
931 ) {
932 agg_call.order_by = OrderBy::any();
933 }
934 if matches!(
936 agg_call.agg_type,
937 agg_types::result_unaffected_by_distinct!()
938 ) {
939 agg_call.distinct = false;
940 }
941
942 if matches!(agg_call.agg_type, AggType::Builtin(PbAggKind::Grouping)) {
943 if self.grouping_sets.is_empty() {
944 return Err(ErrorCode::NotSupported(
945 "GROUPING must be used in a query with grouping sets".into(),
946 "try to use grouping sets instead".into(),
947 )
948 .into());
949 }
950 if agg_call.args.len() >= 32 {
951 return Err(ErrorCode::InvalidInputSyntax(
952 "GROUPING must have fewer than 32 arguments".into(),
953 )
954 .into());
955 }
956 if agg_call
957 .args
958 .iter()
959 .any(|x| self.try_as_group_expr(x).is_none())
960 {
961 return Err(ErrorCode::InvalidInputSyntax(
962 "arguments to GROUPING must be grouping expressions of the associated query level"
963 .into(),
964 ).into());
965 }
966 }
967
968 Self::general_rewrite_agg_call(agg_call, |agg_call| self.push_agg_call(agg_call))
969 }
970}
971
972impl ExprRewriter for LogicalAggBuilder {
973 fn rewrite_agg_call(&mut self, agg_call: AggCall) -> ExprImpl {
974 let dummy = Literal::new(None, agg_call.return_type()).into();
975 match self.try_rewrite_agg_call(agg_call) {
976 Ok(expr) => expr,
977 Err(err) => {
978 self.error = Some(err);
979 dummy
980 }
981 }
982 }
983
984 fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
987 let expr = func_call.into();
988 if let Some(group_key) = self.try_as_group_expr(&expr) {
989 InputRef::new(group_key, expr.return_type()).into()
990 } else {
991 let (func_type, inputs, ret) = expr.into_function_call().unwrap().decompose();
992 let inputs = inputs
993 .into_iter()
994 .map(|expr| self.rewrite_expr(expr))
995 .collect();
996 FunctionCall::new_unchecked(func_type, inputs, ret).into()
997 }
998 }
999
1000 fn rewrite_window_function(&mut self, window_func: WindowFunction) -> ExprImpl {
1003 let WindowFunction {
1004 args,
1005 partition_by,
1006 order_by,
1007 ..
1008 } = window_func;
1009 let args = args
1010 .into_iter()
1011 .map(|expr| self.rewrite_expr(expr))
1012 .collect();
1013 let partition_by = partition_by
1014 .into_iter()
1015 .map(|expr| self.rewrite_expr(expr))
1016 .collect();
1017 let order_by = order_by.rewrite_expr(self);
1018 WindowFunction {
1019 args,
1020 partition_by,
1021 order_by,
1022 ..window_func
1023 }
1024 .into()
1025 }
1026
1027 fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
1029 let expr = input_ref.into();
1030 if let Some(group_key) = self.try_as_group_expr(&expr) {
1031 InputRef::new(group_key, expr.return_type()).into()
1032 } else if self.is_in_filter_clause {
1033 InputRef::new(
1034 self.input_proj_builder.add_expr(&expr).unwrap(),
1035 expr.return_type(),
1036 )
1037 .into()
1038 } else {
1039 self.error = Some(
1040 ErrorCode::InvalidInputSyntax(
1041 "column must appear in the GROUP BY clause or be used in an aggregate function"
1042 .into(),
1043 )
1044 .into(),
1045 );
1046 expr
1047 }
1048 }
1049
1050 fn rewrite_subquery(&mut self, subquery: crate::expr::Subquery) -> ExprImpl {
1051 if subquery.is_correlated_by_depth(0) {
1052 self.error = Some(
1053 not_implemented!(
1054 issue = 2275,
1055 "correlated subquery in HAVING or SELECT with agg",
1056 )
1057 .into(),
1058 );
1059 }
1060 subquery.into()
1061 }
1062}
1063
1064impl From<Agg<PlanRef>> for LogicalAgg {
1065 fn from(core: Agg<PlanRef>) -> Self {
1066 let base = PlanBase::new_logical_with_core(&core);
1067 Self { base, core }
1068 }
1069}
1070
1071impl From<Agg<PlanRef>> for PlanRef {
1073 fn from(core: Agg<PlanRef>) -> Self {
1074 LogicalAgg::from(core).into()
1075 }
1076}
1077
1078impl LogicalAgg {
1079 pub fn i2o_col_mapping(&self) -> ColIndexMapping {
1081 self.core.i2o_col_mapping()
1082 }
1083
1084 pub fn create(
1093 select_exprs: Vec<ExprImpl>,
1094 group_by: GroupBy,
1095 having: Option<ExprImpl>,
1096 input: PlanRef,
1097 ) -> Result<(PlanRef, Vec<ExprImpl>, Option<ExprImpl>)> {
1098 let mut agg_builder = LogicalAggBuilder::new(group_by, input.schema().len())?;
1099
1100 let rewritten_select_exprs = select_exprs
1101 .into_iter()
1102 .map(|expr| agg_builder.rewrite_with_error(expr))
1103 .collect::<Result<_>>()?;
1104 let rewritten_having = having
1105 .map(|expr| agg_builder.rewrite_with_error(expr))
1106 .transpose()?;
1107
1108 Ok((
1109 agg_builder.build(input).into(),
1110 rewritten_select_exprs,
1111 rewritten_having,
1112 ))
1113 }
1114
1115 pub fn agg_calls(&self) -> &Vec<PlanAggCall> {
1117 &self.core.agg_calls
1118 }
1119
1120 pub fn group_key(&self) -> &IndexSet {
1122 &self.core.group_key
1123 }
1124
1125 pub fn grouping_sets(&self) -> &Vec<IndexSet> {
1126 &self.core.grouping_sets
1127 }
1128
1129 pub fn decompose(self) -> (Vec<PlanAggCall>, IndexSet, Vec<IndexSet>, PlanRef, bool) {
1130 self.core.decompose()
1131 }
1132
1133 #[must_use]
1134 pub fn rewrite_with_input_agg(
1135 &self,
1136 input: PlanRef,
1137 agg_calls: &[PlanAggCall],
1138 mut input_col_change: ColIndexMapping,
1139 ) -> (Self, ColIndexMapping) {
1140 let agg_calls = agg_calls
1141 .iter()
1142 .cloned()
1143 .map(|mut agg_call| {
1144 agg_call.inputs.iter_mut().for_each(|i| {
1145 *i = InputRef::new(input_col_change.map(i.index()), i.return_type())
1146 });
1147 agg_call.order_by.iter_mut().for_each(|o| {
1148 o.column_index = input_col_change.map(o.column_index);
1149 });
1150 agg_call.filter = agg_call.filter.rewrite_expr(&mut input_col_change);
1151 agg_call
1152 })
1153 .collect();
1154 let group_key_in_vec: Vec<usize> = self
1156 .group_key()
1157 .indices()
1158 .map(|key| input_col_change.map(key))
1159 .collect();
1160 let group_key: IndexSet = group_key_in_vec.iter().cloned().collect();
1162 let grouping_sets = self
1163 .grouping_sets()
1164 .iter()
1165 .map(|set| set.indices().map(|key| input_col_change.map(key)).collect())
1166 .collect();
1167
1168 let new_agg = Agg::new(agg_calls, group_key.clone(), input)
1169 .with_grouping_sets(grouping_sets)
1170 .with_enable_two_phase(self.core().enable_two_phase);
1171
1172 let mut out_col_change = vec![];
1175 for idx in group_key_in_vec {
1176 let pos = group_key.indices().position(|x| x == idx).unwrap();
1177 out_col_change.push(pos);
1178 }
1179 for i in (group_key.len())..new_agg.schema().len() {
1180 out_col_change.push(i);
1181 }
1182 let out_col_change =
1183 ColIndexMapping::with_remaining_columns(&out_col_change, new_agg.schema().len());
1184
1185 (new_agg.into(), out_col_change)
1186 }
1187}
1188
1189impl PlanTreeNodeUnary<Logical> for LogicalAgg {
1190 fn input(&self) -> PlanRef {
1191 self.core.input.clone()
1192 }
1193
1194 fn clone_with_input(&self, input: PlanRef) -> Self {
1195 Agg::new(self.agg_calls().to_vec(), self.group_key().clone(), input)
1196 .with_grouping_sets(self.grouping_sets().clone())
1197 .with_enable_two_phase(self.core().enable_two_phase)
1198 .into()
1199 }
1200
1201 fn rewrite_with_input(
1202 &self,
1203 input: PlanRef,
1204 input_col_change: ColIndexMapping,
1205 ) -> (Self, ColIndexMapping) {
1206 self.rewrite_with_input_agg(input, self.agg_calls(), input_col_change)
1207 }
1208}
1209
1210impl_plan_tree_node_for_unary! { Logical, LogicalAgg }
1211impl_distill_by_unit!(LogicalAgg, core, "LogicalAgg");
1212
1213impl ExprRewritable<Logical> for LogicalAgg {
1214 fn has_rewritable_expr(&self) -> bool {
1215 true
1216 }
1217
1218 fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
1219 let mut core = self.core.clone();
1220 core.rewrite_exprs(r);
1221 Self {
1222 base: self.base.clone_with_new_plan_id(),
1223 core,
1224 }
1225 .into()
1226 }
1227}
1228
1229impl ExprVisitable for LogicalAgg {
1230 fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
1231 self.core.visit_exprs(v);
1232 }
1233}
1234
1235impl ColPrunable for LogicalAgg {
1236 fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
1237 let group_key_required_cols = self.group_key().to_bitset();
1238
1239 let (agg_call_required_cols, agg_calls) = {
1240 let input_cnt = self.input().schema().len();
1241 let mut tmp = FixedBitSet::with_capacity(input_cnt);
1242 let group_key_cardinality = self.group_key().len();
1243 let new_agg_calls = required_cols
1244 .iter()
1245 .filter(|&&index| index >= group_key_cardinality)
1246 .map(|&index| {
1247 let index = index - group_key_cardinality;
1248 let agg_call = self.agg_calls()[index].clone();
1249 tmp.extend(agg_call.inputs.iter().map(|x| x.index()));
1250 tmp.extend(agg_call.order_by.iter().map(|x| x.column_index));
1251 for i in &agg_call.filter.conjunctions {
1253 tmp.union_with(&i.collect_input_refs(input_cnt));
1254 }
1255 agg_call
1256 })
1257 .collect_vec();
1258 (tmp, new_agg_calls)
1259 };
1260
1261 let input_required_cols = {
1262 let mut tmp = FixedBitSet::with_capacity(self.input().schema().len());
1263 tmp.union_with(&group_key_required_cols);
1264 tmp.union_with(&agg_call_required_cols);
1265 tmp.ones().collect_vec()
1266 };
1267 let input_col_change = ColIndexMapping::with_remaining_columns(
1268 &input_required_cols,
1269 self.input().schema().len(),
1270 );
1271 let agg = {
1272 let input = self.input().prune_col(&input_required_cols, ctx);
1273 let (agg, output_col_change) =
1274 self.rewrite_with_input_agg(input, &agg_calls, input_col_change);
1275 assert!(output_col_change.is_identity());
1276 agg
1277 };
1278 let new_output_cols = {
1279 let group_key_cardinality = agg.group_key().len();
1281 let mut tmp = (0..group_key_cardinality).collect_vec();
1282 tmp.extend(
1283 required_cols
1284 .iter()
1285 .filter(|&&index| index >= group_key_cardinality),
1286 );
1287 tmp
1288 };
1289 if new_output_cols == required_cols {
1290 agg.into()
1292 } else {
1293 let mapping =
1296 &ColIndexMapping::with_remaining_columns(&new_output_cols, self.schema().len());
1297 let output_required_cols = required_cols
1298 .iter()
1299 .map(|&idx| mapping.map(idx))
1300 .collect_vec();
1301 let src_size = agg.schema().len();
1302 LogicalProject::with_mapping(
1303 agg.into(),
1304 ColIndexMapping::with_remaining_columns(&output_required_cols, src_size),
1305 )
1306 .into()
1307 }
1308 }
1309}
1310
1311impl PredicatePushdown for LogicalAgg {
1312 fn predicate_pushdown(
1313 &self,
1314 predicate: Condition,
1315 ctx: &mut PredicatePushdownContext,
1316 ) -> PlanRef {
1317 let num_group_key = self.group_key().len();
1318 let num_agg_calls = self.agg_calls().len();
1319 assert!(num_group_key + num_agg_calls == self.schema().len());
1320
1321 if num_group_key == 0 {
1329 return gen_filter_and_pushdown(self, predicate, Condition::true_cond(), ctx);
1330 }
1331
1332 let mut agg_call_columns = FixedBitSet::with_capacity(num_group_key + num_agg_calls);
1334 agg_call_columns.insert_range(num_group_key..num_group_key + num_agg_calls);
1335 let (agg_call_pred, pushed_predicate) = predicate.split_disjoint(&agg_call_columns);
1336
1337 let mut subst = Substitute {
1339 mapping: self
1340 .group_key()
1341 .indices()
1342 .enumerate()
1343 .map(|(i, group_key)| {
1344 InputRef::new(group_key, self.schema().fields()[i].data_type()).into()
1345 })
1346 .collect(),
1347 };
1348 let pushed_predicate = pushed_predicate.rewrite_expr(&mut subst);
1349
1350 gen_filter_and_pushdown(self, agg_call_pred, pushed_predicate, ctx)
1351 }
1352}
1353
1354impl ToBatch for LogicalAgg {
1355 fn to_batch(&self) -> Result<crate::optimizer::plan_node::BatchPlanRef> {
1356 self.to_batch_with_order_required(&Order::any())
1357 }
1358
1359 fn to_batch_with_order_required(
1362 &self,
1363 required_order: &Order,
1364 ) -> Result<crate::optimizer::plan_node::BatchPlanRef> {
1365 let input = self.input().to_batch()?;
1366 let new_logical = self.core.clone_with_input(input);
1367 let agg_plan = if self.group_key().is_empty() {
1368 BatchSimpleAgg::new(new_logical).into()
1369 } else if self.ctx().session_ctx().config().batch_enable_sort_agg()
1370 && new_logical.input_provides_order_on_group_keys()
1371 {
1372 BatchSortAgg::new(new_logical).into()
1373 } else {
1374 BatchHashAgg::new(new_logical).into()
1375 };
1376 required_order.enforce_if_not_satisfies(agg_plan)
1377 }
1378}
1379
1380fn find_or_append_row_count(mut logical: Agg<StreamPlanRef>) -> (Agg<StreamPlanRef>, usize) {
1381 let count_star = PlanAggCall::count_star();
1384 let row_count_idx = if let Some((idx, _)) = logical
1385 .agg_calls
1386 .iter()
1387 .find_position(|&c| c == &count_star)
1388 {
1389 idx
1390 } else {
1391 let idx = logical.agg_calls.len();
1392 logical.agg_calls.push(count_star);
1393 idx
1394 };
1395 (logical, row_count_idx)
1396}
1397
1398fn new_stream_simple_agg(
1399 core: Agg<StreamPlanRef>,
1400 must_output_per_barrier: bool,
1401) -> Result<StreamSimpleAgg> {
1402 let (logical, row_count_idx) = find_or_append_row_count(core);
1403 StreamSimpleAgg::new(logical, row_count_idx, must_output_per_barrier)
1404}
1405
1406fn new_stream_hash_agg(
1407 core: Agg<StreamPlanRef>,
1408 vnode_col_idx: Option<usize>,
1409) -> Result<StreamHashAgg> {
1410 let (logical, row_count_idx) = find_or_append_row_count(core);
1411 StreamHashAgg::new(logical, vnode_col_idx, row_count_idx)
1412}
1413
1414impl ToStream for LogicalAgg {
1415 fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<StreamPlanRef> {
1416 use super::stream::prelude::*;
1417
1418 for agg_call in self.agg_calls() {
1419 if matches!(agg_call.agg_type, agg_types::unimplemented_in_stream!()) {
1420 bail_not_implemented!("{} aggregation in materialized view", agg_call.agg_type);
1421 }
1422 }
1423 let eowc = ctx.emit_on_window_close();
1424 let input = self
1425 .input()
1426 .try_better_locality(&self.group_key().to_vec())
1427 .unwrap_or_else(|| self.input());
1428 let stream_input = input.to_stream(ctx)?;
1429
1430 if stream_input.append_only() && self.agg_calls().is_empty() && !self.group_key().is_empty()
1432 {
1433 let input = if self.group_key().len() != self.input().schema().len() {
1434 let cols = &self.group_key().to_vec();
1435 LogicalProject::with_mapping(
1436 self.input(),
1437 ColIndexMapping::with_remaining_columns(cols, self.input().schema().len()),
1438 )
1439 .into()
1440 } else {
1441 self.input()
1442 };
1443 let input_schema_len = input.schema().len();
1444 let logical_dedup = LogicalDedup::new(input, (0..input_schema_len).collect());
1445 return logical_dedup.to_stream(ctx);
1446 }
1447
1448 if self.agg_calls().iter().any(|call| {
1449 matches!(
1450 call.agg_type,
1451 AggType::Builtin(PbAggKind::ApproxCountDistinct)
1452 )
1453 }) {
1454 if stream_input.append_only() {
1455 self.core.ctx().session_ctx().notice_to_user(
1456 "Streaming `APPROX_COUNT_DISTINCT` is still a preview feature and subject to change. Please do not use it in production environment.",
1457 );
1458 } else {
1459 bail_not_implemented!(
1460 "Streaming `APPROX_COUNT_DISTINCT` is only supported in append-only stream"
1461 );
1462 }
1463 }
1464
1465 let plan = self.gen_dist_stream_agg_plan(stream_input)?;
1466
1467 let (plan, n_final_agg_calls) = if let Some(final_agg) = plan.as_stream_simple_agg() {
1468 if eowc {
1469 return Err(ErrorCode::InvalidInputSyntax(
1470 "`EMIT ON WINDOW CLOSE` cannot be used for aggregation without `GROUP BY`"
1471 .to_owned(),
1472 )
1473 .into());
1474 }
1475 (plan.clone(), final_agg.agg_calls().len())
1476 } else if let Some(final_agg) = plan.as_stream_hash_agg() {
1477 (
1478 if eowc {
1479 final_agg.to_eowc_version()?
1480 } else {
1481 plan.clone()
1482 },
1483 final_agg.agg_calls().len(),
1484 )
1485 } else if let Some(_approx_percentile_agg) = plan.as_stream_global_approx_percentile() {
1486 if eowc {
1487 return Err(ErrorCode::InvalidInputSyntax(
1488 "`EMIT ON WINDOW CLOSE` cannot be used for aggregation without `GROUP BY`"
1489 .to_owned(),
1490 )
1491 .into());
1492 }
1493 (plan.clone(), 1)
1494 } else if let Some(stream_row_merge) = plan.as_stream_row_merge() {
1495 if eowc {
1496 return Err(ErrorCode::InvalidInputSyntax(
1497 "`EMIT ON WINDOW CLOSE` cannot be used for aggregation without `GROUP BY`"
1498 .to_owned(),
1499 )
1500 .into());
1501 }
1502 (plan.clone(), stream_row_merge.base.schema().len())
1503 } else {
1504 panic!(
1505 "the root PlanNode must be StreamHashAgg, StreamSimpleAgg, StreamGlobalApproxPercentile, or StreamRowMerge"
1506 );
1507 };
1508
1509 if self.agg_calls().len() == n_final_agg_calls {
1510 Ok(plan)
1512 } else {
1513 assert_eq!(self.agg_calls().len() + 1, n_final_agg_calls);
1515
1516 let mut project = StreamProject::new(generic::Project::with_out_col_idx(
1517 plan,
1518 0..self.schema().len(),
1519 ));
1520 if self.agg_calls().is_empty() {
1525 project = project.with_noop_update_hint(true);
1526 }
1527 Ok(project.into())
1528 }
1529 }
1530
1531 fn logical_rewrite_for_stream(
1532 &self,
1533 ctx: &mut RewriteStreamContext,
1534 ) -> Result<(PlanRef, ColIndexMapping)> {
1535 let (input, input_col_change) = self.input().logical_rewrite_for_stream(ctx)?;
1536 let (agg, out_col_change) = self.rewrite_with_input(input, input_col_change);
1537 let (map, _) = out_col_change.into_parts();
1538 let out_col_change = ColIndexMapping::new(map, agg.schema().len());
1539 Ok((agg.into(), out_col_change))
1540 }
1541}
1542
1543#[cfg(test)]
1544mod tests {
1545 use risingwave_common::catalog::{Field, Schema};
1546
1547 use super::*;
1548 use crate::expr::{assert_eq_input_ref, input_ref_to_column_indices};
1549 use crate::optimizer::optimizer_context::OptimizerContext;
1550 use crate::optimizer::plan_node::LogicalValues;
1551
1552 #[tokio::test]
1553 async fn test_create() {
1554 let ty = DataType::Int32;
1555 let ctx = OptimizerContext::mock().await;
1556 let fields: Vec<Field> = vec![
1557 Field::with_name(ty.clone(), "v1"),
1558 Field::with_name(ty.clone(), "v2"),
1559 Field::with_name(ty.clone(), "v3"),
1560 ];
1561 let values = LogicalValues::new(vec![], Schema { fields }, ctx);
1562 let input = PlanRef::from(values);
1563 let input_ref_1 = InputRef::new(0, ty.clone());
1564 let input_ref_2 = InputRef::new(1, ty.clone());
1565 let input_ref_3 = InputRef::new(2, ty.clone());
1566
1567 let gen_internal_value = |select_exprs: Vec<ExprImpl>,
1568 group_exprs|
1569 -> (Vec<ExprImpl>, Vec<PlanAggCall>, IndexSet) {
1570 let (plan, exprs, _) = LogicalAgg::create(
1571 select_exprs,
1572 GroupBy::GroupKey(group_exprs),
1573 None,
1574 input.clone(),
1575 )
1576 .unwrap();
1577
1578 let logical_agg = plan.as_logical_agg().unwrap();
1579 let agg_calls = logical_agg.agg_calls().to_vec();
1580 let group_key = logical_agg.group_key().clone();
1581
1582 (exprs, agg_calls, group_key)
1583 };
1584
1585 {
1587 let select_exprs = vec![input_ref_1.clone().into()];
1588 let group_exprs = vec![input_ref_1.clone().into()];
1589
1590 let (exprs, agg_calls, group_key) = gen_internal_value(select_exprs, group_exprs);
1591
1592 assert_eq!(exprs.len(), 1);
1593 assert_eq_input_ref!(&exprs[0], 0);
1594
1595 assert_eq!(agg_calls.len(), 0);
1596 assert_eq!(group_key, vec![0].into());
1597 }
1598
1599 {
1601 let min_v2 = AggCall::new(
1602 PbAggKind::Min.into(),
1603 vec![input_ref_2.clone().into()],
1604 false,
1605 OrderBy::any(),
1606 Condition::true_cond(),
1607 vec![],
1608 )
1609 .unwrap();
1610 let select_exprs = vec![input_ref_1.clone().into(), min_v2.into()];
1611 let group_exprs = vec![input_ref_1.clone().into()];
1612
1613 let (exprs, agg_calls, group_key) = gen_internal_value(select_exprs, group_exprs);
1614
1615 assert_eq!(exprs.len(), 2);
1616 assert_eq_input_ref!(&exprs[0], 0);
1617 assert_eq_input_ref!(&exprs[1], 1);
1618
1619 assert_eq!(agg_calls.len(), 1);
1620 assert_eq!(agg_calls[0].agg_type, PbAggKind::Min.into());
1621 assert_eq!(input_ref_to_column_indices(&agg_calls[0].inputs), vec![1]);
1622 assert_eq!(group_key, vec![0].into());
1623 }
1624
1625 {
1627 let min_v2 = AggCall::new(
1628 PbAggKind::Min.into(),
1629 vec![input_ref_2.clone().into()],
1630 false,
1631 OrderBy::any(),
1632 Condition::true_cond(),
1633 vec![],
1634 )
1635 .unwrap();
1636 let max_v3 = AggCall::new(
1637 PbAggKind::Max.into(),
1638 vec![input_ref_3.clone().into()],
1639 false,
1640 OrderBy::any(),
1641 Condition::true_cond(),
1642 vec![],
1643 )
1644 .unwrap();
1645 let func_call =
1646 FunctionCall::new(ExprType::Add, vec![min_v2.into(), max_v3.into()]).unwrap();
1647 let select_exprs = vec![input_ref_1.clone().into(), ExprImpl::from(func_call)];
1648 let group_exprs = vec![input_ref_1.clone().into()];
1649
1650 let (exprs, agg_calls, group_key) = gen_internal_value(select_exprs, group_exprs);
1651
1652 assert_eq_input_ref!(&exprs[0], 0);
1653 if let ExprImpl::FunctionCall(func_call) = &exprs[1] {
1654 assert_eq!(func_call.func_type(), ExprType::Add);
1655 let inputs = func_call.inputs();
1656 assert_eq_input_ref!(&inputs[0], 1);
1657 assert_eq_input_ref!(&inputs[1], 2);
1658 } else {
1659 panic!("Wrong expression type!");
1660 }
1661
1662 assert_eq!(agg_calls.len(), 2);
1663 assert_eq!(agg_calls[0].agg_type, PbAggKind::Min.into());
1664 assert_eq!(input_ref_to_column_indices(&agg_calls[0].inputs), vec![1]);
1665 assert_eq!(agg_calls[1].agg_type, PbAggKind::Max.into());
1666 assert_eq!(input_ref_to_column_indices(&agg_calls[1].inputs), vec![2]);
1667 assert_eq!(group_key, vec![0].into());
1668 }
1669
1670 {
1672 let v1_mult_v3 = FunctionCall::new(
1673 ExprType::Multiply,
1674 vec![input_ref_1.into(), input_ref_3.into()],
1675 )
1676 .unwrap();
1677 let agg_call = AggCall::new(
1678 PbAggKind::Min.into(),
1679 vec![v1_mult_v3.into()],
1680 false,
1681 OrderBy::any(),
1682 Condition::true_cond(),
1683 vec![],
1684 )
1685 .unwrap();
1686 let select_exprs = vec![input_ref_2.clone().into(), agg_call.into()];
1687 let group_exprs = vec![input_ref_2.into()];
1688
1689 let (exprs, agg_calls, group_key) = gen_internal_value(select_exprs, group_exprs);
1690
1691 assert_eq_input_ref!(&exprs[0], 0);
1692 assert_eq_input_ref!(&exprs[1], 1);
1693
1694 assert_eq!(agg_calls.len(), 1);
1695 assert_eq!(agg_calls[0].agg_type, PbAggKind::Min.into());
1696 assert_eq!(input_ref_to_column_indices(&agg_calls[0].inputs), vec![1]);
1697 assert_eq!(group_key, vec![0].into());
1698 }
1699 }
1700
1701 async fn generate_agg_call(ty: DataType, fields: Vec<Field>) -> LogicalAgg {
1708 let ctx = OptimizerContext::mock().await;
1709
1710 let values = LogicalValues::new(vec![], Schema { fields }, ctx);
1711 let agg_call = PlanAggCall {
1712 agg_type: PbAggKind::Min.into(),
1713 return_type: ty.clone(),
1714 inputs: vec![InputRef::new(2, ty.clone())],
1715 distinct: false,
1716 order_by: vec![],
1717 filter: Condition::true_cond(),
1718 direct_args: vec![],
1719 };
1720 Agg::new(vec![agg_call], vec![1].into(), values.into()).into()
1721 }
1722
1723 #[tokio::test]
1724 async fn test_prune_all() {
1735 let ty = DataType::Int32;
1736 let fields: Vec<Field> = vec![
1737 Field::with_name(ty.clone(), "v1"),
1738 Field::with_name(ty.clone(), "v2"),
1739 Field::with_name(ty.clone(), "v3"),
1740 ];
1741 let agg: PlanRef = generate_agg_call(ty.clone(), fields.clone()).await.into();
1742 let required_cols = vec![0, 1];
1744 let plan = agg.prune_col(&required_cols, &mut ColumnPruningContext::new(agg.clone()));
1745
1746 let agg_new = plan.as_logical_agg().unwrap();
1748 assert_eq!(agg_new.group_key(), &vec![0].into());
1749
1750 assert_eq!(agg_new.agg_calls().len(), 1);
1751 let agg_call_new = agg_new.agg_calls()[0].clone();
1752 assert_eq!(agg_call_new.agg_type, PbAggKind::Min.into());
1753 assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![1]);
1754 assert_eq!(agg_call_new.return_type, ty);
1755
1756 let values = agg_new.input();
1757 let values = values.as_logical_values().unwrap();
1758 assert_eq!(values.schema().fields(), &fields[1..]);
1759 }
1760
1761 #[tokio::test]
1762 async fn test_prune_all_with_order_required() {
1774 let ty = DataType::Int32;
1775 let fields: Vec<Field> = vec![
1776 Field::with_name(ty.clone(), "v1"),
1777 Field::with_name(ty.clone(), "v2"),
1778 Field::with_name(ty.clone(), "v3"),
1779 ];
1780 let agg: PlanRef = generate_agg_call(ty.clone(), fields.clone()).await.into();
1781 let required_cols = vec![1, 0];
1783 let plan = agg.prune_col(&required_cols, &mut ColumnPruningContext::new(agg.clone()));
1784 let proj = plan.as_logical_project().unwrap();
1786 assert_eq!(proj.exprs().len(), 2);
1787 assert_eq!(proj.exprs()[0].as_input_ref().unwrap().index(), 1);
1788 assert_eq!(proj.exprs()[1].as_input_ref().unwrap().index(), 0);
1789 let proj_input = proj.input();
1790 let agg_new = proj_input.as_logical_agg().unwrap();
1791 assert_eq!(agg_new.group_key(), &vec![0].into());
1792
1793 assert_eq!(agg_new.agg_calls().len(), 1);
1794 let agg_call_new = agg_new.agg_calls()[0].clone();
1795 assert_eq!(agg_call_new.agg_type, PbAggKind::Min.into());
1796 assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![1]);
1797 assert_eq!(agg_call_new.return_type, ty);
1798
1799 let values = agg_new.input();
1800 let values = values.as_logical_values().unwrap();
1801 assert_eq!(values.schema().fields(), &fields[1..]);
1802 }
1803
1804 #[tokio::test]
1805 async fn test_prune_group_key() {
1817 let ctx = OptimizerContext::mock().await;
1818 let ty = DataType::Int32;
1819 let fields: Vec<Field> = vec![
1820 Field::with_name(ty.clone(), "v1"),
1821 Field::with_name(ty.clone(), "v2"),
1822 Field::with_name(ty.clone(), "v3"),
1823 ];
1824 let values: LogicalValues = LogicalValues::new(
1825 vec![],
1826 Schema {
1827 fields: fields.clone(),
1828 },
1829 ctx,
1830 );
1831 let agg_call = PlanAggCall {
1832 agg_type: PbAggKind::Min.into(),
1833 return_type: ty.clone(),
1834 inputs: vec![InputRef::new(2, ty.clone())],
1835 distinct: false,
1836 order_by: vec![],
1837 filter: Condition::true_cond(),
1838 direct_args: vec![],
1839 };
1840 let agg: PlanRef = Agg::new(vec![agg_call], vec![1].into(), values.into()).into();
1841
1842 let required_cols = vec![1];
1844 let plan = agg.prune_col(&required_cols, &mut ColumnPruningContext::new(agg.clone()));
1845
1846 let project = plan.as_logical_project().unwrap();
1848 assert_eq!(project.exprs().len(), 1);
1849 assert_eq_input_ref!(&project.exprs()[0], 1);
1850
1851 let agg_new = project.input();
1852 let agg_new = agg_new.as_logical_agg().unwrap();
1853 assert_eq!(agg_new.group_key(), &vec![0].into());
1854
1855 assert_eq!(agg_new.agg_calls().len(), 1);
1856 let agg_call_new = agg_new.agg_calls()[0].clone();
1857 assert_eq!(agg_call_new.agg_type, PbAggKind::Min.into());
1858 assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![1]);
1859 assert_eq!(agg_call_new.return_type, ty);
1860
1861 let values = agg_new.input();
1862 let values = values.as_logical_values().unwrap();
1863 assert_eq!(values.schema().fields(), &fields[1..]);
1864 }
1865
1866 #[tokio::test]
1867 async fn test_prune_agg() {
1879 let ty = DataType::Int32;
1880 let ctx = OptimizerContext::mock().await;
1881 let fields: Vec<Field> = vec![
1882 Field::with_name(ty.clone(), "v1"),
1883 Field::with_name(ty.clone(), "v2"),
1884 Field::with_name(ty.clone(), "v3"),
1885 ];
1886 let values = LogicalValues::new(
1887 vec![],
1888 Schema {
1889 fields: fields.clone(),
1890 },
1891 ctx,
1892 );
1893
1894 let agg_calls = vec![
1895 PlanAggCall {
1896 agg_type: PbAggKind::Min.into(),
1897 return_type: ty.clone(),
1898 inputs: vec![InputRef::new(2, ty.clone())],
1899 distinct: false,
1900 order_by: vec![],
1901 filter: Condition::true_cond(),
1902 direct_args: vec![],
1903 },
1904 PlanAggCall {
1905 agg_type: PbAggKind::Max.into(),
1906 return_type: ty.clone(),
1907 inputs: vec![InputRef::new(1, ty.clone())],
1908 distinct: false,
1909 order_by: vec![],
1910 filter: Condition::true_cond(),
1911 direct_args: vec![],
1912 },
1913 ];
1914 let agg: PlanRef = Agg::new(agg_calls, vec![1, 2].into(), values.into()).into();
1915
1916 let required_cols = vec![0, 3];
1918 let plan = agg.prune_col(&required_cols, &mut ColumnPruningContext::new(agg.clone()));
1919 let project = plan.as_logical_project().unwrap();
1921 assert_eq!(project.exprs().len(), 2);
1922 assert_eq_input_ref!(&project.exprs()[0], 0);
1923 assert_eq_input_ref!(&project.exprs()[1], 2);
1924
1925 let agg_new = project.input();
1926 let agg_new = agg_new.as_logical_agg().unwrap();
1927 assert_eq!(agg_new.group_key(), &vec![0, 1].into());
1928
1929 assert_eq!(agg_new.agg_calls().len(), 1);
1930 let agg_call_new = agg_new.agg_calls()[0].clone();
1931 assert_eq!(agg_call_new.agg_type, PbAggKind::Max.into());
1932 assert_eq!(input_ref_to_column_indices(&agg_call_new.inputs), vec![0]);
1933 assert_eq!(agg_call_new.return_type, ty);
1934
1935 let values = agg_new.input();
1936 let values = values.as_logical_values().unwrap();
1937 assert_eq!(values.schema().fields(), &fields[1..]);
1938 }
1939}