risingwave_common/
cache.rs

1//  Copyright 2025 RisingWave Labs
2//
3//  Licensed under the Apache License, Version 2.0 (the "License");
4//  you may not use this file except in compliance with the License.
5//  You may obtain a copy of the License at
6//
7//  http://www.apache.org/licenses/LICENSE-2.0
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14//
15// Copyright (c) 2011-present, Facebook, Inc.  All rights reserved.
16// This source code is licensed under both the GPLv2 (found in the
17// COPYING file in the root directory) and Apache 2.0 License
18// (found in the LICENSE.Apache file in the root directory).
19
20//! `LruCache` implementation port from github.com/facebook/rocksdb. The class `LruCache` is
21//! thread-safe, because every operation on cache will be protected by a spin lock.
22use std::collections::HashMap;
23use std::error::Error;
24use std::future::Future;
25use std::hash::Hash;
26use std::ops::Deref;
27use std::ptr;
28use std::ptr::null_mut;
29use std::sync::Arc;
30use std::sync::atomic::{AtomicUsize, Ordering};
31
32use futures::FutureExt;
33use parking_lot::Mutex;
34use tokio::sync::oneshot::error::RecvError;
35use tokio::sync::oneshot::{Receiver, Sender, channel};
36use tokio::task::JoinHandle;
37
38const IN_CACHE: u8 = 1;
39const REVERSE_IN_CACHE: u8 = !IN_CACHE;
40
41#[cfg(debug_assertions)]
42const IN_LRU: u8 = 1 << 1;
43#[cfg(debug_assertions)]
44const REVERSE_IN_LRU: u8 = !IN_LRU;
45const IS_HIGH_PRI: u8 = 1 << 2;
46const IN_HIGH_PRI_POOL: u8 = 1 << 3;
47
48pub trait LruKey: Eq + Send + Hash {}
49impl<T: Eq + Send + Hash> LruKey for T {}
50
51pub trait LruValue: Send + Sync {}
52impl<T: Send + Sync> LruValue for T {}
53
54#[derive(Clone, Copy, Eq, PartialEq)]
55pub enum CachePriority {
56    High,
57    Low,
58}
59
60/// An entry is a variable length heap-allocated structure.
61/// Entries are referenced by cache and/or by any external entity.
62/// The cache keeps all its entries in a hash table. Some elements
63/// are also stored on LRU list.
64///
65/// `LruHandle` can be in these states:
66/// 1. Referenced externally AND in hash table.
67///    In that case the entry is *not* in the LRU list
68///    (`refs` >= 1 && `in_cache` == true)
69/// 2. Not referenced externally AND in hash table.
70///    In that case the entry is in the LRU list and can be freed.
71///    (`refs` == 0 && `in_cache` == true)
72/// 3. Referenced externally AND not in hash table.
73///    In that case the entry is not in the LRU list and not in hash table.
74///    The entry can be freed when refs becomes 0.
75///    (`refs` >= 1 && `in_cache` == false)
76///
77/// All newly created `LruHandle`s are in state 1. If you call
78/// `LruCacheShard::release` on entry in state 1, it will go into state 2.
79/// To move from state 1 to state 3, either call `LruCacheShard::erase` or
80/// `LruCacheShard::insert` with the same key (but possibly different value).
81/// To move from state 2 to state 1, use `LruCacheShard::lookup`.
82/// Before destruction, make sure that no handles are in state 1. This means
83/// that any successful `LruCacheShard::lookup/LruCacheShard::insert` have a
84/// matching `LruCache::release` (to move into state 2) or `LruCacheShard::erase`
85/// (to move into state 3).
86pub struct LruHandle<K: LruKey, T: LruValue> {
87    /// next element in the linked-list of hash bucket, only used by hash-table.
88    next_hash: *mut LruHandle<K, T>,
89
90    /// next element in LRU linked list
91    next: *mut LruHandle<K, T>,
92
93    /// prev element in LRU linked list
94    prev: *mut LruHandle<K, T>,
95
96    /// When the handle is on-use, the fields is `Some(...)`, while the handle is cleared up and
97    /// recycled, the field is `None`.
98    kv: Option<(K, T)>,
99    hash: u64,
100    charge: usize,
101
102    /// The count for external references. If `refs > 0`, the handle is not in the lru cache, and
103    /// when `refs == 0`, the handle must either be in LRU cache or has been recycled.
104    refs: u32,
105    flags: u8,
106}
107
108impl<K: LruKey, T: LruValue> Default for LruHandle<K, T> {
109    fn default() -> Self {
110        Self {
111            next_hash: null_mut(),
112            next: null_mut(),
113            prev: null_mut(),
114            kv: None,
115            hash: 0,
116            charge: 0,
117            refs: 0,
118            flags: 0,
119        }
120    }
121}
122
123impl<K: LruKey, T: LruValue> LruHandle<K, T> {
124    pub fn new(key: K, value: T, hash: u64, charge: usize) -> Self {
125        let mut ret = Self::default();
126        ret.init(key, value, hash, charge);
127        ret
128    }
129
130    pub fn init(&mut self, key: K, value: T, hash: u64, charge: usize) {
131        self.next_hash = null_mut();
132        self.prev = null_mut();
133        self.next = null_mut();
134        self.kv = Some((key, value));
135        self.hash = hash;
136        self.charge = charge;
137        self.flags = 0;
138        self.refs = 0;
139    }
140
141    /// Set the `in_cache` bit in the flag
142    ///
143    /// Since only `in_cache` reflects whether the handle is present in the hash table, this method
144    /// should only be called in the method of hash table. Whenever the handle enters the hash
145    /// table, we should call `set_in_cache(true)`, and whenever the handle leaves the hash table,
146    /// we should call `set_in_cache(false)`
147    fn set_in_cache(&mut self, in_cache: bool) {
148        if in_cache {
149            self.flags |= IN_CACHE;
150        } else {
151            self.flags &= REVERSE_IN_CACHE;
152        }
153    }
154
155    fn is_high_priority(&self) -> bool {
156        (self.flags & IS_HIGH_PRI) > 0
157    }
158
159    fn set_high_priority(&mut self, high_priority: bool) {
160        if high_priority {
161            self.flags |= IS_HIGH_PRI;
162        } else {
163            self.flags &= !IS_HIGH_PRI;
164        }
165    }
166
167    fn set_in_high_pri_pool(&mut self, in_high_pri_pool: bool) {
168        if in_high_pri_pool {
169            self.flags |= IN_HIGH_PRI_POOL;
170        } else {
171            self.flags &= !IN_HIGH_PRI_POOL;
172        }
173    }
174
175    fn is_in_high_pri_pool(&self) -> bool {
176        (self.flags & IN_HIGH_PRI_POOL) > 0
177    }
178
179    fn add_ref(&mut self) {
180        self.refs += 1;
181    }
182
183    fn add_multi_refs(&mut self, ref_count: u32) {
184        self.refs += ref_count;
185    }
186
187    fn unref(&mut self) -> bool {
188        debug_assert!(self.refs > 0);
189        self.refs -= 1;
190        self.refs == 0
191    }
192
193    fn has_refs(&self) -> bool {
194        self.refs > 0
195    }
196
197    /// Test whether the handle is in cache. `in cache` is equivalent to that the handle is in the
198    /// hash table.
199    fn is_in_cache(&self) -> bool {
200        (self.flags & IN_CACHE) > 0
201    }
202
203    unsafe fn get_key(&self) -> &K {
204        unsafe {
205            debug_assert!(self.kv.is_some());
206            &self.kv.as_ref().unwrap_unchecked().0
207        }
208    }
209
210    unsafe fn get_value(&self) -> &T {
211        unsafe {
212            debug_assert!(self.kv.is_some());
213            &self.kv.as_ref().unwrap_unchecked().1
214        }
215    }
216
217    unsafe fn is_same_key(&self, key: &K) -> bool {
218        unsafe {
219            debug_assert!(self.kv.is_some());
220            self.kv.as_ref().unwrap_unchecked().0.eq(key)
221        }
222    }
223
224    unsafe fn take_kv(&mut self) -> (K, T) {
225        unsafe {
226            debug_assert!(self.kv.is_some());
227            self.kv.take().unwrap_unchecked()
228        }
229    }
230
231    #[cfg(debug_assertions)]
232    fn is_in_lru(&self) -> bool {
233        (self.flags & IN_LRU) > 0
234    }
235
236    #[cfg(debug_assertions)]
237    fn set_in_lru(&mut self, in_lru: bool) {
238        if in_lru {
239            self.flags |= IN_LRU;
240        } else {
241            self.flags &= REVERSE_IN_LRU;
242        }
243    }
244}
245
246unsafe impl<K: LruKey, T: LruValue> Send for LruHandle<K, T> {}
247
248pub struct LruHandleTable<K: LruKey, T: LruValue> {
249    list: Vec<*mut LruHandle<K, T>>,
250    elems: usize,
251}
252
253impl<K: LruKey, T: LruValue> LruHandleTable<K, T> {
254    fn new() -> Self {
255        Self {
256            list: vec![null_mut(); 16],
257            elems: 0,
258        }
259    }
260
261    // A util method that is only used internally in this struct.
262    unsafe fn find_pointer(
263        &self,
264        idx: usize,
265        key: &K,
266    ) -> (*mut LruHandle<K, T>, *mut LruHandle<K, T>) {
267        unsafe {
268            let mut ptr = self.list[idx];
269            let mut prev = null_mut();
270            while !ptr.is_null() && !(*ptr).is_same_key(key) {
271                prev = ptr;
272                ptr = (*ptr).next_hash;
273            }
274            (prev, ptr)
275        }
276    }
277
278    unsafe fn remove(&mut self, hash: u64, key: &K) -> *mut LruHandle<K, T> {
279        unsafe {
280            debug_assert!(self.list.len().is_power_of_two());
281            let idx = (hash as usize) & (self.list.len() - 1);
282            let (prev, ptr) = self.find_pointer(idx, key);
283            if ptr.is_null() {
284                return null_mut();
285            }
286            debug_assert!((*ptr).is_in_cache());
287            (*ptr).set_in_cache(false);
288            if prev.is_null() {
289                self.list[idx] = (*ptr).next_hash;
290            } else {
291                (*prev).next_hash = (*ptr).next_hash;
292            }
293            self.elems -= 1;
294            ptr
295        }
296    }
297
298    /// Insert a handle into the hash table. Return the handle of the previous value if the key
299    /// exists.
300    unsafe fn insert(&mut self, hash: u64, h: *mut LruHandle<K, T>) -> *mut LruHandle<K, T> {
301        unsafe {
302            debug_assert!(!h.is_null());
303            debug_assert!(!(*h).is_in_cache());
304            (*h).set_in_cache(true);
305            debug_assert!(self.list.len().is_power_of_two());
306            let idx = (hash as usize) & (self.list.len() - 1);
307            let (prev, ptr) = self.find_pointer(idx, (*h).get_key());
308            if prev.is_null() {
309                self.list[idx] = h;
310            } else {
311                (*prev).next_hash = h;
312            }
313
314            if !ptr.is_null() {
315                debug_assert!((*ptr).is_same_key((*h).get_key()));
316                debug_assert!((*ptr).is_in_cache());
317                // The handle to be removed is set not in cache.
318                (*ptr).set_in_cache(false);
319                (*h).next_hash = (*ptr).next_hash;
320                return ptr;
321            }
322
323            (*h).next_hash = ptr;
324
325            self.elems += 1;
326            if self.elems > self.list.len() {
327                self.resize();
328            }
329            null_mut()
330        }
331    }
332
333    unsafe fn lookup(&self, hash: u64, key: &K) -> *mut LruHandle<K, T> {
334        unsafe {
335            debug_assert!(self.list.len().is_power_of_two());
336            let idx = (hash as usize) & (self.list.len() - 1);
337            let (_, ptr) = self.find_pointer(idx, key);
338            ptr
339        }
340    }
341
342    unsafe fn resize(&mut self) {
343        unsafe {
344            let mut l = std::cmp::max(16, self.list.len());
345            let next_capacity = self.elems * 3 / 2;
346            while l < next_capacity {
347                l <<= 1;
348            }
349            let mut count = 0;
350            let mut new_list = vec![null_mut(); l];
351            for head in self.list.drain(..) {
352                let mut handle = head;
353                while !handle.is_null() {
354                    let idx = (*handle).hash as usize & (l - 1);
355                    let next = (*handle).next_hash;
356                    (*handle).next_hash = new_list[idx];
357                    new_list[idx] = handle;
358                    handle = next;
359                    count += 1;
360                }
361            }
362            assert_eq!(count, self.elems);
363            self.list = new_list;
364        }
365    }
366
367    unsafe fn for_all<F>(&self, f: &mut F)
368    where
369        F: FnMut(&K, &T),
370    {
371        unsafe {
372            for idx in 0..self.list.len() {
373                let mut ptr = self.list[idx];
374                while !ptr.is_null() {
375                    f((*ptr).get_key(), (*ptr).get_value());
376                    ptr = (*ptr).next_hash;
377                }
378            }
379        }
380    }
381}
382
383type RequestQueue<K, T> = Vec<Sender<CacheableEntry<K, T>>>;
384pub struct LruCacheShard<K: LruKey, T: LruValue> {
385    /// The dummy header node of a ring linked list. The linked list is a LRU list, holding the
386    /// cache handles that are not used externally. lru.prev point to the head of linked list while
387    ///  lru.next point to the tail of linked-list. Every time when the usage of cache reaches
388    /// capacity  we will remove lru.next at first.
389    lru: Box<LruHandle<K, T>>,
390    low_priority_head: *mut LruHandle<K, T>,
391    high_priority_pool_capacity: usize,
392    high_priority_pool_usage: usize,
393    table: LruHandleTable<K, T>,
394    // TODO: may want to use an atomic object linked list shared by all shards.
395    object_pool: Vec<Box<LruHandle<K, T>>>,
396    write_request: HashMap<K, RequestQueue<K, T>>,
397    lru_usage: Arc<AtomicUsize>,
398    usage: Arc<AtomicUsize>,
399    capacity: usize,
400}
401
402unsafe impl<K: LruKey, T: LruValue> Send for LruCacheShard<K, T> {}
403
404impl<K: LruKey, T: LruValue> LruCacheShard<K, T> {
405    // high_priority_ratio_percent 100 means 100%
406    fn new_with_priority_pool(capacity: usize, high_priority_ratio_percent: usize) -> Self {
407        let mut lru = Box::<LruHandle<K, T>>::default();
408        lru.prev = lru.as_mut();
409        lru.next = lru.as_mut();
410        let mut object_pool = Vec::with_capacity(DEFAULT_OBJECT_POOL_SIZE);
411        for _ in 0..DEFAULT_OBJECT_POOL_SIZE {
412            object_pool.push(Box::default());
413        }
414        Self {
415            capacity,
416            lru_usage: Arc::new(AtomicUsize::new(0)),
417            usage: Arc::new(AtomicUsize::new(0)),
418            object_pool,
419            low_priority_head: lru.as_mut(),
420            high_priority_pool_capacity: high_priority_ratio_percent * capacity / 100,
421            lru,
422            table: LruHandleTable::new(),
423            write_request: HashMap::with_capacity(16),
424            high_priority_pool_usage: 0,
425        }
426    }
427
428    unsafe fn lru_remove(&mut self, e: *mut LruHandle<K, T>) {
429        unsafe {
430            debug_assert!(!e.is_null());
431            #[cfg(debug_assertions)]
432            {
433                assert!((*e).is_in_lru());
434                (*e).set_in_lru(false);
435            }
436
437            if ptr::eq(e, self.low_priority_head) {
438                self.low_priority_head = (*e).prev;
439            }
440
441            (*(*e).next).prev = (*e).prev;
442            (*(*e).prev).next = (*e).next;
443            (*e).prev = null_mut();
444            (*e).next = null_mut();
445            if (*e).is_in_high_pri_pool() {
446                debug_assert!(self.high_priority_pool_usage >= (*e).charge);
447                self.high_priority_pool_usage -= (*e).charge;
448            }
449            self.lru_usage.fetch_sub((*e).charge, Ordering::Relaxed);
450        }
451    }
452
453    // insert entry in the end of the linked-list
454    unsafe fn lru_insert(&mut self, e: *mut LruHandle<K, T>) {
455        unsafe {
456            debug_assert!(!e.is_null());
457            let entry = &mut (*e);
458            #[cfg(debug_assertions)]
459            {
460                assert!(!(*e).is_in_lru());
461                (*e).set_in_lru(true);
462            }
463
464            if self.high_priority_pool_capacity > 0 && entry.is_high_priority() {
465                entry.set_in_high_pri_pool(true);
466                entry.next = self.lru.as_mut();
467                entry.prev = self.lru.prev;
468                (*entry.prev).next = e;
469                (*entry.next).prev = e;
470                self.high_priority_pool_usage += (*e).charge;
471                self.maintain_pool_size();
472            } else {
473                entry.set_in_high_pri_pool(false);
474                entry.next = (*self.low_priority_head).next;
475                entry.prev = self.low_priority_head;
476                (*entry.next).prev = e;
477                (*entry.prev).next = e;
478                self.low_priority_head = e;
479            }
480            self.lru_usage.fetch_add((*e).charge, Ordering::Relaxed);
481        }
482    }
483
484    unsafe fn maintain_pool_size(&mut self) {
485        unsafe {
486            while self.high_priority_pool_usage > self.high_priority_pool_capacity {
487                // overflow last entry in high-pri pool to low-pri pool.
488                self.low_priority_head = (*self.low_priority_head).next;
489                (*self.low_priority_head).set_in_high_pri_pool(false);
490                self.high_priority_pool_usage -= (*self.low_priority_head).charge;
491            }
492        }
493    }
494
495    unsafe fn evict_from_lru(&mut self, charge: usize, last_reference_list: &mut Vec<(K, T)>) {
496        unsafe {
497            // TODO: may want to optimize by only loading at the beginning and storing at the end for
498            // only once.
499            while self.usage.load(Ordering::Relaxed) + charge > self.capacity
500                && !std::ptr::eq(self.lru.next, self.lru.as_mut())
501            {
502                let old_ptr = self.lru.next;
503                self.table.remove((*old_ptr).hash, (*old_ptr).get_key());
504                self.lru_remove(old_ptr);
505                let (key, value) = self.clear_handle(old_ptr);
506                last_reference_list.push((key, value));
507            }
508        }
509    }
510
511    /// Clear a currently used handle and recycle it if possible
512    unsafe fn clear_handle(&mut self, h: *mut LruHandle<K, T>) -> (K, T) {
513        unsafe {
514            debug_assert!(!h.is_null());
515            debug_assert!((*h).kv.is_some());
516            #[cfg(debug_assertions)]
517            assert!(!(*h).is_in_lru());
518            debug_assert!(!(*h).is_in_cache());
519            debug_assert!(!(*h).has_refs());
520            self.usage.fetch_sub((*h).charge, Ordering::Relaxed);
521            let (key, value) = (*h).take_kv();
522            self.try_recycle_handle_object(h);
523            (key, value)
524        }
525    }
526
527    /// Try to recycle a handle object if the object pool is not full.
528    ///
529    /// The handle should already cleared its kv.
530    unsafe fn try_recycle_handle_object(&mut self, h: *mut LruHandle<K, T>) {
531        unsafe {
532            let mut node = Box::from_raw(h);
533            if self.object_pool.len() < self.object_pool.capacity() {
534                node.next_hash = null_mut();
535                node.next = null_mut();
536                node.prev = null_mut();
537                debug_assert!(node.kv.is_none());
538                self.object_pool.push(node);
539            }
540        }
541    }
542
543    /// insert a new key value in the cache. The handle for the new key value is returned.
544    unsafe fn insert(
545        &mut self,
546        key: K,
547        hash: u64,
548        charge: usize,
549        value: T,
550        priority: CachePriority,
551        last_reference_list: &mut Vec<(K, T)>,
552    ) -> *mut LruHandle<K, T> {
553        unsafe {
554            self.evict_from_lru(charge, last_reference_list);
555
556            let mut handle = match self.object_pool.pop() {
557                Some(mut h) => {
558                    h.init(key, value, hash, charge);
559                    h
560                }
561                _ => Box::new(LruHandle::new(key, value, hash, charge)),
562            };
563            if priority == CachePriority::High {
564                handle.set_high_priority(true);
565            }
566            let ptr = Box::into_raw(handle);
567            let old = self.table.insert(hash, ptr);
568            if !old.is_null() {
569                if let Some(data) = self.try_remove_cache_handle(old) {
570                    last_reference_list.push(data);
571                }
572            }
573            self.usage.fetch_add(charge, Ordering::Relaxed);
574            (*ptr).add_ref();
575            ptr
576        }
577    }
578
579    /// Release the usage on a handle.
580    ///
581    /// Return: `Some(value)` if the handle is released, and `None` if the value is still in use.
582    unsafe fn release(&mut self, h: *mut LruHandle<K, T>) -> Option<(K, T)> {
583        unsafe {
584            debug_assert!(!h.is_null());
585            // The handle should not be in lru before calling this method.
586            #[cfg(debug_assertions)]
587            assert!(!(*h).is_in_lru());
588            let last_reference = (*h).unref();
589            // If the handle is still referenced by someone else, do nothing and return.
590            if !last_reference {
591                return None;
592            }
593
594            // Keep the handle in lru list if it is still in the cache and the cache is not over-sized.
595            if (*h).is_in_cache() {
596                if self.usage.load(Ordering::Relaxed) <= self.capacity {
597                    self.lru_insert(h);
598                    return None;
599                }
600                // Remove the handle from table.
601                self.table.remove((*h).hash, (*h).get_key());
602            }
603
604            // Since the released handle was previously used externally, it must not be in LRU, and we
605            // don't need to remove it from lru.
606            #[cfg(debug_assertions)]
607            assert!(!(*h).is_in_lru());
608
609            let (key, value) = self.clear_handle(h);
610            Some((key, value))
611        }
612    }
613
614    unsafe fn lookup(&mut self, hash: u64, key: &K) -> *mut LruHandle<K, T> {
615        unsafe {
616            let e = self.table.lookup(hash, key);
617            if !e.is_null() {
618                // If the handle previously has not ref, it must exist in the lru. And therefore we are
619                // safe to remove it from lru.
620                if !(*e).has_refs() {
621                    self.lru_remove(e);
622                }
623                (*e).add_ref();
624            }
625            e
626        }
627    }
628
629    /// Erase a key from the cache.
630    unsafe fn erase(&mut self, hash: u64, key: &K) -> Option<(K, T)> {
631        unsafe {
632            let h = self.table.remove(hash, key);
633            if !h.is_null() {
634                self.try_remove_cache_handle(h)
635            } else {
636                None
637            }
638        }
639    }
640
641    /// Try removing the handle from the cache if the handle is not used externally any more.
642    ///
643    /// This method can only be called on the handle that just removed from the hash table.
644    unsafe fn try_remove_cache_handle(&mut self, h: *mut LruHandle<K, T>) -> Option<(K, T)> {
645        unsafe {
646            debug_assert!(!h.is_null());
647            if !(*h).has_refs() {
648                // Since the handle is just removed from the hash table, it should either be in lru or
649                // referenced externally. Since we have checked that it is not referenced externally, it
650                // must be in the LRU, and therefore we are safe to call `lru_remove`.
651                self.lru_remove(h);
652                let (key, value) = self.clear_handle(h);
653                return Some((key, value));
654            }
655            None
656        }
657    }
658
659    // Clears the content of the cache.
660    // This method is safe to use only if no cache entries are referenced outside.
661    unsafe fn clear(&mut self) {
662        unsafe {
663            while !std::ptr::eq(self.lru.next, self.lru.as_mut()) {
664                let handle = self.lru.next;
665                // `listener` should not be triggered here, for it doesn't listen to `clear`.
666                self.erase((*handle).hash, (*handle).get_key());
667            }
668        }
669    }
670
671    fn for_all<F>(&self, f: &mut F)
672    where
673        F: FnMut(&K, &T),
674    {
675        unsafe { self.table.for_all(f) };
676    }
677}
678
679impl<K: LruKey, T: LruValue> Drop for LruCacheShard<K, T> {
680    fn drop(&mut self) {
681        // Since the shard is being drop, there must be no cache entries referenced outside. So we
682        // are safe to call clear.
683        unsafe {
684            self.clear();
685        }
686    }
687}
688
689pub trait LruCacheEventListener: Send + Sync {
690    type K: LruKey;
691    type T: LruValue;
692
693    /// `on_release` is called when a cache entry is erased or evicted by a new inserted entry.
694    ///
695    /// Note:
696    /// `on_release` will not be triggered when the `LruCache` and its inner entries are dropped.
697    fn on_release(&self, _key: Self::K, _value: Self::T) {}
698}
699
700pub struct LruCache<K: LruKey, T: LruValue> {
701    shards: Vec<Mutex<LruCacheShard<K, T>>>,
702    shard_usages: Vec<Arc<AtomicUsize>>,
703    shard_lru_usages: Vec<Arc<AtomicUsize>>,
704
705    listener: Option<Arc<dyn LruCacheEventListener<K = K, T = T>>>,
706}
707
708// we only need a small object pool because when the cache reach the limit of capacity, it will
709// always release some object after insert a new block.
710const DEFAULT_OBJECT_POOL_SIZE: usize = 1024;
711
712impl<K: LruKey, T: LruValue> LruCache<K, T> {
713    pub fn new(num_shards: usize, capacity: usize, high_priority_ratio: usize) -> Self {
714        Self::new_inner(num_shards, capacity, high_priority_ratio, None)
715    }
716
717    pub fn with_event_listener(
718        num_shards: usize,
719        capacity: usize,
720        high_priority_ratio: usize,
721        listener: Arc<dyn LruCacheEventListener<K = K, T = T>>,
722    ) -> Self {
723        Self::new_inner(num_shards, capacity, high_priority_ratio, Some(listener))
724    }
725
726    fn new_inner(
727        num_shards: usize,
728        capacity: usize,
729        high_priority_ratio: usize,
730        listener: Option<Arc<dyn LruCacheEventListener<K = K, T = T>>>,
731    ) -> Self {
732        let mut shards = Vec::with_capacity(num_shards);
733        let per_shard = capacity / num_shards;
734        let mut shard_usages = Vec::with_capacity(num_shards);
735        let mut shard_lru_usages = Vec::with_capacity(num_shards);
736        for _ in 0..num_shards {
737            let shard = LruCacheShard::new_with_priority_pool(per_shard, high_priority_ratio);
738            shard_usages.push(shard.usage.clone());
739            shard_lru_usages.push(shard.lru_usage.clone());
740            shards.push(Mutex::new(shard));
741        }
742        Self {
743            shards,
744            shard_usages,
745            shard_lru_usages,
746            listener,
747        }
748    }
749
750    pub fn contains(self: &Arc<Self>, hash: u64, key: &K) -> bool {
751        let shard = self.shards[self.shard(hash)].lock();
752        unsafe {
753            let ptr = shard.table.lookup(hash, key);
754            !ptr.is_null()
755        }
756    }
757
758    pub fn lookup(self: &Arc<Self>, hash: u64, key: &K) -> Option<CacheableEntry<K, T>> {
759        let mut shard = self.shards[self.shard(hash)].lock();
760        unsafe {
761            let ptr = shard.lookup(hash, key);
762            if ptr.is_null() {
763                return None;
764            }
765            let entry = CacheableEntry {
766                cache: self.clone(),
767                handle: ptr,
768            };
769            Some(entry)
770        }
771    }
772
773    pub fn lookup_for_request(self: &Arc<Self>, hash: u64, key: K) -> LookupResult<K, T> {
774        let mut shard = self.shards[self.shard(hash)].lock();
775        unsafe {
776            let ptr = shard.lookup(hash, &key);
777            if !ptr.is_null() {
778                return LookupResult::Cached(CacheableEntry {
779                    cache: self.clone(),
780                    handle: ptr,
781                });
782            }
783            if let Some(que) = shard.write_request.get_mut(&key) {
784                let (tx, recv) = channel();
785                que.push(tx);
786                return LookupResult::WaitPendingRequest(recv);
787            }
788            shard.write_request.insert(key, vec![]);
789            LookupResult::Miss
790        }
791    }
792
793    unsafe fn release(&self, handle: *mut LruHandle<K, T>) {
794        unsafe {
795            debug_assert!(!handle.is_null());
796            let data = {
797                let mut shard = self.shards[self.shard((*handle).hash)].lock();
798                shard.release(handle)
799            };
800            // do not deallocate data with holding mutex.
801            if let Some((key, value)) = data
802                && let Some(listener) = &self.listener
803            {
804                listener.on_release(key, value);
805            }
806        }
807    }
808
809    unsafe fn inc_reference(&self, handle: *mut LruHandle<K, T>) {
810        unsafe {
811            let _shard = self.shards[self.shard((*handle).hash)].lock();
812            (*handle).refs += 1;
813        }
814    }
815
816    pub fn insert(
817        self: &Arc<Self>,
818        key: K,
819        hash: u64,
820        charge: usize,
821        value: T,
822        priority: CachePriority,
823    ) -> CacheableEntry<K, T> {
824        let mut to_delete = vec![];
825        // Drop the entries outside lock to avoid deadlock.
826        let mut senders = vec![];
827        let handle = unsafe {
828            let mut shard = self.shards[self.shard(hash)].lock();
829            let pending_request = shard.write_request.remove(&key);
830            let ptr = shard.insert(key, hash, charge, value, priority, &mut to_delete);
831            debug_assert!(!ptr.is_null());
832            if let Some(mut que) = pending_request {
833                (*ptr).add_multi_refs(que.len() as u32);
834                senders = std::mem::take(&mut que);
835            }
836            CacheableEntry {
837                cache: self.clone(),
838                handle: ptr,
839            }
840        };
841        for sender in senders {
842            let _ = sender.send(CacheableEntry {
843                cache: self.clone(),
844                handle: handle.handle,
845            });
846        }
847
848        // do not deallocate data with holding mutex.
849        if let Some(listener) = &self.listener {
850            for (key, value) in to_delete {
851                listener.on_release(key, value);
852            }
853        }
854        handle
855    }
856
857    pub fn clear_pending_request(&self, key: &K, hash: u64) {
858        let mut shard = self.shards[self.shard(hash)].lock();
859        shard.write_request.remove(key);
860    }
861
862    pub fn erase(&self, hash: u64, key: &K) {
863        let data = unsafe {
864            let mut shard = self.shards[self.shard(hash)].lock();
865            shard.erase(hash, key)
866        };
867        // do not deallocate data with holding mutex.
868        if let Some((key, value)) = data
869            && let Some(listener) = &self.listener
870        {
871            listener.on_release(key, value);
872        }
873    }
874
875    pub fn get_memory_usage(&self) -> usize {
876        self.shard_usages
877            .iter()
878            .map(|x| x.load(Ordering::Relaxed))
879            .sum()
880    }
881
882    pub fn get_lru_usage(&self) -> usize {
883        self.shard_lru_usages
884            .iter()
885            .map(|x| x.load(Ordering::Relaxed))
886            .sum()
887    }
888
889    fn shard(&self, hash: u64) -> usize {
890        hash as usize % self.shards.len()
891    }
892
893    /// # Safety
894    ///
895    /// This method is used for read-only [`LruCache`]. It locks one shard per loop to prevent the
896    /// iterating progress from blocking reads among all shards.
897    ///
898    /// If there is another thread inserting entries at the same time, there will be data
899    /// inconsistency.
900    pub fn for_all<F>(&self, mut f: F)
901    where
902        F: FnMut(&K, &T),
903    {
904        for shard in &self.shards {
905            let shard = shard.lock();
906            shard.for_all(&mut f);
907        }
908    }
909
910    /// # Safety
911    ///
912    /// This method can only be called when no cache entry are referenced outside.
913    pub fn clear(&self) {
914        for shard in &self.shards {
915            unsafe {
916                let mut shard = shard.lock();
917                shard.clear();
918            }
919        }
920    }
921}
922
923pub struct CleanCacheGuard<'a, K: LruKey + Clone + 'static, T: LruValue + 'static> {
924    cache: &'a Arc<LruCache<K, T>>,
925    key: Option<K>,
926    hash: u64,
927}
928
929impl<K: LruKey + Clone + 'static, T: LruValue + 'static> CleanCacheGuard<'_, K, T> {
930    fn mark_success(mut self) -> K {
931        self.key.take().unwrap()
932    }
933}
934
935impl<K: LruKey + Clone + 'static, T: LruValue + 'static> Drop for CleanCacheGuard<'_, K, T> {
936    fn drop(&mut self) {
937        if let Some(key) = self.key.as_ref() {
938            self.cache.clear_pending_request(key, self.hash);
939        }
940    }
941}
942
943/// `lookup_with_request_dedup.await` can directly return `Result<CacheableEntry<K, T>, E>`, but if
944/// we do not want to wait when cache hit does not happen, we can directly call
945/// `lookup_with_request_dedup` which will return a `LookupResponse` which contains
946/// `Receiver<CacheableEntry<K, T>>` or `JoinHandle<Result<CacheableEntry<K, T>, E>>` when cache hit
947/// does not happen.
948pub enum LookupResponse<K: LruKey + Clone + 'static, T: LruValue + 'static, E> {
949    Invalid,
950    Cached(CacheableEntry<K, T>),
951    WaitPendingRequest(Receiver<CacheableEntry<K, T>>),
952    Miss(JoinHandle<Result<CacheableEntry<K, T>, E>>),
953}
954
955impl<K: LruKey + Clone + 'static, T: LruValue + 'static, E> Default for LookupResponse<K, T, E> {
956    fn default() -> Self {
957        Self::Invalid
958    }
959}
960
961impl<K: LruKey + Clone + 'static, T: LruValue + 'static, E: From<RecvError>> Future
962    for LookupResponse<K, T, E>
963{
964    type Output = Result<CacheableEntry<K, T>, E>;
965
966    fn poll(
967        mut self: std::pin::Pin<&mut Self>,
968        cx: &mut std::task::Context<'_>,
969    ) -> std::task::Poll<Self::Output> {
970        match &mut *self {
971            Self::Invalid => unreachable!(),
972            Self::Cached(_) => std::task::Poll::Ready(Ok(
973                must_match!(std::mem::take(&mut *self), Self::Cached(entry) => entry),
974            )),
975            Self::WaitPendingRequest(receiver) => {
976                receiver.poll_unpin(cx).map_err(|recv_err| recv_err.into())
977            }
978            Self::Miss(join_handle) => join_handle
979                .poll_unpin(cx)
980                .map(|join_result| join_result.unwrap()),
981        }
982    }
983}
984
985/// Only implement `lookup_with_request_dedup` for static
986/// values, as they can be sent across tokio spawned futures.
987impl<K: LruKey + Clone + 'static, T: LruValue + 'static> LruCache<K, T> {
988    pub fn lookup_with_request_dedup<F, E, VC>(
989        self: &Arc<Self>,
990        hash: u64,
991        key: K,
992        priority: CachePriority,
993        fetch_value: F,
994    ) -> LookupResponse<K, T, E>
995    where
996        F: FnOnce() -> VC,
997        E: Error + Send + 'static + From<RecvError>,
998        VC: Future<Output = Result<(T, usize), E>> + Send + 'static,
999    {
1000        match self.lookup_for_request(hash, key.clone()) {
1001            LookupResult::Cached(entry) => LookupResponse::Cached(entry),
1002            LookupResult::WaitPendingRequest(receiver) => {
1003                LookupResponse::WaitPendingRequest(receiver)
1004            }
1005            LookupResult::Miss => {
1006                let this = self.clone();
1007                let fetch_value = fetch_value();
1008                let key2 = key;
1009                let join_handle = tokio::spawn(async move {
1010                    let guard = CleanCacheGuard {
1011                        cache: &this,
1012                        key: Some(key2),
1013                        hash,
1014                    };
1015                    let (value, charge) = fetch_value.await?;
1016                    let key2 = guard.mark_success();
1017                    let entry = this.insert(key2, hash, charge, value, priority);
1018                    Ok(entry)
1019                });
1020                LookupResponse::Miss(join_handle)
1021            }
1022        }
1023    }
1024}
1025
1026pub struct CacheableEntry<K: LruKey, T: LruValue> {
1027    cache: Arc<LruCache<K, T>>,
1028    handle: *mut LruHandle<K, T>,
1029}
1030
1031pub enum LookupResult<K: LruKey, T: LruValue> {
1032    Cached(CacheableEntry<K, T>),
1033    Miss,
1034    WaitPendingRequest(Receiver<CacheableEntry<K, T>>),
1035}
1036
1037unsafe impl<K: LruKey, T: LruValue> Send for CacheableEntry<K, T> {}
1038unsafe impl<K: LruKey, T: LruValue> Sync for CacheableEntry<K, T> {}
1039
1040impl<K: LruKey, T: LruValue> Deref for CacheableEntry<K, T> {
1041    type Target = T;
1042
1043    fn deref(&self) -> &Self::Target {
1044        unsafe { (*self.handle).get_value() }
1045    }
1046}
1047
1048impl<K: LruKey, T: LruValue> Drop for CacheableEntry<K, T> {
1049    fn drop(&mut self) {
1050        unsafe {
1051            self.cache.release(self.handle);
1052        }
1053    }
1054}
1055
1056impl<K: LruKey, T: LruValue> Clone for CacheableEntry<K, T> {
1057    fn clone(&self) -> Self {
1058        unsafe {
1059            self.cache.inc_reference(self.handle);
1060            CacheableEntry {
1061                cache: self.cache.clone(),
1062                handle: self.handle,
1063            }
1064        }
1065    }
1066}
1067
1068#[cfg(test)]
1069mod tests {
1070    use std::collections::hash_map::DefaultHasher;
1071    use std::hash::Hasher;
1072    use std::pin::Pin;
1073    use std::sync::atomic::AtomicBool;
1074    use std::sync::atomic::Ordering::Relaxed;
1075    use std::task::{Context, Poll};
1076
1077    use rand::rngs::SmallRng;
1078    use rand::{RngCore, SeedableRng};
1079    use tokio::sync::oneshot::error::TryRecvError;
1080
1081    use super::*;
1082
1083    pub struct Block {
1084        pub offset: u64,
1085        #[allow(dead_code)]
1086        pub sst: u64,
1087    }
1088
1089    #[test]
1090    fn test_cache_handle_basic() {
1091        let mut h = Box::new(LruHandle::new(1, 2, 0, 0));
1092        h.set_in_cache(true);
1093        assert!(h.is_in_cache());
1094        h.set_in_cache(false);
1095        assert!(!h.is_in_cache());
1096    }
1097
1098    #[test]
1099    fn test_cache_shard() {
1100        let cache = Arc::new(LruCache::<(u64, u64), Block>::new(4, 256, 0));
1101        assert_eq!(cache.shard(0), 0);
1102        assert_eq!(cache.shard(1), 1);
1103        assert_eq!(cache.shard(10), 2);
1104    }
1105
1106    #[test]
1107    fn test_cache_basic() {
1108        let cache = Arc::new(LruCache::<(u64, u64), Block>::new(2, 256, 0));
1109        let seed = 10244021u64;
1110        let mut rng = SmallRng::seed_from_u64(seed);
1111        for _ in 0..100000 {
1112            let block_offset = rng.next_u64() % 1024;
1113            let sst = rng.next_u64() % 1024;
1114            let mut hasher = DefaultHasher::new();
1115            sst.hash(&mut hasher);
1116            block_offset.hash(&mut hasher);
1117            let h = hasher.finish();
1118            if let Some(block) = cache.lookup(h, &(sst, block_offset)) {
1119                assert_eq!(block.offset, block_offset);
1120                drop(block);
1121                continue;
1122            }
1123            cache.insert(
1124                (sst, block_offset),
1125                h,
1126                1,
1127                Block {
1128                    offset: block_offset,
1129                    sst,
1130                },
1131                CachePriority::High,
1132            );
1133        }
1134        assert_eq!(256, cache.get_memory_usage());
1135    }
1136
1137    fn validate_lru_list(cache: &mut LruCacheShard<String, String>, keys: Vec<&str>) {
1138        unsafe {
1139            let mut lru: *mut LruHandle<String, String> = cache.lru.as_mut();
1140            for k in keys {
1141                lru = (*lru).next;
1142                assert!(
1143                    (*lru).is_same_key(&k.to_owned()),
1144                    "compare failed: {} vs {}, get value: {:?}",
1145                    (*lru).get_key(),
1146                    k,
1147                    (*lru).get_value()
1148                );
1149            }
1150        }
1151    }
1152
1153    fn create_cache(capacity: usize) -> LruCacheShard<String, String> {
1154        LruCacheShard::new_with_priority_pool(capacity, 0)
1155    }
1156
1157    fn lookup(cache: &mut LruCacheShard<String, String>, key: &str) -> bool {
1158        unsafe {
1159            let h = cache.lookup(0, &key.to_owned());
1160            let exist = !h.is_null();
1161            if exist {
1162                assert!((*h).is_same_key(&key.to_owned()));
1163                cache.release(h);
1164            }
1165            exist
1166        }
1167    }
1168
1169    fn insert_priority(
1170        cache: &mut LruCacheShard<String, String>,
1171        key: &str,
1172        value: &str,
1173        priority: CachePriority,
1174    ) {
1175        let mut free_list = vec![];
1176        unsafe {
1177            let handle = cache.insert(
1178                key.to_owned(),
1179                0,
1180                value.len(),
1181                value.to_owned(),
1182                priority,
1183                &mut free_list,
1184            );
1185            cache.release(handle);
1186        }
1187        free_list.clear();
1188    }
1189
1190    fn insert(cache: &mut LruCacheShard<String, String>, key: &str, value: &str) {
1191        insert_priority(cache, key, value, CachePriority::Low);
1192    }
1193
1194    #[test]
1195    fn test_basic_lru() {
1196        let mut cache = LruCacheShard::new_with_priority_pool(5, 40);
1197        let keys = vec!["a", "b", "c", "d", "e"];
1198        for &k in &keys {
1199            insert(&mut cache, k, k);
1200        }
1201        validate_lru_list(&mut cache, keys);
1202        for k in ["x", "y", "z"] {
1203            insert(&mut cache, k, k);
1204        }
1205        validate_lru_list(&mut cache, vec!["d", "e", "x", "y", "z"]);
1206        assert!(!lookup(&mut cache, "b"));
1207        assert!(lookup(&mut cache, "e"));
1208        validate_lru_list(&mut cache, vec!["d", "x", "y", "z", "e"]);
1209        assert!(lookup(&mut cache, "z"));
1210        validate_lru_list(&mut cache, vec!["d", "x", "y", "e", "z"]);
1211        unsafe {
1212            let h = cache.erase(0, &"x".to_owned());
1213            assert!(h.is_some());
1214            validate_lru_list(&mut cache, vec!["d", "y", "e", "z"]);
1215        }
1216        assert!(lookup(&mut cache, "d"));
1217        validate_lru_list(&mut cache, vec!["y", "e", "z", "d"]);
1218        insert(&mut cache, "u", "u");
1219        validate_lru_list(&mut cache, vec!["y", "e", "z", "d", "u"]);
1220        insert(&mut cache, "v", "v");
1221        validate_lru_list(&mut cache, vec!["e", "z", "d", "u", "v"]);
1222        insert_priority(&mut cache, "x", "x", CachePriority::High);
1223        validate_lru_list(&mut cache, vec!["z", "d", "u", "v", "x"]);
1224        assert!(lookup(&mut cache, "d"));
1225        validate_lru_list(&mut cache, vec!["z", "u", "v", "d", "x"]);
1226        insert(&mut cache, "y", "y");
1227        validate_lru_list(&mut cache, vec!["u", "v", "d", "y", "x"]);
1228        insert_priority(&mut cache, "z", "z", CachePriority::High);
1229        validate_lru_list(&mut cache, vec!["v", "d", "y", "x", "z"]);
1230        insert(&mut cache, "u", "u");
1231        validate_lru_list(&mut cache, vec!["d", "y", "u", "x", "z"]);
1232        insert_priority(&mut cache, "v", "v", CachePriority::High);
1233        validate_lru_list(&mut cache, vec!["y", "u", "x", "z", "v"]);
1234    }
1235
1236    #[test]
1237    fn test_reference_and_usage() {
1238        let mut cache = create_cache(5);
1239        insert(&mut cache, "k1", "a");
1240        assert_eq!(cache.usage.load(Ordering::Relaxed), 1);
1241        insert(&mut cache, "k0", "aa");
1242        assert_eq!(cache.usage.load(Ordering::Relaxed), 3);
1243        insert(&mut cache, "k1", "aa");
1244        assert_eq!(cache.usage.load(Ordering::Relaxed), 4);
1245        insert(&mut cache, "k2", "aa");
1246        assert_eq!(cache.usage.load(Ordering::Relaxed), 4);
1247        let mut free_list = vec![];
1248        validate_lru_list(&mut cache, vec!["k1", "k2"]);
1249        unsafe {
1250            let h1 = cache.lookup(0, &"k1".to_owned());
1251            assert!(!h1.is_null());
1252            let h2 = cache.lookup(0, &"k2".to_owned());
1253            assert!(!h2.is_null());
1254
1255            let h3 = cache.insert(
1256                "k3".to_owned(),
1257                0,
1258                2,
1259                "bb".to_owned(),
1260                CachePriority::High,
1261                &mut free_list,
1262            );
1263            assert_eq!(cache.usage.load(Ordering::Relaxed), 6);
1264            assert!(!h3.is_null());
1265            let h4 = cache.lookup(0, &"k1".to_owned());
1266            assert!(!h4.is_null());
1267
1268            cache.release(h1);
1269            assert_eq!(cache.usage.load(Ordering::Relaxed), 6);
1270            cache.release(h4);
1271            assert_eq!(cache.usage.load(Ordering::Relaxed), 4);
1272
1273            cache.release(h3);
1274            cache.release(h2);
1275
1276            validate_lru_list(&mut cache, vec!["k3", "k2"]);
1277        }
1278    }
1279
1280    #[test]
1281    fn test_update_referenced_key() {
1282        unsafe {
1283            let mut to_delete = vec![];
1284            let mut cache = create_cache(5);
1285            let insert_handle = cache.insert(
1286                "key".to_owned(),
1287                0,
1288                1,
1289                "old_value".to_owned(),
1290                CachePriority::High,
1291                &mut to_delete,
1292            );
1293            let old_entry = cache.lookup(0, &"key".to_owned());
1294            assert!(!old_entry.is_null());
1295            assert_eq!((*old_entry).get_value(), &"old_value".to_owned());
1296            assert_eq!((*old_entry).refs, 2);
1297            cache.release(insert_handle);
1298            assert_eq!((*old_entry).refs, 1);
1299            let insert_handle = cache.insert(
1300                "key".to_owned(),
1301                0,
1302                1,
1303                "new_value".to_owned(),
1304                CachePriority::Low,
1305                &mut to_delete,
1306            );
1307            assert!(!(*old_entry).is_in_cache());
1308            let new_entry = cache.lookup(0, &"key".to_owned());
1309            assert!(!new_entry.is_null());
1310            assert_eq!((*new_entry).get_value(), &"new_value".to_owned());
1311            assert_eq!((*new_entry).refs, 2);
1312            cache.release(insert_handle);
1313            assert_eq!((*new_entry).refs, 1);
1314
1315            assert!(!old_entry.is_null());
1316            assert_eq!((*old_entry).get_value(), &"old_value".to_owned());
1317            assert_eq!((*old_entry).refs, 1);
1318
1319            cache.release(new_entry);
1320            assert!((*new_entry).is_in_cache());
1321            #[cfg(debug_assertions)]
1322            assert!((*new_entry).is_in_lru());
1323
1324            // assert old value unchanged.
1325            assert!(!old_entry.is_null());
1326            assert_eq!((*old_entry).get_value(), &"old_value".to_owned());
1327            assert_eq!((*old_entry).refs, 1);
1328
1329            cache.release(old_entry);
1330            assert!(!(*old_entry).is_in_cache());
1331            assert!((*new_entry).is_in_cache());
1332            #[cfg(debug_assertions)]
1333            {
1334                assert!(!(*old_entry).is_in_lru());
1335                assert!((*new_entry).is_in_lru());
1336            }
1337        }
1338    }
1339
1340    #[test]
1341    fn test_release_stale_value() {
1342        unsafe {
1343            let mut to_delete = vec![];
1344            // The cache can only hold one handle
1345            let mut cache = create_cache(1);
1346            let insert_handle = cache.insert(
1347                "key".to_owned(),
1348                0,
1349                1,
1350                "old_value".to_owned(),
1351                CachePriority::High,
1352                &mut to_delete,
1353            );
1354            cache.release(insert_handle);
1355            let old_entry = cache.lookup(0, &"key".to_owned());
1356            assert!(!old_entry.is_null());
1357            assert_eq!((*old_entry).get_value(), &"old_value".to_owned());
1358            assert_eq!((*old_entry).refs, 1);
1359
1360            let insert_handle = cache.insert(
1361                "key".to_owned(),
1362                0,
1363                1,
1364                "new_value".to_owned(),
1365                CachePriority::High,
1366                &mut to_delete,
1367            );
1368            assert!(!(*old_entry).is_in_cache());
1369            let new_entry = cache.lookup(0, &"key".to_owned());
1370            assert!(!new_entry.is_null());
1371            assert_eq!((*new_entry).get_value(), &"new_value".to_owned());
1372            assert_eq!((*new_entry).refs, 2);
1373            cache.release(insert_handle);
1374            assert_eq!((*new_entry).refs, 1);
1375
1376            // The handle for new and old value are both referenced.
1377            assert_eq!(2, cache.usage.load(Relaxed));
1378            assert_eq!(0, cache.lru_usage.load(Relaxed));
1379
1380            // Release the old handle, it will be cleared since the cache capacity is 1
1381            cache.release(old_entry);
1382            assert_eq!(1, cache.usage.load(Relaxed));
1383            assert_eq!(0, cache.lru_usage.load(Relaxed));
1384
1385            let new_entry_again = cache.lookup(0, &"key".to_owned());
1386            assert!(!new_entry_again.is_null());
1387            assert_eq!((*new_entry_again).get_value(), &"new_value".to_owned());
1388            assert_eq!((*new_entry_again).refs, 2);
1389
1390            cache.release(new_entry);
1391            cache.release(new_entry_again);
1392
1393            assert_eq!(1, cache.usage.load(Relaxed));
1394            assert_eq!(1, cache.lru_usage.load(Relaxed));
1395        }
1396    }
1397
1398    #[test]
1399    fn test_write_request_pending() {
1400        let cache = Arc::new(LruCache::new(1, 5, 0));
1401        {
1402            let mut shard = cache.shards[0].lock();
1403            insert(&mut shard, "a", "v1");
1404            assert!(lookup(&mut shard, "a"));
1405        }
1406        assert!(matches!(
1407            cache.lookup_for_request(0, "a".to_owned()),
1408            LookupResult::Cached(_)
1409        ));
1410        assert!(matches!(
1411            cache.lookup_for_request(0, "b".to_owned()),
1412            LookupResult::Miss
1413        ));
1414        let ret2 = cache.lookup_for_request(0, "b".to_owned());
1415        match ret2 {
1416            LookupResult::WaitPendingRequest(mut recv) => {
1417                assert!(matches!(recv.try_recv(), Err(TryRecvError::Empty)));
1418                cache.insert("b".to_owned(), 0, 1, "v2".to_owned(), CachePriority::Low);
1419                recv.try_recv().unwrap();
1420                assert!(
1421                    matches!(cache.lookup_for_request(0, "b".to_owned()), LookupResult::Cached(v) if v.eq("v2"))
1422                );
1423            }
1424            _ => panic!(),
1425        }
1426    }
1427
1428    #[derive(Default, Debug)]
1429    struct TestLruCacheEventListener {
1430        released: Arc<Mutex<HashMap<String, String>>>,
1431    }
1432
1433    impl LruCacheEventListener for TestLruCacheEventListener {
1434        type K = String;
1435        type T = String;
1436
1437        fn on_release(&self, key: Self::K, value: Self::T) {
1438            self.released.lock().insert(key, value);
1439        }
1440    }
1441
1442    #[test]
1443    fn test_event_listener() {
1444        let listener = Arc::new(TestLruCacheEventListener::default());
1445        let cache = Arc::new(LruCache::with_event_listener(1, 2, 0, listener.clone()));
1446
1447        // full-fill cache
1448        let h = cache.insert("k1".to_owned(), 0, 1, "v1".to_owned(), CachePriority::High);
1449        drop(h);
1450        let h = cache.insert("k2".to_owned(), 0, 1, "v2".to_owned(), CachePriority::High);
1451        drop(h);
1452        assert_eq!(cache.get_memory_usage(), 2);
1453        assert!(listener.released.lock().is_empty());
1454
1455        // test evict
1456        let h = cache.insert("k3".to_owned(), 0, 1, "v3".to_owned(), CachePriority::High);
1457        drop(h);
1458        assert_eq!(cache.get_memory_usage(), 2);
1459        assert!(listener.released.lock().remove("k1").is_some());
1460
1461        // test erase
1462        cache.erase(0, &"k2".to_owned());
1463        assert_eq!(cache.get_memory_usage(), 1);
1464        assert!(listener.released.lock().remove("k2").is_some());
1465
1466        // test refill
1467        let h = cache.insert("k4".to_owned(), 0, 1, "v4".to_owned(), CachePriority::Low);
1468        drop(h);
1469        assert_eq!(cache.get_memory_usage(), 2);
1470        assert!(listener.released.lock().is_empty());
1471
1472        // test release after full
1473        // 1. full-fill cache but not release
1474        let h1 = cache.insert("k5".to_owned(), 0, 1, "v5".to_owned(), CachePriority::Low);
1475        assert_eq!(cache.get_memory_usage(), 2);
1476        assert!(listener.released.lock().remove("k3").is_some());
1477        let h2 = cache.insert("k6".to_owned(), 0, 1, "v6".to_owned(), CachePriority::Low);
1478        assert_eq!(cache.get_memory_usage(), 2);
1479        assert!(listener.released.lock().remove("k4").is_some());
1480
1481        // 2. insert one more entry after cache is full, cache will be oversized
1482        let h3 = cache.insert("k7".to_owned(), 0, 1, "v7".to_owned(), CachePriority::Low);
1483        assert_eq!(cache.get_memory_usage(), 3);
1484        assert!(listener.released.lock().is_empty());
1485
1486        // 3. release one entry, and it will be evicted immediately bucause cache is oversized
1487        drop(h1);
1488        assert_eq!(cache.get_memory_usage(), 2);
1489        assert!(listener.released.lock().remove("k5").is_some());
1490
1491        // 4. release other entries, no entry will be evicted
1492        drop(h2);
1493        assert_eq!(cache.get_memory_usage(), 2);
1494        assert!(listener.released.lock().is_empty());
1495        drop(h3);
1496        assert_eq!(cache.get_memory_usage(), 2);
1497        assert!(listener.released.lock().is_empty());
1498
1499        // assert listener won't listen clear
1500        drop(cache);
1501        assert!(listener.released.lock().is_empty());
1502    }
1503
1504    pub struct SyncPointFuture<F: Future> {
1505        inner: F,
1506        polled: Arc<AtomicBool>,
1507    }
1508
1509    impl<F: Future + Unpin> Future for SyncPointFuture<F> {
1510        type Output = ();
1511
1512        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
1513            if self.polled.load(Ordering::Acquire) {
1514                return Poll::Ready(());
1515            }
1516            self.inner.poll_unpin(cx).map(|_| ())
1517        }
1518    }
1519
1520    #[tokio::test]
1521    async fn test_future_cancel() {
1522        let cache: Arc<LruCache<u64, u64>> = Arc::new(LruCache::new(1, 5, 0));
1523        // do not need sender because this receiver will be cancelled.
1524        let (_, recv) = channel::<()>();
1525        let polled = Arc::new(AtomicBool::new(false));
1526        let cache2 = cache.clone();
1527        let polled2 = polled.clone();
1528        let f = Box::pin(async move {
1529            cache2
1530                .lookup_with_request_dedup(1, 2, CachePriority::High, || async move {
1531                    polled2.store(true, Ordering::Release);
1532                    recv.await.map(|_| (1, 1))
1533                })
1534                .await
1535                .unwrap();
1536        });
1537        let wrapper = SyncPointFuture {
1538            inner: f,
1539            polled: polled.clone(),
1540        };
1541        {
1542            let handle = tokio::spawn(wrapper);
1543            while !polled.load(Ordering::Acquire) {
1544                tokio::task::yield_now().await;
1545            }
1546            handle.await.unwrap();
1547        }
1548        assert!(cache.shards[0].lock().write_request.is_empty());
1549    }
1550}