risingwave_common/memory/
alloc.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::alloc::{AllocError, Allocator, Global, Layout};
16use std::ptr::NonNull;
17
18use crate::memory::MemoryContext;
19
20pub type MonitoredGlobalAlloc = MonitoredAlloc<Global>;
21
22pub struct MonitoredAlloc<A: Allocator> {
23    ctx: MemoryContext,
24    alloc: A,
25}
26
27impl<A: Allocator> MonitoredAlloc<A> {
28    pub fn new(ctx: MemoryContext, alloc: A) -> Self {
29        Self { ctx, alloc }
30    }
31}
32
33impl MonitoredGlobalAlloc {
34    pub fn with_memory_context(ctx: MemoryContext) -> Self {
35        Self { ctx, alloc: Global }
36    }
37
38    pub fn for_test() -> Self {
39        Self::with_memory_context(MemoryContext::none())
40    }
41}
42
43unsafe impl<A: Allocator> Allocator for MonitoredAlloc<A> {
44    fn allocate(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocError> {
45        let ret = self.alloc.allocate(layout)?;
46        // We don't throw an AllocError if the memory context is out of memory, otherwise the whole process will crash.
47        self.ctx.add(layout.size() as i64);
48        Ok(ret)
49    }
50
51    unsafe fn deallocate(&self, ptr: NonNull<u8>, layout: Layout) {
52        unsafe {
53            self.alloc.deallocate(ptr, layout);
54            self.ctx.add(-(layout.size() as i64));
55        }
56    }
57}
58
59impl<A: Allocator + Clone> Clone for MonitoredAlloc<A> {
60    fn clone(&self) -> Self {
61        Self {
62            ctx: self.ctx.clone(),
63            alloc: self.alloc.clone(),
64        }
65    }
66}