risingwave_frontend/optimizer/rule/
over_window_to_topn_rule.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use fixedbitset::FixedBitSet;
16use risingwave_common::types::DataType;
17use risingwave_expr::window_function::WindowFuncKind;
18
19use super::{BoxedRule, Rule};
20use crate::PlanRef;
21use crate::expr::{ExprImpl, ExprType, collect_input_refs};
22use crate::optimizer::plan_node::generic::GenericPlanRef;
23use crate::optimizer::plan_node::{LogicalFilter, LogicalTopN, PlanTreeNodeUnary};
24use crate::optimizer::property::Order;
25use crate::planner::LIMIT_ALL_COUNT;
26
27/// Transforms the following pattern to group `TopN` (No Ranking Output).
28///
29/// ```sql
30/// -- project - filter - over window
31/// SELECT .. FROM
32///   (SELECT .., ROW_NUMBER() OVER(PARTITION BY .. ORDER BY ..) rank FROM ..)
33/// WHERE rank [ < | <= | > | >= | = ] ..;
34/// ```
35///
36/// Transforms the following pattern to `OverWindow` + group `TopN` (Ranking Output).
37/// The `TopN` decreases the number of rows to be processed by the `OverWindow`.
38///
39/// ```sql
40/// -- filter - over window
41/// SELECT .., ROW_NUMBER() OVER(PARTITION BY .. ORDER BY ..) rank
42/// FROM ..
43/// WHERE rank [ < | <= | > | >= | = ] ..;
44/// ```
45pub struct OverWindowToTopNRule;
46
47impl OverWindowToTopNRule {
48    pub fn create() -> BoxedRule {
49        Box::new(OverWindowToTopNRule)
50    }
51}
52
53impl Rule for OverWindowToTopNRule {
54    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
55        let ctx = plan.ctx();
56        let (project, plan) = {
57            if let Some(project) = plan.as_logical_project() {
58                (Some(project), project.input())
59            } else {
60                (None, plan)
61            }
62        };
63        let filter = plan.as_logical_filter()?;
64        let plan = filter.input();
65        // The filter is directly on top of the over window after predicate pushdown.
66        let over_window = plan.as_logical_over_window()?;
67
68        if over_window.window_functions().len() != 1 {
69            // Queries with multiple window function calls are not supported yet.
70            return None;
71        }
72        let window_func = &over_window.window_functions()[0];
73        if !window_func.kind.is_numbering() {
74            // Only rank functions can be converted to TopN.
75            return None;
76        }
77
78        let output_len = over_window.schema().len();
79        let window_func_pos = output_len - 1;
80
81        let with_ties = match window_func.kind {
82            // Only `ROW_NUMBER` and `RANK` can be optimized to TopN now.
83            WindowFuncKind::RowNumber => false,
84            WindowFuncKind::Rank => true,
85            WindowFuncKind::DenseRank => {
86                ctx.warn_to_user("`dense_rank` is not supported in Top-N pattern, will fallback to inefficient implementation");
87                return None;
88            }
89            _ => unreachable!("window functions other than rank functions should not reach here"),
90        };
91
92        let (rank_pred, other_pred) = {
93            let predicate = filter.predicate();
94            let mut rank_col = FixedBitSet::with_capacity(output_len);
95            rank_col.set(window_func_pos, true);
96            predicate.clone().split_disjoint(&rank_col)
97        };
98
99        let (limit, offset) = handle_rank_preds(&rank_pred.conjunctions, window_func_pos)?;
100
101        if offset > 0 && with_ties {
102            tracing::warn!("Failed to optimize with ties and offset");
103            ctx.warn_to_user("group topN with ties and offset is not supported, see https://www.risingwave.dev/docs/current/sql-pattern-topn/ for more information");
104            return None;
105        }
106
107        let topn: PlanRef = LogicalTopN::new(
108            over_window.input(),
109            limit,
110            offset,
111            with_ties,
112            Order {
113                column_orders: window_func.order_by.to_vec(),
114            },
115            window_func.partition_by.iter().map(|i| i.index).collect(),
116        )
117        .into();
118        let filter = LogicalFilter::create(topn, other_pred);
119
120        let plan = if let Some(project) = project {
121            let referred_cols = collect_input_refs(output_len, project.exprs());
122            if !referred_cols.contains(window_func_pos) {
123                // No Ranking Output
124                project.clone_with_input(filter).into()
125            } else {
126                // Ranking Output, with project
127                project
128                    .clone_with_input(over_window.clone_with_input(filter).into())
129                    .into()
130            }
131        } else {
132            // Ranking Output, without project
133            ctx.warn_to_user("It can be inefficient to output ranking number in Top-N, see https://www.risingwave.dev/docs/current/sql-pattern-topn/ for more information");
134            over_window.clone_with_input(filter).into()
135        };
136        Some(plan)
137    }
138}
139
140/// Returns `None` if the conditions are too complex or invalid. `Some((limit, offset))` otherwise.
141fn handle_rank_preds(rank_preds: &[ExprImpl], window_func_pos: usize) -> Option<(u64, u64)> {
142    if rank_preds.is_empty() {
143        return None;
144    }
145
146    // rank >= lb
147    let mut lb: Option<i64> = None;
148    // rank <= ub
149    let mut ub: Option<i64> = None;
150    // rank == eq
151    let mut eq: Option<i64> = None;
152
153    for cond in rank_preds {
154        if let Some((input_ref, cmp, v)) = cond.as_comparison_const() {
155            assert_eq!(input_ref.index, window_func_pos);
156            let v = v.cast_implicit(DataType::Int64).ok()?.fold_const().ok()??;
157            let v = *v.as_int64();
158            match cmp {
159                ExprType::LessThanOrEqual => ub = ub.map_or(Some(v), |ub| Some(ub.min(v))),
160                ExprType::LessThan => ub = ub.map_or(Some(v - 1), |ub| Some(ub.min(v - 1))),
161                ExprType::GreaterThan => lb = lb.map_or(Some(v + 1), |lb| Some(lb.max(v + 1))),
162                ExprType::GreaterThanOrEqual => lb = lb.map_or(Some(v), |lb| Some(lb.max(v))),
163                _ => unreachable!(),
164            }
165        } else if let Some((input_ref, v)) = cond.as_eq_const() {
166            assert_eq!(input_ref.index, window_func_pos);
167            let v = v.cast_implicit(DataType::Int64).ok()?.fold_const().ok()??;
168            let v = *v.as_int64();
169            if let Some(eq) = eq
170                && eq != v
171            {
172                tracing::warn!(
173                    "Failed to optimize rank predicate with conflicting equal conditions."
174                );
175                return None;
176            }
177            eq = Some(v)
178        } else {
179            // TODO: support between and in
180            tracing::warn!("Failed to optimize complex rank predicate {:?}", cond);
181            return None;
182        }
183    }
184
185    // Note: rank functions start from 1
186    if let Some(eq) = eq {
187        if eq < 1 {
188            tracing::warn!(
189                "Failed to optimize rank predicate with invalid predicate rank={}.",
190                eq
191            );
192            return None;
193        }
194        let lb = lb.unwrap_or(i64::MIN);
195        let ub = ub.unwrap_or(i64::MAX);
196        if !(lb <= eq && eq <= ub) {
197            tracing::warn!("Failed to optimize rank predicate with conflicting bounds.");
198            return None;
199        }
200        Some((1, (eq - 1) as u64))
201    } else {
202        match (lb, ub) {
203            (Some(lb), Some(ub)) => Some(((ub - lb + 1).max(0) as u64, (lb - 1).max(0) as u64)),
204            (Some(lb), None) => Some((LIMIT_ALL_COUNT, (lb - 1).max(0) as u64)),
205            (None, Some(ub)) => Some((ub.max(0) as u64, 0)),
206            (None, None) => unreachable!(),
207        }
208    }
209}