risingwave_frontend/optimizer/rule/
unify_first_last_value_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16
17use risingwave_common::types::{DataType, StructType};
18use risingwave_common::util::sort_util::ColumnOrder;
19use risingwave_expr::aggregate::{AggType, PbAggKind};
20
21use super::prelude::{PlanRef, *};
22use crate::expr::{ExprImpl, ExprType, FunctionCall, InputRef, Literal};
23use crate::optimizer::plan_node::generic::{Agg, PlanAggCall};
24use crate::optimizer::plan_node::{LogicalProject, PlanTreeNodeUnary};
25
26/// Unifies `FIRST_VALUE` and `LAST_VALUE` aggregation calls with the same ordering to reuse a single aggregation call with ROW construction.
27///
28/// This rule is particularly beneficial for streaming queries as it reduces the number of states maintained.
29/// It specifically targets `FIRST_VALUE` and `LAST_VALUE` aggregation functions that have identical ordering clauses.
30///
31/// # Example transformation:
32///
33/// ## Before:
34/// ```sql
35/// SELECT
36///   LAST_VALUE(col1 ORDER BY col2) AS last_col1,
37///   LAST_VALUE(col3 ORDER BY col2) AS last_col3,
38///   LAST_VALUE(col4 ORDER BY col2) AS last_col4,
39///   LAST_VALUE(col5 ORDER BY col2) AS last_col5
40/// FROM table_name
41/// GROUP BY group_col;
42/// ```
43///
44/// ## After:
45/// ```sql
46/// SELECT
47///   (unified_last).f0 AS last_col1,
48///   (unified_last).f1 AS last_col3,
49///   (unified_last).f2 AS last_col4,
50///   (unified_last).f3 AS last_col5
51/// FROM (
52///   SELECT
53///     LAST_VALUE(ROW(col1, col3, col4, col5) ORDER BY col2) AS unified_last
54///   FROM table_name GROUP BY group_col
55/// ) sub;
56/// ```
57///
58/// # Plan transformation:
59///
60/// ## Before:
61/// ```text
62/// LogicalAgg [group_col]
63///  ├─agg_calls:
64///  │  ├─ LAST_VALUE(col1 ORDER BY col2)     -- State 1
65///  │  ├─ LAST_VALUE(col3 ORDER BY col2)     -- State 2
66///  │  ├─ LAST_VALUE(col4 ORDER BY col2)     -- State 3
67///  │  └─ LAST_VALUE(col5 ORDER BY col2)     -- State 4
68///  └─LogicalScan { table: table_name }
69/// ```
70///
71/// ## After:
72/// ```text
73/// LogicalProject
74///  ├─exprs: [(unified_last).f0, (unified_last).f1, (unified_last).f2, (unified_last).f3]
75///  └─LogicalAgg [group_col]
76///     ├─agg_calls:
77///     │  └─ LAST_VALUE(ROW(col1,col3,col4,col5) ORDER BY col2)  -- Single State!
78///     └─LogicalProject
79///        ├─exprs: [group_col, col1, col2, col3, col4, col5, ROW(col1,col3,col4,col5)]
80///        └─LogicalScan { table: table_name }
81/// ```
82///
83/// The key benefit: **4 aggregation states → 1 aggregation state** for streaming performance!
84pub struct UnifyFirstLastValueRule {}
85
86impl Rule<Logical> for UnifyFirstLastValueRule {
87    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
88        let agg = plan.as_logical_agg()?;
89
90        let calls = agg.agg_calls();
91        if calls.len() < 2 {
92            // Need at least 2 calls to merge
93            return None;
94        }
95
96        // Group calls by their "pattern" (agg_type + order_by + distinct + filter)
97        let mut pattern_groups: HashMap<AggPattern, Vec<(usize, &PlanAggCall)>> = HashMap::new();
98
99        for (idx, call) in calls.iter().enumerate() {
100            let PlanAggCall {
101                agg_type,
102                return_type: _,
103                inputs,
104                distinct,
105                order_by,
106                filter,
107                direct_args: _,
108            } = call;
109
110            if !self.is_supported_agg_type(agg_type) {
111                continue;
112            }
113
114            // Only support single input aggregations for now
115            if inputs.len() != 1 {
116                continue;
117            }
118
119            let pattern = AggPattern {
120                agg_type: agg_type.clone(),
121                order_by: order_by.clone(),
122                distinct: *distinct,
123                filter: filter.clone(),
124            };
125
126            pattern_groups.entry(pattern).or_default().push((idx, call));
127        }
128
129        // Find ALL patterns with multiple calls that can be merged
130        let mut mergeable_patterns: Vec<(AggPattern, Vec<(usize, &PlanAggCall)>)> = pattern_groups
131            .into_iter()
132            .filter(|(_, calls_in_pattern)| calls_in_pattern.len() >= 2)
133            .collect();
134
135        // Sort by the minimum original index in each pattern to ensure stable optimization results
136        mergeable_patterns.sort_by_key(|(_, calls_in_pattern)| {
137            calls_in_pattern.iter().map(|(idx, _)| *idx).min().unwrap()
138        });
139
140        // Also sort calls within each pattern by their original index
141        for (_, calls_in_pattern) in &mut mergeable_patterns {
142            calls_in_pattern.sort_by_key(|(idx, _)| *idx);
143        }
144
145        if mergeable_patterns.is_empty() {
146            return None;
147        }
148
149        // Build the complete transformation pipeline:
150        // 1. Pre-projection: Construct ROW expressions for each mergeable pattern
151        // 2. Aggregation: Apply aggregation functions on ROW expressions
152        // 3. Post-projection: Extract individual fields from ROW results
153
154        // Step 1: Build pre-projection that constructs ROW expressions
155        let input_plan = agg.input();
156        let input_schema = input_plan.schema();
157        let mut pre_proj_exprs = Vec::new();
158
159        // Add all original input columns first
160        for i in 0..input_schema.len() {
161            pre_proj_exprs
162                .push(InputRef::new(i, input_schema.fields()[i].data_type.clone()).into());
163        }
164
165        // Add ROW construction expressions for each mergeable pattern
166        let mut pattern_to_row_col_idx = HashMap::new();
167        for (pattern, mergeable_calls) in &mergeable_patterns {
168            // Create ROW expression with inputs from this pattern
169            let row_inputs: Vec<ExprImpl> = mergeable_calls
170                .iter()
171                .map(|(_, call)| call.inputs[0].clone().into())
172                .collect();
173
174            let field_types = mergeable_calls
175                .iter()
176                .map(|(_, call)| call.return_type.clone());
177            let row_data_type = DataType::Struct(StructType::unnamed(field_types));
178
179            let row_expr = FunctionCall::new_unchecked(ExprType::Row, row_inputs, row_data_type);
180
181            let row_col_idx = pre_proj_exprs.len();
182            pre_proj_exprs.push(row_expr.into());
183            pattern_to_row_col_idx.insert(pattern.clone(), row_col_idx);
184        }
185
186        // Create pre-projection
187        let pre_projection = LogicalProject::create(input_plan.clone(), pre_proj_exprs);
188
189        // Step 2: Build new aggregation calls operating on ROW columns
190        let mut new_calls = Vec::new();
191        let mut original_to_output_mapping = Vec::new();
192
193        // Process mergeable patterns
194        for (pattern, mergeable_calls) in &mergeable_patterns {
195            let row_col_idx = pattern_to_row_col_idx[pattern];
196            let row_data_type = pre_projection.schema().fields()[row_col_idx]
197                .data_type
198                .clone();
199
200            // Create aggregation call that operates on the ROW column
201            let merged_call = PlanAggCall {
202                agg_type: pattern.agg_type.clone(),
203                return_type: row_data_type,
204                inputs: vec![InputRef::new(
205                    row_col_idx,
206                    pre_projection.schema().fields()[row_col_idx]
207                        .data_type
208                        .clone(),
209                )],
210                distinct: pattern.distinct,
211                order_by: pattern
212                    .order_by
213                    .iter()
214                    .map(|order| {
215                        // Adjust column references in ORDER BY to point to ROW column
216                        ColumnOrder::new(row_col_idx, order.order_type)
217                    })
218                    .collect(),
219                filter: pattern.filter.clone(),
220                direct_args: vec![],
221            };
222
223            let merged_call_idx = new_calls.len();
224            new_calls.push(merged_call);
225
226            // Map each original call to its field in the merged result
227            for (field_idx, (original_idx, _)) in mergeable_calls.iter().enumerate() {
228                original_to_output_mapping.push((*original_idx, merged_call_idx, Some(field_idx)));
229            }
230        }
231
232        // Add non-mergeable calls
233        for (original_idx, call) in calls.iter().enumerate() {
234            // Check if this call was already handled in mergeable patterns
235            if !original_to_output_mapping
236                .iter()
237                .any(|(idx, _, _)| *idx == original_idx)
238            {
239                let new_call_idx = new_calls.len();
240                new_calls.push(call.clone());
241                original_to_output_mapping.push((original_idx, new_call_idx, None));
242            }
243        }
244
245        // Sort mapping by original index to maintain order
246        original_to_output_mapping.sort_by_key(|(original_idx, _, _)| *original_idx);
247
248        // Create aggregation on pre-projection
249        let new_agg: PlanRef = Agg::new(new_calls, agg.group_key().clone(), pre_projection)
250            .with_enable_two_phase(agg.core().two_phase_agg_enabled())
251            .into();
252
253        // Step 3: Build post-projection to extract fields from ROW results
254        let mut post_proj_exprs = Vec::new();
255
256        // Add group key columns
257        for i in 0..agg.group_key().len() {
258            let group_col_idx = agg.group_key().indices().nth(i).unwrap();
259            post_proj_exprs.push(
260                InputRef::new(i, input_schema.fields()[group_col_idx].data_type.clone()).into(),
261            );
262        }
263
264        // Add aggregation result columns with proper field extraction
265        for (original_idx, new_call_idx, field_idx_opt) in original_to_output_mapping {
266            let original_return_type = calls[original_idx].return_type.clone();
267
268            if let Some(field_idx) = field_idx_opt {
269                // Extract field from ROW result using proper field access
270                let struct_ref = InputRef::new(
271                    agg.group_key().len() + new_call_idx,
272                    new_agg.schema().fields()[agg.group_key().len() + new_call_idx]
273                        .data_type
274                        .clone(),
275                );
276
277                let field_access = FunctionCall::new_unchecked(
278                    ExprType::Field,
279                    vec![
280                        struct_ref.into(),
281                        Literal::new(Some((field_idx as i32).into()), DataType::Int32).into(),
282                    ],
283                    original_return_type,
284                );
285
286                post_proj_exprs.push(field_access.into());
287            } else {
288                // Direct reference to non-merged call
289                post_proj_exprs.push(
290                    InputRef::new(agg.group_key().len() + new_call_idx, original_return_type)
291                        .into(),
292                );
293            }
294        }
295
296        Some(LogicalProject::create(new_agg, post_proj_exprs))
297    }
298}
299
300#[derive(Clone, Debug, PartialEq, Eq, Hash)]
301struct AggPattern {
302    agg_type: AggType,
303    order_by: Vec<ColumnOrder>,
304    distinct: bool,
305    filter: crate::utils::Condition,
306}
307
308impl UnifyFirstLastValueRule {
309    pub fn create() -> BoxedRule {
310        Box::new(Self {})
311    }
312
313    fn is_supported_agg_type(&self, agg_type: &AggType) -> bool {
314        matches!(
315            agg_type,
316            AggType::Builtin(PbAggKind::FirstValue) | AggType::Builtin(PbAggKind::LastValue)
317        )
318    }
319}