risingwave_frontend/optimizer/rule/
agg_call_merge_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use super::{BoxedRule, Rule};
16use crate::PlanRef;
17use crate::optimizer::plan_node::generic::Agg;
18use crate::optimizer::plan_node::{LogicalProject, PlanTreeNodeUnary};
19
20/// Merges duplicated aggregate function calls in `LogicalAgg`, and project them back to the desired schema.
21pub struct AggCallMergeRule {}
22
23impl Rule for AggCallMergeRule {
24    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
25        let agg = plan.as_logical_agg()?;
26
27        let calls = agg.agg_calls();
28        let mut new_calls = Vec::with_capacity(calls.len());
29        let mut out_fields = (0..agg.group_key().len()).collect::<Vec<_>>();
30        out_fields.extend(calls.iter().map(|call| {
31            let pos = new_calls.iter().position(|c| c == call).unwrap_or_else(|| {
32                let pos = new_calls.len();
33                new_calls.push(call.clone());
34                pos
35            });
36            agg.group_key().len() + pos
37        }));
38
39        if calls.len() == new_calls.len() {
40            // no change
41            None
42        } else {
43            let new_agg = Agg::new(new_calls, agg.group_key().clone(), agg.input())
44                .with_enable_two_phase(agg.core().two_phase_agg_enabled())
45                .into();
46            Some(LogicalProject::with_out_col_idx(new_agg, out_fields.into_iter()).into())
47        }
48    }
49}
50
51impl AggCallMergeRule {
52    pub fn create() -> BoxedRule {
53        Box::new(Self {})
54    }
55}