risingwave_frontend/optimizer/plan_node/
logical_topn.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use fixedbitset::FixedBitSet;
16use itertools::Itertools;
17use risingwave_common::bail_not_implemented;
18use risingwave_common::util::sort_util::ColumnOrder;
19
20use super::generic::{GenericPlanRef, TopNLimit};
21use super::utils::impl_distill_by_unit;
22use super::{
23    BatchGroupTopN, ColPrunable, ExprRewritable, Logical, PlanBase, PlanRef, PlanTreeNodeUnary,
24    PredicatePushdown, StreamGroupTopN, StreamProject, ToBatch, ToStream, gen_filter_and_pushdown,
25    generic,
26};
27use crate::error::{ErrorCode, Result, RwError};
28use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
29use crate::optimizer::plan_node::{
30    BatchTopN, ColumnPruningContext, LogicalProject, PredicatePushdownContext,
31    RewriteStreamContext, StreamTopN, ToStreamContext,
32};
33use crate::optimizer::property::{Distribution, Order, RequiredDist};
34use crate::planner::LIMIT_ALL_COUNT;
35use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, Condition};
36
37/// `LogicalTopN` sorts the input data and fetches up to `limit` rows from `offset`
38#[derive(Debug, Clone, PartialEq, Eq, Hash)]
39pub struct LogicalTopN {
40    pub base: PlanBase<Logical>,
41    core: generic::TopN<PlanRef>,
42}
43
44impl From<generic::TopN<PlanRef>> for LogicalTopN {
45    fn from(core: generic::TopN<PlanRef>) -> Self {
46        let base = PlanBase::new_logical_with_core(&core);
47        Self { base, core }
48    }
49}
50
51impl LogicalTopN {
52    pub fn new(
53        input: PlanRef,
54        limit: u64,
55        offset: u64,
56        with_ties: bool,
57        order: Order,
58        group_key: Vec<usize>,
59    ) -> Self {
60        let limit_attr = TopNLimit::new(limit, with_ties);
61        let core = generic::TopN::with_group(input, limit_attr, offset, order, group_key);
62        core.into()
63    }
64
65    pub fn create(
66        input: PlanRef,
67        limit: u64,
68        offset: u64,
69        order: Order,
70        with_ties: bool,
71        group_key: Vec<usize>,
72    ) -> Result<PlanRef> {
73        if with_ties && offset > 0 {
74            bail_not_implemented!("WITH TIES is not supported with OFFSET");
75        }
76        Ok(Self::new(input, limit, offset, with_ties, order, group_key).into())
77    }
78
79    pub fn limit_attr(&self) -> TopNLimit {
80        self.core.limit_attr
81    }
82
83    pub fn offset(&self) -> u64 {
84        self.core.offset
85    }
86
87    /// `topn_order` returns the order of the Top-N operator. This naming is because `order()`
88    /// already exists and it was designed to return the operator's physical property order.
89    ///
90    /// Note that for streaming query, `order()` and `topn_order()` may differ. `order()` which
91    /// implies the output ordering of an operator, is never guaranteed; while `topn_order()` must
92    /// be non-null because it's a critical information for Top-N operators to work
93    pub fn topn_order(&self) -> &Order {
94        &self.core.order
95    }
96
97    pub fn group_key(&self) -> &[usize] {
98        &self.core.group_key
99    }
100
101    /// decompose -> (input, limit, offset, `with_ties`, order, `group_key`)
102    pub fn decompose(self) -> (PlanRef, u64, u64, bool, Order, Vec<usize>) {
103        self.core.decompose()
104    }
105
106    fn gen_dist_stream_top_n_plan(&self, stream_input: PlanRef) -> Result<PlanRef> {
107        use super::stream::prelude::*;
108
109        let input_dist = stream_input.distribution().clone();
110
111        // if it is append only, for now we don't generate 2-phase rules
112        if stream_input.append_only() {
113            return self.gen_single_stream_top_n_plan(stream_input);
114        }
115
116        match input_dist {
117            Distribution::Single | Distribution::SomeShard => {
118                self.gen_single_stream_top_n_plan(stream_input)
119            }
120            Distribution::Broadcast => bail_not_implemented!("topN does not support Broadcast"),
121            Distribution::HashShard(dists) | Distribution::UpstreamHashShard(dists, _) => {
122                self.gen_vnode_two_phase_stream_top_n_plan(stream_input, &dists)
123            }
124        }
125    }
126
127    fn gen_single_stream_top_n_plan(&self, stream_input: PlanRef) -> Result<PlanRef> {
128        let input = RequiredDist::single().enforce_if_not_satisfies(stream_input, &Order::any())?;
129        let mut core = self.core.clone();
130        core.input = input;
131        Ok(StreamTopN::new(core).into())
132    }
133
134    fn gen_vnode_two_phase_stream_top_n_plan(
135        &self,
136        stream_input: PlanRef,
137        dist_key: &[usize],
138    ) -> Result<PlanRef> {
139        // use projectiton to add a column for vnode, and use this column as group key.
140        let project = StreamProject::new(generic::Project::with_vnode_col(stream_input, dist_key));
141        let vnode_col_idx = project.base.schema().len() - 1;
142
143        let limit_attr = TopNLimit::new(
144            self.limit_attr().limit() + self.offset(),
145            self.limit_attr().with_ties(),
146        );
147        let local_top_n = generic::TopN::with_group(
148            project.into(),
149            limit_attr,
150            0,
151            self.topn_order().clone(),
152            vec![vnode_col_idx],
153        );
154        let local_top_n = StreamGroupTopN::new(local_top_n, Some(vnode_col_idx));
155
156        let exchange =
157            RequiredDist::single().enforce_if_not_satisfies(local_top_n.into(), &Order::any())?;
158
159        let global_top_n = generic::TopN::without_group(
160            exchange,
161            self.limit_attr(),
162            self.offset(),
163            self.topn_order().clone(),
164        );
165        let global_top_n = StreamTopN::new(global_top_n);
166
167        // use another projection to remove the column we added before.
168        assert_eq!(vnode_col_idx, global_top_n.base.schema().len() - 1);
169        let project = StreamProject::new(generic::Project::with_out_col_idx(
170            global_top_n.into(),
171            0..vnode_col_idx,
172        ));
173        Ok(project.into())
174    }
175
176    pub fn clone_with_input_and_prefix(&self, input: PlanRef, prefix: Order) -> Self {
177        let mut core = self.core.clone();
178        core.input = input;
179        core.order = prefix.concat(core.order);
180        core.into()
181    }
182}
183
184impl PlanTreeNodeUnary for LogicalTopN {
185    fn input(&self) -> PlanRef {
186        self.core.input.clone()
187    }
188
189    fn clone_with_input(&self, input: PlanRef) -> Self {
190        let mut core = self.core.clone();
191        core.input = input;
192        core.into()
193    }
194
195    fn rewrite_with_input(
196        &self,
197        input: PlanRef,
198        input_col_change: ColIndexMapping,
199    ) -> (Self, ColIndexMapping) {
200        let mut core = self.core.clone();
201        core.input = input;
202        core.order = input_col_change
203            .rewrite_required_order(self.topn_order())
204            .unwrap();
205        for key in &mut core.group_key {
206            *key = input_col_change.map(*key)
207        }
208        (core.into(), input_col_change)
209    }
210}
211impl_plan_tree_node_for_unary! {LogicalTopN}
212impl_distill_by_unit!(LogicalTopN, core, "LogicalTopN");
213
214impl ColPrunable for LogicalTopN {
215    fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef {
216        let input_required_bitset = FixedBitSet::from_iter(required_cols.iter().copied());
217        let order_required_cols = {
218            let mut order_required_cols = FixedBitSet::with_capacity(self.input().schema().len());
219            self.topn_order()
220                .column_orders
221                .iter()
222                .for_each(|o| order_required_cols.insert(o.column_index));
223            order_required_cols
224        };
225        let group_required_cols = {
226            let mut group_required_cols = FixedBitSet::with_capacity(self.input().schema().len());
227            self.group_key()
228                .iter()
229                .for_each(|idx| group_required_cols.insert(*idx));
230            group_required_cols
231        };
232
233        let input_required_cols = {
234            let mut tmp = order_required_cols;
235            tmp.union_with(&input_required_bitset);
236            tmp.union_with(&group_required_cols);
237            tmp.ones().collect_vec()
238        };
239        let mapping = ColIndexMapping::with_remaining_columns(
240            &input_required_cols,
241            self.input().schema().len(),
242        );
243        let new_order = Order {
244            column_orders: self
245                .topn_order()
246                .column_orders
247                .iter()
248                .map(|o| ColumnOrder::new(mapping.map(o.column_index), o.order_type))
249                .collect(),
250        };
251        let new_group_key = self
252            .group_key()
253            .iter()
254            .map(|group_key| mapping.map(*group_key))
255            .collect();
256        let new_input = self.input().prune_col(&input_required_cols, ctx);
257        let top_n = Self::new(
258            new_input,
259            self.limit_attr().limit(),
260            self.offset(),
261            self.limit_attr().with_ties(),
262            new_order,
263            new_group_key,
264        )
265        .into();
266
267        if input_required_cols == required_cols {
268            top_n
269        } else {
270            let output_required_cols = required_cols
271                .iter()
272                .map(|&idx| mapping.map(idx))
273                .collect_vec();
274            let src_size = top_n.schema().len();
275            LogicalProject::with_mapping(
276                top_n,
277                ColIndexMapping::with_remaining_columns(&output_required_cols, src_size),
278            )
279            .into()
280        }
281    }
282}
283
284impl ExprRewritable for LogicalTopN {}
285
286impl ExprVisitable for LogicalTopN {}
287
288impl PredicatePushdown for LogicalTopN {
289    fn predicate_pushdown(
290        &self,
291        predicate: Condition,
292        ctx: &mut PredicatePushdownContext,
293    ) -> PlanRef {
294        // filter can not transpose topN
295        gen_filter_and_pushdown(self, predicate, Condition::true_cond(), ctx)
296    }
297}
298
299impl ToBatch for LogicalTopN {
300    fn to_batch(&self) -> Result<PlanRef> {
301        let new_input = self.input().to_batch()?;
302        let mut new_logical = self.core.clone();
303        new_logical.input = new_input;
304        if self.group_key().is_empty() {
305            Ok(BatchTopN::new(new_logical).into())
306        } else {
307            Ok(BatchGroupTopN::new(new_logical).into())
308        }
309    }
310}
311
312impl ToStream for LogicalTopN {
313    fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<PlanRef> {
314        if self.offset() != 0 && self.limit_attr().limit() == LIMIT_ALL_COUNT {
315            return Err(RwError::from(ErrorCode::InvalidInputSyntax(
316                "OFFSET without LIMIT in streaming mode".to_owned(),
317            )));
318        }
319        if self.limit_attr().limit() == 0 {
320            return Err(RwError::from(ErrorCode::InvalidInputSyntax(
321                "LIMIT 0 in streaming mode".to_owned(),
322            )));
323        }
324        Ok(if !self.group_key().is_empty() {
325            let input = self.input().to_stream(ctx)?;
326            let input = RequiredDist::hash_shard(self.group_key())
327                .enforce_if_not_satisfies(input, &Order::any())?;
328            let mut core = self.core.clone();
329            core.input = input;
330            StreamGroupTopN::new(core, None).into()
331        } else {
332            self.gen_dist_stream_top_n_plan(self.input().to_stream(ctx)?)?
333        })
334    }
335
336    fn logical_rewrite_for_stream(
337        &self,
338        ctx: &mut RewriteStreamContext,
339    ) -> Result<(PlanRef, ColIndexMapping)> {
340        let (input, input_col_change) = self.input().logical_rewrite_for_stream(ctx)?;
341        let (top_n, out_col_change) = self.rewrite_with_input(input, input_col_change);
342        Ok((top_n.into(), out_col_change))
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use risingwave_common::catalog::{Field, Schema};
349    use risingwave_common::types::DataType;
350
351    use super::LogicalTopN;
352    use crate::PlanRef;
353    use crate::optimizer::optimizer_context::OptimizerContext;
354    use crate::optimizer::plan_node::{ColPrunable, ColumnPruningContext, LogicalValues};
355    use crate::optimizer::property::Order;
356
357    #[tokio::test]
358    async fn test_prune_col() {
359        let ty = DataType::Int32;
360        let ctx = OptimizerContext::mock().await;
361        let fields: Vec<Field> = vec![
362            Field::with_name(ty.clone(), "v1"),
363            Field::with_name(ty.clone(), "v2"),
364            Field::with_name(ty.clone(), "v3"),
365        ];
366        let values = LogicalValues::new(vec![], Schema { fields }, ctx);
367        let input = PlanRef::from(values);
368
369        let original_logical = LogicalTopN::new(input, 1, 0, false, Order::default(), vec![1]);
370        assert_eq!(original_logical.group_key(), &[1]);
371        let original_logical: PlanRef = original_logical.into();
372        let pruned_node = original_logical.prune_col(
373            &[0, 1, 2],
374            &mut ColumnPruningContext::new(original_logical.clone()),
375        );
376
377        let pruned_logical = pruned_node.as_logical_top_n().unwrap();
378        assert_eq!(pruned_logical.group_key(), &[1]);
379    }
380}