risingwave_stream/executor/top_n/
top_n_state.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::ops::Bound;
16
17use futures::{StreamExt, pin_mut};
18use risingwave_common::row::{OwnedRow, Row, RowExt};
19use risingwave_common::types::ScalarImpl;
20use risingwave_common::util::epoch::EpochPair;
21use risingwave_storage::StateStore;
22use risingwave_storage::store::PrefetchOptions;
23
24use super::top_n_cache::CacheKey;
25use super::{CacheKeySerde, GroupKey, TopNCache, serialize_pk_to_cache_key};
26use crate::common::table::state_table::{StateTable, StateTablePostCommit};
27use crate::executor::error::StreamExecutorResult;
28use crate::executor::top_n::top_n_cache::Cache;
29
30/// * For TopN, the storage key is: `[ order_by + remaining columns of pk ]`
31/// * For group TopN, the storage key is: `[ group_key + order_by + remaining columns of pk ]`
32///
33/// The key in [`TopNCache`] is [`CacheKey`], which is `[ order_by|remaining columns of pk ]`, and
34/// `group_key` is not included.
35pub struct ManagedTopNState<S: StateStore> {
36    /// Relational table.
37    state_table: StateTable<S>,
38
39    /// Used for serializing pk into `CacheKey`.
40    cache_key_serde: CacheKeySerde,
41}
42
43#[derive(Clone, PartialEq, Debug)]
44pub struct TopNStateRow {
45    pub cache_key: CacheKey,
46    pub row: OwnedRow,
47}
48
49impl TopNStateRow {
50    pub fn new(cache_key: CacheKey, row: OwnedRow) -> Self {
51        Self { cache_key, row }
52    }
53}
54
55impl<S: StateStore> ManagedTopNState<S> {
56    pub fn new(state_table: StateTable<S>, cache_key_serde: CacheKeySerde) -> Self {
57        Self {
58            state_table,
59            cache_key_serde,
60        }
61    }
62
63    /// Get the immutable reference of managed state table.
64    pub fn table(&self) -> &StateTable<S> {
65        &self.state_table
66    }
67
68    /// Init epoch for the managed state table.
69    pub async fn init_epoch(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> {
70        self.state_table.init_epoch(epoch).await
71    }
72
73    /// Update watermark for the managed state table.
74    pub fn update_watermark(&mut self, watermark: ScalarImpl) {
75        self.state_table.update_watermark(watermark)
76    }
77
78    pub fn insert(&mut self, value: impl Row) {
79        self.state_table.insert(value);
80    }
81
82    pub fn delete(&mut self, value: impl Row) {
83        self.state_table.delete(value);
84    }
85
86    fn get_topn_row(&self, row: OwnedRow, group_key_len: usize) -> TopNStateRow {
87        let pk = (&row).project(&self.state_table.pk_indices()[group_key_len..]);
88        let cache_key = serialize_pk_to_cache_key(pk, &self.cache_key_serde);
89
90        TopNStateRow::new(cache_key, row)
91    }
92
93    /// This function will return the rows in the range of [`offset`, `offset` + `limit`).
94    ///
95    /// If `group_key` is None, it will scan rows from the very beginning.
96    /// Otherwise it will scan rows with prefix `group_key`.
97    #[cfg(test)]
98    pub async fn find_range(
99        &self,
100        group_key: Option<impl GroupKey>,
101        offset: usize,
102        limit: Option<usize>,
103    ) -> StreamExecutorResult<Vec<TopNStateRow>> {
104        let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
105        let state_table_iter = self
106            .state_table
107            .iter_with_prefix(&group_key, sub_range, Default::default())
108            .await?;
109        pin_mut!(state_table_iter);
110
111        // here we don't expect users to have large OFFSET.
112        let (mut rows, mut stream) = if let Some(limit) = limit {
113            (
114                Vec::with_capacity(limit.min(1024)),
115                state_table_iter.skip(offset).take(limit),
116            )
117        } else {
118            (
119                Vec::with_capacity(1024),
120                state_table_iter.skip(offset).take(1024),
121            )
122        };
123        while let Some(item) = stream.next().await {
124            rows.push(self.get_topn_row(item?.into_owned_row(), group_key.len()));
125        }
126        Ok(rows)
127    }
128
129    /// # Arguments
130    ///
131    /// * `group_key` - Used as the prefix of the key to scan. Only for group TopN.
132    /// * `start_key` - The start point of the key to scan. It should be the last key of the middle
133    ///   cache. It doesn't contain the group key.
134    pub async fn fill_high_cache<const WITH_TIES: bool>(
135        &self,
136        group_key: Option<impl GroupKey>,
137        topn_cache: &mut TopNCache<WITH_TIES>,
138        start_key: Option<CacheKey>,
139        cache_size_limit: usize,
140    ) -> StreamExecutorResult<()> {
141        let high_cache = &mut topn_cache.high;
142        assert!(high_cache.is_empty());
143
144        // TODO(rc): iterate from `start_key`
145        let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
146        let state_table_iter = self
147            .state_table
148            .iter_with_prefix(
149                &group_key,
150                sub_range,
151                PrefetchOptions {
152                    prefetch: cache_size_limit == usize::MAX,
153                    for_large_query: false,
154                },
155            )
156            .await?;
157        pin_mut!(state_table_iter);
158
159        let mut group_row_count = 0;
160
161        while let Some(item) = state_table_iter.next().await {
162            group_row_count += 1;
163
164            // Note(bugen): should first compare with start key before constructing TopNStateRow.
165            let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len());
166            if let Some(start_key) = start_key.as_ref()
167                && &topn_row.cache_key <= start_key
168            {
169                continue;
170            }
171            high_cache.insert(topn_row.cache_key, (&topn_row.row).into());
172            if high_cache.len() == cache_size_limit {
173                break;
174            }
175        }
176
177        if WITH_TIES && topn_cache.high_is_full() {
178            let high_last_sort_key = topn_cache.high.last_key_value().unwrap().0.0.clone();
179            while let Some(item) = state_table_iter.next().await {
180                group_row_count += 1;
181
182                let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len());
183                if topn_row.cache_key.0 == high_last_sort_key {
184                    topn_cache
185                        .high
186                        .insert(topn_row.cache_key, (&topn_row.row).into());
187                } else {
188                    break;
189                }
190            }
191        }
192
193        if state_table_iter.next().await.is_none() {
194            // We can only update the row count when we have seen all rows of the group in the table.
195            topn_cache.update_table_row_count(group_row_count);
196        }
197
198        Ok(())
199    }
200
201    pub async fn init_topn_cache_inner<const WITH_TIES: bool>(
202        &self,
203        group_key: Option<impl GroupKey>,
204        topn_cache: &mut TopNCache<WITH_TIES>,
205        skip_high: bool,
206    ) -> StreamExecutorResult<()> {
207        assert!(topn_cache.low.as_ref().map(Cache::is_empty).unwrap_or(true));
208        assert!(topn_cache.middle.is_empty());
209        assert!(topn_cache.high.is_empty());
210
211        let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
212        let state_table_iter = self
213            .state_table
214            .iter_with_prefix(
215                &group_key,
216                sub_range,
217                PrefetchOptions {
218                    prefetch: topn_cache.limit == usize::MAX,
219                    for_large_query: false,
220                },
221            )
222            .await?;
223        pin_mut!(state_table_iter);
224
225        let mut group_row_count = 0;
226
227        if let Some(low) = &mut topn_cache.low {
228            while let Some(item) = state_table_iter.next().await {
229                group_row_count += 1;
230                let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len());
231                low.insert(topn_row.cache_key, (&topn_row.row).into());
232                if low.len() == topn_cache.offset {
233                    break;
234                }
235            }
236        }
237
238        assert!(topn_cache.limit > 0, "topn cache limit should always > 0");
239        while let Some(item) = state_table_iter.next().await {
240            group_row_count += 1;
241            let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len());
242            topn_cache
243                .middle
244                .insert(topn_row.cache_key, (&topn_row.row).into());
245            if topn_cache.middle.len() == topn_cache.limit {
246                break;
247            }
248        }
249        if WITH_TIES && topn_cache.middle_is_full() {
250            let middle_last_sort_key = topn_cache.middle.last_key_value().unwrap().0.0.clone();
251            while let Some(item) = state_table_iter.next().await {
252                group_row_count += 1;
253                let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len());
254                if topn_row.cache_key.0 == middle_last_sort_key {
255                    topn_cache
256                        .middle
257                        .insert(topn_row.cache_key, (&topn_row.row).into());
258                } else {
259                    topn_cache
260                        .high
261                        .insert(topn_row.cache_key, (&topn_row.row).into());
262                    break;
263                }
264            }
265        }
266
267        if !skip_high {
268            assert!(
269                topn_cache.high_cache_capacity > 0,
270                "topn cache high_capacity should always > 0"
271            );
272            while !topn_cache.high_is_full()
273                && let Some(item) = state_table_iter.next().await
274            {
275                group_row_count += 1;
276                let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len());
277                topn_cache
278                    .high
279                    .insert(topn_row.cache_key, (&topn_row.row).into());
280            }
281            if WITH_TIES && topn_cache.high_is_full() {
282                let high_last_sort_key = topn_cache.high.last_key_value().unwrap().0.0.clone();
283                while let Some(item) = state_table_iter.next().await {
284                    group_row_count += 1;
285                    let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len());
286                    if topn_row.cache_key.0 == high_last_sort_key {
287                        topn_cache
288                            .high
289                            .insert(topn_row.cache_key, (&topn_row.row).into());
290                    } else {
291                        break;
292                    }
293                }
294            }
295            if state_table_iter.next().await.is_none() {
296                // After trying to initially fill in the cache, all table entries are in the cache,
297                // we then get the precise table row count.
298                topn_cache.update_table_row_count(group_row_count);
299            }
300        } else {
301            topn_cache.update_table_row_count(group_row_count);
302        }
303
304        Ok(())
305    }
306
307    pub async fn init_topn_cache<const WITH_TIES: bool>(
308        &self,
309        group_key: Option<impl GroupKey>,
310        topn_cache: &mut TopNCache<WITH_TIES>,
311    ) -> StreamExecutorResult<()> {
312        self.init_topn_cache_inner(group_key, topn_cache, false)
313            .await
314    }
315
316    pub async fn init_append_only_topn_cache<const WITH_TIES: bool>(
317        &self,
318        group_key: Option<impl GroupKey>,
319        topn_cache: &mut TopNCache<WITH_TIES>,
320    ) -> StreamExecutorResult<()> {
321        self.init_topn_cache_inner(group_key, topn_cache, true)
322            .await
323    }
324
325    pub async fn flush(
326        &mut self,
327        epoch: EpochPair,
328    ) -> StreamExecutorResult<StateTablePostCommit<'_, S>> {
329        self.state_table.commit(epoch).await
330    }
331
332    pub async fn try_flush(&mut self) -> StreamExecutorResult<()> {
333        self.state_table.try_flush().await?;
334        Ok(())
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use risingwave_common::catalog::{Field, Schema};
341    use risingwave_common::types::DataType;
342    use risingwave_common::util::epoch::test_epoch;
343    use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
344
345    use super::*;
346    use crate::executor::test_utils::top_n_executor::create_in_memory_state_table;
347    use crate::executor::top_n::top_n_cache::{TopNCacheTrait, TopNStaging};
348    use crate::executor::top_n::{NO_GROUP_KEY, create_cache_key_serde};
349    use crate::row_nonnull;
350
351    fn cache_key_serde() -> CacheKeySerde {
352        let data_types = vec![DataType::Varchar, DataType::Int64];
353        let schema = Schema::new(data_types.into_iter().map(Field::unnamed).collect());
354        let storage_key = vec![
355            ColumnOrder::new(0, OrderType::ascending()),
356            ColumnOrder::new(1, OrderType::ascending()),
357        ];
358        let order_by = vec![ColumnOrder::new(0, OrderType::ascending())];
359
360        create_cache_key_serde(&storage_key, &schema, &order_by, &[])
361    }
362
363    #[tokio::test]
364    async fn test_managed_top_n_state() {
365        let state_table = {
366            let mut tb = create_in_memory_state_table(
367                &[DataType::Varchar, DataType::Int64],
368                &[OrderType::ascending(), OrderType::ascending()],
369                &[0, 1],
370            )
371            .await;
372            tb.init_epoch(EpochPair::new_test_epoch(test_epoch(1)))
373                .await
374                .unwrap();
375            tb
376        };
377
378        let cache_key_serde = cache_key_serde();
379        let mut managed_state = ManagedTopNState::new(state_table, cache_key_serde.clone());
380
381        let row1 = row_nonnull!["abc", 2i64];
382        let row2 = row_nonnull!["abc", 3i64];
383        let row3 = row_nonnull!["abd", 3i64];
384        let row4 = row_nonnull!["ab", 4i64];
385
386        let row1_bytes = serialize_pk_to_cache_key(row1.clone(), &cache_key_serde);
387        let row2_bytes = serialize_pk_to_cache_key(row2.clone(), &cache_key_serde);
388        let row3_bytes = serialize_pk_to_cache_key(row3.clone(), &cache_key_serde);
389        let row4_bytes = serialize_pk_to_cache_key(row4.clone(), &cache_key_serde);
390        let rows = [row1, row2, row3, row4];
391        let ordered_rows = [row1_bytes, row2_bytes, row3_bytes, row4_bytes];
392        managed_state.insert(rows[3].clone());
393
394        // now ("ab", 4)
395        let valid_rows = managed_state
396            .find_range(NO_GROUP_KEY, 0, Some(1))
397            .await
398            .unwrap();
399
400        assert_eq!(valid_rows.len(), 1);
401        assert_eq!(valid_rows[0].cache_key, ordered_rows[3].clone());
402
403        managed_state.insert(rows[2].clone());
404        let valid_rows = managed_state
405            .find_range(NO_GROUP_KEY, 1, Some(1))
406            .await
407            .unwrap();
408        assert_eq!(valid_rows.len(), 1);
409        assert_eq!(valid_rows[0].cache_key, ordered_rows[2].clone());
410
411        managed_state.insert(rows[1].clone());
412
413        let valid_rows = managed_state
414            .find_range(NO_GROUP_KEY, 1, Some(2))
415            .await
416            .unwrap();
417        assert_eq!(valid_rows.len(), 2);
418        assert_eq!(
419            valid_rows.first().unwrap().cache_key,
420            ordered_rows[1].clone()
421        );
422        assert_eq!(
423            valid_rows.last().unwrap().cache_key,
424            ordered_rows[2].clone()
425        );
426
427        // delete ("abc", 3)
428        managed_state.delete(rows[1].clone());
429
430        // insert ("abc", 2)
431        managed_state.insert(rows[0].clone());
432
433        let valid_rows = managed_state
434            .find_range(NO_GROUP_KEY, 0, Some(3))
435            .await
436            .unwrap();
437
438        assert_eq!(valid_rows.len(), 3);
439        assert_eq!(valid_rows[0].cache_key, ordered_rows[3].clone());
440        assert_eq!(valid_rows[1].cache_key, ordered_rows[0].clone());
441        assert_eq!(valid_rows[2].cache_key, ordered_rows[2].clone());
442    }
443
444    #[tokio::test]
445    async fn test_managed_top_n_state_fill_cache() {
446        let data_types = vec![DataType::Varchar, DataType::Int64];
447        let state_table = {
448            let mut tb = create_in_memory_state_table(
449                &data_types,
450                &[OrderType::ascending(), OrderType::ascending()],
451                &[0, 1],
452            )
453            .await;
454            tb.init_epoch(EpochPair::new_test_epoch(test_epoch(1)))
455                .await
456                .unwrap();
457            tb
458        };
459
460        let cache_key_serde = cache_key_serde();
461        let mut managed_state = ManagedTopNState::new(state_table, cache_key_serde.clone());
462
463        let row1 = row_nonnull!["abc", 2i64];
464        let row2 = row_nonnull!["abc", 3i64];
465        let row3 = row_nonnull!["abd", 3i64];
466        let row4 = row_nonnull!["ab", 4i64];
467        let row5 = row_nonnull!["abcd", 5i64];
468
469        let row1_bytes = serialize_pk_to_cache_key(row1.clone(), &cache_key_serde);
470        let row2_bytes = serialize_pk_to_cache_key(row2.clone(), &cache_key_serde);
471        let row3_bytes = serialize_pk_to_cache_key(row3.clone(), &cache_key_serde);
472        let row4_bytes = serialize_pk_to_cache_key(row4.clone(), &cache_key_serde);
473        let row5_bytes = serialize_pk_to_cache_key(row5.clone(), &cache_key_serde);
474        let rows = [row1, row2, row3, row4, row5];
475        let ordered_rows = vec![row1_bytes, row2_bytes, row3_bytes, row4_bytes, row5_bytes];
476
477        let mut cache = TopNCache::<false>::new(1, 1, data_types);
478
479        managed_state.insert(rows[3].clone());
480        managed_state.insert(rows[1].clone());
481        managed_state.insert(rows[2].clone());
482        managed_state.insert(rows[4].clone());
483
484        managed_state
485            .fill_high_cache(NO_GROUP_KEY, &mut cache, Some(ordered_rows[3].clone()), 2)
486            .await
487            .unwrap();
488        assert_eq!(cache.high.len(), 2);
489        assert_eq!(cache.high.first_key_value().unwrap().0, &ordered_rows[1]);
490        assert_eq!(cache.high.last_key_value().unwrap().0, &ordered_rows[4]);
491    }
492
493    #[tokio::test]
494    async fn test_top_n_cache_limit_1() {
495        let data_types = vec![DataType::Varchar, DataType::Int64];
496        let state_table = {
497            let mut tb = create_in_memory_state_table(
498                &data_types,
499                &[OrderType::ascending(), OrderType::ascending()],
500                &[0, 1],
501            )
502            .await;
503            tb.init_epoch(EpochPair::new_test_epoch(test_epoch(1)))
504                .await
505                .unwrap();
506            tb
507        };
508
509        let cache_key_serde = cache_key_serde();
510        let mut managed_state = ManagedTopNState::new(state_table, cache_key_serde.clone());
511
512        let row1 = row_nonnull!["abc", 2i64];
513        let row1_bytes = serialize_pk_to_cache_key(row1.clone(), &cache_key_serde);
514
515        let mut cache = TopNCache::<true>::new(0, 1, data_types);
516        cache.insert(row1_bytes.clone(), row1.clone(), &mut TopNStaging::new());
517        cache
518            .delete(
519                NO_GROUP_KEY,
520                &mut managed_state,
521                row1_bytes,
522                row1,
523                &mut TopNStaging::new(),
524            )
525            .await
526            .unwrap();
527    }
528}