risingwave_common/memory/
monitored_heap.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BinaryHeap;
16use std::mem::size_of;
17
18use risingwave_common_estimate_size::EstimateSize;
19
20use crate::memory::{MemoryContext, MonitoredGlobalAlloc};
21
22pub struct MemMonitoredHeap<T> {
23    inner: BinaryHeap<T>,
24    mem_ctx: MemoryContext,
25}
26
27impl<T: Ord + EstimateSize> MemMonitoredHeap<T> {
28    pub fn new_with(mem_ctx: MemoryContext) -> Self {
29        Self {
30            inner: BinaryHeap::new(),
31            mem_ctx,
32        }
33    }
34
35    pub fn with_capacity(capacity: usize, mem_ctx: MemoryContext) -> Self {
36        let inner = BinaryHeap::with_capacity(capacity);
37        mem_ctx.add((capacity * size_of::<T>()) as i64);
38        Self { inner, mem_ctx }
39    }
40
41    pub fn push(&mut self, item: T) {
42        let prev_cap = self.inner.capacity();
43        let item_heap = item.estimated_heap_size();
44        self.inner.push(item);
45        let new_cap = self.inner.capacity();
46        self.mem_ctx
47            .add(((new_cap - prev_cap) * size_of::<T>() + item_heap) as i64);
48    }
49
50    pub fn pop(&mut self) -> Option<T> {
51        let prev_cap = self.inner.capacity();
52        let item = self.inner.pop();
53        let item_heap = item.as_ref().map(|i| i.estimated_heap_size()).unwrap_or(0);
54        let new_cap = self.inner.capacity();
55        self.mem_ctx
56            .add(-(((prev_cap - new_cap) * size_of::<T>() + item_heap) as i64));
57
58        item
59    }
60
61    pub fn is_empty(&self) -> bool {
62        self.inner.is_empty()
63    }
64
65    pub fn len(&self) -> usize {
66        self.inner.len()
67    }
68
69    pub fn peek(&self) -> Option<&T> {
70        self.inner.peek()
71    }
72
73    pub fn into_sorted_vec(self) -> Vec<T, MonitoredGlobalAlloc> {
74        let old_cap = self.inner.capacity();
75        let alloc = MonitoredGlobalAlloc::with_memory_context(self.mem_ctx.clone());
76        let vec = self.inner.into_iter_sorted();
77
78        let mut ret = Vec::with_capacity_in(vec.len(), alloc);
79        ret.extend(vec);
80
81        self.mem_ctx.add(-((old_cap * size_of::<T>()) as i64));
82        ret
83    }
84
85    pub fn mem_context(&self) -> &MemoryContext {
86        &self.mem_ctx
87    }
88}
89
90impl<T> Extend<T> for MemMonitoredHeap<T>
91where
92    T: Ord + EstimateSize,
93{
94    fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
95        let old_cap = self.inner.capacity();
96        let mut items_heap_size = 0usize;
97        let items = iter.into_iter();
98        self.inner.reserve_exact(items.size_hint().0);
99        for item in items {
100            items_heap_size += item.estimated_heap_size();
101            self.inner.push(item);
102        }
103
104        let new_cap = self.inner.capacity();
105
106        let diff = (new_cap - old_cap) * size_of::<T>() + items_heap_size;
107        self.mem_ctx.add(diff as i64);
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::MemMonitoredHeap;
114    use crate::memory::MemoryContext;
115    use crate::metrics::LabelGuardedIntGauge;
116
117    #[test]
118    fn test_heap() {
119        let gauge = LabelGuardedIntGauge::<4>::test_int_gauge();
120        let mem_ctx = MemoryContext::root(gauge.clone(), u64::MAX);
121
122        let mut heap = MemMonitoredHeap::<u8>::new_with(mem_ctx);
123        assert_eq!(0, gauge.get());
124
125        heap.push(9u8);
126        heap.push(1u8);
127        assert_eq!(heap.inner.capacity() as i64, gauge.get());
128
129        heap.pop().unwrap();
130        assert_eq!(heap.inner.capacity() as i64, gauge.get());
131
132        assert!(!heap.is_empty());
133    }
134
135    #[test]
136    fn test_heap_drop() {
137        let gauge = LabelGuardedIntGauge::<4>::test_int_gauge();
138        let mem_ctx = MemoryContext::root(gauge.clone(), u64::MAX);
139
140        let vec = {
141            let mut heap = MemMonitoredHeap::<u8>::new_with(mem_ctx);
142            assert_eq!(0, gauge.get());
143
144            heap.push(9u8);
145            heap.push(1u8);
146            assert_eq!(heap.inner.capacity() as i64, gauge.get());
147
148            heap.into_sorted_vec()
149        };
150
151        assert_eq!(2, gauge.get());
152
153        drop(vec);
154
155        assert_eq!(0, gauge.get());
156    }
157}