risingwave_frontend/optimizer/rule/
apply_agg_transpose_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use risingwave_common::types::DataType;
16use risingwave_expr::aggregate::{AggType, PbAggKind};
17use risingwave_pb::plan_common::JoinType;
18
19use super::{ApplyOffsetRewriter, BoxedRule, Rule};
20use crate::expr::{ExprImpl, ExprType, FunctionCall, InputRef};
21use crate::optimizer::PlanRef;
22use crate::optimizer::plan_node::generic::Agg;
23use crate::optimizer::plan_node::{LogicalAgg, LogicalApply, LogicalFilter, LogicalProject};
24use crate::utils::{Condition, IndexSet};
25
26/// Transpose `LogicalApply` and `LogicalAgg`.
27///
28/// Before:
29///
30/// ```text
31///     LogicalApply
32///    /            \
33///  Domain      LogicalAgg
34///                  |
35///                Input
36/// ```
37///
38/// After:
39///
40/// ```text
41///      LogicalAgg
42///          |
43///     LogicalApply
44///    /            \
45///  Domain        Input
46/// ```
47pub struct ApplyAggTransposeRule {}
48impl Rule for ApplyAggTransposeRule {
49    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
50        let apply: &LogicalApply = plan.as_logical_apply()?;
51        let (left, right, on, join_type, correlated_id, correlated_indices, max_one_row) =
52            apply.clone().decompose();
53        assert_eq!(join_type, JoinType::Inner);
54        let agg: &LogicalAgg = right.as_logical_agg()?;
55        let (mut agg_calls, agg_group_key, grouping_sets, agg_input, enable_two_phase) =
56            agg.clone().decompose();
57        assert!(grouping_sets.is_empty());
58        let is_scalar_agg = agg_group_key.is_empty();
59        let apply_left_len = left.schema().len();
60
61        if !is_scalar_agg && max_one_row {
62            // We can only eliminate max_one_row for scalar aggregation.
63            return None;
64        }
65
66        let input = if is_scalar_agg {
67            // add a constant column to help convert count(*) to count(c) where c is non-nullable.
68            let mut exprs: Vec<ExprImpl> = agg_input
69                .schema()
70                .data_types()
71                .into_iter()
72                .enumerate()
73                .map(|(i, data_type)| InputRef::new(i, data_type).into())
74                .collect();
75            exprs.push(ExprImpl::literal_int(1));
76            LogicalProject::create(agg_input, exprs)
77        } else {
78            agg_input
79        };
80
81        let node = if is_scalar_agg {
82            // LOJ Apply need to be converted to cross Apply.
83            let left_len = left.schema().len();
84            let eq_predicates = left
85                .schema()
86                .data_types()
87                .into_iter()
88                .enumerate()
89                .map(|(i, data_type)| {
90                    let left = InputRef::new(i, data_type.clone());
91                    let right = InputRef::new(i + left_len, data_type);
92                    // use null-safe equal
93                    FunctionCall::new_unchecked(
94                        ExprType::IsNotDistinctFrom,
95                        vec![left.into(), right.into()],
96                        DataType::Boolean,
97                    )
98                    .into()
99                })
100                .collect();
101            LogicalApply::new(
102                left.clone(),
103                input,
104                JoinType::LeftOuter,
105                Condition::true_cond(),
106                correlated_id,
107                correlated_indices.clone(),
108                false,
109                false,
110            )
111            .translate_apply(left, eq_predicates)
112        } else {
113            LogicalApply::create(
114                left,
115                input,
116                JoinType::Inner,
117                Condition::true_cond(),
118                correlated_id,
119                correlated_indices.clone(),
120                false,
121            )
122        };
123
124        let group_agg = {
125            // shift index of agg_calls' `InputRef` with `apply_left_len`.
126            let offset = apply_left_len as isize;
127            let mut rewriter =
128                ApplyOffsetRewriter::new(apply_left_len, &correlated_indices, correlated_id);
129            agg_calls.iter_mut().for_each(|agg_call| {
130                agg_call.inputs.iter_mut().for_each(|input_ref| {
131                    input_ref.shift_with_offset(offset);
132                });
133                agg_call
134                    .order_by
135                    .iter_mut()
136                    .for_each(|o| o.shift_with_offset(offset));
137                agg_call.filter = agg_call.filter.clone().rewrite_expr(&mut rewriter);
138            });
139            if is_scalar_agg {
140                // convert count(*) to count(1).
141                let pos_of_constant_column = node.schema().len() - 1;
142                agg_calls.iter_mut().for_each(|agg_call| {
143                    match agg_call.agg_type {
144                        AggType::Builtin(PbAggKind::Count) if agg_call.inputs.is_empty() => {
145                            let input_ref = InputRef::new(pos_of_constant_column, DataType::Int32);
146                            agg_call.inputs.push(input_ref);
147                        }
148                        AggType::Builtin(PbAggKind::ArrayAgg
149                        | PbAggKind::JsonbAgg
150                        | PbAggKind::JsonbObjectAgg)
151                        | AggType::UserDefined(_)
152                        | AggType::WrapScalar(_) => {
153                            let input_ref = InputRef::new(pos_of_constant_column, DataType::Int32);
154                            let cond = FunctionCall::new(ExprType::IsNotNull, vec![input_ref.into()]).unwrap();
155                            agg_call.filter.conjunctions.push(cond.into());
156                        }
157                        AggType::Builtin(PbAggKind::Count
158                        | PbAggKind::Sum
159                        | PbAggKind::Sum0
160                        | PbAggKind::Avg
161                        | PbAggKind::Min
162                        | PbAggKind::Max
163                        | PbAggKind::BitAnd
164                        | PbAggKind::BitOr
165                        | PbAggKind::BitXor
166                        | PbAggKind::BoolAnd
167                        | PbAggKind::BoolOr
168                        | PbAggKind::StringAgg
169                        // not in PostgreSQL
170                        | PbAggKind::ApproxCountDistinct
171                        | PbAggKind::FirstValue
172                        | PbAggKind::LastValue
173                        | PbAggKind::InternalLastSeenValue
174                        // All statistical aggregates only consider non-null inputs.
175                        | PbAggKind::ApproxPercentile
176                        | PbAggKind::VarPop
177                        | PbAggKind::VarSamp
178                        | PbAggKind::StddevPop
179                        | PbAggKind::StddevSamp
180                        // All ordered-set aggregates ignore null values in their aggregated input.
181                        | PbAggKind::PercentileCont
182                        | PbAggKind::PercentileDisc
183                        | PbAggKind::Mode
184                        // `grouping` has no *aggregate* input and unreachable when `is_scalar_agg`.
185                        | PbAggKind::Grouping)
186                        => {
187                            // no-op when `agg(0 rows) == agg(1 row of nulls)`
188                        }
189                        AggType::Builtin(PbAggKind::Unspecified | PbAggKind::UserDefined | PbAggKind::WrapScalar) => {
190                            panic!("Unexpected aggregate function: {:?}", agg_call.agg_type)
191                        }
192                    }
193                });
194            }
195            let mut group_keys: IndexSet = (0..apply_left_len).collect();
196            group_keys.extend(agg_group_key.indices().map(|key| key + apply_left_len));
197            Agg::new(agg_calls, group_keys, node)
198                .with_enable_two_phase(enable_two_phase)
199                .into()
200        };
201
202        let filter = LogicalFilter::create(group_agg, on);
203        Some(filter)
204    }
205}
206
207impl ApplyAggTransposeRule {
208    pub fn create() -> BoxedRule {
209        Box::new(ApplyAggTransposeRule {})
210    }
211}