risingwave_frontend/optimizer/plan_node/
batch_sort_agg.rs1use risingwave_pb::batch_plan::SortAggNode;
16use risingwave_pb::batch_plan::plan_node::NodeBody;
17use risingwave_pb::expr::ExprNode;
18
19use super::batch::prelude::*;
20use super::generic::{self, PlanAggCall};
21use super::utils::impl_distill_by_unit;
22use super::{
23 BatchPlanRef as PlanRef, ExprRewritable, PlanBase, PlanTreeNodeUnary, ToBatchPb,
24 ToDistributedBatch,
25};
26use crate::error::Result;
27use crate::expr::{Expr, ExprImpl, ExprRewriter, ExprVisitor, InputRef};
28use crate::optimizer::plan_node::ToLocalBatch;
29use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
30use crate::optimizer::property::{Order, RequiredDist};
31use crate::utils::{ColIndexMappingRewriteExt, IndexSet};
32
33#[derive(Debug, Clone, PartialEq, Eq, Hash)]
34pub struct BatchSortAgg {
35 pub base: PlanBase<Batch>,
36 core: generic::Agg<PlanRef>,
37 input_order: Order,
38}
39
40impl BatchSortAgg {
41 pub fn new(core: generic::Agg<PlanRef>) -> Self {
42 assert!(!core.group_key.is_empty());
43 assert!(core.input_provides_order_on_group_keys());
44
45 let input = core.input.clone();
46 let input_dist = input.distribution();
47 let dist = core
48 .i2o_col_mapping()
49 .rewrite_provided_distribution(input_dist);
50 let input_order = Order {
51 column_orders: input
52 .order()
53 .column_orders
54 .iter()
55 .filter(|o| core.group_key.indices().any(|g_k| g_k == o.column_index))
56 .cloned()
57 .collect(),
58 };
59
60 let order = core.i2o_col_mapping().rewrite_provided_order(&input_order);
61
62 let base = PlanBase::new_batch_with_core(&core, dist, order);
63
64 BatchSortAgg {
65 base,
66 core,
67 input_order,
68 }
69 }
70
71 pub fn agg_calls(&self) -> &[PlanAggCall] {
72 &self.core.agg_calls
73 }
74
75 pub fn group_key(&self) -> &IndexSet {
76 &self.core.group_key
77 }
78}
79
80impl PlanTreeNodeUnary<Batch> for BatchSortAgg {
81 fn input(&self) -> PlanRef {
82 self.core.input.clone()
83 }
84
85 fn clone_with_input(&self, input: PlanRef) -> Self {
86 let mut core = self.core.clone();
87 core.input = input;
88 Self::new(core)
89 }
90}
91impl_plan_tree_node_for_unary! { Batch, BatchSortAgg }
92impl_distill_by_unit!(BatchSortAgg, core, "BatchSortAgg");
93
94impl ToDistributedBatch for BatchSortAgg {
95 fn to_distributed(&self) -> Result<PlanRef> {
96 let new_input = self.input().to_distributed_with_required(
97 &self.input_order,
98 &RequiredDist::shard_by_key(self.input().schema().len(), &self.group_key().to_vec()),
99 )?;
100 Ok(self.clone_with_input(new_input).into())
101 }
102}
103
104impl ToBatchPb for BatchSortAgg {
105 fn to_batch_prost_body(&self) -> NodeBody {
106 let input = self.input();
107 NodeBody::SortAgg(SortAggNode {
108 agg_calls: self
109 .agg_calls()
110 .iter()
111 .map(PlanAggCall::to_protobuf)
112 .collect(),
113 group_key: self
114 .group_key()
115 .indices()
116 .map(|idx| {
117 ExprImpl::InputRef(InputRef::new(idx, input.schema()[idx].data_type()).into())
118 })
119 .map(|expr| expr.to_expr_proto())
120 .collect::<Vec<ExprNode>>(),
121 })
122 }
123}
124
125impl ToLocalBatch for BatchSortAgg {
126 fn to_local(&self) -> Result<PlanRef> {
127 let new_input = self.input().to_local()?;
128
129 let new_input = RequiredDist::single()
130 .batch_enforce_if_not_satisfies(new_input, self.input().order())?;
131
132 Ok(self.clone_with_input(new_input).into())
133 }
134}
135
136impl ExprRewritable<Batch> for BatchSortAgg {
137 fn has_rewritable_expr(&self) -> bool {
138 true
139 }
140
141 fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
142 let mut new_logical = self.core.clone();
143 new_logical.rewrite_exprs(r);
144 Self::new(new_logical).into()
145 }
146}
147
148impl ExprVisitable for BatchSortAgg {
149 fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
150 self.core.visit_exprs(v);
151 }
152}