risingwave_frontend/optimizer/rule/
grouping_sets_to_expand_rule.rs1use fixedbitset::FixedBitSet;
16use itertools::Itertools;
17use risingwave_common::types::DataType;
18use risingwave_common::util::column_index_mapping::ColIndexMapping;
19use risingwave_expr::aggregate::{AggType, PbAggKind};
20
21use super::super::plan_node::*;
22use super::{BoxedRule, Rule};
23use crate::expr::{Expr, ExprImpl, ExprType, FunctionCall, InputRef};
24use crate::optimizer::plan_node::generic::{Agg, GenericPlanNode, GenericPlanRef};
25pub struct GroupingSetsToExpandRule {}
26
27impl GroupingSetsToExpandRule {
28 pub fn create() -> BoxedRule {
29 Box::new(Self {})
30 }
31
32 fn prune_column_for_agg(agg: &LogicalAgg) -> LogicalAgg {
34 let group_key_required_cols = agg.group_key().to_bitset();
35 let agg_call_required_cols = {
36 let input_cnt = agg.input().schema().len();
37 let mut tmp = FixedBitSet::with_capacity(input_cnt);
38
39 agg.agg_calls().iter().for_each(|agg_call| {
40 tmp.extend(agg_call.inputs.iter().map(|x| x.index()));
41 tmp.extend(agg_call.order_by.iter().map(|x| x.column_index));
42 for i in &agg_call.filter.conjunctions {
44 tmp.union_with(&i.collect_input_refs(input_cnt));
45 }
46 });
47 tmp
48 };
49
50 let input_required_cols = {
51 let mut tmp = FixedBitSet::with_capacity(agg.input().schema().len());
52 tmp.union_with(&group_key_required_cols);
53 tmp.union_with(&agg_call_required_cols);
54 tmp.ones().collect_vec()
55 };
56 let input_col_change = ColIndexMapping::with_remaining_columns(
57 &input_required_cols,
58 agg.input().schema().len(),
59 );
60 let input =
61 LogicalProject::with_out_col_idx(agg.input(), input_required_cols.iter().cloned())
62 .into();
63
64 let (new_agg, output_col_change) =
65 agg.rewrite_with_input_agg(input, agg.agg_calls(), input_col_change);
66 assert!(output_col_change.is_identity());
67 new_agg
68 }
69}
70
71impl Rule for GroupingSetsToExpandRule {
72 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
73 let agg: &LogicalAgg = plan.as_logical_agg()?;
74 if agg.grouping_sets().is_empty() {
75 return None;
76 }
77 let agg = Self::prune_column_for_agg(agg);
78 let (old_agg_calls, old_group_keys, grouping_sets, input, enable_two_phase) =
79 agg.decompose();
80
81 let old_input_schema_len = input.schema().len();
82 let flag_col_idx = old_group_keys.len();
83
84 let column_subset = grouping_sets
85 .iter()
86 .map(|set| set.indices().collect_vec())
87 .collect_vec();
88
89 let expand = LogicalExpand::create(input, column_subset.clone());
90 let new_group_keys = {
91 let mut k = old_group_keys.clone();
92 k.extend(std::iter::once(expand.schema().len() - 1));
94 k
95 };
96
97 let mut input_col_change =
100 ColIndexMapping::with_shift_offset(old_input_schema_len, old_input_schema_len as isize);
101
102 let mut project_agg_call_exprs = vec![];
105 let mut new_agg_calls = vec![];
106 for agg_call in old_agg_calls {
107 if matches!(agg_call.agg_type, AggType::Builtin(PbAggKind::Grouping)) {
109 let mut grouping_values = vec![];
110 let args = agg_call
111 .inputs
112 .iter()
113 .map(|input_ref| input_ref.index)
114 .collect_vec();
115 for subset in &column_subset {
116 let mut value = 0;
117 for arg in &args {
118 value <<= 1;
119 if !subset.contains(arg) {
120 value += 1;
121 }
122 }
123 grouping_values.push(value);
124 }
125
126 let mut case_inputs = vec![];
127 for (i, grouping_value) in grouping_values.into_iter().enumerate() {
128 let condition = ExprImpl::FunctionCall(
129 FunctionCall::new_unchecked(
130 ExprType::Equal,
131 vec![
132 ExprImpl::literal_bigint(i as i64),
133 ExprImpl::InputRef(
134 InputRef::new(flag_col_idx, DataType::Int64).into(),
135 ),
136 ],
137 DataType::Boolean,
138 )
139 .into(),
140 );
141 let value = ExprImpl::literal_int(grouping_value);
142 case_inputs.push(condition);
143 case_inputs.push(value);
144 }
145
146 let case_expr = ExprImpl::FunctionCall(
147 FunctionCall::new_unchecked(ExprType::Case, case_inputs, DataType::Int32)
148 .into(),
149 );
150 project_agg_call_exprs.push(case_expr);
151 } else {
152 let mut new_agg_call = agg_call;
153 new_agg_call.inputs.iter_mut().for_each(|i| {
155 let new_i = input_col_change.map(i.index());
156 assert_eq!(expand.schema()[new_i].data_type(), i.return_type());
157 *i = InputRef::new(new_i, i.return_type());
158 });
159 new_agg_call.order_by.iter_mut().for_each(|o| {
160 o.column_index = input_col_change.map(o.column_index);
161 });
162 new_agg_call.filter = new_agg_call.filter.rewrite_expr(&mut input_col_change);
163 project_agg_call_exprs.push(ExprImpl::InputRef(
164 InputRef::new(
165 new_group_keys.len() + new_agg_calls.len(),
166 new_agg_call.return_type.clone(),
167 )
168 .into(),
169 ));
170 new_agg_calls.push(new_agg_call);
171 }
172 }
173
174 let new_agg =
175 Agg::new(new_agg_calls, new_group_keys, expand).with_enable_two_phase(enable_two_phase);
176 let project_exprs = (0..old_group_keys.len())
177 .map(|i| ExprImpl::InputRef(InputRef::new(i, new_agg.schema()[i].data_type()).into()))
178 .chain(project_agg_call_exprs)
179 .collect();
180
181 let project = LogicalProject::new(new_agg.into(), project_exprs);
182
183 Some(project.into())
184 }
185}