1use std::collections::{BTreeMap, BTreeSet, HashMap};
16use std::{fmt, vec};
17
18use itertools::{Either, Itertools};
19use pretty_xmlish::{Pretty, StrAssocArr};
20use risingwave_common::catalog::{Field, FieldDisplay, Schema};
21use risingwave_common::types::DataType;
22use risingwave_common::util::iter_util::ZipEqFast;
23use risingwave_common::util::sort_util::{ColumnOrder, ColumnOrderDisplay, OrderType};
24use risingwave_common::util::value_encoding::DatumToProtoExt;
25use risingwave_expr::aggregate::{AggType, PbAggKind, agg_types};
26use risingwave_expr::sig::{FUNCTION_REGISTRY, FuncBuilder};
27use risingwave_pb::expr::{PbAggCall, PbConstant};
28use risingwave_pb::stream_plan::{AggCallState as PbAggCallState, agg_call_state};
29
30use super::super::utils::TableCatalogBuilder;
31use super::{GenericPlanNode, GenericPlanRef, PhysicalPlanRef, impl_distill_unit_from_fields};
32use crate::TableCatalog;
33use crate::error::{ErrorCode, Result};
34use crate::expr::{Expr, ExprRewriter, ExprVisitor, InputRef, InputRefDisplay, Literal};
35use crate::optimizer::optimizer_context::OptimizerContextRef;
36use crate::optimizer::plan_node::batch::BatchPlanNodeMetadata;
37use crate::optimizer::plan_node::{BatchPlanRef, StreamPlanNodeMetadata, StreamPlanRef};
38use crate::optimizer::property::{
39 Distribution, FunctionalDependencySet, RequiredDist, WatermarkColumns,
40};
41use crate::stream_fragmenter::BuildFragmentGraphState;
42use crate::utils::{
43 ColIndexMapping, ColIndexMappingRewriteExt, Condition, ConditionDisplay, IndexRewriter,
44 IndexSet,
45};
46
47#[derive(Debug, Clone, PartialEq, Eq, Hash)]
54pub struct Agg<PlanRef> {
55 pub agg_calls: Vec<PlanAggCall>,
56 pub group_key: IndexSet,
57 pub grouping_sets: Vec<IndexSet>,
58 pub input: PlanRef,
59 pub enable_two_phase: bool,
60}
61
62impl<PlanRef: GenericPlanRef> Agg<PlanRef> {
63 pub(crate) fn clone_with_input<OtherPlanRef>(&self, input: OtherPlanRef) -> Agg<OtherPlanRef> {
64 Agg {
65 agg_calls: self.agg_calls.clone(),
66 group_key: self.group_key.clone(),
67 grouping_sets: self.grouping_sets.clone(),
68 input,
69 enable_two_phase: self.enable_two_phase,
70 }
71 }
72
73 pub(crate) fn rewrite_exprs(&mut self, r: &mut dyn ExprRewriter) {
74 self.agg_calls.iter_mut().for_each(|call| {
75 call.filter = call.filter.clone().rewrite_expr(r);
76 });
77 }
78
79 pub(crate) fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
80 self.agg_calls.iter().for_each(|call| {
81 call.filter.visit_expr(v);
82 });
83 }
84
85 pub(crate) fn output_len(&self) -> usize {
86 self.group_key.len() + self.agg_calls.len()
87 }
88
89 pub fn o2i_col_mapping(&self) -> ColIndexMapping {
92 let mut map = vec![None; self.output_len()];
93 for (i, key) in self.group_key.indices().enumerate() {
94 map[i] = Some(key);
95 }
96 ColIndexMapping::new(map, self.input.schema().len())
97 }
98
99 pub fn i2o_col_mapping(&self) -> ColIndexMapping {
101 let mut map = vec![None; self.input.schema().len()];
102 for (i, key) in self.group_key.indices().enumerate() {
103 map[key] = Some(i);
104 }
105 ColIndexMapping::new(map, self.output_len())
106 }
107
108 fn two_phase_agg_forced(&self) -> bool {
109 self.ctx().session_ctx().config().force_two_phase_agg()
110 }
111
112 pub fn two_phase_agg_enabled(&self) -> bool {
113 self.enable_two_phase
114 }
115
116 pub(crate) fn can_two_phase_agg(&self) -> bool {
117 self.two_phase_agg_enabled()
118 && !self.agg_calls.is_empty()
119 && self.agg_calls.iter().all(|call| {
120 let agg_type_ok = !matches!(call.agg_type, agg_types::simply_cannot_two_phase!());
121 let order_ok = matches!(
122 call.agg_type,
123 agg_types::result_unaffected_by_order_by!()
124 | AggType::Builtin(PbAggKind::ApproxPercentile)
125 ) || call.order_by.is_empty();
126 let distinct_ok =
127 matches!(call.agg_type, agg_types::result_unaffected_by_distinct!())
128 || !call.distinct;
129 agg_type_ok && order_ok && distinct_ok
130 })
131 }
132
133 pub(crate) fn must_try_two_phase_agg(&self) -> bool {
135 self.two_phase_agg_forced() && self.can_two_phase_agg()
136 }
137
138 pub(crate) fn hash_agg_dist_satisfied_by_input_dist(&self, input_dist: &Distribution) -> bool {
142 let required_dist =
143 RequiredDist::shard_by_key(self.input.schema().len(), &self.group_key.to_vec());
144 input_dist.satisfies(&required_dist)
145 }
146
147 pub(crate) fn all_local_aggs_are_stateless(&self, stream_input_append_only: bool) -> bool {
149 self.agg_calls.iter().all(|c| {
150 matches!(c.agg_type, agg_types::single_value_state!())
151 || (matches!(c.agg_type, agg_types::single_value_state_iff_in_append_only!() if stream_input_append_only))
152 })
153 }
154
155 pub(crate) fn eowc_window_column(
156 &self,
157 input_watermark_columns: &WatermarkColumns,
158 ) -> Result<usize> {
159 let group_key_with_wtmk = self
160 .group_key
161 .indices()
162 .filter_map(|idx| {
163 input_watermark_columns
164 .get_group(idx)
165 .map(|group| (idx, group))
166 })
167 .collect::<Vec<_>>();
168
169 if group_key_with_wtmk.is_empty() {
170 return Err(ErrorCode::NotSupported(
171 "Emit-On-Window-Close mode requires a watermark column in GROUP BY.".to_owned(),
172 "Please try to GROUP BY a watermark column".to_owned(),
173 )
174 .into());
175 }
176 if group_key_with_wtmk.len() == 1
177 || group_key_with_wtmk
178 .iter()
179 .map(|(_, group)| group)
180 .all_equal()
181 {
182 return Ok(group_key_with_wtmk[0].0);
185 }
186 Err(ErrorCode::NotSupported(
187 "Emit-On-Window-Close mode requires that watermark columns in GROUP BY are derived from the same upstream column.".to_owned(),
188 "Please try to remove undesired columns from GROUP BY".to_owned(),
189 )
190 .into())
191 }
192
193 pub fn new(agg_calls: Vec<PlanAggCall>, group_key: IndexSet, input: PlanRef) -> Self {
194 let enable_two_phase = input.ctx().session_ctx().config().enable_two_phase_agg();
195 Self {
196 agg_calls,
197 group_key,
198 input,
199 grouping_sets: vec![],
200 enable_two_phase,
201 }
202 }
203
204 pub fn with_grouping_sets(mut self, grouping_sets: Vec<IndexSet>) -> Self {
205 self.grouping_sets = grouping_sets;
206 self
207 }
208
209 pub fn with_enable_two_phase(mut self, enable_two_phase: bool) -> Self {
210 self.enable_two_phase = enable_two_phase;
211 self
212 }
213}
214
215impl Agg<BatchPlanRef> {
216 pub(crate) fn input_provides_order_on_group_keys(&self) -> bool {
218 let mut input_order_prefix = IndexSet::empty();
219 for input_order_col in &self.input.order().column_orders {
220 if !self.group_key.contains(input_order_col.column_index) {
221 break;
222 }
223 input_order_prefix.insert(input_order_col.column_index);
224 }
225 self.group_key == input_order_prefix
226 }
227}
228
229impl<PlanRef: GenericPlanRef> GenericPlanNode for Agg<PlanRef> {
230 fn schema(&self) -> Schema {
231 let fields = self
232 .group_key
233 .indices()
234 .map(|i| self.input.schema().fields()[i].clone())
235 .chain(self.agg_calls.iter().map(|agg_call| {
236 let plan_agg_call_display = PlanAggCallDisplay {
237 plan_agg_call: agg_call,
238 input_schema: self.input.schema(),
239 };
240 let name = format!("{:?}", plan_agg_call_display);
241 Field::with_name(agg_call.return_type.clone(), name)
242 }))
243 .collect();
244 Schema { fields }
245 }
246
247 fn stream_key(&self) -> Option<Vec<usize>> {
248 Some((0..self.group_key.len()).collect())
249 }
250
251 fn ctx(&self) -> OptimizerContextRef {
252 self.input.ctx()
253 }
254
255 fn functional_dependency(&self) -> FunctionalDependencySet {
256 let output_len = self.output_len();
257 let _input_len = self.input.schema().len();
258 let mut fd_set =
259 FunctionalDependencySet::with_key(output_len, &(0..self.group_key.len()).collect_vec());
260 let i2o = self.i2o_col_mapping();
262 for fd in self.input.functional_dependency().as_dependencies() {
263 if let Some(fd) = i2o.rewrite_functional_dependency(fd) {
264 fd_set.add_functional_dependency(fd);
265 }
266 }
267 fd_set
268 }
269}
270
271pub enum AggCallState {
272 Value,
273 MaterializedInput(Box<MaterializedInputState>),
274}
275
276impl AggCallState {
277 pub fn into_prost(self, state: &mut BuildFragmentGraphState) -> PbAggCallState {
278 PbAggCallState {
279 inner: Some(match self {
280 AggCallState::Value => {
281 agg_call_state::Inner::ValueState(agg_call_state::ValueState {})
282 }
283 AggCallState::MaterializedInput(s) => {
284 agg_call_state::Inner::MaterializedInputState(
285 agg_call_state::MaterializedInputState {
286 table: Some(
287 s.table
288 .with_id(state.gen_table_id_wrapped())
289 .to_internal_table_prost(),
290 ),
291 included_upstream_indices: s
292 .included_upstream_indices
293 .into_iter()
294 .map(|x| x as _)
295 .collect(),
296 table_value_indices: s
297 .table_value_indices
298 .into_iter()
299 .map(|x| x as _)
300 .collect(),
301 order_columns: s
302 .order_columns
303 .into_iter()
304 .map(|x| x.to_protobuf())
305 .collect(),
306 },
307 )
308 }
309 }),
310 }
311 }
312}
313
314pub struct MaterializedInputState {
315 pub table: TableCatalog,
316 pub included_upstream_indices: Vec<usize>,
317 pub table_value_indices: Vec<usize>,
318 pub order_columns: Vec<ColumnOrder>,
319}
320
321impl Agg<StreamPlanRef> {
322 pub fn infer_tables(
323 &self,
324 me: impl StreamPlanNodeMetadata,
325 vnode_col_idx: Option<usize>,
326 window_col_idx: Option<usize>,
327 ) -> (
328 TableCatalog,
329 Vec<AggCallState>,
330 HashMap<usize, TableCatalog>,
331 ) {
332 (
333 self.infer_intermediate_state_table(&me, vnode_col_idx, window_col_idx),
334 self.infer_stream_agg_state(&me, vnode_col_idx, window_col_idx),
335 self.infer_distinct_dedup_tables(&me, vnode_col_idx, window_col_idx),
336 )
337 }
338
339 fn get_ordered_group_key(&self, window_col_idx: Option<usize>) -> Vec<usize> {
340 if let Some(window_col_idx) = window_col_idx {
341 assert!(self.group_key.contains(window_col_idx));
342 Either::Left(
343 std::iter::once(window_col_idx).chain(
344 self.group_key
345 .indices()
346 .filter(move |&i| i != window_col_idx),
347 ),
348 )
349 } else {
350 Either::Right(self.group_key.indices())
351 }
352 .collect()
353 }
354
355 fn create_table_builder(
363 &self,
364 _ctx: OptimizerContextRef,
365 window_col_idx: Option<usize>,
366 ) -> (TableCatalogBuilder, Vec<usize>, BTreeMap<usize, usize>) {
367 let mut table_builder = TableCatalogBuilder::default();
370
371 assert!(table_builder.columns().is_empty());
372 assert_eq!(table_builder.get_current_pk_len(), 0);
373
374 let mut included_upstream_indices = vec![];
376 let mut column_mapping = BTreeMap::new();
377 let in_fields = self.input.schema().fields();
378 for idx in self.group_key.indices() {
379 let tbl_col_idx = table_builder.add_column(&in_fields[idx]);
380 included_upstream_indices.push(idx);
381 column_mapping.insert(idx, tbl_col_idx);
382 }
383
384 let ordered_group_key = self.get_ordered_group_key(window_col_idx);
386 for idx in ordered_group_key {
387 table_builder.add_order_column(column_mapping[&idx], OrderType::ascending());
388 }
389
390 (table_builder, included_upstream_indices, column_mapping)
391 }
392
393 pub fn infer_stream_agg_state(
395 &self,
396 me: impl StreamPlanNodeMetadata,
397 vnode_col_idx: Option<usize>,
398 window_col_idx: Option<usize>,
399 ) -> Vec<AggCallState> {
400 let in_fields = self.input.schema().fields().to_vec();
401 let in_pks = self.input.expect_stream_key().to_vec();
402 let in_append_only = self.input.append_only();
403 let in_dist_key = self.input.distribution().dist_column_indices().to_vec();
404
405 let gen_materialized_input_state = |sort_keys: Vec<(OrderType, usize)>,
406 extra_keys: Vec<usize>,
407 include_keys: Vec<usize>|
408 -> MaterializedInputState {
409 let (mut table_builder, mut included_upstream_indices, mut column_mapping) =
410 self.create_table_builder(me.ctx(), window_col_idx);
411 let read_prefix_len_hint = table_builder.get_current_pk_len();
412
413 let mut order_columns = vec![];
414 let mut table_value_indices = BTreeSet::new(); let mut add_column =
416 |upstream_idx, order_type, table_builder: &mut TableCatalogBuilder| {
417 column_mapping.entry(upstream_idx).or_insert_with(|| {
418 let table_col_idx = table_builder.add_column(&in_fields[upstream_idx]);
419 if let Some(order_type) = order_type {
420 table_builder.add_order_column(table_col_idx, order_type);
421 order_columns.push(ColumnOrder::new(upstream_idx, order_type));
422 }
423 included_upstream_indices.push(upstream_idx);
424 table_col_idx
425 });
426 table_value_indices.insert(column_mapping[&upstream_idx]);
427 };
428
429 for (order_type, idx) in sort_keys {
430 add_column(idx, Some(order_type), &mut table_builder);
431 }
432 for idx in extra_keys {
433 add_column(idx, Some(OrderType::ascending()), &mut table_builder);
434 }
435 for idx in include_keys {
436 add_column(idx, None, &mut table_builder);
437 }
438
439 let mapping =
440 ColIndexMapping::with_included_columns(&included_upstream_indices, in_fields.len());
441 let tb_dist = mapping.rewrite_dist_key(&in_dist_key);
442 if let Some(tb_vnode_idx) = vnode_col_idx.and_then(|idx| mapping.try_map(idx)) {
443 table_builder.set_vnode_col_idx(tb_vnode_idx);
444 }
445
446 let table_value_indices = table_value_indices.into_iter().collect_vec();
448 table_builder.set_value_indices(table_value_indices.clone());
449
450 MaterializedInputState {
451 table: table_builder.build(tb_dist.unwrap_or_default(), read_prefix_len_hint),
452 included_upstream_indices,
453 table_value_indices,
454 order_columns,
455 }
456 };
457
458 self.agg_calls
459 .iter()
460 .map(|agg_call| match agg_call.agg_type {
461 agg_types::single_value_state_iff_in_append_only!() if in_append_only => {
462 AggCallState::Value
463 }
464 agg_types::single_value_state!() => AggCallState::Value,
465 agg_types::materialized_input_state!() => {
466 let sort_keys = {
468 match agg_call.agg_type {
469 AggType::Builtin(PbAggKind::Min) => {
470 vec![(OrderType::ascending(), agg_call.inputs[0].index)]
471 }
472 AggType::Builtin(PbAggKind::Max) => {
473 vec![(OrderType::descending(), agg_call.inputs[0].index)]
474 }
475 AggType::Builtin(
476 PbAggKind::FirstValue
477 | PbAggKind::LastValue
478 | PbAggKind::StringAgg
479 | PbAggKind::ArrayAgg
480 | PbAggKind::JsonbAgg,
481 )
482 | AggType::WrapScalar(_) => {
483 if agg_call.order_by.is_empty() {
484 me.ctx().warn_to_user(format!(
485 "{} without ORDER BY may produce non-deterministic result",
486 agg_call.agg_type,
487 ));
488 }
489 agg_call
490 .order_by
491 .iter()
492 .map(|o| {
493 (
494 if matches!(
495 agg_call.agg_type,
496 AggType::Builtin(PbAggKind::LastValue)
497 ) {
498 o.order_type.reverse()
499 } else {
500 o.order_type
501 },
502 o.column_index,
503 )
504 })
505 .collect()
506 }
507 AggType::Builtin(PbAggKind::JsonbObjectAgg) => agg_call
508 .order_by
509 .iter()
510 .map(|o| (o.order_type, o.column_index))
511 .collect(),
512 _ => unreachable!(),
513 }
514 };
515
516 let extra_keys = if agg_call.distinct {
518 let distinct_key = agg_call.inputs[0].index;
520 vec![distinct_key]
521 } else {
522 in_pks.clone()
524 };
525
526 let include_keys = match agg_call.agg_type {
528 AggType::Builtin(
530 PbAggKind::FirstValue
531 | PbAggKind::LastValue
532 | PbAggKind::StringAgg
533 | PbAggKind::ArrayAgg
534 | PbAggKind::JsonbAgg
535 | PbAggKind::JsonbObjectAgg,
536 )
537 | AggType::WrapScalar(_) => {
538 agg_call.inputs.iter().map(|i| i.index).collect()
539 }
540 _ => vec![],
541 };
542
543 let state = gen_materialized_input_state(sort_keys, extra_keys, include_keys);
544 AggCallState::MaterializedInput(Box::new(state))
545 }
546 agg_types::rewritten!() => {
547 unreachable!("should have been rewritten")
548 }
549 agg_types::unimplemented_in_stream!() => {
550 unreachable!("should have been banned")
551 }
552 AggType::Builtin(
553 PbAggKind::Unspecified | PbAggKind::UserDefined | PbAggKind::WrapScalar,
554 ) => {
555 unreachable!("invalid agg kind")
556 }
557 })
558 .collect()
559 }
560
561 pub fn infer_intermediate_state_table(
564 &self,
565 me: impl GenericPlanRef,
566 vnode_col_idx: Option<usize>,
567 window_col_idx: Option<usize>,
568 ) -> TableCatalog {
569 let mut out_fields = me.schema().fields().to_vec();
570
571 let in_append_only = self.input.append_only();
573 for (agg_call, field) in self
574 .agg_calls
575 .iter()
576 .zip_eq_fast(&mut out_fields[self.group_key.len()..])
577 {
578 let agg_kind = match agg_call.agg_type {
579 AggType::UserDefined(_) => {
580 field.data_type = DataType::Bytea;
582 continue;
583 }
584 AggType::WrapScalar(_) => {
585 continue;
587 }
588 AggType::Builtin(kind) => kind,
589 };
590 let sig = FUNCTION_REGISTRY
591 .get(
592 agg_kind,
593 &agg_call
594 .inputs
595 .iter()
596 .map(|input| input.data_type.clone())
597 .collect_vec(),
598 &agg_call.return_type,
599 )
600 .expect("agg not found");
601 match (in_append_only, sig.is_append_only()) {
604 (false, true) => {
605 }
609 (true, true) => {
610 if let FuncBuilder::Aggregate {
612 append_only_state_type: Some(state_type),
613 ..
614 } = &sig.build
615 {
616 field.data_type = state_type.clone();
617 }
618 }
619 (_, false) => {
620 if let FuncBuilder::Aggregate {
622 retractable_state_type: Some(state_type),
623 ..
624 } = &sig.build
625 {
626 field.data_type = state_type.clone();
627 }
628 }
629 }
630 }
631 let in_dist_key = self.input.distribution().dist_column_indices().to_vec();
632 let n_group_key_cols = self.group_key.len();
633
634 let (mut table_builder, _, _) = self.create_table_builder(me.ctx(), window_col_idx);
635 let read_prefix_len_hint = table_builder.get_current_pk_len();
636
637 for field in out_fields.iter().skip(n_group_key_cols) {
638 table_builder.add_column(field);
639 }
640
641 let mapping = self.i2o_col_mapping();
642 let tb_dist = mapping.rewrite_dist_key(&in_dist_key).unwrap_or_default();
643 if let Some(tb_vnode_idx) = vnode_col_idx.and_then(|idx| mapping.try_map(idx)) {
644 table_builder.set_vnode_col_idx(tb_vnode_idx);
645 }
646
647 table_builder.set_value_indices((n_group_key_cols..out_fields.len()).collect());
650 table_builder.build(tb_dist, read_prefix_len_hint)
651 }
652
653 pub fn infer_distinct_dedup_tables(
660 &self,
661 me: impl GenericPlanRef,
662 vnode_col_idx: Option<usize>,
663 window_col_idx: Option<usize>,
664 ) -> HashMap<usize, TableCatalog> {
665 let in_dist_key = self.input.distribution().dist_column_indices().to_vec();
666 let in_fields = self.input.schema().fields();
667
668 self.agg_calls
669 .iter()
670 .enumerate()
671 .filter(|(_, call)| call.distinct) .into_group_map_by(|(_, call)| call.inputs[0].index) .into_iter()
674 .map(|(distinct_col, indices_and_calls)| {
675 let (mut table_builder, mut key_cols, _) =
676 self.create_table_builder(me.ctx(), window_col_idx);
677 let table_col_idx = table_builder.add_column(&in_fields[distinct_col]);
678 table_builder.add_order_column(table_col_idx, OrderType::ascending());
679 key_cols.push(distinct_col);
680
681 let read_prefix_len_hint = table_builder.get_current_pk_len();
682
683 for (call_index, _) in indices_and_calls {
687 table_builder.add_column(&Field {
688 data_type: DataType::Int64,
689 name: format!("count_for_agg_call_{}", call_index),
690 });
691 }
692 table_builder
693 .set_value_indices((key_cols.len()..table_builder.columns().len()).collect());
694
695 let mapping = ColIndexMapping::with_included_columns(&key_cols, in_fields.len());
696 if let Some(idx) = vnode_col_idx.and_then(|idx| mapping.try_map(idx)) {
697 table_builder.set_vnode_col_idx(idx);
698 }
699 let dist_key = mapping.rewrite_dist_key(&in_dist_key).unwrap_or_default();
700 let table = table_builder.build(dist_key, read_prefix_len_hint);
701 (distinct_col, table)
702 })
703 .collect()
704 }
705}
706
707impl<PlanRef: GenericPlanRef> Agg<PlanRef> {
708 pub fn decompose(self) -> (Vec<PlanAggCall>, IndexSet, Vec<IndexSet>, PlanRef, bool) {
709 (
710 self.agg_calls,
711 self.group_key,
712 self.grouping_sets,
713 self.input,
714 self.enable_two_phase,
715 )
716 }
717
718 pub fn fields_pretty<'a>(&self) -> StrAssocArr<'a> {
719 let last = ("aggs", self.agg_calls_pretty());
720 if !self.group_key.is_empty() {
721 let first = ("group_key", self.group_key_pretty());
722 vec![first, last]
723 } else {
724 vec![last]
725 }
726 }
727
728 fn agg_calls_pretty<'a>(&self) -> Pretty<'a> {
729 let f = |plan_agg_call| {
730 Pretty::debug(&PlanAggCallDisplay {
731 plan_agg_call,
732 input_schema: self.input.schema(),
733 })
734 };
735 Pretty::Array(self.agg_calls.iter().map(f).collect())
736 }
737
738 fn group_key_pretty<'a>(&self) -> Pretty<'a> {
739 let f = |i| Pretty::display(&FieldDisplay(self.input.schema().fields.get(i).unwrap()));
740 Pretty::Array(self.group_key.indices().map(f).collect())
741 }
742}
743
744impl_distill_unit_from_fields!(Agg, GenericPlanRef);
745
746#[derive(Clone, PartialEq, Eq, Hash)]
750pub struct PlanAggCall {
751 pub agg_type: AggType,
753
754 pub return_type: DataType,
756
757 pub inputs: Vec<InputRef>,
766
767 pub distinct: bool,
768 pub order_by: Vec<ColumnOrder>,
769 pub filter: Condition,
772 pub direct_args: Vec<Literal>,
773}
774
775impl fmt::Debug for PlanAggCall {
776 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
777 write!(f, "{}", self.agg_type)?;
778 if !self.inputs.is_empty() {
779 write!(f, "(")?;
780 for (idx, input) in self.inputs.iter().enumerate() {
781 if idx == 0 && self.distinct {
782 write!(f, "distinct ")?;
783 }
784 write!(f, "{:?}", input)?;
785 if idx != (self.inputs.len() - 1) {
786 write!(f, ",")?;
787 }
788 }
789 if !self.order_by.is_empty() {
790 let clause_text = self.order_by.iter().map(|e| format!("{:?}", e)).join(", ");
791 write!(f, " order_by({})", clause_text)?;
792 }
793 write!(f, ")")?;
794 }
795 if !self.filter.always_true() {
796 write!(
797 f,
798 " filter({:?})",
799 self.filter.as_expr_unless_true().unwrap()
800 )?;
801 }
802 Ok(())
803 }
804}
805
806impl PlanAggCall {
807 pub fn rewrite_input_index(&mut self, mapping: ColIndexMapping) {
808 self.inputs.iter_mut().for_each(|x| {
810 x.index = mapping.map(x.index);
811 });
812
813 self.order_by.iter_mut().for_each(|x| {
815 x.column_index = mapping.map(x.column_index);
816 });
817
818 let mut rewriter = IndexRewriter::new(mapping);
820 self.filter.conjunctions.iter_mut().for_each(|x| {
821 *x = rewriter.rewrite_expr(x.clone());
822 });
823 }
824
825 pub fn to_protobuf(&self) -> PbAggCall {
826 PbAggCall {
827 kind: match &self.agg_type {
828 AggType::Builtin(kind) => *kind,
829 AggType::UserDefined(_) => PbAggKind::UserDefined,
830 AggType::WrapScalar(_) => PbAggKind::WrapScalar,
831 }
832 .into(),
833 return_type: Some(self.return_type.to_protobuf()),
834 args: self.inputs.iter().map(InputRef::to_proto).collect(),
835 distinct: self.distinct,
836 order_by: self
837 .order_by
838 .iter()
839 .copied()
840 .map(ColumnOrder::to_protobuf)
841 .collect(),
842 filter: self.filter.as_expr_unless_true().map(|x| x.to_expr_proto()),
843 direct_args: self
844 .direct_args
845 .iter()
846 .map(|x| PbConstant {
847 datum: Some(x.get_data().to_protobuf()),
848 r#type: Some(x.return_type().to_protobuf()),
849 })
850 .collect(),
851 udf: match &self.agg_type {
852 AggType::UserDefined(udf) => Some(udf.clone()),
853 _ => None,
854 },
855 scalar: match &self.agg_type {
856 AggType::WrapScalar(expr) => Some(expr.clone()),
857 _ => None,
858 },
859 }
860 }
861
862 pub fn partial_to_total_agg_call(&self, partial_output_idx: usize) -> PlanAggCall {
863 let total_agg_type = self
864 .agg_type
865 .partial_to_total()
866 .expect("unsupported kinds shouldn't get here");
867 PlanAggCall {
868 agg_type: total_agg_type,
869 inputs: vec![InputRef::new(partial_output_idx, self.return_type.clone())],
870 order_by: vec![], filter: Condition::true_cond(),
872 ..self.clone()
873 }
874 }
875
876 pub fn count_star() -> Self {
877 PlanAggCall {
878 agg_type: PbAggKind::Count.into(),
879 return_type: DataType::Int64,
880 inputs: vec![],
881 distinct: false,
882 order_by: vec![],
883 filter: Condition::true_cond(),
884 direct_args: vec![],
885 }
886 }
887
888 pub fn with_condition(mut self, filter: Condition) -> Self {
889 self.filter = filter;
890 self
891 }
892
893 pub fn input_indices(&self) -> Vec<usize> {
894 self.inputs.iter().map(|input| input.index()).collect()
895 }
896}
897
898pub struct PlanAggCallDisplay<'a> {
899 pub plan_agg_call: &'a PlanAggCall,
900 pub input_schema: &'a Schema,
901}
902
903impl fmt::Debug for PlanAggCallDisplay<'_> {
904 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
905 let that = self.plan_agg_call;
906 write!(f, "{}", that.agg_type)?;
907 if !that.inputs.is_empty() {
908 write!(f, "(")?;
909 for (idx, input) in that.inputs.iter().enumerate() {
910 if idx == 0 && that.distinct {
911 write!(f, "distinct ")?;
912 }
913 write!(
914 f,
915 "{}",
916 InputRefDisplay {
917 input_ref: input,
918 input_schema: self.input_schema
919 }
920 )?;
921 if idx != (that.inputs.len() - 1) {
922 write!(f, ", ")?;
923 }
924 }
925 if !that.order_by.is_empty() {
926 write!(
927 f,
928 " order_by({})",
929 that.order_by.iter().format_with(", ", |o, f| {
930 f(&ColumnOrderDisplay {
931 column_order: o,
932 input_schema: self.input_schema,
933 })
934 })
935 )?;
936 }
937 write!(f, ")")?;
938 }
939
940 if !that.filter.always_true() {
941 write!(
942 f,
943 " filter({:?})",
944 ConditionDisplay {
945 condition: &that.filter,
946 input_schema: self.input_schema,
947 }
948 )?;
949 }
950 Ok(())
951 }
952}