risingwave_frontend/optimizer/plan_node/
logical_topn.rs1use fixedbitset::FixedBitSet;
16use itertools::Itertools;
17use risingwave_common::bail_not_implemented;
18use risingwave_common::util::sort_util::ColumnOrder;
19
20use super::generic::{GenericPlanRef, TopNLimit};
21use super::utils::impl_distill_by_unit;
22use super::{
23 BatchGroupTopN, ColPrunable, ExprRewritable, Logical, PlanBase, PlanRef, PlanTreeNodeUnary,
24 PredicatePushdown, StreamGroupTopN, StreamProject, ToBatch, ToStream, gen_filter_and_pushdown,
25 generic,
26};
27use crate::error::{ErrorCode, Result, RwError};
28use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
29use crate::optimizer::plan_node::{
30 BatchTopN, ColumnPruningContext, LogicalProject, PredicatePushdownContext,
31 RewriteStreamContext, StreamTopN, ToStreamContext,
32};
33use crate::optimizer::property::{Distribution, Order, RequiredDist};
34use crate::planner::LIMIT_ALL_COUNT;
35use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, Condition};
36
37#[derive(Debug, Clone, PartialEq, Eq, Hash)]
39pub struct LogicalTopN {
40 pub base: PlanBase<Logical>,
41 core: generic::TopN<PlanRef>,
42}
43
44impl From<generic::TopN<PlanRef>> for LogicalTopN {
45 fn from(core: generic::TopN<PlanRef>) -> Self {
46 let base = PlanBase::new_logical_with_core(&core);
47 Self { base, core }
48 }
49}
50
51impl LogicalTopN {
52 pub fn new(
53 input: PlanRef,
54 limit: u64,
55 offset: u64,
56 with_ties: bool,
57 order: Order,
58 group_key: Vec<usize>,
59 ) -> Self {
60 let limit_attr = TopNLimit::new(limit, with_ties);
61 let core = generic::TopN::with_group(input, limit_attr, offset, order, group_key);
62 core.into()
63 }
64
65 pub fn create(
66 input: PlanRef,
67 limit: u64,
68 offset: u64,
69 order: Order,
70 with_ties: bool,
71 group_key: Vec<usize>,
72 ) -> Result<PlanRef> {
73 if with_ties && offset > 0 {
74 bail_not_implemented!("WITH TIES is not supported with OFFSET");
75 }
76 Ok(Self::new(input, limit, offset, with_ties, order, group_key).into())
77 }
78
79 pub fn limit_attr(&self) -> TopNLimit {
80 self.core.limit_attr
81 }
82
83 pub fn offset(&self) -> u64 {
84 self.core.offset
85 }
86
87 pub fn topn_order(&self) -> &Order {
94 &self.core.order
95 }
96
97 pub fn group_key(&self) -> &[usize] {
98 &self.core.group_key
99 }
100
101 pub fn decompose(self) -> (PlanRef, u64, u64, bool, Order, Vec<usize>) {
103 self.core.decompose()
104 }
105
106 fn gen_dist_stream_top_n_plan(&self, stream_input: PlanRef) -> Result<PlanRef> {
107 use super::stream::prelude::*;
108
109 let input_dist = stream_input.distribution().clone();
110
111 if stream_input.append_only() {
113 return self.gen_single_stream_top_n_plan(stream_input);
114 }
115
116 match input_dist {
117 Distribution::Single | Distribution::SomeShard => {
118 self.gen_single_stream_top_n_plan(stream_input)
119 }
120 Distribution::Broadcast => bail_not_implemented!("topN does not support Broadcast"),
121 Distribution::HashShard(dists) | Distribution::UpstreamHashShard(dists, _) => {
122 self.gen_vnode_two_phase_stream_top_n_plan(stream_input, &dists)
123 }
124 }
125 }
126
127 fn gen_single_stream_top_n_plan(&self, stream_input: PlanRef) -> Result<PlanRef> {
128 let input = RequiredDist::single().enforce_if_not_satisfies(stream_input, &Order::any())?;
129 let mut core = self.core.clone();
130 core.input = input;
131 Ok(StreamTopN::new(core).into())
132 }
133
134 fn gen_vnode_two_phase_stream_top_n_plan(
135 &self,
136 stream_input: PlanRef,
137 dist_key: &[usize],
138 ) -> Result<PlanRef> {
139 let project = StreamProject::new(generic::Project::with_vnode_col(stream_input, dist_key));
141 let vnode_col_idx = project.base.schema().len() - 1;
142
143 let limit_attr = TopNLimit::new(
144 self.limit_attr().limit() + self.offset(),
145 self.limit_attr().with_ties(),
146 );
147 let local_top_n = generic::TopN::with_group(
148 project.into(),
149 limit_attr,
150 0,
151 self.topn_order().clone(),
152 vec![vnode_col_idx],
153 );
154 let local_top_n = StreamGroupTopN::new(local_top_n, Some(vnode_col_idx));
155
156 let exchange =
157 RequiredDist::single().enforce_if_not_satisfies(local_top_n.into(), &Order::any())?;
158
159 let global_top_n = generic::TopN::without_group(
160 exchange,
161 self.limit_attr(),
162 self.offset(),
163 self.topn_order().clone(),
164 );
165 let global_top_n = StreamTopN::new(global_top_n);
166
167 assert_eq!(vnode_col_idx, global_top_n.base.schema().len() - 1);
169 let project = StreamProject::new(generic::Project::with_out_col_idx(
170 global_top_n.into(),
171 0..vnode_col_idx,
172 ));
173 Ok(project.into())
174 }
175
176 pub fn clone_with_input_and_prefix(&self, input: PlanRef, prefix: Order) -> Self {
177 let mut core = self.core.clone();
178 core.input = input;
179 core.order = prefix.concat(core.order);
180 core.into()
181 }
182}
183
184impl PlanTreeNodeUnary for LogicalTopN {
185 fn input(&self) -> PlanRef {
186 self.core.input.clone()
187 }
188
189 fn clone_with_input(&self, input: PlanRef) -> Self {
190 let mut core = self.core.clone();
191 core.input = input;
192 core.into()
193 }
194
195 fn rewrite_with_input(
196 &self,
197 input: PlanRef,
198 input_col_change: ColIndexMapping,
199 ) -> (Self, ColIndexMapping) {
200 let mut core = self.core.clone();
201 core.input = input;
202 core.order = input_col_change
203 .rewrite_required_order(self.topn_order())
204 .unwrap();
205 for key in &mut core.group_key {
206 *key = input_col_change.map(*key)
207 }
208 (core.into(), input_col_change)
209 }
210}
211impl_plan_tree_node_for_unary! {LogicalTopN}
212impl_distill_by_unit!(LogicalTopN, core, "LogicalTopN");
213
214impl ColPrunable for LogicalTopN {
215 fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
216 let input_required_bitset = FixedBitSet::from_iter(required_cols.iter().copied());
217 let order_required_cols = {
218 let mut order_required_cols = FixedBitSet::with_capacity(self.input().schema().len());
219 self.topn_order()
220 .column_orders
221 .iter()
222 .for_each(|o| order_required_cols.insert(o.column_index));
223 order_required_cols
224 };
225 let group_required_cols = {
226 let mut group_required_cols = FixedBitSet::with_capacity(self.input().schema().len());
227 self.group_key()
228 .iter()
229 .for_each(|idx| group_required_cols.insert(*idx));
230 group_required_cols
231 };
232
233 let input_required_cols = {
234 let mut tmp = order_required_cols;
235 tmp.union_with(&input_required_bitset);
236 tmp.union_with(&group_required_cols);
237 tmp.ones().collect_vec()
238 };
239 let mapping = ColIndexMapping::with_remaining_columns(
240 &input_required_cols,
241 self.input().schema().len(),
242 );
243 let new_order = Order {
244 column_orders: self
245 .topn_order()
246 .column_orders
247 .iter()
248 .map(|o| ColumnOrder::new(mapping.map(o.column_index), o.order_type))
249 .collect(),
250 };
251 let new_group_key = self
252 .group_key()
253 .iter()
254 .map(|group_key| mapping.map(*group_key))
255 .collect();
256 let new_input = self.input().prune_col(&input_required_cols, ctx);
257 let top_n = Self::new(
258 new_input,
259 self.limit_attr().limit(),
260 self.offset(),
261 self.limit_attr().with_ties(),
262 new_order,
263 new_group_key,
264 )
265 .into();
266
267 if input_required_cols == required_cols {
268 top_n
269 } else {
270 let output_required_cols = required_cols
271 .iter()
272 .map(|&idx| mapping.map(idx))
273 .collect_vec();
274 let src_size = top_n.schema().len();
275 LogicalProject::with_mapping(
276 top_n,
277 ColIndexMapping::with_remaining_columns(&output_required_cols, src_size),
278 )
279 .into()
280 }
281 }
282}
283
284impl ExprRewritable for LogicalTopN {}
285
286impl ExprVisitable for LogicalTopN {}
287
288impl PredicatePushdown for LogicalTopN {
289 fn predicate_pushdown(
290 &self,
291 predicate: Condition,
292 ctx: &mut PredicatePushdownContext,
293 ) -> PlanRef {
294 gen_filter_and_pushdown(self, predicate, Condition::true_cond(), ctx)
296 }
297}
298
299impl ToBatch for LogicalTopN {
300 fn to_batch(&self) -> Result<PlanRef> {
301 let new_input = self.input().to_batch()?;
302 let mut new_logical = self.core.clone();
303 new_logical.input = new_input;
304 if self.group_key().is_empty() {
305 Ok(BatchTopN::new(new_logical).into())
306 } else {
307 Ok(BatchGroupTopN::new(new_logical).into())
308 }
309 }
310}
311
312impl ToStream for LogicalTopN {
313 fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<PlanRef> {
314 if self.offset() != 0 && self.limit_attr().limit() == LIMIT_ALL_COUNT {
315 return Err(RwError::from(ErrorCode::InvalidInputSyntax(
316 "OFFSET without LIMIT in streaming mode".to_owned(),
317 )));
318 }
319 if self.limit_attr().limit() == 0 {
320 return Err(RwError::from(ErrorCode::InvalidInputSyntax(
321 "LIMIT 0 in streaming mode".to_owned(),
322 )));
323 }
324 Ok(if !self.group_key().is_empty() {
325 let input = self.input().to_stream(ctx)?;
326 let input = RequiredDist::hash_shard(self.group_key())
327 .enforce_if_not_satisfies(input, &Order::any())?;
328 let mut core = self.core.clone();
329 core.input = input;
330 StreamGroupTopN::new(core, None).into()
331 } else {
332 self.gen_dist_stream_top_n_plan(self.input().to_stream(ctx)?)?
333 })
334 }
335
336 fn logical_rewrite_for_stream(
337 &self,
338 ctx: &mut RewriteStreamContext,
339 ) -> Result<(PlanRef, ColIndexMapping)> {
340 let (input, input_col_change) = self.input().logical_rewrite_for_stream(ctx)?;
341 let (top_n, out_col_change) = self.rewrite_with_input(input, input_col_change);
342 Ok((top_n.into(), out_col_change))
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use risingwave_common::catalog::{Field, Schema};
349 use risingwave_common::types::DataType;
350
351 use super::LogicalTopN;
352 use crate::PlanRef;
353 use crate::optimizer::optimizer_context::OptimizerContext;
354 use crate::optimizer::plan_node::{ColPrunable, ColumnPruningContext, LogicalValues};
355 use crate::optimizer::property::Order;
356
357 #[tokio::test]
358 async fn test_prune_col() {
359 let ty = DataType::Int32;
360 let ctx = OptimizerContext::mock().await;
361 let fields: Vec<Field> = vec![
362 Field::with_name(ty.clone(), "v1"),
363 Field::with_name(ty.clone(), "v2"),
364 Field::with_name(ty.clone(), "v3"),
365 ];
366 let values = LogicalValues::new(vec![], Schema { fields }, ctx);
367 let input = PlanRef::from(values);
368
369 let original_logical = LogicalTopN::new(input, 1, 0, false, Order::default(), vec![1]);
370 assert_eq!(original_logical.group_key(), &[1]);
371 let original_logical: PlanRef = original_logical.into();
372 let pruned_node = original_logical.prune_col(
373 &[0, 1, 2],
374 &mut ColumnPruningContext::new(original_logical.clone()),
375 );
376
377 let pruned_logical = pruned_node.as_logical_top_n().unwrap();
378 assert_eq!(pruned_logical.group_key(), &[1]);
379 }
380}