risingwave_stream/executor/over_window/
range_cache.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BTreeMap;
16use std::ops::{Bound, RangeInclusive};
17
18use risingwave_common::row::OwnedRow;
19use risingwave_common::session_config::OverWindowCachePolicy as CachePolicy;
20use risingwave_common::types::Sentinelled;
21use risingwave_common_estimate_size::EstimateSize;
22use risingwave_common_estimate_size::collections::EstimatedBTreeMap;
23use risingwave_expr::window_function::StateKey;
24use static_assertions::const_assert;
25
26pub(super) type CacheKey = Sentinelled<StateKey>;
27
28/// Range cache for one over window partition.
29/// The cache entries can be:
30///
31/// - `(Normal)*`
32/// - `Smallest, (Normal)*, Largest`
33/// - `(Normal)+, Largest`
34/// - `Smallest, (Normal)+`
35///
36/// This means it's impossible to only have one sentinel in the cache without any normal entry,
37/// and, each of the two types of sentinel can only appear once. Also, since sentinels are either
38/// smallest or largest, they always appear at the beginning or the end of the cache.
39#[derive(Clone, Debug, Default)]
40pub(super) struct PartitionCache {
41    inner: EstimatedBTreeMap<CacheKey, OwnedRow>,
42}
43
44impl PartitionCache {
45    /// Create a new empty partition cache without sentinel values.
46    pub fn new_without_sentinels() -> Self {
47        Self {
48            inner: EstimatedBTreeMap::new(),
49        }
50    }
51
52    /// Create a new empty partition cache with sentinel values.
53    pub fn new() -> Self {
54        let mut cache = Self {
55            inner: EstimatedBTreeMap::new(),
56        };
57        cache.insert(CacheKey::Smallest, OwnedRow::empty());
58        cache.insert(CacheKey::Largest, OwnedRow::empty());
59        cache
60    }
61
62    /// Get access to the inner `BTreeMap` for cursor operations.
63    pub fn inner(&self) -> &BTreeMap<CacheKey, OwnedRow> {
64        self.inner.inner()
65    }
66
67    /// Insert a key-value pair into the cache.
68    pub fn insert(&mut self, key: CacheKey, value: OwnedRow) -> Option<OwnedRow> {
69        self.inner.insert(key, value)
70    }
71
72    /// Remove a key from the cache.
73    pub fn remove(&mut self, key: &CacheKey) -> Option<OwnedRow> {
74        self.inner.remove(key)
75    }
76
77    /// Get the number of entries in the cache.
78    pub fn len(&self) -> usize {
79        self.inner.len()
80    }
81
82    /// Check if the cache is empty.
83    pub fn is_empty(&self) -> bool {
84        self.inner.is_empty()
85    }
86
87    /// Get the first key-value pair in the cache.
88    pub fn first_key_value(&self) -> Option<(&CacheKey, &OwnedRow)> {
89        self.inner.first_key_value()
90    }
91
92    /// Get the last key-value pair in the cache.
93    pub fn last_key_value(&self) -> Option<(&CacheKey, &OwnedRow)> {
94        self.inner.last_key_value()
95    }
96
97    /// Retain entries in the given range, removing others.
98    /// Returns `(left_removed, right_removed)` where sentinels are filtered out.
99    /// Sentinels are preserved in the cache.
100    fn retain_range(
101        &mut self,
102        range: RangeInclusive<&CacheKey>,
103    ) -> (BTreeMap<CacheKey, OwnedRow>, BTreeMap<CacheKey, OwnedRow>) {
104        // Check if we had sentinels before the operation
105        let had_smallest = self.inner.inner().contains_key(&CacheKey::Smallest);
106        let had_largest = self.inner.inner().contains_key(&CacheKey::Largest);
107
108        let (left_removed, right_removed) = self.inner.retain_range(range);
109
110        // Restore sentinels if they were present before
111        if had_smallest {
112            self.inner.insert(CacheKey::Smallest, OwnedRow::empty());
113        }
114        if had_largest {
115            self.inner.insert(CacheKey::Largest, OwnedRow::empty());
116        }
117
118        // Filter out sentinels from the returned maps
119        let left_removed = left_removed
120            .into_iter()
121            .filter(|(k, _)| k.is_normal())
122            .collect();
123        let right_removed = right_removed
124            .into_iter()
125            .filter(|(k, _)| k.is_normal())
126            .collect();
127
128        (left_removed, right_removed)
129    }
130
131    /// Get the number of cached `Sentinel::Normal` entries.
132    pub fn normal_len(&self) -> usize {
133        let len = self.inner().len();
134        if len <= 1 {
135            debug_assert!(
136                self.inner()
137                    .first_key_value()
138                    .map(|(k, _)| k.is_normal())
139                    .unwrap_or(true)
140            );
141            return len;
142        }
143        // len >= 2
144        let cache_inner = self.inner();
145        let sentinels = [
146            // sentinels only appear at the beginning and/or the end
147            cache_inner.first_key_value().unwrap().0.is_sentinel(),
148            cache_inner.last_key_value().unwrap().0.is_sentinel(),
149        ];
150        len - sentinels.into_iter().filter(|x| *x).count()
151    }
152
153    /// Get the first normal key in the cache, if any.
154    pub fn first_normal_key(&self) -> Option<&StateKey> {
155        self.inner()
156            .iter()
157            .find(|(k, _)| k.is_normal())
158            .map(|(k, _)| k.as_normal_expect())
159    }
160
161    /// Get the last normal key in the cache, if any.
162    pub fn last_normal_key(&self) -> Option<&StateKey> {
163        self.inner()
164            .iter()
165            .rev()
166            .find(|(k, _)| k.is_normal())
167            .map(|(k, _)| k.as_normal_expect())
168    }
169
170    /// Whether the leftmost entry is a sentinel.
171    pub fn left_is_sentinel(&self) -> bool {
172        self.first_key_value()
173            .map(|(k, _)| k.is_sentinel())
174            .unwrap_or(false)
175    }
176
177    /// Whether the rightmost entry is a sentinel.
178    pub fn right_is_sentinel(&self) -> bool {
179        self.last_key_value()
180            .map(|(k, _)| k.is_sentinel())
181            .unwrap_or(false)
182    }
183
184    /// Shrink the partition cache based on the given policy and recently accessed range.
185    pub fn shrink(
186        &mut self,
187        deduped_part_key: &OwnedRow,
188        cache_policy: CachePolicy,
189        recently_accessed_range: RangeInclusive<StateKey>,
190    ) {
191        const MAGIC_CACHE_SIZE: usize = 1024;
192        const MAGIC_JITTER_PREVENTION: usize = MAGIC_CACHE_SIZE / 8;
193
194        tracing::trace!(
195            partition=?deduped_part_key,
196            cache_policy=?cache_policy,
197            recently_accessed_range=?recently_accessed_range,
198            "find the range to retain in the range cache"
199        );
200
201        let (start, end) = match cache_policy {
202            CachePolicy::Full => {
203                // evict nothing if the policy is to cache full partition
204                return;
205            }
206            CachePolicy::Recent => {
207                let (sk_start, sk_end) = recently_accessed_range.into_inner();
208                let (ck_start, ck_end) = (CacheKey::from(sk_start), CacheKey::from(sk_end));
209
210                // find the cursor just before `ck_start`
211                let mut cursor = self.inner().upper_bound(Bound::Excluded(&ck_start));
212                for _ in 0..MAGIC_JITTER_PREVENTION {
213                    if cursor.prev().is_none() {
214                        // already at the beginning
215                        break;
216                    }
217                }
218                let start = cursor
219                    .peek_prev()
220                    .map(|(k, _)| k)
221                    .unwrap_or_else(|| self.first_key_value().unwrap().0)
222                    .clone();
223
224                // find the cursor just after `ck_end`
225                let mut cursor = self.inner().lower_bound(Bound::Excluded(&ck_end));
226                for _ in 0..MAGIC_JITTER_PREVENTION {
227                    if cursor.next().is_none() {
228                        // already at the end
229                        break;
230                    }
231                }
232                let end = cursor
233                    .peek_next()
234                    .map(|(k, _)| k)
235                    .unwrap_or_else(|| self.last_key_value().unwrap().0)
236                    .clone();
237
238                (start, end)
239            }
240            CachePolicy::RecentFirstN => {
241                if self.len() <= MAGIC_CACHE_SIZE {
242                    // no need to evict if cache len <= N
243                    return;
244                } else {
245                    let (sk_start, _sk_end) = recently_accessed_range.into_inner();
246                    let ck_start = CacheKey::from(sk_start);
247
248                    let mut capacity_remain = MAGIC_CACHE_SIZE; // precision is not important here, code simplicity is the first
249                    const_assert!(MAGIC_JITTER_PREVENTION < MAGIC_CACHE_SIZE);
250
251                    // find the cursor just before `ck_start`
252                    let cursor_just_before_ck_start =
253                        self.inner().upper_bound(Bound::Excluded(&ck_start));
254
255                    let mut cursor = cursor_just_before_ck_start.clone();
256                    // go back for at most `MAGIC_JITTER_PREVENTION` entries
257                    for _ in 0..MAGIC_JITTER_PREVENTION {
258                        if cursor.prev().is_none() {
259                            // already at the beginning
260                            break;
261                        }
262                        capacity_remain -= 1;
263                    }
264                    let start = cursor
265                        .peek_prev()
266                        .map(|(k, _)| k)
267                        .unwrap_or_else(|| self.first_key_value().unwrap().0)
268                        .clone();
269
270                    let mut cursor = cursor_just_before_ck_start;
271                    // go forward for at most `capacity_remain` entries
272                    for _ in 0..capacity_remain {
273                        if cursor.next().is_none() {
274                            // already at the end
275                            break;
276                        }
277                    }
278                    let end = cursor
279                        .peek_next()
280                        .map(|(k, _)| k)
281                        .unwrap_or_else(|| self.last_key_value().unwrap().0)
282                        .clone();
283
284                    (start, end)
285                }
286            }
287            CachePolicy::RecentLastN => {
288                if self.len() <= MAGIC_CACHE_SIZE {
289                    // no need to evict if cache len <= N
290                    return;
291                } else {
292                    let (_sk_start, sk_end) = recently_accessed_range.into_inner();
293                    let ck_end = CacheKey::from(sk_end);
294
295                    let mut capacity_remain = MAGIC_CACHE_SIZE; // precision is not important here, code simplicity is the first
296                    const_assert!(MAGIC_JITTER_PREVENTION < MAGIC_CACHE_SIZE);
297
298                    // find the cursor just after `ck_end`
299                    let cursor_just_after_ck_end =
300                        self.inner().lower_bound(Bound::Excluded(&ck_end));
301
302                    let mut cursor = cursor_just_after_ck_end.clone();
303                    // go forward for at most `MAGIC_JITTER_PREVENTION` entries
304                    for _ in 0..MAGIC_JITTER_PREVENTION {
305                        if cursor.next().is_none() {
306                            // already at the end
307                            break;
308                        }
309                        capacity_remain -= 1;
310                    }
311                    let end = cursor
312                        .peek_next()
313                        .map(|(k, _)| k)
314                        .unwrap_or_else(|| self.last_key_value().unwrap().0)
315                        .clone();
316
317                    let mut cursor = cursor_just_after_ck_end;
318                    // go back for at most `capacity_remain` entries
319                    for _ in 0..capacity_remain {
320                        if cursor.prev().is_none() {
321                            // already at the beginning
322                            break;
323                        }
324                    }
325                    let start = cursor
326                        .peek_prev()
327                        .map(|(k, _)| k)
328                        .unwrap_or_else(|| self.first_key_value().unwrap().0)
329                        .clone();
330
331                    (start, end)
332                }
333            }
334        };
335
336        tracing::trace!(
337            partition=?deduped_part_key,
338            retain_range=?(&start..=&end),
339            "retain range in the range cache"
340        );
341
342        let (left_removed, right_removed) = self.retain_range(&start..=&end);
343        if self.is_empty() {
344            if !left_removed.is_empty() || !right_removed.is_empty() {
345                self.insert(CacheKey::Smallest, OwnedRow::empty());
346                self.insert(CacheKey::Largest, OwnedRow::empty());
347            }
348        } else {
349            if !left_removed.is_empty() {
350                self.insert(CacheKey::Smallest, OwnedRow::empty());
351            }
352            if !right_removed.is_empty() {
353                self.insert(CacheKey::Largest, OwnedRow::empty());
354            }
355        }
356    }
357}
358
359impl EstimateSize for PartitionCache {
360    fn estimated_heap_size(&self) -> usize {
361        self.inner.estimated_heap_size()
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use risingwave_common::row::OwnedRow;
368    use risingwave_common::session_config::OverWindowCachePolicy as CachePolicy;
369    use risingwave_common::types::{DefaultOrdered, ScalarImpl};
370    use risingwave_common::util::memcmp_encoding::encode_value;
371    use risingwave_common::util::sort_util::OrderType;
372    use risingwave_expr::window_function::StateKey;
373
374    use super::*;
375
376    fn create_test_state_key(value: i32) -> StateKey {
377        let row = OwnedRow::new(vec![Some(ScalarImpl::Int32(value))]);
378        StateKey {
379            order_key: encode_value(Some(ScalarImpl::Int32(value)), OrderType::ascending())
380                .unwrap(),
381            pk: DefaultOrdered::new(row),
382        }
383    }
384
385    fn create_test_cache_key(value: i32) -> CacheKey {
386        CacheKey::from(create_test_state_key(value))
387    }
388
389    fn create_test_row(value: i32) -> OwnedRow {
390        OwnedRow::new(vec![Some(ScalarImpl::Int32(value))])
391    }
392
393    #[test]
394    fn test_partition_cache_new() {
395        let cache = PartitionCache::new_without_sentinels();
396        assert!(cache.is_empty());
397        assert_eq!(cache.len(), 0);
398    }
399
400    #[test]
401    fn test_partition_cache_new_with_sentinels() {
402        let cache = PartitionCache::new();
403        assert!(!cache.is_empty());
404        assert_eq!(cache.len(), 2);
405
406        // Should have smallest and largest sentinels
407        let first = cache.first_key_value().unwrap();
408        let last = cache.last_key_value().unwrap();
409
410        assert_eq!(*first.0, CacheKey::Smallest);
411        assert_eq!(*last.0, CacheKey::Largest);
412    }
413
414    #[test]
415    fn test_partition_cache_insert_and_remove() {
416        let mut cache = PartitionCache::new_without_sentinels();
417        let key = create_test_cache_key(1);
418        let value = create_test_row(100);
419
420        // Insert
421        assert!(cache.insert(key.clone(), value.clone()).is_none());
422        assert_eq!(cache.len(), 1);
423        assert!(!cache.is_empty());
424
425        // Remove
426        let removed = cache.remove(&key);
427        assert!(removed.is_some());
428        assert_eq!(removed.unwrap(), value);
429        assert!(cache.is_empty());
430        assert_eq!(cache.len(), 0);
431    }
432
433    #[test]
434    fn test_partition_cache_first_last_key_value() {
435        let mut cache = PartitionCache::new_without_sentinels();
436
437        // Empty cache
438        assert!(cache.first_key_value().is_none());
439        assert!(cache.last_key_value().is_none());
440
441        // Add some entries
442        cache.insert(create_test_cache_key(2), create_test_row(200));
443        cache.insert(create_test_cache_key(1), create_test_row(100));
444        cache.insert(create_test_cache_key(3), create_test_row(300));
445
446        let first = cache.first_key_value().unwrap();
447        let last = cache.last_key_value().unwrap();
448
449        // BTreeMap should order by key
450        assert_eq!(*first.0, create_test_cache_key(1));
451        assert_eq!(*first.1, create_test_row(100));
452
453        assert_eq!(*last.0, create_test_cache_key(3));
454        assert_eq!(*last.1, create_test_row(300));
455    }
456
457    #[test]
458    fn test_partition_cache_retain_range() {
459        let mut cache = PartitionCache::new();
460
461        // Add some entries
462        for i in 1..=5 {
463            cache.insert(create_test_cache_key(i), create_test_row(i * 100));
464        }
465
466        assert_eq!(cache.len(), 7); // 5 normal entries + 2 sentinels
467
468        // Retain range [2, 4]
469        let start = create_test_cache_key(2);
470        let end = create_test_cache_key(4);
471        let (left_removed, right_removed) = cache.retain_range(&start..=&end);
472
473        // Should have removed key 1 on the left and key 5 on the right
474        assert_eq!(left_removed.len(), 1);
475        assert_eq!(right_removed.len(), 1);
476        assert!(left_removed.contains_key(&create_test_cache_key(1)));
477        assert!(right_removed.contains_key(&create_test_cache_key(5)));
478
479        // Cache should now contain keys 2, 3, 4 plus sentinels
480        assert_eq!(cache.len(), 5);
481        for i in 2..=4 {
482            let key = create_test_cache_key(i);
483            assert!(cache.inner.iter().any(|(k, _)| *k == key));
484        }
485    }
486
487    #[test]
488    fn test_partition_cache_shrink_full_policy() {
489        let mut cache = PartitionCache::new();
490
491        // Add many entries
492        for i in 1..=10 {
493            cache.insert(create_test_cache_key(i), create_test_row(i * 100));
494        }
495
496        let initial_len = cache.len();
497        let deduped_part_key = OwnedRow::empty();
498        let recently_accessed_range = create_test_state_key(3)..=create_test_state_key(7);
499
500        // Full policy should not shrink anything
501        cache.shrink(
502            &deduped_part_key,
503            CachePolicy::Full,
504            recently_accessed_range,
505        );
506
507        assert_eq!(cache.len(), initial_len);
508    }
509
510    #[test]
511    fn test_partition_cache_shrink_recent_policy() {
512        let mut cache = PartitionCache::new();
513
514        // Add entries
515        for i in 1..=10 {
516            cache.insert(create_test_cache_key(i), create_test_row(i * 100));
517        }
518
519        let deduped_part_key = OwnedRow::empty();
520        let recently_accessed_range = create_test_state_key(4)..=create_test_state_key(6);
521
522        // Recent policy should keep entries around the accessed range
523        cache.shrink(
524            &deduped_part_key,
525            CachePolicy::Recent,
526            recently_accessed_range,
527        );
528
529        // Cache should still contain the accessed range and some nearby entries
530        let remaining_keys: Vec<_> = cache
531            .inner
532            .iter()
533            .filter_map(|(k, _)| match k {
534                CacheKey::Normal(state_key) => Some(state_key),
535                _ => None,
536            })
537            .collect();
538
539        // Should contain at least the accessed range
540        for i in 4..=6 {
541            let target_key = create_test_state_key(i);
542            assert!(
543                remaining_keys
544                    .iter()
545                    .any(|k| k.order_key == target_key.order_key)
546            );
547        }
548    }
549
550    #[test]
551    fn test_partition_cache_shrink_with_small_cache() {
552        let mut cache = PartitionCache::new();
553
554        // Add only a few entries (less than MAGIC_CACHE_SIZE)
555        for i in 1..=5 {
556            cache.insert(create_test_cache_key(i), create_test_row(i * 100));
557        }
558
559        let initial_len = cache.len();
560        let deduped_part_key = OwnedRow::empty();
561        let recently_accessed_range = create_test_state_key(2)..=create_test_state_key(4);
562
563        // RecentFirstN and RecentLastN should not shrink small caches
564        cache.shrink(
565            &deduped_part_key,
566            CachePolicy::RecentFirstN,
567            recently_accessed_range.clone(),
568        );
569        assert_eq!(cache.len(), initial_len);
570
571        cache.shrink(
572            &deduped_part_key,
573            CachePolicy::RecentLastN,
574            recently_accessed_range,
575        );
576        assert_eq!(cache.len(), initial_len);
577    }
578
579    #[test]
580    fn test_partition_cache_estimate_size() {
581        let cache = PartitionCache::new_without_sentinels();
582        let size_without_sentinels = cache.estimated_heap_size();
583
584        let mut cache = PartitionCache::new();
585        let size_with_sentinels = cache.estimated_heap_size();
586
587        // Size should increase when adding entries
588        assert!(size_with_sentinels >= size_without_sentinels);
589
590        cache.insert(create_test_cache_key(1), create_test_row(100));
591        let size_with_entry = cache.estimated_heap_size();
592
593        assert!(size_with_entry > size_with_sentinels);
594    }
595}