risingwave_frontend/optimizer/plan_node/
batch_sort_agg.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use risingwave_pb::batch_plan::SortAggNode;
16use risingwave_pb::batch_plan::plan_node::NodeBody;
17use risingwave_pb::expr::ExprNode;
18
19use super::batch::prelude::*;
20use super::generic::{self, PlanAggCall};
21use super::utils::impl_distill_by_unit;
22use super::{
23    BatchPlanRef as PlanRef, ExprRewritable, PlanBase, PlanTreeNodeUnary, ToBatchPb,
24    ToDistributedBatch,
25};
26use crate::error::Result;
27use crate::expr::{Expr, ExprImpl, ExprRewriter, ExprVisitor, InputRef};
28use crate::optimizer::plan_node::ToLocalBatch;
29use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
30use crate::optimizer::property::{Order, RequiredDist};
31use crate::utils::{ColIndexMappingRewriteExt, IndexSet};
32
33#[derive(Debug, Clone, PartialEq, Eq, Hash)]
34pub struct BatchSortAgg {
35    pub base: PlanBase<Batch>,
36    core: generic::Agg<PlanRef>,
37    input_order: Order,
38}
39
40impl BatchSortAgg {
41    pub fn new(core: generic::Agg<PlanRef>) -> Self {
42        assert!(!core.group_key.is_empty());
43        assert!(core.input_provides_order_on_group_keys());
44
45        let input = core.input.clone();
46        let input_dist = input.distribution();
47        let dist = core
48            .i2o_col_mapping()
49            .rewrite_provided_distribution(input_dist);
50        let input_order = Order {
51            column_orders: input
52                .order()
53                .column_orders
54                .iter()
55                .filter(|o| core.group_key.indices().any(|g_k| g_k == o.column_index))
56                .cloned()
57                .collect(),
58        };
59
60        let order = core.i2o_col_mapping().rewrite_provided_order(&input_order);
61
62        let base = PlanBase::new_batch_with_core(&core, dist, order);
63
64        BatchSortAgg {
65            base,
66            core,
67            input_order,
68        }
69    }
70
71    pub fn agg_calls(&self) -> &[PlanAggCall] {
72        &self.core.agg_calls
73    }
74
75    pub fn group_key(&self) -> &IndexSet {
76        &self.core.group_key
77    }
78}
79
80impl PlanTreeNodeUnary<Batch> for BatchSortAgg {
81    fn input(&self) -> PlanRef {
82        self.core.input.clone()
83    }
84
85    fn clone_with_input(&self, input: PlanRef) -> Self {
86        let mut core = self.core.clone();
87        core.input = input;
88        Self::new(core)
89    }
90}
91impl_plan_tree_node_for_unary! { Batch, BatchSortAgg }
92impl_distill_by_unit!(BatchSortAgg, core, "BatchSortAgg");
93
94impl ToDistributedBatch for BatchSortAgg {
95    fn to_distributed(&self) -> Result<PlanRef> {
96        let new_input = self.input().to_distributed_with_required(
97            &self.input_order,
98            &RequiredDist::shard_by_key(self.input().schema().len(), &self.group_key().to_vec()),
99        )?;
100        Ok(self.clone_with_input(new_input).into())
101    }
102}
103
104impl ToBatchPb for BatchSortAgg {
105    fn to_batch_prost_body(&self) -> NodeBody {
106        let input = self.input();
107        NodeBody::SortAgg(SortAggNode {
108            agg_calls: self
109                .agg_calls()
110                .iter()
111                .map(PlanAggCall::to_protobuf)
112                .collect(),
113            group_key: self
114                .group_key()
115                .indices()
116                .map(|idx| {
117                    ExprImpl::InputRef(InputRef::new(idx, input.schema()[idx].data_type()).into())
118                })
119                .map(|expr| expr.to_expr_proto())
120                .collect::<Vec<ExprNode>>(),
121        })
122    }
123}
124
125impl ToLocalBatch for BatchSortAgg {
126    fn to_local(&self) -> Result<PlanRef> {
127        let new_input = self.input().to_local()?;
128
129        let new_input = RequiredDist::single()
130            .batch_enforce_if_not_satisfies(new_input, self.input().order())?;
131
132        Ok(self.clone_with_input(new_input).into())
133    }
134}
135
136impl ExprRewritable<Batch> for BatchSortAgg {
137    fn has_rewritable_expr(&self) -> bool {
138        true
139    }
140
141    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
142        let mut new_logical = self.core.clone();
143        new_logical.rewrite_exprs(r);
144        Self::new(new_logical).into()
145    }
146}
147
148impl ExprVisitable for BatchSortAgg {
149    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
150        self.core.visit_exprs(v);
151    }
152}