risingwave_frontend/optimizer/rule/
over_window_to_topn_rule.rs1use fixedbitset::FixedBitSet;
16use risingwave_common::types::DataType;
17use risingwave_expr::window_function::WindowFuncKind;
18
19use super::prelude::{PlanRef, *};
20use crate::expr::{
21 Expr, ExprImpl, ExprRewriter, ExprType, FunctionCall, Literal, collect_input_refs,
22};
23use crate::optimizer::plan_node::generic::GenericPlanRef;
24use crate::optimizer::plan_node::{LogicalFilter, LogicalTopN, PlanTreeNodeUnary};
25use crate::optimizer::property::Order;
26use crate::planner::LIMIT_ALL_COUNT;
27use crate::utils::Condition;
28
29pub struct OverWindowToTopNRule;
51
52impl OverWindowToTopNRule {
53 pub fn create() -> BoxedRule {
54 Box::new(OverWindowToTopNRule)
55 }
56}
57
58impl Rule<Logical> for OverWindowToTopNRule {
59 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
60 let ctx = plan.ctx();
61 let (project, plan) = {
62 if let Some(project) = plan.as_logical_project() {
63 (Some(project), project.input())
64 } else {
65 (None, plan)
66 }
67 };
68 let filter = plan.as_logical_filter()?;
69 let plan = filter.input();
70 let over_window = plan.as_logical_over_window()?;
72
73 let filter = if let Some(simplified) = self.simplify_filter_arithmetic(filter) {
75 simplified
76 } else {
77 filter.clone()
78 };
79
80 if over_window.window_functions().len() != 1 {
81 return None;
83 }
84 let window_func = &over_window.window_functions()[0];
85 if !window_func.kind.is_numbering() {
86 return None;
88 }
89
90 let output_len = over_window.schema().len();
91 let window_func_pos = output_len - 1;
92
93 let with_ties = match window_func.kind {
94 WindowFuncKind::RowNumber => false,
96 WindowFuncKind::Rank => true,
97 WindowFuncKind::DenseRank => {
98 ctx.warn_to_user("`dense_rank` is not supported in Top-N pattern, will fallback to inefficient implementation");
99 return None;
100 }
101 _ => unreachable!("window functions other than rank functions should not reach here"),
102 };
103
104 let (rank_pred, other_pred) = {
105 let predicate = filter.predicate();
106 let mut rank_col = FixedBitSet::with_capacity(output_len);
107 rank_col.set(window_func_pos, true);
108 predicate.clone().split_disjoint(&rank_col)
109 };
110
111 let (limit, offset) = handle_rank_preds(&rank_pred.conjunctions, window_func_pos)?;
112
113 if offset > 0 && with_ties {
114 tracing::warn!("Failed to optimize with ties and offset");
115 ctx.warn_to_user("group topN with ties and offset is not supported, see https://www.risingwave.dev/docs/current/sql-pattern-topn/ for more information");
116 return None;
117 }
118
119 let topn: PlanRef = LogicalTopN::new(
120 over_window.input(),
121 limit,
122 offset,
123 with_ties,
124 Order {
125 column_orders: window_func.order_by.to_vec(),
126 },
127 window_func.partition_by.iter().map(|i| i.index).collect(),
128 )
129 .into();
130 let filter = LogicalFilter::create(topn, other_pred);
131
132 let plan = if let Some(project) = project {
133 let referred_cols = collect_input_refs(output_len, project.exprs());
134 if !referred_cols.contains(window_func_pos) {
135 project.clone_with_input(filter).into()
137 } else {
138 project
140 .clone_with_input(over_window.clone_with_input(filter).into())
141 .into()
142 }
143 } else {
144 ctx.warn_to_user("It can be inefficient to output ranking number in Top-N, see https://www.risingwave.dev/docs/current/sql-pattern-topn/ for more information");
146 over_window.clone_with_input(filter).into()
147 };
148 Some(plan)
149 }
150}
151
152impl OverWindowToTopNRule {
153 fn simplify_filter_arithmetic(&self, filter: &LogicalFilter) -> Option<LogicalFilter> {
156 let new_predicate = self.simplify_filter_arithmetic_condition(filter.predicate())?;
157 Some(LogicalFilter::new(filter.input(), new_predicate))
158 }
159
160 fn simplify_filter_arithmetic_condition(&self, predicate: &Condition) -> Option<Condition> {
162 let expr = predicate.as_expr_unless_true()?;
163 let mut rewriter = FilterArithmeticRewriter {};
164 let new_expr = rewriter.rewrite_expr(expr.clone());
165
166 if new_expr != expr {
167 Some(Condition::with_expr(new_expr))
168 } else {
169 None
170 }
171 }
172}
173
174struct FilterArithmeticRewriter {}
176
177impl ExprRewriter for FilterArithmeticRewriter {
178 fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
179 use ExprType::{
180 Equal, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, NotEqual,
181 };
182
183 match func_call.func_type() {
185 Equal | NotEqual | LessThan | LessThanOrEqual | GreaterThan | GreaterThanOrEqual => {
186 let inputs = func_call.inputs();
187 if inputs.len() == 2 {
188 if let ExprImpl::FunctionCall(left_func) = &inputs[0]
190 && inputs[1].is_const()
191 && let Some(simplified) = self.simplify_arithmetic_comparison(
192 left_func,
193 &inputs[1],
194 func_call.func_type(),
195 )
196 {
197 return simplified;
198 }
199 }
200 }
201 _ => {}
202 }
203
204 let (func_type, inputs, ret_type) = func_call.decompose();
206 let new_inputs: Vec<_> = inputs
207 .into_iter()
208 .map(|input| self.rewrite_expr(input))
209 .collect();
210
211 FunctionCall::new_unchecked(func_type, new_inputs, ret_type).into()
212 }
213}
214
215impl FilterArithmeticRewriter {
216 fn simplify_arithmetic_comparison(
218 &self,
219 arithmetic_func: &FunctionCall,
220 comparison_const: &ExprImpl,
221 comparison_op: ExprType,
222 ) -> Option<ExprImpl> {
223 use ExprType::{Add, Subtract};
224
225 match arithmetic_func.func_type() {
227 Add | Subtract => {
228 let inputs = arithmetic_func.inputs();
229 if inputs.len() == 2 {
230 let (column_ref, arith_const, reverse_op) = if inputs[1].is_const() {
232 let reverse_op = match arithmetic_func.func_type() {
234 Add => Subtract,
235 Subtract => Add,
236 _ => unreachable!(),
237 };
238 (&inputs[0], &inputs[1], reverse_op)
239 } else if inputs[0].is_const() && arithmetic_func.func_type() == Add {
240 (&inputs[1], &inputs[0], Subtract)
242 } else {
243 return None;
244 };
245
246 if let Ok(new_const_func) = FunctionCall::new(
248 reverse_op,
249 vec![comparison_const.clone(), arith_const.clone()],
250 ) {
251 let new_const_expr: ExprImpl = new_const_func.into();
252 if let Some(Ok(Some(folded_value))) = new_const_expr.try_fold_const() {
254 let new_const =
255 Literal::new(Some(folded_value), new_const_expr.return_type())
256 .into();
257
258 if let Ok(new_comparison) = FunctionCall::new(
260 comparison_op,
261 vec![column_ref.clone(), new_const],
262 ) {
263 return Some(new_comparison.into());
264 }
265 }
266 }
267 }
268 }
269 _ => {}
270 }
271
272 None
273 }
274}
275
276fn handle_rank_preds(rank_preds: &[ExprImpl], window_func_pos: usize) -> Option<(u64, u64)> {
278 if rank_preds.is_empty() {
279 return None;
280 }
281
282 let mut lb: Option<i64> = None;
284 let mut ub: Option<i64> = None;
286 let mut eq: Option<i64> = None;
288
289 for cond in rank_preds {
290 if let Some((input_ref, cmp, v)) = cond.as_comparison_const() {
291 assert_eq!(input_ref.index, window_func_pos);
292 let v = v
293 .cast_implicit(&DataType::Int64)
294 .ok()?
295 .fold_const()
296 .ok()??;
297 let v = *v.as_int64();
298 match cmp {
299 ExprType::LessThanOrEqual => ub = ub.map_or(Some(v), |ub| Some(ub.min(v))),
300 ExprType::LessThan => ub = ub.map_or(Some(v - 1), |ub| Some(ub.min(v - 1))),
301 ExprType::GreaterThan => lb = lb.map_or(Some(v + 1), |lb| Some(lb.max(v + 1))),
302 ExprType::GreaterThanOrEqual => lb = lb.map_or(Some(v), |lb| Some(lb.max(v))),
303 _ => unreachable!(),
304 }
305 } else if let Some((input_ref, v)) = cond.as_eq_const() {
306 assert_eq!(input_ref.index, window_func_pos);
307 let v = v
308 .cast_implicit(&DataType::Int64)
309 .ok()?
310 .fold_const()
311 .ok()??;
312 let v = *v.as_int64();
313 if let Some(eq) = eq
314 && eq != v
315 {
316 tracing::warn!(
317 "Failed to optimize rank predicate with conflicting equal conditions."
318 );
319 return None;
320 }
321 eq = Some(v)
322 } else {
323 tracing::warn!("Failed to optimize complex rank predicate {:?}", cond);
325 return None;
326 }
327 }
328
329 if let Some(eq) = eq {
331 if eq < 1 {
332 tracing::warn!(
333 "Failed to optimize rank predicate with invalid predicate rank={}.",
334 eq
335 );
336 return None;
337 }
338 let lb = lb.unwrap_or(i64::MIN);
339 let ub = ub.unwrap_or(i64::MAX);
340 if !(lb <= eq && eq <= ub) {
341 tracing::warn!("Failed to optimize rank predicate with conflicting bounds.");
342 return None;
343 }
344 Some((1, (eq - 1) as u64))
345 } else {
346 match (lb, ub) {
347 (Some(lb), Some(ub)) => Some(((ub - lb + 1).max(0) as u64, (lb - 1).max(0) as u64)),
348 (Some(lb), None) => Some((LIMIT_ALL_COUNT, (lb - 1).max(0) as u64)),
349 (None, Some(ub)) => Some((ub.max(0) as u64, 0)),
350 (None, None) => unreachable!(),
351 }
352 }
353}