risingwave_stream/executor/over_window/
over_partition.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Types and functions that store or manipulate state/cache inside one single over window
16//! partition.
17
18use std::collections::BTreeMap;
19use std::marker::PhantomData;
20use std::ops::{Bound, RangeInclusive};
21
22use delta_btree_map::{Change, DeltaBTreeMap};
23use educe::Educe;
24use futures_async_stream::for_await;
25use risingwave_common::array::stream_record::Record;
26use risingwave_common::config::streaming::OverWindowCachePolicy as CachePolicy;
27use risingwave_common::row::{OwnedRow, Row, RowExt};
28use risingwave_common::types::{Datum, Sentinelled};
29use risingwave_common::util::iter_util::ZipEqFast;
30use risingwave_expr::window_function::{StateKey, WindowStates, create_window_state};
31use risingwave_storage::StateStore;
32use risingwave_storage::store::PrefetchOptions;
33
34use super::general::{Calls, RowConverter};
35use super::range_cache::{CacheKey, PartitionCache};
36use crate::common::table::state_table::StateTable;
37use crate::consistency::{consistency_error, enable_strict_consistency};
38use crate::executor::StreamExecutorResult;
39use crate::executor::over_window::frame_finder::*;
40
41/// Changes happened in one over window partition.
42pub(super) type PartitionDelta = BTreeMap<CacheKey, Change<OwnedRow>>;
43
44#[derive(Default, Debug)]
45pub(super) struct OverPartitionStats {
46    // stats for range cache operations
47    pub lookup_count: u64,
48    pub left_miss_count: u64,
49    pub right_miss_count: u64,
50
51    // stats for window function state computation
52    pub accessed_entry_count: u64,
53    pub compute_count: u64,
54    pub same_output_count: u64,
55}
56
57/// [`AffectedRange`] represents a range of keys that are affected by a delta.
58/// The [`CacheKey`] fields are keys in the partition range cache + delta, which is
59/// represented by [`DeltaBTreeMap`].
60///
61/// - `first_curr_key` and `last_curr_key` are the current keys of the first and the last
62///   windows affected. They are used to pinpoint the bounds where state needs to be updated.
63/// - `first_frame_start` and `last_frame_end` are the frame start and end of the first and
64///   the last windows affected. They are used to pinpoint the bounds where state needs to be
65///   included for computing the new state.
66#[derive(Debug, Educe)]
67#[educe(Clone, Copy)]
68pub(super) struct AffectedRange<'a> {
69    pub first_frame_start: &'a CacheKey,
70    pub first_curr_key: &'a CacheKey,
71    pub last_curr_key: &'a CacheKey,
72    pub last_frame_end: &'a CacheKey,
73}
74
75impl<'a> AffectedRange<'a> {
76    fn new(
77        first_frame_start: &'a CacheKey,
78        first_curr_key: &'a CacheKey,
79        last_curr_key: &'a CacheKey,
80        last_frame_end: &'a CacheKey,
81    ) -> Self {
82        Self {
83            first_frame_start,
84            first_curr_key,
85            last_curr_key,
86            last_frame_end,
87        }
88    }
89}
90
91/// A wrapper of [`PartitionCache`] that provides helper methods to manipulate the cache.
92/// By putting this type inside `private` module, we can avoid misuse of the internal fields and
93/// methods.
94pub(super) struct OverPartition<'a, S: StateStore> {
95    deduped_part_key: &'a OwnedRow,
96    range_cache: &'a mut PartitionCache,
97    cache_policy: CachePolicy,
98
99    calls: &'a Calls,
100    row_conv: RowConverter<'a>,
101
102    stats: OverPartitionStats,
103
104    _phantom: PhantomData<S>,
105}
106
107const MAGIC_BATCH_SIZE: usize = 512;
108
109impl<'a, S: StateStore> OverPartition<'a, S> {
110    pub fn new(
111        deduped_part_key: &'a OwnedRow,
112        cache: &'a mut PartitionCache,
113        cache_policy: CachePolicy,
114        calls: &'a Calls,
115        row_conv: RowConverter<'a>,
116    ) -> Self {
117        Self {
118            deduped_part_key,
119            range_cache: cache,
120            cache_policy,
121
122            calls,
123            row_conv,
124
125            stats: Default::default(),
126
127            _phantom: PhantomData,
128        }
129    }
130
131    /// Get a summary for the execution happened in the [`OverPartition`] in current round.
132    /// This will consume the [`OverPartition`] value itself.
133    pub fn summarize(self) -> OverPartitionStats {
134        // We may extend this function in the future.
135        self.stats
136    }
137
138    /// Get the number of cached entries ignoring sentinels.
139    pub fn cache_real_len(&self) -> usize {
140        self.range_cache.normal_len()
141    }
142
143    /// Build changes for the partition, with the given `delta`. Necessary maintenance of the range
144    /// cache will be done during this process, like loading rows from the `table` into the cache.
145    pub async fn build_changes(
146        &mut self,
147        table: &StateTable<S>,
148        mut delta: PartitionDelta,
149    ) -> StreamExecutorResult<(
150        BTreeMap<StateKey, Record<OwnedRow>>,
151        Option<RangeInclusive<StateKey>>,
152    )> {
153        let calls = self.calls;
154        let input_schema_len = table.get_data_types().len() - calls.len();
155        let numbering_only = calls.numbering_only;
156        let has_rank = calls.has_rank;
157
158        // return values
159        let mut part_changes = BTreeMap::new();
160        let mut accessed_range: Option<RangeInclusive<StateKey>> = None;
161
162        // stats
163        let mut accessed_entry_count = 0;
164        let mut compute_count = 0;
165        let mut same_output_count = 0;
166
167        // Find affected ranges, this also ensures that all rows in the affected ranges are loaded into the cache.
168        let (part_with_delta, affected_ranges) =
169            self.find_affected_ranges(table, &mut delta).await?;
170
171        let snapshot = part_with_delta.snapshot();
172        let delta = part_with_delta.delta();
173        let last_delta_key = delta.last_key_value().map(|(k, _)| k.as_normal_expect());
174
175        // Generate delete changes first, because deletes are skipped during iteration over
176        // `part_with_delta` in the next step.
177        for (key, change) in delta {
178            if change.is_delete() {
179                part_changes.insert(
180                    key.as_normal_expect().clone(),
181                    Record::Delete {
182                        old_row: snapshot.get(key).unwrap().clone(),
183                    },
184                );
185            }
186        }
187
188        for AffectedRange {
189            first_frame_start,
190            first_curr_key,
191            last_curr_key,
192            last_frame_end,
193        } in affected_ranges
194        {
195            assert!(first_frame_start <= first_curr_key);
196            assert!(first_curr_key <= last_curr_key);
197            assert!(last_curr_key <= last_frame_end);
198            assert!(first_frame_start.is_normal());
199            assert!(first_curr_key.is_normal());
200            assert!(last_curr_key.is_normal());
201            assert!(last_frame_end.is_normal());
202
203            let last_delta_key = last_delta_key.unwrap();
204
205            if let Some(accessed_range) = accessed_range.as_mut() {
206                let min_start = first_frame_start
207                    .as_normal_expect()
208                    .min(accessed_range.start())
209                    .clone();
210                let max_end = last_frame_end
211                    .as_normal_expect()
212                    .max(accessed_range.end())
213                    .clone();
214                *accessed_range = min_start..=max_end;
215            } else {
216                accessed_range = Some(
217                    first_frame_start.as_normal_expect().clone()
218                        ..=last_frame_end.as_normal_expect().clone(),
219                );
220            }
221
222            let mut states =
223                WindowStates::new(calls.iter().map(create_window_state).try_collect()?);
224
225            // Populate window states with the affected range of rows.
226            {
227                let mut cursor = part_with_delta
228                    .before(first_frame_start)
229                    .expect("first frame start key must exist");
230
231                while let Some((key, row)) = cursor.next() {
232                    accessed_entry_count += 1;
233
234                    for (call, state) in calls.iter().zip_eq_fast(states.iter_mut()) {
235                        // TODO(rc): batch appending
236                        // TODO(rc): append not only the arguments but also the old output for optimization
237                        state.append(
238                            key.as_normal_expect().clone(),
239                            row.project(call.args.val_indices())
240                                .into_owned_row()
241                                .as_inner()
242                                .into(),
243                        );
244                    }
245
246                    if key == last_frame_end {
247                        break;
248                    }
249                }
250            }
251
252            // Slide to the first affected key. We can safely pass in `first_curr_key` here
253            // because it definitely exists in the states by the definition of affected range.
254            states.just_slide_to(first_curr_key.as_normal_expect())?;
255            let mut curr_key_cursor = part_with_delta.before(first_curr_key).unwrap();
256            assert_eq!(
257                states.curr_key(),
258                curr_key_cursor
259                    .peek_next()
260                    .map(|(k, _)| k)
261                    .map(CacheKey::as_normal_expect)
262            );
263
264            // Slide and generate changes.
265            while let Some((key, row)) = curr_key_cursor.next() {
266                let mut should_stop = false;
267
268                let output = states.slide_no_evict_hint()?;
269                compute_count += 1;
270
271                let old_output = &row.as_inner()[input_schema_len..];
272                if !old_output.is_empty() && old_output == output {
273                    same_output_count += 1;
274
275                    if numbering_only {
276                        if has_rank {
277                            // It's possible that an `Insert` doesn't affect it's ties but affects
278                            // all the following rows, so we need to check the `order_key`.
279                            if key.as_normal_expect().order_key > last_delta_key.order_key {
280                                // there won't be any more changes after this point, we can stop early
281                                should_stop = true;
282                            }
283                        } else if key.as_normal_expect() >= last_delta_key {
284                            // there won't be any more changes after this point, we can stop early
285                            should_stop = true;
286                        }
287                    }
288                }
289
290                let new_row = OwnedRow::new(
291                    row.as_inner()
292                        .iter()
293                        .take(input_schema_len)
294                        .cloned()
295                        .chain(output)
296                        .collect(),
297                );
298
299                if let Some(old_row) = snapshot.get(key).cloned() {
300                    // update
301                    if old_row != new_row {
302                        part_changes.insert(
303                            key.as_normal_expect().clone(),
304                            Record::Update { old_row, new_row },
305                        );
306                    }
307                } else {
308                    // insert
309                    part_changes.insert(key.as_normal_expect().clone(), Record::Insert { new_row });
310                }
311
312                if should_stop || key == last_curr_key {
313                    break;
314                }
315            }
316        }
317
318        self.stats.accessed_entry_count += accessed_entry_count;
319        self.stats.compute_count += compute_count;
320        self.stats.same_output_count += same_output_count;
321
322        Ok((part_changes, accessed_range))
323    }
324
325    /// Write a change record to state table and cache.
326    /// This function must be called after finding affected ranges, which means the change records
327    /// should never exceed the cached range.
328    pub fn write_record(
329        &mut self,
330        table: &mut StateTable<S>,
331        key: StateKey,
332        record: Record<OwnedRow>,
333    ) {
334        table.write_record(record.as_ref());
335        match record {
336            Record::Insert { new_row } | Record::Update { new_row, .. } => {
337                self.range_cache.insert(CacheKey::from(key), new_row);
338            }
339            Record::Delete { .. } => {
340                self.range_cache.remove(&CacheKey::from(key));
341
342                if self.range_cache.normal_len() == 0 && self.range_cache.len() == 1 {
343                    // only one sentinel remains, should insert the other
344                    self.range_cache
345                        .insert(CacheKey::Smallest, OwnedRow::empty());
346                    self.range_cache
347                        .insert(CacheKey::Largest, OwnedRow::empty());
348                }
349            }
350        }
351    }
352
353    /// Find all ranges in the partition that are affected by the given delta.
354    /// The returned ranges are guaranteed to be sorted and non-overlapping. All keys in the ranges
355    /// are guaranteed to be cached, which means they should be [`Sentinelled::Normal`]s.
356    async fn find_affected_ranges<'s, 'delta>(
357        &'s mut self,
358        table: &StateTable<S>,
359        delta: &'delta mut PartitionDelta,
360    ) -> StreamExecutorResult<(
361        DeltaBTreeMap<'delta, CacheKey, OwnedRow>,
362        Vec<AffectedRange<'delta>>,
363    )>
364    where
365        'a: 'delta,
366        's: 'delta,
367    {
368        if delta.is_empty() {
369            return Ok((DeltaBTreeMap::new(self.range_cache.inner(), delta), vec![]));
370        }
371
372        self.ensure_delta_in_cache(table, delta).await?;
373        let delta = &*delta; // let's make it immutable
374
375        let delta_first = delta.first_key_value().unwrap().0.as_normal_expect();
376        let delta_last = delta.last_key_value().unwrap().0.as_normal_expect();
377
378        let range_frame_logical_curr =
379            calc_logical_curr_for_range_frames(&self.calls.range_frames, delta_first, delta_last);
380
381        loop {
382            // TERMINATEABILITY: `extend_cache_leftward_by_n` and `extend_cache_rightward_by_n` keep
383            // pushing the cache to the boundary of current partition. In these two methods, when
384            // any side of boundary is reached, the sentinel key will be removed, so finally
385            // `Self::find_affected_ranges_readonly` will return `Ok`.
386
387            // SAFETY: Here we shortly borrow the range cache and turn the reference into a
388            // `'delta` one to bypass the borrow checker. This is safe because we only return
389            // the reference once we don't need to do any further mutation.
390            let cache_inner = unsafe { &*(self.range_cache.inner() as *const _) };
391            let part_with_delta = DeltaBTreeMap::new(cache_inner, delta);
392
393            self.stats.lookup_count += 1;
394            let res = self
395                .find_affected_ranges_readonly(part_with_delta, range_frame_logical_curr.as_ref());
396
397            let (need_extend_leftward, need_extend_rightward) = match res {
398                Ok(ranges) => return Ok((part_with_delta, ranges)),
399                Err(cache_extend_hint) => cache_extend_hint,
400            };
401
402            if need_extend_leftward {
403                self.stats.left_miss_count += 1;
404                tracing::trace!(partition=?self.deduped_part_key, "partition cache left extension triggered");
405                let left_most = self
406                    .range_cache
407                    .first_normal_key()
408                    .unwrap_or(delta_first)
409                    .clone();
410                self.extend_cache_leftward_by_n(table, &left_most).await?;
411            }
412            if need_extend_rightward {
413                self.stats.right_miss_count += 1;
414                tracing::trace!(partition=?self.deduped_part_key, "partition cache right extension triggered");
415                let right_most = self
416                    .range_cache
417                    .last_normal_key()
418                    .unwrap_or(delta_last)
419                    .clone();
420                self.extend_cache_rightward_by_n(table, &right_most).await?;
421            }
422            tracing::trace!(partition=?self.deduped_part_key, "partition cache extended");
423        }
424    }
425
426    async fn ensure_delta_in_cache(
427        &mut self,
428        table: &StateTable<S>,
429        delta: &mut PartitionDelta,
430    ) -> StreamExecutorResult<()> {
431        if delta.is_empty() {
432            return Ok(());
433        }
434
435        let delta_first = delta.first_key_value().unwrap().0.as_normal_expect();
436        let delta_last = delta.last_key_value().unwrap().0.as_normal_expect();
437
438        if self.cache_policy.is_full() {
439            // ensure everything is in the cache
440            self.extend_cache_to_boundary(table).await?;
441        } else {
442            // TODO(rc): later we should extend cache using `self.calls.super_rows_frame_bounds` and
443            // `range_frame_logical_curr` as hints.
444
445            // ensure the cache covers all delta (if possible)
446            self.extend_cache_by_range(table, delta_first..=delta_last)
447                .await?;
448        }
449
450        if !enable_strict_consistency() {
451            // in non-strict mode, we should ensure the delta is consistent with the cache
452            let cache = self.range_cache.inner();
453            delta.retain(|key, change| match &*change {
454                Change::Insert(_) => {
455                    // this also includes the case of double-insert and ghost-update,
456                    // but since we already lost the information, let's just ignore it
457                    true
458                }
459                Change::Delete => {
460                    // if the key is not in the cache, it's a ghost-delete
461                    let consistent = cache.contains_key(key);
462                    if !consistent {
463                        consistency_error!(?key, "removing a row with non-existing key");
464                    }
465                    consistent
466                }
467            });
468        }
469
470        Ok(())
471    }
472
473    /// Try to find affected ranges on immutable range cache + delta. If the algorithm reaches
474    /// any sentinel node in the cache, which means some entries in the affected range may be
475    /// in the state table, it returns an `Err((bool, bool))` to notify the caller that the
476    /// left side or the right side or both sides of the cache should be extended.
477    ///
478    /// TODO(rc): Currently at most one range will be in the result vector. Ideally we should
479    /// recognize uncontinuous changes in the delta and find multiple ranges, but that will be
480    /// too complex for now.
481    fn find_affected_ranges_readonly<'delta>(
482        &self,
483        part_with_delta: DeltaBTreeMap<'delta, CacheKey, OwnedRow>,
484        range_frame_logical_curr: Option<&(Sentinelled<Datum>, Sentinelled<Datum>)>,
485    ) -> std::result::Result<Vec<AffectedRange<'delta>>, (bool, bool)> {
486        if part_with_delta.first_key().is_none() {
487            // nothing is left after applying the delta, meaning all entries are deleted
488            return Ok(vec![]);
489        }
490
491        let delta_first_key = part_with_delta.delta().first_key_value().unwrap().0;
492        let delta_last_key = part_with_delta.delta().last_key_value().unwrap().0;
493        let cache_key_pk_len = delta_first_key.as_normal_expect().pk.len();
494
495        if part_with_delta.snapshot().is_empty() {
496            // all existing keys are inserted in the delta
497            return Ok(vec![AffectedRange::new(
498                delta_first_key,
499                delta_first_key,
500                delta_last_key,
501                delta_last_key,
502            )]);
503        }
504
505        let first_key = part_with_delta.first_key().unwrap();
506        let last_key = part_with_delta.last_key().unwrap();
507
508        let first_curr_key = if self.calls.end_is_unbounded || delta_first_key == first_key {
509            // If the frame end is unbounded, or, the first key is in delta, then the frame corresponding
510            // to the first key is always affected.
511            first_key
512        } else {
513            let mut key = find_first_curr_for_rows_frame(
514                &self.calls.super_rows_frame_bounds,
515                part_with_delta,
516                delta_first_key,
517            );
518
519            if let Some((logical_first_curr, _)) = range_frame_logical_curr {
520                let logical_curr = logical_first_curr.as_normal_expect(); // otherwise should go `end_is_unbounded` branch
521                let new_key = find_left_for_range_frames(
522                    &self.calls.range_frames,
523                    part_with_delta,
524                    logical_curr,
525                    cache_key_pk_len,
526                );
527                key = std::cmp::min(key, new_key);
528            }
529
530            key
531        };
532
533        let last_curr_key = if self.calls.start_is_unbounded || delta_last_key == last_key {
534            // similar to `first_curr_key`
535            last_key
536        } else {
537            let mut key = find_last_curr_for_rows_frame(
538                &self.calls.super_rows_frame_bounds,
539                part_with_delta,
540                delta_last_key,
541            );
542
543            if let Some((_, logical_last_curr)) = range_frame_logical_curr {
544                let logical_curr = logical_last_curr.as_normal_expect(); // otherwise should go `start_is_unbounded` branch
545                let new_key = find_right_for_range_frames(
546                    &self.calls.range_frames,
547                    part_with_delta,
548                    logical_curr,
549                    cache_key_pk_len,
550                );
551                key = std::cmp::max(key, new_key);
552            }
553
554            key
555        };
556
557        {
558            // We quickly return if there's any sentinel in `[first_curr_key, last_curr_key]`,
559            // just for the sake of simplicity.
560            let mut need_extend_leftward = false;
561            let mut need_extend_rightward = false;
562            for key in [first_curr_key, last_curr_key] {
563                if key.is_smallest() {
564                    need_extend_leftward = true;
565                } else if key.is_largest() {
566                    need_extend_rightward = true;
567                }
568            }
569            if need_extend_leftward || need_extend_rightward {
570                return Err((need_extend_leftward, need_extend_rightward));
571            }
572        }
573
574        // From now on we definitely have two normal `curr_key`s.
575
576        if first_curr_key > last_curr_key {
577            // Note that we cannot move the this check before the above block, because for example,
578            // if the range cache contains `[Smallest, 5, Largest]`, and the delta contains only
579            // `Delete 5`, the frame is `RANGE BETWEEN CURRENT ROW AND CURRENT ROW`, then
580            // `first_curr_key` will be `Largest`, `last_curr_key` will be `Smallest`, in this case
581            // there may be some other entries with order value `5` in the table, which should be
582            // *affected*.
583            return Ok(vec![]);
584        }
585
586        let range_frame_logical_boundary = calc_logical_boundary_for_range_frames(
587            &self.calls.range_frames,
588            first_curr_key.as_normal_expect(),
589            last_curr_key.as_normal_expect(),
590        );
591
592        let first_frame_start = if self.calls.start_is_unbounded || first_curr_key == first_key {
593            // If the frame start is unbounded, or, the first curr key is the first key, then the first key
594            // always need to be included in the affected range.
595            first_key
596        } else {
597            let mut key = find_frame_start_for_rows_frame(
598                &self.calls.super_rows_frame_bounds,
599                part_with_delta,
600                first_curr_key,
601            );
602
603            if let Some((logical_first_start, _)) = range_frame_logical_boundary.as_ref() {
604                let logical_boundary = logical_first_start.as_normal_expect(); // otherwise should go `end_is_unbounded` branch
605                let new_key = find_left_for_range_frames(
606                    &self.calls.range_frames,
607                    part_with_delta,
608                    logical_boundary,
609                    cache_key_pk_len,
610                );
611                key = std::cmp::min(key, new_key);
612            }
613
614            key
615        };
616        assert!(first_frame_start <= first_curr_key);
617
618        let last_frame_end = if self.calls.end_is_unbounded || last_curr_key == last_key {
619            // similar to `first_frame_start`
620            last_key
621        } else {
622            let mut key = find_frame_end_for_rows_frame(
623                &self.calls.super_rows_frame_bounds,
624                part_with_delta,
625                last_curr_key,
626            );
627
628            if let Some((_, logical_last_end)) = range_frame_logical_boundary.as_ref() {
629                let logical_boundary = logical_last_end.as_normal_expect(); // otherwise should go `end_is_unbounded` branch
630                let new_key = find_right_for_range_frames(
631                    &self.calls.range_frames,
632                    part_with_delta,
633                    logical_boundary,
634                    cache_key_pk_len,
635                );
636                key = std::cmp::max(key, new_key);
637            }
638
639            key
640        };
641        assert!(last_frame_end >= last_curr_key);
642
643        let mut need_extend_leftward = false;
644        let mut need_extend_rightward = false;
645        for key in [
646            first_curr_key,
647            last_curr_key,
648            first_frame_start,
649            last_frame_end,
650        ] {
651            if key.is_smallest() {
652                need_extend_leftward = true;
653            } else if key.is_largest() {
654                need_extend_rightward = true;
655            }
656        }
657
658        if need_extend_leftward || need_extend_rightward {
659            Err((need_extend_leftward, need_extend_rightward))
660        } else {
661            Ok(vec![AffectedRange::new(
662                first_frame_start,
663                first_curr_key,
664                last_curr_key,
665                last_frame_end,
666            )])
667        }
668    }
669
670    async fn extend_cache_to_boundary(
671        &mut self,
672        table: &StateTable<S>,
673    ) -> StreamExecutorResult<()> {
674        if self.range_cache.normal_len() == self.range_cache.len() {
675            // no sentinel in the cache, meaning we already cached all entries of this partition
676            return Ok(());
677        }
678
679        tracing::trace!(partition=?self.deduped_part_key, "loading the whole partition into cache");
680
681        let mut new_cache = PartitionCache::new_without_sentinels(); // shouldn't use `new` here because we are extending to boundary
682        let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
683        let table_iter = table
684            .iter_with_prefix(self.deduped_part_key, sub_range, PrefetchOptions::default())
685            .await?;
686
687        #[for_await]
688        for row in table_iter {
689            let row: OwnedRow = row?.into_owned_row();
690            new_cache.insert(self.row_conv.row_to_state_key(&row)?.into(), row);
691        }
692        *self.range_cache = new_cache;
693
694        Ok(())
695    }
696
697    /// Try to load the given range of entries from table into cache.
698    /// When the function returns, it's guaranteed that there's no entry in the table that is within
699    /// the given range but not in the cache.
700    async fn extend_cache_by_range(
701        &mut self,
702        table: &StateTable<S>,
703        range: RangeInclusive<&StateKey>,
704    ) -> StreamExecutorResult<()> {
705        if self.range_cache.normal_len() == self.range_cache.len() {
706            // no sentinel in the cache, meaning we already cached all entries of this partition
707            return Ok(());
708        }
709        assert!(self.range_cache.len() >= 2);
710
711        let cache_first_normal_key = self.range_cache.first_normal_key();
712        let cache_last_normal_key = self.range_cache.last_normal_key();
713
714        if cache_first_normal_key.is_some() && *range.end() < cache_first_normal_key.unwrap()
715            || cache_last_normal_key.is_some() && *range.start() > cache_last_normal_key.unwrap()
716        {
717            // completely not overlapping, for the sake of simplicity, we re-init the cache
718            tracing::debug!(
719                partition=?self.deduped_part_key,
720                cache_first=?cache_first_normal_key,
721                cache_last=?cache_last_normal_key,
722                range=?range,
723                "modified range is completely non-overlapping with the cached range, re-initializing the cache"
724            );
725            *self.range_cache = PartitionCache::new();
726        }
727
728        if self.cache_real_len() == 0 {
729            // no normal entry in the cache, just load the given range
730            let table_sub_range = (
731                Bound::Included(self.row_conv.state_key_to_table_sub_pk(range.start())?),
732                Bound::Included(self.row_conv.state_key_to_table_sub_pk(range.end())?),
733            );
734            tracing::debug!(
735                partition=?self.deduped_part_key,
736                table_sub_range=?table_sub_range,
737                "cache is empty, just loading the given range"
738            );
739            return self
740                .extend_cache_by_range_inner(table, table_sub_range)
741                .await;
742        }
743
744        let cache_real_first_key = self
745            .range_cache
746            .first_normal_key()
747            .expect("cache real len is not 0");
748        if self.range_cache.left_is_sentinel() && *range.start() < cache_real_first_key {
749            // extend leftward only if there's smallest sentinel
750            let table_sub_range = (
751                Bound::Included(self.row_conv.state_key_to_table_sub_pk(range.start())?),
752                Bound::Excluded(
753                    self.row_conv
754                        .state_key_to_table_sub_pk(cache_real_first_key)?,
755                ),
756            );
757            tracing::trace!(
758                partition=?self.deduped_part_key,
759                table_sub_range=?table_sub_range,
760                "loading the left half of given range"
761            );
762            self.extend_cache_by_range_inner(table, table_sub_range)
763                .await?;
764        }
765
766        let cache_real_last_key = self
767            .range_cache
768            .last_normal_key()
769            .expect("cache real len is not 0");
770        if self.range_cache.right_is_sentinel() && *range.end() > cache_real_last_key {
771            // extend rightward only if there's largest sentinel
772            let table_sub_range = (
773                Bound::Excluded(
774                    self.row_conv
775                        .state_key_to_table_sub_pk(cache_real_last_key)?,
776                ),
777                Bound::Included(self.row_conv.state_key_to_table_sub_pk(range.end())?),
778            );
779            tracing::trace!(
780                partition=?self.deduped_part_key,
781                table_sub_range=?table_sub_range,
782                "loading the right half of given range"
783            );
784            self.extend_cache_by_range_inner(table, table_sub_range)
785                .await?;
786        }
787
788        // prefetch rows before the start of the range
789        self.extend_cache_leftward_by_n(table, range.start())
790            .await?;
791
792        // prefetch rows after the end of the range
793        self.extend_cache_rightward_by_n(table, range.end()).await
794    }
795
796    async fn extend_cache_leftward_by_n(
797        &mut self,
798        table: &StateTable<S>,
799        hint_key: &StateKey,
800    ) -> StreamExecutorResult<()> {
801        if self.range_cache.normal_len() == self.range_cache.len() {
802            // no sentinel in the cache, meaning we already cached all entries of this partition
803            return Ok(());
804        }
805        assert!(self.range_cache.len() >= 2);
806
807        let left_second = {
808            let mut iter = self.range_cache.inner().iter();
809            let left_first = iter.next().unwrap().0;
810            if left_first.is_normal() {
811                // the leftside already reaches the beginning of this partition in the table
812                return Ok(());
813            }
814            iter.next().unwrap().0
815        };
816        let range_to_exclusive = match left_second {
817            CacheKey::Normal(smallest_in_cache) => smallest_in_cache,
818            CacheKey::Largest => hint_key, // no normal entry in the cache
819            _ => unreachable!(),
820        }
821        .clone();
822
823        self.extend_cache_leftward_by_n_inner(table, &range_to_exclusive)
824            .await?;
825
826        if self.cache_real_len() == 0 {
827            // Cache was empty, and extending leftward didn't add anything to the cache, but we
828            // can't just remove the smallest sentinel, we must also try extending rightward.
829            self.extend_cache_rightward_by_n_inner(table, hint_key)
830                .await?;
831            if self.cache_real_len() == 0 {
832                // still empty, meaning the table is empty
833                self.range_cache.remove(&CacheKey::Smallest);
834                self.range_cache.remove(&CacheKey::Largest);
835            }
836        }
837
838        Ok(())
839    }
840
841    async fn extend_cache_rightward_by_n(
842        &mut self,
843        table: &StateTable<S>,
844        hint_key: &StateKey,
845    ) -> StreamExecutorResult<()> {
846        if self.range_cache.normal_len() == self.range_cache.len() {
847            // no sentinel in the cache, meaning we already cached all entries of this partition
848            return Ok(());
849        }
850        assert!(self.range_cache.len() >= 2);
851
852        let right_second = {
853            let mut iter = self.range_cache.inner().iter();
854            let right_first = iter.next_back().unwrap().0;
855            if right_first.is_normal() {
856                // the rightside already reaches the end of this partition in the table
857                return Ok(());
858            }
859            iter.next_back().unwrap().0
860        };
861        let range_from_exclusive = match right_second {
862            CacheKey::Normal(largest_in_cache) => largest_in_cache,
863            CacheKey::Smallest => hint_key, // no normal entry in the cache
864            _ => unreachable!(),
865        }
866        .clone();
867
868        self.extend_cache_rightward_by_n_inner(table, &range_from_exclusive)
869            .await?;
870
871        if self.cache_real_len() == 0 {
872            // Cache was empty, and extending rightward didn't add anything to the cache, but we
873            // can't just remove the smallest sentinel, we must also try extending leftward.
874            self.extend_cache_leftward_by_n_inner(table, hint_key)
875                .await?;
876            if self.cache_real_len() == 0 {
877                // still empty, meaning the table is empty
878                self.range_cache.remove(&CacheKey::Smallest);
879                self.range_cache.remove(&CacheKey::Largest);
880            }
881        }
882
883        Ok(())
884    }
885
886    async fn extend_cache_by_range_inner(
887        &mut self,
888        table: &StateTable<S>,
889        table_sub_range: (Bound<impl Row>, Bound<impl Row>),
890    ) -> StreamExecutorResult<()> {
891        let stream = table
892            .iter_with_prefix(
893                self.deduped_part_key,
894                &table_sub_range,
895                PrefetchOptions::default(),
896            )
897            .await?;
898
899        #[for_await]
900        for row in stream {
901            let row: OwnedRow = row?.into_owned_row();
902            let key = self.row_conv.row_to_state_key(&row)?;
903            self.range_cache.insert(CacheKey::from(key), row);
904        }
905
906        Ok(())
907    }
908
909    async fn extend_cache_leftward_by_n_inner(
910        &mut self,
911        table: &StateTable<S>,
912        range_to_exclusive: &StateKey,
913    ) -> StreamExecutorResult<()> {
914        let mut n_extended = 0usize;
915        {
916            let sub_range = (
917                Bound::<OwnedRow>::Unbounded,
918                Bound::Excluded(
919                    self.row_conv
920                        .state_key_to_table_sub_pk(range_to_exclusive)?,
921                ),
922            );
923            let rev_stream = table
924                .rev_iter_with_prefix(
925                    self.deduped_part_key,
926                    &sub_range,
927                    PrefetchOptions::default(),
928                )
929                .await?;
930
931            #[for_await]
932            for row in rev_stream {
933                let row: OwnedRow = row?.into_owned_row();
934
935                let key = self.row_conv.row_to_state_key(&row)?;
936                self.range_cache.insert(CacheKey::from(key), row);
937
938                n_extended += 1;
939                if n_extended == MAGIC_BATCH_SIZE {
940                    break;
941                }
942            }
943        }
944
945        if n_extended < MAGIC_BATCH_SIZE && self.cache_real_len() > 0 {
946            // we reached the beginning of this partition in the table
947            self.range_cache.remove(&CacheKey::Smallest);
948        }
949
950        Ok(())
951    }
952
953    async fn extend_cache_rightward_by_n_inner(
954        &mut self,
955        table: &StateTable<S>,
956        range_from_exclusive: &StateKey,
957    ) -> StreamExecutorResult<()> {
958        let mut n_extended = 0usize;
959        {
960            let sub_range = (
961                Bound::Excluded(
962                    self.row_conv
963                        .state_key_to_table_sub_pk(range_from_exclusive)?,
964                ),
965                Bound::<OwnedRow>::Unbounded,
966            );
967            let stream = table
968                .iter_with_prefix(
969                    self.deduped_part_key,
970                    &sub_range,
971                    PrefetchOptions::default(),
972                )
973                .await?;
974
975            #[for_await]
976            for row in stream {
977                let row: OwnedRow = row?.into_owned_row();
978
979                let key = self.row_conv.row_to_state_key(&row)?;
980                self.range_cache.insert(CacheKey::from(key), row);
981
982                n_extended += 1;
983                if n_extended == MAGIC_BATCH_SIZE {
984                    break;
985                }
986            }
987        }
988
989        if n_extended < MAGIC_BATCH_SIZE && self.cache_real_len() > 0 {
990            // we reached the end of this partition in the table
991            self.range_cache.remove(&CacheKey::Largest);
992        }
993
994        Ok(())
995    }
996}