risingwave_frontend/optimizer/rule/
over_window_to_topn_rule.rs1use fixedbitset::FixedBitSet;
16use risingwave_common::types::{DataType, ScalarImpl};
17use risingwave_expr::window_function::WindowFuncKind;
18
19use super::prelude::{PlanRef, *};
20use crate::expr::{
21 Expr, ExprImpl, ExprRewriter, ExprType, FunctionCall, InputRef, Literal, collect_input_refs,
22};
23use crate::optimizer::plan_node::generic::GenericPlanRef;
24use crate::optimizer::plan_node::{LogicalFilter, LogicalProject, LogicalTopN, PlanTreeNodeUnary};
25use crate::optimizer::property::Order;
26use crate::planner::LIMIT_ALL_COUNT;
27use crate::utils::Condition;
28
29pub struct OverWindowToTopNRule;
51
52impl OverWindowToTopNRule {
53 pub fn create() -> BoxedRule {
54 Box::new(OverWindowToTopNRule)
55 }
56}
57
58impl Rule<Logical> for OverWindowToTopNRule {
59 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
60 let ctx = plan.ctx();
61 let (project, plan) = {
62 if let Some(project) = plan.as_logical_project() {
63 (Some(project), project.input())
64 } else {
65 (None, plan)
66 }
67 };
68 let filter = plan.as_logical_filter()?;
69 let plan = filter.input();
70 let over_window = plan.as_logical_over_window()?;
72
73 let filter = if let Some(simplified) = self.simplify_filter_arithmetic(filter) {
75 simplified
76 } else {
77 filter.clone()
78 };
79
80 if over_window.window_functions().len() != 1 {
81 return None;
83 }
84 let window_func = &over_window.window_functions()[0];
85 if !window_func.kind.is_numbering() {
86 return None;
88 }
89
90 let output_len = over_window.schema().len();
91 let window_func_pos = output_len - 1;
92
93 let with_ties = match window_func.kind {
94 WindowFuncKind::RowNumber => false,
96 WindowFuncKind::Rank => true,
97 WindowFuncKind::DenseRank => {
98 ctx.warn_to_user("`dense_rank` is not supported in Top-N pattern, will fallback to inefficient implementation");
99 return None;
100 }
101 _ => unreachable!("window functions other than rank functions should not reach here"),
102 };
103
104 let (rank_pred, other_pred) = {
105 let predicate = filter.predicate();
106 let mut rank_col = FixedBitSet::with_capacity(output_len);
107 rank_col.set(window_func_pos, true);
108 predicate.clone().split_disjoint(&rank_col)
109 };
110
111 let (limit, offset) = handle_rank_preds(&rank_pred.conjunctions, window_func_pos)?;
112
113 if offset > 0 && with_ties {
114 tracing::warn!("Failed to optimize with ties and offset");
115 ctx.warn_to_user("group topN with ties and offset is not supported, see https://www.risingwave.dev/docs/current/sql-pattern-topn/ for more information");
116 return None;
117 }
118
119 let topn: PlanRef = LogicalTopN::new(
120 over_window.input(),
121 limit,
122 offset,
123 with_ties,
124 Order {
125 column_orders: window_func.order_by.clone(),
126 },
127 window_func.partition_by.iter().map(|i| i.index).collect(),
128 )
129 .into();
130 let filter = LogicalFilter::create(topn, other_pred);
131
132 let plan = if let Some(project) = project {
133 let referred_cols = collect_input_refs(output_len, project.exprs());
134 if !referred_cols.contains(window_func_pos) {
135 project.clone_with_input(filter).into()
137 } else {
138 let over_window = over_window.clone_with_input(filter).into();
140 let over_window =
141 add_rank_offset_if_needed(over_window, offset, window_func_pos, output_len);
142 project.clone_with_input(over_window).into()
143 }
144 } else {
145 ctx.warn_to_user("It can be inefficient to output ranking number in Top-N, see https://www.risingwave.dev/docs/current/sql-pattern-topn/ for more information");
147 let over_window = over_window.clone_with_input(filter).into();
148 add_rank_offset_if_needed(over_window, offset, window_func_pos, output_len)
149 };
150 Some(plan)
151 }
152}
153
154impl OverWindowToTopNRule {
155 fn simplify_filter_arithmetic(&self, filter: &LogicalFilter) -> Option<LogicalFilter> {
158 let new_predicate = self.simplify_filter_arithmetic_condition(filter.predicate())?;
159 Some(LogicalFilter::new(filter.input(), new_predicate))
160 }
161
162 fn simplify_filter_arithmetic_condition(&self, predicate: &Condition) -> Option<Condition> {
164 let expr = predicate.as_expr_unless_true()?;
165 let mut rewriter = FilterArithmeticRewriter {};
166 let new_expr = rewriter.rewrite_expr(expr.clone());
167
168 if new_expr != expr {
169 Some(Condition::with_expr(new_expr))
170 } else {
171 None
172 }
173 }
174}
175
176struct FilterArithmeticRewriter {}
178
179impl ExprRewriter for FilterArithmeticRewriter {
180 fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
181 use ExprType::{
182 Equal, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, NotEqual,
183 };
184
185 match func_call.func_type() {
187 Equal | NotEqual | LessThan | LessThanOrEqual | GreaterThan | GreaterThanOrEqual => {
188 let inputs = func_call.inputs();
189 if inputs.len() == 2 {
190 if let ExprImpl::FunctionCall(left_func) = &inputs[0]
192 && inputs[1].is_const()
193 && let Some(simplified) = self.simplify_arithmetic_comparison(
194 left_func,
195 &inputs[1],
196 func_call.func_type(),
197 )
198 {
199 return simplified;
200 }
201 }
202 }
203 _ => {}
204 }
205
206 let (func_type, inputs, ret_type) = func_call.decompose();
208 let new_inputs: Vec<_> = inputs
209 .into_iter()
210 .map(|input| self.rewrite_expr(input))
211 .collect();
212
213 FunctionCall::new_unchecked(func_type, new_inputs, ret_type).into()
214 }
215}
216
217impl FilterArithmeticRewriter {
218 fn simplify_arithmetic_comparison(
220 &self,
221 arithmetic_func: &FunctionCall,
222 comparison_const: &ExprImpl,
223 comparison_op: ExprType,
224 ) -> Option<ExprImpl> {
225 use ExprType::{Add, Subtract};
226
227 match arithmetic_func.func_type() {
229 Add | Subtract => {
230 let inputs = arithmetic_func.inputs();
231 if inputs.len() == 2 {
232 let (column_ref, arith_const, reverse_op) = if inputs[1].is_const() {
234 let reverse_op = match arithmetic_func.func_type() {
236 Add => Subtract,
237 Subtract => Add,
238 _ => unreachable!(),
239 };
240 (&inputs[0], &inputs[1], reverse_op)
241 } else if inputs[0].is_const() && arithmetic_func.func_type() == Add {
242 (&inputs[1], &inputs[0], Subtract)
244 } else {
245 return None;
246 };
247
248 if let Ok(new_const_func) = FunctionCall::new(
250 reverse_op,
251 vec![comparison_const.clone(), arith_const.clone()],
252 ) {
253 let new_const_expr: ExprImpl = new_const_func.into();
254 if let Some(Ok(Some(folded_value))) = new_const_expr.try_fold_const() {
256 let new_const =
257 Literal::new(Some(folded_value), new_const_expr.return_type())
258 .into();
259
260 if let Ok(new_comparison) = FunctionCall::new(
262 comparison_op,
263 vec![column_ref.clone(), new_const],
264 ) {
265 return Some(new_comparison.into());
266 }
267 }
268 }
269 }
270 }
271 _ => {}
272 }
273
274 None
275 }
276}
277
278fn handle_rank_preds(rank_preds: &[ExprImpl], window_func_pos: usize) -> Option<(u64, u64)> {
280 if rank_preds.is_empty() {
281 return None;
282 }
283
284 let mut lb: Option<i64> = None;
286 let mut ub: Option<i64> = None;
288 let mut eq: Option<i64> = None;
290
291 for cond in rank_preds {
292 if let Some((input_ref, cmp, v)) = cond.as_comparison_const() {
293 assert_eq!(input_ref.index, window_func_pos);
294 let v = v
295 .cast_implicit(&DataType::Int64)
296 .ok()?
297 .fold_const()
298 .ok()??;
299 let v = *v.as_int64();
300 match cmp {
301 ExprType::LessThanOrEqual => ub = ub.map_or(Some(v), |ub| Some(ub.min(v))),
302 ExprType::LessThan => ub = ub.map_or(Some(v - 1), |ub| Some(ub.min(v - 1))),
303 ExprType::GreaterThan => lb = lb.map_or(Some(v + 1), |lb| Some(lb.max(v + 1))),
304 ExprType::GreaterThanOrEqual => lb = lb.map_or(Some(v), |lb| Some(lb.max(v))),
305 _ => unreachable!(),
306 }
307 } else if let Some((input_ref, v)) = cond.as_eq_const() {
308 assert_eq!(input_ref.index, window_func_pos);
309 let v = v
310 .cast_implicit(&DataType::Int64)
311 .ok()?
312 .fold_const()
313 .ok()??;
314 let v = *v.as_int64();
315 if let Some(eq) = eq
316 && eq != v
317 {
318 tracing::warn!(
319 "Failed to optimize rank predicate with conflicting equal conditions."
320 );
321 return None;
322 }
323 eq = Some(v)
324 } else {
325 tracing::warn!("Failed to optimize complex rank predicate {:?}", cond);
327 return None;
328 }
329 }
330
331 if let Some(eq) = eq {
333 if eq < 1 {
334 tracing::warn!(
335 "Failed to optimize rank predicate with invalid predicate rank={}.",
336 eq
337 );
338 return None;
339 }
340 let lb = lb.unwrap_or(i64::MIN);
341 let ub = ub.unwrap_or(i64::MAX);
342 if !(lb <= eq && eq <= ub) {
343 tracing::warn!("Failed to optimize rank predicate with conflicting bounds.");
344 return None;
345 }
346 Some((1, (eq - 1) as u64))
347 } else {
348 match (lb, ub) {
349 (Some(lb), Some(ub)) => Some(((ub - lb + 1).max(0) as u64, (lb - 1).max(0) as u64)),
350 (Some(lb), None) => Some((LIMIT_ALL_COUNT, (lb - 1).max(0) as u64)),
351 (None, Some(ub)) => Some((ub.max(0) as u64, 0)),
352 (None, None) => unreachable!(),
353 }
354 }
355}
356
357fn add_rank_offset_if_needed(
358 plan: PlanRef,
359 offset: u64,
360 window_func_pos: usize,
361 output_len: usize,
362) -> PlanRef {
363 if offset == 0 {
364 return plan;
365 }
366
367 let schema = plan.schema();
368 let mut exprs = Vec::with_capacity(output_len);
369
370 for idx in 0..output_len {
371 let input: ExprImpl = InputRef::new(idx, schema.fields()[idx].data_type().clone()).into();
372 if idx == window_func_pos {
373 let offset_expr: ExprImpl =
374 Literal::new(Some(ScalarImpl::Int64(offset as i64)), DataType::Int64).into();
375 let adjusted = FunctionCall::new(ExprType::Add, vec![input, offset_expr])
376 .unwrap()
377 .into();
378 exprs.push(adjusted);
379 } else {
380 exprs.push(input);
381 }
382 }
383
384 LogicalProject::create(plan, exprs)
385}