risingwave_frontend/optimizer/rule/
agg_call_merge_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use super::prelude::{PlanRef, *};
16use crate::optimizer::plan_node::generic::Agg;
17use crate::optimizer::plan_node::{LogicalProject, PlanTreeNodeUnary};
18
19/// Merges duplicated aggregate function calls in `LogicalAgg`, and project them back to the desired schema.
20pub struct AggCallMergeRule {}
21
22impl Rule<Logical> for AggCallMergeRule {
23    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
24        let agg = plan.as_logical_agg()?;
25
26        let calls = agg.agg_calls();
27        let mut new_calls = Vec::with_capacity(calls.len());
28        let mut out_fields = (0..agg.group_key().len()).collect::<Vec<_>>();
29        out_fields.extend(calls.iter().map(|call| {
30            let pos = new_calls.iter().position(|c| c == call).unwrap_or_else(|| {
31                let pos = new_calls.len();
32                new_calls.push(call.clone());
33                pos
34            });
35            agg.group_key().len() + pos
36        }));
37
38        if calls.len() == new_calls.len() {
39            // no change
40            None
41        } else {
42            let new_agg = Agg::new(new_calls, agg.group_key().clone(), agg.input())
43                .with_enable_two_phase(agg.core().two_phase_agg_enabled())
44                .into();
45            Some(LogicalProject::with_out_col_idx(new_agg, out_fields.into_iter()).into())
46        }
47    }
48}
49
50impl AggCallMergeRule {
51    pub fn create() -> BoxedRule {
52        Box::new(Self {})
53    }
54}