risingwave_frontend/optimizer/rule/
distinct_agg_rule.rs1use std::collections::HashMap;
16use std::mem;
17
18use fixedbitset::FixedBitSet;
19use itertools::Itertools;
20use risingwave_common::types::DataType;
21use risingwave_expr::aggregate::{AggType, PbAggKind, agg_types};
22
23use super::prelude::{PlanRef, *};
24use crate::expr::{CollectInputRef, ExprType, FunctionCall, InputRef, Literal};
25use crate::optimizer::plan_node::generic::Agg;
26use crate::optimizer::plan_node::{LogicalAgg, LogicalExpand, LogicalProject, PlanAggCall};
27use crate::utils::{ColIndexMapping, Condition, IndexSet};
28
29pub struct DistinctAggRule {
31 for_stream: bool,
32}
33
34impl Rule<Logical> for DistinctAggRule {
35 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
36 let agg: &LogicalAgg = plan.as_logical_agg()?;
37 let (mut agg_calls, mut agg_group_keys, grouping_sets, input, enable_two_phase) =
38 agg.clone().decompose();
39 assert!(grouping_sets.is_empty());
40
41 if agg_calls.iter().all(|c| !c.distinct) {
42 return None;
44 }
45
46 if self.for_stream && !agg_group_keys.is_empty() {
47 return None;
50 }
51
52 if !agg_calls.iter().all(|c| {
53 assert!(
54 !matches!(c.agg_type, agg_types::rewritten!()),
55 "We shouldn't see agg kind {} here",
56 c.agg_type
57 );
58 let agg_type_ok = !matches!(c.agg_type, agg_types::simply_cannot_two_phase!());
59 let order_ok = matches!(
60 c.agg_type,
61 agg_types::result_unaffected_by_order_by!()
62 | AggType::Builtin(PbAggKind::ApproxPercentile)
63 ) || c.order_by.is_empty();
64 agg_type_ok && order_ok
65 }) {
66 tracing::warn!("DistinctAggRule: unsupported agg kind, fallback to backend impl");
67 return None;
68 }
69
70 let (node, flag_values, has_expand) =
71 Self::build_expand(input, &mut agg_group_keys, &mut agg_calls);
72 let mid_agg =
73 Self::build_middle_agg(node, agg_group_keys.clone(), agg_calls.clone(), has_expand);
74
75 let mut final_agg_group_keys = IndexSet::empty();
78 for (i, v) in mid_agg.group_key.indices().enumerate() {
79 if agg_group_keys.contains(v) {
80 final_agg_group_keys.insert(i);
81 }
82 }
83
84 Some(Self::build_final_agg(
85 mid_agg,
86 final_agg_group_keys,
87 agg_calls,
88 flag_values,
89 has_expand,
90 enable_two_phase,
91 ))
92 }
93}
94
95impl DistinctAggRule {
96 pub fn create(for_stream: bool) -> BoxedRule {
97 Box::new(DistinctAggRule { for_stream })
98 }
99
100 fn build_expand(
108 input: PlanRef,
109 group_keys: &mut IndexSet,
110 agg_calls: &mut Vec<PlanAggCall>,
111 ) -> (PlanRef, Vec<usize>, bool) {
112 let input_schema_len = input.schema().len();
113 let mut column_subsets = vec![];
116 let mut flag_values = vec![];
118 let mut hash_map = HashMap::new();
120 let (distinct_aggs, non_distinct_aggs): (Vec<_>, Vec<_>) =
121 agg_calls.iter().partition(|agg_call| agg_call.distinct);
122 assert!(!distinct_aggs.is_empty());
123
124 if !non_distinct_aggs.is_empty() {
125 let subset = {
126 let mut subset = group_keys.clone();
127 non_distinct_aggs.iter().for_each(|agg_call| {
128 subset.extend(agg_call.input_indices());
129 });
130 subset.to_vec()
131 };
132 hash_map.insert(subset.clone(), 0);
133 column_subsets.push(subset);
134 }
135
136 distinct_aggs.iter().for_each(|agg_call| {
137 let subset = {
138 let mut subset = group_keys.clone();
139 subset.extend(agg_call.input_indices());
140 subset.to_vec()
141 };
142 if let Some(i) = hash_map.get(&subset) {
143 flag_values.push(*i);
144 } else {
145 let flag_value = column_subsets.len();
146 flag_values.push(flag_value);
147 hash_map.insert(subset.clone(), flag_value);
148 column_subsets.push(subset);
149 }
150 });
151
152 let n_different_distinct = distinct_aggs
153 .iter()
154 .unique_by(|agg_call| agg_call.input_indices()[0])
155 .count();
156 assert_ne!(n_different_distinct, 0); if n_different_distinct == 1 {
158 return (input, flag_values, false);
160 }
161
162 let expand = LogicalExpand::create(input, column_subsets);
163 let project = Self::build_project(input_schema_len, expand, group_keys, agg_calls);
165 (project, flag_values, true)
166 }
167
168 fn build_project(
170 input_schema_len: usize,
171 expand: PlanRef,
172 group_keys: &mut IndexSet,
173 agg_calls: &mut Vec<PlanAggCall>,
174 ) -> PlanRef {
175 let mut shift_with_offset =
177 ColIndexMapping::with_shift_offset(input_schema_len, input_schema_len as isize);
178 for agg_call in &mut *agg_calls {
179 agg_call.filter = mem::replace(&mut agg_call.filter, Condition::true_cond())
180 .rewrite_expr(&mut shift_with_offset);
181 }
182
183 let expand_schema_len = expand.schema().len();
185 let mut input_indices = CollectInputRef::with_capacity(expand_schema_len);
186 input_indices.extend(group_keys.indices());
187 for agg_call in &*agg_calls {
188 input_indices.extend(agg_call.input_indices());
189 agg_call.filter.visit_expr(&mut input_indices);
190 }
191 input_indices.extend(vec![expand_schema_len - 1]);
193 let mut mapping = ColIndexMapping::with_remaining_columns(
194 &FixedBitSet::from(input_indices).ones().collect_vec(),
195 expand_schema_len,
196 );
197
198 let mut new_group_keys = IndexSet::empty();
200 for i in group_keys.indices() {
201 new_group_keys.insert(mapping.map(i))
202 }
203 *group_keys = new_group_keys;
204 for agg_call in agg_calls {
205 for input in &mut agg_call.inputs {
206 input.index = mapping.map(input.index);
207 }
208 agg_call.filter = mem::replace(&mut agg_call.filter, Condition::true_cond())
209 .rewrite_expr(&mut mapping);
210 }
211
212 LogicalProject::with_mapping(expand, mapping).into()
213 }
214
215 fn build_middle_agg(
216 project: PlanRef,
217 mut group_keys: IndexSet,
218 agg_calls: Vec<PlanAggCall>,
219 has_expand: bool,
220 ) -> Agg<PlanRef> {
221 let agg_calls = agg_calls
224 .into_iter()
225 .filter_map(|mut agg_call| {
226 if agg_call.distinct {
227 group_keys.extend(agg_call.input_indices());
229 if agg_call.filter.always_true() {
232 return None;
233 }
234 agg_call = PlanAggCall::count_star().with_condition(agg_call.filter);
236 }
237 Some(agg_call)
238 })
239 .collect_vec();
240 if has_expand {
241 group_keys.insert(project.schema().len() - 1);
243 }
244 Agg::new(agg_calls, group_keys, project).with_enable_two_phase(false)
245 }
246
247 fn build_final_agg(
248 mid_agg: Agg<PlanRef>,
249 final_agg_group_keys: IndexSet,
250 mut agg_calls: Vec<PlanAggCall>,
251 flag_values: Vec<usize>,
252 has_expand: bool,
253 enable_two_phase: bool,
254 ) -> PlanRef {
255 let pos_of_flag = mid_agg.group_key.len() - 1;
257 let mut flag_values = flag_values.into_iter();
258
259 let mut index_of_middle_agg = mid_agg.group_key.len();
267 agg_calls.iter_mut().for_each(|agg_call| {
268 let flag_value = if agg_call.distinct {
269 agg_call.distinct = false;
270
271 agg_call.inputs.iter_mut().for_each(|input_ref| {
272 input_ref.index = mid_agg
273 .group_key
274 .indices()
275 .position(|x| x == input_ref.index)
276 .unwrap();
277 });
278
279 if !agg_call.filter.always_true() {
282 let check_count = FunctionCall::new(
284 ExprType::GreaterThan,
285 vec![
286 InputRef::new(index_of_middle_agg, DataType::Int64).into(),
287 Literal::new(Some(0_i64.into()), DataType::Int64).into(),
288 ],
289 )
290 .unwrap();
291 index_of_middle_agg += 1;
292 agg_call.filter.conjunctions = vec![check_count.into()];
293 }
294
295 flag_values.next().unwrap() as i64
296 } else {
297 agg_call.inputs = vec![InputRef::new(
299 index_of_middle_agg,
300 agg_call.return_type.clone(),
301 )];
302 index_of_middle_agg += 1;
303
304 agg_call.filter = Condition::true_cond();
306
307 agg_call.agg_type = agg_call.agg_type.partial_to_total().expect(
309 "we should get a valid total phase agg kind here since unsupported cases have been filtered out"
310 );
311
312 0
315 };
316 if has_expand {
317 let filter_expr = FunctionCall::new(
319 ExprType::Equal,
320 vec![
321 InputRef::new(pos_of_flag, DataType::Int64).into(),
322 Literal::new(Some(flag_value.into()), DataType::Int64).into(),
323 ],
324 )
325 .unwrap();
326 agg_call.filter.conjunctions.push(filter_expr.into());
327 }
328 });
329
330 Agg::new(agg_calls, final_agg_group_keys, mid_agg.into())
331 .with_enable_two_phase(enable_two_phase)
332 .into()
333 }
334}