risingwave_stream/executor/top_n/
top_n_state.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::ops::Bound;
16
17use futures::{StreamExt, pin_mut};
18use risingwave_common::row::{OwnedRow, Row, RowExt};
19use risingwave_common::types::ScalarImpl;
20use risingwave_common::util::epoch::EpochPair;
21use risingwave_storage::StateStore;
22use risingwave_storage::store::PrefetchOptions;
23
24use super::top_n_cache::CacheKey;
25use super::{CacheKeySerde, GroupKey, TopNCache, serialize_pk_to_cache_key};
26use crate::common::table::state_table::{StateTable, StateTablePostCommit};
27use crate::executor::error::StreamExecutorResult;
28use crate::executor::top_n::top_n_cache::Cache;
29
30/// * For TopN, the storage key is: `[ order_by + remaining columns of pk ]`
31/// * For group TopN, the storage key is: `[ group_key + order_by + remaining columns of pk ]`
32///
33/// The key in [`TopNCache`] is [`CacheKey`], which is `[ order_by|remaining columns of pk ]`, and
34/// `group_key` is not included.
35pub struct ManagedTopNState<S: StateStore> {
36    /// Relational table.
37    state_table: StateTable<S>,
38
39    /// Used for serializing pk into `CacheKey`.
40    cache_key_serde: CacheKeySerde,
41}
42
43#[derive(Clone, PartialEq, Debug)]
44pub struct TopNStateRow {
45    pub cache_key: CacheKey,
46    pub row: OwnedRow,
47}
48
49impl TopNStateRow {
50    pub fn new(cache_key: CacheKey, row: OwnedRow) -> Self {
51        Self { cache_key, row }
52    }
53}
54
55impl<S: StateStore> ManagedTopNState<S> {
56    pub fn new(state_table: StateTable<S>, cache_key_serde: CacheKeySerde) -> Self {
57        Self {
58            state_table,
59            cache_key_serde,
60        }
61    }
62
63    /// Get the immutable reference of managed state table.
64    pub fn table(&self) -> &StateTable<S> {
65        &self.state_table
66    }
67
68    /// Init epoch for the managed state table.
69    pub async fn init_epoch(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> {
70        self.state_table.init_epoch(epoch).await
71    }
72
73    /// Update watermark for the managed state table.
74    pub fn update_watermark(&mut self, watermark: ScalarImpl) {
75        self.state_table.update_watermark(watermark)
76    }
77
78    pub fn insert(&mut self, value: impl Row) {
79        self.state_table.insert(value);
80    }
81
82    pub fn delete(&mut self, value: impl Row) {
83        self.state_table.delete(value);
84    }
85
86    fn get_topn_row(&self, row: OwnedRow, group_key_len: usize) -> TopNStateRow {
87        let pk = (&row).project(&self.state_table.pk_indices()[group_key_len..]);
88        let cache_key = serialize_pk_to_cache_key(pk, &self.cache_key_serde);
89
90        TopNStateRow::new(cache_key, row)
91    }
92
93    /// This function will return the rows in the range of [`offset`, `offset` + `limit`).
94    ///
95    /// If `group_key` is None, it will scan rows from the very beginning.
96    /// Otherwise it will scan rows with prefix `group_key`.
97    #[cfg(test)]
98    pub async fn find_range(
99        &self,
100        group_key: Option<impl GroupKey>,
101        offset: usize,
102        limit: Option<usize>,
103    ) -> StreamExecutorResult<Vec<TopNStateRow>> {
104        let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
105        let state_table_iter = self
106            .state_table
107            .iter_with_prefix(&group_key, sub_range, Default::default())
108            .await?;
109        pin_mut!(state_table_iter);
110
111        // here we don't expect users to have large OFFSET.
112        let (mut rows, mut stream) = if let Some(limit) = limit {
113            (
114                Vec::with_capacity(limit.min(1024)),
115                state_table_iter.skip(offset).take(limit),
116            )
117        } else {
118            (
119                Vec::with_capacity(1024),
120                state_table_iter.skip(offset).take(1024),
121            )
122        };
123        while let Some(item) = stream.next().await {
124            rows.push(self.get_topn_row(item?.into_owned_row(), group_key.len()));
125        }
126        Ok(rows)
127    }
128
129    /// # Arguments
130    ///
131    /// * `group_key` - Used as the prefix of the key to scan. Only for group TopN.
132    /// * `start_key` - The start point of the key to scan. It should be the last key of the middle
133    ///   cache. It doesn't contain the group key.
134    pub async fn fill_high_cache<const WITH_TIES: bool>(
135        &self,
136        group_key: Option<impl GroupKey>,
137        topn_cache: &mut TopNCache<WITH_TIES>,
138        start_key: Option<CacheKey>,
139        cache_size_limit: usize,
140    ) -> StreamExecutorResult<()> {
141        let high_cache = &mut topn_cache.high;
142        assert!(high_cache.is_empty());
143
144        // TODO(rc): iterate from `start_key`
145        let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
146        let state_table_iter = self
147            .state_table
148            .iter_with_prefix(
149                &group_key,
150                sub_range,
151                PrefetchOptions {
152                    prefetch: cache_size_limit == usize::MAX,
153                    for_large_query: false,
154                },
155            )
156            .await?;
157        pin_mut!(state_table_iter);
158
159        let mut group_row_count = 0;
160
161        while let Some(item) = state_table_iter.next().await {
162            group_row_count += 1;
163
164            // Note(bugen): should first compare with start key before constructing TopNStateRow.
165            let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len());
166            if let Some(start_key) = start_key.as_ref()
167                && &topn_row.cache_key <= start_key
168            {
169                continue;
170            }
171            high_cache.insert(topn_row.cache_key, (&topn_row.row).into());
172            if high_cache.len() == cache_size_limit {
173                break;
174            }
175        }
176
177        if WITH_TIES && topn_cache.high_is_full() {
178            let high_last_sort_key = topn_cache.high.last_key_value().unwrap().0.0.clone();
179            while let Some(item) = state_table_iter.next().await {
180                group_row_count += 1;
181
182                let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len());
183                if topn_row.cache_key.0 == high_last_sort_key {
184                    topn_cache
185                        .high
186                        .insert(topn_row.cache_key, (&topn_row.row).into());
187                } else {
188                    break;
189                }
190            }
191        }
192
193        if state_table_iter.next().await.is_none() {
194            // We can only update the row count when we have seen all rows of the group in the table.
195            topn_cache.update_table_row_count(group_row_count);
196        }
197
198        Ok(())
199    }
200
201    pub async fn init_topn_cache<const WITH_TIES: bool>(
202        &self,
203        group_key: Option<impl GroupKey>,
204        topn_cache: &mut TopNCache<WITH_TIES>,
205    ) -> StreamExecutorResult<()> {
206        assert!(topn_cache.low.as_ref().map(Cache::is_empty).unwrap_or(true));
207        assert!(topn_cache.middle.is_empty());
208        assert!(topn_cache.high.is_empty());
209
210        let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
211        let state_table_iter = self
212            .state_table
213            .iter_with_prefix(
214                &group_key,
215                sub_range,
216                PrefetchOptions {
217                    prefetch: topn_cache.limit == usize::MAX,
218                    for_large_query: false,
219                },
220            )
221            .await?;
222        pin_mut!(state_table_iter);
223
224        let mut group_row_count = 0;
225
226        if let Some(low) = &mut topn_cache.low {
227            while let Some(item) = state_table_iter.next().await {
228                group_row_count += 1;
229                let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len());
230                low.insert(topn_row.cache_key, (&topn_row.row).into());
231                if low.len() == topn_cache.offset {
232                    break;
233                }
234            }
235        }
236
237        assert!(topn_cache.limit > 0, "topn cache limit should always > 0");
238        while let Some(item) = state_table_iter.next().await {
239            group_row_count += 1;
240            let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len());
241            topn_cache
242                .middle
243                .insert(topn_row.cache_key, (&topn_row.row).into());
244            if topn_cache.middle.len() == topn_cache.limit {
245                break;
246            }
247        }
248        if WITH_TIES && topn_cache.middle_is_full() {
249            let middle_last_sort_key = topn_cache.middle.last_key_value().unwrap().0.0.clone();
250            while let Some(item) = state_table_iter.next().await {
251                group_row_count += 1;
252                let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len());
253                if topn_row.cache_key.0 == middle_last_sort_key {
254                    topn_cache
255                        .middle
256                        .insert(topn_row.cache_key, (&topn_row.row).into());
257                } else {
258                    topn_cache
259                        .high
260                        .insert(topn_row.cache_key, (&topn_row.row).into());
261                    break;
262                }
263            }
264        }
265
266        assert!(
267            topn_cache.high_cache_capacity > 0,
268            "topn cache high_capacity should always > 0"
269        );
270        while !topn_cache.high_is_full()
271            && let Some(item) = state_table_iter.next().await
272        {
273            group_row_count += 1;
274            let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len());
275            topn_cache
276                .high
277                .insert(topn_row.cache_key, (&topn_row.row).into());
278        }
279        if WITH_TIES && topn_cache.high_is_full() {
280            let high_last_sort_key = topn_cache.high.last_key_value().unwrap().0.0.clone();
281            while let Some(item) = state_table_iter.next().await {
282                group_row_count += 1;
283                let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len());
284                if topn_row.cache_key.0 == high_last_sort_key {
285                    topn_cache
286                        .high
287                        .insert(topn_row.cache_key, (&topn_row.row).into());
288                } else {
289                    break;
290                }
291            }
292        }
293
294        if state_table_iter.next().await.is_none() {
295            // After trying to initially fill in the cache, all table entries are in the cache,
296            // we then get the precise table row count.
297            topn_cache.update_table_row_count(group_row_count);
298        }
299
300        Ok(())
301    }
302
303    pub async fn flush(
304        &mut self,
305        epoch: EpochPair,
306    ) -> StreamExecutorResult<StateTablePostCommit<'_, S>> {
307        self.state_table.commit(epoch).await
308    }
309
310    pub async fn try_flush(&mut self) -> StreamExecutorResult<()> {
311        self.state_table.try_flush().await?;
312        Ok(())
313    }
314}
315
316#[cfg(test)]
317mod tests {
318    use risingwave_common::catalog::{Field, Schema};
319    use risingwave_common::types::DataType;
320    use risingwave_common::util::epoch::test_epoch;
321    use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
322
323    use super::*;
324    use crate::executor::test_utils::top_n_executor::create_in_memory_state_table;
325    use crate::executor::top_n::top_n_cache::{TopNCacheTrait, TopNStaging};
326    use crate::executor::top_n::{NO_GROUP_KEY, create_cache_key_serde};
327    use crate::row_nonnull;
328
329    fn cache_key_serde() -> CacheKeySerde {
330        let data_types = vec![DataType::Varchar, DataType::Int64];
331        let schema = Schema::new(data_types.into_iter().map(Field::unnamed).collect());
332        let storage_key = vec![
333            ColumnOrder::new(0, OrderType::ascending()),
334            ColumnOrder::new(1, OrderType::ascending()),
335        ];
336        let order_by = vec![ColumnOrder::new(0, OrderType::ascending())];
337
338        create_cache_key_serde(&storage_key, &schema, &order_by, &[])
339    }
340
341    #[tokio::test]
342    async fn test_managed_top_n_state() {
343        let state_table = {
344            let mut tb = create_in_memory_state_table(
345                &[DataType::Varchar, DataType::Int64],
346                &[OrderType::ascending(), OrderType::ascending()],
347                &[0, 1],
348            )
349            .await;
350            tb.init_epoch(EpochPair::new_test_epoch(test_epoch(1)))
351                .await
352                .unwrap();
353            tb
354        };
355
356        let cache_key_serde = cache_key_serde();
357        let mut managed_state = ManagedTopNState::new(state_table, cache_key_serde.clone());
358
359        let row1 = row_nonnull!["abc", 2i64];
360        let row2 = row_nonnull!["abc", 3i64];
361        let row3 = row_nonnull!["abd", 3i64];
362        let row4 = row_nonnull!["ab", 4i64];
363
364        let row1_bytes = serialize_pk_to_cache_key(row1.clone(), &cache_key_serde);
365        let row2_bytes = serialize_pk_to_cache_key(row2.clone(), &cache_key_serde);
366        let row3_bytes = serialize_pk_to_cache_key(row3.clone(), &cache_key_serde);
367        let row4_bytes = serialize_pk_to_cache_key(row4.clone(), &cache_key_serde);
368        let rows = [row1, row2, row3, row4];
369        let ordered_rows = [row1_bytes, row2_bytes, row3_bytes, row4_bytes];
370        managed_state.insert(rows[3].clone());
371
372        // now ("ab", 4)
373        let valid_rows = managed_state
374            .find_range(NO_GROUP_KEY, 0, Some(1))
375            .await
376            .unwrap();
377
378        assert_eq!(valid_rows.len(), 1);
379        assert_eq!(valid_rows[0].cache_key, ordered_rows[3].clone());
380
381        managed_state.insert(rows[2].clone());
382        let valid_rows = managed_state
383            .find_range(NO_GROUP_KEY, 1, Some(1))
384            .await
385            .unwrap();
386        assert_eq!(valid_rows.len(), 1);
387        assert_eq!(valid_rows[0].cache_key, ordered_rows[2].clone());
388
389        managed_state.insert(rows[1].clone());
390
391        let valid_rows = managed_state
392            .find_range(NO_GROUP_KEY, 1, Some(2))
393            .await
394            .unwrap();
395        assert_eq!(valid_rows.len(), 2);
396        assert_eq!(
397            valid_rows.first().unwrap().cache_key,
398            ordered_rows[1].clone()
399        );
400        assert_eq!(
401            valid_rows.last().unwrap().cache_key,
402            ordered_rows[2].clone()
403        );
404
405        // delete ("abc", 3)
406        managed_state.delete(rows[1].clone());
407
408        // insert ("abc", 2)
409        managed_state.insert(rows[0].clone());
410
411        let valid_rows = managed_state
412            .find_range(NO_GROUP_KEY, 0, Some(3))
413            .await
414            .unwrap();
415
416        assert_eq!(valid_rows.len(), 3);
417        assert_eq!(valid_rows[0].cache_key, ordered_rows[3].clone());
418        assert_eq!(valid_rows[1].cache_key, ordered_rows[0].clone());
419        assert_eq!(valid_rows[2].cache_key, ordered_rows[2].clone());
420    }
421
422    #[tokio::test]
423    async fn test_managed_top_n_state_fill_cache() {
424        let data_types = vec![DataType::Varchar, DataType::Int64];
425        let state_table = {
426            let mut tb = create_in_memory_state_table(
427                &data_types,
428                &[OrderType::ascending(), OrderType::ascending()],
429                &[0, 1],
430            )
431            .await;
432            tb.init_epoch(EpochPair::new_test_epoch(test_epoch(1)))
433                .await
434                .unwrap();
435            tb
436        };
437
438        let cache_key_serde = cache_key_serde();
439        let mut managed_state = ManagedTopNState::new(state_table, cache_key_serde.clone());
440
441        let row1 = row_nonnull!["abc", 2i64];
442        let row2 = row_nonnull!["abc", 3i64];
443        let row3 = row_nonnull!["abd", 3i64];
444        let row4 = row_nonnull!["ab", 4i64];
445        let row5 = row_nonnull!["abcd", 5i64];
446
447        let row1_bytes = serialize_pk_to_cache_key(row1.clone(), &cache_key_serde);
448        let row2_bytes = serialize_pk_to_cache_key(row2.clone(), &cache_key_serde);
449        let row3_bytes = serialize_pk_to_cache_key(row3.clone(), &cache_key_serde);
450        let row4_bytes = serialize_pk_to_cache_key(row4.clone(), &cache_key_serde);
451        let row5_bytes = serialize_pk_to_cache_key(row5.clone(), &cache_key_serde);
452        let rows = [row1, row2, row3, row4, row5];
453        let ordered_rows = vec![row1_bytes, row2_bytes, row3_bytes, row4_bytes, row5_bytes];
454
455        let mut cache = TopNCache::<false>::new(1, 1, data_types);
456
457        managed_state.insert(rows[3].clone());
458        managed_state.insert(rows[1].clone());
459        managed_state.insert(rows[2].clone());
460        managed_state.insert(rows[4].clone());
461
462        managed_state
463            .fill_high_cache(NO_GROUP_KEY, &mut cache, Some(ordered_rows[3].clone()), 2)
464            .await
465            .unwrap();
466        assert_eq!(cache.high.len(), 2);
467        assert_eq!(cache.high.first_key_value().unwrap().0, &ordered_rows[1]);
468        assert_eq!(cache.high.last_key_value().unwrap().0, &ordered_rows[4]);
469    }
470
471    #[tokio::test]
472    async fn test_top_n_cache_limit_1() {
473        let data_types = vec![DataType::Varchar, DataType::Int64];
474        let state_table = {
475            let mut tb = create_in_memory_state_table(
476                &data_types,
477                &[OrderType::ascending(), OrderType::ascending()],
478                &[0, 1],
479            )
480            .await;
481            tb.init_epoch(EpochPair::new_test_epoch(test_epoch(1)))
482                .await
483                .unwrap();
484            tb
485        };
486
487        let cache_key_serde = cache_key_serde();
488        let mut managed_state = ManagedTopNState::new(state_table, cache_key_serde.clone());
489
490        let row1 = row_nonnull!["abc", 2i64];
491        let row1_bytes = serialize_pk_to_cache_key(row1.clone(), &cache_key_serde);
492
493        let mut cache = TopNCache::<true>::new(0, 1, data_types);
494        cache.insert(row1_bytes.clone(), row1.clone(), &mut TopNStaging::new());
495        cache
496            .delete(
497                NO_GROUP_KEY,
498                &mut managed_state,
499                row1_bytes,
500                row1,
501                &mut TopNStaging::new(),
502            )
503            .await
504            .unwrap();
505    }
506}