risingwave_frontend/optimizer/plan_node/
logical_topn.rs1use fixedbitset::FixedBitSet;
16use itertools::Itertools;
17use risingwave_common::bail_not_implemented;
18use risingwave_common::util::sort_util::ColumnOrder;
19
20use super::generic::{GenericPlanRef, TopNLimit};
21use super::utils::impl_distill_by_unit;
22use super::{
23 BatchGroupTopN, ColPrunable, ExprRewritable, Logical, LogicalPlanRef as PlanRef, PlanBase,
24 PlanTreeNodeUnary, PredicatePushdown, StreamGroupTopN, StreamPlanRef, StreamProject, ToBatch,
25 ToStream, gen_filter_and_pushdown, generic,
26};
27use crate::error::{ErrorCode, Result, RwError};
28use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
29use crate::optimizer::plan_node::{
30 BatchTopN, ColumnPruningContext, LogicalProject, PredicatePushdownContext,
31 RewriteStreamContext, StreamTopN, ToStreamContext,
32};
33use crate::optimizer::property::{Distribution, Order, RequiredDist};
34use crate::planner::LIMIT_ALL_COUNT;
35use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, Condition};
36
37#[derive(Debug, Clone, PartialEq, Eq, Hash)]
39pub struct LogicalTopN {
40 pub base: PlanBase<Logical>,
41 core: generic::TopN<PlanRef>,
42}
43
44impl From<generic::TopN<PlanRef>> for LogicalTopN {
45 fn from(core: generic::TopN<PlanRef>) -> Self {
46 let base = PlanBase::new_logical_with_core(&core);
47 Self { base, core }
48 }
49}
50
51impl LogicalTopN {
52 pub fn new(
53 input: PlanRef,
54 limit: u64,
55 offset: u64,
56 with_ties: bool,
57 order: Order,
58 group_key: Vec<usize>,
59 ) -> Self {
60 let limit_attr = TopNLimit::new(limit, with_ties);
61 let core = generic::TopN::with_group(input, limit_attr, offset, order, group_key);
62 core.into()
63 }
64
65 pub fn create(
66 input: PlanRef,
67 limit: u64,
68 offset: u64,
69 order: Order,
70 with_ties: bool,
71 group_key: Vec<usize>,
72 ) -> Result<PlanRef> {
73 if with_ties && offset > 0 {
74 bail_not_implemented!("WITH TIES is not supported with OFFSET");
75 }
76 Ok(Self::new(input, limit, offset, with_ties, order, group_key).into())
77 }
78
79 pub fn limit_attr(&self) -> TopNLimit {
80 self.core.limit_attr
81 }
82
83 pub fn offset(&self) -> u64 {
84 self.core.offset
85 }
86
87 pub fn topn_order(&self) -> &Order {
94 &self.core.order
95 }
96
97 pub fn group_key(&self) -> &[usize] {
98 &self.core.group_key
99 }
100
101 pub fn decompose(self) -> (PlanRef, u64, u64, bool, Order, Vec<usize>) {
103 self.core.decompose()
104 }
105
106 fn gen_dist_stream_top_n_plan(&self, stream_input: StreamPlanRef) -> Result<StreamPlanRef> {
107 use super::stream::prelude::*;
108
109 let input_dist = stream_input.distribution().clone();
110
111 if stream_input.append_only() {
113 return self.gen_single_stream_top_n_plan(stream_input);
114 }
115
116 match input_dist {
117 Distribution::Single | Distribution::SomeShard => {
118 self.gen_single_stream_top_n_plan(stream_input)
119 }
120 Distribution::Broadcast => bail_not_implemented!("topN does not support Broadcast"),
121 Distribution::HashShard(dists) | Distribution::UpstreamHashShard(dists, _) => {
122 self.gen_vnode_two_phase_stream_top_n_plan(stream_input, &dists)
123 }
124 }
125 }
126
127 fn gen_single_stream_top_n_plan(&self, stream_input: StreamPlanRef) -> Result<StreamPlanRef> {
128 let input = RequiredDist::single().streaming_enforce_if_not_satisfies(stream_input)?;
129 let core = self.core.clone_with_input(input);
130 Ok(StreamTopN::new(core)?.into())
131 }
132
133 fn gen_vnode_two_phase_stream_top_n_plan(
134 &self,
135 stream_input: StreamPlanRef,
136 dist_key: &[usize],
137 ) -> Result<StreamPlanRef> {
138 let project = StreamProject::new(generic::Project::with_vnode_col(stream_input, dist_key));
140 let vnode_col_idx = project.base.schema().len() - 1;
141
142 let limit_attr = TopNLimit::new(
143 self.limit_attr().limit() + self.offset(),
144 self.limit_attr().with_ties(),
145 );
146 let local_top_n = generic::TopN::with_group(
147 project.into(),
148 limit_attr,
149 0,
150 self.topn_order().clone(),
151 vec![vnode_col_idx],
152 );
153 let local_top_n = StreamGroupTopN::new(local_top_n, Some(vnode_col_idx))?;
154
155 let exchange =
156 RequiredDist::single().streaming_enforce_if_not_satisfies(local_top_n.into())?;
157
158 let global_top_n = generic::TopN::without_group(
159 exchange,
160 self.limit_attr(),
161 self.offset(),
162 self.topn_order().clone(),
163 );
164 let global_top_n = StreamTopN::new(global_top_n)?;
165
166 assert_eq!(vnode_col_idx, global_top_n.base.schema().len() - 1);
168 let project = StreamProject::new(generic::Project::with_out_col_idx(
169 global_top_n.into(),
170 0..vnode_col_idx,
171 ));
172 Ok(project.into())
173 }
174
175 pub fn clone_with_input_and_prefix(&self, input: PlanRef, prefix: Order) -> Self {
176 let mut core = self.core.clone();
177 core.input = input;
178 core.order = prefix.concat(core.order);
179 core.into()
180 }
181}
182
183impl PlanTreeNodeUnary<Logical> for LogicalTopN {
184 fn input(&self) -> PlanRef {
185 self.core.input.clone()
186 }
187
188 fn clone_with_input(&self, input: PlanRef) -> Self {
189 let mut core = self.core.clone();
190 core.input = input;
191 core.into()
192 }
193
194 fn rewrite_with_input(
195 &self,
196 input: PlanRef,
197 input_col_change: ColIndexMapping,
198 ) -> (Self, ColIndexMapping) {
199 let mut core = self.core.clone();
200 core.input = input;
201 core.order = input_col_change
202 .rewrite_required_order(self.topn_order())
203 .unwrap();
204 for key in &mut core.group_key {
205 *key = input_col_change.map(*key)
206 }
207 (core.into(), input_col_change)
208 }
209}
210impl_plan_tree_node_for_unary! { Logical, LogicalTopN}
211impl_distill_by_unit!(LogicalTopN, core, "LogicalTopN");
212
213impl ColPrunable for LogicalTopN {
214 fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
215 let input_required_bitset = FixedBitSet::from_iter(required_cols.iter().copied());
216 let order_required_cols = {
217 let mut order_required_cols = FixedBitSet::with_capacity(self.input().schema().len());
218 self.topn_order()
219 .column_orders
220 .iter()
221 .for_each(|o| order_required_cols.insert(o.column_index));
222 order_required_cols
223 };
224 let group_required_cols = {
225 let mut group_required_cols = FixedBitSet::with_capacity(self.input().schema().len());
226 self.group_key()
227 .iter()
228 .for_each(|idx| group_required_cols.insert(*idx));
229 group_required_cols
230 };
231
232 let input_required_cols = {
233 let mut tmp = order_required_cols;
234 tmp.union_with(&input_required_bitset);
235 tmp.union_with(&group_required_cols);
236 tmp.ones().collect_vec()
237 };
238 let mapping = ColIndexMapping::with_remaining_columns(
239 &input_required_cols,
240 self.input().schema().len(),
241 );
242 let new_order = Order {
243 column_orders: self
244 .topn_order()
245 .column_orders
246 .iter()
247 .map(|o| ColumnOrder::new(mapping.map(o.column_index), o.order_type))
248 .collect(),
249 };
250 let new_group_key = self
251 .group_key()
252 .iter()
253 .map(|group_key| mapping.map(*group_key))
254 .collect();
255 let new_input = self.input().prune_col(&input_required_cols, ctx);
256 let top_n = Self::new(
257 new_input,
258 self.limit_attr().limit(),
259 self.offset(),
260 self.limit_attr().with_ties(),
261 new_order,
262 new_group_key,
263 )
264 .into();
265
266 if input_required_cols == required_cols {
267 top_n
268 } else {
269 let output_required_cols = required_cols
270 .iter()
271 .map(|&idx| mapping.map(idx))
272 .collect_vec();
273 let src_size = top_n.schema().len();
274 LogicalProject::with_mapping(
275 top_n,
276 ColIndexMapping::with_remaining_columns(&output_required_cols, src_size),
277 )
278 .into()
279 }
280 }
281}
282
283impl ExprRewritable<Logical> for LogicalTopN {}
284
285impl ExprVisitable for LogicalTopN {}
286
287impl PredicatePushdown for LogicalTopN {
288 fn predicate_pushdown(
289 &self,
290 predicate: Condition,
291 ctx: &mut PredicatePushdownContext,
292 ) -> PlanRef {
293 gen_filter_and_pushdown(self, predicate, Condition::true_cond(), ctx)
295 }
296}
297
298impl ToBatch for LogicalTopN {
299 fn to_batch(&self) -> Result<crate::optimizer::plan_node::BatchPlanRef> {
300 let new_input = self.input().to_batch()?;
301 let core = self.core.clone_with_input(new_input);
302 if self.group_key().is_empty() {
303 Ok(BatchTopN::new(core).into())
304 } else {
305 Ok(BatchGroupTopN::new(core).into())
306 }
307 }
308}
309
310impl ToStream for LogicalTopN {
311 fn to_stream(
312 &self,
313 ctx: &mut ToStreamContext,
314 ) -> Result<crate::optimizer::plan_node::StreamPlanRef> {
315 if self.offset() != 0 && self.limit_attr().limit() == LIMIT_ALL_COUNT {
316 return Err(RwError::from(ErrorCode::InvalidInputSyntax(
317 "OFFSET without LIMIT in streaming mode".to_owned(),
318 )));
319 }
320 if self.limit_attr().limit() == 0 {
321 return Err(RwError::from(ErrorCode::InvalidInputSyntax(
322 "LIMIT 0 in streaming mode".to_owned(),
323 )));
324 }
325 Ok(if !self.group_key().is_empty() {
326 let input = self.input().to_stream(ctx)?;
327 let input = RequiredDist::hash_shard(self.group_key())
328 .streaming_enforce_if_not_satisfies(input)?;
329 let core = self.core.clone_with_input(input);
330 StreamGroupTopN::new(core, None)?.into()
331 } else {
332 self.gen_dist_stream_top_n_plan(self.input().to_stream(ctx)?)?
333 })
334 }
335
336 fn logical_rewrite_for_stream(
337 &self,
338 ctx: &mut RewriteStreamContext,
339 ) -> Result<(PlanRef, ColIndexMapping)> {
340 let (input, input_col_change) = self.input().logical_rewrite_for_stream(ctx)?;
341 let (top_n, out_col_change) = self.rewrite_with_input(input, input_col_change);
342 Ok((top_n.into(), out_col_change))
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use risingwave_common::catalog::{Field, Schema};
349 use risingwave_common::types::DataType;
350
351 use super::LogicalTopN;
352 use crate::PlanRef;
353 use crate::optimizer::optimizer_context::OptimizerContext;
354 use crate::optimizer::plan_node::{
355 ColPrunable, ColumnPruningContext, LogicalPlanRef, LogicalValues,
356 };
357 use crate::optimizer::property::Order;
358
359 #[tokio::test]
360 async fn test_prune_col() {
361 let ty = DataType::Int32;
362 let ctx = OptimizerContext::mock().await;
363 let fields: Vec<Field> = vec![
364 Field::with_name(ty.clone(), "v1"),
365 Field::with_name(ty.clone(), "v2"),
366 Field::with_name(ty.clone(), "v3"),
367 ];
368 let values = LogicalValues::new(vec![], Schema { fields }, ctx);
369 let input = PlanRef::from(values);
370
371 let original_logical = LogicalTopN::new(input, 1, 0, false, Order::default(), vec![1]);
372 assert_eq!(original_logical.group_key(), &[1]);
373 let original_logical: LogicalPlanRef = original_logical.into();
374 let pruned_node = original_logical.prune_col(
375 &[0, 1, 2],
376 &mut ColumnPruningContext::new(original_logical.clone()),
377 );
378
379 let pruned_logical = pruned_node.as_logical_top_n().unwrap();
380 assert_eq!(pruned_logical.group_key(), &[1]);
381 }
382}