1use std::ops::Bound;
16
17use futures::{StreamExt, pin_mut};
18use risingwave_common::row::{OwnedRow, Row, RowExt};
19use risingwave_common::types::ScalarImpl;
20use risingwave_common::util::epoch::EpochPair;
21use risingwave_storage::StateStore;
22use risingwave_storage::store::PrefetchOptions;
23
24use super::top_n_cache::CacheKey;
25use super::{CacheKeySerde, GroupKey, TopNCache, serialize_pk_to_cache_key};
26use crate::common::table::state_table::{StateTable, StateTablePostCommit};
27use crate::executor::error::StreamExecutorResult;
28use crate::executor::top_n::top_n_cache::Cache;
29
30pub struct ManagedTopNState<S: StateStore> {
36 state_table: StateTable<S>,
38
39 cache_key_serde: CacheKeySerde,
41}
42
43#[derive(Clone, PartialEq, Debug)]
44pub struct TopNStateRow {
45 pub cache_key: CacheKey,
46 pub row: OwnedRow,
47}
48
49impl TopNStateRow {
50 pub fn new(cache_key: CacheKey, row: OwnedRow) -> Self {
51 Self { cache_key, row }
52 }
53}
54
55impl<S: StateStore> ManagedTopNState<S> {
56 pub fn new(state_table: StateTable<S>, cache_key_serde: CacheKeySerde) -> Self {
57 Self {
58 state_table,
59 cache_key_serde,
60 }
61 }
62
63 pub fn table(&self) -> &StateTable<S> {
65 &self.state_table
66 }
67
68 pub async fn init_epoch(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> {
70 self.state_table.init_epoch(epoch).await
71 }
72
73 pub fn update_watermark(&mut self, watermark: ScalarImpl) {
75 self.state_table.update_watermark(watermark)
76 }
77
78 pub fn insert(&mut self, value: impl Row) {
79 self.state_table.insert(value);
80 }
81
82 pub fn delete(&mut self, value: impl Row) {
83 self.state_table.delete(value);
84 }
85
86 fn get_topn_row(&self, row: OwnedRow, group_key_len: usize) -> TopNStateRow {
87 let pk = (&row).project(&self.state_table.pk_indices()[group_key_len..]);
88 let cache_key = serialize_pk_to_cache_key(pk, &self.cache_key_serde);
89
90 TopNStateRow::new(cache_key, row)
91 }
92
93 #[cfg(test)]
98 pub async fn find_range(
99 &self,
100 group_key: Option<impl GroupKey>,
101 offset: usize,
102 limit: Option<usize>,
103 ) -> StreamExecutorResult<Vec<TopNStateRow>> {
104 let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
105 let state_table_iter = self
106 .state_table
107 .iter_with_prefix(&group_key, sub_range, Default::default())
108 .await?;
109 pin_mut!(state_table_iter);
110
111 let (mut rows, mut stream) = if let Some(limit) = limit {
113 (
114 Vec::with_capacity(limit.min(1024)),
115 state_table_iter.skip(offset).take(limit),
116 )
117 } else {
118 (
119 Vec::with_capacity(1024),
120 state_table_iter.skip(offset).take(1024),
121 )
122 };
123 while let Some(item) = stream.next().await {
124 rows.push(self.get_topn_row(item?.into_owned_row(), group_key.len()));
125 }
126 Ok(rows)
127 }
128
129 pub async fn fill_high_cache<const WITH_TIES: bool>(
135 &self,
136 group_key: Option<impl GroupKey>,
137 topn_cache: &mut TopNCache<WITH_TIES>,
138 start_key: Option<CacheKey>,
139 cache_size_limit: usize,
140 ) -> StreamExecutorResult<()> {
141 let high_cache = &mut topn_cache.high;
142 assert!(high_cache.is_empty());
143
144 let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
146 let state_table_iter = self
147 .state_table
148 .iter_with_prefix(
149 &group_key,
150 sub_range,
151 PrefetchOptions {
152 prefetch: cache_size_limit == usize::MAX,
153 for_large_query: false,
154 },
155 )
156 .await?;
157 pin_mut!(state_table_iter);
158
159 let mut group_row_count = 0;
160
161 while let Some(item) = state_table_iter.next().await {
162 group_row_count += 1;
163
164 let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len());
166 if let Some(start_key) = start_key.as_ref()
167 && &topn_row.cache_key <= start_key
168 {
169 continue;
170 }
171 high_cache.insert(topn_row.cache_key, (&topn_row.row).into());
172 if high_cache.len() == cache_size_limit {
173 break;
174 }
175 }
176
177 if WITH_TIES && topn_cache.high_is_full() {
178 let high_last_sort_key = topn_cache.high.last_key_value().unwrap().0.0.clone();
179 while let Some(item) = state_table_iter.next().await {
180 group_row_count += 1;
181
182 let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len());
183 if topn_row.cache_key.0 == high_last_sort_key {
184 topn_cache
185 .high
186 .insert(topn_row.cache_key, (&topn_row.row).into());
187 } else {
188 break;
189 }
190 }
191 }
192
193 if state_table_iter.next().await.is_none() {
194 topn_cache.update_table_row_count(group_row_count);
196 }
197
198 Ok(())
199 }
200
201 pub async fn init_topn_cache_inner<const WITH_TIES: bool>(
202 &self,
203 group_key: Option<impl GroupKey>,
204 topn_cache: &mut TopNCache<WITH_TIES>,
205 skip_high: bool,
206 ) -> StreamExecutorResult<()> {
207 assert!(topn_cache.low.as_ref().map(Cache::is_empty).unwrap_or(true));
208 assert!(topn_cache.middle.is_empty());
209 assert!(topn_cache.high.is_empty());
210
211 let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
212 let state_table_iter = self
213 .state_table
214 .iter_with_prefix(
215 &group_key,
216 sub_range,
217 PrefetchOptions {
218 prefetch: topn_cache.limit == usize::MAX,
219 for_large_query: false,
220 },
221 )
222 .await?;
223 pin_mut!(state_table_iter);
224
225 let mut group_row_count = 0;
226
227 if let Some(low) = &mut topn_cache.low {
228 while let Some(item) = state_table_iter.next().await {
229 group_row_count += 1;
230 let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len());
231 low.insert(topn_row.cache_key, (&topn_row.row).into());
232 if low.len() == topn_cache.offset {
233 break;
234 }
235 }
236 }
237
238 assert!(topn_cache.limit > 0, "topn cache limit should always > 0");
239 while let Some(item) = state_table_iter.next().await {
240 group_row_count += 1;
241 let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len());
242 topn_cache
243 .middle
244 .insert(topn_row.cache_key, (&topn_row.row).into());
245 if topn_cache.middle.len() == topn_cache.limit {
246 break;
247 }
248 }
249 if WITH_TIES && topn_cache.middle_is_full() {
250 let middle_last_sort_key = topn_cache.middle.last_key_value().unwrap().0.0.clone();
251 while let Some(item) = state_table_iter.next().await {
252 group_row_count += 1;
253 let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len());
254 if topn_row.cache_key.0 == middle_last_sort_key {
255 topn_cache
256 .middle
257 .insert(topn_row.cache_key, (&topn_row.row).into());
258 } else {
259 topn_cache
260 .high
261 .insert(topn_row.cache_key, (&topn_row.row).into());
262 break;
263 }
264 }
265 }
266
267 if !skip_high {
268 assert!(
269 topn_cache.high_cache_capacity > 0,
270 "topn cache high_capacity should always > 0"
271 );
272 while !topn_cache.high_is_full()
273 && let Some(item) = state_table_iter.next().await
274 {
275 group_row_count += 1;
276 let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len());
277 topn_cache
278 .high
279 .insert(topn_row.cache_key, (&topn_row.row).into());
280 }
281 if WITH_TIES && topn_cache.high_is_full() {
282 let high_last_sort_key = topn_cache.high.last_key_value().unwrap().0.0.clone();
283 while let Some(item) = state_table_iter.next().await {
284 group_row_count += 1;
285 let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len());
286 if topn_row.cache_key.0 == high_last_sort_key {
287 topn_cache
288 .high
289 .insert(topn_row.cache_key, (&topn_row.row).into());
290 } else {
291 break;
292 }
293 }
294 }
295 if state_table_iter.next().await.is_none() {
296 topn_cache.update_table_row_count(group_row_count);
299 }
300 } else {
301 topn_cache.update_table_row_count(group_row_count);
302 }
303
304 Ok(())
305 }
306
307 pub async fn init_topn_cache<const WITH_TIES: bool>(
308 &self,
309 group_key: Option<impl GroupKey>,
310 topn_cache: &mut TopNCache<WITH_TIES>,
311 ) -> StreamExecutorResult<()> {
312 self.init_topn_cache_inner(group_key, topn_cache, false)
313 .await
314 }
315
316 pub async fn init_append_only_topn_cache<const WITH_TIES: bool>(
317 &self,
318 group_key: Option<impl GroupKey>,
319 topn_cache: &mut TopNCache<WITH_TIES>,
320 ) -> StreamExecutorResult<()> {
321 self.init_topn_cache_inner(group_key, topn_cache, true)
322 .await
323 }
324
325 pub async fn flush(
326 &mut self,
327 epoch: EpochPair,
328 ) -> StreamExecutorResult<StateTablePostCommit<'_, S>> {
329 self.state_table.commit(epoch).await
330 }
331
332 pub async fn try_flush(&mut self) -> StreamExecutorResult<()> {
333 self.state_table.try_flush().await?;
334 Ok(())
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use risingwave_common::catalog::{Field, Schema};
341 use risingwave_common::types::DataType;
342 use risingwave_common::util::epoch::test_epoch;
343 use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
344
345 use super::*;
346 use crate::executor::test_utils::top_n_executor::create_in_memory_state_table;
347 use crate::executor::top_n::top_n_cache::{TopNCacheTrait, TopNStaging};
348 use crate::executor::top_n::{NO_GROUP_KEY, create_cache_key_serde};
349 use crate::row_nonnull;
350
351 fn cache_key_serde() -> CacheKeySerde {
352 let data_types = vec![DataType::Varchar, DataType::Int64];
353 let schema = Schema::new(data_types.into_iter().map(Field::unnamed).collect());
354 let storage_key = vec![
355 ColumnOrder::new(0, OrderType::ascending()),
356 ColumnOrder::new(1, OrderType::ascending()),
357 ];
358 let order_by = vec![ColumnOrder::new(0, OrderType::ascending())];
359
360 create_cache_key_serde(&storage_key, &schema, &order_by, &[])
361 }
362
363 #[tokio::test]
364 async fn test_managed_top_n_state() {
365 let state_table = {
366 let mut tb = create_in_memory_state_table(
367 &[DataType::Varchar, DataType::Int64],
368 &[OrderType::ascending(), OrderType::ascending()],
369 &[0, 1],
370 )
371 .await;
372 tb.init_epoch(EpochPair::new_test_epoch(test_epoch(1)))
373 .await
374 .unwrap();
375 tb
376 };
377
378 let cache_key_serde = cache_key_serde();
379 let mut managed_state = ManagedTopNState::new(state_table, cache_key_serde.clone());
380
381 let row1 = row_nonnull!["abc", 2i64];
382 let row2 = row_nonnull!["abc", 3i64];
383 let row3 = row_nonnull!["abd", 3i64];
384 let row4 = row_nonnull!["ab", 4i64];
385
386 let row1_bytes = serialize_pk_to_cache_key(row1.clone(), &cache_key_serde);
387 let row2_bytes = serialize_pk_to_cache_key(row2.clone(), &cache_key_serde);
388 let row3_bytes = serialize_pk_to_cache_key(row3.clone(), &cache_key_serde);
389 let row4_bytes = serialize_pk_to_cache_key(row4.clone(), &cache_key_serde);
390 let rows = [row1, row2, row3, row4];
391 let ordered_rows = [row1_bytes, row2_bytes, row3_bytes, row4_bytes];
392 managed_state.insert(rows[3].clone());
393
394 let valid_rows = managed_state
396 .find_range(NO_GROUP_KEY, 0, Some(1))
397 .await
398 .unwrap();
399
400 assert_eq!(valid_rows.len(), 1);
401 assert_eq!(valid_rows[0].cache_key, ordered_rows[3].clone());
402
403 managed_state.insert(rows[2].clone());
404 let valid_rows = managed_state
405 .find_range(NO_GROUP_KEY, 1, Some(1))
406 .await
407 .unwrap();
408 assert_eq!(valid_rows.len(), 1);
409 assert_eq!(valid_rows[0].cache_key, ordered_rows[2].clone());
410
411 managed_state.insert(rows[1].clone());
412
413 let valid_rows = managed_state
414 .find_range(NO_GROUP_KEY, 1, Some(2))
415 .await
416 .unwrap();
417 assert_eq!(valid_rows.len(), 2);
418 assert_eq!(
419 valid_rows.first().unwrap().cache_key,
420 ordered_rows[1].clone()
421 );
422 assert_eq!(
423 valid_rows.last().unwrap().cache_key,
424 ordered_rows[2].clone()
425 );
426
427 managed_state.delete(rows[1].clone());
429
430 managed_state.insert(rows[0].clone());
432
433 let valid_rows = managed_state
434 .find_range(NO_GROUP_KEY, 0, Some(3))
435 .await
436 .unwrap();
437
438 assert_eq!(valid_rows.len(), 3);
439 assert_eq!(valid_rows[0].cache_key, ordered_rows[3].clone());
440 assert_eq!(valid_rows[1].cache_key, ordered_rows[0].clone());
441 assert_eq!(valid_rows[2].cache_key, ordered_rows[2].clone());
442 }
443
444 #[tokio::test]
445 async fn test_managed_top_n_state_fill_cache() {
446 let data_types = vec![DataType::Varchar, DataType::Int64];
447 let state_table = {
448 let mut tb = create_in_memory_state_table(
449 &data_types,
450 &[OrderType::ascending(), OrderType::ascending()],
451 &[0, 1],
452 )
453 .await;
454 tb.init_epoch(EpochPair::new_test_epoch(test_epoch(1)))
455 .await
456 .unwrap();
457 tb
458 };
459
460 let cache_key_serde = cache_key_serde();
461 let mut managed_state = ManagedTopNState::new(state_table, cache_key_serde.clone());
462
463 let row1 = row_nonnull!["abc", 2i64];
464 let row2 = row_nonnull!["abc", 3i64];
465 let row3 = row_nonnull!["abd", 3i64];
466 let row4 = row_nonnull!["ab", 4i64];
467 let row5 = row_nonnull!["abcd", 5i64];
468
469 let row1_bytes = serialize_pk_to_cache_key(row1.clone(), &cache_key_serde);
470 let row2_bytes = serialize_pk_to_cache_key(row2.clone(), &cache_key_serde);
471 let row3_bytes = serialize_pk_to_cache_key(row3.clone(), &cache_key_serde);
472 let row4_bytes = serialize_pk_to_cache_key(row4.clone(), &cache_key_serde);
473 let row5_bytes = serialize_pk_to_cache_key(row5.clone(), &cache_key_serde);
474 let rows = [row1, row2, row3, row4, row5];
475 let ordered_rows = vec![row1_bytes, row2_bytes, row3_bytes, row4_bytes, row5_bytes];
476
477 let mut cache = TopNCache::<false>::new(1, 1, data_types);
478
479 managed_state.insert(rows[3].clone());
480 managed_state.insert(rows[1].clone());
481 managed_state.insert(rows[2].clone());
482 managed_state.insert(rows[4].clone());
483
484 managed_state
485 .fill_high_cache(NO_GROUP_KEY, &mut cache, Some(ordered_rows[3].clone()), 2)
486 .await
487 .unwrap();
488 assert_eq!(cache.high.len(), 2);
489 assert_eq!(cache.high.first_key_value().unwrap().0, &ordered_rows[1]);
490 assert_eq!(cache.high.last_key_value().unwrap().0, &ordered_rows[4]);
491 }
492
493 #[tokio::test]
494 async fn test_top_n_cache_limit_1() {
495 let data_types = vec![DataType::Varchar, DataType::Int64];
496 let state_table = {
497 let mut tb = create_in_memory_state_table(
498 &data_types,
499 &[OrderType::ascending(), OrderType::ascending()],
500 &[0, 1],
501 )
502 .await;
503 tb.init_epoch(EpochPair::new_test_epoch(test_epoch(1)))
504 .await
505 .unwrap();
506 tb
507 };
508
509 let cache_key_serde = cache_key_serde();
510 let mut managed_state = ManagedTopNState::new(state_table, cache_key_serde.clone());
511
512 let row1 = row_nonnull!["abc", 2i64];
513 let row1_bytes = serialize_pk_to_cache_key(row1.clone(), &cache_key_serde);
514
515 let mut cache = TopNCache::<true>::new(0, 1, data_types);
516 cache.insert(row1_bytes.clone(), row1.clone(), &mut TopNStaging::new());
517 cache
518 .delete(
519 NO_GROUP_KEY,
520 &mut managed_state,
521 row1_bytes,
522 row1,
523 &mut TopNStaging::new(),
524 )
525 .await
526 .unwrap();
527 }
528}