risingwave_common/memory/
alloc.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::alloc::{AllocError, Allocator, Global, Layout};
16use std::ptr::NonNull;
17
18use allocator_api2::alloc::{AllocError as AllocErrorApi2, Allocator as AllocatorApi2};
19
20use crate::memory::MemoryContext;
21
22pub type MonitoredGlobalAlloc = MonitoredAlloc<Global>;
23
24pub struct MonitoredAlloc<A: Allocator> {
25    ctx: MemoryContext,
26    alloc: A,
27}
28
29impl<A: Allocator> MonitoredAlloc<A> {
30    pub fn new(ctx: MemoryContext, alloc: A) -> Self {
31        Self { ctx, alloc }
32    }
33}
34
35impl MonitoredGlobalAlloc {
36    pub fn with_memory_context(ctx: MemoryContext) -> Self {
37        Self { ctx, alloc: Global }
38    }
39
40    pub fn for_test() -> Self {
41        Self::with_memory_context(MemoryContext::none())
42    }
43}
44
45unsafe impl<A: Allocator> Allocator for MonitoredAlloc<A> {
46    fn allocate(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocError> {
47        let ret = self.alloc.allocate(layout)?;
48        // We don't throw an AllocError if the memory context is out of memory, otherwise the whole process will crash.
49        self.ctx.add(layout.size() as i64);
50        Ok(ret)
51    }
52
53    unsafe fn deallocate(&self, ptr: NonNull<u8>, layout: Layout) {
54        unsafe {
55            self.alloc.deallocate(ptr, layout);
56            self.ctx.add(-(layout.size() as i64));
57        }
58    }
59}
60
61unsafe impl<A: Allocator> AllocatorApi2 for MonitoredAlloc<A> {
62    fn allocate(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocErrorApi2> {
63        let ret = self.alloc.allocate(layout).map_err(|_| AllocErrorApi2)?;
64        // Keep memory accounting behavior consistent with the std::alloc::Allocator path.
65        self.ctx.add(layout.size() as i64);
66        Ok(ret)
67    }
68
69    unsafe fn deallocate(&self, ptr: NonNull<u8>, layout: Layout) {
70        unsafe {
71            self.alloc.deallocate(ptr, layout);
72            self.ctx.add(-(layout.size() as i64));
73        }
74    }
75}
76
77impl<A: Allocator + Clone> Clone for MonitoredAlloc<A> {
78    fn clone(&self) -> Self {
79        Self {
80            ctx: self.ctx.clone(),
81            alloc: self.alloc.clone(),
82        }
83    }
84}