risingwave_stream/common/state_cache/
top_n.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use risingwave_common::array::Op;
16use risingwave_common_estimate_size::EstimateSize;
17use risingwave_common_estimate_size::collections::EstimatedBTreeMap;
18
19use super::{StateCache, StateCacheFiller};
20
21/// An implementation of [`StateCache`] that keeps a limited number of entries in an ordered in-memory map.
22#[derive(Clone, EstimateSize)]
23pub struct TopNStateCache<K: Ord + EstimateSize, V: EstimateSize> {
24    table_row_count: Option<usize>,
25    cache: EstimatedBTreeMap<K, V>,
26    capacity: usize,
27    synced: bool,
28}
29
30impl<K: Ord + EstimateSize, V: EstimateSize> TopNStateCache<K, V> {
31    pub fn new(capacity: usize) -> Self {
32        Self {
33            table_row_count: None,
34            cache: Default::default(),
35            capacity,
36            synced: false,
37        }
38    }
39
40    pub fn with_table_row_count(capacity: usize, table_row_count: usize) -> Self {
41        Self {
42            table_row_count: Some(table_row_count),
43            cache: Default::default(),
44            capacity,
45            synced: false,
46        }
47    }
48
49    pub fn set_table_row_count(&mut self, table_row_count: usize) {
50        self.table_row_count = Some(table_row_count);
51    }
52
53    #[cfg(test)]
54    pub fn get_table_row_count(&self) -> &Option<usize> {
55        &self.table_row_count
56    }
57
58    fn row_count_matched(&self) -> bool {
59        self.table_row_count
60            .map(|n| n == self.cache.len())
61            .unwrap_or(false)
62    }
63
64    /// Insert an entry with the assumption that the cache is SYNCED.
65    fn insert_synced(&mut self, key: K, value: V) -> Option<V> {
66        let old_v = if self.row_count_matched()
67            || self.cache.is_empty()
68            || &key <= self.cache.last_key().unwrap()
69        {
70            let old_v = self.cache.insert(key, value);
71            // evict if capacity is reached
72            while self.cache.len() > self.capacity {
73                self.cache.pop_last();
74            }
75            old_v
76        } else {
77            None
78        };
79        // In other cases, we can't insert this key because we're not sure whether there're keys
80        // less than it in the table. So we only update table row count.
81        self.table_row_count = self.table_row_count.map(|n| n + 1);
82        old_v
83    }
84
85    /// Delete an entry with the assumption that the cache is SYNCED.
86    fn delete_synced(&mut self, key: &K) -> Option<V> {
87        let old_val = self.cache.remove(key);
88        self.table_row_count = self.table_row_count.map(|n| n - 1);
89        if self.cache.is_empty() && !self.row_count_matched() {
90            // The cache becomes empty, but there're still rows in the table, so mark it as not
91            // synced.
92            self.synced = false;
93        }
94        old_val
95    }
96
97    pub fn capacity(&self) -> usize {
98        self.capacity
99    }
100
101    pub fn len(&self) -> usize {
102        self.cache.len()
103    }
104
105    pub fn is_empty(&self) -> bool {
106        self.cache.is_empty()
107    }
108}
109
110impl<K: Ord + EstimateSize, V: EstimateSize> StateCache for TopNStateCache<K, V> {
111    type Filler<'a>
112        = &'a mut Self
113    where
114        Self: 'a;
115    type Key = K;
116    type Value = V;
117
118    fn is_synced(&self) -> bool {
119        self.synced
120    }
121
122    fn begin_syncing(&mut self) -> Self::Filler<'_> {
123        self.synced = false;
124        self.cache.clear();
125        self
126    }
127
128    fn insert(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value> {
129        if self.synced {
130            self.insert_synced(key, value)
131        } else {
132            None
133        }
134    }
135
136    fn delete(&mut self, key: &Self::Key) -> Option<Self::Value> {
137        if self.synced {
138            self.delete_synced(key)
139        } else {
140            None
141        }
142    }
143
144    fn apply_batch(&mut self, batch: impl IntoIterator<Item = (Op, Self::Key, Self::Value)>) {
145        if self.synced {
146            for (op, key, value) in batch {
147                match op {
148                    Op::Insert | Op::UpdateInsert => {
149                        self.insert_synced(key, value);
150                    }
151                    Op::Delete | Op::UpdateDelete => {
152                        self.delete_synced(&key);
153                        if !self.synced {
154                            break;
155                        }
156                    }
157                }
158            }
159        }
160    }
161
162    fn clear(&mut self) {
163        self.cache.clear();
164        self.synced = false;
165    }
166
167    fn values(&self) -> impl Iterator<Item = &Self::Value> {
168        assert!(self.synced);
169        self.cache.values()
170    }
171
172    fn first_key_value(&self) -> Option<(&Self::Key, &Self::Value)> {
173        assert!(self.synced);
174        self.cache.first_key_value()
175    }
176}
177
178impl<K: Ord + EstimateSize, V: EstimateSize> StateCacheFiller for &mut TopNStateCache<K, V> {
179    type Key = K;
180    type Value = V;
181
182    fn capacity(&self) -> Option<usize> {
183        Some(TopNStateCache::capacity(self))
184    }
185
186    fn insert_unchecked(&mut self, key: Self::Key, value: Self::Value) {
187        self.cache.insert(key, value);
188    }
189
190    fn finish(self) {
191        self.synced = true;
192    }
193}