risingwave_frontend/optimizer/rule/
over_window_to_topn_rule.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use fixedbitset::FixedBitSet;
16use risingwave_common::types::{DataType, ScalarImpl};
17use risingwave_expr::window_function::WindowFuncKind;
18
19use super::prelude::{PlanRef, *};
20use crate::expr::{
21    Expr, ExprImpl, ExprRewriter, ExprType, FunctionCall, InputRef, Literal, collect_input_refs,
22};
23use crate::optimizer::plan_node::generic::GenericPlanRef;
24use crate::optimizer::plan_node::{LogicalFilter, LogicalProject, LogicalTopN, PlanTreeNodeUnary};
25use crate::optimizer::property::Order;
26use crate::planner::LIMIT_ALL_COUNT;
27use crate::utils::Condition;
28
29/// Transforms the following pattern to group `TopN` (No Ranking Output).
30///
31/// ```sql
32/// -- project - filter - over window
33/// SELECT .. FROM
34///   (SELECT .., ROW_NUMBER() OVER(PARTITION BY .. ORDER BY ..) rank FROM ..)
35/// WHERE rank [ < | <= | > | >= | = ] ..;
36/// ```
37///
38/// Transforms the following pattern to `OverWindow` + group `TopN` (Ranking Output).
39/// The `TopN` decreases the number of rows to be processed by the `OverWindow`.
40///
41/// ```sql
42/// -- filter - over window
43/// SELECT .., ROW_NUMBER() OVER(PARTITION BY .. ORDER BY ..) rank
44/// FROM ..
45/// WHERE rank [ < | <= | > | >= | = ] ..;
46/// ```
47///
48/// Also optimizes filter arithmetic expressions in the `Project <- Filter <- OverWindow` pattern,
49/// such as simplifying `(row_number - 1) = 0` to `row_number = 1`.
50pub struct OverWindowToTopNRule;
51
52impl OverWindowToTopNRule {
53    pub fn create() -> BoxedRule {
54        Box::new(OverWindowToTopNRule)
55    }
56}
57
58impl Rule<Logical> for OverWindowToTopNRule {
59    fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
60        let ctx = plan.ctx();
61        let (project, plan) = {
62            if let Some(project) = plan.as_logical_project() {
63                (Some(project), project.input())
64            } else {
65                (None, plan)
66            }
67        };
68        let filter = plan.as_logical_filter()?;
69        let plan = filter.input();
70        // The filter is directly on top of the over window after predicate pushdown.
71        let over_window = plan.as_logical_over_window()?;
72
73        // First try to simplify filter arithmetic expressions
74        let filter = if let Some(simplified) = self.simplify_filter_arithmetic(filter) {
75            simplified
76        } else {
77            filter.clone()
78        };
79
80        if over_window.window_functions().len() != 1 {
81            // Queries with multiple window function calls are not supported yet.
82            return None;
83        }
84        let window_func = &over_window.window_functions()[0];
85        if !window_func.kind.is_numbering() {
86            // Only rank functions can be converted to TopN.
87            return None;
88        }
89
90        let output_len = over_window.schema().len();
91        let window_func_pos = output_len - 1;
92
93        let with_ties = match window_func.kind {
94            // Only `ROW_NUMBER` and `RANK` can be optimized to TopN now.
95            WindowFuncKind::RowNumber => false,
96            WindowFuncKind::Rank => true,
97            WindowFuncKind::DenseRank => {
98                ctx.warn_to_user("`dense_rank` is not supported in Top-N pattern, will fallback to inefficient implementation");
99                return None;
100            }
101            _ => unreachable!("window functions other than rank functions should not reach here"),
102        };
103
104        let (rank_pred, other_pred) = {
105            let predicate = filter.predicate();
106            let mut rank_col = FixedBitSet::with_capacity(output_len);
107            rank_col.set(window_func_pos, true);
108            predicate.clone().split_disjoint(&rank_col)
109        };
110
111        let (limit, offset) = handle_rank_preds(&rank_pred.conjunctions, window_func_pos)?;
112
113        if offset > 0 && with_ties {
114            tracing::warn!("Failed to optimize with ties and offset");
115            ctx.warn_to_user("group topN with ties and offset is not supported, see https://docs.risingwave.com/processing/sql/top-n-by-group for more information");
116            return None;
117        }
118
119        let topn: PlanRef = LogicalTopN::new(
120            over_window.input(),
121            limit,
122            offset,
123            with_ties,
124            Order {
125                column_orders: window_func.order_by.clone(),
126            },
127            window_func.partition_by.iter().map(|i| i.index).collect(),
128        )
129        .into();
130        let filter = LogicalFilter::create(topn, other_pred);
131
132        let plan = if let Some(project) = project {
133            let referred_cols = collect_input_refs(output_len, project.exprs());
134            if !referred_cols.contains(window_func_pos) {
135                // No Ranking Output
136                project.clone_with_input(filter).into()
137            } else {
138                // Ranking Output, with project
139                let over_window = over_window.clone_with_input(filter).into();
140                let over_window =
141                    add_rank_offset_if_needed(over_window, offset, window_func_pos, output_len);
142                project.clone_with_input(over_window).into()
143            }
144        } else {
145            // Ranking Output, without project
146            ctx.warn_to_user("It can be inefficient to output ranking number in Top-N, see https://docs.risingwave.com/processing/sql/top-n-by-group for more information");
147            let over_window = over_window.clone_with_input(filter).into();
148            add_rank_offset_if_needed(over_window, offset, window_func_pos, output_len)
149        };
150        Some(plan)
151    }
152}
153
154impl OverWindowToTopNRule {
155    /// Simplify arithmetic expressions in filter conditions before TopN optimization
156    /// For example: `(row_number - 1) = 0` -> `row_number = 1`
157    fn simplify_filter_arithmetic(&self, filter: &LogicalFilter) -> Option<LogicalFilter> {
158        let new_predicate = self.simplify_filter_arithmetic_condition(filter.predicate())?;
159        Some(LogicalFilter::new(filter.input(), new_predicate))
160    }
161
162    /// Simplify arithmetic expressions in the filter condition
163    fn simplify_filter_arithmetic_condition(&self, predicate: &Condition) -> Option<Condition> {
164        let expr = predicate.as_expr_unless_true()?;
165        let mut rewriter = FilterArithmeticRewriter {};
166        let new_expr = rewriter.rewrite_expr(expr.clone());
167
168        if new_expr != expr {
169            Some(Condition::with_expr(new_expr))
170        } else {
171            None
172        }
173    }
174}
175
176/// Filter arithmetic simplification rewriter: simplifies `(col op const) = const2` to `col = (const2 reverse_op const)`
177struct FilterArithmeticRewriter {}
178
179impl ExprRewriter for FilterArithmeticRewriter {
180    fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
181        use ExprType::{
182            Equal, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, NotEqual,
183        };
184
185        // Check if this is a comparison operation
186        match func_call.func_type() {
187            Equal | NotEqual | LessThan | LessThanOrEqual | GreaterThan | GreaterThanOrEqual => {
188                let inputs = func_call.inputs();
189                if inputs.len() == 2
190                    && let Some(simplified) = self.simplify_arithmetic_comparison(
191                        &inputs[0],
192                        &inputs[1],
193                        func_call.func_type(),
194                    )
195                {
196                    return self.rewrite_expr(simplified);
197                }
198            }
199            _ => {}
200        }
201
202        // Recursively handle sub-expressions
203        let (func_type, inputs, ret_type) = func_call.decompose();
204        let new_inputs: Vec<_> = inputs
205            .into_iter()
206            .map(|input| self.rewrite_expr(input))
207            .collect();
208
209        FunctionCall::new_unchecked(func_type, new_inputs, ret_type).into()
210    }
211}
212
213impl FilterArithmeticRewriter {
214    fn reverse_comparison(comparison_op: ExprType) -> ExprType {
215        match comparison_op {
216            ExprType::LessThan => ExprType::GreaterThan,
217            ExprType::LessThanOrEqual => ExprType::GreaterThanOrEqual,
218            ExprType::GreaterThan => ExprType::LessThan,
219            ExprType::GreaterThanOrEqual => ExprType::LessThanOrEqual,
220            ExprType::Equal | ExprType::NotEqual => comparison_op,
221            _ => unreachable!(),
222        }
223    }
224
225    /// Simplify arithmetic comparison: `(col op const1) comp const2` -> `col comp (const2 reverse_op const1)`
226    fn simplify_arithmetic_comparison(
227        &self,
228        left: &ExprImpl,
229        right: &ExprImpl,
230        comparison_op: ExprType,
231    ) -> Option<ExprImpl> {
232        use ExprType::{Add, Subtract};
233
234        let (arithmetic_func, comparison_const, comparison_op) =
235            if let ExprImpl::FunctionCall(left_func) = left
236                && right.is_const()
237            {
238                (left_func, right, comparison_op)
239            } else if left.is_const()
240                && let ExprImpl::FunctionCall(right_func) = right
241            {
242                (right_func, left, Self::reverse_comparison(comparison_op))
243            } else {
244                return None;
245            };
246
247        // Check arithmetic operation
248        match arithmetic_func.func_type() {
249            Add | Subtract => {
250                let inputs = arithmetic_func.inputs();
251                if inputs.len() == 2 {
252                    // Find column reference and constant
253                    let (column_ref, arith_const, reverse_op) = if inputs[1].is_const() {
254                        // col op const
255                        let reverse_op = match arithmetic_func.func_type() {
256                            Add => Subtract,
257                            Subtract => Add,
258                            _ => unreachable!(),
259                        };
260                        (&inputs[0], &inputs[1], reverse_op)
261                    } else if inputs[0].is_const() && arithmetic_func.func_type() == Add {
262                        // const + col
263                        (&inputs[1], &inputs[0], Subtract)
264                    } else {
265                        return None;
266                    };
267
268                    // Calculate new constant value
269                    if let Ok(new_const_func) = FunctionCall::new(
270                        reverse_op,
271                        vec![comparison_const.clone(), arith_const.clone()],
272                    ) {
273                        let new_const_expr: ExprImpl = new_const_func.into();
274                        // Try constant folding
275                        if let Some(Ok(Some(folded_value))) = new_const_expr.try_fold_const() {
276                            let new_const =
277                                Literal::new(Some(folded_value), new_const_expr.return_type())
278                                    .into();
279
280                            // Construct new comparison expression
281                            if let Ok(new_comparison) = FunctionCall::new(
282                                comparison_op,
283                                vec![column_ref.clone(), new_const],
284                            ) {
285                                return Some(new_comparison.into());
286                            }
287                        }
288                    }
289                }
290            }
291            _ => {}
292        }
293
294        None
295    }
296}
297
298/// Returns `None` if the conditions are too complex or invalid. `Some((limit, offset))` otherwise.
299fn handle_rank_preds(rank_preds: &[ExprImpl], window_func_pos: usize) -> Option<(u64, u64)> {
300    if rank_preds.is_empty() {
301        return None;
302    }
303
304    // rank >= lb
305    let mut lb: Option<i64> = None;
306    // rank <= ub
307    let mut ub: Option<i64> = None;
308    // rank == eq
309    let mut eq: Option<i64> = None;
310
311    for cond in rank_preds {
312        if let Some((input_ref, cmp, v)) = cond.as_comparison_const() {
313            assert_eq!(input_ref.index, window_func_pos);
314            let v = v
315                .cast_implicit(&DataType::Int64)
316                .ok()?
317                .fold_const()
318                .ok()??;
319            let v = *v.as_int64();
320            match cmp {
321                ExprType::LessThanOrEqual => ub = ub.map_or(Some(v), |ub| Some(ub.min(v))),
322                ExprType::LessThan => ub = ub.map_or(Some(v - 1), |ub| Some(ub.min(v - 1))),
323                ExprType::GreaterThan => lb = lb.map_or(Some(v + 1), |lb| Some(lb.max(v + 1))),
324                ExprType::GreaterThanOrEqual => lb = lb.map_or(Some(v), |lb| Some(lb.max(v))),
325                _ => unreachable!(),
326            }
327        } else if let Some((input_ref, v)) = cond.as_eq_const() {
328            assert_eq!(input_ref.index, window_func_pos);
329            let v = v
330                .cast_implicit(&DataType::Int64)
331                .ok()?
332                .fold_const()
333                .ok()??;
334            let v = *v.as_int64();
335            if let Some(eq) = eq
336                && eq != v
337            {
338                tracing::warn!(
339                    "Failed to optimize rank predicate with conflicting equal conditions."
340                );
341                return None;
342            }
343            eq = Some(v)
344        } else {
345            // TODO: support between and in
346            tracing::warn!("Failed to optimize complex rank predicate {:?}", cond);
347            return None;
348        }
349    }
350
351    // Note: rank functions start from 1
352    if let Some(eq) = eq {
353        if eq < 1 {
354            tracing::warn!(
355                "Failed to optimize rank predicate with invalid predicate rank={}.",
356                eq
357            );
358            return None;
359        }
360        let lb = lb.unwrap_or(i64::MIN);
361        let ub = ub.unwrap_or(i64::MAX);
362        if !(lb <= eq && eq <= ub) {
363            tracing::warn!("Failed to optimize rank predicate with conflicting bounds.");
364            return None;
365        }
366        Some((1, (eq - 1) as u64))
367    } else {
368        match (lb, ub) {
369            (Some(lb), Some(ub)) => Some(((ub - lb + 1).max(0) as u64, (lb - 1).max(0) as u64)),
370            (Some(lb), None) => Some((LIMIT_ALL_COUNT, (lb - 1).max(0) as u64)),
371            (None, Some(ub)) => Some((ub.max(0) as u64, 0)),
372            (None, None) => unreachable!(),
373        }
374    }
375}
376
377fn add_rank_offset_if_needed(
378    plan: PlanRef,
379    offset: u64,
380    window_func_pos: usize,
381    output_len: usize,
382) -> PlanRef {
383    if offset == 0 {
384        return plan;
385    }
386
387    let schema = plan.schema();
388    let mut exprs = Vec::with_capacity(output_len);
389
390    for idx in 0..output_len {
391        let input: ExprImpl = InputRef::new(idx, schema.fields()[idx].data_type().clone()).into();
392        if idx == window_func_pos {
393            let offset_expr: ExprImpl =
394                Literal::new(Some(ScalarImpl::Int64(offset as i64)), DataType::Int64).into();
395            let adjusted = FunctionCall::new(ExprType::Add, vec![input, offset_expr])
396                .unwrap()
397                .into();
398            exprs.push(adjusted);
399        } else {
400            exprs.push(input);
401        }
402    }
403
404    LogicalProject::create(plan, exprs)
405}