risingwave_common/
lru.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::alloc::{Allocator, Global};
16use std::borrow::Borrow;
17use std::cell::RefCell;
18use std::hash::{BuildHasher, Hash};
19use std::mem::MaybeUninit;
20use std::ptr::NonNull;
21use std::sync::atomic::Ordering;
22
23pub use ahash::RandomState;
24use hashbrown::HashTable;
25use hashbrown::hash_table::Entry;
26
27use crate::sequence::{AtomicSequence, Sequence, Sequencer};
28
29thread_local! {
30    pub static SEQUENCER: RefCell<Sequencer> = const { RefCell::new(Sequencer::new(Sequencer::DEFAULT_STEP, Sequencer::DEFAULT_LAG)) };
31}
32
33static SEQUENCER_DEFAULT_STEP: AtomicSequence = AtomicSequence::new(Sequencer::DEFAULT_STEP);
34static SEQUENCER_DEFAULT_LAG: AtomicSequence = AtomicSequence::new(Sequencer::DEFAULT_LAG);
35
36pub fn init_global_sequencer_args(step: Sequence, lag: Sequence) {
37    SEQUENCER_DEFAULT_STEP.store(step, Ordering::Relaxed);
38    SEQUENCER_DEFAULT_LAG.store(lag, Ordering::Relaxed);
39}
40
41struct LruEntry<K, V>
42where
43    K: Hash + Eq,
44{
45    prev: Option<NonNull<LruEntry<K, V>>>,
46    next: Option<NonNull<LruEntry<K, V>>>,
47    key: MaybeUninit<K>,
48    value: MaybeUninit<V>,
49    hash: u64,
50    sequence: Sequence,
51}
52
53impl<K, V> LruEntry<K, V>
54where
55    K: Hash + Eq,
56{
57    fn key(&self) -> &K {
58        unsafe { self.key.assume_init_ref() }
59    }
60
61    fn value(&self) -> &V {
62        unsafe { self.value.assume_init_ref() }
63    }
64
65    fn value_mut(&mut self) -> &mut V {
66        unsafe { self.value.assume_init_mut() }
67    }
68}
69
70unsafe impl<K, V> Send for LruEntry<K, V> where K: Hash + Eq {}
71unsafe impl<K, V> Sync for LruEntry<K, V> where K: Hash + Eq {}
72
73pub struct LruCache<K, V, S = RandomState, A = Global>
74where
75    K: Hash + Eq,
76    S: BuildHasher + Send + Sync + 'static,
77    A: Clone + Allocator,
78{
79    map: HashTable<NonNull<LruEntry<K, V>>, A>,
80    /// dummy node of the lru linked list
81    dummy: Box<LruEntry<K, V>, A>,
82
83    alloc: A,
84    hash_builder: S,
85}
86
87unsafe impl<K, V, S, A> Send for LruCache<K, V, S, A>
88where
89    K: Hash + Eq,
90    S: BuildHasher + Send + Sync + 'static,
91    A: Clone + Allocator,
92{
93}
94unsafe impl<K, V, S, A> Sync for LruCache<K, V, S, A>
95where
96    K: Hash + Eq,
97    S: BuildHasher + Send + Sync + 'static,
98    A: Clone + Allocator,
99{
100}
101
102impl<K, V> LruCache<K, V>
103where
104    K: Hash + Eq,
105{
106    pub fn unbounded() -> Self {
107        Self::unbounded_with_hasher_in(RandomState::default(), Global)
108    }
109}
110
111impl<K, V, S, A> LruCache<K, V, S, A>
112where
113    K: Hash + Eq,
114    S: BuildHasher + Send + Sync + 'static,
115    A: Clone + Allocator,
116{
117    pub fn unbounded_with_hasher_in(hash_builder: S, alloc: A) -> Self {
118        let map = HashTable::new_in(alloc.clone());
119        let mut dummy = Box::new_in(
120            LruEntry {
121                prev: None,
122                next: None,
123                key: MaybeUninit::uninit(),
124                value: MaybeUninit::uninit(),
125                hash: 0,
126                sequence: Sequence::default(),
127            },
128            alloc.clone(),
129        );
130        let ptr = unsafe { NonNull::new_unchecked(dummy.as_mut() as *mut _) };
131        dummy.next = Some(ptr);
132        dummy.prev = Some(ptr);
133        Self {
134            map,
135            dummy,
136            alloc,
137            hash_builder,
138        }
139    }
140
141    pub fn put(&mut self, key: K, mut value: V) -> Option<V> {
142        unsafe {
143            let hash = self.hash_builder.hash_one(&key);
144
145            match self
146                .map
147                .entry(hash, |p| p.as_ref().key() == &key, |p| p.as_ref().hash)
148            {
149                Entry::Occupied(o) => {
150                    let mut ptr = *o.get();
151                    let entry = ptr.as_mut();
152                    std::mem::swap(&mut value, entry.value_mut());
153                    Self::detach(ptr);
154                    self.attach(ptr);
155                    Some(value)
156                }
157                Entry::Vacant(v) => {
158                    let entry = Box::new_in(
159                        LruEntry {
160                            prev: None,
161                            next: None,
162                            key: MaybeUninit::new(key),
163                            value: MaybeUninit::new(value),
164                            hash,
165                            // sequence will be updated by `attach`
166                            sequence: 0,
167                        },
168                        self.alloc.clone(),
169                    );
170                    let ptr = NonNull::new_unchecked(Box::into_raw(entry));
171                    v.insert(ptr);
172                    self.attach(ptr);
173                    None
174                }
175            }
176        }
177    }
178
179    pub fn remove(&mut self, key: &K) -> Option<V> {
180        unsafe {
181            let hash = self.hash_builder.hash_one(key);
182
183            match self
184                .map
185                .entry(hash, |p| p.as_ref().key() == key, |p| p.as_ref().hash)
186            {
187                Entry::Occupied(o) => {
188                    let ptr = *o.get();
189
190                    // Detach the entry from the LRU list
191                    Self::detach(ptr);
192
193                    // Extract entry from the box and get its value
194                    let mut entry = Box::from_raw_in(ptr.as_ptr(), self.alloc.clone());
195                    entry.key.assume_init_drop();
196                    let value = entry.value.assume_init();
197
198                    // Remove entry from the hash table
199                    o.remove();
200
201                    Some(value)
202                }
203                Entry::Vacant(_) => None,
204            }
205        }
206    }
207
208    pub fn get<'a, Q>(&'a mut self, key: &Q) -> Option<&'a V>
209    where
210        K: Borrow<Q>,
211        Q: Hash + Eq + ?Sized,
212    {
213        unsafe {
214            let key = key.borrow();
215            let hash = self.hash_builder.hash_one(key);
216            if let Some(ptr) = self.map.find(hash, |p| p.as_ref().key().borrow() == key) {
217                let ptr = *ptr;
218                Self::detach(ptr);
219                self.attach(ptr);
220                Some(ptr.as_ref().value())
221            } else {
222                None
223            }
224        }
225    }
226
227    pub fn get_mut<'a, Q>(&'a mut self, key: &Q) -> Option<&'a mut V>
228    where
229        K: Borrow<Q>,
230        Q: Hash + Eq + ?Sized,
231    {
232        unsafe {
233            let key = key.borrow();
234            let hash = self.hash_builder.hash_one(key);
235            if let Some(ptr) = self
236                .map
237                .find_mut(hash, |p| p.as_ref().key().borrow() == key)
238            {
239                let mut ptr = *ptr;
240                Self::detach(ptr);
241                self.attach(ptr);
242                Some(ptr.as_mut().value_mut())
243            } else {
244                None
245            }
246        }
247    }
248
249    pub fn peek<'a, Q>(&'a self, key: &Q) -> Option<&'a V>
250    where
251        K: Borrow<Q>,
252        Q: Hash + Eq + ?Sized,
253    {
254        unsafe {
255            let key = key.borrow();
256            let hash = self.hash_builder.hash_one(key);
257            self.map
258                .find(hash, |p| p.as_ref().key().borrow() == key)
259                .map(|ptr| ptr.as_ref().value())
260        }
261    }
262
263    pub fn peek_mut<'a, Q>(&'a mut self, key: &Q) -> Option<&'a mut V>
264    where
265        K: Borrow<Q>,
266        Q: Hash + Eq + ?Sized,
267    {
268        unsafe {
269            let key = key.borrow();
270            let hash = self.hash_builder.hash_one(key);
271            self.map
272                .find(hash, |p| p.as_ref().key().borrow() == key)
273                .map(|ptr| ptr.clone().as_mut().value_mut())
274        }
275    }
276
277    pub fn contains<Q>(&self, key: &Q) -> bool
278    where
279        K: Borrow<Q>,
280        Q: Hash + Eq + ?Sized,
281    {
282        unsafe {
283            let key = key.borrow();
284            let hash = self.hash_builder.hash_one(key);
285            self.map
286                .find(hash, |p| p.as_ref().key().borrow() == key)
287                .is_some()
288        }
289    }
290
291    pub fn len(&self) -> usize {
292        self.map.len()
293    }
294
295    pub fn is_empty(&self) -> bool {
296        self.len() == 0
297    }
298
299    /// Pop first entry if its sequence is less than the given sequence.
300    pub fn pop_with_sequence(&mut self, sequence: Sequence) -> Option<(K, V, Sequence)> {
301        unsafe {
302            if self.is_empty() {
303                return None;
304            }
305
306            let ptr = self.dummy.next.unwrap_unchecked();
307            if ptr.as_ref().sequence >= sequence {
308                return None;
309            }
310
311            Self::detach(ptr);
312
313            let entry = Box::from_raw_in(ptr.as_ptr(), self.alloc.clone());
314
315            let key = entry.key.assume_init();
316            let value = entry.value.assume_init();
317            let sequence = entry.sequence;
318
319            let hash = self.hash_builder.hash_one(&key);
320
321            match self
322                .map
323                .entry(hash, |p| p.as_ref().key() == &key, |p| p.as_ref().hash)
324            {
325                Entry::Occupied(o) => {
326                    o.remove();
327                }
328                Entry::Vacant(_) => {}
329            }
330
331            Some((key, value, sequence))
332        }
333    }
334
335    pub fn clear(&mut self) {
336        unsafe {
337            let mut map = HashTable::new_in(self.alloc.clone());
338            std::mem::swap(&mut map, &mut self.map);
339
340            for ptr in map.drain() {
341                Self::detach(ptr);
342                let mut entry = Box::from_raw_in(ptr.as_ptr(), self.alloc.clone());
343                entry.key.assume_init_drop();
344                entry.value.assume_init_drop();
345            }
346
347            debug_assert!(self.is_empty());
348            debug_assert_eq!(
349                self.dummy.as_mut() as *mut _,
350                self.dummy.next.unwrap_unchecked().as_ptr()
351            )
352        }
353    }
354
355    fn detach(mut ptr: NonNull<LruEntry<K, V>>) {
356        unsafe {
357            let entry = ptr.as_mut();
358
359            debug_assert!(entry.prev.is_some() && entry.next.is_some());
360
361            entry.prev.unwrap_unchecked().as_mut().next = entry.next;
362            entry.next.unwrap_unchecked().as_mut().prev = entry.prev;
363
364            entry.next = None;
365            entry.prev = None;
366        }
367    }
368
369    fn attach(&mut self, mut ptr: NonNull<LruEntry<K, V>>) {
370        unsafe {
371            let entry = ptr.as_mut();
372
373            debug_assert!(entry.prev.is_none() && entry.next.is_none());
374
375            entry.next = Some(NonNull::new_unchecked(self.dummy.as_mut() as *mut _));
376            entry.prev = self.dummy.prev;
377
378            self.dummy.prev.unwrap_unchecked().as_mut().next = Some(ptr);
379            self.dummy.prev = Some(ptr);
380
381            entry.sequence = SEQUENCER.with(|s| s.borrow_mut().alloc());
382        }
383    }
384}
385
386impl<K, V, S, A> Drop for LruCache<K, V, S, A>
387where
388    K: Hash + Eq,
389    S: BuildHasher + Send + Sync + 'static,
390    A: Clone + Allocator,
391{
392    fn drop(&mut self) {
393        self.clear()
394    }
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400
401    #[test]
402    fn test_unbounded() {
403        let cache: LruCache<i32, &str> = LruCache::unbounded();
404        assert!(cache.is_empty());
405        assert_eq!(cache.len(), 0);
406    }
407
408    #[test]
409    fn test_unbounded_with_hasher_in() {
410        let cache: LruCache<i32, &str, RandomState, Global> =
411            LruCache::unbounded_with_hasher_in(RandomState::default(), Global);
412        assert!(cache.is_empty());
413        assert_eq!(cache.len(), 0);
414    }
415
416    #[test]
417    fn test_put() {
418        let mut cache = LruCache::unbounded();
419
420        // Put new entry
421        assert_eq!(cache.put(1, "one"), None);
422        assert_eq!(cache.len(), 1);
423
424        // Update existing entry
425        assert_eq!(cache.put(1, "ONE"), Some("one"));
426        assert_eq!(cache.len(), 1);
427
428        // Multiple entries
429        assert_eq!(cache.put(2, "two"), None);
430        assert_eq!(cache.len(), 2);
431    }
432
433    #[test]
434    fn test_remove() {
435        let mut cache = LruCache::unbounded();
436
437        // Remove non-existent key
438        assert_eq!(cache.remove(&1), None);
439
440        // Remove existing key
441        cache.put(1, "one");
442        assert_eq!(cache.remove(&1), Some("one"));
443        assert!(cache.is_empty());
444
445        // Remove already removed key
446        assert_eq!(cache.remove(&1), None);
447
448        // Multiple entries
449        cache.put(1, "one");
450        cache.put(2, "two");
451        cache.put(3, "three");
452        assert_eq!(cache.remove(&2), Some("two"));
453        assert_eq!(cache.len(), 2);
454        assert!(!cache.contains(&2));
455    }
456
457    #[test]
458    fn test_get() {
459        let mut cache = LruCache::unbounded();
460
461        // Get non-existent key
462        assert_eq!(cache.get(&1), None);
463
464        // Get existing key
465        cache.put(1, "one");
466        assert_eq!(cache.get(&1), Some(&"one"));
467
468        // Check LRU order updated after get
469        cache.put(2, "two");
470        let _ = cache.get(&1); // Moves 1 to most recently used
471
472        // Verify LRU order by using pop_with_sequence
473        let (key, _, _) = cache.pop_with_sequence(u64::MAX).unwrap();
474        assert_eq!(key, 2); // key 2 should be least recently used
475    }
476
477    #[test]
478    fn test_get_mut() {
479        let mut cache = LruCache::unbounded();
480
481        // Get_mut non-existent key
482        assert_eq!(cache.get_mut(&1), None);
483
484        // Get_mut and modify existing key
485        cache.put(1, String::from("one"));
486        {
487            let val = cache.get_mut(&1).unwrap();
488            *val = String::from("ONE");
489        }
490        assert_eq!(cache.get(&1), Some(&String::from("ONE")));
491
492        // Check LRU order updated after get_mut
493        cache.put(2, String::from("two"));
494        let _ = cache.get_mut(&1); // Moves 1 to most recently used
495
496        // Verify LRU order by using pop_with_sequence
497        let (key, _, _) = cache.pop_with_sequence(u64::MAX).unwrap();
498        assert_eq!(key, 2); // key 2 should be least recently used
499    }
500
501    #[test]
502    fn test_peek() {
503        let mut cache = LruCache::unbounded();
504
505        // Peek non-existent key
506        assert_eq!(cache.peek(&1), None);
507
508        // Peek existing key
509        cache.put(1, "one");
510        cache.put(2, "two");
511        assert_eq!(cache.peek(&1), Some(&"one"));
512
513        // Verify LRU order NOT updated after peek
514        let (key, _, _) = cache.pop_with_sequence(u64::MAX).unwrap();
515        assert_eq!(key, 1); // key 1 should still be least recently used
516    }
517
518    #[test]
519    fn test_peek_mut() {
520        let mut cache = LruCache::unbounded();
521
522        // Peek_mut non-existent key
523        assert_eq!(cache.peek_mut(&1), None);
524
525        // Peek_mut and modify existing key
526        cache.put(1, String::from("one"));
527        cache.put(2, String::from("two"));
528        {
529            let val = cache.peek_mut(&1).unwrap();
530            *val = String::from("ONE");
531        }
532        assert_eq!(cache.peek(&1), Some(&String::from("ONE")));
533
534        // Verify LRU order NOT updated after peek_mut
535        let (key, _, _) = cache.pop_with_sequence(u64::MAX).unwrap();
536        assert_eq!(key, 1); // key 1 should still be least recently used
537    }
538
539    #[test]
540    fn test_contains() {
541        let mut cache = LruCache::unbounded();
542
543        // Contains on empty cache
544        assert!(!cache.contains(&1));
545
546        // Contains after put
547        cache.put(1, "one");
548        assert!(cache.contains(&1));
549
550        // Contains after remove
551        cache.remove(&1);
552        assert!(!cache.contains(&1));
553    }
554
555    #[test]
556    fn test_len_and_is_empty() {
557        let mut cache = LruCache::unbounded();
558
559        // Empty cache
560        assert!(cache.is_empty());
561        assert_eq!(cache.len(), 0);
562
563        // Non-empty cache
564        cache.put(1, "one");
565        assert!(!cache.is_empty());
566        assert_eq!(cache.len(), 1);
567
568        // After multiple operations
569        cache.put(2, "two");
570        assert_eq!(cache.len(), 2);
571        cache.remove(&1);
572        assert_eq!(cache.len(), 1);
573        cache.clear();
574        assert!(cache.is_empty());
575        assert_eq!(cache.len(), 0);
576    }
577
578    #[test]
579    fn test_clear() {
580        let mut cache = LruCache::unbounded();
581
582        // Clear empty cache
583        cache.clear();
584        assert!(cache.is_empty());
585
586        // Clear non-empty cache
587        cache.put(1, "one");
588        cache.put(2, "two");
589        cache.put(3, "three");
590        assert_eq!(cache.len(), 3);
591        cache.clear();
592        assert!(cache.is_empty());
593        assert_eq!(cache.len(), 0);
594        assert!(!cache.contains(&1));
595        assert!(!cache.contains(&2));
596        assert!(!cache.contains(&3));
597    }
598
599    #[test]
600    fn test_lru_behavior() {
601        let mut cache = LruCache::unbounded();
602
603        // Insert in order
604        cache.put(1, "one");
605        cache.put(2, "two");
606        cache.put(3, "three");
607
608        // Manipulate LRU order
609        let _ = cache.get(&1); // Moves 1 to most recently used
610
611        // Check order: 2->3->1
612        let (key, _, _) = cache.pop_with_sequence(u64::MAX).unwrap();
613        assert_eq!(key, 2);
614        let (key, _, _) = cache.pop_with_sequence(u64::MAX).unwrap();
615        assert_eq!(key, 3);
616        let (key, _, _) = cache.pop_with_sequence(u64::MAX).unwrap();
617        assert_eq!(key, 1);
618        assert!(cache.is_empty());
619    }
620}