risingwave_stream/executor/top_n/
group_top_n.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::ops::{Deref, DerefMut};
17
18use risingwave_common::array::Op;
19use risingwave_common::hash::HashKey;
20use risingwave_common::row::{RowDeserializer, RowExt};
21use risingwave_common::util::epoch::EpochPair;
22use risingwave_common::util::iter_util::ZipEqDebug;
23use risingwave_common::util::sort_util::ColumnOrder;
24
25use super::top_n_cache::TopNCacheTrait;
26use super::utils::*;
27use super::{ManagedTopNState, TopNCache};
28use crate::cache::ManagedLruCache;
29use crate::common::metrics::MetricsInfo;
30use crate::common::table::state_table::StateTablePostCommit;
31use crate::executor::monitor::GroupTopNMetrics;
32use crate::executor::prelude::*;
33use crate::executor::top_n::top_n_cache::TopNStaging;
34
35pub type GroupTopNExecutor<K, S, const WITH_TIES: bool> =
36    TopNExecutorWrapper<InnerGroupTopNExecutor<K, S, WITH_TIES>>;
37
38impl<K: HashKey, S: StateStore, const WITH_TIES: bool> GroupTopNExecutor<K, S, WITH_TIES> {
39    #[allow(clippy::too_many_arguments)]
40    pub fn new(
41        input: Executor,
42        ctx: ActorContextRef,
43        schema: Schema,
44        storage_key: Vec<ColumnOrder>,
45        offset_and_limit: (usize, usize),
46        order_by: Vec<ColumnOrder>,
47        group_by: Vec<usize>,
48        state_table: StateTable<S>,
49        watermark_epoch: AtomicU64Ref,
50    ) -> StreamResult<Self> {
51        let inner = InnerGroupTopNExecutor::new(
52            schema,
53            storage_key,
54            offset_and_limit,
55            order_by,
56            group_by,
57            state_table,
58            watermark_epoch,
59            &ctx,
60        )?;
61        Ok(TopNExecutorWrapper { input, ctx, inner })
62    }
63}
64
65pub struct InnerGroupTopNExecutor<K: HashKey, S: StateStore, const WITH_TIES: bool> {
66    schema: Schema,
67
68    /// `LIMIT XXX`. None means no limit.
69    limit: usize,
70
71    /// `OFFSET XXX`. `0` means no offset.
72    offset: usize,
73
74    /// The storage key indices of the `GroupTopNExecutor`
75    storage_key_indices: Vec<usize>,
76
77    managed_state: ManagedTopNState<S>,
78
79    /// which column we used to group the data.
80    group_by: Vec<usize>,
81
82    /// group key -> cache for this group
83    caches: GroupTopNCache<K, WITH_TIES>,
84
85    /// Used for serializing pk into `CacheKey`.
86    cache_key_serde: CacheKeySerde,
87
88    /// Minimum cache capacity per group from config
89    topn_cache_min_capacity: usize,
90
91    metrics: GroupTopNMetrics,
92}
93
94impl<K: HashKey, S: StateStore, const WITH_TIES: bool> InnerGroupTopNExecutor<K, S, WITH_TIES> {
95    #[allow(clippy::too_many_arguments)]
96    pub fn new(
97        schema: Schema,
98        storage_key: Vec<ColumnOrder>,
99        offset_and_limit: (usize, usize),
100        order_by: Vec<ColumnOrder>,
101        group_by: Vec<usize>,
102        state_table: StateTable<S>,
103        watermark_epoch: AtomicU64Ref,
104        ctx: &ActorContext,
105    ) -> StreamResult<Self> {
106        let metrics_info = MetricsInfo::new(
107            ctx.streaming_metrics.clone(),
108            state_table.table_id(),
109            ctx.id,
110            "GroupTopN",
111        );
112        let metrics = ctx.streaming_metrics.new_group_top_n_metrics(
113            state_table.table_id(),
114            ctx.id,
115            ctx.fragment_id,
116        );
117
118        let cache_key_serde = create_cache_key_serde(&storage_key, &schema, &order_by, &group_by);
119        let managed_state = ManagedTopNState::<S>::new(state_table, cache_key_serde.clone());
120
121        Ok(Self {
122            schema,
123            offset: offset_and_limit.0,
124            limit: offset_and_limit.1,
125            managed_state,
126            storage_key_indices: storage_key.into_iter().map(|op| op.column_index).collect(),
127            group_by,
128            caches: GroupTopNCache::new(watermark_epoch, metrics_info),
129            cache_key_serde,
130            topn_cache_min_capacity: ctx.config.developer.topn_cache_min_capacity,
131            metrics,
132        })
133    }
134}
135
136pub struct GroupTopNCache<K: HashKey, const WITH_TIES: bool> {
137    data: ManagedLruCache<K, TopNCache<WITH_TIES>>,
138}
139
140impl<K: HashKey, const WITH_TIES: bool> GroupTopNCache<K, WITH_TIES> {
141    pub fn new(watermark_sequence: AtomicU64Ref, metrics_info: MetricsInfo) -> Self {
142        let cache = ManagedLruCache::unbounded(watermark_sequence, metrics_info);
143        Self { data: cache }
144    }
145}
146
147impl<K: HashKey, const WITH_TIES: bool> Deref for GroupTopNCache<K, WITH_TIES> {
148    type Target = ManagedLruCache<K, TopNCache<WITH_TIES>>;
149
150    fn deref(&self) -> &Self::Target {
151        &self.data
152    }
153}
154
155impl<K: HashKey, const WITH_TIES: bool> DerefMut for GroupTopNCache<K, WITH_TIES> {
156    fn deref_mut(&mut self) -> &mut Self::Target {
157        &mut self.data
158    }
159}
160
161impl<K: HashKey, S: StateStore, const WITH_TIES: bool> TopNExecutorBase
162    for InnerGroupTopNExecutor<K, S, WITH_TIES>
163where
164    TopNCache<WITH_TIES>: TopNCacheTrait,
165{
166    type State = S;
167
168    async fn apply_chunk(
169        &mut self,
170        chunk: StreamChunk,
171    ) -> StreamExecutorResult<Option<StreamChunk>> {
172        let keys = K::build_many(&self.group_by, chunk.data_chunk());
173        let mut stagings: HashMap<K, TopNStaging> = HashMap::new();
174
175        for (r, group_cache_key) in chunk.rows_with_holes().zip_eq_debug(keys.iter()) {
176            let Some((op, row_ref)) = r else {
177                continue;
178            };
179
180            // The pk without group by
181            let pk_row = row_ref.project(&self.storage_key_indices[self.group_by.len()..]);
182            let cache_key = serialize_pk_to_cache_key(pk_row, &self.cache_key_serde);
183
184            let group_key = row_ref.project(&self.group_by);
185            self.metrics.group_top_n_total_query_cache_count.inc();
186            // If 'self.caches' does not already have a cache for the current group, create a new
187            // cache for it and insert it into `self.caches`
188            if !self.caches.contains(group_cache_key) {
189                self.metrics.group_top_n_cache_miss_count.inc();
190                let mut topn_cache = TopNCache::with_min_capacity(
191                    self.offset,
192                    self.limit,
193                    self.schema.data_types(),
194                    self.topn_cache_min_capacity,
195                );
196                self.managed_state
197                    .init_topn_cache(Some(group_key), &mut topn_cache)
198                    .await?;
199                self.caches.put(group_cache_key.clone(), topn_cache);
200            }
201
202            let mut cache = self.caches.get_mut(group_cache_key).unwrap();
203            let staging = stagings.entry(group_cache_key.clone()).or_default();
204
205            // apply the chunk to state table
206            match op {
207                Op::Insert | Op::UpdateInsert => {
208                    self.managed_state.insert(row_ref);
209                    cache.insert(cache_key, row_ref, staging);
210                }
211
212                Op::Delete | Op::UpdateDelete => {
213                    self.managed_state.delete(row_ref);
214                    cache
215                        .delete(
216                            Some(group_key),
217                            &mut self.managed_state,
218                            cache_key,
219                            row_ref,
220                            staging,
221                        )
222                        .await?;
223                }
224            }
225        }
226
227        self.metrics
228            .group_top_n_cached_entry_count
229            .set(self.caches.len() as i64);
230
231        let data_types = self.schema.data_types();
232        let deserializer = RowDeserializer::new(data_types.clone());
233        let mut chunk_builder = StreamChunkBuilder::unlimited(data_types, Some(chunk.capacity()));
234        for staging in stagings.into_values() {
235            for res in staging.into_deserialized_changes(&deserializer) {
236                let record = res?;
237                let _none = chunk_builder.append_record(record);
238            }
239        }
240        Ok(chunk_builder.take())
241    }
242
243    async fn flush_data(
244        &mut self,
245        epoch: EpochPair,
246    ) -> StreamExecutorResult<StateTablePostCommit<'_, S>> {
247        self.managed_state.flush(epoch).await
248    }
249
250    async fn try_flush_data(&mut self) -> StreamExecutorResult<()> {
251        self.managed_state.try_flush().await
252    }
253
254    fn clear_cache(&mut self) {
255        self.caches.clear();
256    }
257
258    fn evict(&mut self) {
259        self.caches.evict()
260    }
261
262    async fn init(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> {
263        self.managed_state.init_epoch(epoch).await
264    }
265
266    async fn handle_watermark(&mut self, watermark: Watermark) -> Option<Watermark> {
267        if watermark.col_idx == self.group_by[0] {
268            self.managed_state.update_watermark(watermark.val.clone());
269            Some(watermark)
270        } else {
271            None
272        }
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use std::sync::atomic::AtomicU64;
279
280    use risingwave_common::array::stream_chunk::StreamChunkTestExt;
281    use risingwave_common::catalog::Field;
282    use risingwave_common::hash::SerializedKey;
283    use risingwave_common::util::epoch::test_epoch;
284    use risingwave_common::util::sort_util::OrderType;
285    use risingwave_storage::memory::MemoryStateStore;
286
287    use super::*;
288    use crate::executor::test_utils::top_n_executor::create_in_memory_state_table;
289    use crate::executor::test_utils::{MockSource, StreamExecutorTestExt};
290
291    fn create_schema() -> Schema {
292        Schema {
293            fields: vec![
294                Field::unnamed(DataType::Int64),
295                Field::unnamed(DataType::Int64),
296                Field::unnamed(DataType::Int64),
297            ],
298        }
299    }
300
301    fn storage_key() -> Vec<ColumnOrder> {
302        vec![
303            ColumnOrder::new(1, OrderType::ascending()),
304            ColumnOrder::new(2, OrderType::ascending()),
305            ColumnOrder::new(0, OrderType::ascending()),
306        ]
307    }
308
309    /// group by 1, order by 2
310    fn order_by_1() -> Vec<ColumnOrder> {
311        vec![ColumnOrder::new(2, OrderType::ascending())]
312    }
313
314    /// group by 1,2, order by 0
315    fn order_by_2() -> Vec<ColumnOrder> {
316        vec![ColumnOrder::new(0, OrderType::ascending())]
317    }
318
319    fn stream_key() -> StreamKey {
320        vec![1, 2, 0]
321    }
322
323    fn create_stream_chunks() -> Vec<StreamChunk> {
324        let chunk0 = StreamChunk::from_pretty(
325            "  I I I
326            + 10 9 1
327            +  8 8 2
328            +  7 8 2
329            +  9 1 1
330            + 10 1 1
331            +  8 1 3",
332        );
333        let chunk1 = StreamChunk::from_pretty(
334            "  I I I
335            - 10 9 1
336            -  8 8 2
337            - 10 1 1",
338        );
339        let chunk2 = StreamChunk::from_pretty(
340            " I I I
341            - 7 8 2
342            - 8 1 3
343            - 9 1 1",
344        );
345        let chunk3 = StreamChunk::from_pretty(
346            "  I I I
347            +  5 1 1
348            +  2 1 1
349            +  3 1 2
350            +  4 1 3",
351        );
352        vec![chunk0, chunk1, chunk2, chunk3]
353    }
354
355    fn create_source() -> Executor {
356        let mut chunks = create_stream_chunks();
357        let schema = create_schema();
358        MockSource::with_messages(vec![
359            Message::Barrier(Barrier::new_test_barrier(test_epoch(1))),
360            Message::Chunk(std::mem::take(&mut chunks[0])),
361            Message::Barrier(Barrier::new_test_barrier(test_epoch(2))),
362            Message::Chunk(std::mem::take(&mut chunks[1])),
363            Message::Barrier(Barrier::new_test_barrier(test_epoch(3))),
364            Message::Chunk(std::mem::take(&mut chunks[2])),
365            Message::Barrier(Barrier::new_test_barrier(test_epoch(4))),
366            Message::Chunk(std::mem::take(&mut chunks[3])),
367            Message::Barrier(Barrier::new_test_barrier(test_epoch(5))),
368        ])
369        .into_executor(schema, stream_key())
370    }
371
372    #[tokio::test]
373    async fn test_without_offset_and_with_limits() {
374        let source = create_source();
375        let state_table = create_in_memory_state_table(
376            &[DataType::Int64, DataType::Int64, DataType::Int64],
377            &[
378                OrderType::ascending(),
379                OrderType::ascending(),
380                OrderType::ascending(),
381            ],
382            &stream_key(),
383        )
384        .await;
385        let schema = source.schema().clone();
386        let top_n = GroupTopNExecutor::<SerializedKey, MemoryStateStore, false>::new(
387            source,
388            ActorContext::for_test(0),
389            schema,
390            storage_key(),
391            (0, 2),
392            order_by_1(),
393            vec![1],
394            state_table,
395            Arc::new(AtomicU64::new(0)),
396        )
397        .unwrap();
398        let mut top_n = top_n.boxed().execute();
399
400        // consume the init barrier
401        top_n.expect_barrier().await;
402        assert_eq!(
403            top_n.expect_chunk().await.sort_rows(),
404            StreamChunk::from_pretty(
405                "  I I I
406                + 10 9 1
407                +  8 8 2
408                +  7 8 2
409                +  9 1 1
410                + 10 1 1
411                ",
412            )
413            .sort_rows(),
414        );
415
416        // barrier
417        top_n.expect_barrier().await;
418        assert_eq!(
419            top_n.expect_chunk().await.sort_rows(),
420            StreamChunk::from_pretty(
421                "  I I I
422                - 10 9 1
423                -  8 8 2
424                - 10 1 1
425                +  8 1 3
426                ",
427            )
428            .sort_rows(),
429        );
430
431        // barrier
432        top_n.expect_barrier().await;
433        assert_eq!(
434            top_n.expect_chunk().await.sort_rows(),
435            StreamChunk::from_pretty(
436                " I I I
437                - 7 8 2
438                - 8 1 3
439                - 9 1 1
440                ",
441            )
442            .sort_rows(),
443        );
444
445        // barrier
446        top_n.expect_barrier().await;
447        assert_eq!(
448            top_n.expect_chunk().await.sort_rows(),
449            StreamChunk::from_pretty(
450                " I I I
451                + 5 1 1
452                + 2 1 1
453                ",
454            )
455            .sort_rows(),
456        );
457    }
458
459    #[tokio::test]
460    async fn test_with_offset_and_with_limits() {
461        let source = create_source();
462        let state_table = create_in_memory_state_table(
463            &[DataType::Int64, DataType::Int64, DataType::Int64],
464            &[
465                OrderType::ascending(),
466                OrderType::ascending(),
467                OrderType::ascending(),
468            ],
469            &stream_key(),
470        )
471        .await;
472        let schema = source.schema().clone();
473        let top_n = GroupTopNExecutor::<SerializedKey, MemoryStateStore, false>::new(
474            source,
475            ActorContext::for_test(0),
476            schema,
477            storage_key(),
478            (1, 2),
479            order_by_1(),
480            vec![1],
481            state_table,
482            Arc::new(AtomicU64::new(0)),
483        )
484        .unwrap();
485        let mut top_n = top_n.boxed().execute();
486
487        // consume the init barrier
488        top_n.expect_barrier().await;
489        assert_eq!(
490            top_n.expect_chunk().await.sort_rows(),
491            StreamChunk::from_pretty(
492                "  I I I
493                +  8 8 2
494                + 10 1 1
495                +  8 1 3
496                ",
497            )
498            .sort_rows(),
499        );
500
501        // barrier
502        top_n.expect_barrier().await;
503        assert_eq!(
504            top_n.expect_chunk().await.sort_rows(),
505            StreamChunk::from_pretty(
506                "  I I I
507                -  8 8 2
508                - 10 1 1
509                ",
510            )
511            .sort_rows(),
512        );
513
514        // barrier
515        top_n.expect_barrier().await;
516        assert_eq!(
517            top_n.expect_chunk().await.sort_rows(),
518            StreamChunk::from_pretty(
519                " I I I
520                - 8 1 3",
521            )
522            .sort_rows(),
523        );
524
525        // barrier
526        top_n.expect_barrier().await;
527        assert_eq!(
528            top_n.expect_chunk().await.sort_rows(),
529            StreamChunk::from_pretty(
530                " I I I
531                + 5 1 1
532                + 3 1 2
533                ",
534            )
535            .sort_rows(),
536        );
537    }
538
539    #[tokio::test]
540    async fn test_multi_group_key() {
541        let source = create_source();
542        let state_table = create_in_memory_state_table(
543            &[DataType::Int64, DataType::Int64, DataType::Int64],
544            &[
545                OrderType::ascending(),
546                OrderType::ascending(),
547                OrderType::ascending(),
548            ],
549            &stream_key(),
550        )
551        .await;
552        let schema = source.schema().clone();
553        let top_n = GroupTopNExecutor::<SerializedKey, MemoryStateStore, false>::new(
554            source,
555            ActorContext::for_test(0),
556            schema,
557            storage_key(),
558            (0, 2),
559            order_by_2(),
560            vec![1, 2],
561            state_table,
562            Arc::new(AtomicU64::new(0)),
563        )
564        .unwrap();
565        let mut top_n = top_n.boxed().execute();
566
567        // consume the init barrier
568        top_n.expect_barrier().await;
569        assert_eq!(
570            top_n.expect_chunk().await.sort_rows(),
571            StreamChunk::from_pretty(
572                "  I I I
573                + 10 9 1
574                +  8 8 2
575                +  7 8 2
576                +  9 1 1
577                + 10 1 1
578                +  8 1 3",
579            )
580            .sort_rows(),
581        );
582
583        // barrier
584        top_n.expect_barrier().await;
585        assert_eq!(
586            top_n.expect_chunk().await.sort_rows(),
587            StreamChunk::from_pretty(
588                "  I I I
589                - 10 9 1
590                -  8 8 2
591                - 10 1 1",
592            )
593            .sort_rows(),
594        );
595
596        // barrier
597        top_n.expect_barrier().await;
598        assert_eq!(
599            top_n.expect_chunk().await.sort_rows(),
600            StreamChunk::from_pretty(
601                "  I I I
602                - 7 8 2
603                - 8 1 3
604                - 9 1 1",
605            )
606            .sort_rows(),
607        );
608
609        // barrier
610        top_n.expect_barrier().await;
611        assert_eq!(
612            top_n.expect_chunk().await.sort_rows(),
613            StreamChunk::from_pretty(
614                "  I I I
615                +  5 1 1
616                +  2 1 1
617                +  3 1 2
618                +  4 1 3",
619            )
620            .sort_rows(),
621        );
622    }
623
624    #[tokio::test]
625    async fn test_compact_changes() {
626        let schema = create_schema();
627        let source = MockSource::with_messages(vec![
628            Message::Barrier(Barrier::new_test_barrier(test_epoch(1))),
629            Message::Chunk(StreamChunk::from_pretty(
630                "  I I I
631                +  0 0 9
632                +  0 0 8
633                +  0 0 7
634                +  0 0 6
635                +  0 1 15
636                +  0 1 14",
637            )),
638            Message::Barrier(Barrier::new_test_barrier(test_epoch(2))),
639            Message::Chunk(StreamChunk::from_pretty(
640                "  I I I
641                -  0 0 6
642                -  0 0 8
643                +  0 0 4
644                +  0 0 3
645                +  0 1 12
646                +  0 2 26
647                -  0 1 12
648                +  0 1 11",
649            )),
650            Message::Barrier(Barrier::new_test_barrier(test_epoch(3))),
651            Message::Chunk(StreamChunk::from_pretty(
652                "  I I I
653                +  0 0 11", // this should result in no chunk output
654            )),
655            Message::Barrier(Barrier::new_test_barrier(test_epoch(4))),
656        ])
657        .into_executor(schema.clone(), vec![2]);
658
659        let state_table = create_in_memory_state_table(
660            &schema.data_types(),
661            &[
662                OrderType::ascending(),
663                OrderType::ascending(),
664                OrderType::ascending(),
665            ],
666            &[0, 1, 2], // table pk = group key (0, 1) + order key (2) + additional pk (empty)
667        )
668        .await;
669
670        let top_n = GroupTopNExecutor::<SerializedKey, MemoryStateStore, false>::new(
671            source,
672            ActorContext::for_test(0),
673            schema,
674            vec![
675                ColumnOrder::new(0, OrderType::ascending()),
676                ColumnOrder::new(1, OrderType::ascending()),
677                ColumnOrder::new(2, OrderType::ascending()),
678            ],
679            (0, 2), // (offset, limit)
680            vec![ColumnOrder::new(2, OrderType::ascending())],
681            vec![0, 1],
682            state_table,
683            Arc::new(AtomicU64::new(0)),
684        )
685        .unwrap();
686        let mut top_n = top_n.boxed().execute();
687
688        // initial barrier
689        top_n.expect_barrier().await;
690
691        assert_eq!(
692            top_n.expect_chunk().await.sort_rows(),
693            StreamChunk::from_pretty(
694                "  I I I
695                +  0 0 7
696                +  0 0 6
697                +  0 1 15
698                +  0 1 14",
699            )
700            .sort_rows(),
701        );
702        top_n.expect_barrier().await;
703
704        assert_eq!(
705            top_n.expect_chunk().await.sort_rows(),
706            StreamChunk::from_pretty(
707                "  I I I
708                -  0 0 6
709                -  0 0 7
710                +  0 0 4
711                +  0 0 3
712                -  0 1 15
713                +  0 1 11
714                +  0 2 26",
715            )
716            .sort_rows(),
717        );
718        top_n.expect_barrier().await;
719
720        // no output chunk for the last input chunk
721        top_n.expect_barrier().await;
722    }
723}