risingwave_batch_executors/executor/
top_n.rs1use std::cmp::Ordering;
16use std::sync::Arc;
17
18use futures_async_stream::try_stream;
19use risingwave_common::array::DataChunk;
20use risingwave_common::catalog::Schema;
21use risingwave_common::memory::{MemMonitoredHeap, MemoryContext};
22use risingwave_common::row::{OwnedRow, Row};
23use risingwave_common::util::chunk_coalesce::DataChunkBuilder;
24use risingwave_common::util::memcmp_encoding::{MemcmpEncoded, encode_chunk};
25use risingwave_common::util::sort_util::ColumnOrder;
26use risingwave_common_estimate_size::EstimateSize;
27use risingwave_pb::batch_plan::plan_node::NodeBody;
28
29use crate::error::{BatchError, Result};
30use crate::executor::{
31 BoxedDataChunkStream, BoxedExecutor, BoxedExecutorBuilder, Executor, ExecutorBuilder,
32};
33
34pub struct TopNExecutor {
38 child: BoxedExecutor,
39 column_orders: Vec<ColumnOrder>,
40 offset: usize,
41 limit: usize,
42 with_ties: bool,
43 schema: Schema,
44 identity: String,
45 chunk_size: usize,
46 mem_ctx: MemoryContext,
47}
48
49impl BoxedExecutorBuilder for TopNExecutor {
50 async fn new_boxed_executor(
51 source: &ExecutorBuilder<'_>,
52 inputs: Vec<BoxedExecutor>,
53 ) -> Result<BoxedExecutor> {
54 let [child]: [_; 1] = inputs.try_into().unwrap();
55
56 let top_n_node =
57 try_match_expand!(source.plan_node().get_node_body().unwrap(), NodeBody::TopN)?;
58
59 let column_orders = top_n_node
60 .column_orders
61 .iter()
62 .map(ColumnOrder::from_protobuf)
63 .collect();
64
65 let identity = source.plan_node().get_identity();
66
67 Ok(Box::new(Self::new(
68 child,
69 column_orders,
70 top_n_node.get_offset() as usize,
71 top_n_node.get_limit() as usize,
72 top_n_node.get_with_ties(),
73 identity.clone(),
74 source.context().get_config().developer.chunk_size,
75 source.context().create_executor_mem_context(identity),
76 )))
77 }
78}
79
80impl TopNExecutor {
81 pub fn new(
82 child: BoxedExecutor,
83 column_orders: Vec<ColumnOrder>,
84 offset: usize,
85 limit: usize,
86 with_ties: bool,
87 identity: String,
88 chunk_size: usize,
89 mem_ctx: MemoryContext,
90 ) -> Self {
91 let schema = child.schema().clone();
92 Self {
93 child,
94 column_orders,
95 offset,
96 limit,
97 with_ties,
98 schema,
99 identity,
100 chunk_size,
101 mem_ctx,
102 }
103 }
104}
105
106impl Executor for TopNExecutor {
107 fn schema(&self) -> &Schema {
108 &self.schema
109 }
110
111 fn identity(&self) -> &str {
112 &self.identity
113 }
114
115 fn execute(self: Box<Self>) -> BoxedDataChunkStream {
116 self.do_execute()
117 }
118}
119
120pub const MAX_TOPN_INIT_HEAP_CAPACITY: usize = 1024;
121
122pub struct TopNHeap {
124 heap: MemMonitoredHeap<HeapElem>,
125 limit: usize,
126 offset: usize,
127 with_ties: bool,
128}
129
130impl TopNHeap {
131 pub fn new(limit: usize, offset: usize, with_ties: bool, mem_ctx: MemoryContext) -> Self {
132 assert!(limit > 0);
133 Self {
134 heap: MemMonitoredHeap::with_capacity(
135 (limit + offset).min(MAX_TOPN_INIT_HEAP_CAPACITY),
136 mem_ctx,
137 ),
138 limit,
139 offset,
140 with_ties,
141 }
142 }
143
144 pub fn empty() -> Self {
147 Self {
148 heap: MemMonitoredHeap::with_capacity(0, MemoryContext::none()),
149 limit: 0,
150 offset: 0,
151 with_ties: false,
152 }
153 }
154
155 pub fn push(&mut self, elem: HeapElem) {
156 if self.heap.len() < self.limit + self.offset {
157 self.heap.push(elem);
158 } else {
159 if !self.with_ties {
161 let peek = self.heap.pop().unwrap();
162 if elem < peek {
163 self.heap.push(elem);
164 } else {
165 self.heap.push(peek);
166 }
167 } else {
173 let peek = self.heap.peek().unwrap().clone();
174 match elem.cmp(&peek) {
175 Ordering::Less => {
176 let mut ties_with_peek = vec![];
177 ties_with_peek.push(self.heap.pop().unwrap());
179 while let Some(e) = self.heap.peek()
180 && e.encoded_row == peek.encoded_row
181 {
182 ties_with_peek.push(self.heap.pop().unwrap());
183 }
184 self.heap.push(elem);
185 if self.heap.len() < self.limit {
187 self.heap.extend(ties_with_peek);
188 }
189 }
190 Ordering::Equal => {
191 self.heap.push(elem);
193 }
194 Ordering::Greater => {}
195 }
196 }
197 }
198 }
199
200 pub fn dump(self) -> impl Iterator<Item = HeapElem> {
202 self.heap
203 .into_sorted_vec()
204 .into_iter()
205 .rev()
206 .skip(self.offset)
207 }
208}
209
210#[derive(Clone, EstimateSize)]
211pub struct HeapElem {
212 encoded_row: MemcmpEncoded,
213 row: OwnedRow,
214}
215
216impl PartialEq for HeapElem {
217 fn eq(&self, other: &Self) -> bool {
218 self.encoded_row.eq(&other.encoded_row)
219 }
220}
221
222impl Eq for HeapElem {}
223
224impl PartialOrd for HeapElem {
225 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
226 Some(self.cmp(other))
227 }
228}
229
230impl Ord for HeapElem {
231 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
232 self.encoded_row.cmp(&other.encoded_row)
233 }
234}
235
236impl HeapElem {
237 pub fn new(encoded_row: MemcmpEncoded, row: impl Row) -> Self {
238 Self {
239 encoded_row,
240 row: row.into_owned_row(),
241 }
242 }
243
244 pub fn row(&self) -> impl Row + '_ {
245 &self.row
246 }
247}
248
249impl TopNExecutor {
250 #[try_stream(boxed, ok = DataChunk, error = BatchError)]
251 async fn do_execute(self: Box<Self>) {
252 if self.limit == 0 {
253 return Ok(());
254 }
255 let mut heap = TopNHeap::new(
256 self.limit,
257 self.offset,
258 self.with_ties,
259 self.mem_ctx.clone(),
260 );
261
262 #[for_await]
263 for chunk in self.child.execute() {
264 let chunk = Arc::new(chunk?.compact());
265 for (row_id, encoded_row) in encode_chunk(&chunk, &self.column_orders)?
266 .into_iter()
267 .enumerate()
268 {
269 heap.push(HeapElem {
270 encoded_row,
271 row: chunk.row_at(row_id).0.to_owned_row(),
272 });
273 }
274 }
275
276 let mut chunk_builder = DataChunkBuilder::new(self.schema.data_types(), self.chunk_size);
277 for HeapElem { row, .. } in heap.dump() {
278 if let Some(spilled) = chunk_builder.append_one_row(row) {
279 yield spilled
280 }
281 }
282 if let Some(spilled) = chunk_builder.consume_all() {
283 yield spilled
284 }
285 }
286}
287
288#[cfg(test)]
289mod tests {
290 use futures::stream::StreamExt;
291 use itertools::Itertools;
292 use risingwave_common::array::Array;
293 use risingwave_common::catalog::Field;
294 use risingwave_common::test_prelude::DataChunkTestExt;
295 use risingwave_common::types::DataType;
296 use risingwave_common::util::sort_util::OrderType;
297
298 use super::*;
299 use crate::executor::test_utils::MockExecutor;
300
301 const CHUNK_SIZE: usize = 1024;
302
303 #[tokio::test]
304 async fn test_simple_top_n_executor() {
305 let schema = Schema {
306 fields: vec![
307 Field::unnamed(DataType::Int32),
308 Field::unnamed(DataType::Int32),
309 ],
310 };
311 let mut mock_executor = MockExecutor::new(schema);
312 mock_executor.add(DataChunk::from_pretty(
313 "i i
314 1 5
315 2 4
316 3 3
317 4 2
318 5 1",
319 ));
320 let column_orders = vec![
321 ColumnOrder {
322 column_index: 1,
323 order_type: OrderType::ascending(),
324 },
325 ColumnOrder {
326 column_index: 0,
327 order_type: OrderType::ascending(),
328 },
329 ];
330 let top_n_executor = Box::new(TopNExecutor::new(
331 Box::new(mock_executor),
332 column_orders,
333 1,
334 3,
335 false,
336 "TopNExecutor".to_owned(),
337 CHUNK_SIZE,
338 MemoryContext::none(),
339 ));
340 let fields = &top_n_executor.schema().fields;
341 assert_eq!(fields[0].data_type, DataType::Int32);
342 assert_eq!(fields[1].data_type, DataType::Int32);
343
344 let mut stream = top_n_executor.execute();
345 let res = stream.next().await;
346
347 assert!(res.is_some());
348 if let Some(res) = res {
349 let res = res.unwrap();
350 assert_eq!(res.cardinality(), 3);
351 assert_eq!(
352 res.column_at(0).as_int32().iter().collect_vec(),
353 vec![Some(4), Some(3), Some(2)]
354 );
355 }
356
357 let res = stream.next().await;
358 assert!(res.is_none());
359 }
360
361 #[tokio::test]
362 async fn test_limit_0() {
363 let schema = Schema {
364 fields: vec![
365 Field::unnamed(DataType::Int32),
366 Field::unnamed(DataType::Int32),
367 ],
368 };
369 let mut mock_executor = MockExecutor::new(schema);
370 mock_executor.add(DataChunk::from_pretty(
371 "i i
372 1 5
373 2 4
374 3 3
375 4 2
376 5 1",
377 ));
378 let column_orders = vec![
379 ColumnOrder {
380 column_index: 1,
381 order_type: OrderType::ascending(),
382 },
383 ColumnOrder {
384 column_index: 0,
385 order_type: OrderType::ascending(),
386 },
387 ];
388 let top_n_executor = Box::new(TopNExecutor::new(
389 Box::new(mock_executor),
390 column_orders,
391 1,
392 0,
393 false,
394 "TopNExecutor".to_owned(),
395 CHUNK_SIZE,
396 MemoryContext::none(),
397 ));
398 let fields = &top_n_executor.schema().fields;
399 assert_eq!(fields[0].data_type, DataType::Int32);
400 assert_eq!(fields[1].data_type, DataType::Int32);
401
402 let mut stream = top_n_executor.execute();
403 let res = stream.next().await;
404
405 assert!(res.is_none());
406 }
407}