1use std::collections::{BTreeMap, BTreeSet, HashMap};
16use std::{fmt, vec};
17
18use itertools::{Either, Itertools};
19use pretty_xmlish::{Pretty, StrAssocArr};
20use risingwave_common::catalog::{Field, FieldDisplay, Schema};
21use risingwave_common::types::DataType;
22use risingwave_common::util::iter_util::ZipEqFast;
23use risingwave_common::util::sort_util::{ColumnOrder, ColumnOrderDisplay, OrderType};
24use risingwave_common::util::value_encoding::DatumToProtoExt;
25use risingwave_expr::aggregate::{AggType, PbAggKind, agg_types};
26use risingwave_expr::sig::{FUNCTION_REGISTRY, FuncBuilder};
27use risingwave_pb::expr::{PbAggCall, PbConstant};
28use risingwave_pb::stream_plan::{AggCallState as PbAggCallState, agg_call_state};
29
30use super::super::utils::TableCatalogBuilder;
31use super::{GenericPlanNode, GenericPlanRef, PhysicalPlanRef, impl_distill_unit_from_fields};
32use crate::TableCatalog;
33use crate::error::{ErrorCode, Result};
34use crate::expr::{Expr, ExprRewriter, ExprVisitor, InputRef, InputRefDisplay, Literal};
35use crate::optimizer::optimizer_context::OptimizerContextRef;
36use crate::optimizer::plan_node::batch::BatchPlanNodeMetadata;
37use crate::optimizer::plan_node::{BatchPlanRef, StreamPlanNodeMetadata, StreamPlanRef};
38use crate::optimizer::property::{
39 Distribution, FunctionalDependencySet, RequiredDist, WatermarkColumns,
40};
41use crate::stream_fragmenter::BuildFragmentGraphState;
42use crate::utils::{
43 ColIndexMapping, ColIndexMappingRewriteExt, Condition, ConditionDisplay, IndexRewriter,
44 IndexSet,
45};
46
47#[derive(Debug, Clone, PartialEq, Eq, Hash)]
54pub struct Agg<PlanRef> {
55 pub agg_calls: Vec<PlanAggCall>,
56 pub group_key: IndexSet,
57 pub grouping_sets: Vec<IndexSet>,
58 pub input: PlanRef,
59 pub enable_two_phase: bool,
60}
61
62impl<PlanRef: GenericPlanRef> Agg<PlanRef> {
63 pub(crate) fn clone_with_input<OtherPlanRef>(&self, input: OtherPlanRef) -> Agg<OtherPlanRef> {
64 Agg {
65 agg_calls: self.agg_calls.clone(),
66 group_key: self.group_key.clone(),
67 grouping_sets: self.grouping_sets.clone(),
68 input,
69 enable_two_phase: self.enable_two_phase,
70 }
71 }
72
73 pub(crate) fn rewrite_exprs(&mut self, r: &mut dyn ExprRewriter) {
74 self.agg_calls.iter_mut().for_each(|call| {
75 call.filter = call.filter.clone().rewrite_expr(r);
76 });
77 }
78
79 pub(crate) fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
80 self.agg_calls.iter().for_each(|call| {
81 call.filter.visit_expr(v);
82 });
83 }
84
85 pub(crate) fn output_len(&self) -> usize {
86 self.group_key.len() + self.agg_calls.len()
87 }
88
89 pub fn o2i_col_mapping(&self) -> ColIndexMapping {
92 let mut map = vec![None; self.output_len()];
93 for (i, key) in self.group_key.indices().enumerate() {
94 map[i] = Some(key);
95 }
96 ColIndexMapping::new(map, self.input.schema().len())
97 }
98
99 pub fn i2o_col_mapping(&self) -> ColIndexMapping {
101 let mut map = vec![None; self.input.schema().len()];
102 for (i, key) in self.group_key.indices().enumerate() {
103 map[key] = Some(i);
104 }
105 ColIndexMapping::new(map, self.output_len())
106 }
107
108 fn two_phase_agg_forced(&self) -> bool {
109 self.ctx().session_ctx().config().force_two_phase_agg()
110 }
111
112 pub fn two_phase_agg_enabled(&self) -> bool {
113 self.enable_two_phase
114 }
115
116 pub(crate) fn can_two_phase_agg(&self) -> bool {
117 self.two_phase_agg_enabled()
118 && !self.agg_calls.is_empty()
119 && self.agg_calls.iter().all(|call| {
120 let agg_type_ok = !matches!(call.agg_type, agg_types::simply_cannot_two_phase!());
121 let order_ok = matches!(
122 call.agg_type,
123 agg_types::result_unaffected_by_order_by!()
124 | AggType::Builtin(PbAggKind::ApproxPercentile)
125 ) || call.order_by.is_empty();
126 let distinct_ok =
127 matches!(call.agg_type, agg_types::result_unaffected_by_distinct!())
128 || !call.distinct;
129 agg_type_ok && order_ok && distinct_ok
130 })
131 }
132
133 pub(crate) fn must_try_two_phase_agg(&self) -> bool {
135 self.two_phase_agg_forced() && self.can_two_phase_agg()
136 }
137
138 pub(crate) fn hash_agg_dist_satisfied_by_input_dist(&self, input_dist: &Distribution) -> bool {
142 let required_dist =
143 RequiredDist::shard_by_key(self.input.schema().len(), &self.group_key.to_vec());
144 input_dist.satisfies(&required_dist)
145 }
146
147 pub(crate) fn all_local_aggs_are_stateless(&self, stream_input_append_only: bool) -> bool {
149 self.agg_calls.iter().all(|c| {
150 matches!(c.agg_type, agg_types::single_value_state!())
151 || (matches!(c.agg_type, agg_types::single_value_state_iff_in_append_only!() if stream_input_append_only))
152 })
153 }
154
155 pub(crate) fn eowc_window_column(
156 &self,
157 input_watermark_columns: &WatermarkColumns,
158 ) -> Result<usize> {
159 let group_key_with_wtmk = self
160 .group_key
161 .indices()
162 .filter_map(|idx| {
163 input_watermark_columns
164 .get_group(idx)
165 .map(|group| (idx, group))
166 })
167 .collect::<Vec<_>>();
168
169 if group_key_with_wtmk.is_empty() {
170 return Err(ErrorCode::NotSupported(
171 "Emit-On-Window-Close mode requires a watermark column in GROUP BY.".to_owned(),
172 "Please try to GROUP BY a watermark column".to_owned(),
173 )
174 .into());
175 }
176 if group_key_with_wtmk.len() == 1
177 || group_key_with_wtmk
178 .iter()
179 .map(|(_, group)| group)
180 .all_equal()
181 {
182 return Ok(group_key_with_wtmk[0].0);
185 }
186 Err(ErrorCode::NotSupported(
187 "Emit-On-Window-Close mode requires that watermark columns in GROUP BY are derived from the same upstream column.".to_owned(),
188 "Please try to remove undesired columns from GROUP BY".to_owned(),
189 )
190 .into())
191 }
192
193 pub fn new(agg_calls: Vec<PlanAggCall>, group_key: IndexSet, input: PlanRef) -> Self {
194 let enable_two_phase = input.ctx().session_ctx().config().enable_two_phase_agg();
195 Self {
196 agg_calls,
197 group_key,
198 input,
199 grouping_sets: vec![],
200 enable_two_phase,
201 }
202 }
203
204 pub fn with_grouping_sets(mut self, grouping_sets: Vec<IndexSet>) -> Self {
205 self.grouping_sets = grouping_sets;
206 self
207 }
208
209 pub fn with_enable_two_phase(mut self, enable_two_phase: bool) -> Self {
210 self.enable_two_phase = enable_two_phase;
211 self
212 }
213}
214
215impl Agg<BatchPlanRef> {
216 pub(crate) fn input_provides_order_on_group_keys(&self) -> bool {
218 let mut input_order_prefix = IndexSet::empty();
219 for input_order_col in &self.input.order().column_orders {
220 if !self.group_key.contains(input_order_col.column_index) {
221 break;
222 }
223 input_order_prefix.insert(input_order_col.column_index);
224 }
225 self.group_key == input_order_prefix
226 }
227}
228
229impl<PlanRef: GenericPlanRef> GenericPlanNode for Agg<PlanRef> {
230 fn schema(&self) -> Schema {
231 let fields = self
232 .group_key
233 .indices()
234 .map(|i| self.input.schema().fields()[i].clone())
235 .chain(self.agg_calls.iter().map(|agg_call| {
236 let plan_agg_call_display = PlanAggCallDisplay {
237 plan_agg_call: agg_call,
238 input_schema: self.input.schema(),
239 };
240 let name = format!("{:?}", plan_agg_call_display);
241 Field::with_name(agg_call.return_type.clone(), name)
242 }))
243 .collect();
244 Schema { fields }
245 }
246
247 fn stream_key(&self) -> Option<Vec<usize>> {
248 Some((0..self.group_key.len()).collect())
249 }
250
251 fn ctx(&self) -> OptimizerContextRef {
252 self.input.ctx()
253 }
254
255 fn functional_dependency(&self) -> FunctionalDependencySet {
256 let output_len = self.output_len();
257 let _input_len = self.input.schema().len();
258 let mut fd_set =
259 FunctionalDependencySet::with_key(output_len, &(0..self.group_key.len()).collect_vec());
260 let i2o = self.i2o_col_mapping();
262 for fd in self.input.functional_dependency().as_dependencies() {
263 if let Some(fd) = i2o.rewrite_functional_dependency(fd) {
264 fd_set.add_functional_dependency(fd);
265 }
266 }
267 fd_set
268 }
269}
270
271pub enum AggCallState {
272 Value,
273 MaterializedInput(Box<MaterializedInputState>),
274}
275
276impl AggCallState {
277 pub fn into_prost(self, state: &mut BuildFragmentGraphState) -> PbAggCallState {
278 PbAggCallState {
279 inner: Some(match self {
280 AggCallState::Value => {
281 agg_call_state::Inner::ValueState(agg_call_state::ValueState {})
282 }
283 AggCallState::MaterializedInput(s) => {
284 agg_call_state::Inner::MaterializedInputState(
285 agg_call_state::MaterializedInputState {
286 table: Some(
287 s.table
288 .with_id(state.gen_table_id_wrapped())
289 .to_internal_table_prost(),
290 ),
291 included_upstream_indices: s
292 .included_upstream_indices
293 .into_iter()
294 .map(|x| x as _)
295 .collect(),
296 table_value_indices: s
297 .table_value_indices
298 .into_iter()
299 .map(|x| x as _)
300 .collect(),
301 order_columns: s
302 .order_columns
303 .into_iter()
304 .map(|x| x.to_protobuf())
305 .collect(),
306 },
307 )
308 }
309 }),
310 }
311 }
312}
313
314pub struct MaterializedInputState {
315 pub table: TableCatalog,
316 pub included_upstream_indices: Vec<usize>,
317 pub table_value_indices: Vec<usize>,
318 pub order_columns: Vec<ColumnOrder>,
319}
320
321impl Agg<StreamPlanRef> {
322 pub fn infer_tables(
323 &self,
324 me: impl StreamPlanNodeMetadata,
325 vnode_col_idx: Option<usize>,
326 window_col_idx: Option<usize>,
327 ) -> (
328 TableCatalog,
329 Vec<AggCallState>,
330 HashMap<usize, TableCatalog>,
331 ) {
332 (
333 self.infer_intermediate_state_table(&me, vnode_col_idx, window_col_idx),
334 self.infer_stream_agg_state(&me, vnode_col_idx, window_col_idx),
335 self.infer_distinct_dedup_tables(&me, vnode_col_idx, window_col_idx),
336 )
337 }
338
339 fn get_ordered_group_key(&self, window_col_idx: Option<usize>) -> Vec<usize> {
340 if let Some(window_col_idx) = window_col_idx {
341 assert!(self.group_key.contains(window_col_idx));
342 Either::Left(
343 std::iter::once(window_col_idx).chain(
344 self.group_key
345 .indices()
346 .filter(move |&i| i != window_col_idx),
347 ),
348 )
349 } else {
350 Either::Right(self.group_key.indices())
351 }
352 .collect()
353 }
354
355 fn create_table_builder(
363 &self,
364 _ctx: OptimizerContextRef,
365 window_col_idx: Option<usize>,
366 ) -> (TableCatalogBuilder, Vec<usize>, BTreeMap<usize, usize>) {
367 let mut table_builder = TableCatalogBuilder::default();
370
371 assert!(table_builder.columns().is_empty());
372 assert_eq!(table_builder.get_current_pk_len(), 0);
373
374 let mut included_upstream_indices = vec![];
376 let mut column_mapping = BTreeMap::new();
377 let in_fields = self.input.schema().fields();
378 for idx in self.group_key.indices() {
379 let tbl_col_idx = table_builder.add_column(&in_fields[idx]);
380 included_upstream_indices.push(idx);
381 column_mapping.insert(idx, tbl_col_idx);
382 }
383
384 let ordered_group_key = self.get_ordered_group_key(window_col_idx);
386 for idx in ordered_group_key {
387 table_builder.add_order_column(column_mapping[&idx], OrderType::ascending());
388 }
389
390 (table_builder, included_upstream_indices, column_mapping)
391 }
392
393 pub fn infer_stream_agg_state(
395 &self,
396 me: impl StreamPlanNodeMetadata,
397 vnode_col_idx: Option<usize>,
398 window_col_idx: Option<usize>,
399 ) -> Vec<AggCallState> {
400 let in_fields = self.input.schema().fields().to_vec();
401 let in_pks = self.input.expect_stream_key().to_vec();
402 let in_append_only = self.input.append_only();
403 let in_dist_key = self.input.distribution().dist_column_indices().to_vec();
404
405 let gen_materialized_input_state = |sort_keys: Vec<(OrderType, usize)>,
406 extra_keys: Vec<usize>,
407 include_keys: Vec<usize>|
408 -> MaterializedInputState {
409 let (mut table_builder, mut included_upstream_indices, mut column_mapping) =
410 self.create_table_builder(me.ctx(), window_col_idx);
411 let read_prefix_len_hint = table_builder.get_current_pk_len();
412
413 let mut order_columns = vec![];
414 let mut table_value_indices = BTreeSet::new(); let mut add_column =
416 |upstream_idx, order_type, table_builder: &mut TableCatalogBuilder| {
417 column_mapping.entry(upstream_idx).or_insert_with(|| {
418 let table_col_idx = table_builder.add_column(&in_fields[upstream_idx]);
419 if let Some(order_type) = order_type {
420 table_builder.add_order_column(table_col_idx, order_type);
421 order_columns.push(ColumnOrder::new(upstream_idx, order_type));
422 }
423 included_upstream_indices.push(upstream_idx);
424 table_col_idx
425 });
426 table_value_indices.insert(column_mapping[&upstream_idx]);
427 };
428
429 for (order_type, idx) in sort_keys {
430 add_column(idx, Some(order_type), &mut table_builder);
431 }
432 for idx in extra_keys {
433 add_column(idx, Some(OrderType::ascending()), &mut table_builder);
434 }
435 for idx in include_keys {
436 add_column(idx, None, &mut table_builder);
437 }
438
439 let mapping =
440 ColIndexMapping::with_included_columns(&included_upstream_indices, in_fields.len());
441 let tb_dist = mapping.rewrite_dist_key(&in_dist_key);
442 if let Some(tb_vnode_idx) = vnode_col_idx.and_then(|idx| mapping.try_map(idx)) {
443 table_builder.set_vnode_col_idx(tb_vnode_idx);
444 }
445
446 let table_value_indices = table_value_indices.into_iter().collect_vec();
448 table_builder.set_value_indices(table_value_indices.clone());
449
450 MaterializedInputState {
451 table: table_builder.build(tb_dist.unwrap_or_default(), read_prefix_len_hint),
452 included_upstream_indices,
453 table_value_indices,
454 order_columns,
455 }
456 };
457
458 self.agg_calls
459 .iter()
460 .map(|agg_call| match agg_call.agg_type {
461 agg_types::single_value_state_iff_in_append_only!() if in_append_only => {
462 AggCallState::Value
463 }
464 agg_types::single_value_state!() => AggCallState::Value,
465 agg_types::materialized_input_state!() => {
466 let sort_keys = {
468 match agg_call.agg_type {
469 AggType::Builtin(PbAggKind::Min) => {
470 vec![(OrderType::ascending(), agg_call.inputs[0].index)]
471 }
472 AggType::Builtin(PbAggKind::Max) => {
473 vec![(OrderType::descending(), agg_call.inputs[0].index)]
474 }
475 AggType::Builtin(
476 PbAggKind::FirstValue
477 | PbAggKind::LastValue
478 | PbAggKind::StringAgg
479 | PbAggKind::ArrayAgg
480 | PbAggKind::JsonbAgg
481 | PbAggKind::PercentileCont
482 | PbAggKind::PercentileDisc
483 | PbAggKind::Mode,
484 )
485 | AggType::WrapScalar(_) => {
486 if agg_call.order_by.is_empty() {
487 me.ctx().warn_to_user(format!(
488 "{} without ORDER BY may produce non-deterministic result",
489 agg_call.agg_type,
490 ));
491 }
492 agg_call
493 .order_by
494 .iter()
495 .map(|o| {
496 (
497 if matches!(
498 agg_call.agg_type,
499 AggType::Builtin(PbAggKind::LastValue)
500 ) {
501 o.order_type.reverse()
502 } else {
503 o.order_type
504 },
505 o.column_index,
506 )
507 })
508 .collect()
509 }
510 AggType::Builtin(PbAggKind::JsonbObjectAgg) => agg_call
511 .order_by
512 .iter()
513 .map(|o| (o.order_type, o.column_index))
514 .collect(),
515 _ => unreachable!(),
516 }
517 };
518
519 let extra_keys = if agg_call.distinct {
521 let distinct_key = agg_call.inputs[0].index;
523 vec![distinct_key]
524 } else {
525 in_pks.clone()
527 };
528
529 let include_keys = match agg_call.agg_type {
531 AggType::Builtin(
533 PbAggKind::FirstValue
534 | PbAggKind::LastValue
535 | PbAggKind::StringAgg
536 | PbAggKind::ArrayAgg
537 | PbAggKind::JsonbAgg
538 | PbAggKind::JsonbObjectAgg
539 | PbAggKind::PercentileCont
540 | PbAggKind::PercentileDisc
541 | PbAggKind::Mode,
542 )
543 | AggType::WrapScalar(_) => {
544 agg_call.inputs.iter().map(|i| i.index).collect()
545 }
546 _ => vec![],
547 };
548
549 let state = gen_materialized_input_state(sort_keys, extra_keys, include_keys);
550 AggCallState::MaterializedInput(Box::new(state))
551 }
552 agg_types::rewritten!() => {
553 unreachable!("should have been rewritten")
554 }
555 AggType::Builtin(
556 PbAggKind::Unspecified | PbAggKind::UserDefined | PbAggKind::WrapScalar,
557 ) => {
558 unreachable!("invalid agg kind")
559 }
560 })
561 .collect()
562 }
563
564 pub fn infer_intermediate_state_table(
567 &self,
568 me: impl GenericPlanRef,
569 vnode_col_idx: Option<usize>,
570 window_col_idx: Option<usize>,
571 ) -> TableCatalog {
572 let mut out_fields = me.schema().fields().to_vec();
573
574 let in_append_only = self.input.append_only();
576 for (agg_call, field) in self
577 .agg_calls
578 .iter()
579 .zip_eq_fast(&mut out_fields[self.group_key.len()..])
580 {
581 let agg_kind = match agg_call.agg_type {
582 AggType::UserDefined(_) => {
583 field.data_type = DataType::Bytea;
585 continue;
586 }
587 AggType::WrapScalar(_) => {
588 continue;
590 }
591 AggType::Builtin(kind) => kind,
592 };
593 let sig = FUNCTION_REGISTRY
594 .get(
595 agg_kind,
596 &agg_call
597 .inputs
598 .iter()
599 .map(|input| input.data_type.clone())
600 .collect_vec(),
601 &agg_call.return_type,
602 )
603 .expect("agg not found");
604 match (in_append_only, sig.is_append_only()) {
607 (false, true) => {
608 }
612 (true, true) => {
613 if let FuncBuilder::Aggregate {
615 append_only_state_type: Some(state_type),
616 ..
617 } = &sig.build
618 {
619 field.data_type = state_type.clone();
620 }
621 }
622 (_, false) => {
623 if let FuncBuilder::Aggregate {
625 retractable_state_type: Some(state_type),
626 ..
627 } = &sig.build
628 {
629 field.data_type = state_type.clone();
630 }
631 }
632 }
633 }
634 let in_dist_key = self.input.distribution().dist_column_indices().to_vec();
635 let n_group_key_cols = self.group_key.len();
636
637 let (mut table_builder, _, _) = self.create_table_builder(me.ctx(), window_col_idx);
638 let read_prefix_len_hint = table_builder.get_current_pk_len();
639
640 for field in out_fields.iter().skip(n_group_key_cols) {
641 table_builder.add_column(field);
642 }
643
644 let mapping = self.i2o_col_mapping();
645 let tb_dist = mapping.rewrite_dist_key(&in_dist_key).unwrap_or_default();
646 if let Some(tb_vnode_idx) = vnode_col_idx.and_then(|idx| mapping.try_map(idx)) {
647 table_builder.set_vnode_col_idx(tb_vnode_idx);
648 }
649
650 table_builder.set_value_indices((n_group_key_cols..out_fields.len()).collect());
653 table_builder.build(tb_dist, read_prefix_len_hint)
654 }
655
656 pub fn infer_distinct_dedup_tables(
663 &self,
664 me: impl GenericPlanRef,
665 vnode_col_idx: Option<usize>,
666 window_col_idx: Option<usize>,
667 ) -> HashMap<usize, TableCatalog> {
668 let in_dist_key = self.input.distribution().dist_column_indices().to_vec();
669 let in_fields = self.input.schema().fields();
670
671 self.agg_calls
672 .iter()
673 .enumerate()
674 .filter(|(_, call)| call.distinct) .into_group_map_by(|(_, call)| call.inputs[0].index) .into_iter()
677 .map(|(distinct_col, indices_and_calls)| {
678 let (mut table_builder, mut key_cols, _) =
679 self.create_table_builder(me.ctx(), window_col_idx);
680 let table_col_idx = table_builder.add_column(&in_fields[distinct_col]);
681 table_builder.add_order_column(table_col_idx, OrderType::ascending());
682 key_cols.push(distinct_col);
683
684 let read_prefix_len_hint = table_builder.get_current_pk_len();
685
686 for (call_index, _) in indices_and_calls {
690 table_builder.add_column(&Field {
691 data_type: DataType::Int64,
692 name: format!("count_for_agg_call_{}", call_index),
693 });
694 }
695 table_builder
696 .set_value_indices((key_cols.len()..table_builder.columns().len()).collect());
697
698 let mapping = ColIndexMapping::with_included_columns(&key_cols, in_fields.len());
699 if let Some(idx) = vnode_col_idx.and_then(|idx| mapping.try_map(idx)) {
700 table_builder.set_vnode_col_idx(idx);
701 }
702 let dist_key = mapping.rewrite_dist_key(&in_dist_key).unwrap_or_default();
703 let table = table_builder.build(dist_key, read_prefix_len_hint);
704 (distinct_col, table)
705 })
706 .collect()
707 }
708}
709
710impl<PlanRef: GenericPlanRef> Agg<PlanRef> {
711 pub fn decompose(self) -> (Vec<PlanAggCall>, IndexSet, Vec<IndexSet>, PlanRef, bool) {
712 (
713 self.agg_calls,
714 self.group_key,
715 self.grouping_sets,
716 self.input,
717 self.enable_two_phase,
718 )
719 }
720
721 pub fn fields_pretty<'a>(&self) -> StrAssocArr<'a> {
722 let last = ("aggs", self.agg_calls_pretty());
723 if !self.group_key.is_empty() {
724 let first = ("group_key", self.group_key_pretty());
725 vec![first, last]
726 } else {
727 vec![last]
728 }
729 }
730
731 fn agg_calls_pretty<'a>(&self) -> Pretty<'a> {
732 let f = |plan_agg_call| {
733 Pretty::debug(&PlanAggCallDisplay {
734 plan_agg_call,
735 input_schema: self.input.schema(),
736 })
737 };
738 Pretty::Array(self.agg_calls.iter().map(f).collect())
739 }
740
741 fn group_key_pretty<'a>(&self) -> Pretty<'a> {
742 let f = |i| Pretty::display(&FieldDisplay(self.input.schema().fields.get(i).unwrap()));
743 Pretty::Array(self.group_key.indices().map(f).collect())
744 }
745}
746
747impl_distill_unit_from_fields!(Agg, GenericPlanRef);
748
749#[derive(Clone, PartialEq, Eq, Hash)]
753pub struct PlanAggCall {
754 pub agg_type: AggType,
756
757 pub return_type: DataType,
759
760 pub inputs: Vec<InputRef>,
769
770 pub distinct: bool,
771 pub order_by: Vec<ColumnOrder>,
772 pub filter: Condition,
775 pub direct_args: Vec<Literal>,
776}
777
778impl fmt::Debug for PlanAggCall {
779 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
780 write!(f, "{}", self.agg_type)?;
781 if !self.inputs.is_empty() {
782 write!(f, "(")?;
783 for (idx, input) in self.inputs.iter().enumerate() {
784 if idx == 0 && self.distinct {
785 write!(f, "distinct ")?;
786 }
787 write!(f, "{:?}", input)?;
788 if idx != (self.inputs.len() - 1) {
789 write!(f, ",")?;
790 }
791 }
792 if !self.order_by.is_empty() {
793 let clause_text = self.order_by.iter().map(|e| format!("{:?}", e)).join(", ");
794 write!(f, " order_by({})", clause_text)?;
795 }
796 write!(f, ")")?;
797 }
798 if !self.filter.always_true() {
799 write!(
800 f,
801 " filter({:?})",
802 self.filter.as_expr_unless_true().unwrap()
803 )?;
804 }
805 Ok(())
806 }
807}
808
809impl PlanAggCall {
810 pub fn rewrite_input_index(&mut self, mapping: ColIndexMapping) {
811 self.inputs.iter_mut().for_each(|x| {
813 x.index = mapping.map(x.index);
814 });
815
816 self.order_by.iter_mut().for_each(|x| {
818 x.column_index = mapping.map(x.column_index);
819 });
820
821 let mut rewriter = IndexRewriter::new(mapping);
823 self.filter.conjunctions.iter_mut().for_each(|x| {
824 *x = rewriter.rewrite_expr(x.clone());
825 });
826 }
827
828 pub fn to_protobuf(&self) -> PbAggCall {
829 PbAggCall {
830 kind: match &self.agg_type {
831 AggType::Builtin(kind) => *kind,
832 AggType::UserDefined(_) => PbAggKind::UserDefined,
833 AggType::WrapScalar(_) => PbAggKind::WrapScalar,
834 }
835 .into(),
836 return_type: Some(self.return_type.to_protobuf()),
837 args: self.inputs.iter().map(InputRef::to_proto).collect(),
838 distinct: self.distinct,
839 order_by: self
840 .order_by
841 .iter()
842 .copied()
843 .map(ColumnOrder::to_protobuf)
844 .collect(),
845 filter: self.filter.as_expr_unless_true().map(|x| x.to_expr_proto()),
846 direct_args: self
847 .direct_args
848 .iter()
849 .map(|x| PbConstant {
850 datum: Some(x.get_data().to_protobuf()),
851 r#type: Some(x.return_type().to_protobuf()),
852 })
853 .collect(),
854 udf: match &self.agg_type {
855 AggType::UserDefined(udf) => Some(udf.clone()),
856 _ => None,
857 },
858 scalar: match &self.agg_type {
859 AggType::WrapScalar(expr) => Some(expr.clone()),
860 _ => None,
861 },
862 }
863 }
864
865 pub fn partial_to_total_agg_call(&self, partial_output_idx: usize) -> PlanAggCall {
866 let total_agg_type = self
867 .agg_type
868 .partial_to_total()
869 .expect("unsupported kinds shouldn't get here");
870 PlanAggCall {
871 agg_type: total_agg_type,
872 inputs: vec![InputRef::new(partial_output_idx, self.return_type.clone())],
873 order_by: vec![], filter: Condition::true_cond(),
875 ..self.clone()
876 }
877 }
878
879 pub fn count_star() -> Self {
880 PlanAggCall {
881 agg_type: PbAggKind::Count.into(),
882 return_type: DataType::Int64,
883 inputs: vec![],
884 distinct: false,
885 order_by: vec![],
886 filter: Condition::true_cond(),
887 direct_args: vec![],
888 }
889 }
890
891 pub fn with_condition(mut self, filter: Condition) -> Self {
892 self.filter = filter;
893 self
894 }
895
896 pub fn input_indices(&self) -> Vec<usize> {
897 self.inputs.iter().map(|input| input.index()).collect()
898 }
899}
900
901pub struct PlanAggCallDisplay<'a> {
902 pub plan_agg_call: &'a PlanAggCall,
903 pub input_schema: &'a Schema,
904}
905
906impl fmt::Debug for PlanAggCallDisplay<'_> {
907 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
908 let that = self.plan_agg_call;
909 write!(f, "{}", that.agg_type)?;
910 if !that.inputs.is_empty() {
911 write!(f, "(")?;
912 for (idx, input) in that.inputs.iter().enumerate() {
913 if idx == 0 && that.distinct {
914 write!(f, "distinct ")?;
915 }
916 write!(
917 f,
918 "{}",
919 InputRefDisplay {
920 input_ref: input,
921 input_schema: self.input_schema
922 }
923 )?;
924 if idx != (that.inputs.len() - 1) {
925 write!(f, ", ")?;
926 }
927 }
928 if !that.order_by.is_empty() {
929 write!(
930 f,
931 " order_by({})",
932 that.order_by.iter().format_with(", ", |o, f| {
933 f(&ColumnOrderDisplay {
934 column_order: o,
935 input_schema: self.input_schema,
936 })
937 })
938 )?;
939 }
940 write!(f, ")")?;
941 }
942
943 if !that.filter.always_true() {
944 write!(
945 f,
946 " filter({:?})",
947 ConditionDisplay {
948 condition: &that.filter,
949 input_schema: self.input_schema,
950 }
951 )?;
952 }
953 Ok(())
954 }
955}