risingwave_common_estimate_size/collections/
btreemap.rs1use core::fmt;
16use std::collections::BTreeMap;
17use std::ops::{Bound, RangeInclusive};
18
19use crate::{EstimateSize, KvSize};
20
21#[derive(Clone)]
22pub struct EstimatedBTreeMap<K, V> {
23 inner: BTreeMap<K, V>,
24 heap_size: KvSize,
25}
26
27impl<K, V> EstimatedBTreeMap<K, V> {
28 pub fn new() -> Self {
29 Self {
30 inner: BTreeMap::new(),
31 heap_size: KvSize::new(),
32 }
33 }
34
35 pub fn inner(&self) -> &BTreeMap<K, V> {
36 &self.inner
37 }
38
39 pub fn len(&self) -> usize {
40 self.inner.len()
41 }
42
43 pub fn is_empty(&self) -> bool {
44 self.inner.is_empty()
45 }
46
47 pub fn iter(&self) -> impl DoubleEndedIterator<Item = (&K, &V)> {
48 self.inner.iter()
49 }
50
51 pub fn range<R>(&self, range: R) -> std::collections::btree_map::Range<'_, K, V>
52 where
53 K: Ord,
54 R: std::ops::RangeBounds<K>,
55 {
56 self.inner.range(range)
57 }
58
59 pub fn values(&self) -> impl Iterator<Item = &V> {
60 self.inner.values()
61 }
62}
63
64impl<K, V> EstimatedBTreeMap<K, V>
65where
66 K: Ord,
67{
68 pub fn first_key_value(&self) -> Option<(&K, &V)> {
69 self.inner.first_key_value()
70 }
71
72 pub fn first_key(&self) -> Option<&K> {
73 self.first_key_value().map(|(k, _)| k)
74 }
75
76 pub fn first_value(&self) -> Option<&V> {
77 self.first_key_value().map(|(_, v)| v)
78 }
79
80 pub fn last_key_value(&self) -> Option<(&K, &V)> {
81 self.inner.last_key_value()
82 }
83
84 pub fn last_key(&self) -> Option<&K> {
85 self.last_key_value().map(|(k, _)| k)
86 }
87
88 pub fn last_value(&self) -> Option<&V> {
89 self.last_key_value().map(|(_, v)| v)
90 }
91}
92
93impl<K, V> Default for EstimatedBTreeMap<K, V> {
94 fn default() -> Self {
95 Self::new()
96 }
97}
98
99impl<K, V> EstimatedBTreeMap<K, V>
100where
101 K: EstimateSize + Ord,
102 V: EstimateSize,
103{
104 pub fn insert(&mut self, key: K, value: V) -> Option<V> {
105 let key_size = self.heap_size.add_val(&key);
106 self.heap_size.add_val(&value);
107 let old_value = self.inner.insert(key, value);
108 if let Some(old_value) = &old_value {
109 self.heap_size.sub_size(key_size);
110 self.heap_size.sub_val(old_value);
111 }
112 old_value
113 }
114
115 pub fn remove(&mut self, key: &K) -> Option<V> {
116 let old_value = self.inner.remove(key);
117 if let Some(old_value) = &old_value {
118 self.heap_size.sub(key, old_value);
119 }
120 old_value
121 }
122
123 pub fn clear(&mut self) {
124 self.inner.clear();
125 self.heap_size.set(0);
126 }
127
128 pub fn pop_first(&mut self) -> Option<(K, V)> {
129 let (key, value) = self.inner.pop_first()?;
130 self.heap_size.sub(&key, &value);
131 Some((key, value))
132 }
133
134 pub fn pop_last(&mut self) -> Option<(K, V)> {
135 let (key, value) = self.inner.pop_last()?;
136 self.heap_size.sub(&key, &value);
137 Some((key, value))
138 }
139
140 pub fn last_entry(&mut self) -> Option<OccupiedEntry<'_, K, V>> {
141 self.inner.last_entry().map(|inner| OccupiedEntry {
142 inner,
143 heap_size: &mut self.heap_size,
144 })
145 }
146
147 pub fn retain_range(&mut self, range: RangeInclusive<&K>) -> (BTreeMap<K, V>, BTreeMap<K, V>)
149 where
150 K: Clone,
151 {
152 let start = *range.start();
153 let end = *range.end();
154
155 let mut mid_right = self.inner.split_off(start);
157 let mid_right_split_key = mid_right
158 .lower_bound(Bound::Excluded(end))
159 .peek_next()
160 .map(|(k, _)| k)
161 .cloned();
162 let right = if let Some(ref mid_right_split_key) = mid_right_split_key {
163 mid_right.split_off(mid_right_split_key)
164 } else {
165 Default::default()
166 };
167 let mid = mid_right;
168 let left = std::mem::replace(&mut self.inner, mid);
169
170 for (k, v) in &left {
171 self.heap_size.sub(k, v);
172 }
173 for (k, v) in &right {
174 self.heap_size.sub(k, v);
175 }
176
177 (left, right)
178 }
179
180 pub fn extract_if<'a, F>(
181 &'a mut self,
182 mut pred: F,
183 ) -> ExtractIf<'a, K, V, impl FnMut(&K, &mut V) -> bool + use<F, K, V>>
184 where
185 F: 'a + FnMut(&K, &V) -> bool,
186 {
187 let pred_immut = move |key: &K, value: &mut V| pred(key, value);
188 ExtractIf {
189 inner: self.inner.extract_if(pred_immut),
190 heap_size: &mut self.heap_size,
191 }
192 }
193
194 pub fn retain<F>(&mut self, mut f: F)
195 where
196 F: FnMut(&K, &V) -> bool,
197 {
198 self.extract_if(|k, v| !f(k, v)).for_each(drop);
199 }
200}
201
202impl<K, V> EstimateSize for EstimatedBTreeMap<K, V>
203where
204 K: EstimateSize,
205 V: EstimateSize,
206{
207 fn estimated_heap_size(&self) -> usize {
208 self.heap_size.size()
209 }
210}
211
212impl<K, V> fmt::Debug for EstimatedBTreeMap<K, V>
213where
214 K: fmt::Debug,
215 V: fmt::Debug,
216{
217 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
218 self.inner.fmt(f)
219 }
220}
221
222pub struct OccupiedEntry<'a, K, V> {
223 inner: std::collections::btree_map::OccupiedEntry<'a, K, V>,
224 heap_size: &'a mut KvSize,
225}
226
227impl<K, V> OccupiedEntry<'_, K, V>
228where
229 K: EstimateSize + Ord,
230 V: EstimateSize,
231{
232 pub fn key(&self) -> &K {
233 self.inner.key()
234 }
235
236 pub fn remove_entry(self) -> (K, V) {
237 let (key, value) = self.inner.remove_entry();
238 self.heap_size.sub(&key, &value);
239 (key, value)
240 }
241}
242
243pub struct ExtractIf<'a, K, V, F>
244where
245 F: FnMut(&K, &mut V) -> bool,
246{
247 inner: std::collections::btree_map::ExtractIf<'a, K, V, F>,
248 heap_size: &'a mut KvSize,
249}
250
251impl<K, V, F> Iterator for ExtractIf<'_, K, V, F>
252where
253 K: EstimateSize,
254 V: EstimateSize,
255 F: FnMut(&K, &mut V) -> bool,
256{
257 type Item = (K, V);
258
259 fn next(&mut self) -> Option<Self::Item> {
260 let (key, value) = self.inner.next()?;
261 self.heap_size.sub(&key, &value);
262 Some((key, value))
263 }
264
265 fn size_hint(&self) -> (usize, Option<usize>) {
266 self.inner.size_hint()
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::EstimatedBTreeMap;
273
274 #[test]
275 fn test_retain_range() {
276 let mut map = EstimatedBTreeMap::new();
277
278 let (left, right) = map.retain_range(&1..=&10);
279 assert!(left.is_empty());
280 assert!(right.is_empty());
281
282 map.insert(1, "hello".to_owned());
283 map.insert(6, "world".to_owned());
284 let (left, right) = map.retain_range(&6..=&6);
285 assert_eq!(map.len(), 1);
286 assert_eq!(map.inner[&6], "world".to_owned());
287 assert_eq!(left.len(), 1);
288 assert_eq!(left[&1], "hello".to_owned());
289 assert!(right.is_empty());
290
291 map.insert(8, "risingwave".to_owned());
292 map.insert(3, "great".to_owned());
293 map.insert(0, "wooow".to_owned());
294 let (left, right) = map.retain_range(&2..=&7);
295 assert_eq!(map.len(), 2);
296 assert_eq!(map.inner[&3], "great".to_owned());
297 assert_eq!(map.inner[&6], "world".to_owned());
298 assert_eq!(left.len(), 1);
299 assert_eq!(left[&0], "wooow".to_owned());
300 assert_eq!(right.len(), 1);
301 assert_eq!(right[&8], "risingwave".to_owned());
302 }
303}