risingwave_common/memory/
mem_context.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::ops::Deref;
16use std::sync::Arc;
17
18use prometheus::core::Atomic;
19use risingwave_common_metrics::TrAdderAtomic;
20
21use super::MonitoredGlobalAlloc;
22use crate::metrics::{LabelGuardedIntGauge, TrAdderGauge};
23
24pub trait MemCounter: Send + Sync + 'static {
25    fn add(&self, bytes: i64);
26    fn get_bytes_used(&self) -> i64;
27}
28
29impl MemCounter for TrAdderGauge {
30    fn add(&self, bytes: i64) {
31        self.add(bytes)
32    }
33
34    fn get_bytes_used(&self) -> i64 {
35        self.get()
36    }
37}
38
39impl MemCounter for TrAdderAtomic {
40    fn add(&self, bytes: i64) {
41        self.inc_by(bytes)
42    }
43
44    fn get_bytes_used(&self) -> i64 {
45        self.get()
46    }
47}
48
49impl<const N: usize> MemCounter for LabelGuardedIntGauge<N> {
50    fn add(&self, bytes: i64) {
51        self.deref().add(bytes)
52    }
53
54    fn get_bytes_used(&self) -> i64 {
55        self.get()
56    }
57}
58
59struct MemoryContextInner {
60    counter: Box<dyn MemCounter>,
61    parent: Option<MemoryContext>,
62    mem_limit: u64,
63}
64
65#[derive(Clone)]
66pub struct MemoryContext {
67    /// Add None op mem context, so that we don't need to return [`Option`] in
68    /// `BatchTaskContext`. This helps with later `Allocator` implementation.
69    inner: Option<Arc<MemoryContextInner>>,
70}
71
72impl MemoryContext {
73    pub fn new(parent: Option<MemoryContext>, counter: impl MemCounter) -> Self {
74        let mem_limit = parent.as_ref().map_or_else(|| u64::MAX, |p| p.mem_limit());
75        Self::new_with_mem_limit(parent, counter, mem_limit)
76    }
77
78    pub fn new_with_mem_limit(
79        parent: Option<MemoryContext>,
80        counter: impl MemCounter,
81        mem_limit: u64,
82    ) -> Self {
83        let c = Box::new(counter);
84        Self {
85            inner: Some(Arc::new(MemoryContextInner {
86                counter: c,
87                parent,
88                mem_limit,
89            })),
90        }
91    }
92
93    /// Creates a noop memory context.
94    pub fn none() -> Self {
95        Self { inner: None }
96    }
97
98    pub fn root(counter: impl MemCounter, mem_limit: u64) -> Self {
99        Self::new_with_mem_limit(None, counter, mem_limit)
100    }
101
102    pub fn for_spill_test() -> Self {
103        Self::new_with_mem_limit(None, TrAdderAtomic::new(0), 0)
104    }
105
106    /// Add `bytes` memory usage. Pass negative value to decrease memory usage.
107    /// Returns `false` if the memory usage exceeds the limit.
108    pub fn add(&self, bytes: i64) -> bool {
109        if let Some(inner) = &self.inner {
110            if (inner.counter.get_bytes_used() + bytes) as u64 > inner.mem_limit {
111                return false;
112            }
113            if let Some(parent) = &inner.parent {
114                if parent.add(bytes) {
115                    inner.counter.add(bytes);
116                } else {
117                    return false;
118                }
119            } else {
120                inner.counter.add(bytes);
121            }
122        }
123        true
124    }
125
126    pub fn get_bytes_used(&self) -> i64 {
127        if let Some(inner) = &self.inner {
128            inner.counter.get_bytes_used()
129        } else {
130            0
131        }
132    }
133
134    pub fn mem_limit(&self) -> u64 {
135        if let Some(inner) = &self.inner {
136            inner.mem_limit
137        } else {
138            u64::MAX
139        }
140    }
141
142    /// Check if the memory usage exceeds the limit.
143    /// Returns `false` if the memory usage exceeds the limit.
144    pub fn check_memory_usage(&self) -> bool {
145        if let Some(inner) = &self.inner {
146            if inner.counter.get_bytes_used() as u64 > inner.mem_limit {
147                return false;
148            }
149            if let Some(parent) = &inner.parent {
150                return parent.check_memory_usage();
151            }
152        }
153
154        true
155    }
156
157    /// Creates a new global allocator that reports memory usage to this context.
158    pub fn global_allocator(&self) -> MonitoredGlobalAlloc {
159        MonitoredGlobalAlloc::with_memory_context(self.clone())
160    }
161}
162
163impl Drop for MemoryContextInner {
164    fn drop(&mut self) {
165        if let Some(p) = &self.parent {
166            p.add(-self.counter.get_bytes_used());
167        }
168    }
169}