risingwave_stream/executor/over_window/
range_cache.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BTreeMap;
16use std::ops::{Bound, RangeInclusive};
17
18use risingwave_common::config::streaming::OverWindowCachePolicy as CachePolicy;
19use risingwave_common::row::OwnedRow;
20use risingwave_common::types::Sentinelled;
21use risingwave_common_estimate_size::EstimateSize;
22use risingwave_common_estimate_size::collections::EstimatedBTreeMap;
23use risingwave_expr::window_function::StateKey;
24use static_assertions::const_assert;
25
26pub(super) type CacheKey = Sentinelled<StateKey>;
27
28/// Range cache for one over window partition.
29/// The cache entries can be:
30///
31/// - `(Normal)*`
32/// - `Smallest, (Normal)*, Largest`
33/// - `(Normal)+, Largest`
34/// - `Smallest, (Normal)+`
35///
36/// This means it's impossible to only have one sentinel in the cache without any normal entry,
37/// and, each of the two types of sentinel can only appear once. Also, since sentinels are either
38/// smallest or largest, they always appear at the beginning or the end of the cache.
39#[derive(Clone, Debug, Default)]
40pub(super) struct PartitionCache {
41    inner: EstimatedBTreeMap<CacheKey, OwnedRow>,
42}
43
44impl PartitionCache {
45    /// Create a new empty partition cache without sentinel values.
46    pub fn new_without_sentinels() -> Self {
47        Self {
48            inner: EstimatedBTreeMap::new(),
49        }
50    }
51
52    /// Create a new empty partition cache with sentinel values.
53    pub fn new() -> Self {
54        let mut cache = Self {
55            inner: EstimatedBTreeMap::new(),
56        };
57        cache.insert(CacheKey::Smallest, OwnedRow::empty());
58        cache.insert(CacheKey::Largest, OwnedRow::empty());
59        cache
60    }
61
62    /// Get access to the inner `BTreeMap` for cursor operations.
63    pub fn inner(&self) -> &BTreeMap<CacheKey, OwnedRow> {
64        self.inner.inner()
65    }
66
67    /// Insert a key-value pair into the cache.
68    pub fn insert(&mut self, key: CacheKey, value: OwnedRow) -> Option<OwnedRow> {
69        self.inner.insert(key, value)
70    }
71
72    /// Remove a key from the cache.
73    pub fn remove(&mut self, key: &CacheKey) -> Option<OwnedRow> {
74        self.inner.remove(key)
75    }
76
77    /// Get the number of entries in the cache.
78    pub fn len(&self) -> usize {
79        self.inner.len()
80    }
81
82    /// Check if the cache is empty.
83    pub fn is_empty(&self) -> bool {
84        self.inner.is_empty()
85    }
86
87    /// Get the first key-value pair in the cache.
88    pub fn first_key_value(&self) -> Option<(&CacheKey, &OwnedRow)> {
89        self.inner.first_key_value()
90    }
91
92    /// Get the last key-value pair in the cache.
93    pub fn last_key_value(&self) -> Option<(&CacheKey, &OwnedRow)> {
94        self.inner.last_key_value()
95    }
96
97    /// Retain entries in the given range, removing others.
98    /// Returns `(left_removed, right_removed)` where sentinels are filtered out.
99    /// Sentinels are preserved in the cache.
100    fn retain_range(
101        &mut self,
102        range: RangeInclusive<&CacheKey>,
103    ) -> (BTreeMap<CacheKey, OwnedRow>, BTreeMap<CacheKey, OwnedRow>) {
104        // Check if we had sentinels before the operation
105        let had_smallest = self.inner.inner().contains_key(&CacheKey::Smallest);
106        let had_largest = self.inner.inner().contains_key(&CacheKey::Largest);
107
108        let (left_removed, right_removed) = self.inner.retain_range(range);
109
110        // Restore sentinels if they were present before
111        if had_smallest {
112            self.inner.insert(CacheKey::Smallest, OwnedRow::empty());
113        }
114        if had_largest {
115            self.inner.insert(CacheKey::Largest, OwnedRow::empty());
116        }
117
118        // Filter out sentinels from the returned maps
119        let left_removed = left_removed
120            .into_iter()
121            .filter(|(k, _)| k.is_normal())
122            .collect();
123        let right_removed = right_removed
124            .into_iter()
125            .filter(|(k, _)| k.is_normal())
126            .collect();
127
128        (left_removed, right_removed)
129    }
130
131    /// Get the number of cached `Sentinel::Normal` entries.
132    pub fn normal_len(&self) -> usize {
133        let len = self.inner().len();
134        if len <= 1 {
135            debug_assert!(
136                self.inner()
137                    .first_key_value()
138                    .map(|(k, _)| k.is_normal())
139                    .unwrap_or(true)
140            );
141            return len;
142        }
143        // len >= 2
144        let cache_inner = self.inner();
145        let sentinels = [
146            // sentinels only appear at the beginning and/or the end
147            cache_inner.first_key_value().unwrap().0.is_sentinel(),
148            cache_inner.last_key_value().unwrap().0.is_sentinel(),
149        ];
150        len - sentinels.into_iter().filter(|x| *x).count()
151    }
152
153    /// Get the first normal key in the cache, if any.
154    pub fn first_normal_key(&self) -> Option<&StateKey> {
155        self.inner()
156            .iter()
157            .find(|(k, _)| k.is_normal())
158            .map(|(k, _)| k.as_normal_expect())
159    }
160
161    /// Get the last normal key in the cache, if any.
162    pub fn last_normal_key(&self) -> Option<&StateKey> {
163        self.inner()
164            .iter()
165            .rev()
166            .find(|(k, _)| k.is_normal())
167            .map(|(k, _)| k.as_normal_expect())
168    }
169
170    /// Whether the leftmost entry is a sentinel.
171    pub fn left_is_sentinel(&self) -> bool {
172        self.first_key_value()
173            .map(|(k, _)| k.is_sentinel())
174            .unwrap_or(false)
175    }
176
177    /// Whether the rightmost entry is a sentinel.
178    pub fn right_is_sentinel(&self) -> bool {
179        self.last_key_value()
180            .map(|(k, _)| k.is_sentinel())
181            .unwrap_or(false)
182    }
183
184    /// Shrink the partition cache based on the given policy and recently accessed range.
185    pub fn shrink(
186        &mut self,
187        deduped_part_key: &OwnedRow,
188        cache_policy: CachePolicy,
189        recently_accessed_range: RangeInclusive<StateKey>,
190    ) {
191        const MAGIC_CACHE_SIZE: usize = 1024;
192        const MAGIC_JITTER_PREVENTION: usize = MAGIC_CACHE_SIZE / 8;
193
194        tracing::trace!(
195            partition=?deduped_part_key,
196            cache_policy=?cache_policy,
197            recently_accessed_range=?recently_accessed_range,
198            "find the range to retain in the range cache"
199        );
200
201        let (start, end) = match cache_policy {
202            CachePolicy::Full => {
203                // evict nothing if the policy is to cache full partition
204                return;
205            }
206            CachePolicy::Recent => {
207                let (sk_start, sk_end) = recently_accessed_range.into_inner();
208                let (ck_start, ck_end) = (CacheKey::from(sk_start), CacheKey::from(sk_end));
209
210                // find the cursor just before `ck_start`
211                let mut cursor = self.inner().upper_bound(Bound::Excluded(&ck_start));
212                for _ in 0..MAGIC_JITTER_PREVENTION {
213                    if cursor.prev().is_none() {
214                        // already at the beginning
215                        break;
216                    }
217                }
218                let start = cursor
219                    .peek_prev()
220                    .map(|(k, _)| k)
221                    .unwrap_or_else(|| self.first_key_value().unwrap().0)
222                    .clone();
223
224                // find the cursor just after `ck_end`
225                let mut cursor = self.inner().lower_bound(Bound::Excluded(&ck_end));
226                for _ in 0..MAGIC_JITTER_PREVENTION {
227                    if cursor.next().is_none() {
228                        // already at the end
229                        break;
230                    }
231                }
232                let end = cursor
233                    .peek_next()
234                    .map(|(k, _)| k)
235                    .unwrap_or_else(|| self.last_key_value().unwrap().0)
236                    .clone();
237
238                (start, end)
239            }
240            CachePolicy::RecentFirstN => {
241                if self.len() <= MAGIC_CACHE_SIZE {
242                    // no need to evict if cache len <= N
243                    return;
244                } else {
245                    let (sk_start, _sk_end) = recently_accessed_range.into_inner();
246                    let ck_start = CacheKey::from(sk_start);
247
248                    let mut capacity_remain = MAGIC_CACHE_SIZE; // precision is not important here, code simplicity is the first
249                    const_assert!(MAGIC_JITTER_PREVENTION < MAGIC_CACHE_SIZE);
250
251                    // find the cursor just before `ck_start`
252                    let cursor_just_before_ck_start =
253                        self.inner().upper_bound(Bound::Excluded(&ck_start));
254
255                    let mut cursor = cursor_just_before_ck_start.clone();
256                    // go back for at most `MAGIC_JITTER_PREVENTION` entries
257                    for _ in 0..MAGIC_JITTER_PREVENTION {
258                        if cursor.prev().is_none() {
259                            // already at the beginning
260                            break;
261                        }
262                        capacity_remain -= 1;
263                    }
264                    let start = cursor
265                        .peek_prev()
266                        .map(|(k, _)| k)
267                        .unwrap_or_else(|| self.first_key_value().unwrap().0)
268                        .clone();
269
270                    let mut cursor = cursor_just_before_ck_start;
271                    // go forward for at most `capacity_remain` entries
272                    for _ in 0..capacity_remain {
273                        if cursor.next().is_none() {
274                            // already at the end
275                            break;
276                        }
277                    }
278                    let end = cursor
279                        .peek_next()
280                        .map(|(k, _)| k)
281                        .unwrap_or_else(|| self.last_key_value().unwrap().0)
282                        .clone();
283
284                    (start, end)
285                }
286            }
287            CachePolicy::RecentLastN => {
288                if self.len() <= MAGIC_CACHE_SIZE {
289                    // no need to evict if cache len <= N
290                    return;
291                } else {
292                    let (_sk_start, sk_end) = recently_accessed_range.into_inner();
293                    let ck_end = CacheKey::from(sk_end);
294
295                    let mut capacity_remain = MAGIC_CACHE_SIZE; // precision is not important here, code simplicity is the first
296                    const_assert!(MAGIC_JITTER_PREVENTION < MAGIC_CACHE_SIZE);
297
298                    // find the cursor just after `ck_end`
299                    let cursor_just_after_ck_end =
300                        self.inner().lower_bound(Bound::Excluded(&ck_end));
301
302                    let mut cursor = cursor_just_after_ck_end.clone();
303                    // go forward for at most `MAGIC_JITTER_PREVENTION` entries
304                    for _ in 0..MAGIC_JITTER_PREVENTION {
305                        if cursor.next().is_none() {
306                            // already at the end
307                            break;
308                        }
309                        capacity_remain -= 1;
310                    }
311                    let end = cursor
312                        .peek_next()
313                        .map(|(k, _)| k)
314                        .unwrap_or_else(|| self.last_key_value().unwrap().0)
315                        .clone();
316
317                    let mut cursor = cursor_just_after_ck_end;
318                    // go back for at most `capacity_remain` entries
319                    for _ in 0..capacity_remain {
320                        if cursor.prev().is_none() {
321                            // already at the beginning
322                            break;
323                        }
324                    }
325                    let start = cursor
326                        .peek_prev()
327                        .map(|(k, _)| k)
328                        .unwrap_or_else(|| self.first_key_value().unwrap().0)
329                        .clone();
330
331                    (start, end)
332                }
333            }
334        };
335
336        tracing::trace!(
337            partition=?deduped_part_key,
338            retain_range=?(&start..=&end),
339            "retain range in the range cache"
340        );
341
342        let (left_removed, right_removed) = self.retain_range(&start..=&end);
343        if self.is_empty() {
344            if !left_removed.is_empty() || !right_removed.is_empty() {
345                self.insert(CacheKey::Smallest, OwnedRow::empty());
346                self.insert(CacheKey::Largest, OwnedRow::empty());
347            }
348        } else {
349            if !left_removed.is_empty() {
350                self.insert(CacheKey::Smallest, OwnedRow::empty());
351            }
352            if !right_removed.is_empty() {
353                self.insert(CacheKey::Largest, OwnedRow::empty());
354            }
355        }
356    }
357}
358
359impl EstimateSize for PartitionCache {
360    fn estimated_heap_size(&self) -> usize {
361        self.inner.estimated_heap_size()
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use risingwave_common::row::OwnedRow;
368    use risingwave_common::types::{DefaultOrdered, ScalarImpl};
369    use risingwave_common::util::memcmp_encoding::encode_value;
370    use risingwave_common::util::sort_util::OrderType;
371    use risingwave_expr::window_function::StateKey;
372
373    use super::*;
374
375    fn create_test_state_key(value: i32) -> StateKey {
376        let row = OwnedRow::new(vec![Some(ScalarImpl::Int32(value))]);
377        StateKey {
378            order_key: encode_value(Some(ScalarImpl::Int32(value)), OrderType::ascending())
379                .unwrap(),
380            pk: DefaultOrdered::new(row),
381        }
382    }
383
384    fn create_test_cache_key(value: i32) -> CacheKey {
385        CacheKey::from(create_test_state_key(value))
386    }
387
388    fn create_test_row(value: i32) -> OwnedRow {
389        OwnedRow::new(vec![Some(ScalarImpl::Int32(value))])
390    }
391
392    #[test]
393    fn test_partition_cache_new() {
394        let cache = PartitionCache::new_without_sentinels();
395        assert!(cache.is_empty());
396        assert_eq!(cache.len(), 0);
397    }
398
399    #[test]
400    fn test_partition_cache_new_with_sentinels() {
401        let cache = PartitionCache::new();
402        assert!(!cache.is_empty());
403        assert_eq!(cache.len(), 2);
404
405        // Should have smallest and largest sentinels
406        let first = cache.first_key_value().unwrap();
407        let last = cache.last_key_value().unwrap();
408
409        assert_eq!(*first.0, CacheKey::Smallest);
410        assert_eq!(*last.0, CacheKey::Largest);
411    }
412
413    #[test]
414    fn test_partition_cache_insert_and_remove() {
415        let mut cache = PartitionCache::new_without_sentinels();
416        let key = create_test_cache_key(1);
417        let value = create_test_row(100);
418
419        // Insert
420        assert!(cache.insert(key.clone(), value.clone()).is_none());
421        assert_eq!(cache.len(), 1);
422        assert!(!cache.is_empty());
423
424        // Remove
425        let removed = cache.remove(&key);
426        assert!(removed.is_some());
427        assert_eq!(removed.unwrap(), value);
428        assert!(cache.is_empty());
429        assert_eq!(cache.len(), 0);
430    }
431
432    #[test]
433    fn test_partition_cache_first_last_key_value() {
434        let mut cache = PartitionCache::new_without_sentinels();
435
436        // Empty cache
437        assert!(cache.first_key_value().is_none());
438        assert!(cache.last_key_value().is_none());
439
440        // Add some entries
441        cache.insert(create_test_cache_key(2), create_test_row(200));
442        cache.insert(create_test_cache_key(1), create_test_row(100));
443        cache.insert(create_test_cache_key(3), create_test_row(300));
444
445        let first = cache.first_key_value().unwrap();
446        let last = cache.last_key_value().unwrap();
447
448        // BTreeMap should order by key
449        assert_eq!(*first.0, create_test_cache_key(1));
450        assert_eq!(*first.1, create_test_row(100));
451
452        assert_eq!(*last.0, create_test_cache_key(3));
453        assert_eq!(*last.1, create_test_row(300));
454    }
455
456    #[test]
457    fn test_partition_cache_retain_range() {
458        let mut cache = PartitionCache::new();
459
460        // Add some entries
461        for i in 1..=5 {
462            cache.insert(create_test_cache_key(i), create_test_row(i * 100));
463        }
464
465        assert_eq!(cache.len(), 7); // 5 normal entries + 2 sentinels
466
467        // Retain range [2, 4]
468        let start = create_test_cache_key(2);
469        let end = create_test_cache_key(4);
470        let (left_removed, right_removed) = cache.retain_range(&start..=&end);
471
472        // Should have removed key 1 on the left and key 5 on the right
473        assert_eq!(left_removed.len(), 1);
474        assert_eq!(right_removed.len(), 1);
475        assert!(left_removed.contains_key(&create_test_cache_key(1)));
476        assert!(right_removed.contains_key(&create_test_cache_key(5)));
477
478        // Cache should now contain keys 2, 3, 4 plus sentinels
479        assert_eq!(cache.len(), 5);
480        for i in 2..=4 {
481            let key = create_test_cache_key(i);
482            assert!(cache.inner.iter().any(|(k, _)| *k == key));
483        }
484    }
485
486    #[test]
487    fn test_partition_cache_shrink_full_policy() {
488        let mut cache = PartitionCache::new();
489
490        // Add many entries
491        for i in 1..=10 {
492            cache.insert(create_test_cache_key(i), create_test_row(i * 100));
493        }
494
495        let initial_len = cache.len();
496        let deduped_part_key = OwnedRow::empty();
497        let recently_accessed_range = create_test_state_key(3)..=create_test_state_key(7);
498
499        // Full policy should not shrink anything
500        cache.shrink(
501            &deduped_part_key,
502            CachePolicy::Full,
503            recently_accessed_range,
504        );
505
506        assert_eq!(cache.len(), initial_len);
507    }
508
509    #[test]
510    fn test_partition_cache_shrink_recent_policy() {
511        let mut cache = PartitionCache::new();
512
513        // Add entries
514        for i in 1..=10 {
515            cache.insert(create_test_cache_key(i), create_test_row(i * 100));
516        }
517
518        let deduped_part_key = OwnedRow::empty();
519        let recently_accessed_range = create_test_state_key(4)..=create_test_state_key(6);
520
521        // Recent policy should keep entries around the accessed range
522        cache.shrink(
523            &deduped_part_key,
524            CachePolicy::Recent,
525            recently_accessed_range,
526        );
527
528        // Cache should still contain the accessed range and some nearby entries
529        let remaining_keys: Vec<_> = cache
530            .inner
531            .iter()
532            .filter_map(|(k, _)| match k {
533                CacheKey::Normal(state_key) => Some(state_key),
534                _ => None,
535            })
536            .collect();
537
538        // Should contain at least the accessed range
539        for i in 4..=6 {
540            let target_key = create_test_state_key(i);
541            assert!(
542                remaining_keys
543                    .iter()
544                    .any(|k| k.order_key == target_key.order_key)
545            );
546        }
547    }
548
549    #[test]
550    fn test_partition_cache_shrink_with_small_cache() {
551        let mut cache = PartitionCache::new();
552
553        // Add only a few entries (less than MAGIC_CACHE_SIZE)
554        for i in 1..=5 {
555            cache.insert(create_test_cache_key(i), create_test_row(i * 100));
556        }
557
558        let initial_len = cache.len();
559        let deduped_part_key = OwnedRow::empty();
560        let recently_accessed_range = create_test_state_key(2)..=create_test_state_key(4);
561
562        // RecentFirstN and RecentLastN should not shrink small caches
563        cache.shrink(
564            &deduped_part_key,
565            CachePolicy::RecentFirstN,
566            recently_accessed_range.clone(),
567        );
568        assert_eq!(cache.len(), initial_len);
569
570        cache.shrink(
571            &deduped_part_key,
572            CachePolicy::RecentLastN,
573            recently_accessed_range,
574        );
575        assert_eq!(cache.len(), initial_len);
576    }
577
578    #[test]
579    fn test_partition_cache_estimate_size() {
580        let cache = PartitionCache::new_without_sentinels();
581        let size_without_sentinels = cache.estimated_heap_size();
582
583        let mut cache = PartitionCache::new();
584        let size_with_sentinels = cache.estimated_heap_size();
585
586        // Size should increase when adding entries
587        assert!(size_with_sentinels >= size_without_sentinels);
588
589        cache.insert(create_test_cache_key(1), create_test_row(100));
590        let size_with_entry = cache.estimated_heap_size();
591
592        assert!(size_with_entry > size_with_sentinels);
593    }
594}