risingwave_frontend/optimizer/rule/
distinct_agg_rule.rs1use std::collections::HashMap;
16use std::mem;
17
18use fixedbitset::FixedBitSet;
19use itertools::Itertools;
20use risingwave_common::types::DataType;
21use risingwave_expr::aggregate::{AggType, PbAggKind, agg_types};
22
23use super::{BoxedRule, Rule};
24use crate::expr::{CollectInputRef, ExprType, FunctionCall, InputRef, Literal};
25use crate::optimizer::PlanRef;
26use crate::optimizer::plan_node::generic::Agg;
27use crate::optimizer::plan_node::{LogicalAgg, LogicalExpand, LogicalProject, PlanAggCall};
28use crate::utils::{ColIndexMapping, Condition, IndexSet};
29
30pub struct DistinctAggRule {
32 for_stream: bool,
33}
34
35impl Rule for DistinctAggRule {
36 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
37 let agg: &LogicalAgg = plan.as_logical_agg()?;
38 let (mut agg_calls, mut agg_group_keys, grouping_sets, input, enable_two_phase) =
39 agg.clone().decompose();
40 assert!(grouping_sets.is_empty());
41
42 if agg_calls.iter().all(|c| !c.distinct) {
43 return None;
45 }
46
47 if self.for_stream && !agg_group_keys.is_empty() {
48 return None;
51 }
52
53 if !agg_calls.iter().all(|c| {
54 assert!(
55 !matches!(c.agg_type, agg_types::rewritten!()),
56 "We shouldn't see agg kind {} here",
57 c.agg_type
58 );
59 let agg_type_ok = !matches!(c.agg_type, agg_types::simply_cannot_two_phase!());
60 let order_ok = matches!(
61 c.agg_type,
62 agg_types::result_unaffected_by_order_by!()
63 | AggType::Builtin(PbAggKind::ApproxPercentile)
64 ) || c.order_by.is_empty();
65 agg_type_ok && order_ok
66 }) {
67 tracing::warn!("DistinctAggRule: unsupported agg kind, fallback to backend impl");
68 return None;
69 }
70
71 let (node, flag_values, has_expand) =
72 Self::build_expand(input, &mut agg_group_keys, &mut agg_calls);
73 let mid_agg =
74 Self::build_middle_agg(node, agg_group_keys.clone(), agg_calls.clone(), has_expand);
75
76 let mut final_agg_group_keys = IndexSet::empty();
79 for (i, v) in mid_agg.group_key.indices().enumerate() {
80 if agg_group_keys.contains(v) {
81 final_agg_group_keys.insert(i);
82 }
83 }
84
85 Some(Self::build_final_agg(
86 mid_agg,
87 final_agg_group_keys,
88 agg_calls,
89 flag_values,
90 has_expand,
91 enable_two_phase,
92 ))
93 }
94}
95
96impl DistinctAggRule {
97 pub fn create(for_stream: bool) -> BoxedRule {
98 Box::new(DistinctAggRule { for_stream })
99 }
100
101 fn build_expand(
109 input: PlanRef,
110 group_keys: &mut IndexSet,
111 agg_calls: &mut Vec<PlanAggCall>,
112 ) -> (PlanRef, Vec<usize>, bool) {
113 let input_schema_len = input.schema().len();
114 let mut column_subsets = vec![];
117 let mut flag_values = vec![];
119 let mut hash_map = HashMap::new();
121 let (distinct_aggs, non_distinct_aggs): (Vec<_>, Vec<_>) =
122 agg_calls.iter().partition(|agg_call| agg_call.distinct);
123 assert!(!distinct_aggs.is_empty());
124
125 if !non_distinct_aggs.is_empty() {
126 let subset = {
127 let mut subset = group_keys.clone();
128 non_distinct_aggs.iter().for_each(|agg_call| {
129 subset.extend(agg_call.input_indices());
130 });
131 subset.to_vec()
132 };
133 hash_map.insert(subset.clone(), 0);
134 column_subsets.push(subset);
135 }
136
137 distinct_aggs.iter().for_each(|agg_call| {
138 let subset = {
139 let mut subset = group_keys.clone();
140 subset.extend(agg_call.input_indices());
141 subset.to_vec()
142 };
143 if let Some(i) = hash_map.get(&subset) {
144 flag_values.push(*i);
145 } else {
146 let flag_value = column_subsets.len();
147 flag_values.push(flag_value);
148 hash_map.insert(subset.clone(), flag_value);
149 column_subsets.push(subset);
150 }
151 });
152
153 let n_different_distinct = distinct_aggs
154 .iter()
155 .unique_by(|agg_call| agg_call.input_indices()[0])
156 .count();
157 assert_ne!(n_different_distinct, 0); if n_different_distinct == 1 {
159 return (input, flag_values, false);
161 }
162
163 let expand = LogicalExpand::create(input, column_subsets);
164 let project = Self::build_project(input_schema_len, expand, group_keys, agg_calls);
166 (project, flag_values, true)
167 }
168
169 fn build_project(
171 input_schema_len: usize,
172 expand: PlanRef,
173 group_keys: &mut IndexSet,
174 agg_calls: &mut Vec<PlanAggCall>,
175 ) -> PlanRef {
176 let mut shift_with_offset =
178 ColIndexMapping::with_shift_offset(input_schema_len, input_schema_len as isize);
179 for agg_call in &mut *agg_calls {
180 agg_call.filter = mem::replace(&mut agg_call.filter, Condition::true_cond())
181 .rewrite_expr(&mut shift_with_offset);
182 }
183
184 let expand_schema_len = expand.schema().len();
186 let mut input_indices = CollectInputRef::with_capacity(expand_schema_len);
187 input_indices.extend(group_keys.indices());
188 for agg_call in &*agg_calls {
189 input_indices.extend(agg_call.input_indices());
190 agg_call.filter.visit_expr(&mut input_indices);
191 }
192 input_indices.extend(vec![expand_schema_len - 1]);
194 let mut mapping = ColIndexMapping::with_remaining_columns(
195 &FixedBitSet::from(input_indices).ones().collect_vec(),
196 expand_schema_len,
197 );
198
199 let mut new_group_keys = IndexSet::empty();
201 for i in group_keys.indices() {
202 new_group_keys.insert(mapping.map(i))
203 }
204 *group_keys = new_group_keys;
205 for agg_call in agg_calls {
206 for input in &mut agg_call.inputs {
207 input.index = mapping.map(input.index);
208 }
209 agg_call.filter = mem::replace(&mut agg_call.filter, Condition::true_cond())
210 .rewrite_expr(&mut mapping);
211 }
212
213 LogicalProject::with_mapping(expand, mapping).into()
214 }
215
216 fn build_middle_agg(
217 project: PlanRef,
218 mut group_keys: IndexSet,
219 agg_calls: Vec<PlanAggCall>,
220 has_expand: bool,
221 ) -> Agg<PlanRef> {
222 let agg_calls = agg_calls
225 .into_iter()
226 .filter_map(|mut agg_call| {
227 if agg_call.distinct {
228 group_keys.extend(agg_call.input_indices());
230 if agg_call.filter.always_true() {
233 return None;
234 }
235 agg_call = PlanAggCall::count_star().with_condition(agg_call.filter);
237 }
238 Some(agg_call)
239 })
240 .collect_vec();
241 if has_expand {
242 group_keys.insert(project.schema().len() - 1);
244 }
245 Agg::new(agg_calls, group_keys, project).with_enable_two_phase(false)
246 }
247
248 fn build_final_agg(
249 mid_agg: Agg<PlanRef>,
250 final_agg_group_keys: IndexSet,
251 mut agg_calls: Vec<PlanAggCall>,
252 flag_values: Vec<usize>,
253 has_expand: bool,
254 enable_two_phase: bool,
255 ) -> PlanRef {
256 let pos_of_flag = mid_agg.group_key.len() - 1;
258 let mut flag_values = flag_values.into_iter();
259
260 let mut index_of_middle_agg = mid_agg.group_key.len();
268 agg_calls.iter_mut().for_each(|agg_call| {
269 let flag_value = if agg_call.distinct {
270 agg_call.distinct = false;
271
272 agg_call.inputs.iter_mut().for_each(|input_ref| {
273 input_ref.index = mid_agg
274 .group_key
275 .indices()
276 .position(|x| x == input_ref.index)
277 .unwrap();
278 });
279
280 if !agg_call.filter.always_true() {
283 let check_count = FunctionCall::new(
285 ExprType::GreaterThan,
286 vec![
287 InputRef::new(index_of_middle_agg, DataType::Int64).into(),
288 Literal::new(Some(0_i64.into()), DataType::Int64).into(),
289 ],
290 )
291 .unwrap();
292 index_of_middle_agg += 1;
293 agg_call.filter.conjunctions = vec![check_count.into()];
294 }
295
296 flag_values.next().unwrap() as i64
297 } else {
298 agg_call.inputs = vec![InputRef::new(
300 index_of_middle_agg,
301 agg_call.return_type.clone(),
302 )];
303 index_of_middle_agg += 1;
304
305 agg_call.filter = Condition::true_cond();
307
308 agg_call.agg_type = agg_call.agg_type.partial_to_total().expect(
310 "we should get a valid total phase agg kind here since unsupported cases have been filtered out"
311 );
312
313 0
316 };
317 if has_expand {
318 let filter_expr = FunctionCall::new(
320 ExprType::Equal,
321 vec![
322 InputRef::new(pos_of_flag, DataType::Int64).into(),
323 Literal::new(Some(flag_value.into()), DataType::Int64).into(),
324 ],
325 )
326 .unwrap();
327 agg_call.filter.conjunctions.push(filter_expr.into());
328 }
329 });
330
331 Agg::new(agg_calls, final_agg_group_keys, mid_agg.into())
332 .with_enable_two_phase(enable_two_phase)
333 .into()
334 }
335}