risingwave_batch_executors/executor/
top_n.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::Ordering;
16use std::sync::Arc;
17
18use futures_async_stream::try_stream;
19use risingwave_common::array::DataChunk;
20use risingwave_common::catalog::Schema;
21use risingwave_common::memory::{MemMonitoredHeap, MemoryContext};
22use risingwave_common::row::{OwnedRow, Row};
23use risingwave_common::util::chunk_coalesce::DataChunkBuilder;
24use risingwave_common::util::memcmp_encoding::{MemcmpEncoded, encode_chunk};
25use risingwave_common::util::sort_util::ColumnOrder;
26use risingwave_common_estimate_size::EstimateSize;
27use risingwave_pb::batch_plan::plan_node::NodeBody;
28
29use crate::error::{BatchError, Result};
30use crate::executor::{
31    BoxedDataChunkStream, BoxedExecutor, BoxedExecutorBuilder, Executor, ExecutorBuilder,
32};
33
34/// Top-N Executor
35///
36/// Use a N-heap to store the smallest N rows.
37pub struct TopNExecutor {
38    child: BoxedExecutor,
39    column_orders: Vec<ColumnOrder>,
40    offset: usize,
41    limit: usize,
42    with_ties: bool,
43    schema: Schema,
44    identity: String,
45    chunk_size: usize,
46    mem_ctx: MemoryContext,
47}
48
49impl BoxedExecutorBuilder for TopNExecutor {
50    async fn new_boxed_executor(
51        source: &ExecutorBuilder<'_>,
52        inputs: Vec<BoxedExecutor>,
53    ) -> Result<BoxedExecutor> {
54        let [child]: [_; 1] = inputs.try_into().unwrap();
55
56        let top_n_node =
57            try_match_expand!(source.plan_node().get_node_body().unwrap(), NodeBody::TopN)?;
58
59        let column_orders = top_n_node
60            .column_orders
61            .iter()
62            .map(ColumnOrder::from_protobuf)
63            .collect();
64
65        let identity = source.plan_node().get_identity();
66
67        Ok(Box::new(Self::new(
68            child,
69            column_orders,
70            top_n_node.get_offset() as usize,
71            top_n_node.get_limit() as usize,
72            top_n_node.get_with_ties(),
73            identity.clone(),
74            source.context().get_config().developer.chunk_size,
75            source.context().create_executor_mem_context(identity),
76        )))
77    }
78}
79
80impl TopNExecutor {
81    pub fn new(
82        child: BoxedExecutor,
83        column_orders: Vec<ColumnOrder>,
84        offset: usize,
85        limit: usize,
86        with_ties: bool,
87        identity: String,
88        chunk_size: usize,
89        mem_ctx: MemoryContext,
90    ) -> Self {
91        let schema = child.schema().clone();
92        Self {
93            child,
94            column_orders,
95            offset,
96            limit,
97            with_ties,
98            schema,
99            identity,
100            chunk_size,
101            mem_ctx,
102        }
103    }
104}
105
106impl Executor for TopNExecutor {
107    fn schema(&self) -> &Schema {
108        &self.schema
109    }
110
111    fn identity(&self) -> &str {
112        &self.identity
113    }
114
115    fn execute(self: Box<Self>) -> BoxedDataChunkStream {
116        self.do_execute()
117    }
118}
119
120pub const MAX_TOPN_INIT_HEAP_CAPACITY: usize = 1024;
121
122/// A max-heap used to find the smallest `limit+offset` items.
123pub struct TopNHeap {
124    heap: MemMonitoredHeap<HeapElem>,
125    limit: usize,
126    offset: usize,
127    with_ties: bool,
128}
129
130impl TopNHeap {
131    pub fn new(limit: usize, offset: usize, with_ties: bool, mem_ctx: MemoryContext) -> Self {
132        assert!(limit > 0);
133        Self {
134            heap: MemMonitoredHeap::with_capacity(
135                (limit + offset).min(MAX_TOPN_INIT_HEAP_CAPACITY),
136                mem_ctx,
137            ),
138            limit,
139            offset,
140            with_ties,
141        }
142    }
143
144    // Only used for swapping out the heap in hashmap, due to a bug in hashmap which forbids us from
145    // using `into_iter`. We should remove this after Hashmap upgraded and fixed the bug.
146    pub fn empty() -> Self {
147        Self {
148            heap: MemMonitoredHeap::with_capacity(0, MemoryContext::none()),
149            limit: 0,
150            offset: 0,
151            with_ties: false,
152        }
153    }
154
155    pub fn push(&mut self, elem: HeapElem) {
156        if self.heap.len() < self.limit + self.offset {
157            self.heap.push(elem);
158        } else {
159            // heap is full
160            if !self.with_ties {
161                let peek = self.heap.pop().unwrap();
162                if elem < peek {
163                    self.heap.push(elem);
164                } else {
165                    self.heap.push(peek);
166                }
167                // let inner = self.heap.inner();
168                // let mut peek = inner.peek_mut().unwrap();
169                // if elem < *peek {
170                //     *peek = elem;
171                // }
172            } else {
173                let peek = self.heap.peek().unwrap().clone();
174                match elem.cmp(&peek) {
175                    Ordering::Less => {
176                        let mut ties_with_peek = vec![];
177                        // pop all the ties with peek
178                        ties_with_peek.push(self.heap.pop().unwrap());
179                        while let Some(e) = self.heap.peek()
180                            && e.encoded_row == peek.encoded_row
181                        {
182                            ties_with_peek.push(self.heap.pop().unwrap());
183                        }
184                        self.heap.push(elem);
185                        // If the size is smaller than limit, we can push all the elements back.
186                        if self.heap.len() < self.limit {
187                            self.heap.extend(ties_with_peek);
188                        }
189                    }
190                    Ordering::Equal => {
191                        // It's a tie.
192                        self.heap.push(elem);
193                    }
194                    Ordering::Greater => {}
195                }
196            }
197        }
198    }
199
200    /// Returns the elements in the range `[offset, offset+limit)`.
201    pub fn dump(self) -> impl Iterator<Item = HeapElem> {
202        self.heap
203            .into_sorted_vec()
204            .into_iter()
205            .rev()
206            .skip(self.offset)
207    }
208}
209
210#[derive(Clone, EstimateSize)]
211pub struct HeapElem {
212    encoded_row: MemcmpEncoded,
213    row: OwnedRow,
214}
215
216impl PartialEq for HeapElem {
217    fn eq(&self, other: &Self) -> bool {
218        self.encoded_row.eq(&other.encoded_row)
219    }
220}
221
222impl Eq for HeapElem {}
223
224impl PartialOrd for HeapElem {
225    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
226        Some(self.cmp(other))
227    }
228}
229
230impl Ord for HeapElem {
231    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
232        self.encoded_row.cmp(&other.encoded_row)
233    }
234}
235
236impl HeapElem {
237    pub fn new(encoded_row: MemcmpEncoded, row: impl Row) -> Self {
238        Self {
239            encoded_row,
240            row: row.into_owned_row(),
241        }
242    }
243
244    pub fn row(&self) -> impl Row + '_ {
245        &self.row
246    }
247}
248
249impl TopNExecutor {
250    #[try_stream(boxed, ok = DataChunk, error = BatchError)]
251    async fn do_execute(self: Box<Self>) {
252        if self.limit == 0 {
253            return Ok(());
254        }
255        let mut heap = TopNHeap::new(
256            self.limit,
257            self.offset,
258            self.with_ties,
259            self.mem_ctx.clone(),
260        );
261
262        #[for_await]
263        for chunk in self.child.execute() {
264            let chunk = Arc::new(chunk?.compact());
265            for (row_id, encoded_row) in encode_chunk(&chunk, &self.column_orders)?
266                .into_iter()
267                .enumerate()
268            {
269                heap.push(HeapElem {
270                    encoded_row,
271                    row: chunk.row_at(row_id).0.to_owned_row(),
272                });
273            }
274        }
275
276        let mut chunk_builder = DataChunkBuilder::new(self.schema.data_types(), self.chunk_size);
277        for HeapElem { row, .. } in heap.dump() {
278            if let Some(spilled) = chunk_builder.append_one_row(row) {
279                yield spilled
280            }
281        }
282        if let Some(spilled) = chunk_builder.consume_all() {
283            yield spilled
284        }
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use futures::stream::StreamExt;
291    use itertools::Itertools;
292    use risingwave_common::array::Array;
293    use risingwave_common::catalog::Field;
294    use risingwave_common::test_prelude::DataChunkTestExt;
295    use risingwave_common::types::DataType;
296    use risingwave_common::util::sort_util::OrderType;
297
298    use super::*;
299    use crate::executor::test_utils::MockExecutor;
300
301    const CHUNK_SIZE: usize = 1024;
302
303    #[tokio::test]
304    async fn test_simple_top_n_executor() {
305        let schema = Schema {
306            fields: vec![
307                Field::unnamed(DataType::Int32),
308                Field::unnamed(DataType::Int32),
309            ],
310        };
311        let mut mock_executor = MockExecutor::new(schema);
312        mock_executor.add(DataChunk::from_pretty(
313            "i i
314             1 5
315             2 4
316             3 3
317             4 2
318             5 1",
319        ));
320        let column_orders = vec![
321            ColumnOrder {
322                column_index: 1,
323                order_type: OrderType::ascending(),
324            },
325            ColumnOrder {
326                column_index: 0,
327                order_type: OrderType::ascending(),
328            },
329        ];
330        let top_n_executor = Box::new(TopNExecutor::new(
331            Box::new(mock_executor),
332            column_orders,
333            1,
334            3,
335            false,
336            "TopNExecutor".to_owned(),
337            CHUNK_SIZE,
338            MemoryContext::none(),
339        ));
340        let fields = &top_n_executor.schema().fields;
341        assert_eq!(fields[0].data_type, DataType::Int32);
342        assert_eq!(fields[1].data_type, DataType::Int32);
343
344        let mut stream = top_n_executor.execute();
345        let res = stream.next().await;
346
347        assert!(res.is_some());
348        if let Some(res) = res {
349            let res = res.unwrap();
350            assert_eq!(res.cardinality(), 3);
351            assert_eq!(
352                res.column_at(0).as_int32().iter().collect_vec(),
353                vec![Some(4), Some(3), Some(2)]
354            );
355        }
356
357        let res = stream.next().await;
358        assert!(res.is_none());
359    }
360
361    #[tokio::test]
362    async fn test_limit_0() {
363        let schema = Schema {
364            fields: vec![
365                Field::unnamed(DataType::Int32),
366                Field::unnamed(DataType::Int32),
367            ],
368        };
369        let mut mock_executor = MockExecutor::new(schema);
370        mock_executor.add(DataChunk::from_pretty(
371            "i i
372             1 5
373             2 4
374             3 3
375             4 2
376             5 1",
377        ));
378        let column_orders = vec![
379            ColumnOrder {
380                column_index: 1,
381                order_type: OrderType::ascending(),
382            },
383            ColumnOrder {
384                column_index: 0,
385                order_type: OrderType::ascending(),
386            },
387        ];
388        let top_n_executor = Box::new(TopNExecutor::new(
389            Box::new(mock_executor),
390            column_orders,
391            1,
392            0,
393            false,
394            "TopNExecutor".to_owned(),
395            CHUNK_SIZE,
396            MemoryContext::none(),
397        ));
398        let fields = &top_n_executor.schema().fields;
399        assert_eq!(fields[0].data_type, DataType::Int32);
400        assert_eq!(fields[1].data_type, DataType::Int32);
401
402        let mut stream = top_n_executor.execute();
403        let res = stream.next().await;
404
405        assert!(res.is_none());
406    }
407}