risingwave_frontend/optimizer/rule/
over_window_to_topn_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use fixedbitset::FixedBitSet;
16use risingwave_common::types::DataType;
17use risingwave_expr::window_function::WindowFuncKind;
18
19use super::{BoxedRule, Rule};
20use crate::PlanRef;
21use crate::expr::{
22    Expr, ExprImpl, ExprRewriter, ExprType, FunctionCall, Literal, collect_input_refs,
23};
24use crate::optimizer::plan_node::generic::GenericPlanRef;
25use crate::optimizer::plan_node::{LogicalFilter, LogicalTopN, PlanTreeNodeUnary};
26use crate::optimizer::property::Order;
27use crate::planner::LIMIT_ALL_COUNT;
28use crate::utils::Condition;
29
30/// Transforms the following pattern to group `TopN` (No Ranking Output).
31///
32/// ```sql
33/// -- project - filter - over window
34/// SELECT .. FROM
35///   (SELECT .., ROW_NUMBER() OVER(PARTITION BY .. ORDER BY ..) rank FROM ..)
36/// WHERE rank [ < | <= | > | >= | = ] ..;
37/// ```
38///
39/// Transforms the following pattern to `OverWindow` + group `TopN` (Ranking Output).
40/// The `TopN` decreases the number of rows to be processed by the `OverWindow`.
41///
42/// ```sql
43/// -- filter - over window
44/// SELECT .., ROW_NUMBER() OVER(PARTITION BY .. ORDER BY ..) rank
45/// FROM ..
46/// WHERE rank [ < | <= | > | >= | = ] ..;
47/// ```
48///
49/// Also optimizes filter arithmetic expressions in the `Project <- Filter <- OverWindow` pattern,
50/// such as simplifying `(row_number - 1) = 0` to `row_number = 1`.
51pub struct OverWindowToTopNRule;
52
53impl OverWindowToTopNRule {
54    pub fn create() -> BoxedRule {
55        Box::new(OverWindowToTopNRule)
56    }
57}
58
59impl Rule for OverWindowToTopNRule {
60    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
61        let ctx = plan.ctx();
62        let (project, plan) = {
63            if let Some(project) = plan.as_logical_project() {
64                (Some(project), project.input())
65            } else {
66                (None, plan)
67            }
68        };
69        let filter = plan.as_logical_filter()?;
70        let plan = filter.input();
71        // The filter is directly on top of the over window after predicate pushdown.
72        let over_window = plan.as_logical_over_window()?;
73
74        // First try to simplify filter arithmetic expressions
75        let filter = if let Some(simplified) = self.simplify_filter_arithmetic(filter) {
76            simplified
77        } else {
78            filter.clone()
79        };
80
81        if over_window.window_functions().len() != 1 {
82            // Queries with multiple window function calls are not supported yet.
83            return None;
84        }
85        let window_func = &over_window.window_functions()[0];
86        if !window_func.kind.is_numbering() {
87            // Only rank functions can be converted to TopN.
88            return None;
89        }
90
91        let output_len = over_window.schema().len();
92        let window_func_pos = output_len - 1;
93
94        let with_ties = match window_func.kind {
95            // Only `ROW_NUMBER` and `RANK` can be optimized to TopN now.
96            WindowFuncKind::RowNumber => false,
97            WindowFuncKind::Rank => true,
98            WindowFuncKind::DenseRank => {
99                ctx.warn_to_user("`dense_rank` is not supported in Top-N pattern, will fallback to inefficient implementation");
100                return None;
101            }
102            _ => unreachable!("window functions other than rank functions should not reach here"),
103        };
104
105        let (rank_pred, other_pred) = {
106            let predicate = filter.predicate();
107            let mut rank_col = FixedBitSet::with_capacity(output_len);
108            rank_col.set(window_func_pos, true);
109            predicate.clone().split_disjoint(&rank_col)
110        };
111
112        let (limit, offset) = handle_rank_preds(&rank_pred.conjunctions, window_func_pos)?;
113
114        if offset > 0 && with_ties {
115            tracing::warn!("Failed to optimize with ties and offset");
116            ctx.warn_to_user("group topN with ties and offset is not supported, see https://www.risingwave.dev/docs/current/sql-pattern-topn/ for more information");
117            return None;
118        }
119
120        let topn: PlanRef = LogicalTopN::new(
121            over_window.input(),
122            limit,
123            offset,
124            with_ties,
125            Order {
126                column_orders: window_func.order_by.to_vec(),
127            },
128            window_func.partition_by.iter().map(|i| i.index).collect(),
129        )
130        .into();
131        let filter = LogicalFilter::create(topn, other_pred);
132
133        let plan = if let Some(project) = project {
134            let referred_cols = collect_input_refs(output_len, project.exprs());
135            if !referred_cols.contains(window_func_pos) {
136                // No Ranking Output
137                project.clone_with_input(filter).into()
138            } else {
139                // Ranking Output, with project
140                project
141                    .clone_with_input(over_window.clone_with_input(filter).into())
142                    .into()
143            }
144        } else {
145            // Ranking Output, without project
146            ctx.warn_to_user("It can be inefficient to output ranking number in Top-N, see https://www.risingwave.dev/docs/current/sql-pattern-topn/ for more information");
147            over_window.clone_with_input(filter).into()
148        };
149        Some(plan)
150    }
151}
152
153impl OverWindowToTopNRule {
154    /// Simplify arithmetic expressions in filter conditions before TopN optimization
155    /// For example: `(row_number - 1) = 0` -> `row_number = 1`
156    fn simplify_filter_arithmetic(&self, filter: &LogicalFilter) -> Option<LogicalFilter> {
157        let new_predicate = self.simplify_filter_arithmetic_condition(filter.predicate())?;
158        Some(LogicalFilter::new(filter.input(), new_predicate))
159    }
160
161    /// Simplify arithmetic expressions in the filter condition
162    fn simplify_filter_arithmetic_condition(&self, predicate: &Condition) -> Option<Condition> {
163        let expr = predicate.as_expr_unless_true()?;
164        let mut rewriter = FilterArithmeticRewriter {};
165        let new_expr = rewriter.rewrite_expr(expr.clone());
166
167        if new_expr != expr {
168            Some(Condition::with_expr(new_expr))
169        } else {
170            None
171        }
172    }
173}
174
175/// Filter arithmetic simplification rewriter: simplifies `(col op const) = const2` to `col = (const2 reverse_op const)`
176struct FilterArithmeticRewriter {}
177
178impl ExprRewriter for FilterArithmeticRewriter {
179    fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
180        use ExprType::{
181            Equal, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, NotEqual,
182        };
183
184        // Check if this is a comparison operation
185        match func_call.func_type() {
186            Equal | NotEqual | LessThan | LessThanOrEqual | GreaterThan | GreaterThanOrEqual => {
187                let inputs = func_call.inputs();
188                if inputs.len() == 2 {
189                    // Check if left operand is an arithmetic expression and right operand is a constant
190                    if let ExprImpl::FunctionCall(left_func) = &inputs[0]
191                        && inputs[1].is_const()
192                        && let Some(simplified) = self.simplify_arithmetic_comparison(
193                            left_func,
194                            &inputs[1],
195                            func_call.func_type(),
196                        )
197                    {
198                        return simplified;
199                    }
200                }
201            }
202            _ => {}
203        }
204
205        // Recursively handle sub-expressions
206        let (func_type, inputs, ret_type) = func_call.decompose();
207        let new_inputs: Vec<_> = inputs
208            .into_iter()
209            .map(|input| self.rewrite_expr(input))
210            .collect();
211
212        FunctionCall::new_unchecked(func_type, new_inputs, ret_type).into()
213    }
214}
215
216impl FilterArithmeticRewriter {
217    /// Simplify arithmetic comparison: `(col op const1) comp const2` -> `col comp (const2 reverse_op const1)`
218    fn simplify_arithmetic_comparison(
219        &self,
220        arithmetic_func: &FunctionCall,
221        comparison_const: &ExprImpl,
222        comparison_op: ExprType,
223    ) -> Option<ExprImpl> {
224        use ExprType::{Add, Subtract};
225
226        // Check arithmetic operation
227        match arithmetic_func.func_type() {
228            Add | Subtract => {
229                let inputs = arithmetic_func.inputs();
230                if inputs.len() == 2 {
231                    // Find column reference and constant
232                    let (column_ref, arith_const, reverse_op) = if inputs[1].is_const() {
233                        // col op const
234                        let reverse_op = match arithmetic_func.func_type() {
235                            Add => Subtract,
236                            Subtract => Add,
237                            _ => unreachable!(),
238                        };
239                        (&inputs[0], &inputs[1], reverse_op)
240                    } else if inputs[0].is_const() && arithmetic_func.func_type() == Add {
241                        // const + col
242                        (&inputs[1], &inputs[0], Subtract)
243                    } else {
244                        return None;
245                    };
246
247                    // Calculate new constant value
248                    if let Ok(new_const_func) = FunctionCall::new(
249                        reverse_op,
250                        vec![comparison_const.clone(), arith_const.clone()],
251                    ) {
252                        let new_const_expr: ExprImpl = new_const_func.into();
253                        // Try constant folding
254                        if let Some(Ok(Some(folded_value))) = new_const_expr.try_fold_const() {
255                            let new_const =
256                                Literal::new(Some(folded_value), new_const_expr.return_type())
257                                    .into();
258
259                            // Construct new comparison expression
260                            if let Ok(new_comparison) = FunctionCall::new(
261                                comparison_op,
262                                vec![column_ref.clone(), new_const],
263                            ) {
264                                return Some(new_comparison.into());
265                            }
266                        }
267                    }
268                }
269            }
270            _ => {}
271        }
272
273        None
274    }
275}
276
277/// Returns `None` if the conditions are too complex or invalid. `Some((limit, offset))` otherwise.
278fn handle_rank_preds(rank_preds: &[ExprImpl], window_func_pos: usize) -> Option<(u64, u64)> {
279    if rank_preds.is_empty() {
280        return None;
281    }
282
283    // rank >= lb
284    let mut lb: Option<i64> = None;
285    // rank <= ub
286    let mut ub: Option<i64> = None;
287    // rank == eq
288    let mut eq: Option<i64> = None;
289
290    for cond in rank_preds {
291        if let Some((input_ref, cmp, v)) = cond.as_comparison_const() {
292            assert_eq!(input_ref.index, window_func_pos);
293            let v = v.cast_implicit(DataType::Int64).ok()?.fold_const().ok()??;
294            let v = *v.as_int64();
295            match cmp {
296                ExprType::LessThanOrEqual => ub = ub.map_or(Some(v), |ub| Some(ub.min(v))),
297                ExprType::LessThan => ub = ub.map_or(Some(v - 1), |ub| Some(ub.min(v - 1))),
298                ExprType::GreaterThan => lb = lb.map_or(Some(v + 1), |lb| Some(lb.max(v + 1))),
299                ExprType::GreaterThanOrEqual => lb = lb.map_or(Some(v), |lb| Some(lb.max(v))),
300                _ => unreachable!(),
301            }
302        } else if let Some((input_ref, v)) = cond.as_eq_const() {
303            assert_eq!(input_ref.index, window_func_pos);
304            let v = v.cast_implicit(DataType::Int64).ok()?.fold_const().ok()??;
305            let v = *v.as_int64();
306            if let Some(eq) = eq
307                && eq != v
308            {
309                tracing::warn!(
310                    "Failed to optimize rank predicate with conflicting equal conditions."
311                );
312                return None;
313            }
314            eq = Some(v)
315        } else {
316            // TODO: support between and in
317            tracing::warn!("Failed to optimize complex rank predicate {:?}", cond);
318            return None;
319        }
320    }
321
322    // Note: rank functions start from 1
323    if let Some(eq) = eq {
324        if eq < 1 {
325            tracing::warn!(
326                "Failed to optimize rank predicate with invalid predicate rank={}.",
327                eq
328            );
329            return None;
330        }
331        let lb = lb.unwrap_or(i64::MIN);
332        let ub = ub.unwrap_or(i64::MAX);
333        if !(lb <= eq && eq <= ub) {
334            tracing::warn!("Failed to optimize rank predicate with conflicting bounds.");
335            return None;
336        }
337        Some((1, (eq - 1) as u64))
338    } else {
339        match (lb, ub) {
340            (Some(lb), Some(ub)) => Some(((ub - lb + 1).max(0) as u64, (lb - 1).max(0) as u64)),
341            (Some(lb), None) => Some((LIMIT_ALL_COUNT, (lb - 1).max(0) as u64)),
342            (None, Some(ub)) => Some((ub.max(0) as u64, 0)),
343            (None, None) => unreachable!(),
344        }
345    }
346}