1use std::collections::HashMap;
23use std::error::Error;
24use std::future::Future;
25use std::hash::Hash;
26use std::ops::Deref;
27use std::ptr;
28use std::ptr::null_mut;
29use std::sync::Arc;
30use std::sync::atomic::{AtomicUsize, Ordering};
31
32use futures::FutureExt;
33use parking_lot::Mutex;
34use tokio::sync::oneshot::error::RecvError;
35use tokio::sync::oneshot::{Receiver, Sender, channel};
36use tokio::task::JoinHandle;
37
38const IN_CACHE: u8 = 1;
39const REVERSE_IN_CACHE: u8 = !IN_CACHE;
40
41#[cfg(debug_assertions)]
42const IN_LRU: u8 = 1 << 1;
43#[cfg(debug_assertions)]
44const REVERSE_IN_LRU: u8 = !IN_LRU;
45const IS_HIGH_PRI: u8 = 1 << 2;
46const IN_HIGH_PRI_POOL: u8 = 1 << 3;
47
48pub trait LruKey: Eq + Send + Hash {}
49impl<T: Eq + Send + Hash> LruKey for T {}
50
51pub trait LruValue: Send + Sync {}
52impl<T: Send + Sync> LruValue for T {}
53
54#[derive(Clone, Copy, Eq, PartialEq)]
55pub enum CachePriority {
56 High,
57 Low,
58}
59
60pub struct LruHandle<K: LruKey, T: LruValue> {
87 next_hash: *mut LruHandle<K, T>,
89
90 next: *mut LruHandle<K, T>,
92
93 prev: *mut LruHandle<K, T>,
95
96 kv: Option<(K, T)>,
99 hash: u64,
100 charge: usize,
101
102 refs: u32,
105 flags: u8,
106}
107
108impl<K: LruKey, T: LruValue> Default for LruHandle<K, T> {
109 fn default() -> Self {
110 Self {
111 next_hash: null_mut(),
112 next: null_mut(),
113 prev: null_mut(),
114 kv: None,
115 hash: 0,
116 charge: 0,
117 refs: 0,
118 flags: 0,
119 }
120 }
121}
122
123impl<K: LruKey, T: LruValue> LruHandle<K, T> {
124 pub fn new(key: K, value: T, hash: u64, charge: usize) -> Self {
125 let mut ret = Self::default();
126 ret.init(key, value, hash, charge);
127 ret
128 }
129
130 pub fn init(&mut self, key: K, value: T, hash: u64, charge: usize) {
131 self.next_hash = null_mut();
132 self.prev = null_mut();
133 self.next = null_mut();
134 self.kv = Some((key, value));
135 self.hash = hash;
136 self.charge = charge;
137 self.flags = 0;
138 self.refs = 0;
139 }
140
141 fn set_in_cache(&mut self, in_cache: bool) {
148 if in_cache {
149 self.flags |= IN_CACHE;
150 } else {
151 self.flags &= REVERSE_IN_CACHE;
152 }
153 }
154
155 fn is_high_priority(&self) -> bool {
156 (self.flags & IS_HIGH_PRI) > 0
157 }
158
159 fn set_high_priority(&mut self, high_priority: bool) {
160 if high_priority {
161 self.flags |= IS_HIGH_PRI;
162 } else {
163 self.flags &= !IS_HIGH_PRI;
164 }
165 }
166
167 fn set_in_high_pri_pool(&mut self, in_high_pri_pool: bool) {
168 if in_high_pri_pool {
169 self.flags |= IN_HIGH_PRI_POOL;
170 } else {
171 self.flags &= !IN_HIGH_PRI_POOL;
172 }
173 }
174
175 fn is_in_high_pri_pool(&self) -> bool {
176 (self.flags & IN_HIGH_PRI_POOL) > 0
177 }
178
179 fn add_ref(&mut self) {
180 self.refs += 1;
181 }
182
183 fn add_multi_refs(&mut self, ref_count: u32) {
184 self.refs += ref_count;
185 }
186
187 fn unref(&mut self) -> bool {
188 debug_assert!(self.refs > 0);
189 self.refs -= 1;
190 self.refs == 0
191 }
192
193 fn has_refs(&self) -> bool {
194 self.refs > 0
195 }
196
197 fn is_in_cache(&self) -> bool {
200 (self.flags & IN_CACHE) > 0
201 }
202
203 unsafe fn get_key(&self) -> &K {
204 unsafe {
205 debug_assert!(self.kv.is_some());
206 &self.kv.as_ref().unwrap_unchecked().0
207 }
208 }
209
210 unsafe fn get_value(&self) -> &T {
211 unsafe {
212 debug_assert!(self.kv.is_some());
213 &self.kv.as_ref().unwrap_unchecked().1
214 }
215 }
216
217 unsafe fn is_same_key(&self, key: &K) -> bool {
218 unsafe {
219 debug_assert!(self.kv.is_some());
220 self.kv.as_ref().unwrap_unchecked().0.eq(key)
221 }
222 }
223
224 unsafe fn take_kv(&mut self) -> (K, T) {
225 unsafe {
226 debug_assert!(self.kv.is_some());
227 self.kv.take().unwrap_unchecked()
228 }
229 }
230
231 #[cfg(debug_assertions)]
232 fn is_in_lru(&self) -> bool {
233 (self.flags & IN_LRU) > 0
234 }
235
236 #[cfg(debug_assertions)]
237 fn set_in_lru(&mut self, in_lru: bool) {
238 if in_lru {
239 self.flags |= IN_LRU;
240 } else {
241 self.flags &= REVERSE_IN_LRU;
242 }
243 }
244}
245
246unsafe impl<K: LruKey, T: LruValue> Send for LruHandle<K, T> {}
247
248pub struct LruHandleTable<K: LruKey, T: LruValue> {
249 list: Vec<*mut LruHandle<K, T>>,
250 elems: usize,
251}
252
253impl<K: LruKey, T: LruValue> LruHandleTable<K, T> {
254 fn new() -> Self {
255 Self {
256 list: vec![null_mut(); 16],
257 elems: 0,
258 }
259 }
260
261 unsafe fn find_pointer(
263 &self,
264 idx: usize,
265 key: &K,
266 ) -> (*mut LruHandle<K, T>, *mut LruHandle<K, T>) {
267 unsafe {
268 let mut ptr = self.list[idx];
269 let mut prev = null_mut();
270 while !ptr.is_null() && !(*ptr).is_same_key(key) {
271 prev = ptr;
272 ptr = (*ptr).next_hash;
273 }
274 (prev, ptr)
275 }
276 }
277
278 unsafe fn remove(&mut self, hash: u64, key: &K) -> *mut LruHandle<K, T> {
279 unsafe {
280 debug_assert!(self.list.len().is_power_of_two());
281 let idx = (hash as usize) & (self.list.len() - 1);
282 let (prev, ptr) = self.find_pointer(idx, key);
283 if ptr.is_null() {
284 return null_mut();
285 }
286 debug_assert!((*ptr).is_in_cache());
287 (*ptr).set_in_cache(false);
288 if prev.is_null() {
289 self.list[idx] = (*ptr).next_hash;
290 } else {
291 (*prev).next_hash = (*ptr).next_hash;
292 }
293 self.elems -= 1;
294 ptr
295 }
296 }
297
298 unsafe fn insert(&mut self, hash: u64, h: *mut LruHandle<K, T>) -> *mut LruHandle<K, T> {
301 unsafe {
302 debug_assert!(!h.is_null());
303 debug_assert!(!(*h).is_in_cache());
304 (*h).set_in_cache(true);
305 debug_assert!(self.list.len().is_power_of_two());
306 let idx = (hash as usize) & (self.list.len() - 1);
307 let (prev, ptr) = self.find_pointer(idx, (*h).get_key());
308 if prev.is_null() {
309 self.list[idx] = h;
310 } else {
311 (*prev).next_hash = h;
312 }
313
314 if !ptr.is_null() {
315 debug_assert!((*ptr).is_same_key((*h).get_key()));
316 debug_assert!((*ptr).is_in_cache());
317 (*ptr).set_in_cache(false);
319 (*h).next_hash = (*ptr).next_hash;
320 return ptr;
321 }
322
323 (*h).next_hash = ptr;
324
325 self.elems += 1;
326 if self.elems > self.list.len() {
327 self.resize();
328 }
329 null_mut()
330 }
331 }
332
333 unsafe fn lookup(&self, hash: u64, key: &K) -> *mut LruHandle<K, T> {
334 unsafe {
335 debug_assert!(self.list.len().is_power_of_two());
336 let idx = (hash as usize) & (self.list.len() - 1);
337 let (_, ptr) = self.find_pointer(idx, key);
338 ptr
339 }
340 }
341
342 unsafe fn resize(&mut self) {
343 unsafe {
344 let mut l = std::cmp::max(16, self.list.len());
345 let next_capacity = self.elems * 3 / 2;
346 while l < next_capacity {
347 l <<= 1;
348 }
349 let mut count = 0;
350 let mut new_list = vec![null_mut(); l];
351 for head in self.list.drain(..) {
352 let mut handle = head;
353 while !handle.is_null() {
354 let idx = (*handle).hash as usize & (l - 1);
355 let next = (*handle).next_hash;
356 (*handle).next_hash = new_list[idx];
357 new_list[idx] = handle;
358 handle = next;
359 count += 1;
360 }
361 }
362 assert_eq!(count, self.elems);
363 self.list = new_list;
364 }
365 }
366
367 unsafe fn for_all<F>(&self, f: &mut F)
368 where
369 F: FnMut(&K, &T),
370 {
371 unsafe {
372 for idx in 0..self.list.len() {
373 let mut ptr = self.list[idx];
374 while !ptr.is_null() {
375 f((*ptr).get_key(), (*ptr).get_value());
376 ptr = (*ptr).next_hash;
377 }
378 }
379 }
380 }
381}
382
383type RequestQueue<K, T> = Vec<Sender<CacheableEntry<K, T>>>;
384pub struct LruCacheShard<K: LruKey, T: LruValue> {
385 lru: Box<LruHandle<K, T>>,
390 low_priority_head: *mut LruHandle<K, T>,
391 high_priority_pool_capacity: usize,
392 high_priority_pool_usage: usize,
393 table: LruHandleTable<K, T>,
394 object_pool: Vec<Box<LruHandle<K, T>>>,
396 write_request: HashMap<K, RequestQueue<K, T>>,
397 lru_usage: Arc<AtomicUsize>,
398 usage: Arc<AtomicUsize>,
399 capacity: usize,
400}
401
402unsafe impl<K: LruKey, T: LruValue> Send for LruCacheShard<K, T> {}
403
404impl<K: LruKey, T: LruValue> LruCacheShard<K, T> {
405 fn new_with_priority_pool(capacity: usize, high_priority_ratio_percent: usize) -> Self {
407 let mut lru = Box::<LruHandle<K, T>>::default();
408 lru.prev = lru.as_mut();
409 lru.next = lru.as_mut();
410 let mut object_pool = Vec::with_capacity(DEFAULT_OBJECT_POOL_SIZE);
411 for _ in 0..DEFAULT_OBJECT_POOL_SIZE {
412 object_pool.push(Box::default());
413 }
414 Self {
415 capacity,
416 lru_usage: Arc::new(AtomicUsize::new(0)),
417 usage: Arc::new(AtomicUsize::new(0)),
418 object_pool,
419 low_priority_head: lru.as_mut(),
420 high_priority_pool_capacity: high_priority_ratio_percent * capacity / 100,
421 lru,
422 table: LruHandleTable::new(),
423 write_request: HashMap::with_capacity(16),
424 high_priority_pool_usage: 0,
425 }
426 }
427
428 unsafe fn lru_remove(&mut self, e: *mut LruHandle<K, T>) {
429 unsafe {
430 debug_assert!(!e.is_null());
431 #[cfg(debug_assertions)]
432 {
433 assert!((*e).is_in_lru());
434 (*e).set_in_lru(false);
435 }
436
437 if ptr::eq(e, self.low_priority_head) {
438 self.low_priority_head = (*e).prev;
439 }
440
441 (*(*e).next).prev = (*e).prev;
442 (*(*e).prev).next = (*e).next;
443 (*e).prev = null_mut();
444 (*e).next = null_mut();
445 if (*e).is_in_high_pri_pool() {
446 debug_assert!(self.high_priority_pool_usage >= (*e).charge);
447 self.high_priority_pool_usage -= (*e).charge;
448 }
449 self.lru_usage.fetch_sub((*e).charge, Ordering::Relaxed);
450 }
451 }
452
453 unsafe fn lru_insert(&mut self, e: *mut LruHandle<K, T>) {
455 unsafe {
456 debug_assert!(!e.is_null());
457 let entry = &mut (*e);
458 #[cfg(debug_assertions)]
459 {
460 assert!(!(*e).is_in_lru());
461 (*e).set_in_lru(true);
462 }
463
464 if self.high_priority_pool_capacity > 0 && entry.is_high_priority() {
465 entry.set_in_high_pri_pool(true);
466 entry.next = self.lru.as_mut();
467 entry.prev = self.lru.prev;
468 (*entry.prev).next = e;
469 (*entry.next).prev = e;
470 self.high_priority_pool_usage += (*e).charge;
471 self.maintain_pool_size();
472 } else {
473 entry.set_in_high_pri_pool(false);
474 entry.next = (*self.low_priority_head).next;
475 entry.prev = self.low_priority_head;
476 (*entry.next).prev = e;
477 (*entry.prev).next = e;
478 self.low_priority_head = e;
479 }
480 self.lru_usage.fetch_add((*e).charge, Ordering::Relaxed);
481 }
482 }
483
484 unsafe fn maintain_pool_size(&mut self) {
485 unsafe {
486 while self.high_priority_pool_usage > self.high_priority_pool_capacity {
487 self.low_priority_head = (*self.low_priority_head).next;
489 (*self.low_priority_head).set_in_high_pri_pool(false);
490 self.high_priority_pool_usage -= (*self.low_priority_head).charge;
491 }
492 }
493 }
494
495 unsafe fn evict_from_lru(&mut self, charge: usize, last_reference_list: &mut Vec<(K, T)>) {
496 unsafe {
497 while self.usage.load(Ordering::Relaxed) + charge > self.capacity
500 && !std::ptr::eq(self.lru.next, self.lru.as_mut())
501 {
502 let old_ptr = self.lru.next;
503 self.table.remove((*old_ptr).hash, (*old_ptr).get_key());
504 self.lru_remove(old_ptr);
505 let (key, value) = self.clear_handle(old_ptr);
506 last_reference_list.push((key, value));
507 }
508 }
509 }
510
511 unsafe fn clear_handle(&mut self, h: *mut LruHandle<K, T>) -> (K, T) {
513 unsafe {
514 debug_assert!(!h.is_null());
515 debug_assert!((*h).kv.is_some());
516 #[cfg(debug_assertions)]
517 assert!(!(*h).is_in_lru());
518 debug_assert!(!(*h).is_in_cache());
519 debug_assert!(!(*h).has_refs());
520 self.usage.fetch_sub((*h).charge, Ordering::Relaxed);
521 let (key, value) = (*h).take_kv();
522 self.try_recycle_handle_object(h);
523 (key, value)
524 }
525 }
526
527 unsafe fn try_recycle_handle_object(&mut self, h: *mut LruHandle<K, T>) {
531 unsafe {
532 let mut node = Box::from_raw(h);
533 if self.object_pool.len() < self.object_pool.capacity() {
534 node.next_hash = null_mut();
535 node.next = null_mut();
536 node.prev = null_mut();
537 debug_assert!(node.kv.is_none());
538 self.object_pool.push(node);
539 }
540 }
541 }
542
543 unsafe fn insert(
545 &mut self,
546 key: K,
547 hash: u64,
548 charge: usize,
549 value: T,
550 priority: CachePriority,
551 last_reference_list: &mut Vec<(K, T)>,
552 ) -> *mut LruHandle<K, T> {
553 unsafe {
554 self.evict_from_lru(charge, last_reference_list);
555
556 let mut handle = match self.object_pool.pop() {
557 Some(mut h) => {
558 h.init(key, value, hash, charge);
559 h
560 }
561 _ => Box::new(LruHandle::new(key, value, hash, charge)),
562 };
563 if priority == CachePriority::High {
564 handle.set_high_priority(true);
565 }
566 let ptr = Box::into_raw(handle);
567 let old = self.table.insert(hash, ptr);
568 if !old.is_null() {
569 if let Some(data) = self.try_remove_cache_handle(old) {
570 last_reference_list.push(data);
571 }
572 }
573 self.usage.fetch_add(charge, Ordering::Relaxed);
574 (*ptr).add_ref();
575 ptr
576 }
577 }
578
579 unsafe fn release(&mut self, h: *mut LruHandle<K, T>) -> Option<(K, T)> {
583 unsafe {
584 debug_assert!(!h.is_null());
585 #[cfg(debug_assertions)]
587 assert!(!(*h).is_in_lru());
588 let last_reference = (*h).unref();
589 if !last_reference {
591 return None;
592 }
593
594 if (*h).is_in_cache() {
596 if self.usage.load(Ordering::Relaxed) <= self.capacity {
597 self.lru_insert(h);
598 return None;
599 }
600 self.table.remove((*h).hash, (*h).get_key());
602 }
603
604 #[cfg(debug_assertions)]
607 assert!(!(*h).is_in_lru());
608
609 let (key, value) = self.clear_handle(h);
610 Some((key, value))
611 }
612 }
613
614 unsafe fn lookup(&mut self, hash: u64, key: &K) -> *mut LruHandle<K, T> {
615 unsafe {
616 let e = self.table.lookup(hash, key);
617 if !e.is_null() {
618 if !(*e).has_refs() {
621 self.lru_remove(e);
622 }
623 (*e).add_ref();
624 }
625 e
626 }
627 }
628
629 unsafe fn erase(&mut self, hash: u64, key: &K) -> Option<(K, T)> {
631 unsafe {
632 let h = self.table.remove(hash, key);
633 if !h.is_null() {
634 self.try_remove_cache_handle(h)
635 } else {
636 None
637 }
638 }
639 }
640
641 unsafe fn try_remove_cache_handle(&mut self, h: *mut LruHandle<K, T>) -> Option<(K, T)> {
645 unsafe {
646 debug_assert!(!h.is_null());
647 if !(*h).has_refs() {
648 self.lru_remove(h);
652 let (key, value) = self.clear_handle(h);
653 return Some((key, value));
654 }
655 None
656 }
657 }
658
659 unsafe fn clear(&mut self) {
662 unsafe {
663 while !std::ptr::eq(self.lru.next, self.lru.as_mut()) {
664 let handle = self.lru.next;
665 self.erase((*handle).hash, (*handle).get_key());
667 }
668 }
669 }
670
671 fn for_all<F>(&self, f: &mut F)
672 where
673 F: FnMut(&K, &T),
674 {
675 unsafe { self.table.for_all(f) };
676 }
677}
678
679impl<K: LruKey, T: LruValue> Drop for LruCacheShard<K, T> {
680 fn drop(&mut self) {
681 unsafe {
684 self.clear();
685 }
686 }
687}
688
689pub trait LruCacheEventListener: Send + Sync {
690 type K: LruKey;
691 type T: LruValue;
692
693 fn on_release(&self, _key: Self::K, _value: Self::T) {}
698}
699
700pub struct LruCache<K: LruKey, T: LruValue> {
701 shards: Vec<Mutex<LruCacheShard<K, T>>>,
702 shard_usages: Vec<Arc<AtomicUsize>>,
703 shard_lru_usages: Vec<Arc<AtomicUsize>>,
704
705 listener: Option<Arc<dyn LruCacheEventListener<K = K, T = T>>>,
706}
707
708const DEFAULT_OBJECT_POOL_SIZE: usize = 1024;
711
712impl<K: LruKey, T: LruValue> LruCache<K, T> {
713 pub fn new(num_shards: usize, capacity: usize, high_priority_ratio: usize) -> Self {
714 Self::new_inner(num_shards, capacity, high_priority_ratio, None)
715 }
716
717 pub fn with_event_listener(
718 num_shards: usize,
719 capacity: usize,
720 high_priority_ratio: usize,
721 listener: Arc<dyn LruCacheEventListener<K = K, T = T>>,
722 ) -> Self {
723 Self::new_inner(num_shards, capacity, high_priority_ratio, Some(listener))
724 }
725
726 fn new_inner(
727 num_shards: usize,
728 capacity: usize,
729 high_priority_ratio: usize,
730 listener: Option<Arc<dyn LruCacheEventListener<K = K, T = T>>>,
731 ) -> Self {
732 let mut shards = Vec::with_capacity(num_shards);
733 let per_shard = capacity / num_shards;
734 let mut shard_usages = Vec::with_capacity(num_shards);
735 let mut shard_lru_usages = Vec::with_capacity(num_shards);
736 for _ in 0..num_shards {
737 let shard = LruCacheShard::new_with_priority_pool(per_shard, high_priority_ratio);
738 shard_usages.push(shard.usage.clone());
739 shard_lru_usages.push(shard.lru_usage.clone());
740 shards.push(Mutex::new(shard));
741 }
742 Self {
743 shards,
744 shard_usages,
745 shard_lru_usages,
746 listener,
747 }
748 }
749
750 pub fn contains(self: &Arc<Self>, hash: u64, key: &K) -> bool {
751 let shard = self.shards[self.shard(hash)].lock();
752 unsafe {
753 let ptr = shard.table.lookup(hash, key);
754 !ptr.is_null()
755 }
756 }
757
758 pub fn lookup(self: &Arc<Self>, hash: u64, key: &K) -> Option<CacheableEntry<K, T>> {
759 let mut shard = self.shards[self.shard(hash)].lock();
760 unsafe {
761 let ptr = shard.lookup(hash, key);
762 if ptr.is_null() {
763 return None;
764 }
765 let entry = CacheableEntry {
766 cache: self.clone(),
767 handle: ptr,
768 };
769 Some(entry)
770 }
771 }
772
773 pub fn lookup_for_request(self: &Arc<Self>, hash: u64, key: K) -> LookupResult<K, T> {
774 let mut shard = self.shards[self.shard(hash)].lock();
775 unsafe {
776 let ptr = shard.lookup(hash, &key);
777 if !ptr.is_null() {
778 return LookupResult::Cached(CacheableEntry {
779 cache: self.clone(),
780 handle: ptr,
781 });
782 }
783 if let Some(que) = shard.write_request.get_mut(&key) {
784 let (tx, recv) = channel();
785 que.push(tx);
786 return LookupResult::WaitPendingRequest(recv);
787 }
788 shard.write_request.insert(key, vec![]);
789 LookupResult::Miss
790 }
791 }
792
793 unsafe fn release(&self, handle: *mut LruHandle<K, T>) {
794 unsafe {
795 debug_assert!(!handle.is_null());
796 let data = {
797 let mut shard = self.shards[self.shard((*handle).hash)].lock();
798 shard.release(handle)
799 };
800 if let Some((key, value)) = data
802 && let Some(listener) = &self.listener
803 {
804 listener.on_release(key, value);
805 }
806 }
807 }
808
809 unsafe fn inc_reference(&self, handle: *mut LruHandle<K, T>) {
810 unsafe {
811 let _shard = self.shards[self.shard((*handle).hash)].lock();
812 (*handle).refs += 1;
813 }
814 }
815
816 pub fn insert(
817 self: &Arc<Self>,
818 key: K,
819 hash: u64,
820 charge: usize,
821 value: T,
822 priority: CachePriority,
823 ) -> CacheableEntry<K, T> {
824 let mut to_delete = vec![];
825 let mut senders = vec![];
827 let handle = unsafe {
828 let mut shard = self.shards[self.shard(hash)].lock();
829 let pending_request = shard.write_request.remove(&key);
830 let ptr = shard.insert(key, hash, charge, value, priority, &mut to_delete);
831 debug_assert!(!ptr.is_null());
832 if let Some(mut que) = pending_request {
833 (*ptr).add_multi_refs(que.len() as u32);
834 senders = std::mem::take(&mut que);
835 }
836 CacheableEntry {
837 cache: self.clone(),
838 handle: ptr,
839 }
840 };
841 for sender in senders {
842 let _ = sender.send(CacheableEntry {
843 cache: self.clone(),
844 handle: handle.handle,
845 });
846 }
847
848 if let Some(listener) = &self.listener {
850 for (key, value) in to_delete {
851 listener.on_release(key, value);
852 }
853 }
854 handle
855 }
856
857 pub fn clear_pending_request(&self, key: &K, hash: u64) {
858 let mut shard = self.shards[self.shard(hash)].lock();
859 shard.write_request.remove(key);
860 }
861
862 pub fn erase(&self, hash: u64, key: &K) {
863 let data = unsafe {
864 let mut shard = self.shards[self.shard(hash)].lock();
865 shard.erase(hash, key)
866 };
867 if let Some((key, value)) = data
869 && let Some(listener) = &self.listener
870 {
871 listener.on_release(key, value);
872 }
873 }
874
875 pub fn get_memory_usage(&self) -> usize {
876 self.shard_usages
877 .iter()
878 .map(|x| x.load(Ordering::Relaxed))
879 .sum()
880 }
881
882 pub fn get_lru_usage(&self) -> usize {
883 self.shard_lru_usages
884 .iter()
885 .map(|x| x.load(Ordering::Relaxed))
886 .sum()
887 }
888
889 fn shard(&self, hash: u64) -> usize {
890 hash as usize % self.shards.len()
891 }
892
893 pub fn for_all<F>(&self, mut f: F)
901 where
902 F: FnMut(&K, &T),
903 {
904 for shard in &self.shards {
905 let shard = shard.lock();
906 shard.for_all(&mut f);
907 }
908 }
909
910 pub fn clear(&self) {
914 for shard in &self.shards {
915 unsafe {
916 let mut shard = shard.lock();
917 shard.clear();
918 }
919 }
920 }
921}
922
923pub struct CleanCacheGuard<'a, K: LruKey + Clone + 'static, T: LruValue + 'static> {
924 cache: &'a Arc<LruCache<K, T>>,
925 key: Option<K>,
926 hash: u64,
927}
928
929impl<K: LruKey + Clone + 'static, T: LruValue + 'static> CleanCacheGuard<'_, K, T> {
930 fn mark_success(mut self) -> K {
931 self.key.take().unwrap()
932 }
933}
934
935impl<K: LruKey + Clone + 'static, T: LruValue + 'static> Drop for CleanCacheGuard<'_, K, T> {
936 fn drop(&mut self) {
937 if let Some(key) = self.key.as_ref() {
938 self.cache.clear_pending_request(key, self.hash);
939 }
940 }
941}
942
943pub enum LookupResponse<K: LruKey + Clone + 'static, T: LruValue + 'static, E> {
949 Invalid,
950 Cached(CacheableEntry<K, T>),
951 WaitPendingRequest(Receiver<CacheableEntry<K, T>>),
952 Miss(JoinHandle<Result<CacheableEntry<K, T>, E>>),
953}
954
955impl<K: LruKey + Clone + 'static, T: LruValue + 'static, E> Default for LookupResponse<K, T, E> {
956 fn default() -> Self {
957 Self::Invalid
958 }
959}
960
961impl<K: LruKey + Clone + 'static, T: LruValue + 'static, E: From<RecvError>> Future
962 for LookupResponse<K, T, E>
963{
964 type Output = Result<CacheableEntry<K, T>, E>;
965
966 fn poll(
967 mut self: std::pin::Pin<&mut Self>,
968 cx: &mut std::task::Context<'_>,
969 ) -> std::task::Poll<Self::Output> {
970 match &mut *self {
971 Self::Invalid => unreachable!(),
972 Self::Cached(_) => std::task::Poll::Ready(Ok(
973 must_match!(std::mem::take(&mut *self), Self::Cached(entry) => entry),
974 )),
975 Self::WaitPendingRequest(receiver) => {
976 receiver.poll_unpin(cx).map_err(|recv_err| recv_err.into())
977 }
978 Self::Miss(join_handle) => join_handle
979 .poll_unpin(cx)
980 .map(|join_result| join_result.unwrap()),
981 }
982 }
983}
984
985impl<K: LruKey + Clone + 'static, T: LruValue + 'static> LruCache<K, T> {
988 pub fn lookup_with_request_dedup<F, E, VC>(
989 self: &Arc<Self>,
990 hash: u64,
991 key: K,
992 priority: CachePriority,
993 fetch_value: F,
994 ) -> LookupResponse<K, T, E>
995 where
996 F: FnOnce() -> VC,
997 E: Error + Send + 'static + From<RecvError>,
998 VC: Future<Output = Result<(T, usize), E>> + Send + 'static,
999 {
1000 match self.lookup_for_request(hash, key.clone()) {
1001 LookupResult::Cached(entry) => LookupResponse::Cached(entry),
1002 LookupResult::WaitPendingRequest(receiver) => {
1003 LookupResponse::WaitPendingRequest(receiver)
1004 }
1005 LookupResult::Miss => {
1006 let this = self.clone();
1007 let fetch_value = fetch_value();
1008 let key2 = key;
1009 let join_handle = tokio::spawn(async move {
1010 let guard = CleanCacheGuard {
1011 cache: &this,
1012 key: Some(key2),
1013 hash,
1014 };
1015 let (value, charge) = fetch_value.await?;
1016 let key2 = guard.mark_success();
1017 let entry = this.insert(key2, hash, charge, value, priority);
1018 Ok(entry)
1019 });
1020 LookupResponse::Miss(join_handle)
1021 }
1022 }
1023 }
1024}
1025
1026pub struct CacheableEntry<K: LruKey, T: LruValue> {
1027 cache: Arc<LruCache<K, T>>,
1028 handle: *mut LruHandle<K, T>,
1029}
1030
1031pub enum LookupResult<K: LruKey, T: LruValue> {
1032 Cached(CacheableEntry<K, T>),
1033 Miss,
1034 WaitPendingRequest(Receiver<CacheableEntry<K, T>>),
1035}
1036
1037unsafe impl<K: LruKey, T: LruValue> Send for CacheableEntry<K, T> {}
1038unsafe impl<K: LruKey, T: LruValue> Sync for CacheableEntry<K, T> {}
1039
1040impl<K: LruKey, T: LruValue> Deref for CacheableEntry<K, T> {
1041 type Target = T;
1042
1043 fn deref(&self) -> &Self::Target {
1044 unsafe { (*self.handle).get_value() }
1045 }
1046}
1047
1048impl<K: LruKey, T: LruValue> Drop for CacheableEntry<K, T> {
1049 fn drop(&mut self) {
1050 unsafe {
1051 self.cache.release(self.handle);
1052 }
1053 }
1054}
1055
1056impl<K: LruKey, T: LruValue> Clone for CacheableEntry<K, T> {
1057 fn clone(&self) -> Self {
1058 unsafe {
1059 self.cache.inc_reference(self.handle);
1060 CacheableEntry {
1061 cache: self.cache.clone(),
1062 handle: self.handle,
1063 }
1064 }
1065 }
1066}
1067
1068#[cfg(test)]
1069mod tests {
1070 use std::collections::hash_map::DefaultHasher;
1071 use std::hash::Hasher;
1072 use std::pin::Pin;
1073 use std::sync::atomic::AtomicBool;
1074 use std::sync::atomic::Ordering::Relaxed;
1075 use std::task::{Context, Poll};
1076
1077 use rand::rngs::SmallRng;
1078 use rand::{RngCore, SeedableRng};
1079 use tokio::sync::oneshot::error::TryRecvError;
1080
1081 use super::*;
1082
1083 pub struct Block {
1084 pub offset: u64,
1085 #[allow(dead_code)]
1086 pub sst: u64,
1087 }
1088
1089 #[test]
1090 fn test_cache_handle_basic() {
1091 let mut h = Box::new(LruHandle::new(1, 2, 0, 0));
1092 h.set_in_cache(true);
1093 assert!(h.is_in_cache());
1094 h.set_in_cache(false);
1095 assert!(!h.is_in_cache());
1096 }
1097
1098 #[test]
1099 fn test_cache_shard() {
1100 let cache = Arc::new(LruCache::<(u64, u64), Block>::new(4, 256, 0));
1101 assert_eq!(cache.shard(0), 0);
1102 assert_eq!(cache.shard(1), 1);
1103 assert_eq!(cache.shard(10), 2);
1104 }
1105
1106 #[test]
1107 fn test_cache_basic() {
1108 let cache = Arc::new(LruCache::<(u64, u64), Block>::new(2, 256, 0));
1109 let seed = 10244021u64;
1110 let mut rng = SmallRng::seed_from_u64(seed);
1111 for _ in 0..100000 {
1112 let block_offset = rng.next_u64() % 1024;
1113 let sst = rng.next_u64() % 1024;
1114 let mut hasher = DefaultHasher::new();
1115 sst.hash(&mut hasher);
1116 block_offset.hash(&mut hasher);
1117 let h = hasher.finish();
1118 if let Some(block) = cache.lookup(h, &(sst, block_offset)) {
1119 assert_eq!(block.offset, block_offset);
1120 drop(block);
1121 continue;
1122 }
1123 cache.insert(
1124 (sst, block_offset),
1125 h,
1126 1,
1127 Block {
1128 offset: block_offset,
1129 sst,
1130 },
1131 CachePriority::High,
1132 );
1133 }
1134 assert_eq!(256, cache.get_memory_usage());
1135 }
1136
1137 fn validate_lru_list(cache: &mut LruCacheShard<String, String>, keys: Vec<&str>) {
1138 unsafe {
1139 let mut lru: *mut LruHandle<String, String> = cache.lru.as_mut();
1140 for k in keys {
1141 lru = (*lru).next;
1142 assert!(
1143 (*lru).is_same_key(&k.to_owned()),
1144 "compare failed: {} vs {}, get value: {:?}",
1145 (*lru).get_key(),
1146 k,
1147 (*lru).get_value()
1148 );
1149 }
1150 }
1151 }
1152
1153 fn create_cache(capacity: usize) -> LruCacheShard<String, String> {
1154 LruCacheShard::new_with_priority_pool(capacity, 0)
1155 }
1156
1157 fn lookup(cache: &mut LruCacheShard<String, String>, key: &str) -> bool {
1158 unsafe {
1159 let h = cache.lookup(0, &key.to_owned());
1160 let exist = !h.is_null();
1161 if exist {
1162 assert!((*h).is_same_key(&key.to_owned()));
1163 cache.release(h);
1164 }
1165 exist
1166 }
1167 }
1168
1169 fn insert_priority(
1170 cache: &mut LruCacheShard<String, String>,
1171 key: &str,
1172 value: &str,
1173 priority: CachePriority,
1174 ) {
1175 let mut free_list = vec![];
1176 unsafe {
1177 let handle = cache.insert(
1178 key.to_owned(),
1179 0,
1180 value.len(),
1181 value.to_owned(),
1182 priority,
1183 &mut free_list,
1184 );
1185 cache.release(handle);
1186 }
1187 free_list.clear();
1188 }
1189
1190 fn insert(cache: &mut LruCacheShard<String, String>, key: &str, value: &str) {
1191 insert_priority(cache, key, value, CachePriority::Low);
1192 }
1193
1194 #[test]
1195 fn test_basic_lru() {
1196 let mut cache = LruCacheShard::new_with_priority_pool(5, 40);
1197 let keys = vec!["a", "b", "c", "d", "e"];
1198 for &k in &keys {
1199 insert(&mut cache, k, k);
1200 }
1201 validate_lru_list(&mut cache, keys);
1202 for k in ["x", "y", "z"] {
1203 insert(&mut cache, k, k);
1204 }
1205 validate_lru_list(&mut cache, vec!["d", "e", "x", "y", "z"]);
1206 assert!(!lookup(&mut cache, "b"));
1207 assert!(lookup(&mut cache, "e"));
1208 validate_lru_list(&mut cache, vec!["d", "x", "y", "z", "e"]);
1209 assert!(lookup(&mut cache, "z"));
1210 validate_lru_list(&mut cache, vec!["d", "x", "y", "e", "z"]);
1211 unsafe {
1212 let h = cache.erase(0, &"x".to_owned());
1213 assert!(h.is_some());
1214 validate_lru_list(&mut cache, vec!["d", "y", "e", "z"]);
1215 }
1216 assert!(lookup(&mut cache, "d"));
1217 validate_lru_list(&mut cache, vec!["y", "e", "z", "d"]);
1218 insert(&mut cache, "u", "u");
1219 validate_lru_list(&mut cache, vec!["y", "e", "z", "d", "u"]);
1220 insert(&mut cache, "v", "v");
1221 validate_lru_list(&mut cache, vec!["e", "z", "d", "u", "v"]);
1222 insert_priority(&mut cache, "x", "x", CachePriority::High);
1223 validate_lru_list(&mut cache, vec!["z", "d", "u", "v", "x"]);
1224 assert!(lookup(&mut cache, "d"));
1225 validate_lru_list(&mut cache, vec!["z", "u", "v", "d", "x"]);
1226 insert(&mut cache, "y", "y");
1227 validate_lru_list(&mut cache, vec!["u", "v", "d", "y", "x"]);
1228 insert_priority(&mut cache, "z", "z", CachePriority::High);
1229 validate_lru_list(&mut cache, vec!["v", "d", "y", "x", "z"]);
1230 insert(&mut cache, "u", "u");
1231 validate_lru_list(&mut cache, vec!["d", "y", "u", "x", "z"]);
1232 insert_priority(&mut cache, "v", "v", CachePriority::High);
1233 validate_lru_list(&mut cache, vec!["y", "u", "x", "z", "v"]);
1234 }
1235
1236 #[test]
1237 fn test_reference_and_usage() {
1238 let mut cache = create_cache(5);
1239 insert(&mut cache, "k1", "a");
1240 assert_eq!(cache.usage.load(Ordering::Relaxed), 1);
1241 insert(&mut cache, "k0", "aa");
1242 assert_eq!(cache.usage.load(Ordering::Relaxed), 3);
1243 insert(&mut cache, "k1", "aa");
1244 assert_eq!(cache.usage.load(Ordering::Relaxed), 4);
1245 insert(&mut cache, "k2", "aa");
1246 assert_eq!(cache.usage.load(Ordering::Relaxed), 4);
1247 let mut free_list = vec![];
1248 validate_lru_list(&mut cache, vec!["k1", "k2"]);
1249 unsafe {
1250 let h1 = cache.lookup(0, &"k1".to_owned());
1251 assert!(!h1.is_null());
1252 let h2 = cache.lookup(0, &"k2".to_owned());
1253 assert!(!h2.is_null());
1254
1255 let h3 = cache.insert(
1256 "k3".to_owned(),
1257 0,
1258 2,
1259 "bb".to_owned(),
1260 CachePriority::High,
1261 &mut free_list,
1262 );
1263 assert_eq!(cache.usage.load(Ordering::Relaxed), 6);
1264 assert!(!h3.is_null());
1265 let h4 = cache.lookup(0, &"k1".to_owned());
1266 assert!(!h4.is_null());
1267
1268 cache.release(h1);
1269 assert_eq!(cache.usage.load(Ordering::Relaxed), 6);
1270 cache.release(h4);
1271 assert_eq!(cache.usage.load(Ordering::Relaxed), 4);
1272
1273 cache.release(h3);
1274 cache.release(h2);
1275
1276 validate_lru_list(&mut cache, vec!["k3", "k2"]);
1277 }
1278 }
1279
1280 #[test]
1281 fn test_update_referenced_key() {
1282 unsafe {
1283 let mut to_delete = vec![];
1284 let mut cache = create_cache(5);
1285 let insert_handle = cache.insert(
1286 "key".to_owned(),
1287 0,
1288 1,
1289 "old_value".to_owned(),
1290 CachePriority::High,
1291 &mut to_delete,
1292 );
1293 let old_entry = cache.lookup(0, &"key".to_owned());
1294 assert!(!old_entry.is_null());
1295 assert_eq!((*old_entry).get_value(), &"old_value".to_owned());
1296 assert_eq!((*old_entry).refs, 2);
1297 cache.release(insert_handle);
1298 assert_eq!((*old_entry).refs, 1);
1299 let insert_handle = cache.insert(
1300 "key".to_owned(),
1301 0,
1302 1,
1303 "new_value".to_owned(),
1304 CachePriority::Low,
1305 &mut to_delete,
1306 );
1307 assert!(!(*old_entry).is_in_cache());
1308 let new_entry = cache.lookup(0, &"key".to_owned());
1309 assert!(!new_entry.is_null());
1310 assert_eq!((*new_entry).get_value(), &"new_value".to_owned());
1311 assert_eq!((*new_entry).refs, 2);
1312 cache.release(insert_handle);
1313 assert_eq!((*new_entry).refs, 1);
1314
1315 assert!(!old_entry.is_null());
1316 assert_eq!((*old_entry).get_value(), &"old_value".to_owned());
1317 assert_eq!((*old_entry).refs, 1);
1318
1319 cache.release(new_entry);
1320 assert!((*new_entry).is_in_cache());
1321 #[cfg(debug_assertions)]
1322 assert!((*new_entry).is_in_lru());
1323
1324 assert!(!old_entry.is_null());
1326 assert_eq!((*old_entry).get_value(), &"old_value".to_owned());
1327 assert_eq!((*old_entry).refs, 1);
1328
1329 cache.release(old_entry);
1330 assert!(!(*old_entry).is_in_cache());
1331 assert!((*new_entry).is_in_cache());
1332 #[cfg(debug_assertions)]
1333 {
1334 assert!(!(*old_entry).is_in_lru());
1335 assert!((*new_entry).is_in_lru());
1336 }
1337 }
1338 }
1339
1340 #[test]
1341 fn test_release_stale_value() {
1342 unsafe {
1343 let mut to_delete = vec![];
1344 let mut cache = create_cache(1);
1346 let insert_handle = cache.insert(
1347 "key".to_owned(),
1348 0,
1349 1,
1350 "old_value".to_owned(),
1351 CachePriority::High,
1352 &mut to_delete,
1353 );
1354 cache.release(insert_handle);
1355 let old_entry = cache.lookup(0, &"key".to_owned());
1356 assert!(!old_entry.is_null());
1357 assert_eq!((*old_entry).get_value(), &"old_value".to_owned());
1358 assert_eq!((*old_entry).refs, 1);
1359
1360 let insert_handle = cache.insert(
1361 "key".to_owned(),
1362 0,
1363 1,
1364 "new_value".to_owned(),
1365 CachePriority::High,
1366 &mut to_delete,
1367 );
1368 assert!(!(*old_entry).is_in_cache());
1369 let new_entry = cache.lookup(0, &"key".to_owned());
1370 assert!(!new_entry.is_null());
1371 assert_eq!((*new_entry).get_value(), &"new_value".to_owned());
1372 assert_eq!((*new_entry).refs, 2);
1373 cache.release(insert_handle);
1374 assert_eq!((*new_entry).refs, 1);
1375
1376 assert_eq!(2, cache.usage.load(Relaxed));
1378 assert_eq!(0, cache.lru_usage.load(Relaxed));
1379
1380 cache.release(old_entry);
1382 assert_eq!(1, cache.usage.load(Relaxed));
1383 assert_eq!(0, cache.lru_usage.load(Relaxed));
1384
1385 let new_entry_again = cache.lookup(0, &"key".to_owned());
1386 assert!(!new_entry_again.is_null());
1387 assert_eq!((*new_entry_again).get_value(), &"new_value".to_owned());
1388 assert_eq!((*new_entry_again).refs, 2);
1389
1390 cache.release(new_entry);
1391 cache.release(new_entry_again);
1392
1393 assert_eq!(1, cache.usage.load(Relaxed));
1394 assert_eq!(1, cache.lru_usage.load(Relaxed));
1395 }
1396 }
1397
1398 #[test]
1399 fn test_write_request_pending() {
1400 let cache = Arc::new(LruCache::new(1, 5, 0));
1401 {
1402 let mut shard = cache.shards[0].lock();
1403 insert(&mut shard, "a", "v1");
1404 assert!(lookup(&mut shard, "a"));
1405 }
1406 assert!(matches!(
1407 cache.lookup_for_request(0, "a".to_owned()),
1408 LookupResult::Cached(_)
1409 ));
1410 assert!(matches!(
1411 cache.lookup_for_request(0, "b".to_owned()),
1412 LookupResult::Miss
1413 ));
1414 let ret2 = cache.lookup_for_request(0, "b".to_owned());
1415 match ret2 {
1416 LookupResult::WaitPendingRequest(mut recv) => {
1417 assert!(matches!(recv.try_recv(), Err(TryRecvError::Empty)));
1418 cache.insert("b".to_owned(), 0, 1, "v2".to_owned(), CachePriority::Low);
1419 recv.try_recv().unwrap();
1420 assert!(
1421 matches!(cache.lookup_for_request(0, "b".to_owned()), LookupResult::Cached(v) if v.eq("v2"))
1422 );
1423 }
1424 _ => panic!(),
1425 }
1426 }
1427
1428 #[derive(Default, Debug)]
1429 struct TestLruCacheEventListener {
1430 released: Arc<Mutex<HashMap<String, String>>>,
1431 }
1432
1433 impl LruCacheEventListener for TestLruCacheEventListener {
1434 type K = String;
1435 type T = String;
1436
1437 fn on_release(&self, key: Self::K, value: Self::T) {
1438 self.released.lock().insert(key, value);
1439 }
1440 }
1441
1442 #[test]
1443 fn test_event_listener() {
1444 let listener = Arc::new(TestLruCacheEventListener::default());
1445 let cache = Arc::new(LruCache::with_event_listener(1, 2, 0, listener.clone()));
1446
1447 let h = cache.insert("k1".to_owned(), 0, 1, "v1".to_owned(), CachePriority::High);
1449 drop(h);
1450 let h = cache.insert("k2".to_owned(), 0, 1, "v2".to_owned(), CachePriority::High);
1451 drop(h);
1452 assert_eq!(cache.get_memory_usage(), 2);
1453 assert!(listener.released.lock().is_empty());
1454
1455 let h = cache.insert("k3".to_owned(), 0, 1, "v3".to_owned(), CachePriority::High);
1457 drop(h);
1458 assert_eq!(cache.get_memory_usage(), 2);
1459 assert!(listener.released.lock().remove("k1").is_some());
1460
1461 cache.erase(0, &"k2".to_owned());
1463 assert_eq!(cache.get_memory_usage(), 1);
1464 assert!(listener.released.lock().remove("k2").is_some());
1465
1466 let h = cache.insert("k4".to_owned(), 0, 1, "v4".to_owned(), CachePriority::Low);
1468 drop(h);
1469 assert_eq!(cache.get_memory_usage(), 2);
1470 assert!(listener.released.lock().is_empty());
1471
1472 let h1 = cache.insert("k5".to_owned(), 0, 1, "v5".to_owned(), CachePriority::Low);
1475 assert_eq!(cache.get_memory_usage(), 2);
1476 assert!(listener.released.lock().remove("k3").is_some());
1477 let h2 = cache.insert("k6".to_owned(), 0, 1, "v6".to_owned(), CachePriority::Low);
1478 assert_eq!(cache.get_memory_usage(), 2);
1479 assert!(listener.released.lock().remove("k4").is_some());
1480
1481 let h3 = cache.insert("k7".to_owned(), 0, 1, "v7".to_owned(), CachePriority::Low);
1483 assert_eq!(cache.get_memory_usage(), 3);
1484 assert!(listener.released.lock().is_empty());
1485
1486 drop(h1);
1488 assert_eq!(cache.get_memory_usage(), 2);
1489 assert!(listener.released.lock().remove("k5").is_some());
1490
1491 drop(h2);
1493 assert_eq!(cache.get_memory_usage(), 2);
1494 assert!(listener.released.lock().is_empty());
1495 drop(h3);
1496 assert_eq!(cache.get_memory_usage(), 2);
1497 assert!(listener.released.lock().is_empty());
1498
1499 drop(cache);
1501 assert!(listener.released.lock().is_empty());
1502 }
1503
1504 pub struct SyncPointFuture<F: Future> {
1505 inner: F,
1506 polled: Arc<AtomicBool>,
1507 }
1508
1509 impl<F: Future + Unpin> Future for SyncPointFuture<F> {
1510 type Output = ();
1511
1512 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
1513 if self.polled.load(Ordering::Acquire) {
1514 return Poll::Ready(());
1515 }
1516 self.inner.poll_unpin(cx).map(|_| ())
1517 }
1518 }
1519
1520 #[tokio::test]
1521 async fn test_future_cancel() {
1522 let cache: Arc<LruCache<u64, u64>> = Arc::new(LruCache::new(1, 5, 0));
1523 let (_, recv) = channel::<()>();
1525 let polled = Arc::new(AtomicBool::new(false));
1526 let cache2 = cache.clone();
1527 let polled2 = polled.clone();
1528 let f = Box::pin(async move {
1529 cache2
1530 .lookup_with_request_dedup(1, 2, CachePriority::High, || async move {
1531 polled2.store(true, Ordering::Release);
1532 recv.await.map(|_| (1, 1))
1533 })
1534 .await
1535 .unwrap();
1536 });
1537 let wrapper = SyncPointFuture {
1538 inner: f,
1539 polled: polled.clone(),
1540 };
1541 {
1542 let handle = tokio::spawn(wrapper);
1543 while !polled.load(Ordering::Acquire) {
1544 tokio::task::yield_now().await;
1545 }
1546 handle.await.unwrap();
1547 }
1548 assert!(cache.shards[0].lock().write_request.is_empty());
1549 }
1550}