risingwave_frontend/optimizer/plan_node/
batch_sort_agg.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use risingwave_pb::batch_plan::SortAggNode;
16use risingwave_pb::batch_plan::plan_node::NodeBody;
17use risingwave_pb::expr::ExprNode;
18
19use super::batch::prelude::*;
20use super::generic::{self, PlanAggCall};
21use super::utils::impl_distill_by_unit;
22use super::{ExprRewritable, PlanBase, PlanRef, PlanTreeNodeUnary, ToBatchPb, ToDistributedBatch};
23use crate::error::Result;
24use crate::expr::{Expr, ExprImpl, ExprRewriter, ExprVisitor, InputRef};
25use crate::optimizer::plan_node::ToLocalBatch;
26use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
27use crate::optimizer::property::{Order, RequiredDist};
28use crate::utils::{ColIndexMappingRewriteExt, IndexSet};
29
30#[derive(Debug, Clone, PartialEq, Eq, Hash)]
31pub struct BatchSortAgg {
32    pub base: PlanBase<Batch>,
33    core: generic::Agg<PlanRef>,
34    input_order: Order,
35}
36
37impl BatchSortAgg {
38    pub fn new(core: generic::Agg<PlanRef>) -> Self {
39        assert!(!core.group_key.is_empty());
40        assert!(core.input_provides_order_on_group_keys());
41
42        let input = core.input.clone();
43        let input_dist = input.distribution();
44        let dist = core
45            .i2o_col_mapping()
46            .rewrite_provided_distribution(input_dist);
47        let input_order = Order {
48            column_orders: input
49                .order()
50                .column_orders
51                .iter()
52                .filter(|o| core.group_key.indices().any(|g_k| g_k == o.column_index))
53                .cloned()
54                .collect(),
55        };
56
57        let order = core.i2o_col_mapping().rewrite_provided_order(&input_order);
58
59        let base = PlanBase::new_batch_with_core(&core, dist, order);
60
61        BatchSortAgg {
62            base,
63            core,
64            input_order,
65        }
66    }
67
68    pub fn agg_calls(&self) -> &[PlanAggCall] {
69        &self.core.agg_calls
70    }
71
72    pub fn group_key(&self) -> &IndexSet {
73        &self.core.group_key
74    }
75}
76
77impl PlanTreeNodeUnary for BatchSortAgg {
78    fn input(&self) -> PlanRef {
79        self.core.input.clone()
80    }
81
82    fn clone_with_input(&self, input: PlanRef) -> Self {
83        let mut core = self.core.clone();
84        core.input = input;
85        Self::new(core)
86    }
87}
88impl_plan_tree_node_for_unary! { BatchSortAgg }
89impl_distill_by_unit!(BatchSortAgg, core, "BatchSortAgg");
90
91impl ToDistributedBatch for BatchSortAgg {
92    fn to_distributed(&self) -> Result<PlanRef> {
93        let new_input = self.input().to_distributed_with_required(
94            &self.input_order,
95            &RequiredDist::shard_by_key(self.input().schema().len(), &self.group_key().to_vec()),
96        )?;
97        Ok(self.clone_with_input(new_input).into())
98    }
99}
100
101impl ToBatchPb for BatchSortAgg {
102    fn to_batch_prost_body(&self) -> NodeBody {
103        let input = self.input();
104        NodeBody::SortAgg(SortAggNode {
105            agg_calls: self
106                .agg_calls()
107                .iter()
108                .map(PlanAggCall::to_protobuf)
109                .collect(),
110            group_key: self
111                .group_key()
112                .indices()
113                .map(|idx| {
114                    ExprImpl::InputRef(InputRef::new(idx, input.schema()[idx].data_type()).into())
115                })
116                .map(|expr| expr.to_expr_proto())
117                .collect::<Vec<ExprNode>>(),
118        })
119    }
120}
121
122impl ToLocalBatch for BatchSortAgg {
123    fn to_local(&self) -> Result<PlanRef> {
124        let new_input = self.input().to_local()?;
125
126        let new_input =
127            RequiredDist::single().enforce_if_not_satisfies(new_input, self.input().order())?;
128
129        Ok(self.clone_with_input(new_input).into())
130    }
131}
132
133impl ExprRewritable for BatchSortAgg {
134    fn has_rewritable_expr(&self) -> bool {
135        true
136    }
137
138    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
139        let mut new_logical = self.core.clone();
140        new_logical.rewrite_exprs(r);
141        Self::new(new_logical).into()
142    }
143}
144
145impl ExprVisitable for BatchSortAgg {
146    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
147        self.core.visit_exprs(v);
148    }
149}