1use std::ops::Bound;
16
17use futures::{StreamExt, pin_mut};
18use risingwave_common::row::{OwnedRow, Row, RowExt};
19use risingwave_common::types::ScalarImpl;
20use risingwave_common::util::epoch::EpochPair;
21use risingwave_storage::StateStore;
22use risingwave_storage::store::PrefetchOptions;
23
24use super::top_n_cache::CacheKey;
25use super::{CacheKeySerde, GroupKey, TopNCache, serialize_pk_to_cache_key};
26use crate::common::table::state_table::{StateTable, StateTablePostCommit};
27use crate::executor::error::StreamExecutorResult;
28use crate::executor::top_n::top_n_cache::Cache;
29
30pub struct ManagedTopNState<S: StateStore> {
36 state_table: StateTable<S>,
38
39 cache_key_serde: CacheKeySerde,
41}
42
43#[derive(Clone, PartialEq, Debug)]
44pub struct TopNStateRow {
45 pub cache_key: CacheKey,
46 pub row: OwnedRow,
47}
48
49impl TopNStateRow {
50 pub fn new(cache_key: CacheKey, row: OwnedRow) -> Self {
51 Self { cache_key, row }
52 }
53}
54
55impl<S: StateStore> ManagedTopNState<S> {
56 pub fn new(state_table: StateTable<S>, cache_key_serde: CacheKeySerde) -> Self {
57 Self {
58 state_table,
59 cache_key_serde,
60 }
61 }
62
63 pub fn table(&self) -> &StateTable<S> {
65 &self.state_table
66 }
67
68 pub async fn init_epoch(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> {
70 self.state_table.init_epoch(epoch).await
71 }
72
73 pub fn update_watermark(&mut self, watermark: ScalarImpl) {
75 self.state_table.update_watermark(watermark)
76 }
77
78 pub fn insert(&mut self, value: impl Row) {
79 self.state_table.insert(value);
80 }
81
82 pub fn delete(&mut self, value: impl Row) {
83 self.state_table.delete(value);
84 }
85
86 fn get_topn_row(&self, row: OwnedRow, group_key_len: usize) -> TopNStateRow {
87 let pk = (&row).project(&self.state_table.pk_indices()[group_key_len..]);
88 let cache_key = serialize_pk_to_cache_key(pk, &self.cache_key_serde);
89
90 TopNStateRow::new(cache_key, row)
91 }
92
93 #[cfg(test)]
98 pub async fn find_range(
99 &self,
100 group_key: Option<impl GroupKey>,
101 offset: usize,
102 limit: Option<usize>,
103 ) -> StreamExecutorResult<Vec<TopNStateRow>> {
104 let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
105 let state_table_iter = self
106 .state_table
107 .iter_with_prefix(&group_key, sub_range, Default::default())
108 .await?;
109 pin_mut!(state_table_iter);
110
111 let (mut rows, mut stream) = if let Some(limit) = limit {
113 (
114 Vec::with_capacity(limit.min(1024)),
115 state_table_iter.skip(offset).take(limit),
116 )
117 } else {
118 (
119 Vec::with_capacity(1024),
120 state_table_iter.skip(offset).take(1024),
121 )
122 };
123 while let Some(item) = stream.next().await {
124 rows.push(self.get_topn_row(item?.into_owned_row(), group_key.len()));
125 }
126 Ok(rows)
127 }
128
129 pub async fn fill_high_cache<const WITH_TIES: bool>(
135 &self,
136 group_key: Option<impl GroupKey>,
137 topn_cache: &mut TopNCache<WITH_TIES>,
138 start_key: Option<CacheKey>,
139 cache_size_limit: usize,
140 ) -> StreamExecutorResult<()> {
141 let high_cache = &mut topn_cache.high;
142 assert!(high_cache.is_empty());
143
144 let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
146 let state_table_iter = self
147 .state_table
148 .iter_with_prefix(
149 &group_key,
150 sub_range,
151 PrefetchOptions {
152 prefetch: cache_size_limit == usize::MAX,
153 for_large_query: false,
154 },
155 )
156 .await?;
157 pin_mut!(state_table_iter);
158
159 let mut group_row_count = 0;
160
161 while let Some(item) = state_table_iter.next().await {
162 group_row_count += 1;
163
164 let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len());
166 if let Some(start_key) = start_key.as_ref()
167 && &topn_row.cache_key <= start_key
168 {
169 continue;
170 }
171 high_cache.insert(topn_row.cache_key, (&topn_row.row).into());
172 if high_cache.len() == cache_size_limit {
173 break;
174 }
175 }
176
177 if WITH_TIES && topn_cache.high_is_full() {
178 let high_last_sort_key = topn_cache.high.last_key_value().unwrap().0.0.clone();
179 while let Some(item) = state_table_iter.next().await {
180 group_row_count += 1;
181
182 let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len());
183 if topn_row.cache_key.0 == high_last_sort_key {
184 topn_cache
185 .high
186 .insert(topn_row.cache_key, (&topn_row.row).into());
187 } else {
188 break;
189 }
190 }
191 }
192
193 if state_table_iter.next().await.is_none() {
194 topn_cache.update_table_row_count(group_row_count);
196 }
197
198 Ok(())
199 }
200
201 pub async fn init_topn_cache<const WITH_TIES: bool>(
202 &self,
203 group_key: Option<impl GroupKey>,
204 topn_cache: &mut TopNCache<WITH_TIES>,
205 ) -> StreamExecutorResult<()> {
206 assert!(topn_cache.low.as_ref().map(Cache::is_empty).unwrap_or(true));
207 assert!(topn_cache.middle.is_empty());
208 assert!(topn_cache.high.is_empty());
209
210 let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
211 let state_table_iter = self
212 .state_table
213 .iter_with_prefix(
214 &group_key,
215 sub_range,
216 PrefetchOptions {
217 prefetch: topn_cache.limit == usize::MAX,
218 for_large_query: false,
219 },
220 )
221 .await?;
222 pin_mut!(state_table_iter);
223
224 let mut group_row_count = 0;
225
226 if let Some(low) = &mut topn_cache.low {
227 while let Some(item) = state_table_iter.next().await {
228 group_row_count += 1;
229 let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len());
230 low.insert(topn_row.cache_key, (&topn_row.row).into());
231 if low.len() == topn_cache.offset {
232 break;
233 }
234 }
235 }
236
237 assert!(topn_cache.limit > 0, "topn cache limit should always > 0");
238 while let Some(item) = state_table_iter.next().await {
239 group_row_count += 1;
240 let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len());
241 topn_cache
242 .middle
243 .insert(topn_row.cache_key, (&topn_row.row).into());
244 if topn_cache.middle.len() == topn_cache.limit {
245 break;
246 }
247 }
248 if WITH_TIES && topn_cache.middle_is_full() {
249 let middle_last_sort_key = topn_cache.middle.last_key_value().unwrap().0.0.clone();
250 while let Some(item) = state_table_iter.next().await {
251 group_row_count += 1;
252 let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len());
253 if topn_row.cache_key.0 == middle_last_sort_key {
254 topn_cache
255 .middle
256 .insert(topn_row.cache_key, (&topn_row.row).into());
257 } else {
258 topn_cache
259 .high
260 .insert(topn_row.cache_key, (&topn_row.row).into());
261 break;
262 }
263 }
264 }
265
266 assert!(
267 topn_cache.high_cache_capacity > 0,
268 "topn cache high_capacity should always > 0"
269 );
270 while !topn_cache.high_is_full()
271 && let Some(item) = state_table_iter.next().await
272 {
273 group_row_count += 1;
274 let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len());
275 topn_cache
276 .high
277 .insert(topn_row.cache_key, (&topn_row.row).into());
278 }
279 if WITH_TIES && topn_cache.high_is_full() {
280 let high_last_sort_key = topn_cache.high.last_key_value().unwrap().0.0.clone();
281 while let Some(item) = state_table_iter.next().await {
282 group_row_count += 1;
283 let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len());
284 if topn_row.cache_key.0 == high_last_sort_key {
285 topn_cache
286 .high
287 .insert(topn_row.cache_key, (&topn_row.row).into());
288 } else {
289 break;
290 }
291 }
292 }
293
294 if state_table_iter.next().await.is_none() {
295 topn_cache.update_table_row_count(group_row_count);
298 }
299
300 Ok(())
301 }
302
303 pub async fn flush(
304 &mut self,
305 epoch: EpochPair,
306 ) -> StreamExecutorResult<StateTablePostCommit<'_, S>> {
307 self.state_table.commit(epoch).await
308 }
309
310 pub async fn try_flush(&mut self) -> StreamExecutorResult<()> {
311 self.state_table.try_flush().await?;
312 Ok(())
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use risingwave_common::catalog::{Field, Schema};
319 use risingwave_common::types::DataType;
320 use risingwave_common::util::epoch::test_epoch;
321 use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
322
323 use super::*;
324 use crate::executor::test_utils::top_n_executor::create_in_memory_state_table;
325 use crate::executor::top_n::top_n_cache::{TopNCacheTrait, TopNStaging};
326 use crate::executor::top_n::{NO_GROUP_KEY, create_cache_key_serde};
327 use crate::row_nonnull;
328
329 fn cache_key_serde() -> CacheKeySerde {
330 let data_types = vec![DataType::Varchar, DataType::Int64];
331 let schema = Schema::new(data_types.into_iter().map(Field::unnamed).collect());
332 let storage_key = vec![
333 ColumnOrder::new(0, OrderType::ascending()),
334 ColumnOrder::new(1, OrderType::ascending()),
335 ];
336 let order_by = vec![ColumnOrder::new(0, OrderType::ascending())];
337
338 create_cache_key_serde(&storage_key, &schema, &order_by, &[])
339 }
340
341 #[tokio::test]
342 async fn test_managed_top_n_state() {
343 let state_table = {
344 let mut tb = create_in_memory_state_table(
345 &[DataType::Varchar, DataType::Int64],
346 &[OrderType::ascending(), OrderType::ascending()],
347 &[0, 1],
348 )
349 .await;
350 tb.init_epoch(EpochPair::new_test_epoch(test_epoch(1)))
351 .await
352 .unwrap();
353 tb
354 };
355
356 let cache_key_serde = cache_key_serde();
357 let mut managed_state = ManagedTopNState::new(state_table, cache_key_serde.clone());
358
359 let row1 = row_nonnull!["abc", 2i64];
360 let row2 = row_nonnull!["abc", 3i64];
361 let row3 = row_nonnull!["abd", 3i64];
362 let row4 = row_nonnull!["ab", 4i64];
363
364 let row1_bytes = serialize_pk_to_cache_key(row1.clone(), &cache_key_serde);
365 let row2_bytes = serialize_pk_to_cache_key(row2.clone(), &cache_key_serde);
366 let row3_bytes = serialize_pk_to_cache_key(row3.clone(), &cache_key_serde);
367 let row4_bytes = serialize_pk_to_cache_key(row4.clone(), &cache_key_serde);
368 let rows = [row1, row2, row3, row4];
369 let ordered_rows = [row1_bytes, row2_bytes, row3_bytes, row4_bytes];
370 managed_state.insert(rows[3].clone());
371
372 let valid_rows = managed_state
374 .find_range(NO_GROUP_KEY, 0, Some(1))
375 .await
376 .unwrap();
377
378 assert_eq!(valid_rows.len(), 1);
379 assert_eq!(valid_rows[0].cache_key, ordered_rows[3].clone());
380
381 managed_state.insert(rows[2].clone());
382 let valid_rows = managed_state
383 .find_range(NO_GROUP_KEY, 1, Some(1))
384 .await
385 .unwrap();
386 assert_eq!(valid_rows.len(), 1);
387 assert_eq!(valid_rows[0].cache_key, ordered_rows[2].clone());
388
389 managed_state.insert(rows[1].clone());
390
391 let valid_rows = managed_state
392 .find_range(NO_GROUP_KEY, 1, Some(2))
393 .await
394 .unwrap();
395 assert_eq!(valid_rows.len(), 2);
396 assert_eq!(
397 valid_rows.first().unwrap().cache_key,
398 ordered_rows[1].clone()
399 );
400 assert_eq!(
401 valid_rows.last().unwrap().cache_key,
402 ordered_rows[2].clone()
403 );
404
405 managed_state.delete(rows[1].clone());
407
408 managed_state.insert(rows[0].clone());
410
411 let valid_rows = managed_state
412 .find_range(NO_GROUP_KEY, 0, Some(3))
413 .await
414 .unwrap();
415
416 assert_eq!(valid_rows.len(), 3);
417 assert_eq!(valid_rows[0].cache_key, ordered_rows[3].clone());
418 assert_eq!(valid_rows[1].cache_key, ordered_rows[0].clone());
419 assert_eq!(valid_rows[2].cache_key, ordered_rows[2].clone());
420 }
421
422 #[tokio::test]
423 async fn test_managed_top_n_state_fill_cache() {
424 let data_types = vec![DataType::Varchar, DataType::Int64];
425 let state_table = {
426 let mut tb = create_in_memory_state_table(
427 &data_types,
428 &[OrderType::ascending(), OrderType::ascending()],
429 &[0, 1],
430 )
431 .await;
432 tb.init_epoch(EpochPair::new_test_epoch(test_epoch(1)))
433 .await
434 .unwrap();
435 tb
436 };
437
438 let cache_key_serde = cache_key_serde();
439 let mut managed_state = ManagedTopNState::new(state_table, cache_key_serde.clone());
440
441 let row1 = row_nonnull!["abc", 2i64];
442 let row2 = row_nonnull!["abc", 3i64];
443 let row3 = row_nonnull!["abd", 3i64];
444 let row4 = row_nonnull!["ab", 4i64];
445 let row5 = row_nonnull!["abcd", 5i64];
446
447 let row1_bytes = serialize_pk_to_cache_key(row1.clone(), &cache_key_serde);
448 let row2_bytes = serialize_pk_to_cache_key(row2.clone(), &cache_key_serde);
449 let row3_bytes = serialize_pk_to_cache_key(row3.clone(), &cache_key_serde);
450 let row4_bytes = serialize_pk_to_cache_key(row4.clone(), &cache_key_serde);
451 let row5_bytes = serialize_pk_to_cache_key(row5.clone(), &cache_key_serde);
452 let rows = [row1, row2, row3, row4, row5];
453 let ordered_rows = vec![row1_bytes, row2_bytes, row3_bytes, row4_bytes, row5_bytes];
454
455 let mut cache = TopNCache::<false>::new(1, 1, data_types);
456
457 managed_state.insert(rows[3].clone());
458 managed_state.insert(rows[1].clone());
459 managed_state.insert(rows[2].clone());
460 managed_state.insert(rows[4].clone());
461
462 managed_state
463 .fill_high_cache(NO_GROUP_KEY, &mut cache, Some(ordered_rows[3].clone()), 2)
464 .await
465 .unwrap();
466 assert_eq!(cache.high.len(), 2);
467 assert_eq!(cache.high.first_key_value().unwrap().0, &ordered_rows[1]);
468 assert_eq!(cache.high.last_key_value().unwrap().0, &ordered_rows[4]);
469 }
470
471 #[tokio::test]
472 async fn test_top_n_cache_limit_1() {
473 let data_types = vec![DataType::Varchar, DataType::Int64];
474 let state_table = {
475 let mut tb = create_in_memory_state_table(
476 &data_types,
477 &[OrderType::ascending(), OrderType::ascending()],
478 &[0, 1],
479 )
480 .await;
481 tb.init_epoch(EpochPair::new_test_epoch(test_epoch(1)))
482 .await
483 .unwrap();
484 tb
485 };
486
487 let cache_key_serde = cache_key_serde();
488 let mut managed_state = ManagedTopNState::new(state_table, cache_key_serde.clone());
489
490 let row1 = row_nonnull!["abc", 2i64];
491 let row1_bytes = serialize_pk_to_cache_key(row1.clone(), &cache_key_serde);
492
493 let mut cache = TopNCache::<true>::new(0, 1, data_types);
494 cache.insert(row1_bytes.clone(), row1.clone(), &mut TopNStaging::new());
495 cache
496 .delete(
497 NO_GROUP_KEY,
498 &mut managed_state,
499 row1_bytes,
500 row1,
501 &mut TopNStaging::new(),
502 )
503 .await
504 .unwrap();
505 }
506}