risingwave_frontend/optimizer/rule/unify_first_last_value_rule.rs
1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16
17use risingwave_common::types::{DataType, StructType};
18use risingwave_common::util::sort_util::ColumnOrder;
19use risingwave_expr::aggregate::{AggType, PbAggKind};
20
21use super::prelude::{PlanRef, *};
22use crate::expr::{ExprImpl, ExprType, FunctionCall, InputRef, Literal};
23use crate::optimizer::plan_node::generic::{Agg, PlanAggCall};
24use crate::optimizer::plan_node::{LogicalProject, PlanTreeNodeUnary};
25
26/// Unifies `FIRST_VALUE` and `LAST_VALUE` aggregation calls with the same ordering to reuse a single aggregation call with ROW construction.
27///
28/// This rule is particularly beneficial for streaming queries as it reduces the number of states maintained.
29/// It specifically targets `FIRST_VALUE` and `LAST_VALUE` aggregation functions that have identical ordering clauses.
30///
31/// # Example transformation:
32///
33/// ## Before:
34/// ```sql
35/// SELECT
36/// LAST_VALUE(col1 ORDER BY col2) AS last_col1,
37/// LAST_VALUE(col3 ORDER BY col2) AS last_col3,
38/// LAST_VALUE(col4 ORDER BY col2) AS last_col4,
39/// LAST_VALUE(col5 ORDER BY col2) AS last_col5
40/// FROM table_name
41/// GROUP BY group_col;
42/// ```
43///
44/// ## After:
45/// ```sql
46/// SELECT
47/// (unified_last).f0 AS last_col1,
48/// (unified_last).f1 AS last_col3,
49/// (unified_last).f2 AS last_col4,
50/// (unified_last).f3 AS last_col5
51/// FROM (
52/// SELECT
53/// LAST_VALUE(ROW(col1, col3, col4, col5) ORDER BY col2) AS unified_last
54/// FROM table_name GROUP BY group_col
55/// ) sub;
56/// ```
57///
58/// # Plan transformation:
59///
60/// ## Before:
61/// ```text
62/// LogicalAgg [group_col]
63/// ├─agg_calls:
64/// │ ├─ LAST_VALUE(col1 ORDER BY col2) -- State 1
65/// │ ├─ LAST_VALUE(col3 ORDER BY col2) -- State 2
66/// │ ├─ LAST_VALUE(col4 ORDER BY col2) -- State 3
67/// │ └─ LAST_VALUE(col5 ORDER BY col2) -- State 4
68/// └─LogicalScan { table: table_name }
69/// ```
70///
71/// ## After:
72/// ```text
73/// LogicalProject
74/// ├─exprs: [(unified_last).f0, (unified_last).f1, (unified_last).f2, (unified_last).f3]
75/// └─LogicalAgg [group_col]
76/// ├─agg_calls:
77/// │ └─ LAST_VALUE(ROW(col1,col3,col4,col5) ORDER BY col2) -- Single State!
78/// └─LogicalProject
79/// ├─exprs: [group_col, col1, col2, col3, col4, col5, ROW(col1,col3,col4,col5)]
80/// └─LogicalScan { table: table_name }
81/// ```
82///
83/// The key benefit: **4 aggregation states → 1 aggregation state** for streaming performance!
84pub struct UnifyFirstLastValueRule {}
85
86impl Rule<Logical> for UnifyFirstLastValueRule {
87 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
88 let agg = plan.as_logical_agg()?;
89
90 let calls = agg.agg_calls();
91 if calls.len() < 2 {
92 // Need at least 2 calls to merge
93 return None;
94 }
95
96 // Group calls by their "pattern" (agg_type + order_by + distinct + filter)
97 let mut pattern_groups: HashMap<AggPattern, Vec<(usize, &PlanAggCall)>> = HashMap::new();
98
99 for (idx, call) in calls.iter().enumerate() {
100 let PlanAggCall {
101 agg_type,
102 return_type: _,
103 inputs,
104 distinct,
105 order_by,
106 filter,
107 direct_args: _,
108 } = call;
109
110 if !self.is_supported_agg_type(agg_type) {
111 continue;
112 }
113
114 // Only support single input aggregations for now
115 if inputs.len() != 1 {
116 continue;
117 }
118
119 let pattern = AggPattern {
120 agg_type: agg_type.clone(),
121 order_by: order_by.clone(),
122 distinct: *distinct,
123 filter: filter.clone(),
124 };
125
126 pattern_groups.entry(pattern).or_default().push((idx, call));
127 }
128
129 // Find ALL patterns with multiple calls that can be merged
130 let mut mergeable_patterns: Vec<(AggPattern, Vec<(usize, &PlanAggCall)>)> = pattern_groups
131 .into_iter()
132 .filter(|(_, calls_in_pattern)| calls_in_pattern.len() >= 2)
133 .collect();
134
135 // Sort by the minimum original index in each pattern to ensure stable optimization results
136 mergeable_patterns.sort_by_key(|(_, calls_in_pattern)| {
137 calls_in_pattern.iter().map(|(idx, _)| *idx).min().unwrap()
138 });
139
140 // Also sort calls within each pattern by their original index
141 for (_, calls_in_pattern) in &mut mergeable_patterns {
142 calls_in_pattern.sort_by_key(|(idx, _)| *idx);
143 }
144
145 if mergeable_patterns.is_empty() {
146 return None;
147 }
148
149 // Build the complete transformation pipeline:
150 // 1. Pre-projection: Construct ROW expressions for each mergeable pattern
151 // 2. Aggregation: Apply aggregation functions on ROW expressions
152 // 3. Post-projection: Extract individual fields from ROW results
153
154 // Step 1: Build pre-projection that constructs ROW expressions
155 let input_plan = agg.input();
156 let input_schema = input_plan.schema();
157 let mut pre_proj_exprs = Vec::new();
158
159 // Add all original input columns first
160 for i in 0..input_schema.len() {
161 pre_proj_exprs
162 .push(InputRef::new(i, input_schema.fields()[i].data_type.clone()).into());
163 }
164
165 // Add ROW construction expressions for each mergeable pattern
166 let mut pattern_to_row_col_idx = HashMap::new();
167 for (pattern, mergeable_calls) in &mergeable_patterns {
168 // Create ROW expression with inputs from this pattern
169 let row_inputs: Vec<ExprImpl> = mergeable_calls
170 .iter()
171 .map(|(_, call)| call.inputs[0].clone().into())
172 .collect();
173
174 let field_types = mergeable_calls
175 .iter()
176 .map(|(_, call)| call.return_type.clone());
177 let row_data_type = DataType::Struct(StructType::unnamed(field_types));
178
179 let row_expr = FunctionCall::new_unchecked(ExprType::Row, row_inputs, row_data_type);
180
181 let row_col_idx = pre_proj_exprs.len();
182 pre_proj_exprs.push(row_expr.into());
183 pattern_to_row_col_idx.insert(pattern.clone(), row_col_idx);
184 }
185
186 // Create pre-projection
187 let pre_projection = LogicalProject::create(input_plan.clone(), pre_proj_exprs);
188
189 // Step 2: Build new aggregation calls operating on ROW columns
190 let mut new_calls = Vec::new();
191 let mut original_to_output_mapping = Vec::new();
192
193 // Process mergeable patterns
194 for (pattern, mergeable_calls) in &mergeable_patterns {
195 let row_col_idx = pattern_to_row_col_idx[pattern];
196 let row_data_type = pre_projection.schema().fields()[row_col_idx]
197 .data_type
198 .clone();
199
200 // Create aggregation call that operates on the ROW column
201 let merged_call = PlanAggCall {
202 agg_type: pattern.agg_type.clone(),
203 return_type: row_data_type,
204 inputs: vec![InputRef::new(
205 row_col_idx,
206 pre_projection.schema().fields()[row_col_idx]
207 .data_type
208 .clone(),
209 )],
210 distinct: pattern.distinct,
211 order_by: pattern
212 .order_by
213 .iter()
214 .map(|order| {
215 // Adjust column references in ORDER BY to point to ROW column
216 ColumnOrder::new(row_col_idx, order.order_type)
217 })
218 .collect(),
219 filter: pattern.filter.clone(),
220 direct_args: vec![],
221 };
222
223 let merged_call_idx = new_calls.len();
224 new_calls.push(merged_call);
225
226 // Map each original call to its field in the merged result
227 for (field_idx, (original_idx, _)) in mergeable_calls.iter().enumerate() {
228 original_to_output_mapping.push((*original_idx, merged_call_idx, Some(field_idx)));
229 }
230 }
231
232 // Add non-mergeable calls
233 for (original_idx, call) in calls.iter().enumerate() {
234 // Check if this call was already handled in mergeable patterns
235 if !original_to_output_mapping
236 .iter()
237 .any(|(idx, _, _)| *idx == original_idx)
238 {
239 let new_call_idx = new_calls.len();
240 new_calls.push(call.clone());
241 original_to_output_mapping.push((original_idx, new_call_idx, None));
242 }
243 }
244
245 // Sort mapping by original index to maintain order
246 original_to_output_mapping.sort_by_key(|(original_idx, _, _)| *original_idx);
247
248 // Create aggregation on pre-projection
249 let new_agg: PlanRef = Agg::new(new_calls, agg.group_key().clone(), pre_projection)
250 .with_enable_two_phase(agg.core().two_phase_agg_enabled())
251 .into();
252
253 // Step 3: Build post-projection to extract fields from ROW results
254 let mut post_proj_exprs = Vec::new();
255
256 // Add group key columns
257 for i in 0..agg.group_key().len() {
258 let group_col_idx = agg.group_key().indices().nth(i).unwrap();
259 post_proj_exprs.push(
260 InputRef::new(i, input_schema.fields()[group_col_idx].data_type.clone()).into(),
261 );
262 }
263
264 // Add aggregation result columns with proper field extraction
265 for (original_idx, new_call_idx, field_idx_opt) in original_to_output_mapping {
266 let original_return_type = calls[original_idx].return_type.clone();
267
268 if let Some(field_idx) = field_idx_opt {
269 // Extract field from ROW result using proper field access
270 let struct_ref = InputRef::new(
271 agg.group_key().len() + new_call_idx,
272 new_agg.schema().fields()[agg.group_key().len() + new_call_idx]
273 .data_type
274 .clone(),
275 );
276
277 let field_access = FunctionCall::new_unchecked(
278 ExprType::Field,
279 vec![
280 struct_ref.into(),
281 Literal::new(Some((field_idx as i32).into()), DataType::Int32).into(),
282 ],
283 original_return_type,
284 );
285
286 post_proj_exprs.push(field_access.into());
287 } else {
288 // Direct reference to non-merged call
289 post_proj_exprs.push(
290 InputRef::new(agg.group_key().len() + new_call_idx, original_return_type)
291 .into(),
292 );
293 }
294 }
295
296 Some(LogicalProject::create(new_agg, post_proj_exprs))
297 }
298}
299
300#[derive(Clone, Debug, PartialEq, Eq, Hash)]
301struct AggPattern {
302 agg_type: AggType,
303 order_by: Vec<ColumnOrder>,
304 distinct: bool,
305 filter: crate::utils::Condition,
306}
307
308impl UnifyFirstLastValueRule {
309 pub fn create() -> BoxedRule {
310 Box::new(Self {})
311 }
312
313 fn is_supported_agg_type(&self, agg_type: &AggType) -> bool {
314 matches!(
315 agg_type,
316 AggType::Builtin(PbAggKind::FirstValue) | AggType::Builtin(PbAggKind::LastValue)
317 )
318 }
319}