risingwave_common_estimate_size/collections/
btreemap.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use core::fmt;
16use std::collections::BTreeMap;
17use std::ops::{Bound, RangeInclusive};
18
19use crate::{EstimateSize, KvSize};
20
21#[derive(Clone)]
22pub struct EstimatedBTreeMap<K, V> {
23    inner: BTreeMap<K, V>,
24    heap_size: KvSize,
25}
26
27impl<K, V> EstimatedBTreeMap<K, V> {
28    pub fn new() -> Self {
29        Self {
30            inner: BTreeMap::new(),
31            heap_size: KvSize::new(),
32        }
33    }
34
35    pub fn inner(&self) -> &BTreeMap<K, V> {
36        &self.inner
37    }
38
39    pub fn len(&self) -> usize {
40        self.inner.len()
41    }
42
43    pub fn is_empty(&self) -> bool {
44        self.inner.is_empty()
45    }
46
47    pub fn iter(&self) -> impl DoubleEndedIterator<Item = (&K, &V)> {
48        self.inner.iter()
49    }
50
51    pub fn range<R>(&self, range: R) -> std::collections::btree_map::Range<'_, K, V>
52    where
53        K: Ord,
54        R: std::ops::RangeBounds<K>,
55    {
56        self.inner.range(range)
57    }
58
59    pub fn values(&self) -> impl Iterator<Item = &V> {
60        self.inner.values()
61    }
62}
63
64impl<K, V> EstimatedBTreeMap<K, V>
65where
66    K: Ord,
67{
68    pub fn first_key_value(&self) -> Option<(&K, &V)> {
69        self.inner.first_key_value()
70    }
71
72    pub fn first_key(&self) -> Option<&K> {
73        self.first_key_value().map(|(k, _)| k)
74    }
75
76    pub fn first_value(&self) -> Option<&V> {
77        self.first_key_value().map(|(_, v)| v)
78    }
79
80    pub fn last_key_value(&self) -> Option<(&K, &V)> {
81        self.inner.last_key_value()
82    }
83
84    pub fn last_key(&self) -> Option<&K> {
85        self.last_key_value().map(|(k, _)| k)
86    }
87
88    pub fn last_value(&self) -> Option<&V> {
89        self.last_key_value().map(|(_, v)| v)
90    }
91}
92
93impl<K, V> Default for EstimatedBTreeMap<K, V> {
94    fn default() -> Self {
95        Self::new()
96    }
97}
98
99impl<K, V> EstimatedBTreeMap<K, V>
100where
101    K: EstimateSize + Ord,
102    V: EstimateSize,
103{
104    pub fn insert(&mut self, key: K, value: V) -> Option<V> {
105        let key_size = self.heap_size.add_val(&key);
106        self.heap_size.add_val(&value);
107        let old_value = self.inner.insert(key, value);
108        if let Some(old_value) = &old_value {
109            self.heap_size.sub_size(key_size);
110            self.heap_size.sub_val(old_value);
111        }
112        old_value
113    }
114
115    pub fn remove(&mut self, key: &K) -> Option<V> {
116        let old_value = self.inner.remove(key);
117        if let Some(old_value) = &old_value {
118            self.heap_size.sub(key, old_value);
119        }
120        old_value
121    }
122
123    pub fn clear(&mut self) {
124        self.inner.clear();
125        self.heap_size.set(0);
126    }
127
128    pub fn pop_first(&mut self) -> Option<(K, V)> {
129        let (key, value) = self.inner.pop_first()?;
130        self.heap_size.sub(&key, &value);
131        Some((key, value))
132    }
133
134    pub fn pop_last(&mut self) -> Option<(K, V)> {
135        let (key, value) = self.inner.pop_last()?;
136        self.heap_size.sub(&key, &value);
137        Some((key, value))
138    }
139
140    pub fn last_entry(&mut self) -> Option<OccupiedEntry<'_, K, V>> {
141        self.inner.last_entry().map(|inner| OccupiedEntry {
142            inner,
143            heap_size: &mut self.heap_size,
144        })
145    }
146
147    /// Retain the given range of entries in the map, removing others.
148    pub fn retain_range(&mut self, range: RangeInclusive<&K>) -> (BTreeMap<K, V>, BTreeMap<K, V>)
149    where
150        K: Clone,
151    {
152        let start = *range.start();
153        let end = *range.end();
154
155        // [ left, [mid], right ]
156        let mut mid_right = self.inner.split_off(start);
157        let mid_right_split_key = mid_right
158            .lower_bound(Bound::Excluded(end))
159            .peek_next()
160            .map(|(k, _)| k)
161            .cloned();
162        let right = if let Some(ref mid_right_split_key) = mid_right_split_key {
163            mid_right.split_off(mid_right_split_key)
164        } else {
165            Default::default()
166        };
167        let mid = mid_right;
168        let left = std::mem::replace(&mut self.inner, mid);
169
170        for (k, v) in &left {
171            self.heap_size.sub(k, v);
172        }
173        for (k, v) in &right {
174            self.heap_size.sub(k, v);
175        }
176
177        (left, right)
178    }
179
180    pub fn extract_if<'a, F>(
181        &'a mut self,
182        mut pred: F,
183    ) -> ExtractIf<'a, K, V, impl FnMut(&K, &mut V) -> bool + use<F, K, V>>
184    where
185        F: 'a + FnMut(&K, &V) -> bool,
186    {
187        let pred_immut = move |key: &K, value: &mut V| pred(key, value);
188        ExtractIf {
189            inner: self.inner.extract_if(pred_immut),
190            heap_size: &mut self.heap_size,
191        }
192    }
193
194    pub fn retain<F>(&mut self, mut f: F)
195    where
196        F: FnMut(&K, &V) -> bool,
197    {
198        self.extract_if(|k, v| !f(k, v)).for_each(drop);
199    }
200}
201
202impl<K, V> EstimateSize for EstimatedBTreeMap<K, V>
203where
204    K: EstimateSize,
205    V: EstimateSize,
206{
207    fn estimated_heap_size(&self) -> usize {
208        self.heap_size.size()
209    }
210}
211
212impl<K, V> fmt::Debug for EstimatedBTreeMap<K, V>
213where
214    K: fmt::Debug,
215    V: fmt::Debug,
216{
217    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
218        self.inner.fmt(f)
219    }
220}
221
222pub struct OccupiedEntry<'a, K, V> {
223    inner: std::collections::btree_map::OccupiedEntry<'a, K, V>,
224    heap_size: &'a mut KvSize,
225}
226
227impl<K, V> OccupiedEntry<'_, K, V>
228where
229    K: EstimateSize + Ord,
230    V: EstimateSize,
231{
232    pub fn key(&self) -> &K {
233        self.inner.key()
234    }
235
236    pub fn remove_entry(self) -> (K, V) {
237        let (key, value) = self.inner.remove_entry();
238        self.heap_size.sub(&key, &value);
239        (key, value)
240    }
241}
242
243pub struct ExtractIf<'a, K, V, F>
244where
245    F: FnMut(&K, &mut V) -> bool,
246{
247    inner: std::collections::btree_map::ExtractIf<'a, K, V, F>,
248    heap_size: &'a mut KvSize,
249}
250
251impl<K, V, F> Iterator for ExtractIf<'_, K, V, F>
252where
253    K: EstimateSize,
254    V: EstimateSize,
255    F: FnMut(&K, &mut V) -> bool,
256{
257    type Item = (K, V);
258
259    fn next(&mut self) -> Option<Self::Item> {
260        let (key, value) = self.inner.next()?;
261        self.heap_size.sub(&key, &value);
262        Some((key, value))
263    }
264
265    fn size_hint(&self) -> (usize, Option<usize>) {
266        self.inner.size_hint()
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::EstimatedBTreeMap;
273
274    #[test]
275    fn test_retain_range() {
276        let mut map = EstimatedBTreeMap::new();
277
278        let (left, right) = map.retain_range(&1..=&10);
279        assert!(left.is_empty());
280        assert!(right.is_empty());
281
282        map.insert(1, "hello".to_owned());
283        map.insert(6, "world".to_owned());
284        let (left, right) = map.retain_range(&6..=&6);
285        assert_eq!(map.len(), 1);
286        assert_eq!(map.inner[&6], "world".to_owned());
287        assert_eq!(left.len(), 1);
288        assert_eq!(left[&1], "hello".to_owned());
289        assert!(right.is_empty());
290
291        map.insert(8, "risingwave".to_owned());
292        map.insert(3, "great".to_owned());
293        map.insert(0, "wooow".to_owned());
294        let (left, right) = map.retain_range(&2..=&7);
295        assert_eq!(map.len(), 2);
296        assert_eq!(map.inner[&3], "great".to_owned());
297        assert_eq!(map.inner[&6], "world".to_owned());
298        assert_eq!(left.len(), 1);
299        assert_eq!(left[&0], "wooow".to_owned());
300        assert_eq!(right.len(), 1);
301        assert_eq!(right[&8], "risingwave".to_owned());
302    }
303}