risingwave_common/memory/
mem_context.rs
1use std::ops::Deref;
16use std::sync::Arc;
17
18use prometheus::core::Atomic;
19use risingwave_common_metrics::TrAdderAtomic;
20
21use super::MonitoredGlobalAlloc;
22use crate::metrics::{LabelGuardedIntGauge, TrAdderGauge};
23
24pub trait MemCounter: Send + Sync + 'static {
25 fn add(&self, bytes: i64);
26 fn get_bytes_used(&self) -> i64;
27}
28
29impl MemCounter for TrAdderGauge {
30 fn add(&self, bytes: i64) {
31 self.add(bytes)
32 }
33
34 fn get_bytes_used(&self) -> i64 {
35 self.get()
36 }
37}
38
39impl MemCounter for TrAdderAtomic {
40 fn add(&self, bytes: i64) {
41 self.inc_by(bytes)
42 }
43
44 fn get_bytes_used(&self) -> i64 {
45 self.get()
46 }
47}
48
49impl<const N: usize> MemCounter for LabelGuardedIntGauge<N> {
50 fn add(&self, bytes: i64) {
51 self.deref().add(bytes)
52 }
53
54 fn get_bytes_used(&self) -> i64 {
55 self.get()
56 }
57}
58
59struct MemoryContextInner {
60 counter: Box<dyn MemCounter>,
61 parent: Option<MemoryContext>,
62 mem_limit: u64,
63}
64
65#[derive(Clone)]
66pub struct MemoryContext {
67 inner: Option<Arc<MemoryContextInner>>,
70}
71
72impl MemoryContext {
73 pub fn new(parent: Option<MemoryContext>, counter: impl MemCounter) -> Self {
74 let mem_limit = parent.as_ref().map_or_else(|| u64::MAX, |p| p.mem_limit());
75 Self::new_with_mem_limit(parent, counter, mem_limit)
76 }
77
78 pub fn new_with_mem_limit(
79 parent: Option<MemoryContext>,
80 counter: impl MemCounter,
81 mem_limit: u64,
82 ) -> Self {
83 let c = Box::new(counter);
84 Self {
85 inner: Some(Arc::new(MemoryContextInner {
86 counter: c,
87 parent,
88 mem_limit,
89 })),
90 }
91 }
92
93 pub fn none() -> Self {
95 Self { inner: None }
96 }
97
98 pub fn root(counter: impl MemCounter, mem_limit: u64) -> Self {
99 Self::new_with_mem_limit(None, counter, mem_limit)
100 }
101
102 pub fn for_spill_test() -> Self {
103 Self::new_with_mem_limit(None, TrAdderAtomic::new(0), 0)
104 }
105
106 pub fn add(&self, bytes: i64) -> bool {
109 if let Some(inner) = &self.inner {
110 if (inner.counter.get_bytes_used() + bytes) as u64 > inner.mem_limit {
111 return false;
112 }
113 if let Some(parent) = &inner.parent {
114 if parent.add(bytes) {
115 inner.counter.add(bytes);
116 } else {
117 return false;
118 }
119 } else {
120 inner.counter.add(bytes);
121 }
122 }
123 true
124 }
125
126 pub fn get_bytes_used(&self) -> i64 {
127 if let Some(inner) = &self.inner {
128 inner.counter.get_bytes_used()
129 } else {
130 0
131 }
132 }
133
134 pub fn mem_limit(&self) -> u64 {
135 if let Some(inner) = &self.inner {
136 inner.mem_limit
137 } else {
138 u64::MAX
139 }
140 }
141
142 pub fn check_memory_usage(&self) -> bool {
145 if let Some(inner) = &self.inner {
146 if inner.counter.get_bytes_used() as u64 > inner.mem_limit {
147 return false;
148 }
149 if let Some(parent) = &inner.parent {
150 return parent.check_memory_usage();
151 }
152 }
153
154 true
155 }
156
157 pub fn global_allocator(&self) -> MonitoredGlobalAlloc {
159 MonitoredGlobalAlloc::with_memory_context(self.clone())
160 }
161}
162
163impl Drop for MemoryContextInner {
164 fn drop(&mut self) {
165 if let Some(p) = &self.parent {
166 p.add(-self.counter.get_bytes_used());
167 }
168 }
169}