risingwave_frontend/optimizer/rule/
over_window_to_topn_rule.rs1use fixedbitset::FixedBitSet;
16use risingwave_common::types::DataType;
17use risingwave_expr::window_function::WindowFuncKind;
18
19use super::{BoxedRule, Rule};
20use crate::PlanRef;
21use crate::expr::{ExprImpl, ExprType, collect_input_refs};
22use crate::optimizer::plan_node::generic::GenericPlanRef;
23use crate::optimizer::plan_node::{LogicalFilter, LogicalTopN, PlanTreeNodeUnary};
24use crate::optimizer::property::Order;
25use crate::planner::LIMIT_ALL_COUNT;
26
27pub struct OverWindowToTopNRule;
46
47impl OverWindowToTopNRule {
48 pub fn create() -> BoxedRule {
49 Box::new(OverWindowToTopNRule)
50 }
51}
52
53impl Rule for OverWindowToTopNRule {
54 fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
55 let ctx = plan.ctx();
56 let (project, plan) = {
57 if let Some(project) = plan.as_logical_project() {
58 (Some(project), project.input())
59 } else {
60 (None, plan)
61 }
62 };
63 let filter = plan.as_logical_filter()?;
64 let plan = filter.input();
65 let over_window = plan.as_logical_over_window()?;
67
68 if over_window.window_functions().len() != 1 {
69 return None;
71 }
72 let window_func = &over_window.window_functions()[0];
73 if !window_func.kind.is_numbering() {
74 return None;
76 }
77
78 let output_len = over_window.schema().len();
79 let window_func_pos = output_len - 1;
80
81 let with_ties = match window_func.kind {
82 WindowFuncKind::RowNumber => false,
84 WindowFuncKind::Rank => true,
85 WindowFuncKind::DenseRank => {
86 ctx.warn_to_user("`dense_rank` is not supported in Top-N pattern, will fallback to inefficient implementation");
87 return None;
88 }
89 _ => unreachable!("window functions other than rank functions should not reach here"),
90 };
91
92 let (rank_pred, other_pred) = {
93 let predicate = filter.predicate();
94 let mut rank_col = FixedBitSet::with_capacity(output_len);
95 rank_col.set(window_func_pos, true);
96 predicate.clone().split_disjoint(&rank_col)
97 };
98
99 let (limit, offset) = handle_rank_preds(&rank_pred.conjunctions, window_func_pos)?;
100
101 if offset > 0 && with_ties {
102 tracing::warn!("Failed to optimize with ties and offset");
103 ctx.warn_to_user("group topN with ties and offset is not supported, see https://www.risingwave.dev/docs/current/sql-pattern-topn/ for more information");
104 return None;
105 }
106
107 let topn: PlanRef = LogicalTopN::new(
108 over_window.input(),
109 limit,
110 offset,
111 with_ties,
112 Order {
113 column_orders: window_func.order_by.to_vec(),
114 },
115 window_func.partition_by.iter().map(|i| i.index).collect(),
116 )
117 .into();
118 let filter = LogicalFilter::create(topn, other_pred);
119
120 let plan = if let Some(project) = project {
121 let referred_cols = collect_input_refs(output_len, project.exprs());
122 if !referred_cols.contains(window_func_pos) {
123 project.clone_with_input(filter).into()
125 } else {
126 project
128 .clone_with_input(over_window.clone_with_input(filter).into())
129 .into()
130 }
131 } else {
132 ctx.warn_to_user("It can be inefficient to output ranking number in Top-N, see https://www.risingwave.dev/docs/current/sql-pattern-topn/ for more information");
134 over_window.clone_with_input(filter).into()
135 };
136 Some(plan)
137 }
138}
139
140fn handle_rank_preds(rank_preds: &[ExprImpl], window_func_pos: usize) -> Option<(u64, u64)> {
142 if rank_preds.is_empty() {
143 return None;
144 }
145
146 let mut lb: Option<i64> = None;
148 let mut ub: Option<i64> = None;
150 let mut eq: Option<i64> = None;
152
153 for cond in rank_preds {
154 if let Some((input_ref, cmp, v)) = cond.as_comparison_const() {
155 assert_eq!(input_ref.index, window_func_pos);
156 let v = v.cast_implicit(DataType::Int64).ok()?.fold_const().ok()??;
157 let v = *v.as_int64();
158 match cmp {
159 ExprType::LessThanOrEqual => ub = ub.map_or(Some(v), |ub| Some(ub.min(v))),
160 ExprType::LessThan => ub = ub.map_or(Some(v - 1), |ub| Some(ub.min(v - 1))),
161 ExprType::GreaterThan => lb = lb.map_or(Some(v + 1), |lb| Some(lb.max(v + 1))),
162 ExprType::GreaterThanOrEqual => lb = lb.map_or(Some(v), |lb| Some(lb.max(v))),
163 _ => unreachable!(),
164 }
165 } else if let Some((input_ref, v)) = cond.as_eq_const() {
166 assert_eq!(input_ref.index, window_func_pos);
167 let v = v.cast_implicit(DataType::Int64).ok()?.fold_const().ok()??;
168 let v = *v.as_int64();
169 if let Some(eq) = eq
170 && eq != v
171 {
172 tracing::warn!(
173 "Failed to optimize rank predicate with conflicting equal conditions."
174 );
175 return None;
176 }
177 eq = Some(v)
178 } else {
179 tracing::warn!("Failed to optimize complex rank predicate {:?}", cond);
181 return None;
182 }
183 }
184
185 if let Some(eq) = eq {
187 if eq < 1 {
188 tracing::warn!(
189 "Failed to optimize rank predicate with invalid predicate rank={}.",
190 eq
191 );
192 return None;
193 }
194 let lb = lb.unwrap_or(i64::MIN);
195 let ub = ub.unwrap_or(i64::MAX);
196 if !(lb <= eq && eq <= ub) {
197 tracing::warn!("Failed to optimize rank predicate with conflicting bounds.");
198 return None;
199 }
200 Some((1, (eq - 1) as u64))
201 } else {
202 match (lb, ub) {
203 (Some(lb), Some(ub)) => Some(((ub - lb + 1).max(0) as u64, (lb - 1).max(0) as u64)),
204 (Some(lb), None) => Some((LIMIT_ALL_COUNT, (lb - 1).max(0) as u64)),
205 (None, Some(ub)) => Some((ub.max(0) as u64, 0)),
206 (None, None) => unreachable!(),
207 }
208 }
209}