risingwave_frontend/optimizer/rule/
agg_group_by_simplify_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use risingwave_expr::aggregate::PbAggKind;
16
17use super::super::plan_node::*;
18use super::{BoxedRule, Rule};
19use crate::expr::InputRef;
20use crate::optimizer::plan_node::generic::{Agg, GenericPlanRef};
21use crate::utils::{Condition, IndexSet};
22
23/// Use functional dependencies to simplify aggregation's group by
24/// Before:
25/// group by = [a, b, c], where b -> [a, c]
26/// After
27/// group by b, `first_value`(a), `first_value`(c),
28pub struct AggGroupBySimplifyRule {}
29impl Rule for AggGroupBySimplifyRule {
30    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
31        let agg: &LogicalAgg = plan.as_logical_agg()?;
32        let (agg_calls, group_key, grouping_sets, agg_input, _two_phase) = agg.clone().decompose();
33        if !grouping_sets.is_empty() {
34            return None;
35        }
36        let functional_dependency = agg_input.functional_dependency();
37        let group_key = group_key.to_vec();
38        if !functional_dependency.is_key(&group_key) {
39            return None;
40        }
41        let minimized_group_key = functional_dependency.minimize_key(&group_key);
42        if minimized_group_key.len() < group_key.len() {
43            let new_group_key = IndexSet::from(minimized_group_key);
44            let new_group_key_len = new_group_key.len();
45            let mut new_agg_calls = vec![];
46            for &i in &group_key {
47                if !new_group_key.contains(i) {
48                    let data_type = agg_input.schema().fields[i].data_type();
49                    new_agg_calls.push(PlanAggCall {
50                        agg_type: PbAggKind::InternalLastSeenValue.into(),
51                        return_type: data_type.clone(),
52                        inputs: vec![InputRef::new(i, data_type)],
53                        distinct: false,
54                        order_by: vec![],
55                        filter: Condition::true_cond(),
56                        direct_args: vec![],
57                    });
58                }
59            }
60            new_agg_calls.extend(agg_calls);
61
62            // Use project to align schema type
63            let mut out_fields = vec![];
64            let mut remained_group_key_offset = 0;
65            let mut removed_group_key_offset = new_group_key_len;
66            for &i in &group_key {
67                if new_group_key.contains(i) {
68                    out_fields.push(remained_group_key_offset);
69                    remained_group_key_offset += 1;
70                } else {
71                    out_fields.push(removed_group_key_offset);
72                    removed_group_key_offset += 1;
73                }
74            }
75            for i in group_key.len()..agg.base.schema().len() {
76                out_fields.push(i);
77            }
78            let new_agg = Agg::new(new_agg_calls, new_group_key, agg.input());
79
80            Some(LogicalProject::with_out_col_idx(new_agg.into(), out_fields.into_iter()).into())
81        } else {
82            None
83        }
84    }
85}
86
87impl AggGroupBySimplifyRule {
88    pub fn create() -> BoxedRule {
89        Box::new(AggGroupBySimplifyRule {})
90    }
91}