risingwave_frontend/optimizer/rule/unify_first_last_value_rule.rs
1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16
17use risingwave_common::types::{DataType, StructType};
18use risingwave_common::util::sort_util::ColumnOrder;
19use risingwave_expr::aggregate::{AggType, PbAggKind};
20
21use super::prelude::{PlanRef, *};
22use crate::expr::{ExprImpl, ExprType, FunctionCall, InputRef, Literal};
23use crate::optimizer::plan_node::generic::{Agg, PlanAggCall};
24use crate::optimizer::plan_node::{LogicalProject, PlanTreeNodeUnary};
25
26/// Unifies `FIRST_VALUE` and `LAST_VALUE` aggregation calls with the same ordering to reuse a single aggregation call with ROW construction.
27///
28/// This rule is particularly beneficial for streaming queries as it reduces the number of states maintained.
29/// It specifically targets `FIRST_VALUE` and `LAST_VALUE` aggregation functions that have identical ordering clauses.
30///
31/// # Example transformation:
32///
33/// ## Before:
34/// ```sql
35/// SELECT
36/// LAST_VALUE(col1 ORDER BY col2) AS last_col1,
37/// LAST_VALUE(col3 ORDER BY col2) AS last_col3,
38/// LAST_VALUE(col4 ORDER BY col2) AS last_col4,
39/// LAST_VALUE(col5 ORDER BY col2) AS last_col5
40/// FROM table_name
41/// GROUP BY group_col;
42/// ```
43///
44/// ## After:
45/// ```sql
46/// SELECT
47/// (unified_last).f0 AS last_col1,
48/// (unified_last).f1 AS last_col3,
49/// (unified_last).f2 AS last_col4,
50/// (unified_last).f3 AS last_col5
51/// FROM (
52/// SELECT
53/// LAST_VALUE(ROW(col1, col3, col4, col5) ORDER BY col2) AS unified_last
54/// FROM table_name GROUP BY group_col
55/// ) sub;
56/// ```
57///
58/// # Plan transformation:
59///
60/// ## Before:
61/// ```text
62/// LogicalAgg [group_col]
63/// ├─agg_calls:
64/// │ ├─ LAST_VALUE(col1 ORDER BY col2) -- State 1
65/// │ ├─ LAST_VALUE(col3 ORDER BY col2) -- State 2
66/// │ ├─ LAST_VALUE(col4 ORDER BY col2) -- State 3
67/// │ └─ LAST_VALUE(col5 ORDER BY col2) -- State 4
68/// └─LogicalScan { table: table_name }
69/// ```
70///
71/// ## After:
72/// ```text
73/// LogicalProject
74/// ├─exprs: [(unified_last).f0, (unified_last).f1, (unified_last).f2, (unified_last).f3]
75/// └─LogicalAgg [group_col]
76/// ├─agg_calls:
77/// │ └─ LAST_VALUE(ROW(col1,col3,col4,col5) ORDER BY col2) -- Single State!
78/// └─LogicalProject
79/// ├─exprs: [group_col, col1, col2, col3, col4, col5, ROW(col1,col3,col4,col5)]
80/// └─LogicalScan { table: table_name }
81/// ```
82///
83/// The key benefit: **4 aggregation states → 1 aggregation state** for streaming performance!
84pub struct UnifyFirstLastValueRule {}
85
86impl Rule<Logical> for UnifyFirstLastValueRule {
87 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
88 let agg = plan.as_logical_agg()?;
89
90 let calls = agg.agg_calls();
91 if calls.len() < 2 {
92 // Need at least 2 calls to merge
93 return None;
94 }
95
96 // Group calls by their "pattern" (agg_type + order_by + distinct + filter)
97 let mut pattern_groups: HashMap<AggPattern, Vec<(usize, &PlanAggCall)>> = HashMap::new();
98
99 for (idx, call) in calls.iter().enumerate() {
100 let PlanAggCall {
101 agg_type,
102 return_type: _,
103 inputs,
104 distinct,
105 order_by,
106 filter,
107 direct_args: _,
108 } = call;
109
110 if !self.is_supported_agg_type(agg_type) {
111 continue;
112 }
113
114 // Only support single input aggregations for now
115 if inputs.len() != 1 {
116 continue;
117 }
118
119 let pattern = AggPattern {
120 agg_type: agg_type.clone(),
121 order_by: order_by.clone(),
122 distinct: *distinct,
123 filter: filter.clone(),
124 };
125
126 pattern_groups.entry(pattern).or_default().push((idx, call));
127 }
128
129 // Find ALL patterns with multiple calls that can be merged
130 let mut mergeable_patterns: Vec<(AggPattern, Vec<(usize, &PlanAggCall)>)> = pattern_groups
131 .into_iter()
132 .filter(|(_, calls_in_pattern)| calls_in_pattern.len() >= 2)
133 .collect();
134
135 // Sort by the minimum original index in each pattern to ensure stable optimization results
136 mergeable_patterns.sort_by_key(|(_, calls_in_pattern)| {
137 calls_in_pattern.iter().map(|(idx, _)| *idx).min().unwrap()
138 });
139
140 // Also sort calls within each pattern by their original index
141 for (_, calls_in_pattern) in &mut mergeable_patterns {
142 calls_in_pattern.sort_by_key(|(idx, _)| *idx);
143 }
144
145 if mergeable_patterns.is_empty() {
146 return None;
147 }
148
149 // Build the complete transformation pipeline:
150 // 1. Pre-projection: Construct ROW expressions for each mergeable pattern
151 // 2. Aggregation: Apply aggregation functions on ROW expressions
152 // 3. Post-projection: Extract individual fields from ROW results
153
154 // Step 1: Build pre-projection that constructs ROW expressions
155 let input_plan = agg.input();
156 let input_schema = input_plan.schema();
157 let mut pre_proj_exprs = Vec::new();
158
159 // Add all original input columns first
160 for i in 0..input_schema.len() {
161 pre_proj_exprs
162 .push(InputRef::new(i, input_schema.fields()[i].data_type.clone()).into());
163 }
164
165 // Add ROW construction expressions for each mergeable pattern
166 let mut pattern_to_row_col_idx = HashMap::new();
167 for (pattern, mergeable_calls) in &mergeable_patterns {
168 // Create ROW expression with inputs from this pattern
169 let row_inputs: Vec<ExprImpl> = mergeable_calls
170 .iter()
171 .map(|(_, call)| call.inputs[0].clone().into())
172 .collect();
173
174 let field_types = mergeable_calls
175 .iter()
176 .map(|(_, call)| call.return_type.clone())
177 .collect::<Vec<_>>();
178 let row_data_type = DataType::Struct(StructType::unnamed(field_types));
179
180 let row_expr = FunctionCall::new_unchecked(ExprType::Row, row_inputs, row_data_type);
181
182 let row_col_idx = pre_proj_exprs.len();
183 pre_proj_exprs.push(row_expr.into());
184 pattern_to_row_col_idx.insert(pattern.clone(), row_col_idx);
185 }
186
187 // Create pre-projection
188 let pre_projection = LogicalProject::create(input_plan.clone(), pre_proj_exprs);
189
190 // Step 2: Build new aggregation calls operating on ROW columns
191 let mut new_calls = Vec::new();
192 let mut original_to_output_mapping = Vec::new();
193
194 // Process mergeable patterns
195 for (pattern, mergeable_calls) in &mergeable_patterns {
196 let row_col_idx = pattern_to_row_col_idx[pattern];
197 let row_data_type = pre_projection.schema().fields()[row_col_idx]
198 .data_type
199 .clone();
200
201 // Create aggregation call that operates on the ROW column
202 let merged_call = PlanAggCall {
203 agg_type: pattern.agg_type.clone(),
204 return_type: row_data_type,
205 inputs: vec![InputRef::new(
206 row_col_idx,
207 pre_projection.schema().fields()[row_col_idx]
208 .data_type
209 .clone(),
210 )],
211 distinct: pattern.distinct,
212 order_by: pattern
213 .order_by
214 .iter()
215 .map(|order| {
216 // Adjust column references in ORDER BY to point to ROW column
217 ColumnOrder::new(row_col_idx, order.order_type)
218 })
219 .collect(),
220 filter: pattern.filter.clone(),
221 direct_args: vec![],
222 };
223
224 let merged_call_idx = new_calls.len();
225 new_calls.push(merged_call);
226
227 // Map each original call to its field in the merged result
228 for (field_idx, (original_idx, _)) in mergeable_calls.iter().enumerate() {
229 original_to_output_mapping.push((*original_idx, merged_call_idx, Some(field_idx)));
230 }
231 }
232
233 // Add non-mergeable calls
234 for (original_idx, call) in calls.iter().enumerate() {
235 // Check if this call was already handled in mergeable patterns
236 if !original_to_output_mapping
237 .iter()
238 .any(|(idx, _, _)| *idx == original_idx)
239 {
240 let new_call_idx = new_calls.len();
241 new_calls.push(call.clone());
242 original_to_output_mapping.push((original_idx, new_call_idx, None));
243 }
244 }
245
246 // Sort mapping by original index to maintain order
247 original_to_output_mapping.sort_by_key(|(original_idx, _, _)| *original_idx);
248
249 // Create aggregation on pre-projection
250 let new_agg: PlanRef = Agg::new(new_calls, agg.group_key().clone(), pre_projection)
251 .with_enable_two_phase(agg.core().two_phase_agg_enabled())
252 .into();
253
254 // Step 3: Build post-projection to extract fields from ROW results
255 let mut post_proj_exprs = Vec::new();
256
257 // Add group key columns
258 for i in 0..agg.group_key().len() {
259 let group_col_idx = agg.group_key().indices().nth(i).unwrap();
260 post_proj_exprs.push(
261 InputRef::new(i, input_schema.fields()[group_col_idx].data_type.clone()).into(),
262 );
263 }
264
265 // Add aggregation result columns with proper field extraction
266 for (original_idx, new_call_idx, field_idx_opt) in original_to_output_mapping {
267 let original_return_type = calls[original_idx].return_type.clone();
268
269 if let Some(field_idx) = field_idx_opt {
270 // Extract field from ROW result using proper field access
271 let struct_ref = InputRef::new(
272 agg.group_key().len() + new_call_idx,
273 new_agg.schema().fields()[agg.group_key().len() + new_call_idx]
274 .data_type
275 .clone(),
276 );
277
278 let field_access = FunctionCall::new_unchecked(
279 ExprType::Field,
280 vec![
281 struct_ref.into(),
282 Literal::new(Some((field_idx as i32).into()), DataType::Int32).into(),
283 ],
284 original_return_type,
285 );
286
287 post_proj_exprs.push(field_access.into());
288 } else {
289 // Direct reference to non-merged call
290 post_proj_exprs.push(
291 InputRef::new(agg.group_key().len() + new_call_idx, original_return_type)
292 .into(),
293 );
294 }
295 }
296
297 Some(LogicalProject::create(new_agg, post_proj_exprs))
298 }
299}
300
301#[derive(Clone, Debug, PartialEq, Eq, Hash)]
302struct AggPattern {
303 agg_type: AggType,
304 order_by: Vec<ColumnOrder>,
305 distinct: bool,
306 filter: crate::utils::Condition,
307}
308
309impl UnifyFirstLastValueRule {
310 pub fn create() -> BoxedRule {
311 Box::new(Self {})
312 }
313
314 fn is_supported_agg_type(&self, agg_type: &AggType) -> bool {
315 matches!(
316 agg_type,
317 AggType::Builtin(PbAggKind::FirstValue) | AggType::Builtin(PbAggKind::LastValue)
318 )
319 }
320}