1use std::collections::HashMap;
23use std::error::Error;
24use std::future::Future;
25use std::hash::Hash;
26use std::ops::Deref;
27use std::ptr;
28use std::ptr::null_mut;
29use std::sync::Arc;
30use std::sync::atomic::{AtomicUsize, Ordering};
31
32use futures::FutureExt;
33use parking_lot::Mutex;
34use tokio::sync::oneshot::error::RecvError;
35use tokio::sync::oneshot::{Receiver, Sender, channel};
36use tokio::task::JoinHandle;
37
38const IN_CACHE: u8 = 1;
39const REVERSE_IN_CACHE: u8 = !IN_CACHE;
40
41#[cfg(debug_assertions)]
42const IN_LRU: u8 = 1 << 1;
43#[cfg(debug_assertions)]
44const REVERSE_IN_LRU: u8 = !IN_LRU;
45const IS_HIGH_PRI: u8 = 1 << 2;
46const IN_HIGH_PRI_POOL: u8 = 1 << 3;
47
48pub trait LruKey: Eq + Send + Hash {}
49impl<T: Eq + Send + Hash> LruKey for T {}
50
51pub trait LruValue: Send + Sync {}
52impl<T: Send + Sync> LruValue for T {}
53
54#[derive(Clone, Copy, Eq, PartialEq)]
55pub enum CachePriority {
56 High,
57 Low,
58}
59
60pub struct LruHandle<K: LruKey, T: LruValue> {
87 next_hash: *mut LruHandle<K, T>,
89
90 next: *mut LruHandle<K, T>,
92
93 prev: *mut LruHandle<K, T>,
95
96 kv: Option<(K, T)>,
99 hash: u64,
100 charge: usize,
101
102 refs: u32,
105 flags: u8,
106}
107
108impl<K: LruKey, T: LruValue> Default for LruHandle<K, T> {
109 fn default() -> Self {
110 Self {
111 next_hash: null_mut(),
112 next: null_mut(),
113 prev: null_mut(),
114 kv: None,
115 hash: 0,
116 charge: 0,
117 refs: 0,
118 flags: 0,
119 }
120 }
121}
122
123impl<K: LruKey, T: LruValue> LruHandle<K, T> {
124 pub fn new(key: K, value: T, hash: u64, charge: usize) -> Self {
125 let mut ret = Self::default();
126 ret.init(key, value, hash, charge);
127 ret
128 }
129
130 pub fn init(&mut self, key: K, value: T, hash: u64, charge: usize) {
131 self.next_hash = null_mut();
132 self.prev = null_mut();
133 self.next = null_mut();
134 self.kv = Some((key, value));
135 self.hash = hash;
136 self.charge = charge;
137 self.flags = 0;
138 self.refs = 0;
139 }
140
141 fn set_in_cache(&mut self, in_cache: bool) {
148 if in_cache {
149 self.flags |= IN_CACHE;
150 } else {
151 self.flags &= REVERSE_IN_CACHE;
152 }
153 }
154
155 fn is_high_priority(&self) -> bool {
156 (self.flags & IS_HIGH_PRI) > 0
157 }
158
159 fn set_high_priority(&mut self, high_priority: bool) {
160 if high_priority {
161 self.flags |= IS_HIGH_PRI;
162 } else {
163 self.flags &= !IS_HIGH_PRI;
164 }
165 }
166
167 fn set_in_high_pri_pool(&mut self, in_high_pri_pool: bool) {
168 if in_high_pri_pool {
169 self.flags |= IN_HIGH_PRI_POOL;
170 } else {
171 self.flags &= !IN_HIGH_PRI_POOL;
172 }
173 }
174
175 fn is_in_high_pri_pool(&self) -> bool {
176 (self.flags & IN_HIGH_PRI_POOL) > 0
177 }
178
179 fn add_ref(&mut self) {
180 self.refs += 1;
181 }
182
183 fn add_multi_refs(&mut self, ref_count: u32) {
184 self.refs += ref_count;
185 }
186
187 fn unref(&mut self) -> bool {
188 debug_assert!(self.refs > 0);
189 self.refs -= 1;
190 self.refs == 0
191 }
192
193 fn has_refs(&self) -> bool {
194 self.refs > 0
195 }
196
197 fn is_in_cache(&self) -> bool {
200 (self.flags & IN_CACHE) > 0
201 }
202
203 unsafe fn get_key(&self) -> &K {
204 unsafe {
205 debug_assert!(self.kv.is_some());
206 &self.kv.as_ref().unwrap_unchecked().0
207 }
208 }
209
210 unsafe fn get_value(&self) -> &T {
211 unsafe {
212 debug_assert!(self.kv.is_some());
213 &self.kv.as_ref().unwrap_unchecked().1
214 }
215 }
216
217 unsafe fn is_same_key(&self, key: &K) -> bool {
218 unsafe {
219 debug_assert!(self.kv.is_some());
220 self.kv.as_ref().unwrap_unchecked().0.eq(key)
221 }
222 }
223
224 unsafe fn take_kv(&mut self) -> (K, T) {
225 unsafe {
226 debug_assert!(self.kv.is_some());
227 self.kv.take().unwrap_unchecked()
228 }
229 }
230
231 #[cfg(debug_assertions)]
232 fn is_in_lru(&self) -> bool {
233 (self.flags & IN_LRU) > 0
234 }
235
236 #[cfg(debug_assertions)]
237 fn set_in_lru(&mut self, in_lru: bool) {
238 if in_lru {
239 self.flags |= IN_LRU;
240 } else {
241 self.flags &= REVERSE_IN_LRU;
242 }
243 }
244}
245
246unsafe impl<K: LruKey, T: LruValue> Send for LruHandle<K, T> {}
247
248pub struct LruHandleTable<K: LruKey, T: LruValue> {
249 list: Vec<*mut LruHandle<K, T>>,
250 elems: usize,
251}
252
253impl<K: LruKey, T: LruValue> LruHandleTable<K, T> {
254 fn new() -> Self {
255 Self {
256 list: vec![null_mut(); 16],
257 elems: 0,
258 }
259 }
260
261 unsafe fn find_pointer(
263 &self,
264 idx: usize,
265 key: &K,
266 ) -> (*mut LruHandle<K, T>, *mut LruHandle<K, T>) {
267 unsafe {
268 let mut ptr = self.list[idx];
269 let mut prev = null_mut();
270 while !ptr.is_null() && !(*ptr).is_same_key(key) {
271 prev = ptr;
272 ptr = (*ptr).next_hash;
273 }
274 (prev, ptr)
275 }
276 }
277
278 unsafe fn remove(&mut self, hash: u64, key: &K) -> *mut LruHandle<K, T> {
279 unsafe {
280 debug_assert!(self.list.len().is_power_of_two());
281 let idx = (hash as usize) & (self.list.len() - 1);
282 let (prev, ptr) = self.find_pointer(idx, key);
283 if ptr.is_null() {
284 return null_mut();
285 }
286 debug_assert!((*ptr).is_in_cache());
287 (*ptr).set_in_cache(false);
288 if prev.is_null() {
289 self.list[idx] = (*ptr).next_hash;
290 } else {
291 (*prev).next_hash = (*ptr).next_hash;
292 }
293 self.elems -= 1;
294 ptr
295 }
296 }
297
298 unsafe fn insert(&mut self, hash: u64, h: *mut LruHandle<K, T>) -> *mut LruHandle<K, T> {
301 unsafe {
302 debug_assert!(!h.is_null());
303 debug_assert!(!(*h).is_in_cache());
304 (*h).set_in_cache(true);
305 debug_assert!(self.list.len().is_power_of_two());
306 let idx = (hash as usize) & (self.list.len() - 1);
307 let (prev, ptr) = self.find_pointer(idx, (*h).get_key());
308 if prev.is_null() {
309 self.list[idx] = h;
310 } else {
311 (*prev).next_hash = h;
312 }
313
314 if !ptr.is_null() {
315 debug_assert!((*ptr).is_same_key((*h).get_key()));
316 debug_assert!((*ptr).is_in_cache());
317 (*ptr).set_in_cache(false);
319 (*h).next_hash = (*ptr).next_hash;
320 return ptr;
321 }
322
323 (*h).next_hash = ptr;
324
325 self.elems += 1;
326 if self.elems > self.list.len() {
327 self.resize();
328 }
329 null_mut()
330 }
331 }
332
333 unsafe fn lookup(&self, hash: u64, key: &K) -> *mut LruHandle<K, T> {
334 unsafe {
335 debug_assert!(self.list.len().is_power_of_two());
336 let idx = (hash as usize) & (self.list.len() - 1);
337 let (_, ptr) = self.find_pointer(idx, key);
338 ptr
339 }
340 }
341
342 unsafe fn resize(&mut self) {
343 unsafe {
344 let mut l = std::cmp::max(16, self.list.len());
345 let next_capacity = self.elems * 3 / 2;
346 while l < next_capacity {
347 l <<= 1;
348 }
349 let mut count = 0;
350 let mut new_list = vec![null_mut(); l];
351 for head in self.list.drain(..) {
352 let mut handle = head;
353 while !handle.is_null() {
354 let idx = (*handle).hash as usize & (l - 1);
355 let next = (*handle).next_hash;
356 (*handle).next_hash = new_list[idx];
357 new_list[idx] = handle;
358 handle = next;
359 count += 1;
360 }
361 }
362 assert_eq!(count, self.elems);
363 self.list = new_list;
364 }
365 }
366
367 unsafe fn for_all<F>(&self, f: &mut F)
368 where
369 F: FnMut(&K, &T),
370 {
371 unsafe {
372 for idx in 0..self.list.len() {
373 let mut ptr = self.list[idx];
374 while !ptr.is_null() {
375 f((*ptr).get_key(), (*ptr).get_value());
376 ptr = (*ptr).next_hash;
377 }
378 }
379 }
380 }
381}
382
383type RequestQueue<K, T> = Vec<Sender<CacheableEntry<K, T>>>;
384pub struct LruCacheShard<K: LruKey, T: LruValue> {
385 lru: Box<LruHandle<K, T>>,
390 low_priority_head: *mut LruHandle<K, T>,
391 high_priority_pool_capacity: usize,
392 high_priority_pool_usage: usize,
393 table: LruHandleTable<K, T>,
394 object_pool: Vec<Box<LruHandle<K, T>>>,
396 write_request: HashMap<K, RequestQueue<K, T>>,
397 lru_usage: Arc<AtomicUsize>,
398 usage: Arc<AtomicUsize>,
399 capacity: usize,
400}
401
402unsafe impl<K: LruKey, T: LruValue> Send for LruCacheShard<K, T> {}
403
404impl<K: LruKey, T: LruValue> LruCacheShard<K, T> {
405 fn new_with_priority_pool(capacity: usize, high_priority_ratio_percent: usize) -> Self {
407 let mut lru = Box::<LruHandle<K, T>>::default();
408 lru.prev = lru.as_mut();
409 lru.next = lru.as_mut();
410 let mut object_pool = Vec::with_capacity(DEFAULT_OBJECT_POOL_SIZE);
411 for _ in 0..DEFAULT_OBJECT_POOL_SIZE {
412 object_pool.push(Box::default());
413 }
414 Self {
415 capacity,
416 lru_usage: Arc::new(AtomicUsize::new(0)),
417 usage: Arc::new(AtomicUsize::new(0)),
418 object_pool,
419 low_priority_head: lru.as_mut(),
420 high_priority_pool_capacity: high_priority_ratio_percent * capacity / 100,
421 lru,
422 table: LruHandleTable::new(),
423 write_request: HashMap::with_capacity(16),
424 high_priority_pool_usage: 0,
425 }
426 }
427
428 unsafe fn lru_remove(&mut self, e: *mut LruHandle<K, T>) {
429 unsafe {
430 debug_assert!(!e.is_null());
431 #[cfg(debug_assertions)]
432 {
433 assert!((*e).is_in_lru());
434 (*e).set_in_lru(false);
435 }
436
437 if ptr::eq(e, self.low_priority_head) {
438 self.low_priority_head = (*e).prev;
439 }
440
441 (*(*e).next).prev = (*e).prev;
442 (*(*e).prev).next = (*e).next;
443 (*e).prev = null_mut();
444 (*e).next = null_mut();
445 if (*e).is_in_high_pri_pool() {
446 debug_assert!(self.high_priority_pool_usage >= (*e).charge);
447 self.high_priority_pool_usage -= (*e).charge;
448 }
449 self.lru_usage.fetch_sub((*e).charge, Ordering::Relaxed);
450 }
451 }
452
453 unsafe fn lru_insert(&mut self, e: *mut LruHandle<K, T>) {
455 unsafe {
456 debug_assert!(!e.is_null());
457 let entry = &mut (*e);
458 #[cfg(debug_assertions)]
459 {
460 assert!(!(*e).is_in_lru());
461 (*e).set_in_lru(true);
462 }
463
464 if self.high_priority_pool_capacity > 0 && entry.is_high_priority() {
465 entry.set_in_high_pri_pool(true);
466 entry.next = self.lru.as_mut();
467 entry.prev = self.lru.prev;
468 (*entry.prev).next = e;
469 (*entry.next).prev = e;
470 self.high_priority_pool_usage += (*e).charge;
471 self.maintain_pool_size();
472 } else {
473 entry.set_in_high_pri_pool(false);
474 entry.next = (*self.low_priority_head).next;
475 entry.prev = self.low_priority_head;
476 (*entry.next).prev = e;
477 (*entry.prev).next = e;
478 self.low_priority_head = e;
479 }
480 self.lru_usage.fetch_add((*e).charge, Ordering::Relaxed);
481 }
482 }
483
484 unsafe fn maintain_pool_size(&mut self) {
485 unsafe {
486 while self.high_priority_pool_usage > self.high_priority_pool_capacity {
487 self.low_priority_head = (*self.low_priority_head).next;
489 (*self.low_priority_head).set_in_high_pri_pool(false);
490 self.high_priority_pool_usage -= (*self.low_priority_head).charge;
491 }
492 }
493 }
494
495 unsafe fn evict_from_lru(&mut self, charge: usize, last_reference_list: &mut Vec<(K, T)>) {
496 unsafe {
497 while self.usage.load(Ordering::Relaxed) + charge > self.capacity
500 && !std::ptr::eq(self.lru.next, self.lru.as_mut())
501 {
502 let old_ptr = self.lru.next;
503 self.table.remove((*old_ptr).hash, (*old_ptr).get_key());
504 self.lru_remove(old_ptr);
505 let (key, value) = self.clear_handle(old_ptr);
506 last_reference_list.push((key, value));
507 }
508 }
509 }
510
511 unsafe fn clear_handle(&mut self, h: *mut LruHandle<K, T>) -> (K, T) {
513 unsafe {
514 debug_assert!(!h.is_null());
515 debug_assert!((*h).kv.is_some());
516 #[cfg(debug_assertions)]
517 assert!(!(*h).is_in_lru());
518 debug_assert!(!(*h).is_in_cache());
519 debug_assert!(!(*h).has_refs());
520 self.usage.fetch_sub((*h).charge, Ordering::Relaxed);
521 let (key, value) = (*h).take_kv();
522 self.try_recycle_handle_object(h);
523 (key, value)
524 }
525 }
526
527 unsafe fn try_recycle_handle_object(&mut self, h: *mut LruHandle<K, T>) {
531 unsafe {
532 let mut node = Box::from_raw(h);
533 if self.object_pool.len() < self.object_pool.capacity() {
534 node.next_hash = null_mut();
535 node.next = null_mut();
536 node.prev = null_mut();
537 debug_assert!(node.kv.is_none());
538 self.object_pool.push(node);
539 }
540 }
541 }
542
543 unsafe fn insert(
545 &mut self,
546 key: K,
547 hash: u64,
548 charge: usize,
549 value: T,
550 priority: CachePriority,
551 last_reference_list: &mut Vec<(K, T)>,
552 ) -> *mut LruHandle<K, T> {
553 unsafe {
554 self.evict_from_lru(charge, last_reference_list);
555
556 let mut handle = match self.object_pool.pop() {
557 Some(mut h) => {
558 h.init(key, value, hash, charge);
559 h
560 }
561 _ => Box::new(LruHandle::new(key, value, hash, charge)),
562 };
563 if priority == CachePriority::High {
564 handle.set_high_priority(true);
565 }
566 let ptr = Box::into_raw(handle);
567 let old = self.table.insert(hash, ptr);
568 if !old.is_null()
569 && let Some(data) = self.try_remove_cache_handle(old)
570 {
571 last_reference_list.push(data);
572 }
573 self.usage.fetch_add(charge, Ordering::Relaxed);
574 (*ptr).add_ref();
575 ptr
576 }
577 }
578
579 unsafe fn release(&mut self, h: *mut LruHandle<K, T>) -> Option<(K, T)> {
583 unsafe {
584 debug_assert!(!h.is_null());
585 #[cfg(debug_assertions)]
587 assert!(!(*h).is_in_lru());
588 let last_reference = (*h).unref();
589 if !last_reference {
591 return None;
592 }
593
594 if (*h).is_in_cache() {
596 if self.usage.load(Ordering::Relaxed) <= self.capacity {
597 self.lru_insert(h);
598 return None;
599 }
600 self.table.remove((*h).hash, (*h).get_key());
602 }
603
604 #[cfg(debug_assertions)]
607 assert!(!(*h).is_in_lru());
608
609 let (key, value) = self.clear_handle(h);
610 Some((key, value))
611 }
612 }
613
614 unsafe fn lookup(&mut self, hash: u64, key: &K) -> *mut LruHandle<K, T> {
615 unsafe {
616 let e = self.table.lookup(hash, key);
617 if !e.is_null() {
618 if !(*e).has_refs() {
621 self.lru_remove(e);
622 }
623 (*e).add_ref();
624 }
625 e
626 }
627 }
628
629 unsafe fn erase(&mut self, hash: u64, key: &K) -> Option<(K, T)> {
631 unsafe {
632 let h = self.table.remove(hash, key);
633 if !h.is_null() {
634 self.try_remove_cache_handle(h)
635 } else {
636 None
637 }
638 }
639 }
640
641 unsafe fn try_remove_cache_handle(&mut self, h: *mut LruHandle<K, T>) -> Option<(K, T)> {
645 unsafe {
646 debug_assert!(!h.is_null());
647 if !(*h).has_refs() {
648 self.lru_remove(h);
652 let (key, value) = self.clear_handle(h);
653 return Some((key, value));
654 }
655 None
656 }
657 }
658
659 unsafe fn clear(&mut self) {
662 unsafe {
663 while !std::ptr::eq(self.lru.next, self.lru.as_mut()) {
664 let handle = self.lru.next;
665 self.erase((*handle).hash, (*handle).get_key());
667 }
668 }
669 }
670
671 fn for_all<F>(&self, f: &mut F)
672 where
673 F: FnMut(&K, &T),
674 {
675 unsafe { self.table.for_all(f) };
676 }
677}
678
679impl<K: LruKey, T: LruValue> Drop for LruCacheShard<K, T> {
680 fn drop(&mut self) {
681 unsafe {
684 self.clear();
685 }
686 }
687}
688
689pub trait LruCacheEventListener: Send + Sync {
690 type K: LruKey;
691 type T: LruValue;
692
693 fn on_release(&self, _key: Self::K, _value: Self::T) {}
698}
699
700pub struct LruCache<K: LruKey, T: LruValue> {
701 shards: Vec<Mutex<LruCacheShard<K, T>>>,
702 shard_usages: Vec<Arc<AtomicUsize>>,
703 shard_lru_usages: Vec<Arc<AtomicUsize>>,
704
705 listener: Option<Arc<dyn LruCacheEventListener<K = K, T = T>>>,
706}
707
708const DEFAULT_OBJECT_POOL_SIZE: usize = 1024;
711
712impl<K: LruKey, T: LruValue> LruCache<K, T> {
713 pub fn new(num_shards: usize, capacity: usize, high_priority_ratio: usize) -> Self {
714 Self::new_inner(num_shards, capacity, high_priority_ratio, None)
715 }
716
717 pub fn with_event_listener(
718 num_shards: usize,
719 capacity: usize,
720 high_priority_ratio: usize,
721 listener: Arc<dyn LruCacheEventListener<K = K, T = T>>,
722 ) -> Self {
723 Self::new_inner(num_shards, capacity, high_priority_ratio, Some(listener))
724 }
725
726 fn new_inner(
727 num_shards: usize,
728 capacity: usize,
729 high_priority_ratio: usize,
730 listener: Option<Arc<dyn LruCacheEventListener<K = K, T = T>>>,
731 ) -> Self {
732 let mut shards = Vec::with_capacity(num_shards);
733 let per_shard = capacity / num_shards;
734 let mut shard_usages = Vec::with_capacity(num_shards);
735 let mut shard_lru_usages = Vec::with_capacity(num_shards);
736 for _ in 0..num_shards {
737 let shard = LruCacheShard::new_with_priority_pool(per_shard, high_priority_ratio);
738 shard_usages.push(shard.usage.clone());
739 shard_lru_usages.push(shard.lru_usage.clone());
740 shards.push(Mutex::new(shard));
741 }
742 Self {
743 shards,
744 shard_usages,
745 shard_lru_usages,
746 listener,
747 }
748 }
749
750 pub fn contains(self: &Arc<Self>, hash: u64, key: &K) -> bool {
751 let shard = self.shards[self.shard(hash)].lock();
752 unsafe {
753 let ptr = shard.table.lookup(hash, key);
754 !ptr.is_null()
755 }
756 }
757
758 pub fn lookup(self: &Arc<Self>, hash: u64, key: &K) -> Option<CacheableEntry<K, T>> {
759 let mut shard = self.shards[self.shard(hash)].lock();
760 unsafe {
761 let ptr = shard.lookup(hash, key);
762 if ptr.is_null() {
763 return None;
764 }
765 let entry = CacheableEntry {
766 cache: self.clone(),
767 handle: ptr,
768 };
769 Some(entry)
770 }
771 }
772
773 pub fn lookup_for_request(self: &Arc<Self>, hash: u64, key: K) -> LookupResult<K, T> {
774 let mut shard = self.shards[self.shard(hash)].lock();
775 unsafe {
776 let ptr = shard.lookup(hash, &key);
777 if !ptr.is_null() {
778 return LookupResult::Cached(CacheableEntry {
779 cache: self.clone(),
780 handle: ptr,
781 });
782 }
783 if let Some(que) = shard.write_request.get_mut(&key) {
784 let (tx, recv) = channel();
785 que.push(tx);
786 return LookupResult::WaitPendingRequest(recv);
787 }
788 shard.write_request.insert(key, vec![]);
789 LookupResult::Miss
790 }
791 }
792
793 unsafe fn release(&self, handle: *mut LruHandle<K, T>) {
794 unsafe {
795 debug_assert!(!handle.is_null());
796 let data = {
797 let mut shard = self.shards[self.shard((*handle).hash)].lock();
798 shard.release(handle)
799 };
800 if let Some((key, value)) = data
802 && let Some(listener) = &self.listener
803 {
804 listener.on_release(key, value);
805 }
806 }
807 }
808
809 unsafe fn inc_reference(&self, handle: *mut LruHandle<K, T>) {
810 unsafe {
811 let _shard = self.shards[self.shard((*handle).hash)].lock();
812 (*handle).refs += 1;
813 }
814 }
815
816 pub fn insert(
817 self: &Arc<Self>,
818 key: K,
819 hash: u64,
820 charge: usize,
821 value: T,
822 priority: CachePriority,
823 ) -> CacheableEntry<K, T> {
824 let mut to_delete = vec![];
825 let mut senders = vec![];
827 let handle = unsafe {
828 let mut shard = self.shards[self.shard(hash)].lock();
829 let pending_request = shard.write_request.remove(&key);
830 let ptr = shard.insert(key, hash, charge, value, priority, &mut to_delete);
831 debug_assert!(!ptr.is_null());
832 if let Some(mut que) = pending_request {
833 (*ptr).add_multi_refs(que.len() as u32);
834 senders = std::mem::take(&mut que);
835 }
836 CacheableEntry {
837 cache: self.clone(),
838 handle: ptr,
839 }
840 };
841 for sender in senders {
842 let _ = sender.send(CacheableEntry {
843 cache: self.clone(),
844 handle: handle.handle,
845 });
846 }
847
848 if let Some(listener) = &self.listener {
850 for (key, value) in to_delete {
851 listener.on_release(key, value);
852 }
853 }
854 handle
855 }
856
857 pub fn clear_pending_request(&self, key: &K, hash: u64) {
858 let mut shard = self.shards[self.shard(hash)].lock();
859 shard.write_request.remove(key);
860 }
861
862 pub fn erase(&self, hash: u64, key: &K) {
863 let data = unsafe {
864 let mut shard = self.shards[self.shard(hash)].lock();
865 shard.erase(hash, key)
866 };
867 if let Some((key, value)) = data
869 && let Some(listener) = &self.listener
870 {
871 listener.on_release(key, value);
872 }
873 }
874
875 pub fn get_memory_usage(&self) -> usize {
876 self.shard_usages
877 .iter()
878 .map(|x| x.load(Ordering::Relaxed))
879 .sum()
880 }
881
882 pub fn get_lru_usage(&self) -> usize {
883 self.shard_lru_usages
884 .iter()
885 .map(|x| x.load(Ordering::Relaxed))
886 .sum()
887 }
888
889 fn shard(&self, hash: u64) -> usize {
890 hash as usize % self.shards.len()
891 }
892
893 pub fn for_all<F>(&self, mut f: F)
901 where
902 F: FnMut(&K, &T),
903 {
904 for shard in &self.shards {
905 let shard = shard.lock();
906 shard.for_all(&mut f);
907 }
908 }
909
910 pub fn clear(&self) {
914 for shard in &self.shards {
915 unsafe {
916 let mut shard = shard.lock();
917 shard.clear();
918 }
919 }
920 }
921}
922
923pub struct CleanCacheGuard<'a, K: LruKey + Clone + 'static, T: LruValue + 'static> {
924 cache: &'a Arc<LruCache<K, T>>,
925 key: Option<K>,
926 hash: u64,
927}
928
929impl<K: LruKey + Clone + 'static, T: LruValue + 'static> CleanCacheGuard<'_, K, T> {
930 fn mark_success(mut self) -> K {
931 self.key.take().unwrap()
932 }
933}
934
935impl<K: LruKey + Clone + 'static, T: LruValue + 'static> Drop for CleanCacheGuard<'_, K, T> {
936 fn drop(&mut self) {
937 if let Some(key) = self.key.as_ref() {
938 self.cache.clear_pending_request(key, self.hash);
939 }
940 }
941}
942
943#[derive(Default)]
949pub enum LookupResponse<K: LruKey + Clone + 'static, T: LruValue + 'static, E> {
950 #[default]
951 Invalid,
952 Cached(CacheableEntry<K, T>),
953 WaitPendingRequest(Receiver<CacheableEntry<K, T>>),
954 Miss(JoinHandle<Result<CacheableEntry<K, T>, E>>),
955}
956
957impl<K: LruKey + Clone + 'static, T: LruValue + 'static, E: From<RecvError>> Future
958 for LookupResponse<K, T, E>
959{
960 type Output = Result<CacheableEntry<K, T>, E>;
961
962 fn poll(
963 mut self: std::pin::Pin<&mut Self>,
964 cx: &mut std::task::Context<'_>,
965 ) -> std::task::Poll<Self::Output> {
966 match &mut *self {
967 Self::Invalid => unreachable!(),
968 Self::Cached(_) => std::task::Poll::Ready(Ok(
969 must_match!(std::mem::take(&mut *self), Self::Cached(entry) => entry),
970 )),
971 Self::WaitPendingRequest(receiver) => {
972 receiver.poll_unpin(cx).map_err(|recv_err| recv_err.into())
973 }
974 Self::Miss(join_handle) => join_handle
975 .poll_unpin(cx)
976 .map(|join_result| join_result.unwrap()),
977 }
978 }
979}
980
981impl<K: LruKey + Clone + 'static, T: LruValue + 'static> LruCache<K, T> {
984 pub fn lookup_with_request_dedup<F, E, VC>(
985 self: &Arc<Self>,
986 hash: u64,
987 key: K,
988 priority: CachePriority,
989 fetch_value: F,
990 ) -> LookupResponse<K, T, E>
991 where
992 F: FnOnce() -> VC,
993 E: Error + Send + 'static + From<RecvError>,
994 VC: Future<Output = Result<(T, usize), E>> + Send + 'static,
995 {
996 match self.lookup_for_request(hash, key.clone()) {
997 LookupResult::Cached(entry) => LookupResponse::Cached(entry),
998 LookupResult::WaitPendingRequest(receiver) => {
999 LookupResponse::WaitPendingRequest(receiver)
1000 }
1001 LookupResult::Miss => {
1002 let this = self.clone();
1003 let fetch_value = fetch_value();
1004 let key2 = key;
1005 let join_handle = tokio::spawn(async move {
1006 let guard = CleanCacheGuard {
1007 cache: &this,
1008 key: Some(key2),
1009 hash,
1010 };
1011 let (value, charge) = fetch_value.await?;
1012 let key2 = guard.mark_success();
1013 let entry = this.insert(key2, hash, charge, value, priority);
1014 Ok(entry)
1015 });
1016 LookupResponse::Miss(join_handle)
1017 }
1018 }
1019 }
1020}
1021
1022pub struct CacheableEntry<K: LruKey, T: LruValue> {
1023 cache: Arc<LruCache<K, T>>,
1024 handle: *mut LruHandle<K, T>,
1025}
1026
1027pub enum LookupResult<K: LruKey, T: LruValue> {
1028 Cached(CacheableEntry<K, T>),
1029 Miss,
1030 WaitPendingRequest(Receiver<CacheableEntry<K, T>>),
1031}
1032
1033unsafe impl<K: LruKey, T: LruValue> Send for CacheableEntry<K, T> {}
1034unsafe impl<K: LruKey, T: LruValue> Sync for CacheableEntry<K, T> {}
1035
1036impl<K: LruKey, T: LruValue> Deref for CacheableEntry<K, T> {
1037 type Target = T;
1038
1039 fn deref(&self) -> &Self::Target {
1040 unsafe { (*self.handle).get_value() }
1041 }
1042}
1043
1044impl<K: LruKey, T: LruValue> Drop for CacheableEntry<K, T> {
1045 fn drop(&mut self) {
1046 unsafe {
1047 self.cache.release(self.handle);
1048 }
1049 }
1050}
1051
1052impl<K: LruKey, T: LruValue> Clone for CacheableEntry<K, T> {
1053 fn clone(&self) -> Self {
1054 unsafe {
1055 self.cache.inc_reference(self.handle);
1056 CacheableEntry {
1057 cache: self.cache.clone(),
1058 handle: self.handle,
1059 }
1060 }
1061 }
1062}
1063
1064#[cfg(test)]
1065mod tests {
1066 use std::collections::hash_map::DefaultHasher;
1067 use std::hash::Hasher;
1068 use std::pin::Pin;
1069 use std::sync::atomic::AtomicBool;
1070 use std::sync::atomic::Ordering::Relaxed;
1071 use std::task::{Context, Poll};
1072
1073 use rand::rngs::SmallRng;
1074 use rand::{RngCore, SeedableRng};
1075 use tokio::sync::oneshot::error::TryRecvError;
1076
1077 use super::*;
1078
1079 pub struct Block {
1080 pub offset: u64,
1081 #[allow(dead_code)]
1082 pub sst: u64,
1083 }
1084
1085 #[test]
1086 fn test_cache_handle_basic() {
1087 let mut h = Box::new(LruHandle::new(1, 2, 0, 0));
1088 h.set_in_cache(true);
1089 assert!(h.is_in_cache());
1090 h.set_in_cache(false);
1091 assert!(!h.is_in_cache());
1092 }
1093
1094 #[test]
1095 fn test_cache_shard() {
1096 let cache = Arc::new(LruCache::<(u64, u64), Block>::new(4, 256, 0));
1097 assert_eq!(cache.shard(0), 0);
1098 assert_eq!(cache.shard(1), 1);
1099 assert_eq!(cache.shard(10), 2);
1100 }
1101
1102 #[test]
1103 fn test_cache_basic() {
1104 let cache = Arc::new(LruCache::<(u64, u64), Block>::new(2, 256, 0));
1105 let seed = 10244021u64;
1106 let mut rng = SmallRng::seed_from_u64(seed);
1107 for _ in 0..100000 {
1108 let block_offset = rng.next_u64() % 1024;
1109 let sst = rng.next_u64() % 1024;
1110 let mut hasher = DefaultHasher::new();
1111 sst.hash(&mut hasher);
1112 block_offset.hash(&mut hasher);
1113 let h = hasher.finish();
1114 if let Some(block) = cache.lookup(h, &(sst, block_offset)) {
1115 assert_eq!(block.offset, block_offset);
1116 drop(block);
1117 continue;
1118 }
1119 cache.insert(
1120 (sst, block_offset),
1121 h,
1122 1,
1123 Block {
1124 offset: block_offset,
1125 sst,
1126 },
1127 CachePriority::High,
1128 );
1129 }
1130 assert_eq!(256, cache.get_memory_usage());
1131 }
1132
1133 fn validate_lru_list(cache: &mut LruCacheShard<String, String>, keys: Vec<&str>) {
1134 unsafe {
1135 let mut lru: *mut LruHandle<String, String> = cache.lru.as_mut();
1136 for k in keys {
1137 lru = (*lru).next;
1138 assert!(
1139 (*lru).is_same_key(&k.to_owned()),
1140 "compare failed: {} vs {}, get value: {:?}",
1141 (*lru).get_key(),
1142 k,
1143 (*lru).get_value()
1144 );
1145 }
1146 }
1147 }
1148
1149 fn create_cache(capacity: usize) -> LruCacheShard<String, String> {
1150 LruCacheShard::new_with_priority_pool(capacity, 0)
1151 }
1152
1153 fn lookup(cache: &mut LruCacheShard<String, String>, key: &str) -> bool {
1154 unsafe {
1155 let h = cache.lookup(0, &key.to_owned());
1156 let exist = !h.is_null();
1157 if exist {
1158 assert!((*h).is_same_key(&key.to_owned()));
1159 cache.release(h);
1160 }
1161 exist
1162 }
1163 }
1164
1165 fn insert_priority(
1166 cache: &mut LruCacheShard<String, String>,
1167 key: &str,
1168 value: &str,
1169 priority: CachePriority,
1170 ) {
1171 let mut free_list = vec![];
1172 unsafe {
1173 let handle = cache.insert(
1174 key.to_owned(),
1175 0,
1176 value.len(),
1177 value.to_owned(),
1178 priority,
1179 &mut free_list,
1180 );
1181 cache.release(handle);
1182 }
1183 free_list.clear();
1184 }
1185
1186 fn insert(cache: &mut LruCacheShard<String, String>, key: &str, value: &str) {
1187 insert_priority(cache, key, value, CachePriority::Low);
1188 }
1189
1190 #[test]
1191 fn test_basic_lru() {
1192 let mut cache = LruCacheShard::new_with_priority_pool(5, 40);
1193 let keys = vec!["a", "b", "c", "d", "e"];
1194 for &k in &keys {
1195 insert(&mut cache, k, k);
1196 }
1197 validate_lru_list(&mut cache, keys);
1198 for k in ["x", "y", "z"] {
1199 insert(&mut cache, k, k);
1200 }
1201 validate_lru_list(&mut cache, vec!["d", "e", "x", "y", "z"]);
1202 assert!(!lookup(&mut cache, "b"));
1203 assert!(lookup(&mut cache, "e"));
1204 validate_lru_list(&mut cache, vec!["d", "x", "y", "z", "e"]);
1205 assert!(lookup(&mut cache, "z"));
1206 validate_lru_list(&mut cache, vec!["d", "x", "y", "e", "z"]);
1207 unsafe {
1208 let h = cache.erase(0, &"x".to_owned());
1209 assert!(h.is_some());
1210 validate_lru_list(&mut cache, vec!["d", "y", "e", "z"]);
1211 }
1212 assert!(lookup(&mut cache, "d"));
1213 validate_lru_list(&mut cache, vec!["y", "e", "z", "d"]);
1214 insert(&mut cache, "u", "u");
1215 validate_lru_list(&mut cache, vec!["y", "e", "z", "d", "u"]);
1216 insert(&mut cache, "v", "v");
1217 validate_lru_list(&mut cache, vec!["e", "z", "d", "u", "v"]);
1218 insert_priority(&mut cache, "x", "x", CachePriority::High);
1219 validate_lru_list(&mut cache, vec!["z", "d", "u", "v", "x"]);
1220 assert!(lookup(&mut cache, "d"));
1221 validate_lru_list(&mut cache, vec!["z", "u", "v", "d", "x"]);
1222 insert(&mut cache, "y", "y");
1223 validate_lru_list(&mut cache, vec!["u", "v", "d", "y", "x"]);
1224 insert_priority(&mut cache, "z", "z", CachePriority::High);
1225 validate_lru_list(&mut cache, vec!["v", "d", "y", "x", "z"]);
1226 insert(&mut cache, "u", "u");
1227 validate_lru_list(&mut cache, vec!["d", "y", "u", "x", "z"]);
1228 insert_priority(&mut cache, "v", "v", CachePriority::High);
1229 validate_lru_list(&mut cache, vec!["y", "u", "x", "z", "v"]);
1230 }
1231
1232 #[test]
1233 fn test_reference_and_usage() {
1234 let mut cache = create_cache(5);
1235 insert(&mut cache, "k1", "a");
1236 assert_eq!(cache.usage.load(Ordering::Relaxed), 1);
1237 insert(&mut cache, "k0", "aa");
1238 assert_eq!(cache.usage.load(Ordering::Relaxed), 3);
1239 insert(&mut cache, "k1", "aa");
1240 assert_eq!(cache.usage.load(Ordering::Relaxed), 4);
1241 insert(&mut cache, "k2", "aa");
1242 assert_eq!(cache.usage.load(Ordering::Relaxed), 4);
1243 let mut free_list = vec![];
1244 validate_lru_list(&mut cache, vec!["k1", "k2"]);
1245 unsafe {
1246 let h1 = cache.lookup(0, &"k1".to_owned());
1247 assert!(!h1.is_null());
1248 let h2 = cache.lookup(0, &"k2".to_owned());
1249 assert!(!h2.is_null());
1250
1251 let h3 = cache.insert(
1252 "k3".to_owned(),
1253 0,
1254 2,
1255 "bb".to_owned(),
1256 CachePriority::High,
1257 &mut free_list,
1258 );
1259 assert_eq!(cache.usage.load(Ordering::Relaxed), 6);
1260 assert!(!h3.is_null());
1261 let h4 = cache.lookup(0, &"k1".to_owned());
1262 assert!(!h4.is_null());
1263
1264 cache.release(h1);
1265 assert_eq!(cache.usage.load(Ordering::Relaxed), 6);
1266 cache.release(h4);
1267 assert_eq!(cache.usage.load(Ordering::Relaxed), 4);
1268
1269 cache.release(h3);
1270 cache.release(h2);
1271
1272 validate_lru_list(&mut cache, vec!["k3", "k2"]);
1273 }
1274 }
1275
1276 #[test]
1277 fn test_update_referenced_key() {
1278 unsafe {
1279 let mut to_delete = vec![];
1280 let mut cache = create_cache(5);
1281 let insert_handle = cache.insert(
1282 "key".to_owned(),
1283 0,
1284 1,
1285 "old_value".to_owned(),
1286 CachePriority::High,
1287 &mut to_delete,
1288 );
1289 let old_entry = cache.lookup(0, &"key".to_owned());
1290 assert!(!old_entry.is_null());
1291 assert_eq!((*old_entry).get_value(), &"old_value".to_owned());
1292 assert_eq!((*old_entry).refs, 2);
1293 cache.release(insert_handle);
1294 assert_eq!((*old_entry).refs, 1);
1295 let insert_handle = cache.insert(
1296 "key".to_owned(),
1297 0,
1298 1,
1299 "new_value".to_owned(),
1300 CachePriority::Low,
1301 &mut to_delete,
1302 );
1303 assert!(!(*old_entry).is_in_cache());
1304 let new_entry = cache.lookup(0, &"key".to_owned());
1305 assert!(!new_entry.is_null());
1306 assert_eq!((*new_entry).get_value(), &"new_value".to_owned());
1307 assert_eq!((*new_entry).refs, 2);
1308 cache.release(insert_handle);
1309 assert_eq!((*new_entry).refs, 1);
1310
1311 assert!(!old_entry.is_null());
1312 assert_eq!((*old_entry).get_value(), &"old_value".to_owned());
1313 assert_eq!((*old_entry).refs, 1);
1314
1315 cache.release(new_entry);
1316 assert!((*new_entry).is_in_cache());
1317 #[cfg(debug_assertions)]
1318 assert!((*new_entry).is_in_lru());
1319
1320 assert!(!old_entry.is_null());
1322 assert_eq!((*old_entry).get_value(), &"old_value".to_owned());
1323 assert_eq!((*old_entry).refs, 1);
1324
1325 cache.release(old_entry);
1326 assert!(!(*old_entry).is_in_cache());
1327 assert!((*new_entry).is_in_cache());
1328 #[cfg(debug_assertions)]
1329 {
1330 assert!(!(*old_entry).is_in_lru());
1331 assert!((*new_entry).is_in_lru());
1332 }
1333 }
1334 }
1335
1336 #[test]
1337 fn test_release_stale_value() {
1338 unsafe {
1339 let mut to_delete = vec![];
1340 let mut cache = create_cache(1);
1342 let insert_handle = cache.insert(
1343 "key".to_owned(),
1344 0,
1345 1,
1346 "old_value".to_owned(),
1347 CachePriority::High,
1348 &mut to_delete,
1349 );
1350 cache.release(insert_handle);
1351 let old_entry = cache.lookup(0, &"key".to_owned());
1352 assert!(!old_entry.is_null());
1353 assert_eq!((*old_entry).get_value(), &"old_value".to_owned());
1354 assert_eq!((*old_entry).refs, 1);
1355
1356 let insert_handle = cache.insert(
1357 "key".to_owned(),
1358 0,
1359 1,
1360 "new_value".to_owned(),
1361 CachePriority::High,
1362 &mut to_delete,
1363 );
1364 assert!(!(*old_entry).is_in_cache());
1365 let new_entry = cache.lookup(0, &"key".to_owned());
1366 assert!(!new_entry.is_null());
1367 assert_eq!((*new_entry).get_value(), &"new_value".to_owned());
1368 assert_eq!((*new_entry).refs, 2);
1369 cache.release(insert_handle);
1370 assert_eq!((*new_entry).refs, 1);
1371
1372 assert_eq!(2, cache.usage.load(Relaxed));
1374 assert_eq!(0, cache.lru_usage.load(Relaxed));
1375
1376 cache.release(old_entry);
1378 assert_eq!(1, cache.usage.load(Relaxed));
1379 assert_eq!(0, cache.lru_usage.load(Relaxed));
1380
1381 let new_entry_again = cache.lookup(0, &"key".to_owned());
1382 assert!(!new_entry_again.is_null());
1383 assert_eq!((*new_entry_again).get_value(), &"new_value".to_owned());
1384 assert_eq!((*new_entry_again).refs, 2);
1385
1386 cache.release(new_entry);
1387 cache.release(new_entry_again);
1388
1389 assert_eq!(1, cache.usage.load(Relaxed));
1390 assert_eq!(1, cache.lru_usage.load(Relaxed));
1391 }
1392 }
1393
1394 #[test]
1395 fn test_write_request_pending() {
1396 let cache = Arc::new(LruCache::new(1, 5, 0));
1397 {
1398 let mut shard = cache.shards[0].lock();
1399 insert(&mut shard, "a", "v1");
1400 assert!(lookup(&mut shard, "a"));
1401 }
1402 assert!(matches!(
1403 cache.lookup_for_request(0, "a".to_owned()),
1404 LookupResult::Cached(_)
1405 ));
1406 assert!(matches!(
1407 cache.lookup_for_request(0, "b".to_owned()),
1408 LookupResult::Miss
1409 ));
1410 let ret2 = cache.lookup_for_request(0, "b".to_owned());
1411 match ret2 {
1412 LookupResult::WaitPendingRequest(mut recv) => {
1413 assert!(matches!(recv.try_recv(), Err(TryRecvError::Empty)));
1414 cache.insert("b".to_owned(), 0, 1, "v2".to_owned(), CachePriority::Low);
1415 recv.try_recv().unwrap();
1416 assert!(
1417 matches!(cache.lookup_for_request(0, "b".to_owned()), LookupResult::Cached(v) if v.eq("v2"))
1418 );
1419 }
1420 _ => panic!(),
1421 }
1422 }
1423
1424 #[derive(Default, Debug)]
1425 struct TestLruCacheEventListener {
1426 released: Arc<Mutex<HashMap<String, String>>>,
1427 }
1428
1429 impl LruCacheEventListener for TestLruCacheEventListener {
1430 type K = String;
1431 type T = String;
1432
1433 fn on_release(&self, key: Self::K, value: Self::T) {
1434 self.released.lock().insert(key, value);
1435 }
1436 }
1437
1438 #[test]
1439 fn test_event_listener() {
1440 let listener = Arc::new(TestLruCacheEventListener::default());
1441 let cache = Arc::new(LruCache::with_event_listener(1, 2, 0, listener.clone()));
1442
1443 let h = cache.insert("k1".to_owned(), 0, 1, "v1".to_owned(), CachePriority::High);
1445 drop(h);
1446 let h = cache.insert("k2".to_owned(), 0, 1, "v2".to_owned(), CachePriority::High);
1447 drop(h);
1448 assert_eq!(cache.get_memory_usage(), 2);
1449 assert!(listener.released.lock().is_empty());
1450
1451 let h = cache.insert("k3".to_owned(), 0, 1, "v3".to_owned(), CachePriority::High);
1453 drop(h);
1454 assert_eq!(cache.get_memory_usage(), 2);
1455 assert!(listener.released.lock().remove("k1").is_some());
1456
1457 cache.erase(0, &"k2".to_owned());
1459 assert_eq!(cache.get_memory_usage(), 1);
1460 assert!(listener.released.lock().remove("k2").is_some());
1461
1462 let h = cache.insert("k4".to_owned(), 0, 1, "v4".to_owned(), CachePriority::Low);
1464 drop(h);
1465 assert_eq!(cache.get_memory_usage(), 2);
1466 assert!(listener.released.lock().is_empty());
1467
1468 let h1 = cache.insert("k5".to_owned(), 0, 1, "v5".to_owned(), CachePriority::Low);
1471 assert_eq!(cache.get_memory_usage(), 2);
1472 assert!(listener.released.lock().remove("k3").is_some());
1473 let h2 = cache.insert("k6".to_owned(), 0, 1, "v6".to_owned(), CachePriority::Low);
1474 assert_eq!(cache.get_memory_usage(), 2);
1475 assert!(listener.released.lock().remove("k4").is_some());
1476
1477 let h3 = cache.insert("k7".to_owned(), 0, 1, "v7".to_owned(), CachePriority::Low);
1479 assert_eq!(cache.get_memory_usage(), 3);
1480 assert!(listener.released.lock().is_empty());
1481
1482 drop(h1);
1484 assert_eq!(cache.get_memory_usage(), 2);
1485 assert!(listener.released.lock().remove("k5").is_some());
1486
1487 drop(h2);
1489 assert_eq!(cache.get_memory_usage(), 2);
1490 assert!(listener.released.lock().is_empty());
1491 drop(h3);
1492 assert_eq!(cache.get_memory_usage(), 2);
1493 assert!(listener.released.lock().is_empty());
1494
1495 drop(cache);
1497 assert!(listener.released.lock().is_empty());
1498 }
1499
1500 pub struct SyncPointFuture<F: Future> {
1501 inner: F,
1502 polled: Arc<AtomicBool>,
1503 }
1504
1505 impl<F: Future + Unpin> Future for SyncPointFuture<F> {
1506 type Output = ();
1507
1508 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
1509 if self.polled.load(Ordering::Acquire) {
1510 return Poll::Ready(());
1511 }
1512 self.inner.poll_unpin(cx).map(|_| ())
1513 }
1514 }
1515
1516 #[tokio::test]
1517 async fn test_future_cancel() {
1518 let cache: Arc<LruCache<u64, u64>> = Arc::new(LruCache::new(1, 5, 0));
1519 let (_, recv) = channel::<()>();
1521 let polled = Arc::new(AtomicBool::new(false));
1522 let cache2 = cache.clone();
1523 let polled2 = polled.clone();
1524 let f = Box::pin(async move {
1525 cache2
1526 .lookup_with_request_dedup(1, 2, CachePriority::High, || async move {
1527 polled2.store(true, Ordering::Release);
1528 recv.await.map(|_| (1, 1))
1529 })
1530 .await
1531 .unwrap();
1532 });
1533 let wrapper = SyncPointFuture {
1534 inner: f,
1535 polled: polled.clone(),
1536 };
1537 {
1538 let handle = tokio::spawn(wrapper);
1539 while !polled.load(Ordering::Acquire) {
1540 tokio::task::yield_now().await;
1541 }
1542 handle.await.unwrap();
1543 }
1544 assert!(cache.shards[0].lock().write_request.is_empty());
1545 }
1546}