risingwave_stream/executor/over_window/
over_partition.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Types and functions that store or manipulate state/cache inside one single over window
16//! partition.
17
18use std::collections::BTreeMap;
19use std::marker::PhantomData;
20use std::ops::{Bound, RangeInclusive};
21
22use delta_btree_map::{Change, DeltaBTreeMap};
23use educe::Educe;
24use futures_async_stream::for_await;
25use risingwave_common::array::stream_record::Record;
26use risingwave_common::row::{OwnedRow, Row, RowExt};
27use risingwave_common::session_config::OverWindowCachePolicy as CachePolicy;
28use risingwave_common::types::{Datum, Sentinelled};
29use risingwave_common::util::iter_util::ZipEqFast;
30use risingwave_expr::window_function::{StateKey, WindowStates, create_window_state};
31use risingwave_storage::StateStore;
32use risingwave_storage::store::PrefetchOptions;
33
34use super::general::{Calls, RowConverter};
35use super::range_cache::{CacheKey, PartitionCache};
36use crate::common::table::state_table::StateTable;
37use crate::consistency::{consistency_error, enable_strict_consistency};
38use crate::executor::StreamExecutorResult;
39use crate::executor::over_window::frame_finder::*;
40
41/// Changes happened in one over window partition.
42pub(super) type PartitionDelta = BTreeMap<CacheKey, Change<OwnedRow>>;
43
44#[derive(Default, Debug)]
45pub(super) struct OverPartitionStats {
46    // stats for range cache operations
47    pub lookup_count: u64,
48    pub left_miss_count: u64,
49    pub right_miss_count: u64,
50
51    // stats for window function state computation
52    pub accessed_entry_count: u64,
53    pub compute_count: u64,
54    pub same_output_count: u64,
55}
56
57/// [`AffectedRange`] represents a range of keys that are affected by a delta.
58/// The [`CacheKey`] fields are keys in the partition range cache + delta, which is
59/// represented by [`DeltaBTreeMap`].
60///
61/// - `first_curr_key` and `last_curr_key` are the current keys of the first and the last
62///   windows affected. They are used to pinpoint the bounds where state needs to be updated.
63/// - `first_frame_start` and `last_frame_end` are the frame start and end of the first and
64///   the last windows affected. They are used to pinpoint the bounds where state needs to be
65///   included for computing the new state.
66#[derive(Debug, Educe)]
67#[educe(Clone, Copy)]
68pub(super) struct AffectedRange<'a> {
69    pub first_frame_start: &'a CacheKey,
70    pub first_curr_key: &'a CacheKey,
71    pub last_curr_key: &'a CacheKey,
72    pub last_frame_end: &'a CacheKey,
73}
74
75impl<'a> AffectedRange<'a> {
76    fn new(
77        first_frame_start: &'a CacheKey,
78        first_curr_key: &'a CacheKey,
79        last_curr_key: &'a CacheKey,
80        last_frame_end: &'a CacheKey,
81    ) -> Self {
82        Self {
83            first_frame_start,
84            first_curr_key,
85            last_curr_key,
86            last_frame_end,
87        }
88    }
89}
90
91/// A wrapper of [`PartitionCache`] that provides helper methods to manipulate the cache.
92/// By putting this type inside `private` module, we can avoid misuse of the internal fields and
93/// methods.
94pub(super) struct OverPartition<'a, S: StateStore> {
95    deduped_part_key: &'a OwnedRow,
96    range_cache: &'a mut PartitionCache,
97    cache_policy: CachePolicy,
98
99    calls: &'a Calls,
100    row_conv: RowConverter<'a>,
101
102    stats: OverPartitionStats,
103
104    _phantom: PhantomData<S>,
105}
106
107const MAGIC_BATCH_SIZE: usize = 512;
108
109impl<'a, S: StateStore> OverPartition<'a, S> {
110    #[allow(clippy::too_many_arguments)]
111    pub fn new(
112        deduped_part_key: &'a OwnedRow,
113        cache: &'a mut PartitionCache,
114        cache_policy: CachePolicy,
115        calls: &'a Calls,
116        row_conv: RowConverter<'a>,
117    ) -> Self {
118        Self {
119            deduped_part_key,
120            range_cache: cache,
121            cache_policy,
122
123            calls,
124            row_conv,
125
126            stats: Default::default(),
127
128            _phantom: PhantomData,
129        }
130    }
131
132    /// Get a summary for the execution happened in the [`OverPartition`] in current round.
133    /// This will consume the [`OverPartition`] value itself.
134    pub fn summarize(self) -> OverPartitionStats {
135        // We may extend this function in the future.
136        self.stats
137    }
138
139    /// Get the number of cached entries ignoring sentinels.
140    pub fn cache_real_len(&self) -> usize {
141        self.range_cache.normal_len()
142    }
143
144    /// Build changes for the partition, with the given `delta`. Necessary maintenance of the range
145    /// cache will be done during this process, like loading rows from the `table` into the cache.
146    pub async fn build_changes(
147        &mut self,
148        table: &StateTable<S>,
149        mut delta: PartitionDelta,
150    ) -> StreamExecutorResult<(
151        BTreeMap<StateKey, Record<OwnedRow>>,
152        Option<RangeInclusive<StateKey>>,
153    )> {
154        let calls = self.calls;
155        let input_schema_len = table.get_data_types().len() - calls.len();
156        let numbering_only = calls.numbering_only;
157        let has_rank = calls.has_rank;
158
159        // return values
160        let mut part_changes = BTreeMap::new();
161        let mut accessed_range: Option<RangeInclusive<StateKey>> = None;
162
163        // stats
164        let mut accessed_entry_count = 0;
165        let mut compute_count = 0;
166        let mut same_output_count = 0;
167
168        // Find affected ranges, this also ensures that all rows in the affected ranges are loaded into the cache.
169        let (part_with_delta, affected_ranges) =
170            self.find_affected_ranges(table, &mut delta).await?;
171
172        let snapshot = part_with_delta.snapshot();
173        let delta = part_with_delta.delta();
174        let last_delta_key = delta.last_key_value().map(|(k, _)| k.as_normal_expect());
175
176        // Generate delete changes first, because deletes are skipped during iteration over
177        // `part_with_delta` in the next step.
178        for (key, change) in delta {
179            if change.is_delete() {
180                part_changes.insert(
181                    key.as_normal_expect().clone(),
182                    Record::Delete {
183                        old_row: snapshot.get(key).unwrap().clone(),
184                    },
185                );
186            }
187        }
188
189        for AffectedRange {
190            first_frame_start,
191            first_curr_key,
192            last_curr_key,
193            last_frame_end,
194        } in affected_ranges
195        {
196            assert!(first_frame_start <= first_curr_key);
197            assert!(first_curr_key <= last_curr_key);
198            assert!(last_curr_key <= last_frame_end);
199            assert!(first_frame_start.is_normal());
200            assert!(first_curr_key.is_normal());
201            assert!(last_curr_key.is_normal());
202            assert!(last_frame_end.is_normal());
203
204            let last_delta_key = last_delta_key.unwrap();
205
206            if let Some(accessed_range) = accessed_range.as_mut() {
207                let min_start = first_frame_start
208                    .as_normal_expect()
209                    .min(accessed_range.start())
210                    .clone();
211                let max_end = last_frame_end
212                    .as_normal_expect()
213                    .max(accessed_range.end())
214                    .clone();
215                *accessed_range = min_start..=max_end;
216            } else {
217                accessed_range = Some(
218                    first_frame_start.as_normal_expect().clone()
219                        ..=last_frame_end.as_normal_expect().clone(),
220                );
221            }
222
223            let mut states =
224                WindowStates::new(calls.iter().map(create_window_state).try_collect()?);
225
226            // Populate window states with the affected range of rows.
227            {
228                let mut cursor = part_with_delta
229                    .before(first_frame_start)
230                    .expect("first frame start key must exist");
231
232                while let Some((key, row)) = cursor.next() {
233                    accessed_entry_count += 1;
234
235                    for (call, state) in calls.iter().zip_eq_fast(states.iter_mut()) {
236                        // TODO(rc): batch appending
237                        // TODO(rc): append not only the arguments but also the old output for optimization
238                        state.append(
239                            key.as_normal_expect().clone(),
240                            row.project(call.args.val_indices())
241                                .into_owned_row()
242                                .as_inner()
243                                .into(),
244                        );
245                    }
246
247                    if key == last_frame_end {
248                        break;
249                    }
250                }
251            }
252
253            // Slide to the first affected key. We can safely pass in `first_curr_key` here
254            // because it definitely exists in the states by the definition of affected range.
255            states.just_slide_to(first_curr_key.as_normal_expect())?;
256            let mut curr_key_cursor = part_with_delta.before(first_curr_key).unwrap();
257            assert_eq!(
258                states.curr_key(),
259                curr_key_cursor
260                    .peek_next()
261                    .map(|(k, _)| k)
262                    .map(CacheKey::as_normal_expect)
263            );
264
265            // Slide and generate changes.
266            while let Some((key, row)) = curr_key_cursor.next() {
267                let mut should_stop = false;
268
269                let output = states.slide_no_evict_hint()?;
270                compute_count += 1;
271
272                let old_output = &row.as_inner()[input_schema_len..];
273                if !old_output.is_empty() && old_output == output {
274                    same_output_count += 1;
275
276                    if numbering_only {
277                        if has_rank {
278                            // It's possible that an `Insert` doesn't affect it's ties but affects
279                            // all the following rows, so we need to check the `order_key`.
280                            if key.as_normal_expect().order_key > last_delta_key.order_key {
281                                // there won't be any more changes after this point, we can stop early
282                                should_stop = true;
283                            }
284                        } else if key.as_normal_expect() >= last_delta_key {
285                            // there won't be any more changes after this point, we can stop early
286                            should_stop = true;
287                        }
288                    }
289                }
290
291                let new_row = OwnedRow::new(
292                    row.as_inner()
293                        .iter()
294                        .take(input_schema_len)
295                        .cloned()
296                        .chain(output)
297                        .collect(),
298                );
299
300                if let Some(old_row) = snapshot.get(key).cloned() {
301                    // update
302                    if old_row != new_row {
303                        part_changes.insert(
304                            key.as_normal_expect().clone(),
305                            Record::Update { old_row, new_row },
306                        );
307                    }
308                } else {
309                    // insert
310                    part_changes.insert(key.as_normal_expect().clone(), Record::Insert { new_row });
311                }
312
313                if should_stop || key == last_curr_key {
314                    break;
315                }
316            }
317        }
318
319        self.stats.accessed_entry_count += accessed_entry_count;
320        self.stats.compute_count += compute_count;
321        self.stats.same_output_count += same_output_count;
322
323        Ok((part_changes, accessed_range))
324    }
325
326    /// Write a change record to state table and cache.
327    /// This function must be called after finding affected ranges, which means the change records
328    /// should never exceed the cached range.
329    pub fn write_record(
330        &mut self,
331        table: &mut StateTable<S>,
332        key: StateKey,
333        record: Record<OwnedRow>,
334    ) {
335        table.write_record(record.as_ref());
336        match record {
337            Record::Insert { new_row } | Record::Update { new_row, .. } => {
338                self.range_cache.insert(CacheKey::from(key), new_row);
339            }
340            Record::Delete { .. } => {
341                self.range_cache.remove(&CacheKey::from(key));
342
343                if self.range_cache.normal_len() == 0 && self.range_cache.len() == 1 {
344                    // only one sentinel remains, should insert the other
345                    self.range_cache
346                        .insert(CacheKey::Smallest, OwnedRow::empty());
347                    self.range_cache
348                        .insert(CacheKey::Largest, OwnedRow::empty());
349                }
350            }
351        }
352    }
353
354    /// Find all ranges in the partition that are affected by the given delta.
355    /// The returned ranges are guaranteed to be sorted and non-overlapping. All keys in the ranges
356    /// are guaranteed to be cached, which means they should be [`Sentinelled::Normal`]s.
357    async fn find_affected_ranges<'s, 'delta>(
358        &'s mut self,
359        table: &StateTable<S>,
360        delta: &'delta mut PartitionDelta,
361    ) -> StreamExecutorResult<(
362        DeltaBTreeMap<'delta, CacheKey, OwnedRow>,
363        Vec<AffectedRange<'delta>>,
364    )>
365    where
366        'a: 'delta,
367        's: 'delta,
368    {
369        if delta.is_empty() {
370            return Ok((DeltaBTreeMap::new(self.range_cache.inner(), delta), vec![]));
371        }
372
373        self.ensure_delta_in_cache(table, delta).await?;
374        let delta = &*delta; // let's make it immutable
375
376        let delta_first = delta.first_key_value().unwrap().0.as_normal_expect();
377        let delta_last = delta.last_key_value().unwrap().0.as_normal_expect();
378
379        let range_frame_logical_curr =
380            calc_logical_curr_for_range_frames(&self.calls.range_frames, delta_first, delta_last);
381
382        loop {
383            // TERMINATEABILITY: `extend_cache_leftward_by_n` and `extend_cache_rightward_by_n` keep
384            // pushing the cache to the boundary of current partition. In these two methods, when
385            // any side of boundary is reached, the sentinel key will be removed, so finally
386            // `Self::find_affected_ranges_readonly` will return `Ok`.
387
388            // SAFETY: Here we shortly borrow the range cache and turn the reference into a
389            // `'delta` one to bypass the borrow checker. This is safe because we only return
390            // the reference once we don't need to do any further mutation.
391            let cache_inner = unsafe { &*(self.range_cache.inner() as *const _) };
392            let part_with_delta = DeltaBTreeMap::new(cache_inner, delta);
393
394            self.stats.lookup_count += 1;
395            let res = self
396                .find_affected_ranges_readonly(part_with_delta, range_frame_logical_curr.as_ref());
397
398            let (need_extend_leftward, need_extend_rightward) = match res {
399                Ok(ranges) => return Ok((part_with_delta, ranges)),
400                Err(cache_extend_hint) => cache_extend_hint,
401            };
402
403            if need_extend_leftward {
404                self.stats.left_miss_count += 1;
405                tracing::trace!(partition=?self.deduped_part_key, "partition cache left extension triggered");
406                let left_most = self
407                    .range_cache
408                    .first_normal_key()
409                    .unwrap_or(delta_first)
410                    .clone();
411                self.extend_cache_leftward_by_n(table, &left_most).await?;
412            }
413            if need_extend_rightward {
414                self.stats.right_miss_count += 1;
415                tracing::trace!(partition=?self.deduped_part_key, "partition cache right extension triggered");
416                let right_most = self
417                    .range_cache
418                    .last_normal_key()
419                    .unwrap_or(delta_last)
420                    .clone();
421                self.extend_cache_rightward_by_n(table, &right_most).await?;
422            }
423            tracing::trace!(partition=?self.deduped_part_key, "partition cache extended");
424        }
425    }
426
427    async fn ensure_delta_in_cache(
428        &mut self,
429        table: &StateTable<S>,
430        delta: &mut PartitionDelta,
431    ) -> StreamExecutorResult<()> {
432        if delta.is_empty() {
433            return Ok(());
434        }
435
436        let delta_first = delta.first_key_value().unwrap().0.as_normal_expect();
437        let delta_last = delta.last_key_value().unwrap().0.as_normal_expect();
438
439        if self.cache_policy.is_full() {
440            // ensure everything is in the cache
441            self.extend_cache_to_boundary(table).await?;
442        } else {
443            // TODO(rc): later we should extend cache using `self.calls.super_rows_frame_bounds` and
444            // `range_frame_logical_curr` as hints.
445
446            // ensure the cache covers all delta (if possible)
447            self.extend_cache_by_range(table, delta_first..=delta_last)
448                .await?;
449        }
450
451        if !enable_strict_consistency() {
452            // in non-strict mode, we should ensure the delta is consistent with the cache
453            let cache = self.range_cache.inner();
454            delta.retain(|key, change| match &*change {
455                Change::Insert(_) => {
456                    // this also includes the case of double-insert and ghost-update,
457                    // but since we already lost the information, let's just ignore it
458                    true
459                }
460                Change::Delete => {
461                    // if the key is not in the cache, it's a ghost-delete
462                    let consistent = cache.contains_key(key);
463                    if !consistent {
464                        consistency_error!(?key, "removing a row with non-existing key");
465                    }
466                    consistent
467                }
468            });
469        }
470
471        Ok(())
472    }
473
474    /// Try to find affected ranges on immutable range cache + delta. If the algorithm reaches
475    /// any sentinel node in the cache, which means some entries in the affected range may be
476    /// in the state table, it returns an `Err((bool, bool))` to notify the caller that the
477    /// left side or the right side or both sides of the cache should be extended.
478    ///
479    /// TODO(rc): Currently at most one range will be in the result vector. Ideally we should
480    /// recognize uncontinuous changes in the delta and find multiple ranges, but that will be
481    /// too complex for now.
482    fn find_affected_ranges_readonly<'delta>(
483        &self,
484        part_with_delta: DeltaBTreeMap<'delta, CacheKey, OwnedRow>,
485        range_frame_logical_curr: Option<&(Sentinelled<Datum>, Sentinelled<Datum>)>,
486    ) -> std::result::Result<Vec<AffectedRange<'delta>>, (bool, bool)> {
487        if part_with_delta.first_key().is_none() {
488            // nothing is left after applying the delta, meaning all entries are deleted
489            return Ok(vec![]);
490        }
491
492        let delta_first_key = part_with_delta.delta().first_key_value().unwrap().0;
493        let delta_last_key = part_with_delta.delta().last_key_value().unwrap().0;
494        let cache_key_pk_len = delta_first_key.as_normal_expect().pk.len();
495
496        if part_with_delta.snapshot().is_empty() {
497            // all existing keys are inserted in the delta
498            return Ok(vec![AffectedRange::new(
499                delta_first_key,
500                delta_first_key,
501                delta_last_key,
502                delta_last_key,
503            )]);
504        }
505
506        let first_key = part_with_delta.first_key().unwrap();
507        let last_key = part_with_delta.last_key().unwrap();
508
509        let first_curr_key = if self.calls.end_is_unbounded || delta_first_key == first_key {
510            // If the frame end is unbounded, or, the first key is in delta, then the frame corresponding
511            // to the first key is always affected.
512            first_key
513        } else {
514            let mut key = find_first_curr_for_rows_frame(
515                &self.calls.super_rows_frame_bounds,
516                part_with_delta,
517                delta_first_key,
518            );
519
520            if let Some((logical_first_curr, _)) = range_frame_logical_curr {
521                let logical_curr = logical_first_curr.as_normal_expect(); // otherwise should go `end_is_unbounded` branch
522                let new_key = find_left_for_range_frames(
523                    &self.calls.range_frames,
524                    part_with_delta,
525                    logical_curr,
526                    cache_key_pk_len,
527                );
528                key = std::cmp::min(key, new_key);
529            }
530
531            key
532        };
533
534        let last_curr_key = if self.calls.start_is_unbounded || delta_last_key == last_key {
535            // similar to `first_curr_key`
536            last_key
537        } else {
538            let mut key = find_last_curr_for_rows_frame(
539                &self.calls.super_rows_frame_bounds,
540                part_with_delta,
541                delta_last_key,
542            );
543
544            if let Some((_, logical_last_curr)) = range_frame_logical_curr {
545                let logical_curr = logical_last_curr.as_normal_expect(); // otherwise should go `start_is_unbounded` branch
546                let new_key = find_right_for_range_frames(
547                    &self.calls.range_frames,
548                    part_with_delta,
549                    logical_curr,
550                    cache_key_pk_len,
551                );
552                key = std::cmp::max(key, new_key);
553            }
554
555            key
556        };
557
558        {
559            // We quickly return if there's any sentinel in `[first_curr_key, last_curr_key]`,
560            // just for the sake of simplicity.
561            let mut need_extend_leftward = false;
562            let mut need_extend_rightward = false;
563            for key in [first_curr_key, last_curr_key] {
564                if key.is_smallest() {
565                    need_extend_leftward = true;
566                } else if key.is_largest() {
567                    need_extend_rightward = true;
568                }
569            }
570            if need_extend_leftward || need_extend_rightward {
571                return Err((need_extend_leftward, need_extend_rightward));
572            }
573        }
574
575        // From now on we definitely have two normal `curr_key`s.
576
577        if first_curr_key > last_curr_key {
578            // Note that we cannot move the this check before the above block, because for example,
579            // if the range cache contains `[Smallest, 5, Largest]`, and the delta contains only
580            // `Delete 5`, the frame is `RANGE BETWEEN CURRENT ROW AND CURRENT ROW`, then
581            // `first_curr_key` will be `Largest`, `last_curr_key` will be `Smallest`, in this case
582            // there may be some other entries with order value `5` in the table, which should be
583            // *affected*.
584            return Ok(vec![]);
585        }
586
587        let range_frame_logical_boundary = calc_logical_boundary_for_range_frames(
588            &self.calls.range_frames,
589            first_curr_key.as_normal_expect(),
590            last_curr_key.as_normal_expect(),
591        );
592
593        let first_frame_start = if self.calls.start_is_unbounded || first_curr_key == first_key {
594            // If the frame start is unbounded, or, the first curr key is the first key, then the first key
595            // always need to be included in the affected range.
596            first_key
597        } else {
598            let mut key = find_frame_start_for_rows_frame(
599                &self.calls.super_rows_frame_bounds,
600                part_with_delta,
601                first_curr_key,
602            );
603
604            if let Some((logical_first_start, _)) = range_frame_logical_boundary.as_ref() {
605                let logical_boundary = logical_first_start.as_normal_expect(); // otherwise should go `end_is_unbounded` branch
606                let new_key = find_left_for_range_frames(
607                    &self.calls.range_frames,
608                    part_with_delta,
609                    logical_boundary,
610                    cache_key_pk_len,
611                );
612                key = std::cmp::min(key, new_key);
613            }
614
615            key
616        };
617        assert!(first_frame_start <= first_curr_key);
618
619        let last_frame_end = if self.calls.end_is_unbounded || last_curr_key == last_key {
620            // similar to `first_frame_start`
621            last_key
622        } else {
623            let mut key = find_frame_end_for_rows_frame(
624                &self.calls.super_rows_frame_bounds,
625                part_with_delta,
626                last_curr_key,
627            );
628
629            if let Some((_, logical_last_end)) = range_frame_logical_boundary.as_ref() {
630                let logical_boundary = logical_last_end.as_normal_expect(); // otherwise should go `end_is_unbounded` branch
631                let new_key = find_right_for_range_frames(
632                    &self.calls.range_frames,
633                    part_with_delta,
634                    logical_boundary,
635                    cache_key_pk_len,
636                );
637                key = std::cmp::max(key, new_key);
638            }
639
640            key
641        };
642        assert!(last_frame_end >= last_curr_key);
643
644        let mut need_extend_leftward = false;
645        let mut need_extend_rightward = false;
646        for key in [
647            first_curr_key,
648            last_curr_key,
649            first_frame_start,
650            last_frame_end,
651        ] {
652            if key.is_smallest() {
653                need_extend_leftward = true;
654            } else if key.is_largest() {
655                need_extend_rightward = true;
656            }
657        }
658
659        if need_extend_leftward || need_extend_rightward {
660            Err((need_extend_leftward, need_extend_rightward))
661        } else {
662            Ok(vec![AffectedRange::new(
663                first_frame_start,
664                first_curr_key,
665                last_curr_key,
666                last_frame_end,
667            )])
668        }
669    }
670
671    async fn extend_cache_to_boundary(
672        &mut self,
673        table: &StateTable<S>,
674    ) -> StreamExecutorResult<()> {
675        if self.range_cache.normal_len() == self.range_cache.len() {
676            // no sentinel in the cache, meaning we already cached all entries of this partition
677            return Ok(());
678        }
679
680        tracing::trace!(partition=?self.deduped_part_key, "loading the whole partition into cache");
681
682        let mut new_cache = PartitionCache::new_without_sentinels(); // shouldn't use `new` here because we are extending to boundary
683        let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
684        let table_iter = table
685            .iter_with_prefix(self.deduped_part_key, sub_range, PrefetchOptions::default())
686            .await?;
687
688        #[for_await]
689        for row in table_iter {
690            let row: OwnedRow = row?.into_owned_row();
691            new_cache.insert(self.row_conv.row_to_state_key(&row)?.into(), row);
692        }
693        *self.range_cache = new_cache;
694
695        Ok(())
696    }
697
698    /// Try to load the given range of entries from table into cache.
699    /// When the function returns, it's guaranteed that there's no entry in the table that is within
700    /// the given range but not in the cache.
701    async fn extend_cache_by_range(
702        &mut self,
703        table: &StateTable<S>,
704        range: RangeInclusive<&StateKey>,
705    ) -> StreamExecutorResult<()> {
706        if self.range_cache.normal_len() == self.range_cache.len() {
707            // no sentinel in the cache, meaning we already cached all entries of this partition
708            return Ok(());
709        }
710        assert!(self.range_cache.len() >= 2);
711
712        let cache_first_normal_key = self.range_cache.first_normal_key();
713        let cache_last_normal_key = self.range_cache.last_normal_key();
714
715        if cache_first_normal_key.is_some() && *range.end() < cache_first_normal_key.unwrap()
716            || cache_last_normal_key.is_some() && *range.start() > cache_last_normal_key.unwrap()
717        {
718            // completely not overlapping, for the sake of simplicity, we re-init the cache
719            tracing::debug!(
720                partition=?self.deduped_part_key,
721                cache_first=?cache_first_normal_key,
722                cache_last=?cache_last_normal_key,
723                range=?range,
724                "modified range is completely non-overlapping with the cached range, re-initializing the cache"
725            );
726            *self.range_cache = PartitionCache::new();
727        }
728
729        if self.cache_real_len() == 0 {
730            // no normal entry in the cache, just load the given range
731            let table_sub_range = (
732                Bound::Included(self.row_conv.state_key_to_table_sub_pk(range.start())?),
733                Bound::Included(self.row_conv.state_key_to_table_sub_pk(range.end())?),
734            );
735            tracing::debug!(
736                partition=?self.deduped_part_key,
737                table_sub_range=?table_sub_range,
738                "cache is empty, just loading the given range"
739            );
740            return self
741                .extend_cache_by_range_inner(table, table_sub_range)
742                .await;
743        }
744
745        let cache_real_first_key = self
746            .range_cache
747            .first_normal_key()
748            .expect("cache real len is not 0");
749        if self.range_cache.left_is_sentinel() && *range.start() < cache_real_first_key {
750            // extend leftward only if there's smallest sentinel
751            let table_sub_range = (
752                Bound::Included(self.row_conv.state_key_to_table_sub_pk(range.start())?),
753                Bound::Excluded(
754                    self.row_conv
755                        .state_key_to_table_sub_pk(cache_real_first_key)?,
756                ),
757            );
758            tracing::trace!(
759                partition=?self.deduped_part_key,
760                table_sub_range=?table_sub_range,
761                "loading the left half of given range"
762            );
763            self.extend_cache_by_range_inner(table, table_sub_range)
764                .await?;
765        }
766
767        let cache_real_last_key = self
768            .range_cache
769            .last_normal_key()
770            .expect("cache real len is not 0");
771        if self.range_cache.right_is_sentinel() && *range.end() > cache_real_last_key {
772            // extend rightward only if there's largest sentinel
773            let table_sub_range = (
774                Bound::Excluded(
775                    self.row_conv
776                        .state_key_to_table_sub_pk(cache_real_last_key)?,
777                ),
778                Bound::Included(self.row_conv.state_key_to_table_sub_pk(range.end())?),
779            );
780            tracing::trace!(
781                partition=?self.deduped_part_key,
782                table_sub_range=?table_sub_range,
783                "loading the right half of given range"
784            );
785            self.extend_cache_by_range_inner(table, table_sub_range)
786                .await?;
787        }
788
789        // prefetch rows before the start of the range
790        self.extend_cache_leftward_by_n(table, range.start())
791            .await?;
792
793        // prefetch rows after the end of the range
794        self.extend_cache_rightward_by_n(table, range.end()).await
795    }
796
797    async fn extend_cache_leftward_by_n(
798        &mut self,
799        table: &StateTable<S>,
800        hint_key: &StateKey,
801    ) -> StreamExecutorResult<()> {
802        if self.range_cache.normal_len() == self.range_cache.len() {
803            // no sentinel in the cache, meaning we already cached all entries of this partition
804            return Ok(());
805        }
806        assert!(self.range_cache.len() >= 2);
807
808        let left_second = {
809            let mut iter = self.range_cache.inner().iter();
810            let left_first = iter.next().unwrap().0;
811            if left_first.is_normal() {
812                // the leftside already reaches the beginning of this partition in the table
813                return Ok(());
814            }
815            iter.next().unwrap().0
816        };
817        let range_to_exclusive = match left_second {
818            CacheKey::Normal(smallest_in_cache) => smallest_in_cache,
819            CacheKey::Largest => hint_key, // no normal entry in the cache
820            _ => unreachable!(),
821        }
822        .clone();
823
824        self.extend_cache_leftward_by_n_inner(table, &range_to_exclusive)
825            .await?;
826
827        if self.cache_real_len() == 0 {
828            // Cache was empty, and extending leftward didn't add anything to the cache, but we
829            // can't just remove the smallest sentinel, we must also try extending rightward.
830            self.extend_cache_rightward_by_n_inner(table, hint_key)
831                .await?;
832            if self.cache_real_len() == 0 {
833                // still empty, meaning the table is empty
834                self.range_cache.remove(&CacheKey::Smallest);
835                self.range_cache.remove(&CacheKey::Largest);
836            }
837        }
838
839        Ok(())
840    }
841
842    async fn extend_cache_rightward_by_n(
843        &mut self,
844        table: &StateTable<S>,
845        hint_key: &StateKey,
846    ) -> StreamExecutorResult<()> {
847        if self.range_cache.normal_len() == self.range_cache.len() {
848            // no sentinel in the cache, meaning we already cached all entries of this partition
849            return Ok(());
850        }
851        assert!(self.range_cache.len() >= 2);
852
853        let right_second = {
854            let mut iter = self.range_cache.inner().iter();
855            let right_first = iter.next_back().unwrap().0;
856            if right_first.is_normal() {
857                // the rightside already reaches the end of this partition in the table
858                return Ok(());
859            }
860            iter.next_back().unwrap().0
861        };
862        let range_from_exclusive = match right_second {
863            CacheKey::Normal(largest_in_cache) => largest_in_cache,
864            CacheKey::Smallest => hint_key, // no normal entry in the cache
865            _ => unreachable!(),
866        }
867        .clone();
868
869        self.extend_cache_rightward_by_n_inner(table, &range_from_exclusive)
870            .await?;
871
872        if self.cache_real_len() == 0 {
873            // Cache was empty, and extending rightward didn't add anything to the cache, but we
874            // can't just remove the smallest sentinel, we must also try extending leftward.
875            self.extend_cache_leftward_by_n_inner(table, hint_key)
876                .await?;
877            if self.cache_real_len() == 0 {
878                // still empty, meaning the table is empty
879                self.range_cache.remove(&CacheKey::Smallest);
880                self.range_cache.remove(&CacheKey::Largest);
881            }
882        }
883
884        Ok(())
885    }
886
887    async fn extend_cache_by_range_inner(
888        &mut self,
889        table: &StateTable<S>,
890        table_sub_range: (Bound<impl Row>, Bound<impl Row>),
891    ) -> StreamExecutorResult<()> {
892        let stream = table
893            .iter_with_prefix(
894                self.deduped_part_key,
895                &table_sub_range,
896                PrefetchOptions::default(),
897            )
898            .await?;
899
900        #[for_await]
901        for row in stream {
902            let row: OwnedRow = row?.into_owned_row();
903            let key = self.row_conv.row_to_state_key(&row)?;
904            self.range_cache.insert(CacheKey::from(key), row);
905        }
906
907        Ok(())
908    }
909
910    async fn extend_cache_leftward_by_n_inner(
911        &mut self,
912        table: &StateTable<S>,
913        range_to_exclusive: &StateKey,
914    ) -> StreamExecutorResult<()> {
915        let mut n_extended = 0usize;
916        {
917            let sub_range = (
918                Bound::<OwnedRow>::Unbounded,
919                Bound::Excluded(
920                    self.row_conv
921                        .state_key_to_table_sub_pk(range_to_exclusive)?,
922                ),
923            );
924            let rev_stream = table
925                .rev_iter_with_prefix(
926                    self.deduped_part_key,
927                    &sub_range,
928                    PrefetchOptions::default(),
929                )
930                .await?;
931
932            #[for_await]
933            for row in rev_stream {
934                let row: OwnedRow = row?.into_owned_row();
935
936                let key = self.row_conv.row_to_state_key(&row)?;
937                self.range_cache.insert(CacheKey::from(key), row);
938
939                n_extended += 1;
940                if n_extended == MAGIC_BATCH_SIZE {
941                    break;
942                }
943            }
944        }
945
946        if n_extended < MAGIC_BATCH_SIZE && self.cache_real_len() > 0 {
947            // we reached the beginning of this partition in the table
948            self.range_cache.remove(&CacheKey::Smallest);
949        }
950
951        Ok(())
952    }
953
954    async fn extend_cache_rightward_by_n_inner(
955        &mut self,
956        table: &StateTable<S>,
957        range_from_exclusive: &StateKey,
958    ) -> StreamExecutorResult<()> {
959        let mut n_extended = 0usize;
960        {
961            let sub_range = (
962                Bound::Excluded(
963                    self.row_conv
964                        .state_key_to_table_sub_pk(range_from_exclusive)?,
965                ),
966                Bound::<OwnedRow>::Unbounded,
967            );
968            let stream = table
969                .iter_with_prefix(
970                    self.deduped_part_key,
971                    &sub_range,
972                    PrefetchOptions::default(),
973                )
974                .await?;
975
976            #[for_await]
977            for row in stream {
978                let row: OwnedRow = row?.into_owned_row();
979
980                let key = self.row_conv.row_to_state_key(&row)?;
981                self.range_cache.insert(CacheKey::from(key), row);
982
983                n_extended += 1;
984                if n_extended == MAGIC_BATCH_SIZE {
985                    break;
986                }
987            }
988        }
989
990        if n_extended < MAGIC_BATCH_SIZE && self.cache_real_len() > 0 {
991            // we reached the end of this partition in the table
992            self.range_cache.remove(&CacheKey::Largest);
993        }
994
995        Ok(())
996    }
997}