risingwave_common/
lru.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::alloc::{Allocator, Global};
16use std::borrow::Borrow;
17use std::cell::RefCell;
18use std::hash::{BuildHasher, Hash};
19use std::mem::MaybeUninit;
20use std::ptr::NonNull;
21use std::sync::atomic::Ordering;
22
23pub use ahash::RandomState;
24use hashbrown::HashTable;
25use hashbrown::hash_table::Entry;
26
27use crate::sequence::{AtomicSequence, Sequence, Sequencer};
28
29thread_local! {
30    pub static SEQUENCER: RefCell<Sequencer> = const { RefCell::new(Sequencer::new(Sequencer::DEFAULT_STEP, Sequencer::DEFAULT_LAG)) };
31}
32
33static SEQUENCER_DEFAULT_STEP: AtomicSequence = AtomicSequence::new(Sequencer::DEFAULT_STEP);
34static SEQUENCER_DEFAULT_LAG: AtomicSequence = AtomicSequence::new(Sequencer::DEFAULT_LAG);
35
36pub fn init_global_sequencer_args(step: Sequence, lag: Sequence) {
37    SEQUENCER_DEFAULT_STEP.store(step, Ordering::Relaxed);
38    SEQUENCER_DEFAULT_LAG.store(lag, Ordering::Relaxed);
39}
40
41struct LruEntry<K, V>
42where
43    K: Hash + Eq,
44{
45    prev: Option<NonNull<LruEntry<K, V>>>,
46    next: Option<NonNull<LruEntry<K, V>>>,
47    key: MaybeUninit<K>,
48    value: MaybeUninit<V>,
49    hash: u64,
50    sequence: Sequence,
51}
52
53impl<K, V> LruEntry<K, V>
54where
55    K: Hash + Eq,
56{
57    fn key(&self) -> &K {
58        unsafe { self.key.assume_init_ref() }
59    }
60
61    fn value(&self) -> &V {
62        unsafe { self.value.assume_init_ref() }
63    }
64
65    fn value_mut(&mut self) -> &mut V {
66        unsafe { self.value.assume_init_mut() }
67    }
68}
69
70unsafe impl<K, V> Send for LruEntry<K, V> where K: Hash + Eq {}
71unsafe impl<K, V> Sync for LruEntry<K, V> where K: Hash + Eq {}
72
73pub struct LruCache<K, V, S = RandomState, A = Global>
74where
75    K: Hash + Eq,
76    S: BuildHasher + Send + Sync + 'static,
77    A: Clone + Allocator,
78{
79    map: HashTable<NonNull<LruEntry<K, V>>, A>,
80    /// dummy node of the lru linked list
81    dummy: Box<LruEntry<K, V>, A>,
82
83    alloc: A,
84    hash_builder: S,
85}
86
87unsafe impl<K, V, S, A> Send for LruCache<K, V, S, A>
88where
89    K: Hash + Eq,
90    S: BuildHasher + Send + Sync + 'static,
91    A: Clone + Allocator,
92{
93}
94unsafe impl<K, V, S, A> Sync for LruCache<K, V, S, A>
95where
96    K: Hash + Eq,
97    S: BuildHasher + Send + Sync + 'static,
98    A: Clone + Allocator,
99{
100}
101
102impl<K, V> LruCache<K, V>
103where
104    K: Hash + Eq,
105{
106    pub fn unbounded() -> Self {
107        Self::unbounded_with_hasher_in(RandomState::default(), Global)
108    }
109}
110
111impl<K, V, S, A> LruCache<K, V, S, A>
112where
113    K: Hash + Eq,
114    S: BuildHasher + Send + Sync + 'static,
115    A: Clone + Allocator,
116{
117    pub fn unbounded_with_hasher_in(hash_builder: S, alloc: A) -> Self {
118        let map = HashTable::new_in(alloc.clone());
119        let mut dummy = Box::new_in(
120            LruEntry {
121                prev: None,
122                next: None,
123                key: MaybeUninit::uninit(),
124                value: MaybeUninit::uninit(),
125                hash: 0,
126                sequence: Sequence::default(),
127            },
128            alloc.clone(),
129        );
130        let ptr = unsafe { NonNull::new_unchecked(dummy.as_mut() as *mut _) };
131        dummy.next = Some(ptr);
132        dummy.prev = Some(ptr);
133        Self {
134            map,
135            dummy,
136            alloc,
137            hash_builder,
138        }
139    }
140
141    pub fn put(&mut self, key: K, mut value: V) -> Option<V> {
142        unsafe {
143            let hash = self.hash_builder.hash_one(&key);
144
145            match self
146                .map
147                .entry(hash, |p| p.as_ref().key() == &key, |p| p.as_ref().hash)
148            {
149                Entry::Occupied(o) => {
150                    let mut ptr = *o.get();
151                    let entry = ptr.as_mut();
152                    std::mem::swap(&mut value, entry.value_mut());
153                    self.detach(ptr);
154                    self.attach(ptr);
155                    Some(value)
156                }
157                Entry::Vacant(v) => {
158                    let entry = Box::new_in(
159                        LruEntry {
160                            prev: None,
161                            next: None,
162                            key: MaybeUninit::new(key),
163                            value: MaybeUninit::new(value),
164                            hash,
165                            // sequence will be updated by `attach`
166                            sequence: 0,
167                        },
168                        self.alloc.clone(),
169                    );
170                    let ptr = NonNull::new_unchecked(Box::into_raw(entry));
171                    v.insert(ptr);
172                    self.attach(ptr);
173                    None
174                }
175            }
176        }
177    }
178
179    pub fn get<'a, Q>(&'a mut self, key: &Q) -> Option<&'a V>
180    where
181        K: Borrow<Q>,
182        Q: Hash + Eq + ?Sized,
183    {
184        unsafe {
185            let key = key.borrow();
186            let hash = self.hash_builder.hash_one(key);
187            if let Some(ptr) = self.map.find(hash, |p| p.as_ref().key().borrow() == key) {
188                let ptr = *ptr;
189                self.detach(ptr);
190                self.attach(ptr);
191                Some(ptr.as_ref().value())
192            } else {
193                None
194            }
195        }
196    }
197
198    pub fn get_mut<'a, Q>(&'a mut self, key: &Q) -> Option<&'a mut V>
199    where
200        K: Borrow<Q>,
201        Q: Hash + Eq + ?Sized,
202    {
203        unsafe {
204            let key = key.borrow();
205            let hash = self.hash_builder.hash_one(key);
206            if let Some(ptr) = self
207                .map
208                .find_mut(hash, |p| p.as_ref().key().borrow() == key)
209            {
210                let mut ptr = *ptr;
211                self.detach(ptr);
212                self.attach(ptr);
213                Some(ptr.as_mut().value_mut())
214            } else {
215                None
216            }
217        }
218    }
219
220    pub fn peek<'a, Q>(&'a self, key: &Q) -> Option<&'a V>
221    where
222        K: Borrow<Q>,
223        Q: Hash + Eq + ?Sized,
224    {
225        unsafe {
226            let key = key.borrow();
227            let hash = self.hash_builder.hash_one(key);
228            self.map
229                .find(hash, |p| p.as_ref().key().borrow() == key)
230                .map(|ptr| ptr.as_ref().value())
231        }
232    }
233
234    pub fn peek_mut<'a, Q>(&'a mut self, key: &Q) -> Option<&'a mut V>
235    where
236        K: Borrow<Q>,
237        Q: Hash + Eq + ?Sized,
238    {
239        unsafe {
240            let key = key.borrow();
241            let hash = self.hash_builder.hash_one(key);
242            self.map
243                .find(hash, |p| p.as_ref().key().borrow() == key)
244                .map(|ptr| ptr.clone().as_mut().value_mut())
245        }
246    }
247
248    pub fn contains<Q>(&self, key: &Q) -> bool
249    where
250        K: Borrow<Q>,
251        Q: Hash + Eq + ?Sized,
252    {
253        unsafe {
254            let key = key.borrow();
255            let hash = self.hash_builder.hash_one(key);
256            self.map
257                .find(hash, |p| p.as_ref().key().borrow() == key)
258                .is_some()
259        }
260    }
261
262    pub fn len(&self) -> usize {
263        self.map.len()
264    }
265
266    pub fn is_empty(&self) -> bool {
267        self.len() == 0
268    }
269
270    /// Pop first entry if its sequence is less than the given sequence.
271    pub fn pop_with_sequence(&mut self, sequence: Sequence) -> Option<(K, V, Sequence)> {
272        unsafe {
273            if self.is_empty() {
274                return None;
275            }
276
277            let ptr = self.dummy.next.unwrap_unchecked();
278            if ptr.as_ref().sequence >= sequence {
279                return None;
280            }
281
282            self.detach(ptr);
283
284            let entry = Box::from_raw_in(ptr.as_ptr(), self.alloc.clone());
285
286            let key = entry.key.assume_init();
287            let value = entry.value.assume_init();
288            let sequence = entry.sequence;
289
290            let hash = self.hash_builder.hash_one(&key);
291
292            match self
293                .map
294                .entry(hash, |p| p.as_ref().key() == &key, |p| p.as_ref().hash)
295            {
296                Entry::Occupied(o) => {
297                    o.remove();
298                }
299                Entry::Vacant(_) => {}
300            }
301
302            Some((key, value, sequence))
303        }
304    }
305
306    pub fn clear(&mut self) {
307        unsafe {
308            let mut map = HashTable::new_in(self.alloc.clone());
309            std::mem::swap(&mut map, &mut self.map);
310
311            for ptr in map.drain() {
312                self.detach(ptr);
313                let mut entry = Box::from_raw_in(ptr.as_ptr(), self.alloc.clone());
314                entry.key.assume_init_drop();
315                entry.value.assume_init_drop();
316            }
317
318            debug_assert!(self.is_empty());
319            debug_assert_eq!(
320                self.dummy.as_mut() as *mut _,
321                self.dummy.next.unwrap_unchecked().as_ptr()
322            )
323        }
324    }
325
326    fn detach(&mut self, mut ptr: NonNull<LruEntry<K, V>>) {
327        unsafe {
328            let entry = ptr.as_mut();
329
330            debug_assert!(entry.prev.is_some() && entry.next.is_some());
331
332            entry.prev.unwrap_unchecked().as_mut().next = entry.next;
333            entry.next.unwrap_unchecked().as_mut().prev = entry.prev;
334
335            entry.next = None;
336            entry.prev = None;
337        }
338    }
339
340    fn attach(&mut self, mut ptr: NonNull<LruEntry<K, V>>) {
341        unsafe {
342            let entry = ptr.as_mut();
343
344            debug_assert!(entry.prev.is_none() && entry.next.is_none());
345
346            entry.next = Some(NonNull::new_unchecked(self.dummy.as_mut() as *mut _));
347            entry.prev = self.dummy.prev;
348
349            self.dummy.prev.unwrap_unchecked().as_mut().next = Some(ptr);
350            self.dummy.prev = Some(ptr);
351
352            entry.sequence = SEQUENCER.with(|s| s.borrow_mut().alloc());
353        }
354    }
355}
356
357impl<K, V, S, A> Drop for LruCache<K, V, S, A>
358where
359    K: Hash + Eq,
360    S: BuildHasher + Send + Sync + 'static,
361    A: Clone + Allocator,
362{
363    fn drop(&mut self) {
364        self.clear()
365    }
366}
367
368#[cfg(test)]
369mod tests {}