risingwave_frontend/optimizer/rule/
unify_first_last_value_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16
17use risingwave_common::types::{DataType, StructType};
18use risingwave_common::util::sort_util::ColumnOrder;
19use risingwave_expr::aggregate::{AggType, PbAggKind};
20
21use super::prelude::{PlanRef, *};
22use crate::expr::{ExprImpl, ExprType, FunctionCall, InputRef, Literal};
23use crate::optimizer::plan_node::generic::{Agg, PlanAggCall};
24use crate::optimizer::plan_node::{LogicalProject, PlanTreeNodeUnary};
25
26/// Unifies `FIRST_VALUE` and `LAST_VALUE` aggregation calls with the same ordering to reuse a single aggregation call with ROW construction.
27///
28/// This rule is particularly beneficial for streaming queries as it reduces the number of states maintained.
29/// It specifically targets `FIRST_VALUE` and `LAST_VALUE` aggregation functions that have identical ordering clauses.
30///
31/// # Example transformation:
32///
33/// ## Before:
34/// ```sql
35/// SELECT
36///   LAST_VALUE(col1 ORDER BY col2) AS last_col1,
37///   LAST_VALUE(col3 ORDER BY col2) AS last_col3,
38///   LAST_VALUE(col4 ORDER BY col2) AS last_col4,
39///   LAST_VALUE(col5 ORDER BY col2) AS last_col5
40/// FROM table_name
41/// GROUP BY group_col;
42/// ```
43///
44/// ## After:
45/// ```sql
46/// SELECT
47///   (unified_last).f0 AS last_col1,
48///   (unified_last).f1 AS last_col3,
49///   (unified_last).f2 AS last_col4,
50///   (unified_last).f3 AS last_col5
51/// FROM (
52///   SELECT
53///     LAST_VALUE(ROW(col1, col3, col4, col5) ORDER BY col2) AS unified_last
54///   FROM table_name GROUP BY group_col
55/// ) sub;
56/// ```
57///
58/// # Plan transformation:
59///
60/// ## Before:
61/// ```text
62/// LogicalAgg [group_col]
63///  ├─agg_calls:
64///  │  ├─ LAST_VALUE(col1 ORDER BY col2)     -- State 1
65///  │  ├─ LAST_VALUE(col3 ORDER BY col2)     -- State 2
66///  │  ├─ LAST_VALUE(col4 ORDER BY col2)     -- State 3
67///  │  └─ LAST_VALUE(col5 ORDER BY col2)     -- State 4
68///  └─LogicalScan { table: table_name }
69/// ```
70///
71/// ## After:
72/// ```text
73/// LogicalProject
74///  ├─exprs: [(unified_last).f0, (unified_last).f1, (unified_last).f2, (unified_last).f3]
75///  └─LogicalAgg [group_col]
76///     ├─agg_calls:
77///     │  └─ LAST_VALUE(ROW(col1,col3,col4,col5) ORDER BY col2)  -- Single State!
78///     └─LogicalProject
79///        ├─exprs: [group_col, col1, col2, col3, col4, col5, ROW(col1,col3,col4,col5)]
80///        └─LogicalScan { table: table_name }
81/// ```
82///
83/// The key benefit: **4 aggregation states → 1 aggregation state** for streaming performance!
84pub struct UnifyFirstLastValueRule {}
85
86impl Rule<Logical> for UnifyFirstLastValueRule {
87    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
88        let agg = plan.as_logical_agg()?;
89
90        let calls = agg.agg_calls();
91        if calls.len() < 2 {
92            // Need at least 2 calls to merge
93            return None;
94        }
95
96        // Group calls by their "pattern" (agg_type + order_by + distinct + filter)
97        let mut pattern_groups: HashMap<AggPattern, Vec<(usize, &PlanAggCall)>> = HashMap::new();
98
99        for (idx, call) in calls.iter().enumerate() {
100            let PlanAggCall {
101                agg_type,
102                return_type: _,
103                inputs,
104                distinct,
105                order_by,
106                filter,
107                direct_args: _,
108            } = call;
109
110            if !self.is_supported_agg_type(agg_type) {
111                continue;
112            }
113
114            // Only support single input aggregations for now
115            if inputs.len() != 1 {
116                continue;
117            }
118
119            let pattern = AggPattern {
120                agg_type: agg_type.clone(),
121                order_by: order_by.clone(),
122                distinct: *distinct,
123                filter: filter.clone(),
124            };
125
126            pattern_groups.entry(pattern).or_default().push((idx, call));
127        }
128
129        // Find ALL patterns with multiple calls that can be merged
130        let mut mergeable_patterns: Vec<(AggPattern, Vec<(usize, &PlanAggCall)>)> = pattern_groups
131            .into_iter()
132            .filter(|(_, calls_in_pattern)| calls_in_pattern.len() >= 2)
133            .collect();
134
135        // Sort by the minimum original index in each pattern to ensure stable optimization results
136        mergeable_patterns.sort_by_key(|(_, calls_in_pattern)| {
137            calls_in_pattern.iter().map(|(idx, _)| *idx).min().unwrap()
138        });
139
140        // Also sort calls within each pattern by their original index
141        for (_, calls_in_pattern) in &mut mergeable_patterns {
142            calls_in_pattern.sort_by_key(|(idx, _)| *idx);
143        }
144
145        if mergeable_patterns.is_empty() {
146            return None;
147        }
148
149        // Build the complete transformation pipeline:
150        // 1. Pre-projection: Construct ROW expressions for each mergeable pattern
151        // 2. Aggregation: Apply aggregation functions on ROW expressions
152        // 3. Post-projection: Extract individual fields from ROW results
153
154        // Step 1: Build pre-projection that constructs ROW expressions
155        let input_plan = agg.input();
156        let input_schema = input_plan.schema();
157        let mut pre_proj_exprs = Vec::new();
158
159        // Add all original input columns first
160        for i in 0..input_schema.len() {
161            pre_proj_exprs
162                .push(InputRef::new(i, input_schema.fields()[i].data_type.clone()).into());
163        }
164
165        // Add ROW construction expressions for each mergeable pattern
166        let mut pattern_to_row_col_idx = HashMap::new();
167        for (pattern, mergeable_calls) in &mergeable_patterns {
168            // Create ROW expression with inputs from this pattern
169            let row_inputs: Vec<ExprImpl> = mergeable_calls
170                .iter()
171                .map(|(_, call)| call.inputs[0].clone().into())
172                .collect();
173
174            let field_types = mergeable_calls
175                .iter()
176                .map(|(_, call)| call.return_type.clone())
177                .collect::<Vec<_>>();
178            let row_data_type = DataType::Struct(StructType::unnamed(field_types));
179
180            let row_expr = FunctionCall::new_unchecked(ExprType::Row, row_inputs, row_data_type);
181
182            let row_col_idx = pre_proj_exprs.len();
183            pre_proj_exprs.push(row_expr.into());
184            pattern_to_row_col_idx.insert(pattern.clone(), row_col_idx);
185        }
186
187        // Create pre-projection
188        let pre_projection = LogicalProject::create(input_plan.clone(), pre_proj_exprs);
189
190        // Step 2: Build new aggregation calls operating on ROW columns
191        let mut new_calls = Vec::new();
192        let mut original_to_output_mapping = Vec::new();
193
194        // Process mergeable patterns
195        for (pattern, mergeable_calls) in &mergeable_patterns {
196            let row_col_idx = pattern_to_row_col_idx[pattern];
197            let row_data_type = pre_projection.schema().fields()[row_col_idx]
198                .data_type
199                .clone();
200
201            // Create aggregation call that operates on the ROW column
202            let merged_call = PlanAggCall {
203                agg_type: pattern.agg_type.clone(),
204                return_type: row_data_type,
205                inputs: vec![InputRef::new(
206                    row_col_idx,
207                    pre_projection.schema().fields()[row_col_idx]
208                        .data_type
209                        .clone(),
210                )],
211                distinct: pattern.distinct,
212                order_by: pattern
213                    .order_by
214                    .iter()
215                    .map(|order| {
216                        // Adjust column references in ORDER BY to point to ROW column
217                        ColumnOrder::new(row_col_idx, order.order_type)
218                    })
219                    .collect(),
220                filter: pattern.filter.clone(),
221                direct_args: vec![],
222            };
223
224            let merged_call_idx = new_calls.len();
225            new_calls.push(merged_call);
226
227            // Map each original call to its field in the merged result
228            for (field_idx, (original_idx, _)) in mergeable_calls.iter().enumerate() {
229                original_to_output_mapping.push((*original_idx, merged_call_idx, Some(field_idx)));
230            }
231        }
232
233        // Add non-mergeable calls
234        for (original_idx, call) in calls.iter().enumerate() {
235            // Check if this call was already handled in mergeable patterns
236            if !original_to_output_mapping
237                .iter()
238                .any(|(idx, _, _)| *idx == original_idx)
239            {
240                let new_call_idx = new_calls.len();
241                new_calls.push(call.clone());
242                original_to_output_mapping.push((original_idx, new_call_idx, None));
243            }
244        }
245
246        // Sort mapping by original index to maintain order
247        original_to_output_mapping.sort_by_key(|(original_idx, _, _)| *original_idx);
248
249        // Create aggregation on pre-projection
250        let new_agg: PlanRef = Agg::new(new_calls, agg.group_key().clone(), pre_projection)
251            .with_enable_two_phase(agg.core().two_phase_agg_enabled())
252            .into();
253
254        // Step 3: Build post-projection to extract fields from ROW results
255        let mut post_proj_exprs = Vec::new();
256
257        // Add group key columns
258        for i in 0..agg.group_key().len() {
259            let group_col_idx = agg.group_key().indices().nth(i).unwrap();
260            post_proj_exprs.push(
261                InputRef::new(i, input_schema.fields()[group_col_idx].data_type.clone()).into(),
262            );
263        }
264
265        // Add aggregation result columns with proper field extraction
266        for (original_idx, new_call_idx, field_idx_opt) in original_to_output_mapping {
267            let original_return_type = calls[original_idx].return_type.clone();
268
269            if let Some(field_idx) = field_idx_opt {
270                // Extract field from ROW result using proper field access
271                let struct_ref = InputRef::new(
272                    agg.group_key().len() + new_call_idx,
273                    new_agg.schema().fields()[agg.group_key().len() + new_call_idx]
274                        .data_type
275                        .clone(),
276                );
277
278                let field_access = FunctionCall::new_unchecked(
279                    ExprType::Field,
280                    vec![
281                        struct_ref.into(),
282                        Literal::new(Some((field_idx as i32).into()), DataType::Int32).into(),
283                    ],
284                    original_return_type,
285                );
286
287                post_proj_exprs.push(field_access.into());
288            } else {
289                // Direct reference to non-merged call
290                post_proj_exprs.push(
291                    InputRef::new(agg.group_key().len() + new_call_idx, original_return_type)
292                        .into(),
293                );
294            }
295        }
296
297        Some(LogicalProject::create(new_agg, post_proj_exprs))
298    }
299}
300
301#[derive(Clone, Debug, PartialEq, Eq, Hash)]
302struct AggPattern {
303    agg_type: AggType,
304    order_by: Vec<ColumnOrder>,
305    distinct: bool,
306    filter: crate::utils::Condition,
307}
308
309impl UnifyFirstLastValueRule {
310    pub fn create() -> BoxedRule {
311        Box::new(Self {})
312    }
313
314    fn is_supported_agg_type(&self, agg_type: &AggType) -> bool {
315        matches!(
316            agg_type,
317            AggType::Builtin(PbAggKind::FirstValue) | AggType::Builtin(PbAggKind::LastValue)
318        )
319    }
320}