risingwave_frontend/optimizer/rule/
grouping_sets_to_expand_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use fixedbitset::FixedBitSet;
16use itertools::Itertools;
17use risingwave_common::types::DataType;
18use risingwave_common::util::column_index_mapping::ColIndexMapping;
19use risingwave_expr::aggregate::{AggType, PbAggKind};
20
21use super::super::plan_node::*;
22use super::{BoxedRule, Rule};
23use crate::expr::{Expr, ExprImpl, ExprType, FunctionCall, InputRef};
24use crate::optimizer::plan_node::generic::{Agg, GenericPlanNode, GenericPlanRef};
25pub struct GroupingSetsToExpandRule {}
26
27impl GroupingSetsToExpandRule {
28    pub fn create() -> BoxedRule {
29        Box::new(Self {})
30    }
31
32    /// TODO: Remove this method when we support column pruning for `Expand`.
33    fn prune_column_for_agg(agg: &LogicalAgg) -> LogicalAgg {
34        let group_key_required_cols = agg.group_key().to_bitset();
35        let agg_call_required_cols = {
36            let input_cnt = agg.input().schema().len();
37            let mut tmp = FixedBitSet::with_capacity(input_cnt);
38
39            agg.agg_calls().iter().for_each(|agg_call| {
40                tmp.extend(agg_call.inputs.iter().map(|x| x.index()));
41                tmp.extend(agg_call.order_by.iter().map(|x| x.column_index));
42                // collect columns used in aggregate filter expressions
43                for i in &agg_call.filter.conjunctions {
44                    tmp.union_with(&i.collect_input_refs(input_cnt));
45                }
46            });
47            tmp
48        };
49
50        let input_required_cols = {
51            let mut tmp = FixedBitSet::with_capacity(agg.input().schema().len());
52            tmp.union_with(&group_key_required_cols);
53            tmp.union_with(&agg_call_required_cols);
54            tmp.ones().collect_vec()
55        };
56        let input_col_change = ColIndexMapping::with_remaining_columns(
57            &input_required_cols,
58            agg.input().schema().len(),
59        );
60        let input =
61            LogicalProject::with_out_col_idx(agg.input(), input_required_cols.iter().cloned())
62                .into();
63
64        let (new_agg, output_col_change) =
65            agg.rewrite_with_input_agg(input, agg.agg_calls(), input_col_change);
66        assert!(output_col_change.is_identity());
67        new_agg
68    }
69}
70
71impl Rule for GroupingSetsToExpandRule {
72    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
73        let agg: &LogicalAgg = plan.as_logical_agg()?;
74        if agg.grouping_sets().is_empty() {
75            return None;
76        }
77        let agg = Self::prune_column_for_agg(agg);
78        let (old_agg_calls, old_group_keys, grouping_sets, input, enable_two_phase) =
79            agg.decompose();
80
81        let old_input_schema_len = input.schema().len();
82        let flag_col_idx = old_group_keys.len();
83
84        let column_subset = grouping_sets
85            .iter()
86            .map(|set| set.indices().collect_vec())
87            .collect_vec();
88
89        let expand = LogicalExpand::create(input, column_subset.clone());
90        let new_group_keys = {
91            let mut k = old_group_keys.clone();
92            // Add the expand flag.
93            k.extend(std::iter::once(expand.schema().len() - 1));
94            k
95        };
96
97        // Map from old input ref to expanded input (`LogicalExpand` prepends the same number of fields
98        // as expanded ones with NULLs before the real input fields).
99        let mut input_col_change =
100            ColIndexMapping::with_shift_offset(old_input_schema_len, old_input_schema_len as isize);
101
102        // Grouping agg calls need to be transformed into a project expression, and other agg calls
103        // need to shift their `input_ref`.
104        let mut project_agg_call_exprs = vec![];
105        let mut new_agg_calls = vec![];
106        for agg_call in old_agg_calls {
107            // Deal with grouping agg call for grouping sets.
108            if matches!(agg_call.agg_type, AggType::Builtin(PbAggKind::Grouping)) {
109                let mut grouping_values = vec![];
110                let args = agg_call
111                    .inputs
112                    .iter()
113                    .map(|input_ref| input_ref.index)
114                    .collect_vec();
115                for subset in &column_subset {
116                    let mut value = 0;
117                    for arg in &args {
118                        value <<= 1;
119                        if !subset.contains(arg) {
120                            value += 1;
121                        }
122                    }
123                    grouping_values.push(value);
124                }
125
126                let mut case_inputs = vec![];
127                for (i, grouping_value) in grouping_values.into_iter().enumerate() {
128                    let condition = ExprImpl::FunctionCall(
129                        FunctionCall::new_unchecked(
130                            ExprType::Equal,
131                            vec![
132                                ExprImpl::literal_bigint(i as i64),
133                                ExprImpl::InputRef(
134                                    InputRef::new(flag_col_idx, DataType::Int64).into(),
135                                ),
136                            ],
137                            DataType::Boolean,
138                        )
139                        .into(),
140                    );
141                    let value = ExprImpl::literal_int(grouping_value);
142                    case_inputs.push(condition);
143                    case_inputs.push(value);
144                }
145
146                let case_expr = ExprImpl::FunctionCall(
147                    FunctionCall::new_unchecked(ExprType::Case, case_inputs, DataType::Int32)
148                        .into(),
149                );
150                project_agg_call_exprs.push(case_expr);
151            } else {
152                let mut new_agg_call = agg_call;
153                // Shift agg_call to the original input columns
154                new_agg_call.inputs.iter_mut().for_each(|i| {
155                    let new_i = input_col_change.map(i.index());
156                    assert_eq!(expand.schema()[new_i].data_type(), i.return_type());
157                    *i = InputRef::new(new_i, i.return_type());
158                });
159                new_agg_call.order_by.iter_mut().for_each(|o| {
160                    o.column_index = input_col_change.map(o.column_index);
161                });
162                new_agg_call.filter = new_agg_call.filter.rewrite_expr(&mut input_col_change);
163                project_agg_call_exprs.push(ExprImpl::InputRef(
164                    InputRef::new(
165                        new_group_keys.len() + new_agg_calls.len(),
166                        new_agg_call.return_type.clone(),
167                    )
168                    .into(),
169                ));
170                new_agg_calls.push(new_agg_call);
171            }
172        }
173
174        let new_agg =
175            Agg::new(new_agg_calls, new_group_keys, expand).with_enable_two_phase(enable_two_phase);
176        let project_exprs = (0..old_group_keys.len())
177            .map(|i| ExprImpl::InputRef(InputRef::new(i, new_agg.schema()[i].data_type()).into()))
178            .chain(project_agg_call_exprs)
179            .collect();
180
181        let project = LogicalProject::new(new_agg.into(), project_exprs);
182
183        Some(project.into())
184    }
185}