risingwave_stream/executor/top_n/
group_top_n.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::ops::{Deref, DerefMut};
17
18use risingwave_common::array::Op;
19use risingwave_common::hash::HashKey;
20use risingwave_common::row::{RowDeserializer, RowExt};
21use risingwave_common::util::epoch::EpochPair;
22use risingwave_common::util::iter_util::ZipEqDebug;
23use risingwave_common::util::sort_util::ColumnOrder;
24
25use super::top_n_cache::TopNCacheTrait;
26use super::utils::*;
27use super::{ManagedTopNState, TopNCache};
28use crate::cache::ManagedLruCache;
29use crate::common::metrics::MetricsInfo;
30use crate::common::table::state_table::StateTablePostCommit;
31use crate::executor::monitor::GroupTopNMetrics;
32use crate::executor::prelude::*;
33
34pub type GroupTopNExecutor<K, S, const WITH_TIES: bool> =
35    TopNExecutorWrapper<InnerGroupTopNExecutor<K, S, WITH_TIES>>;
36
37impl<K: HashKey, S: StateStore, const WITH_TIES: bool> GroupTopNExecutor<K, S, WITH_TIES> {
38    #[allow(clippy::too_many_arguments)]
39    pub fn new(
40        input: Executor,
41        ctx: ActorContextRef,
42        schema: Schema,
43        storage_key: Vec<ColumnOrder>,
44        offset_and_limit: (usize, usize),
45        order_by: Vec<ColumnOrder>,
46        group_by: Vec<usize>,
47        state_table: StateTable<S>,
48        watermark_epoch: AtomicU64Ref,
49    ) -> StreamResult<Self> {
50        let inner = InnerGroupTopNExecutor::new(
51            schema,
52            storage_key,
53            offset_and_limit,
54            order_by,
55            group_by,
56            state_table,
57            watermark_epoch,
58            &ctx,
59        )?;
60        Ok(TopNExecutorWrapper { input, ctx, inner })
61    }
62}
63
64pub struct InnerGroupTopNExecutor<K: HashKey, S: StateStore, const WITH_TIES: bool> {
65    schema: Schema,
66
67    /// `LIMIT XXX`. None means no limit.
68    limit: usize,
69
70    /// `OFFSET XXX`. `0` means no offset.
71    offset: usize,
72
73    /// The storage key indices of the `GroupTopNExecutor`
74    storage_key_indices: PkIndices,
75
76    managed_state: ManagedTopNState<S>,
77
78    /// which column we used to group the data.
79    group_by: Vec<usize>,
80
81    /// group key -> cache for this group
82    caches: GroupTopNCache<K, WITH_TIES>,
83
84    /// Used for serializing pk into `CacheKey`.
85    cache_key_serde: CacheKeySerde,
86
87    /// Minimum cache capacity per group from config
88    topn_cache_min_capacity: usize,
89
90    metrics: GroupTopNMetrics,
91}
92
93impl<K: HashKey, S: StateStore, const WITH_TIES: bool> InnerGroupTopNExecutor<K, S, WITH_TIES> {
94    #[allow(clippy::too_many_arguments)]
95    pub fn new(
96        schema: Schema,
97        storage_key: Vec<ColumnOrder>,
98        offset_and_limit: (usize, usize),
99        order_by: Vec<ColumnOrder>,
100        group_by: Vec<usize>,
101        state_table: StateTable<S>,
102        watermark_epoch: AtomicU64Ref,
103        ctx: &ActorContext,
104    ) -> StreamResult<Self> {
105        let metrics_info = MetricsInfo::new(
106            ctx.streaming_metrics.clone(),
107            state_table.table_id(),
108            ctx.id,
109            "GroupTopN",
110        );
111        let metrics = ctx.streaming_metrics.new_group_top_n_metrics(
112            state_table.table_id(),
113            ctx.id,
114            ctx.fragment_id,
115        );
116
117        let cache_key_serde = create_cache_key_serde(&storage_key, &schema, &order_by, &group_by);
118        let managed_state = ManagedTopNState::<S>::new(state_table, cache_key_serde.clone());
119
120        Ok(Self {
121            schema,
122            offset: offset_and_limit.0,
123            limit: offset_and_limit.1,
124            managed_state,
125            storage_key_indices: storage_key.into_iter().map(|op| op.column_index).collect(),
126            group_by,
127            caches: GroupTopNCache::new(watermark_epoch, metrics_info),
128            cache_key_serde,
129            topn_cache_min_capacity: ctx.streaming_config.developer.topn_cache_min_capacity,
130            metrics,
131        })
132    }
133}
134
135pub struct GroupTopNCache<K: HashKey, const WITH_TIES: bool> {
136    data: ManagedLruCache<K, TopNCache<WITH_TIES>>,
137}
138
139impl<K: HashKey, const WITH_TIES: bool> GroupTopNCache<K, WITH_TIES> {
140    pub fn new(watermark_sequence: AtomicU64Ref, metrics_info: MetricsInfo) -> Self {
141        let cache = ManagedLruCache::unbounded(watermark_sequence, metrics_info);
142        Self { data: cache }
143    }
144}
145
146impl<K: HashKey, const WITH_TIES: bool> Deref for GroupTopNCache<K, WITH_TIES> {
147    type Target = ManagedLruCache<K, TopNCache<WITH_TIES>>;
148
149    fn deref(&self) -> &Self::Target {
150        &self.data
151    }
152}
153
154impl<K: HashKey, const WITH_TIES: bool> DerefMut for GroupTopNCache<K, WITH_TIES> {
155    fn deref_mut(&mut self) -> &mut Self::Target {
156        &mut self.data
157    }
158}
159
160impl<K: HashKey, S: StateStore, const WITH_TIES: bool> TopNExecutorBase
161    for InnerGroupTopNExecutor<K, S, WITH_TIES>
162where
163    TopNCache<WITH_TIES>: TopNCacheTrait,
164{
165    type State = S;
166
167    async fn apply_chunk(
168        &mut self,
169        chunk: StreamChunk,
170    ) -> StreamExecutorResult<Option<StreamChunk>> {
171        let keys = K::build_many(&self.group_by, chunk.data_chunk());
172        let mut stagings = HashMap::new(); // K -> `TopNStaging`
173
174        for (r, group_cache_key) in chunk.rows_with_holes().zip_eq_debug(keys.iter()) {
175            let Some((op, row_ref)) = r else {
176                continue;
177            };
178
179            // The pk without group by
180            let pk_row = row_ref.project(&self.storage_key_indices[self.group_by.len()..]);
181            let cache_key = serialize_pk_to_cache_key(pk_row, &self.cache_key_serde);
182
183            let group_key = row_ref.project(&self.group_by);
184            self.metrics.group_top_n_total_query_cache_count.inc();
185            // If 'self.caches' does not already have a cache for the current group, create a new
186            // cache for it and insert it into `self.caches`
187            if !self.caches.contains(group_cache_key) {
188                self.metrics.group_top_n_cache_miss_count.inc();
189                let mut topn_cache = TopNCache::with_min_capacity(
190                    self.offset,
191                    self.limit,
192                    self.schema.data_types(),
193                    self.topn_cache_min_capacity,
194                );
195                self.managed_state
196                    .init_topn_cache(Some(group_key), &mut topn_cache)
197                    .await?;
198                self.caches.put(group_cache_key.clone(), topn_cache);
199            }
200
201            let mut cache = self.caches.get_mut(group_cache_key).unwrap();
202            let staging = stagings.entry(group_cache_key.clone()).or_default();
203
204            // apply the chunk to state table
205            match op {
206                Op::Insert | Op::UpdateInsert => {
207                    self.managed_state.insert(row_ref);
208                    cache.insert(cache_key, row_ref, staging);
209                }
210
211                Op::Delete | Op::UpdateDelete => {
212                    self.managed_state.delete(row_ref);
213                    cache
214                        .delete(
215                            Some(group_key),
216                            &mut self.managed_state,
217                            cache_key,
218                            row_ref,
219                            staging,
220                        )
221                        .await?;
222                }
223            }
224        }
225
226        self.metrics
227            .group_top_n_cached_entry_count
228            .set(self.caches.len() as i64);
229
230        let data_types = self.schema.data_types();
231        let deserializer = RowDeserializer::new(data_types.clone());
232        let mut chunk_builder = StreamChunkBuilder::unlimited(data_types, Some(chunk.capacity()));
233        for staging in stagings.into_values() {
234            for res in staging.into_deserialized_changes(&deserializer) {
235                let (op, row) = res?;
236                let _none = chunk_builder.append_row(op, row);
237            }
238        }
239        Ok(chunk_builder.take())
240    }
241
242    async fn flush_data(
243        &mut self,
244        epoch: EpochPair,
245    ) -> StreamExecutorResult<StateTablePostCommit<'_, S>> {
246        self.managed_state.flush(epoch).await
247    }
248
249    async fn try_flush_data(&mut self) -> StreamExecutorResult<()> {
250        self.managed_state.try_flush().await
251    }
252
253    fn clear_cache(&mut self) {
254        self.caches.clear();
255    }
256
257    fn evict(&mut self) {
258        self.caches.evict()
259    }
260
261    async fn init(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> {
262        self.managed_state.init_epoch(epoch).await
263    }
264
265    async fn handle_watermark(&mut self, watermark: Watermark) -> Option<Watermark> {
266        if watermark.col_idx == self.group_by[0] {
267            self.managed_state.update_watermark(watermark.val.clone());
268            Some(watermark)
269        } else {
270            None
271        }
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use std::sync::atomic::AtomicU64;
278
279    use risingwave_common::array::stream_chunk::StreamChunkTestExt;
280    use risingwave_common::catalog::Field;
281    use risingwave_common::hash::SerializedKey;
282    use risingwave_common::util::epoch::test_epoch;
283    use risingwave_common::util::sort_util::OrderType;
284    use risingwave_storage::memory::MemoryStateStore;
285
286    use super::*;
287    use crate::executor::test_utils::top_n_executor::create_in_memory_state_table;
288    use crate::executor::test_utils::{MockSource, StreamExecutorTestExt};
289
290    fn create_schema() -> Schema {
291        Schema {
292            fields: vec![
293                Field::unnamed(DataType::Int64),
294                Field::unnamed(DataType::Int64),
295                Field::unnamed(DataType::Int64),
296            ],
297        }
298    }
299
300    fn storage_key() -> Vec<ColumnOrder> {
301        vec![
302            ColumnOrder::new(1, OrderType::ascending()),
303            ColumnOrder::new(2, OrderType::ascending()),
304            ColumnOrder::new(0, OrderType::ascending()),
305        ]
306    }
307
308    /// group by 1, order by 2
309    fn order_by_1() -> Vec<ColumnOrder> {
310        vec![ColumnOrder::new(2, OrderType::ascending())]
311    }
312
313    /// group by 1,2, order by 0
314    fn order_by_2() -> Vec<ColumnOrder> {
315        vec![ColumnOrder::new(0, OrderType::ascending())]
316    }
317
318    fn pk_indices() -> PkIndices {
319        vec![1, 2, 0]
320    }
321
322    fn create_stream_chunks() -> Vec<StreamChunk> {
323        let chunk0 = StreamChunk::from_pretty(
324            "  I I I
325            + 10 9 1
326            +  8 8 2
327            +  7 8 2
328            +  9 1 1
329            + 10 1 1
330            +  8 1 3",
331        );
332        let chunk1 = StreamChunk::from_pretty(
333            "  I I I
334            - 10 9 1
335            -  8 8 2
336            - 10 1 1",
337        );
338        let chunk2 = StreamChunk::from_pretty(
339            " I I I
340            - 7 8 2
341            - 8 1 3
342            - 9 1 1",
343        );
344        let chunk3 = StreamChunk::from_pretty(
345            "  I I I
346            +  5 1 1
347            +  2 1 1
348            +  3 1 2
349            +  4 1 3",
350        );
351        vec![chunk0, chunk1, chunk2, chunk3]
352    }
353
354    fn create_source() -> Executor {
355        let mut chunks = create_stream_chunks();
356        let schema = create_schema();
357        MockSource::with_messages(vec![
358            Message::Barrier(Barrier::new_test_barrier(test_epoch(1))),
359            Message::Chunk(std::mem::take(&mut chunks[0])),
360            Message::Barrier(Barrier::new_test_barrier(test_epoch(2))),
361            Message::Chunk(std::mem::take(&mut chunks[1])),
362            Message::Barrier(Barrier::new_test_barrier(test_epoch(3))),
363            Message::Chunk(std::mem::take(&mut chunks[2])),
364            Message::Barrier(Barrier::new_test_barrier(test_epoch(4))),
365            Message::Chunk(std::mem::take(&mut chunks[3])),
366            Message::Barrier(Barrier::new_test_barrier(test_epoch(5))),
367        ])
368        .into_executor(schema, pk_indices())
369    }
370
371    #[tokio::test]
372    async fn test_without_offset_and_with_limits() {
373        let source = create_source();
374        let state_table = create_in_memory_state_table(
375            &[DataType::Int64, DataType::Int64, DataType::Int64],
376            &[
377                OrderType::ascending(),
378                OrderType::ascending(),
379                OrderType::ascending(),
380            ],
381            &pk_indices(),
382        )
383        .await;
384        let schema = source.schema().clone();
385        let top_n = GroupTopNExecutor::<SerializedKey, MemoryStateStore, false>::new(
386            source,
387            ActorContext::for_test(0),
388            schema,
389            storage_key(),
390            (0, 2),
391            order_by_1(),
392            vec![1],
393            state_table,
394            Arc::new(AtomicU64::new(0)),
395        )
396        .unwrap();
397        let mut top_n = top_n.boxed().execute();
398
399        // consume the init barrier
400        top_n.expect_barrier().await;
401        assert_eq!(
402            top_n.expect_chunk().await.sort_rows(),
403            StreamChunk::from_pretty(
404                "  I I I
405                + 10 9 1
406                +  8 8 2
407                +  7 8 2
408                +  9 1 1
409                + 10 1 1
410                ",
411            )
412            .sort_rows(),
413        );
414
415        // barrier
416        top_n.expect_barrier().await;
417        assert_eq!(
418            top_n.expect_chunk().await.sort_rows(),
419            StreamChunk::from_pretty(
420                "  I I I
421                - 10 9 1
422                -  8 8 2
423                - 10 1 1
424                +  8 1 3
425                ",
426            )
427            .sort_rows(),
428        );
429
430        // barrier
431        top_n.expect_barrier().await;
432        assert_eq!(
433            top_n.expect_chunk().await.sort_rows(),
434            StreamChunk::from_pretty(
435                " I I I
436                - 7 8 2
437                - 8 1 3
438                - 9 1 1
439                ",
440            )
441            .sort_rows(),
442        );
443
444        // barrier
445        top_n.expect_barrier().await;
446        assert_eq!(
447            top_n.expect_chunk().await.sort_rows(),
448            StreamChunk::from_pretty(
449                " I I I
450                + 5 1 1
451                + 2 1 1
452                ",
453            )
454            .sort_rows(),
455        );
456    }
457
458    #[tokio::test]
459    async fn test_with_offset_and_with_limits() {
460        let source = create_source();
461        let state_table = create_in_memory_state_table(
462            &[DataType::Int64, DataType::Int64, DataType::Int64],
463            &[
464                OrderType::ascending(),
465                OrderType::ascending(),
466                OrderType::ascending(),
467            ],
468            &pk_indices(),
469        )
470        .await;
471        let schema = source.schema().clone();
472        let top_n = GroupTopNExecutor::<SerializedKey, MemoryStateStore, false>::new(
473            source,
474            ActorContext::for_test(0),
475            schema,
476            storage_key(),
477            (1, 2),
478            order_by_1(),
479            vec![1],
480            state_table,
481            Arc::new(AtomicU64::new(0)),
482        )
483        .unwrap();
484        let mut top_n = top_n.boxed().execute();
485
486        // consume the init barrier
487        top_n.expect_barrier().await;
488        assert_eq!(
489            top_n.expect_chunk().await.sort_rows(),
490            StreamChunk::from_pretty(
491                "  I I I
492                +  8 8 2
493                + 10 1 1
494                +  8 1 3
495                ",
496            )
497            .sort_rows(),
498        );
499
500        // barrier
501        top_n.expect_barrier().await;
502        assert_eq!(
503            top_n.expect_chunk().await.sort_rows(),
504            StreamChunk::from_pretty(
505                "  I I I
506                -  8 8 2
507                - 10 1 1
508                ",
509            )
510            .sort_rows(),
511        );
512
513        // barrier
514        top_n.expect_barrier().await;
515        assert_eq!(
516            top_n.expect_chunk().await.sort_rows(),
517            StreamChunk::from_pretty(
518                " I I I
519                - 8 1 3",
520            )
521            .sort_rows(),
522        );
523
524        // barrier
525        top_n.expect_barrier().await;
526        assert_eq!(
527            top_n.expect_chunk().await.sort_rows(),
528            StreamChunk::from_pretty(
529                " I I I
530                + 5 1 1
531                + 3 1 2
532                ",
533            )
534            .sort_rows(),
535        );
536    }
537
538    #[tokio::test]
539    async fn test_multi_group_key() {
540        let source = create_source();
541        let state_table = create_in_memory_state_table(
542            &[DataType::Int64, DataType::Int64, DataType::Int64],
543            &[
544                OrderType::ascending(),
545                OrderType::ascending(),
546                OrderType::ascending(),
547            ],
548            &pk_indices(),
549        )
550        .await;
551        let schema = source.schema().clone();
552        let top_n = GroupTopNExecutor::<SerializedKey, MemoryStateStore, false>::new(
553            source,
554            ActorContext::for_test(0),
555            schema,
556            storage_key(),
557            (0, 2),
558            order_by_2(),
559            vec![1, 2],
560            state_table,
561            Arc::new(AtomicU64::new(0)),
562        )
563        .unwrap();
564        let mut top_n = top_n.boxed().execute();
565
566        // consume the init barrier
567        top_n.expect_barrier().await;
568        assert_eq!(
569            top_n.expect_chunk().await.sort_rows(),
570            StreamChunk::from_pretty(
571                "  I I I
572                + 10 9 1
573                +  8 8 2
574                +  7 8 2
575                +  9 1 1
576                + 10 1 1
577                +  8 1 3",
578            )
579            .sort_rows(),
580        );
581
582        // barrier
583        top_n.expect_barrier().await;
584        assert_eq!(
585            top_n.expect_chunk().await.sort_rows(),
586            StreamChunk::from_pretty(
587                "  I I I
588                - 10 9 1
589                -  8 8 2
590                - 10 1 1",
591            )
592            .sort_rows(),
593        );
594
595        // barrier
596        top_n.expect_barrier().await;
597        assert_eq!(
598            top_n.expect_chunk().await.sort_rows(),
599            StreamChunk::from_pretty(
600                "  I I I
601                - 7 8 2
602                - 8 1 3
603                - 9 1 1",
604            )
605            .sort_rows(),
606        );
607
608        // barrier
609        top_n.expect_barrier().await;
610        assert_eq!(
611            top_n.expect_chunk().await.sort_rows(),
612            StreamChunk::from_pretty(
613                "  I I I
614                +  5 1 1
615                +  2 1 1
616                +  3 1 2
617                +  4 1 3",
618            )
619            .sort_rows(),
620        );
621    }
622
623    #[tokio::test]
624    async fn test_compact_changes() {
625        let schema = create_schema();
626        let source = MockSource::with_messages(vec![
627            Message::Barrier(Barrier::new_test_barrier(test_epoch(1))),
628            Message::Chunk(StreamChunk::from_pretty(
629                "  I I I
630                +  0 0 9
631                +  0 0 8
632                +  0 0 7
633                +  0 0 6
634                +  0 1 15
635                +  0 1 14",
636            )),
637            Message::Barrier(Barrier::new_test_barrier(test_epoch(2))),
638            Message::Chunk(StreamChunk::from_pretty(
639                "  I I I
640                -  0 0 6
641                -  0 0 8
642                +  0 0 4
643                +  0 0 3
644                +  0 1 12
645                +  0 2 26
646                -  0 1 12
647                +  0 1 11",
648            )),
649            Message::Barrier(Barrier::new_test_barrier(test_epoch(3))),
650            Message::Chunk(StreamChunk::from_pretty(
651                "  I I I
652                +  0 0 11", // this should result in no chunk output
653            )),
654            Message::Barrier(Barrier::new_test_barrier(test_epoch(4))),
655        ])
656        .into_executor(schema.clone(), vec![2]);
657
658        let state_table = create_in_memory_state_table(
659            &schema.data_types(),
660            &[
661                OrderType::ascending(),
662                OrderType::ascending(),
663                OrderType::ascending(),
664            ],
665            &[0, 1, 2], // table pk = group key (0, 1) + order key (2) + additional pk (empty)
666        )
667        .await;
668
669        let top_n = GroupTopNExecutor::<SerializedKey, MemoryStateStore, false>::new(
670            source,
671            ActorContext::for_test(0),
672            schema,
673            vec![
674                ColumnOrder::new(0, OrderType::ascending()),
675                ColumnOrder::new(1, OrderType::ascending()),
676                ColumnOrder::new(2, OrderType::ascending()),
677            ],
678            (0, 2), // (offset, limit)
679            vec![ColumnOrder::new(2, OrderType::ascending())],
680            vec![0, 1],
681            state_table,
682            Arc::new(AtomicU64::new(0)),
683        )
684        .unwrap();
685        let mut top_n = top_n.boxed().execute();
686
687        // initial barrier
688        top_n.expect_barrier().await;
689
690        assert_eq!(
691            top_n.expect_chunk().await.sort_rows(),
692            StreamChunk::from_pretty(
693                "  I I I
694                +  0 0 7
695                +  0 0 6
696                +  0 1 15
697                +  0 1 14",
698            )
699            .sort_rows(),
700        );
701        top_n.expect_barrier().await;
702
703        assert_eq!(
704            top_n.expect_chunk().await.sort_rows(),
705            StreamChunk::from_pretty(
706                "  I I I
707                -  0 0 6
708                -  0 0 7
709                +  0 0 4
710                +  0 0 3
711                -  0 1 15
712                +  0 1 11
713                +  0 2 26",
714            )
715            .sort_rows(),
716        );
717        top_n.expect_barrier().await;
718
719        // no output chunk for the last input chunk
720        top_n.expect_barrier().await;
721    }
722}