risingwave_stream/executor/top_n/
group_top_n.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::ops::{Deref, DerefMut};
17
18use risingwave_common::array::Op;
19use risingwave_common::hash::HashKey;
20use risingwave_common::row::{RowDeserializer, RowExt};
21use risingwave_common::util::epoch::EpochPair;
22use risingwave_common::util::iter_util::ZipEqDebug;
23use risingwave_common::util::sort_util::ColumnOrder;
24
25use super::top_n_cache::TopNCacheTrait;
26use super::utils::*;
27use super::{ManagedTopNState, TopNCache};
28use crate::cache::ManagedLruCache;
29use crate::common::metrics::MetricsInfo;
30use crate::common::table::state_table::StateTablePostCommit;
31use crate::executor::monitor::GroupTopNMetrics;
32use crate::executor::prelude::*;
33
34pub type GroupTopNExecutor<K, S, const WITH_TIES: bool> =
35    TopNExecutorWrapper<InnerGroupTopNExecutor<K, S, WITH_TIES>>;
36
37impl<K: HashKey, S: StateStore, const WITH_TIES: bool> GroupTopNExecutor<K, S, WITH_TIES> {
38    #[allow(clippy::too_many_arguments)]
39    pub fn new(
40        input: Executor,
41        ctx: ActorContextRef,
42        schema: Schema,
43        storage_key: Vec<ColumnOrder>,
44        offset_and_limit: (usize, usize),
45        order_by: Vec<ColumnOrder>,
46        group_by: Vec<usize>,
47        state_table: StateTable<S>,
48        watermark_epoch: AtomicU64Ref,
49    ) -> StreamResult<Self> {
50        let inner = InnerGroupTopNExecutor::new(
51            schema,
52            storage_key,
53            offset_and_limit,
54            order_by,
55            group_by,
56            state_table,
57            watermark_epoch,
58            &ctx,
59        )?;
60        Ok(TopNExecutorWrapper { input, ctx, inner })
61    }
62}
63
64pub struct InnerGroupTopNExecutor<K: HashKey, S: StateStore, const WITH_TIES: bool> {
65    schema: Schema,
66
67    /// `LIMIT XXX`. None means no limit.
68    limit: usize,
69
70    /// `OFFSET XXX`. `0` means no offset.
71    offset: usize,
72
73    /// The storage key indices of the `GroupTopNExecutor`
74    storage_key_indices: PkIndices,
75
76    managed_state: ManagedTopNState<S>,
77
78    /// which column we used to group the data.
79    group_by: Vec<usize>,
80
81    /// group key -> cache for this group
82    caches: GroupTopNCache<K, WITH_TIES>,
83
84    /// Used for serializing pk into `CacheKey`.
85    cache_key_serde: CacheKeySerde,
86
87    metrics: GroupTopNMetrics,
88}
89
90impl<K: HashKey, S: StateStore, const WITH_TIES: bool> InnerGroupTopNExecutor<K, S, WITH_TIES> {
91    #[allow(clippy::too_many_arguments)]
92    pub fn new(
93        schema: Schema,
94        storage_key: Vec<ColumnOrder>,
95        offset_and_limit: (usize, usize),
96        order_by: Vec<ColumnOrder>,
97        group_by: Vec<usize>,
98        state_table: StateTable<S>,
99        watermark_epoch: AtomicU64Ref,
100        ctx: &ActorContext,
101    ) -> StreamResult<Self> {
102        let metrics_info = MetricsInfo::new(
103            ctx.streaming_metrics.clone(),
104            state_table.table_id(),
105            ctx.id,
106            "GroupTopN",
107        );
108        let metrics = ctx.streaming_metrics.new_group_top_n_metrics(
109            state_table.table_id(),
110            ctx.id,
111            ctx.fragment_id,
112        );
113
114        let cache_key_serde = create_cache_key_serde(&storage_key, &schema, &order_by, &group_by);
115        let managed_state = ManagedTopNState::<S>::new(state_table, cache_key_serde.clone());
116
117        Ok(Self {
118            schema,
119            offset: offset_and_limit.0,
120            limit: offset_and_limit.1,
121            managed_state,
122            storage_key_indices: storage_key.into_iter().map(|op| op.column_index).collect(),
123            group_by,
124            caches: GroupTopNCache::new(watermark_epoch, metrics_info),
125            cache_key_serde,
126            metrics,
127        })
128    }
129}
130
131pub struct GroupTopNCache<K: HashKey, const WITH_TIES: bool> {
132    data: ManagedLruCache<K, TopNCache<WITH_TIES>>,
133}
134
135impl<K: HashKey, const WITH_TIES: bool> GroupTopNCache<K, WITH_TIES> {
136    pub fn new(watermark_sequence: AtomicU64Ref, metrics_info: MetricsInfo) -> Self {
137        let cache = ManagedLruCache::unbounded(watermark_sequence, metrics_info);
138        Self { data: cache }
139    }
140}
141
142impl<K: HashKey, const WITH_TIES: bool> Deref for GroupTopNCache<K, WITH_TIES> {
143    type Target = ManagedLruCache<K, TopNCache<WITH_TIES>>;
144
145    fn deref(&self) -> &Self::Target {
146        &self.data
147    }
148}
149
150impl<K: HashKey, const WITH_TIES: bool> DerefMut for GroupTopNCache<K, WITH_TIES> {
151    fn deref_mut(&mut self) -> &mut Self::Target {
152        &mut self.data
153    }
154}
155
156impl<K: HashKey, S: StateStore, const WITH_TIES: bool> TopNExecutorBase
157    for InnerGroupTopNExecutor<K, S, WITH_TIES>
158where
159    TopNCache<WITH_TIES>: TopNCacheTrait,
160{
161    type State = S;
162
163    async fn apply_chunk(
164        &mut self,
165        chunk: StreamChunk,
166    ) -> StreamExecutorResult<Option<StreamChunk>> {
167        let keys = K::build_many(&self.group_by, chunk.data_chunk());
168        let mut stagings = HashMap::new(); // K -> `TopNStaging`
169
170        for (r, group_cache_key) in chunk.rows_with_holes().zip_eq_debug(keys.iter()) {
171            let Some((op, row_ref)) = r else {
172                continue;
173            };
174
175            // The pk without group by
176            let pk_row = row_ref.project(&self.storage_key_indices[self.group_by.len()..]);
177            let cache_key = serialize_pk_to_cache_key(pk_row, &self.cache_key_serde);
178
179            let group_key = row_ref.project(&self.group_by);
180            self.metrics.group_top_n_total_query_cache_count.inc();
181            // If 'self.caches' does not already have a cache for the current group, create a new
182            // cache for it and insert it into `self.caches`
183            if !self.caches.contains(group_cache_key) {
184                self.metrics.group_top_n_cache_miss_count.inc();
185                let mut topn_cache =
186                    TopNCache::new(self.offset, self.limit, self.schema.data_types());
187                self.managed_state
188                    .init_topn_cache(Some(group_key), &mut topn_cache)
189                    .await?;
190                self.caches.push(group_cache_key.clone(), topn_cache);
191            }
192
193            let mut cache = self.caches.get_mut(group_cache_key).unwrap();
194            let staging = stagings.entry(group_cache_key.clone()).or_default();
195
196            // apply the chunk to state table
197            match op {
198                Op::Insert | Op::UpdateInsert => {
199                    self.managed_state.insert(row_ref);
200                    cache.insert(cache_key, row_ref, staging);
201                }
202
203                Op::Delete | Op::UpdateDelete => {
204                    self.managed_state.delete(row_ref);
205                    cache
206                        .delete(
207                            Some(group_key),
208                            &mut self.managed_state,
209                            cache_key,
210                            row_ref,
211                            staging,
212                        )
213                        .await?;
214                }
215            }
216        }
217
218        self.metrics
219            .group_top_n_cached_entry_count
220            .set(self.caches.len() as i64);
221
222        let data_types = self.schema.data_types();
223        let deserializer = RowDeserializer::new(data_types.clone());
224        let mut chunk_builder = StreamChunkBuilder::unlimited(data_types, Some(chunk.capacity()));
225        for staging in stagings.into_values() {
226            for res in staging.into_deserialized_changes(&deserializer) {
227                let (op, row) = res?;
228                let _none = chunk_builder.append_row(op, row);
229            }
230        }
231        Ok(chunk_builder.take())
232    }
233
234    async fn flush_data(
235        &mut self,
236        epoch: EpochPair,
237    ) -> StreamExecutorResult<StateTablePostCommit<'_, S>> {
238        self.managed_state.flush(epoch).await
239    }
240
241    async fn try_flush_data(&mut self) -> StreamExecutorResult<()> {
242        self.managed_state.try_flush().await
243    }
244
245    fn clear_cache(&mut self) {
246        self.caches.clear();
247    }
248
249    fn evict(&mut self) {
250        self.caches.evict()
251    }
252
253    async fn init(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> {
254        self.managed_state.init_epoch(epoch).await
255    }
256
257    async fn handle_watermark(&mut self, watermark: Watermark) -> Option<Watermark> {
258        if watermark.col_idx == self.group_by[0] {
259            self.managed_state.update_watermark(watermark.val.clone());
260            Some(watermark)
261        } else {
262            None
263        }
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use std::sync::atomic::AtomicU64;
270
271    use risingwave_common::array::stream_chunk::StreamChunkTestExt;
272    use risingwave_common::catalog::Field;
273    use risingwave_common::hash::SerializedKey;
274    use risingwave_common::util::epoch::test_epoch;
275    use risingwave_common::util::sort_util::OrderType;
276    use risingwave_storage::memory::MemoryStateStore;
277
278    use super::*;
279    use crate::executor::test_utils::top_n_executor::create_in_memory_state_table;
280    use crate::executor::test_utils::{MockSource, StreamExecutorTestExt};
281
282    fn create_schema() -> Schema {
283        Schema {
284            fields: vec![
285                Field::unnamed(DataType::Int64),
286                Field::unnamed(DataType::Int64),
287                Field::unnamed(DataType::Int64),
288            ],
289        }
290    }
291
292    fn storage_key() -> Vec<ColumnOrder> {
293        vec![
294            ColumnOrder::new(1, OrderType::ascending()),
295            ColumnOrder::new(2, OrderType::ascending()),
296            ColumnOrder::new(0, OrderType::ascending()),
297        ]
298    }
299
300    /// group by 1, order by 2
301    fn order_by_1() -> Vec<ColumnOrder> {
302        vec![ColumnOrder::new(2, OrderType::ascending())]
303    }
304
305    /// group by 1,2, order by 0
306    fn order_by_2() -> Vec<ColumnOrder> {
307        vec![ColumnOrder::new(0, OrderType::ascending())]
308    }
309
310    fn pk_indices() -> PkIndices {
311        vec![1, 2, 0]
312    }
313
314    fn create_stream_chunks() -> Vec<StreamChunk> {
315        let chunk0 = StreamChunk::from_pretty(
316            "  I I I
317            + 10 9 1
318            +  8 8 2
319            +  7 8 2
320            +  9 1 1
321            + 10 1 1
322            +  8 1 3",
323        );
324        let chunk1 = StreamChunk::from_pretty(
325            "  I I I
326            - 10 9 1
327            -  8 8 2
328            - 10 1 1",
329        );
330        let chunk2 = StreamChunk::from_pretty(
331            " I I I
332            - 7 8 2
333            - 8 1 3
334            - 9 1 1",
335        );
336        let chunk3 = StreamChunk::from_pretty(
337            "  I I I
338            +  5 1 1
339            +  2 1 1
340            +  3 1 2
341            +  4 1 3",
342        );
343        vec![chunk0, chunk1, chunk2, chunk3]
344    }
345
346    fn create_source() -> Executor {
347        let mut chunks = create_stream_chunks();
348        let schema = create_schema();
349        MockSource::with_messages(vec![
350            Message::Barrier(Barrier::new_test_barrier(test_epoch(1))),
351            Message::Chunk(std::mem::take(&mut chunks[0])),
352            Message::Barrier(Barrier::new_test_barrier(test_epoch(2))),
353            Message::Chunk(std::mem::take(&mut chunks[1])),
354            Message::Barrier(Barrier::new_test_barrier(test_epoch(3))),
355            Message::Chunk(std::mem::take(&mut chunks[2])),
356            Message::Barrier(Barrier::new_test_barrier(test_epoch(4))),
357            Message::Chunk(std::mem::take(&mut chunks[3])),
358            Message::Barrier(Barrier::new_test_barrier(test_epoch(5))),
359        ])
360        .into_executor(schema, pk_indices())
361    }
362
363    #[tokio::test]
364    async fn test_without_offset_and_with_limits() {
365        let source = create_source();
366        let state_table = create_in_memory_state_table(
367            &[DataType::Int64, DataType::Int64, DataType::Int64],
368            &[
369                OrderType::ascending(),
370                OrderType::ascending(),
371                OrderType::ascending(),
372            ],
373            &pk_indices(),
374        )
375        .await;
376        let schema = source.schema().clone();
377        let top_n = GroupTopNExecutor::<SerializedKey, MemoryStateStore, false>::new(
378            source,
379            ActorContext::for_test(0),
380            schema,
381            storage_key(),
382            (0, 2),
383            order_by_1(),
384            vec![1],
385            state_table,
386            Arc::new(AtomicU64::new(0)),
387        )
388        .unwrap();
389        let mut top_n = top_n.boxed().execute();
390
391        // consume the init barrier
392        top_n.expect_barrier().await;
393        assert_eq!(
394            top_n.expect_chunk().await.sort_rows(),
395            StreamChunk::from_pretty(
396                "  I I I
397                + 10 9 1
398                +  8 8 2
399                +  7 8 2
400                +  9 1 1
401                + 10 1 1
402                ",
403            )
404            .sort_rows(),
405        );
406
407        // barrier
408        top_n.expect_barrier().await;
409        assert_eq!(
410            top_n.expect_chunk().await.sort_rows(),
411            StreamChunk::from_pretty(
412                "  I I I
413                - 10 9 1
414                -  8 8 2
415                - 10 1 1
416                +  8 1 3
417                ",
418            )
419            .sort_rows(),
420        );
421
422        // barrier
423        top_n.expect_barrier().await;
424        assert_eq!(
425            top_n.expect_chunk().await.sort_rows(),
426            StreamChunk::from_pretty(
427                " I I I
428                - 7 8 2
429                - 8 1 3
430                - 9 1 1
431                ",
432            )
433            .sort_rows(),
434        );
435
436        // barrier
437        top_n.expect_barrier().await;
438        assert_eq!(
439            top_n.expect_chunk().await.sort_rows(),
440            StreamChunk::from_pretty(
441                " I I I
442                + 5 1 1
443                + 2 1 1
444                ",
445            )
446            .sort_rows(),
447        );
448    }
449
450    #[tokio::test]
451    async fn test_with_offset_and_with_limits() {
452        let source = create_source();
453        let state_table = create_in_memory_state_table(
454            &[DataType::Int64, DataType::Int64, DataType::Int64],
455            &[
456                OrderType::ascending(),
457                OrderType::ascending(),
458                OrderType::ascending(),
459            ],
460            &pk_indices(),
461        )
462        .await;
463        let schema = source.schema().clone();
464        let top_n = GroupTopNExecutor::<SerializedKey, MemoryStateStore, false>::new(
465            source,
466            ActorContext::for_test(0),
467            schema,
468            storage_key(),
469            (1, 2),
470            order_by_1(),
471            vec![1],
472            state_table,
473            Arc::new(AtomicU64::new(0)),
474        )
475        .unwrap();
476        let mut top_n = top_n.boxed().execute();
477
478        // consume the init barrier
479        top_n.expect_barrier().await;
480        assert_eq!(
481            top_n.expect_chunk().await.sort_rows(),
482            StreamChunk::from_pretty(
483                "  I I I
484                +  8 8 2
485                + 10 1 1
486                +  8 1 3
487                ",
488            )
489            .sort_rows(),
490        );
491
492        // barrier
493        top_n.expect_barrier().await;
494        assert_eq!(
495            top_n.expect_chunk().await.sort_rows(),
496            StreamChunk::from_pretty(
497                "  I I I
498                -  8 8 2
499                - 10 1 1
500                ",
501            )
502            .sort_rows(),
503        );
504
505        // barrier
506        top_n.expect_barrier().await;
507        assert_eq!(
508            top_n.expect_chunk().await.sort_rows(),
509            StreamChunk::from_pretty(
510                " I I I
511                - 8 1 3",
512            )
513            .sort_rows(),
514        );
515
516        // barrier
517        top_n.expect_barrier().await;
518        assert_eq!(
519            top_n.expect_chunk().await.sort_rows(),
520            StreamChunk::from_pretty(
521                " I I I
522                + 5 1 1
523                + 3 1 2
524                ",
525            )
526            .sort_rows(),
527        );
528    }
529
530    #[tokio::test]
531    async fn test_multi_group_key() {
532        let source = create_source();
533        let state_table = create_in_memory_state_table(
534            &[DataType::Int64, DataType::Int64, DataType::Int64],
535            &[
536                OrderType::ascending(),
537                OrderType::ascending(),
538                OrderType::ascending(),
539            ],
540            &pk_indices(),
541        )
542        .await;
543        let schema = source.schema().clone();
544        let top_n = GroupTopNExecutor::<SerializedKey, MemoryStateStore, false>::new(
545            source,
546            ActorContext::for_test(0),
547            schema,
548            storage_key(),
549            (0, 2),
550            order_by_2(),
551            vec![1, 2],
552            state_table,
553            Arc::new(AtomicU64::new(0)),
554        )
555        .unwrap();
556        let mut top_n = top_n.boxed().execute();
557
558        // consume the init barrier
559        top_n.expect_barrier().await;
560        assert_eq!(
561            top_n.expect_chunk().await.sort_rows(),
562            StreamChunk::from_pretty(
563                "  I I I
564                + 10 9 1
565                +  8 8 2
566                +  7 8 2
567                +  9 1 1
568                + 10 1 1
569                +  8 1 3",
570            )
571            .sort_rows(),
572        );
573
574        // barrier
575        top_n.expect_barrier().await;
576        assert_eq!(
577            top_n.expect_chunk().await.sort_rows(),
578            StreamChunk::from_pretty(
579                "  I I I
580                - 10 9 1
581                -  8 8 2
582                - 10 1 1",
583            )
584            .sort_rows(),
585        );
586
587        // barrier
588        top_n.expect_barrier().await;
589        assert_eq!(
590            top_n.expect_chunk().await.sort_rows(),
591            StreamChunk::from_pretty(
592                "  I I I
593                - 7 8 2
594                - 8 1 3
595                - 9 1 1",
596            )
597            .sort_rows(),
598        );
599
600        // barrier
601        top_n.expect_barrier().await;
602        assert_eq!(
603            top_n.expect_chunk().await.sort_rows(),
604            StreamChunk::from_pretty(
605                "  I I I
606                +  5 1 1
607                +  2 1 1
608                +  3 1 2
609                +  4 1 3",
610            )
611            .sort_rows(),
612        );
613    }
614
615    #[tokio::test]
616    async fn test_compact_changes() {
617        let schema = create_schema();
618        let source = MockSource::with_messages(vec![
619            Message::Barrier(Barrier::new_test_barrier(test_epoch(1))),
620            Message::Chunk(StreamChunk::from_pretty(
621                "  I I I
622                +  0 0 9
623                +  0 0 8
624                +  0 0 7
625                +  0 0 6
626                +  0 1 15
627                +  0 1 14",
628            )),
629            Message::Barrier(Barrier::new_test_barrier(test_epoch(2))),
630            Message::Chunk(StreamChunk::from_pretty(
631                "  I I I
632                -  0 0 6
633                -  0 0 8
634                +  0 0 4
635                +  0 0 3
636                +  0 1 12
637                +  0 2 26
638                -  0 1 12
639                +  0 1 11",
640            )),
641            Message::Barrier(Barrier::new_test_barrier(test_epoch(3))),
642            Message::Chunk(StreamChunk::from_pretty(
643                "  I I I
644                +  0 0 11", // this should result in no chunk output
645            )),
646            Message::Barrier(Barrier::new_test_barrier(test_epoch(4))),
647        ])
648        .into_executor(schema.clone(), vec![2]);
649
650        let state_table = create_in_memory_state_table(
651            &schema.data_types(),
652            &[
653                OrderType::ascending(),
654                OrderType::ascending(),
655                OrderType::ascending(),
656            ],
657            &[0, 1, 2], // table pk = group key (0, 1) + order key (2) + additional pk (empty)
658        )
659        .await;
660
661        let top_n = GroupTopNExecutor::<SerializedKey, MemoryStateStore, false>::new(
662            source,
663            ActorContext::for_test(0),
664            schema,
665            vec![
666                ColumnOrder::new(0, OrderType::ascending()),
667                ColumnOrder::new(1, OrderType::ascending()),
668                ColumnOrder::new(2, OrderType::ascending()),
669            ],
670            (0, 2), // (offset, limit)
671            vec![ColumnOrder::new(2, OrderType::ascending())],
672            vec![0, 1],
673            state_table,
674            Arc::new(AtomicU64::new(0)),
675        )
676        .unwrap();
677        let mut top_n = top_n.boxed().execute();
678
679        // initial barrier
680        top_n.expect_barrier().await;
681
682        assert_eq!(
683            top_n.expect_chunk().await.sort_rows(),
684            StreamChunk::from_pretty(
685                "  I I I
686                +  0 0 7
687                +  0 0 6
688                +  0 1 15
689                +  0 1 14",
690            )
691            .sort_rows(),
692        );
693        top_n.expect_barrier().await;
694
695        assert_eq!(
696            top_n.expect_chunk().await.sort_rows(),
697            StreamChunk::from_pretty(
698                "  I I I
699                -  0 0 6
700                -  0 0 7
701                +  0 0 4
702                +  0 0 3
703                -  0 1 15
704                +  0 1 11
705                +  0 2 26",
706            )
707            .sort_rows(),
708        );
709        top_n.expect_barrier().await;
710
711        // no output chunk for the last input chunk
712        top_n.expect_barrier().await;
713    }
714}