risingwave_frontend/optimizer/rule/
distinct_agg_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::mem;
17
18use fixedbitset::FixedBitSet;
19use itertools::Itertools;
20use risingwave_common::types::DataType;
21use risingwave_expr::aggregate::{AggType, PbAggKind, agg_types};
22
23use super::{BoxedRule, Rule};
24use crate::expr::{CollectInputRef, ExprType, FunctionCall, InputRef, Literal};
25use crate::optimizer::PlanRef;
26use crate::optimizer::plan_node::generic::Agg;
27use crate::optimizer::plan_node::{LogicalAgg, LogicalExpand, LogicalProject, PlanAggCall};
28use crate::utils::{ColIndexMapping, Condition, IndexSet};
29
30/// Transform distinct aggregates to `LogicalAgg` -> `LogicalAgg` -> `Expand` -> `Input`.
31pub struct DistinctAggRule {
32    for_stream: bool,
33}
34
35impl Rule for DistinctAggRule {
36    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
37        let agg: &LogicalAgg = plan.as_logical_agg()?;
38        let (mut agg_calls, mut agg_group_keys, grouping_sets, input, enable_two_phase) =
39            agg.clone().decompose();
40        assert!(grouping_sets.is_empty());
41
42        if agg_calls.iter().all(|c| !c.distinct) {
43            // there's no distinct agg call
44            return None;
45        }
46
47        if self.for_stream && !agg_group_keys.is_empty() {
48            // Due to performance issue, we don't do 2-phase agg for stream distinct agg with group
49            // by. See https://github.com/risingwavelabs/risingwave/issues/7271 for more.
50            return None;
51        }
52
53        if !agg_calls.iter().all(|c| {
54            assert!(
55                !matches!(c.agg_type, agg_types::rewritten!()),
56                "We shouldn't see agg kind {} here",
57                c.agg_type
58            );
59            let agg_type_ok = !matches!(c.agg_type, agg_types::simply_cannot_two_phase!());
60            let order_ok = matches!(
61                c.agg_type,
62                agg_types::result_unaffected_by_order_by!()
63                    | AggType::Builtin(PbAggKind::ApproxPercentile)
64            ) || c.order_by.is_empty();
65            agg_type_ok && order_ok
66        }) {
67            tracing::warn!("DistinctAggRule: unsupported agg kind, fallback to backend impl");
68            return None;
69        }
70
71        let (node, flag_values, has_expand) =
72            Self::build_expand(input, &mut agg_group_keys, &mut agg_calls);
73        let mid_agg =
74            Self::build_middle_agg(node, agg_group_keys.clone(), agg_calls.clone(), has_expand);
75
76        // The middle agg will extend some fields for `agg_group_keys`, so we need to find out the
77        // original group key for the final agg.
78        let mut final_agg_group_keys = IndexSet::empty();
79        for (i, v) in mid_agg.group_key.indices().enumerate() {
80            if agg_group_keys.contains(v) {
81                final_agg_group_keys.insert(i);
82            }
83        }
84
85        Some(Self::build_final_agg(
86            mid_agg,
87            final_agg_group_keys,
88            agg_calls,
89            flag_values,
90            has_expand,
91            enable_two_phase,
92        ))
93    }
94}
95
96impl DistinctAggRule {
97    pub fn create(for_stream: bool) -> BoxedRule {
98        Box::new(DistinctAggRule { for_stream })
99    }
100
101    /// Construct `Expand` for distinct aggregates.
102    /// `group_keys` and `agg_calls` will be changed in `build_project` due to column pruning.
103    /// It returns either `LogicalProject` or original input, plus `flag_values` for every distinct
104    /// aggregate and `has_expand` as a flag.
105    ///
106    /// To simplify, we will first deduplicate `column_subsets` and then skip building
107    /// `Expand` if there is only one `subset`.
108    fn build_expand(
109        input: PlanRef,
110        group_keys: &mut IndexSet,
111        agg_calls: &mut Vec<PlanAggCall>,
112    ) -> (PlanRef, Vec<usize>, bool) {
113        let input_schema_len = input.schema().len();
114        // each `subset` in `column_subsets` consists of `group_keys`, `agg_call`'s input indices
115        // and the input indices of `agg_call`'s `filter`.
116        let mut column_subsets = vec![];
117        // flag values of distinct aggregates.
118        let mut flag_values = vec![];
119        // mapping from `subset` to `flag_value`, which is used to deduplicate `column_subsets`.
120        let mut hash_map = HashMap::new();
121        let (distinct_aggs, non_distinct_aggs): (Vec<_>, Vec<_>) =
122            agg_calls.iter().partition(|agg_call| agg_call.distinct);
123        assert!(!distinct_aggs.is_empty());
124
125        if !non_distinct_aggs.is_empty() {
126            let subset = {
127                let mut subset = group_keys.clone();
128                non_distinct_aggs.iter().for_each(|agg_call| {
129                    subset.extend(agg_call.input_indices());
130                });
131                subset.to_vec()
132            };
133            hash_map.insert(subset.clone(), 0);
134            column_subsets.push(subset);
135        }
136
137        distinct_aggs.iter().for_each(|agg_call| {
138            let subset = {
139                let mut subset = group_keys.clone();
140                subset.extend(agg_call.input_indices());
141                subset.to_vec()
142            };
143            if let Some(i) = hash_map.get(&subset) {
144                flag_values.push(*i);
145            } else {
146                let flag_value = column_subsets.len();
147                flag_values.push(flag_value);
148                hash_map.insert(subset.clone(), flag_value);
149                column_subsets.push(subset);
150            }
151        });
152
153        let n_different_distinct = distinct_aggs
154            .iter()
155            .unique_by(|agg_call| agg_call.input_indices()[0])
156            .count();
157        assert_ne!(n_different_distinct, 0); // since `distinct_aggs` is not empty here
158        if n_different_distinct == 1 {
159            // no need to have expand if there is only one distinct aggregates.
160            return (input, flag_values, false);
161        }
162
163        let expand = LogicalExpand::create(input, column_subsets);
164        // manual version of column pruning for expand.
165        let project = Self::build_project(input_schema_len, expand, group_keys, agg_calls);
166        (project, flag_values, true)
167    }
168
169    /// Used to do column pruning for `Expand`.
170    fn build_project(
171        input_schema_len: usize,
172        expand: PlanRef,
173        group_keys: &mut IndexSet,
174        agg_calls: &mut Vec<PlanAggCall>,
175    ) -> PlanRef {
176        // shift the indices of filter first to make later rewrite more convenient.
177        let mut shift_with_offset =
178            ColIndexMapping::with_shift_offset(input_schema_len, input_schema_len as isize);
179        for agg_call in &mut *agg_calls {
180            agg_call.filter = mem::replace(&mut agg_call.filter, Condition::true_cond())
181                .rewrite_expr(&mut shift_with_offset);
182        }
183
184        // collect indices.
185        let expand_schema_len = expand.schema().len();
186        let mut input_indices = CollectInputRef::with_capacity(expand_schema_len);
187        input_indices.extend(group_keys.indices());
188        for agg_call in &*agg_calls {
189            input_indices.extend(agg_call.input_indices());
190            agg_call.filter.visit_expr(&mut input_indices);
191        }
192        // append `flag`.
193        input_indices.extend(vec![expand_schema_len - 1]);
194        let mut mapping = ColIndexMapping::with_remaining_columns(
195            &FixedBitSet::from(input_indices).ones().collect_vec(),
196            expand_schema_len,
197        );
198
199        // remap indices.
200        let mut new_group_keys = IndexSet::empty();
201        for i in group_keys.indices() {
202            new_group_keys.insert(mapping.map(i))
203        }
204        *group_keys = new_group_keys;
205        for agg_call in agg_calls {
206            for input in &mut agg_call.inputs {
207                input.index = mapping.map(input.index);
208            }
209            agg_call.filter = mem::replace(&mut agg_call.filter, Condition::true_cond())
210                .rewrite_expr(&mut mapping);
211        }
212
213        LogicalProject::with_mapping(expand, mapping).into()
214    }
215
216    fn build_middle_agg(
217        project: PlanRef,
218        mut group_keys: IndexSet,
219        agg_calls: Vec<PlanAggCall>,
220        has_expand: bool,
221    ) -> Agg<PlanRef> {
222        // The middle `LogicalAgg` groups by (`agg_group_keys` + arguments of distinct aggregates +
223        // `flag`).
224        let agg_calls = agg_calls
225            .into_iter()
226            .filter_map(|mut agg_call| {
227                if agg_call.distinct {
228                    // collect distinct agg's input indices.
229                    group_keys.extend(agg_call.input_indices());
230                    // filter out distinct agg without real filter(i.e. filter that isn't always
231                    // true).
232                    if agg_call.filter.always_true() {
233                        return None;
234                    }
235                    // convert distinct agg with real filter to count(*) with original filter.
236                    agg_call = PlanAggCall::count_star().with_condition(agg_call.filter);
237                }
238                Some(agg_call)
239            })
240            .collect_vec();
241        if has_expand {
242            // append `flag`.
243            group_keys.insert(project.schema().len() - 1);
244        }
245        Agg::new(agg_calls, group_keys, project).with_enable_two_phase(false)
246    }
247
248    fn build_final_agg(
249        mid_agg: Agg<PlanRef>,
250        final_agg_group_keys: IndexSet,
251        mut agg_calls: Vec<PlanAggCall>,
252        flag_values: Vec<usize>,
253        has_expand: bool,
254        enable_two_phase: bool,
255    ) -> PlanRef {
256        // the index of `flag` in schema of the middle `LogicalAgg`, if has `Expand`.
257        let pos_of_flag = mid_agg.group_key.len() - 1;
258        let mut flag_values = flag_values.into_iter();
259
260        // ```ignore
261        // if has `Expand`, the input(middle agg) has the following schema:
262        // original group columns | distinct agg arguments | flag | count_star_with_filter or non-distinct agg
263        // <-                group                              -> <-             agg calls                 ->
264        // ```
265
266        // scan through `count_star_with_filter` or `non-distinct agg`.
267        let mut index_of_middle_agg = mid_agg.group_key.len();
268        agg_calls.iter_mut().for_each(|agg_call| {
269            let flag_value = if agg_call.distinct {
270                agg_call.distinct = false;
271
272                agg_call.inputs.iter_mut().for_each(|input_ref| {
273                    input_ref.index = mid_agg
274                        .group_key
275                        .indices()
276                        .position(|x| x == input_ref.index)
277                        .unwrap();
278                });
279
280                // distinct-agg with real filter has its corresponding middle agg, which is count(*)
281                // with its original filter.
282                if !agg_call.filter.always_true() {
283                    // make sure count(*) with original filter > 0.
284                    let check_count = FunctionCall::new(
285                        ExprType::GreaterThan,
286                        vec![
287                            InputRef::new(index_of_middle_agg, DataType::Int64).into(),
288                            Literal::new(Some(0_i64.into()), DataType::Int64).into(),
289                        ],
290                    )
291                    .unwrap();
292                    index_of_middle_agg += 1;
293                    agg_call.filter.conjunctions = vec![check_count.into()];
294                }
295
296                flag_values.next().unwrap() as i64
297            } else {
298                // non-distinct agg has its corresponding middle agg.
299                agg_call.inputs = vec![InputRef::new(
300                    index_of_middle_agg,
301                    agg_call.return_type.clone(),
302                )];
303                index_of_middle_agg += 1;
304
305                // the filter of non-distinct agg has been calculated in middle agg.
306                agg_call.filter = Condition::true_cond();
307
308                // change final agg's agg_type just like two-phase agg.
309                agg_call.agg_type = agg_call.agg_type.partial_to_total().expect(
310                    "we should get a valid total phase agg kind here since unsupported cases have been filtered out"
311                );
312
313                // the index of non-distinct aggs' subset in `column_subsets` is always 0 if it
314                // exists.
315                0
316            };
317            if has_expand {
318                // `filter_expr` is used to pick up the rows that are really needed by aggregates.
319                let filter_expr = FunctionCall::new(
320                    ExprType::Equal,
321                    vec![
322                        InputRef::new(pos_of_flag, DataType::Int64).into(),
323                        Literal::new(Some(flag_value.into()), DataType::Int64).into(),
324                    ],
325                )
326                .unwrap();
327                agg_call.filter.conjunctions.push(filter_expr.into());
328            }
329        });
330
331        Agg::new(agg_calls, final_agg_group_keys, mid_agg.into())
332            .with_enable_two_phase(enable_two_phase)
333            .into()
334    }
335}