risingwave_frontend/optimizer/rule/
agg_project_merge_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16
17use super::super::plan_node::*;
18use super::{BoxedRule, Rule};
19use crate::utils::IndexSet;
20
21/// Merge [`LogicalAgg`] <- [`LogicalProject`] to [`LogicalAgg`].
22pub struct AggProjectMergeRule {}
23impl Rule for AggProjectMergeRule {
24    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
25        let agg = plan.as_logical_agg()?;
26        let agg = agg.core().clone();
27        assert!(agg.grouping_sets.is_empty());
28        let old_input = agg.input.clone();
29        let proj = old_input.as_logical_project()?;
30        // only apply when the input proj is all input-ref
31        if !proj.is_all_inputref() {
32            return None;
33        }
34        let proj_o2i = proj.o2i_col_mapping();
35
36        // modify group key according to projection
37        let new_agg_group_keys_in_vec = agg
38            .group_key
39            .indices()
40            .map(|x| proj_o2i.map(x))
41            .collect_vec();
42        let new_agg_group_keys = IndexSet::from_iter(new_agg_group_keys_in_vec.clone());
43
44        let mut agg = agg;
45        agg.input = proj.input();
46        // modify agg calls according to projection
47        agg.agg_calls
48            .iter_mut()
49            .for_each(|x| x.rewrite_input_index(proj_o2i.clone()));
50        agg.group_key = new_agg_group_keys.clone();
51        agg.input = proj.input();
52
53        if new_agg_group_keys.to_vec() != new_agg_group_keys_in_vec {
54            // Need a project
55            let new_agg_group_keys_cardinality = new_agg_group_keys.len();
56            let out_col_idx = new_agg_group_keys_in_vec
57                .into_iter()
58                .map(|x| new_agg_group_keys.indices().position(|y| y == x).unwrap())
59                .chain(
60                    new_agg_group_keys_cardinality
61                        ..new_agg_group_keys_cardinality + agg.agg_calls.len(),
62                );
63            Some(LogicalProject::with_out_col_idx(agg.into(), out_col_idx).into())
64        } else {
65            Some(agg.into())
66        }
67    }
68}
69
70impl AggProjectMergeRule {
71    pub fn create() -> BoxedRule {
72        Box::new(AggProjectMergeRule {})
73    }
74}