risingwave_frontend/optimizer/rule/
over_window_to_topn_rule.rs1use fixedbitset::FixedBitSet;
16use risingwave_common::types::DataType;
17use risingwave_expr::window_function::WindowFuncKind;
18
19use super::{BoxedRule, Rule};
20use crate::PlanRef;
21use crate::expr::{
22 Expr, ExprImpl, ExprRewriter, ExprType, FunctionCall, Literal, collect_input_refs,
23};
24use crate::optimizer::plan_node::generic::GenericPlanRef;
25use crate::optimizer::plan_node::{LogicalFilter, LogicalTopN, PlanTreeNodeUnary};
26use crate::optimizer::property::Order;
27use crate::planner::LIMIT_ALL_COUNT;
28use crate::utils::Condition;
29
30pub struct OverWindowToTopNRule;
52
53impl OverWindowToTopNRule {
54 pub fn create() -> BoxedRule {
55 Box::new(OverWindowToTopNRule)
56 }
57}
58
59impl Rule for OverWindowToTopNRule {
60 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
61 let ctx = plan.ctx();
62 let (project, plan) = {
63 if let Some(project) = plan.as_logical_project() {
64 (Some(project), project.input())
65 } else {
66 (None, plan)
67 }
68 };
69 let filter = plan.as_logical_filter()?;
70 let plan = filter.input();
71 let over_window = plan.as_logical_over_window()?;
73
74 let filter = if let Some(simplified) = self.simplify_filter_arithmetic(filter) {
76 simplified
77 } else {
78 filter.clone()
79 };
80
81 if over_window.window_functions().len() != 1 {
82 return None;
84 }
85 let window_func = &over_window.window_functions()[0];
86 if !window_func.kind.is_numbering() {
87 return None;
89 }
90
91 let output_len = over_window.schema().len();
92 let window_func_pos = output_len - 1;
93
94 let with_ties = match window_func.kind {
95 WindowFuncKind::RowNumber => false,
97 WindowFuncKind::Rank => true,
98 WindowFuncKind::DenseRank => {
99 ctx.warn_to_user("`dense_rank` is not supported in Top-N pattern, will fallback to inefficient implementation");
100 return None;
101 }
102 _ => unreachable!("window functions other than rank functions should not reach here"),
103 };
104
105 let (rank_pred, other_pred) = {
106 let predicate = filter.predicate();
107 let mut rank_col = FixedBitSet::with_capacity(output_len);
108 rank_col.set(window_func_pos, true);
109 predicate.clone().split_disjoint(&rank_col)
110 };
111
112 let (limit, offset) = handle_rank_preds(&rank_pred.conjunctions, window_func_pos)?;
113
114 if offset > 0 && with_ties {
115 tracing::warn!("Failed to optimize with ties and offset");
116 ctx.warn_to_user("group topN with ties and offset is not supported, see https://www.risingwave.dev/docs/current/sql-pattern-topn/ for more information");
117 return None;
118 }
119
120 let topn: PlanRef = LogicalTopN::new(
121 over_window.input(),
122 limit,
123 offset,
124 with_ties,
125 Order {
126 column_orders: window_func.order_by.to_vec(),
127 },
128 window_func.partition_by.iter().map(|i| i.index).collect(),
129 )
130 .into();
131 let filter = LogicalFilter::create(topn, other_pred);
132
133 let plan = if let Some(project) = project {
134 let referred_cols = collect_input_refs(output_len, project.exprs());
135 if !referred_cols.contains(window_func_pos) {
136 project.clone_with_input(filter).into()
138 } else {
139 project
141 .clone_with_input(over_window.clone_with_input(filter).into())
142 .into()
143 }
144 } else {
145 ctx.warn_to_user("It can be inefficient to output ranking number in Top-N, see https://www.risingwave.dev/docs/current/sql-pattern-topn/ for more information");
147 over_window.clone_with_input(filter).into()
148 };
149 Some(plan)
150 }
151}
152
153impl OverWindowToTopNRule {
154 fn simplify_filter_arithmetic(&self, filter: &LogicalFilter) -> Option<LogicalFilter> {
157 let new_predicate = self.simplify_filter_arithmetic_condition(filter.predicate())?;
158 Some(LogicalFilter::new(filter.input(), new_predicate))
159 }
160
161 fn simplify_filter_arithmetic_condition(&self, predicate: &Condition) -> Option<Condition> {
163 let expr = predicate.as_expr_unless_true()?;
164 let mut rewriter = FilterArithmeticRewriter {};
165 let new_expr = rewriter.rewrite_expr(expr.clone());
166
167 if new_expr != expr {
168 Some(Condition::with_expr(new_expr))
169 } else {
170 None
171 }
172 }
173}
174
175struct FilterArithmeticRewriter {}
177
178impl ExprRewriter for FilterArithmeticRewriter {
179 fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
180 use ExprType::{
181 Equal, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, NotEqual,
182 };
183
184 match func_call.func_type() {
186 Equal | NotEqual | LessThan | LessThanOrEqual | GreaterThan | GreaterThanOrEqual => {
187 let inputs = func_call.inputs();
188 if inputs.len() == 2 {
189 if let ExprImpl::FunctionCall(left_func) = &inputs[0]
191 && inputs[1].is_const()
192 && let Some(simplified) = self.simplify_arithmetic_comparison(
193 left_func,
194 &inputs[1],
195 func_call.func_type(),
196 )
197 {
198 return simplified;
199 }
200 }
201 }
202 _ => {}
203 }
204
205 let (func_type, inputs, ret_type) = func_call.decompose();
207 let new_inputs: Vec<_> = inputs
208 .into_iter()
209 .map(|input| self.rewrite_expr(input))
210 .collect();
211
212 FunctionCall::new_unchecked(func_type, new_inputs, ret_type).into()
213 }
214}
215
216impl FilterArithmeticRewriter {
217 fn simplify_arithmetic_comparison(
219 &self,
220 arithmetic_func: &FunctionCall,
221 comparison_const: &ExprImpl,
222 comparison_op: ExprType,
223 ) -> Option<ExprImpl> {
224 use ExprType::{Add, Subtract};
225
226 match arithmetic_func.func_type() {
228 Add | Subtract => {
229 let inputs = arithmetic_func.inputs();
230 if inputs.len() == 2 {
231 let (column_ref, arith_const, reverse_op) = if inputs[1].is_const() {
233 let reverse_op = match arithmetic_func.func_type() {
235 Add => Subtract,
236 Subtract => Add,
237 _ => unreachable!(),
238 };
239 (&inputs[0], &inputs[1], reverse_op)
240 } else if inputs[0].is_const() && arithmetic_func.func_type() == Add {
241 (&inputs[1], &inputs[0], Subtract)
243 } else {
244 return None;
245 };
246
247 if let Ok(new_const_func) = FunctionCall::new(
249 reverse_op,
250 vec![comparison_const.clone(), arith_const.clone()],
251 ) {
252 let new_const_expr: ExprImpl = new_const_func.into();
253 if let Some(Ok(Some(folded_value))) = new_const_expr.try_fold_const() {
255 let new_const =
256 Literal::new(Some(folded_value), new_const_expr.return_type())
257 .into();
258
259 if let Ok(new_comparison) = FunctionCall::new(
261 comparison_op,
262 vec![column_ref.clone(), new_const],
263 ) {
264 return Some(new_comparison.into());
265 }
266 }
267 }
268 }
269 }
270 _ => {}
271 }
272
273 None
274 }
275}
276
277fn handle_rank_preds(rank_preds: &[ExprImpl], window_func_pos: usize) -> Option<(u64, u64)> {
279 if rank_preds.is_empty() {
280 return None;
281 }
282
283 let mut lb: Option<i64> = None;
285 let mut ub: Option<i64> = None;
287 let mut eq: Option<i64> = None;
289
290 for cond in rank_preds {
291 if let Some((input_ref, cmp, v)) = cond.as_comparison_const() {
292 assert_eq!(input_ref.index, window_func_pos);
293 let v = v.cast_implicit(DataType::Int64).ok()?.fold_const().ok()??;
294 let v = *v.as_int64();
295 match cmp {
296 ExprType::LessThanOrEqual => ub = ub.map_or(Some(v), |ub| Some(ub.min(v))),
297 ExprType::LessThan => ub = ub.map_or(Some(v - 1), |ub| Some(ub.min(v - 1))),
298 ExprType::GreaterThan => lb = lb.map_or(Some(v + 1), |lb| Some(lb.max(v + 1))),
299 ExprType::GreaterThanOrEqual => lb = lb.map_or(Some(v), |lb| Some(lb.max(v))),
300 _ => unreachable!(),
301 }
302 } else if let Some((input_ref, v)) = cond.as_eq_const() {
303 assert_eq!(input_ref.index, window_func_pos);
304 let v = v.cast_implicit(DataType::Int64).ok()?.fold_const().ok()??;
305 let v = *v.as_int64();
306 if let Some(eq) = eq
307 && eq != v
308 {
309 tracing::warn!(
310 "Failed to optimize rank predicate with conflicting equal conditions."
311 );
312 return None;
313 }
314 eq = Some(v)
315 } else {
316 tracing::warn!("Failed to optimize complex rank predicate {:?}", cond);
318 return None;
319 }
320 }
321
322 if let Some(eq) = eq {
324 if eq < 1 {
325 tracing::warn!(
326 "Failed to optimize rank predicate with invalid predicate rank={}.",
327 eq
328 );
329 return None;
330 }
331 let lb = lb.unwrap_or(i64::MIN);
332 let ub = ub.unwrap_or(i64::MAX);
333 if !(lb <= eq && eq <= ub) {
334 tracing::warn!("Failed to optimize rank predicate with conflicting bounds.");
335 return None;
336 }
337 Some((1, (eq - 1) as u64))
338 } else {
339 match (lb, ub) {
340 (Some(lb), Some(ub)) => Some(((ub - lb + 1).max(0) as u64, (lb - 1).max(0) as u64)),
341 (Some(lb), None) => Some((LIMIT_ALL_COUNT, (lb - 1).max(0) as u64)),
342 (None, Some(ub)) => Some((ub.max(0) as u64, 0)),
343 (None, None) => unreachable!(),
344 }
345 }
346}