risingwave_frontend/optimizer/plan_node/
logical_topn.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use fixedbitset::FixedBitSet;
16use itertools::Itertools;
17use risingwave_common::bail_not_implemented;
18use risingwave_common::util::sort_util::ColumnOrder;
19
20use super::generic::{GenericPlanRef, TopNLimit};
21use super::utils::impl_distill_by_unit;
22use super::{
23    BatchGroupTopN, ColPrunable, ExprRewritable, Logical, LogicalPlanRef as PlanRef, PlanBase,
24    PlanTreeNodeUnary, PredicatePushdown, StreamGroupTopN, StreamPlanRef, StreamProject, ToBatch,
25    ToStream, gen_filter_and_pushdown, generic,
26};
27use crate::error::{ErrorCode, Result, RwError};
28use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
29use crate::optimizer::plan_node::{
30    BatchTopN, ColumnPruningContext, LogicalProject, PredicatePushdownContext,
31    RewriteStreamContext, StreamTopN, ToStreamContext,
32};
33use crate::optimizer::property::{Distribution, Order, RequiredDist};
34use crate::planner::LIMIT_ALL_COUNT;
35use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, Condition};
36
37/// `LogicalTopN` sorts the input data and fetches up to `limit` rows from `offset`
38#[derive(Debug, Clone, PartialEq, Eq, Hash)]
39pub struct LogicalTopN {
40    pub base: PlanBase<Logical>,
41    core: generic::TopN<PlanRef>,
42}
43
44impl From<generic::TopN<PlanRef>> for LogicalTopN {
45    fn from(core: generic::TopN<PlanRef>) -> Self {
46        let base = PlanBase::new_logical_with_core(&core);
47        Self { base, core }
48    }
49}
50
51impl LogicalTopN {
52    pub fn new(
53        input: PlanRef,
54        limit: u64,
55        offset: u64,
56        with_ties: bool,
57        order: Order,
58        group_key: Vec<usize>,
59    ) -> Self {
60        let limit_attr = TopNLimit::new(limit, with_ties);
61        let core = generic::TopN::with_group(input, limit_attr, offset, order, group_key);
62        core.into()
63    }
64
65    pub fn create(
66        input: PlanRef,
67        limit: u64,
68        offset: u64,
69        order: Order,
70        with_ties: bool,
71        group_key: Vec<usize>,
72    ) -> Result<PlanRef> {
73        if with_ties && offset > 0 {
74            bail_not_implemented!("WITH TIES is not supported with OFFSET");
75        }
76        Ok(Self::new(input, limit, offset, with_ties, order, group_key).into())
77    }
78
79    pub fn limit_attr(&self) -> TopNLimit {
80        self.core.limit_attr
81    }
82
83    pub fn offset(&self) -> u64 {
84        self.core.offset
85    }
86
87    /// `topn_order` returns the order of the Top-N operator. This naming is because `order()`
88    /// already exists and it was designed to return the operator's physical property order.
89    ///
90    /// Note that for streaming query, `order()` and `topn_order()` may differ. `order()` which
91    /// implies the output ordering of an operator, is never guaranteed; while `topn_order()` must
92    /// be non-null because it's a critical information for Top-N operators to work
93    pub fn topn_order(&self) -> &Order {
94        &self.core.order
95    }
96
97    pub fn group_key(&self) -> &[usize] {
98        &self.core.group_key
99    }
100
101    /// decompose -> (input, limit, offset, `with_ties`, order, `group_key`)
102    pub fn decompose(self) -> (PlanRef, u64, u64, bool, Order, Vec<usize>) {
103        self.core.decompose()
104    }
105
106    fn gen_dist_stream_top_n_plan(&self, stream_input: StreamPlanRef) -> Result<StreamPlanRef> {
107        use super::stream::prelude::*;
108
109        let input_dist = stream_input.distribution().clone();
110
111        // if it is append only, for now we don't generate 2-phase rules
112        if stream_input.append_only() {
113            return self.gen_single_stream_top_n_plan(stream_input);
114        }
115
116        match input_dist {
117            Distribution::Single | Distribution::SomeShard => {
118                self.gen_single_stream_top_n_plan(stream_input)
119            }
120            Distribution::Broadcast => bail_not_implemented!("topN does not support Broadcast"),
121            Distribution::HashShard(dists) | Distribution::UpstreamHashShard(dists, _) => {
122                self.gen_vnode_two_phase_stream_top_n_plan(stream_input, &dists)
123            }
124        }
125    }
126
127    fn gen_single_stream_top_n_plan(&self, stream_input: StreamPlanRef) -> Result<StreamPlanRef> {
128        let input = RequiredDist::single().streaming_enforce_if_not_satisfies(stream_input)?;
129        let core = self.core.clone_with_input(input);
130        Ok(StreamTopN::new(core)?.into())
131    }
132
133    fn gen_vnode_two_phase_stream_top_n_plan(
134        &self,
135        stream_input: StreamPlanRef,
136        dist_key: &[usize],
137    ) -> Result<StreamPlanRef> {
138        // use projectiton to add a column for vnode, and use this column as group key.
139        let project = StreamProject::new(generic::Project::with_vnode_col(stream_input, dist_key));
140        let vnode_col_idx = project.base.schema().len() - 1;
141
142        let limit_attr = TopNLimit::new(
143            self.limit_attr().limit() + self.offset(),
144            self.limit_attr().with_ties(),
145        );
146        let local_top_n = generic::TopN::with_group(
147            project.into(),
148            limit_attr,
149            0,
150            self.topn_order().clone(),
151            vec![vnode_col_idx],
152        );
153        let local_top_n = StreamGroupTopN::new(local_top_n, Some(vnode_col_idx))?;
154
155        let exchange =
156            RequiredDist::single().streaming_enforce_if_not_satisfies(local_top_n.into())?;
157
158        let global_top_n = generic::TopN::without_group(
159            exchange,
160            self.limit_attr(),
161            self.offset(),
162            self.topn_order().clone(),
163        );
164        let global_top_n = StreamTopN::new(global_top_n)?;
165
166        // use another projection to remove the column we added before.
167        assert_eq!(vnode_col_idx, global_top_n.base.schema().len() - 1);
168        let project = StreamProject::new(generic::Project::with_out_col_idx(
169            global_top_n.into(),
170            0..vnode_col_idx,
171        ));
172        Ok(project.into())
173    }
174
175    pub fn clone_with_input_and_prefix(&self, input: PlanRef, prefix: Order) -> Self {
176        let mut core = self.core.clone();
177        core.input = input;
178        core.order = prefix.concat(core.order);
179        core.into()
180    }
181}
182
183impl PlanTreeNodeUnary<Logical> for LogicalTopN {
184    fn input(&self) -> PlanRef {
185        self.core.input.clone()
186    }
187
188    fn clone_with_input(&self, input: PlanRef) -> Self {
189        let mut core = self.core.clone();
190        core.input = input;
191        core.into()
192    }
193
194    fn rewrite_with_input(
195        &self,
196        input: PlanRef,
197        input_col_change: ColIndexMapping,
198    ) -> (Self, ColIndexMapping) {
199        let mut core = self.core.clone();
200        core.input = input;
201        core.order = input_col_change
202            .rewrite_required_order(self.topn_order())
203            .unwrap();
204        for key in &mut core.group_key {
205            *key = input_col_change.map(*key)
206        }
207        (core.into(), input_col_change)
208    }
209}
210impl_plan_tree_node_for_unary! { Logical, LogicalTopN}
211impl_distill_by_unit!(LogicalTopN, core, "LogicalTopN");
212
213impl ColPrunable for LogicalTopN {
214    fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
215        let input_required_bitset = FixedBitSet::from_iter(required_cols.iter().copied());
216        let order_required_cols = {
217            let mut order_required_cols = FixedBitSet::with_capacity(self.input().schema().len());
218            self.topn_order()
219                .column_orders
220                .iter()
221                .for_each(|o| order_required_cols.insert(o.column_index));
222            order_required_cols
223        };
224        let group_required_cols = {
225            let mut group_required_cols = FixedBitSet::with_capacity(self.input().schema().len());
226            self.group_key()
227                .iter()
228                .for_each(|idx| group_required_cols.insert(*idx));
229            group_required_cols
230        };
231
232        let input_required_cols = {
233            let mut tmp = order_required_cols;
234            tmp.union_with(&input_required_bitset);
235            tmp.union_with(&group_required_cols);
236            tmp.ones().collect_vec()
237        };
238        let mapping = ColIndexMapping::with_remaining_columns(
239            &input_required_cols,
240            self.input().schema().len(),
241        );
242        let new_order = Order {
243            column_orders: self
244                .topn_order()
245                .column_orders
246                .iter()
247                .map(|o| ColumnOrder::new(mapping.map(o.column_index), o.order_type))
248                .collect(),
249        };
250        let new_group_key = self
251            .group_key()
252            .iter()
253            .map(|group_key| mapping.map(*group_key))
254            .collect();
255        let new_input = self.input().prune_col(&input_required_cols, ctx);
256        let top_n = Self::new(
257            new_input,
258            self.limit_attr().limit(),
259            self.offset(),
260            self.limit_attr().with_ties(),
261            new_order,
262            new_group_key,
263        )
264        .into();
265
266        if input_required_cols == required_cols {
267            top_n
268        } else {
269            let output_required_cols = required_cols
270                .iter()
271                .map(|&idx| mapping.map(idx))
272                .collect_vec();
273            let src_size = top_n.schema().len();
274            LogicalProject::with_mapping(
275                top_n,
276                ColIndexMapping::with_remaining_columns(&output_required_cols, src_size),
277            )
278            .into()
279        }
280    }
281}
282
283impl ExprRewritable<Logical> for LogicalTopN {}
284
285impl ExprVisitable for LogicalTopN {}
286
287impl PredicatePushdown for LogicalTopN {
288    fn predicate_pushdown(
289        &self,
290        predicate: Condition,
291        ctx: &mut PredicatePushdownContext,
292    ) -> PlanRef {
293        // filter can not transpose topN
294        gen_filter_and_pushdown(self, predicate, Condition::true_cond(), ctx)
295    }
296}
297
298impl ToBatch for LogicalTopN {
299    fn to_batch(&self) -> Result<crate::optimizer::plan_node::BatchPlanRef> {
300        let new_input = self.input().to_batch()?;
301        let core = self.core.clone_with_input(new_input);
302        if self.group_key().is_empty() {
303            Ok(BatchTopN::new(core).into())
304        } else {
305            Ok(BatchGroupTopN::new(core).into())
306        }
307    }
308}
309
310impl ToStream for LogicalTopN {
311    fn to_stream(
312        &self,
313        ctx: &mut ToStreamContext,
314    ) -> Result<crate::optimizer::plan_node::StreamPlanRef> {
315        if self.offset() != 0 && self.limit_attr().limit() == LIMIT_ALL_COUNT {
316            return Err(RwError::from(ErrorCode::InvalidInputSyntax(
317                "OFFSET without LIMIT in streaming mode".to_owned(),
318            )));
319        }
320        if self.limit_attr().limit() == 0 {
321            return Err(RwError::from(ErrorCode::InvalidInputSyntax(
322                "LIMIT 0 in streaming mode".to_owned(),
323            )));
324        }
325        Ok(if !self.group_key().is_empty() {
326            let input = self.input().to_stream(ctx)?;
327            let input = RequiredDist::hash_shard(self.group_key())
328                .streaming_enforce_if_not_satisfies(input)?;
329            let core = self.core.clone_with_input(input);
330            StreamGroupTopN::new(core, None)?.into()
331        } else {
332            self.gen_dist_stream_top_n_plan(self.input().to_stream(ctx)?)?
333        })
334    }
335
336    fn logical_rewrite_for_stream(
337        &self,
338        ctx: &mut RewriteStreamContext,
339    ) -> Result<(PlanRef, ColIndexMapping)> {
340        let (input, input_col_change) = self.input().logical_rewrite_for_stream(ctx)?;
341        let (top_n, out_col_change) = self.rewrite_with_input(input, input_col_change);
342        Ok((top_n.into(), out_col_change))
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use risingwave_common::catalog::{Field, Schema};
349    use risingwave_common::types::DataType;
350
351    use super::LogicalTopN;
352    use crate::PlanRef;
353    use crate::optimizer::optimizer_context::OptimizerContext;
354    use crate::optimizer::plan_node::{
355        ColPrunable, ColumnPruningContext, LogicalPlanRef, LogicalValues,
356    };
357    use crate::optimizer::property::Order;
358
359    #[tokio::test]
360    async fn test_prune_col() {
361        let ty = DataType::Int32;
362        let ctx = OptimizerContext::mock().await;
363        let fields: Vec<Field> = vec![
364            Field::with_name(ty.clone(), "v1"),
365            Field::with_name(ty.clone(), "v2"),
366            Field::with_name(ty.clone(), "v3"),
367        ];
368        let values = LogicalValues::new(vec![], Schema { fields }, ctx);
369        let input = PlanRef::from(values);
370
371        let original_logical = LogicalTopN::new(input, 1, 0, false, Order::default(), vec![1]);
372        assert_eq!(original_logical.group_key(), &[1]);
373        let original_logical: LogicalPlanRef = original_logical.into();
374        let pruned_node = original_logical.prune_col(
375            &[0, 1, 2],
376            &mut ColumnPruningContext::new(original_logical.clone()),
377        );
378
379        let pruned_logical = pruned_node.as_logical_top_n().unwrap();
380        assert_eq!(pruned_logical.group_key(), &[1]);
381    }
382}