1use std::collections::{BTreeMap, BTreeSet, HashMap};
16use std::{fmt, vec};
17
18use itertools::{Either, Itertools};
19use pretty_xmlish::{Pretty, StrAssocArr};
20use risingwave_common::catalog::{Field, FieldDisplay, Schema};
21use risingwave_common::types::DataType;
22use risingwave_common::util::iter_util::ZipEqFast;
23use risingwave_common::util::sort_util::{ColumnOrder, ColumnOrderDisplay, OrderType};
24use risingwave_common::util::value_encoding::DatumToProtoExt;
25use risingwave_expr::aggregate::{AggType, PbAggKind, agg_types};
26use risingwave_expr::sig::{FUNCTION_REGISTRY, FuncBuilder};
27use risingwave_pb::expr::{PbAggCall, PbConstant};
28use risingwave_pb::stream_plan::{AggCallState as PbAggCallState, agg_call_state};
29
30use super::super::utils::TableCatalogBuilder;
31use super::{GenericPlanNode, GenericPlanRef, impl_distill_unit_from_fields, stream};
32use crate::TableCatalog;
33use crate::error::{ErrorCode, Result};
34use crate::expr::{Expr, ExprRewriter, ExprVisitor, InputRef, InputRefDisplay, Literal};
35use crate::optimizer::optimizer_context::OptimizerContextRef;
36use crate::optimizer::plan_node::batch::BatchPlanRef;
37use crate::optimizer::property::{
38 Distribution, FunctionalDependencySet, RequiredDist, WatermarkColumns,
39};
40use crate::stream_fragmenter::BuildFragmentGraphState;
41use crate::utils::{
42 ColIndexMapping, ColIndexMappingRewriteExt, Condition, ConditionDisplay, IndexRewriter,
43 IndexSet,
44};
45
46#[derive(Debug, Clone, PartialEq, Eq, Hash)]
53pub struct Agg<PlanRef> {
54 pub agg_calls: Vec<PlanAggCall>,
55 pub group_key: IndexSet,
56 pub grouping_sets: Vec<IndexSet>,
57 pub input: PlanRef,
58 pub enable_two_phase: bool,
59}
60
61impl<PlanRef: GenericPlanRef> Agg<PlanRef> {
62 pub(crate) fn rewrite_exprs(&mut self, r: &mut dyn ExprRewriter) {
63 self.agg_calls.iter_mut().for_each(|call| {
64 call.filter = call.filter.clone().rewrite_expr(r);
65 });
66 }
67
68 pub(crate) fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
69 self.agg_calls.iter().for_each(|call| {
70 call.filter.visit_expr(v);
71 });
72 }
73
74 pub(crate) fn output_len(&self) -> usize {
75 self.group_key.len() + self.agg_calls.len()
76 }
77
78 pub fn o2i_col_mapping(&self) -> ColIndexMapping {
81 let mut map = vec![None; self.output_len()];
82 for (i, key) in self.group_key.indices().enumerate() {
83 map[i] = Some(key);
84 }
85 ColIndexMapping::new(map, self.input.schema().len())
86 }
87
88 pub fn i2o_col_mapping(&self) -> ColIndexMapping {
90 let mut map = vec![None; self.input.schema().len()];
91 for (i, key) in self.group_key.indices().enumerate() {
92 map[key] = Some(i);
93 }
94 ColIndexMapping::new(map, self.output_len())
95 }
96
97 fn two_phase_agg_forced(&self) -> bool {
98 self.ctx().session_ctx().config().force_two_phase_agg()
99 }
100
101 pub fn two_phase_agg_enabled(&self) -> bool {
102 self.enable_two_phase
103 }
104
105 pub(crate) fn can_two_phase_agg(&self) -> bool {
106 self.two_phase_agg_enabled()
107 && !self.agg_calls.is_empty()
108 && self.agg_calls.iter().all(|call| {
109 let agg_type_ok = !matches!(call.agg_type, agg_types::simply_cannot_two_phase!());
110 let order_ok = matches!(
111 call.agg_type,
112 agg_types::result_unaffected_by_order_by!()
113 | AggType::Builtin(PbAggKind::ApproxPercentile)
114 ) || call.order_by.is_empty();
115 let distinct_ok =
116 matches!(call.agg_type, agg_types::result_unaffected_by_distinct!())
117 || !call.distinct;
118 agg_type_ok && order_ok && distinct_ok
119 })
120 }
121
122 pub(crate) fn must_try_two_phase_agg(&self) -> bool {
124 self.two_phase_agg_forced() && self.can_two_phase_agg()
125 }
126
127 pub(crate) fn hash_agg_dist_satisfied_by_input_dist(&self, input_dist: &Distribution) -> bool {
131 let required_dist =
132 RequiredDist::shard_by_key(self.input.schema().len(), &self.group_key.to_vec());
133 input_dist.satisfies(&required_dist)
134 }
135
136 pub(crate) fn all_local_aggs_are_stateless(&self, stream_input_append_only: bool) -> bool {
138 self.agg_calls.iter().all(|c| {
139 matches!(c.agg_type, agg_types::single_value_state!())
140 || (matches!(c.agg_type, agg_types::single_value_state_iff_in_append_only!() if stream_input_append_only))
141 })
142 }
143
144 pub(crate) fn eowc_window_column(
145 &self,
146 input_watermark_columns: &WatermarkColumns,
147 ) -> Result<usize> {
148 let group_key_with_wtmk = self
149 .group_key
150 .indices()
151 .filter_map(|idx| {
152 input_watermark_columns
153 .get_group(idx)
154 .map(|group| (idx, group))
155 })
156 .collect::<Vec<_>>();
157
158 if group_key_with_wtmk.is_empty() {
159 return Err(ErrorCode::NotSupported(
160 "Emit-On-Window-Close mode requires a watermark column in GROUP BY.".to_owned(),
161 "Please try to GROUP BY a watermark column".to_owned(),
162 )
163 .into());
164 }
165 if group_key_with_wtmk.len() == 1
166 || group_key_with_wtmk
167 .iter()
168 .map(|(_, group)| group)
169 .all_equal()
170 {
171 return Ok(group_key_with_wtmk[0].0);
174 }
175 Err(ErrorCode::NotSupported(
176 "Emit-On-Window-Close mode requires that watermark columns in GROUP BY are derived from the same upstream column.".to_owned(),
177 "Please try to remove undesired columns from GROUP BY".to_owned(),
178 )
179 .into())
180 }
181
182 pub fn new(agg_calls: Vec<PlanAggCall>, group_key: IndexSet, input: PlanRef) -> Self {
183 let enable_two_phase = input.ctx().session_ctx().config().enable_two_phase_agg();
184 Self {
185 agg_calls,
186 group_key,
187 input,
188 grouping_sets: vec![],
189 enable_two_phase,
190 }
191 }
192
193 pub fn with_grouping_sets(mut self, grouping_sets: Vec<IndexSet>) -> Self {
194 self.grouping_sets = grouping_sets;
195 self
196 }
197
198 pub fn with_enable_two_phase(mut self, enable_two_phase: bool) -> Self {
199 self.enable_two_phase = enable_two_phase;
200 self
201 }
202}
203
204impl<PlanRef: BatchPlanRef> Agg<PlanRef> {
205 pub(crate) fn input_provides_order_on_group_keys(&self) -> bool {
207 let mut input_order_prefix = IndexSet::empty();
208 for input_order_col in &self.input.order().column_orders {
209 if !self.group_key.contains(input_order_col.column_index) {
210 break;
211 }
212 input_order_prefix.insert(input_order_col.column_index);
213 }
214 self.group_key == input_order_prefix
215 }
216}
217
218impl<PlanRef: GenericPlanRef> GenericPlanNode for Agg<PlanRef> {
219 fn schema(&self) -> Schema {
220 let fields = self
221 .group_key
222 .indices()
223 .map(|i| self.input.schema().fields()[i].clone())
224 .chain(self.agg_calls.iter().map(|agg_call| {
225 let plan_agg_call_display = PlanAggCallDisplay {
226 plan_agg_call: agg_call,
227 input_schema: self.input.schema(),
228 };
229 let name = format!("{:?}", plan_agg_call_display);
230 Field::with_name(agg_call.return_type.clone(), name)
231 }))
232 .collect();
233 Schema { fields }
234 }
235
236 fn stream_key(&self) -> Option<Vec<usize>> {
237 Some((0..self.group_key.len()).collect())
238 }
239
240 fn ctx(&self) -> OptimizerContextRef {
241 self.input.ctx()
242 }
243
244 fn functional_dependency(&self) -> FunctionalDependencySet {
245 let output_len = self.output_len();
246 let _input_len = self.input.schema().len();
247 let mut fd_set =
248 FunctionalDependencySet::with_key(output_len, &(0..self.group_key.len()).collect_vec());
249 let i2o = self.i2o_col_mapping();
251 for fd in self.input.functional_dependency().as_dependencies() {
252 if let Some(fd) = i2o.rewrite_functional_dependency(fd) {
253 fd_set.add_functional_dependency(fd);
254 }
255 }
256 fd_set
257 }
258}
259
260pub enum AggCallState {
261 Value,
262 MaterializedInput(Box<MaterializedInputState>),
263}
264
265impl AggCallState {
266 pub fn into_prost(self, state: &mut BuildFragmentGraphState) -> PbAggCallState {
267 PbAggCallState {
268 inner: Some(match self {
269 AggCallState::Value => {
270 agg_call_state::Inner::ValueState(agg_call_state::ValueState {})
271 }
272 AggCallState::MaterializedInput(s) => {
273 agg_call_state::Inner::MaterializedInputState(
274 agg_call_state::MaterializedInputState {
275 table: Some(
276 s.table
277 .with_id(state.gen_table_id_wrapped())
278 .to_internal_table_prost(),
279 ),
280 included_upstream_indices: s
281 .included_upstream_indices
282 .into_iter()
283 .map(|x| x as _)
284 .collect(),
285 table_value_indices: s
286 .table_value_indices
287 .into_iter()
288 .map(|x| x as _)
289 .collect(),
290 order_columns: s
291 .order_columns
292 .into_iter()
293 .map(|x| x.to_protobuf())
294 .collect(),
295 },
296 )
297 }
298 }),
299 }
300 }
301}
302
303pub struct MaterializedInputState {
304 pub table: TableCatalog,
305 pub included_upstream_indices: Vec<usize>,
306 pub table_value_indices: Vec<usize>,
307 pub order_columns: Vec<ColumnOrder>,
308}
309
310impl<PlanRef: stream::StreamPlanRef> Agg<PlanRef> {
311 pub fn infer_tables(
312 &self,
313 me: impl stream::StreamPlanRef,
314 vnode_col_idx: Option<usize>,
315 window_col_idx: Option<usize>,
316 ) -> (
317 TableCatalog,
318 Vec<AggCallState>,
319 HashMap<usize, TableCatalog>,
320 ) {
321 (
322 self.infer_intermediate_state_table(&me, vnode_col_idx, window_col_idx),
323 self.infer_stream_agg_state(&me, vnode_col_idx, window_col_idx),
324 self.infer_distinct_dedup_tables(&me, vnode_col_idx, window_col_idx),
325 )
326 }
327
328 fn get_ordered_group_key(&self, window_col_idx: Option<usize>) -> Vec<usize> {
329 if let Some(window_col_idx) = window_col_idx {
330 assert!(self.group_key.contains(window_col_idx));
331 Either::Left(
332 std::iter::once(window_col_idx).chain(
333 self.group_key
334 .indices()
335 .filter(move |&i| i != window_col_idx),
336 ),
337 )
338 } else {
339 Either::Right(self.group_key.indices())
340 }
341 .collect()
342 }
343
344 fn create_table_builder(
352 &self,
353 _ctx: OptimizerContextRef,
354 window_col_idx: Option<usize>,
355 ) -> (TableCatalogBuilder, Vec<usize>, BTreeMap<usize, usize>) {
356 let mut table_builder = TableCatalogBuilder::default();
359
360 assert!(table_builder.columns().is_empty());
361 assert_eq!(table_builder.get_current_pk_len(), 0);
362
363 let mut included_upstream_indices = vec![];
365 let mut column_mapping = BTreeMap::new();
366 let in_fields = self.input.schema().fields();
367 for idx in self.group_key.indices() {
368 let tbl_col_idx = table_builder.add_column(&in_fields[idx]);
369 included_upstream_indices.push(idx);
370 column_mapping.insert(idx, tbl_col_idx);
371 }
372
373 let ordered_group_key = self.get_ordered_group_key(window_col_idx);
375 for idx in ordered_group_key {
376 table_builder.add_order_column(column_mapping[&idx], OrderType::ascending());
377 }
378
379 (table_builder, included_upstream_indices, column_mapping)
380 }
381
382 pub fn infer_stream_agg_state(
384 &self,
385 me: impl stream::StreamPlanRef,
386 vnode_col_idx: Option<usize>,
387 window_col_idx: Option<usize>,
388 ) -> Vec<AggCallState> {
389 let in_fields = self.input.schema().fields().to_vec();
390 let in_pks = self.input.stream_key().unwrap().to_vec();
391 let in_append_only = self.input.append_only();
392 let in_dist_key = self.input.distribution().dist_column_indices().to_vec();
393
394 let gen_materialized_input_state = |sort_keys: Vec<(OrderType, usize)>,
395 extra_keys: Vec<usize>,
396 include_keys: Vec<usize>|
397 -> MaterializedInputState {
398 let (mut table_builder, mut included_upstream_indices, mut column_mapping) =
399 self.create_table_builder(me.ctx(), window_col_idx);
400 let read_prefix_len_hint = table_builder.get_current_pk_len();
401
402 let mut order_columns = vec![];
403 let mut table_value_indices = BTreeSet::new(); let mut add_column =
405 |upstream_idx, order_type, table_builder: &mut TableCatalogBuilder| {
406 column_mapping.entry(upstream_idx).or_insert_with(|| {
407 let table_col_idx = table_builder.add_column(&in_fields[upstream_idx]);
408 if let Some(order_type) = order_type {
409 table_builder.add_order_column(table_col_idx, order_type);
410 order_columns.push(ColumnOrder::new(upstream_idx, order_type));
411 }
412 included_upstream_indices.push(upstream_idx);
413 table_col_idx
414 });
415 table_value_indices.insert(column_mapping[&upstream_idx]);
416 };
417
418 for (order_type, idx) in sort_keys {
419 add_column(idx, Some(order_type), &mut table_builder);
420 }
421 for idx in extra_keys {
422 add_column(idx, Some(OrderType::ascending()), &mut table_builder);
423 }
424 for idx in include_keys {
425 add_column(idx, None, &mut table_builder);
426 }
427
428 let mapping =
429 ColIndexMapping::with_included_columns(&included_upstream_indices, in_fields.len());
430 let tb_dist = mapping.rewrite_dist_key(&in_dist_key);
431 if let Some(tb_vnode_idx) = vnode_col_idx.and_then(|idx| mapping.try_map(idx)) {
432 table_builder.set_vnode_col_idx(tb_vnode_idx);
433 }
434
435 let table_value_indices = table_value_indices.into_iter().collect_vec();
437 table_builder.set_value_indices(table_value_indices.clone());
438
439 MaterializedInputState {
440 table: table_builder.build(tb_dist.unwrap_or_default(), read_prefix_len_hint),
441 included_upstream_indices,
442 table_value_indices,
443 order_columns,
444 }
445 };
446
447 self.agg_calls
448 .iter()
449 .map(|agg_call| match agg_call.agg_type {
450 agg_types::single_value_state_iff_in_append_only!() if in_append_only => {
451 AggCallState::Value
452 }
453 agg_types::single_value_state!() => AggCallState::Value,
454 agg_types::materialized_input_state!() => {
455 let sort_keys = {
457 match agg_call.agg_type {
458 AggType::Builtin(PbAggKind::Min) => {
459 vec![(OrderType::ascending(), agg_call.inputs[0].index)]
460 }
461 AggType::Builtin(PbAggKind::Max) => {
462 vec![(OrderType::descending(), agg_call.inputs[0].index)]
463 }
464 AggType::Builtin(
465 PbAggKind::FirstValue
466 | PbAggKind::LastValue
467 | PbAggKind::StringAgg
468 | PbAggKind::ArrayAgg
469 | PbAggKind::JsonbAgg,
470 )
471 | AggType::WrapScalar(_) => {
472 if agg_call.order_by.is_empty() {
473 me.ctx().warn_to_user(format!(
474 "{} without ORDER BY may produce non-deterministic result",
475 agg_call.agg_type,
476 ));
477 }
478 agg_call
479 .order_by
480 .iter()
481 .map(|o| {
482 (
483 if matches!(
484 agg_call.agg_type,
485 AggType::Builtin(PbAggKind::LastValue)
486 ) {
487 o.order_type.reverse()
488 } else {
489 o.order_type
490 },
491 o.column_index,
492 )
493 })
494 .collect()
495 }
496 AggType::Builtin(PbAggKind::JsonbObjectAgg) => agg_call
497 .order_by
498 .iter()
499 .map(|o| (o.order_type, o.column_index))
500 .collect(),
501 _ => unreachable!(),
502 }
503 };
504
505 let extra_keys = if agg_call.distinct {
507 let distinct_key = agg_call.inputs[0].index;
509 vec![distinct_key]
510 } else {
511 in_pks.clone()
513 };
514
515 let include_keys = match agg_call.agg_type {
517 AggType::Builtin(
519 PbAggKind::FirstValue
520 | PbAggKind::LastValue
521 | PbAggKind::StringAgg
522 | PbAggKind::ArrayAgg
523 | PbAggKind::JsonbAgg
524 | PbAggKind::JsonbObjectAgg,
525 )
526 | AggType::WrapScalar(_) => {
527 agg_call.inputs.iter().map(|i| i.index).collect()
528 }
529 _ => vec![],
530 };
531
532 let state = gen_materialized_input_state(sort_keys, extra_keys, include_keys);
533 AggCallState::MaterializedInput(Box::new(state))
534 }
535 agg_types::rewritten!() => {
536 unreachable!("should have been rewritten")
537 }
538 agg_types::unimplemented_in_stream!() => {
539 unreachable!("should have been banned")
540 }
541 AggType::Builtin(
542 PbAggKind::Unspecified | PbAggKind::UserDefined | PbAggKind::WrapScalar,
543 ) => {
544 unreachable!("invalid agg kind")
545 }
546 })
547 .collect()
548 }
549
550 pub fn infer_intermediate_state_table(
553 &self,
554 me: impl GenericPlanRef,
555 vnode_col_idx: Option<usize>,
556 window_col_idx: Option<usize>,
557 ) -> TableCatalog {
558 let mut out_fields = me.schema().fields().to_vec();
559
560 let in_append_only = self.input.append_only();
562 for (agg_call, field) in self
563 .agg_calls
564 .iter()
565 .zip_eq_fast(&mut out_fields[self.group_key.len()..])
566 {
567 let agg_kind = match agg_call.agg_type {
568 AggType::UserDefined(_) => {
569 field.data_type = DataType::Bytea;
571 continue;
572 }
573 AggType::WrapScalar(_) => {
574 continue;
576 }
577 AggType::Builtin(kind) => kind,
578 };
579 let sig = FUNCTION_REGISTRY
580 .get(
581 agg_kind,
582 &agg_call
583 .inputs
584 .iter()
585 .map(|input| input.data_type.clone())
586 .collect_vec(),
587 &agg_call.return_type,
588 )
589 .expect("agg not found");
590 match (in_append_only, sig.is_append_only()) {
593 (false, true) => {
594 }
598 (true, true) => {
599 if let FuncBuilder::Aggregate {
601 append_only_state_type: Some(state_type),
602 ..
603 } = &sig.build
604 {
605 field.data_type = state_type.clone();
606 }
607 }
608 (_, false) => {
609 if let FuncBuilder::Aggregate {
611 retractable_state_type: Some(state_type),
612 ..
613 } = &sig.build
614 {
615 field.data_type = state_type.clone();
616 }
617 }
618 }
619 }
620 let in_dist_key = self.input.distribution().dist_column_indices().to_vec();
621 let n_group_key_cols = self.group_key.len();
622
623 let (mut table_builder, _, _) = self.create_table_builder(me.ctx(), window_col_idx);
624 let read_prefix_len_hint = table_builder.get_current_pk_len();
625
626 for field in out_fields.iter().skip(n_group_key_cols) {
627 table_builder.add_column(field);
628 }
629
630 let mapping = self.i2o_col_mapping();
631 let tb_dist = mapping.rewrite_dist_key(&in_dist_key).unwrap_or_default();
632 if let Some(tb_vnode_idx) = vnode_col_idx.and_then(|idx| mapping.try_map(idx)) {
633 table_builder.set_vnode_col_idx(tb_vnode_idx);
634 }
635
636 table_builder.set_value_indices((n_group_key_cols..out_fields.len()).collect());
639 table_builder.build(tb_dist, read_prefix_len_hint)
640 }
641
642 pub fn infer_distinct_dedup_tables(
649 &self,
650 me: impl GenericPlanRef,
651 vnode_col_idx: Option<usize>,
652 window_col_idx: Option<usize>,
653 ) -> HashMap<usize, TableCatalog> {
654 let in_dist_key = self.input.distribution().dist_column_indices().to_vec();
655 let in_fields = self.input.schema().fields();
656
657 self.agg_calls
658 .iter()
659 .enumerate()
660 .filter(|(_, call)| call.distinct) .into_group_map_by(|(_, call)| call.inputs[0].index) .into_iter()
663 .map(|(distinct_col, indices_and_calls)| {
664 let (mut table_builder, mut key_cols, _) =
665 self.create_table_builder(me.ctx(), window_col_idx);
666 let table_col_idx = table_builder.add_column(&in_fields[distinct_col]);
667 table_builder.add_order_column(table_col_idx, OrderType::ascending());
668 key_cols.push(distinct_col);
669
670 let read_prefix_len_hint = table_builder.get_current_pk_len();
671
672 for (call_index, _) in indices_and_calls {
676 table_builder.add_column(&Field {
677 data_type: DataType::Int64,
678 name: format!("count_for_agg_call_{}", call_index),
679 });
680 }
681 table_builder
682 .set_value_indices((key_cols.len()..table_builder.columns().len()).collect());
683
684 let mapping = ColIndexMapping::with_included_columns(&key_cols, in_fields.len());
685 if let Some(idx) = vnode_col_idx.and_then(|idx| mapping.try_map(idx)) {
686 table_builder.set_vnode_col_idx(idx);
687 }
688 let dist_key = mapping.rewrite_dist_key(&in_dist_key).unwrap_or_default();
689 let table = table_builder.build(dist_key, read_prefix_len_hint);
690 (distinct_col, table)
691 })
692 .collect()
693 }
694
695 pub fn decompose(self) -> (Vec<PlanAggCall>, IndexSet, Vec<IndexSet>, PlanRef, bool) {
696 (
697 self.agg_calls,
698 self.group_key,
699 self.grouping_sets,
700 self.input,
701 self.enable_two_phase,
702 )
703 }
704
705 pub fn fields_pretty<'a>(&self) -> StrAssocArr<'a> {
706 let last = ("aggs", self.agg_calls_pretty());
707 if !self.group_key.is_empty() {
708 let first = ("group_key", self.group_key_pretty());
709 vec![first, last]
710 } else {
711 vec![last]
712 }
713 }
714
715 fn agg_calls_pretty<'a>(&self) -> Pretty<'a> {
716 let f = |plan_agg_call| {
717 Pretty::debug(&PlanAggCallDisplay {
718 plan_agg_call,
719 input_schema: self.input.schema(),
720 })
721 };
722 Pretty::Array(self.agg_calls.iter().map(f).collect())
723 }
724
725 fn group_key_pretty<'a>(&self) -> Pretty<'a> {
726 let f = |i| Pretty::display(&FieldDisplay(self.input.schema().fields.get(i).unwrap()));
727 Pretty::Array(self.group_key.indices().map(f).collect())
728 }
729}
730
731impl_distill_unit_from_fields!(Agg, stream::StreamPlanRef);
732
733#[derive(Clone, PartialEq, Eq, Hash)]
737pub struct PlanAggCall {
738 pub agg_type: AggType,
740
741 pub return_type: DataType,
743
744 pub inputs: Vec<InputRef>,
753
754 pub distinct: bool,
755 pub order_by: Vec<ColumnOrder>,
756 pub filter: Condition,
759 pub direct_args: Vec<Literal>,
760}
761
762impl fmt::Debug for PlanAggCall {
763 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
764 write!(f, "{}", self.agg_type)?;
765 if !self.inputs.is_empty() {
766 write!(f, "(")?;
767 for (idx, input) in self.inputs.iter().enumerate() {
768 if idx == 0 && self.distinct {
769 write!(f, "distinct ")?;
770 }
771 write!(f, "{:?}", input)?;
772 if idx != (self.inputs.len() - 1) {
773 write!(f, ",")?;
774 }
775 }
776 if !self.order_by.is_empty() {
777 let clause_text = self.order_by.iter().map(|e| format!("{:?}", e)).join(", ");
778 write!(f, " order_by({})", clause_text)?;
779 }
780 write!(f, ")")?;
781 }
782 if !self.filter.always_true() {
783 write!(
784 f,
785 " filter({:?})",
786 self.filter.as_expr_unless_true().unwrap()
787 )?;
788 }
789 Ok(())
790 }
791}
792
793impl PlanAggCall {
794 pub fn rewrite_input_index(&mut self, mapping: ColIndexMapping) {
795 self.inputs.iter_mut().for_each(|x| {
797 x.index = mapping.map(x.index);
798 });
799
800 self.order_by.iter_mut().for_each(|x| {
802 x.column_index = mapping.map(x.column_index);
803 });
804
805 let mut rewriter = IndexRewriter::new(mapping);
807 self.filter.conjunctions.iter_mut().for_each(|x| {
808 *x = rewriter.rewrite_expr(x.clone());
809 });
810 }
811
812 pub fn to_protobuf(&self) -> PbAggCall {
813 PbAggCall {
814 kind: match &self.agg_type {
815 AggType::Builtin(kind) => *kind,
816 AggType::UserDefined(_) => PbAggKind::UserDefined,
817 AggType::WrapScalar(_) => PbAggKind::WrapScalar,
818 }
819 .into(),
820 return_type: Some(self.return_type.to_protobuf()),
821 args: self.inputs.iter().map(InputRef::to_proto).collect(),
822 distinct: self.distinct,
823 order_by: self.order_by.iter().map(ColumnOrder::to_protobuf).collect(),
824 filter: self.filter.as_expr_unless_true().map(|x| x.to_expr_proto()),
825 direct_args: self
826 .direct_args
827 .iter()
828 .map(|x| PbConstant {
829 datum: Some(x.get_data().to_protobuf()),
830 r#type: Some(x.return_type().to_protobuf()),
831 })
832 .collect(),
833 udf: match &self.agg_type {
834 AggType::UserDefined(udf) => Some(udf.clone()),
835 _ => None,
836 },
837 scalar: match &self.agg_type {
838 AggType::WrapScalar(expr) => Some(expr.clone()),
839 _ => None,
840 },
841 }
842 }
843
844 pub fn partial_to_total_agg_call(&self, partial_output_idx: usize) -> PlanAggCall {
845 let total_agg_type = self
846 .agg_type
847 .partial_to_total()
848 .expect("unsupported kinds shouldn't get here");
849 PlanAggCall {
850 agg_type: total_agg_type,
851 inputs: vec![InputRef::new(partial_output_idx, self.return_type.clone())],
852 order_by: vec![], filter: Condition::true_cond(),
854 ..self.clone()
855 }
856 }
857
858 pub fn count_star() -> Self {
859 PlanAggCall {
860 agg_type: PbAggKind::Count.into(),
861 return_type: DataType::Int64,
862 inputs: vec![],
863 distinct: false,
864 order_by: vec![],
865 filter: Condition::true_cond(),
866 direct_args: vec![],
867 }
868 }
869
870 pub fn with_condition(mut self, filter: Condition) -> Self {
871 self.filter = filter;
872 self
873 }
874
875 pub fn input_indices(&self) -> Vec<usize> {
876 self.inputs.iter().map(|input| input.index()).collect()
877 }
878}
879
880pub struct PlanAggCallDisplay<'a> {
881 pub plan_agg_call: &'a PlanAggCall,
882 pub input_schema: &'a Schema,
883}
884
885impl fmt::Debug for PlanAggCallDisplay<'_> {
886 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
887 let that = self.plan_agg_call;
888 write!(f, "{}", that.agg_type)?;
889 if !that.inputs.is_empty() {
890 write!(f, "(")?;
891 for (idx, input) in that.inputs.iter().enumerate() {
892 if idx == 0 && that.distinct {
893 write!(f, "distinct ")?;
894 }
895 write!(
896 f,
897 "{}",
898 InputRefDisplay {
899 input_ref: input,
900 input_schema: self.input_schema
901 }
902 )?;
903 if idx != (that.inputs.len() - 1) {
904 write!(f, ", ")?;
905 }
906 }
907 if !that.order_by.is_empty() {
908 write!(
909 f,
910 " order_by({})",
911 that.order_by.iter().format_with(", ", |o, f| {
912 f(&ColumnOrderDisplay {
913 column_order: o,
914 input_schema: self.input_schema,
915 })
916 })
917 )?;
918 }
919 write!(f, ")")?;
920 }
921
922 if !that.filter.always_true() {
923 write!(
924 f,
925 " filter({:?})",
926 ConditionDisplay {
927 condition: &that.filter,
928 input_schema: self.input_schema,
929 }
930 )?;
931 }
932 Ok(())
933 }
934}