1use std::backtrace::Backtrace;
16use std::cmp::Ordering;
17use std::collections::VecDeque;
18use std::fmt::{Debug, Formatter};
19use std::ops::Bound::{Excluded, Included, Unbounded};
20use std::ops::{Bound, RangeBounds};
21use std::sync::Arc;
22use std::sync::atomic::{AtomicBool, AtomicU64, Ordering as AtomicOrdering};
23use std::time::{Duration, Instant};
24
25use bytes::Bytes;
26use foyer::Hint;
27use futures::{Stream, StreamExt, pin_mut};
28use parking_lot::Mutex;
29use risingwave_common::catalog::{TableId, TableOption};
30use risingwave_common::config::StorageMemoryConfig;
31use risingwave_expr::codegen::try_stream;
32use risingwave_hummock_sdk::can_concat;
33use risingwave_hummock_sdk::compaction_group::StateTableId;
34use risingwave_hummock_sdk::key::{
35 EmptySliceRef, FullKey, TableKey, UserKey, bound_table_key_range,
36};
37use risingwave_hummock_sdk::sstable_info::SstableInfo;
38use tokio::sync::oneshot::{Receiver, Sender, channel};
39
40use super::{HummockError, HummockResult, SstableStoreRef};
41use crate::error::{StorageError, StorageResult};
42use crate::hummock::CachePolicy;
43use crate::hummock::local_version::pinned_version::PinnedVersion;
44use crate::mem_table::{KeyOp, MemTableError};
45use crate::monitor::MemoryCollector;
46use crate::store::{
47 OpConsistencyLevel, ReadOptions, StateStoreGet, StateStoreKeyedRow, StateStoreRead,
48};
49
50pub fn range_overlap<R, B>(
51 search_key_range: &R,
52 inclusive_start_key: &B,
53 end_key: Bound<&B>,
54) -> bool
55where
56 R: RangeBounds<B>,
57 B: Ord,
58{
59 let (start_bound, end_bound) = (search_key_range.start_bound(), search_key_range.end_bound());
60
61 let too_left = match (start_bound, end_key) {
64 (Included(range_start), Included(inclusive_end_key)) => range_start > inclusive_end_key,
65 (Included(range_start), Excluded(end_key))
66 | (Excluded(range_start), Included(end_key))
67 | (Excluded(range_start), Excluded(end_key)) => range_start >= end_key,
68 (Unbounded, _) | (_, Unbounded) => false,
69 };
70 let too_right = match end_bound {
73 Included(range_end) => range_end < inclusive_start_key,
74 Excluded(range_end) => range_end <= inclusive_start_key,
75 Unbounded => false,
76 };
77
78 !too_left && !too_right
79}
80
81pub fn filter_single_sst<R, B>(info: &SstableInfo, table_id: TableId, table_key_range: &R) -> bool
82where
83 R: RangeBounds<TableKey<B>>,
84 B: AsRef<[u8]> + EmptySliceRef,
85{
86 debug_assert!(info.table_ids.is_sorted());
87 let table_range = &info.key_range;
88 let table_start = FullKey::decode(table_range.left.as_ref()).user_key;
89 let table_end = FullKey::decode(table_range.right.as_ref()).user_key;
90 let (left, right) = bound_table_key_range(table_id, table_key_range);
91 let left: Bound<UserKey<&[u8]>> = left.as_ref().map(|key| key.as_ref());
92 let right: Bound<UserKey<&[u8]>> = right.as_ref().map(|key| key.as_ref());
93
94 info.table_ids.binary_search(&table_id).is_ok()
95 && range_overlap(
96 &(left, right),
97 &table_start,
98 if table_range.right_exclusive {
99 Bound::Excluded(&table_end)
100 } else {
101 Bound::Included(&table_end)
102 },
103 )
104}
105
106pub(crate) fn search_sst_idx(ssts: &[SstableInfo], key: UserKey<&[u8]>) -> usize {
108 ssts.partition_point(|table| {
109 let ord = FullKey::decode(&table.key_range.left).user_key.cmp(&key);
110 ord == Ordering::Less || ord == Ordering::Equal
111 })
112}
113
114pub fn prune_overlapping_ssts<'a, R, B>(
117 ssts: &'a [SstableInfo],
118 table_id: TableId,
119 table_key_range: &'a R,
120) -> impl DoubleEndedIterator<Item = &'a SstableInfo>
121where
122 R: RangeBounds<TableKey<B>>,
123 B: AsRef<[u8]> + EmptySliceRef,
124{
125 ssts.iter()
126 .filter(move |info| filter_single_sst(info, table_id, table_key_range))
127}
128
129#[allow(clippy::type_complexity)]
132pub fn prune_nonoverlapping_ssts<'a>(
133 ssts: &'a [SstableInfo],
134 user_key_range: (Bound<UserKey<&'a [u8]>>, Bound<UserKey<&'a [u8]>>),
135 table_id: StateTableId,
136) -> impl DoubleEndedIterator<Item = &'a SstableInfo> {
137 debug_assert!(can_concat(ssts));
138 let start_table_idx = match user_key_range.0 {
139 Included(key) | Excluded(key) => search_sst_idx(ssts, key).saturating_sub(1),
140 _ => 0,
141 };
142 let end_table_idx = match user_key_range.1 {
143 Included(key) | Excluded(key) => search_sst_idx(ssts, key).saturating_sub(1),
144 _ => ssts.len().saturating_sub(1),
145 };
146 ssts[start_table_idx..=end_table_idx]
147 .iter()
148 .filter(move |sst| sst.table_ids.binary_search(&table_id).is_ok())
149}
150
151type RequestQueue = VecDeque<(Sender<MemoryTracker>, u64)>;
152enum MemoryRequest {
153 Ready(MemoryTracker),
154 Pending(Receiver<MemoryTracker>),
155}
156
157struct MemoryLimiterInner {
158 total_size: AtomicU64,
159 controller: Mutex<RequestQueue>,
160 has_waiter: AtomicBool,
161 quota: u64,
162}
163
164impl MemoryLimiterInner {
165 fn release_quota(&self, quota: u64) {
166 self.total_size.fetch_sub(quota, AtomicOrdering::SeqCst);
167 }
168
169 fn add_memory(&self, quota: u64) {
170 self.total_size.fetch_add(quota, AtomicOrdering::SeqCst);
171 }
172
173 fn may_notify_waiters(self: &Arc<Self>) {
174 if !self.has_waiter.load(AtomicOrdering::Acquire) {
176 return;
177 }
178 let mut notify_waiters = vec![];
179 {
180 let mut waiters = self.controller.lock();
181 while let Some((_, quota)) = waiters.front() {
182 if !self.try_require_memory(*quota) {
183 break;
184 }
185 let (tx, quota) = waiters.pop_front().unwrap();
186 notify_waiters.push((tx, quota));
187 }
188
189 if waiters.is_empty() {
190 self.has_waiter.store(false, AtomicOrdering::Release);
191 }
192 }
193
194 for (tx, quota) in notify_waiters {
195 let _ = tx.send(MemoryTracker::new(self.clone(), quota));
196 }
197 }
198
199 fn try_require_memory(&self, quota: u64) -> bool {
200 let mut current_quota = self.total_size.load(AtomicOrdering::Acquire);
201 while self.permit_quota(current_quota, quota) {
202 match self.total_size.compare_exchange(
203 current_quota,
204 current_quota + quota,
205 AtomicOrdering::SeqCst,
206 AtomicOrdering::SeqCst,
207 ) {
208 Ok(_) => {
209 return true;
210 }
211 Err(old_quota) => {
212 current_quota = old_quota;
213 }
214 }
215 }
216 false
217 }
218
219 fn require_memory(self: &Arc<Self>, quota: u64) -> MemoryRequest {
220 let mut waiters = self.controller.lock();
221 let first_req = waiters.is_empty();
222 if first_req {
223 self.has_waiter.store(true, AtomicOrdering::Release);
225 }
226 if self.try_require_memory(quota) {
228 if first_req {
229 self.has_waiter.store(false, AtomicOrdering::Release);
230 }
231 return MemoryRequest::Ready(MemoryTracker::new(self.clone(), quota));
232 }
233 let (tx, rx) = channel();
234 waiters.push_back((tx, quota));
235 MemoryRequest::Pending(rx)
236 }
237
238 fn permit_quota(&self, current_quota: u64, _request_quota: u64) -> bool {
239 current_quota <= self.quota
240 }
241}
242
243pub struct MemoryLimiter {
244 inner: Arc<MemoryLimiterInner>,
245}
246
247impl Debug for MemoryLimiter {
248 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
249 f.debug_struct("MemoryLimiter")
250 .field("quota", &self.inner.quota)
251 .field("usage", &self.inner.total_size)
252 .finish()
253 }
254}
255
256pub struct MemoryTracker {
257 limiter: Arc<MemoryLimiterInner>,
258 quota: Option<u64>,
259}
260impl MemoryTracker {
261 fn new(limiter: Arc<MemoryLimiterInner>, quota: u64) -> Self {
262 Self {
263 limiter,
264 quota: Some(quota),
265 }
266 }
267}
268
269impl Debug for MemoryTracker {
270 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
271 f.debug_struct("MemoryTracker")
272 .field("quota", &self.quota)
273 .finish()
274 }
275}
276
277impl MemoryLimiter {
278 pub fn unlimit() -> Arc<Self> {
279 Arc::new(Self::new(u64::MAX))
280 }
281
282 pub fn new(quota: u64) -> Self {
283 Self {
284 inner: Arc::new(MemoryLimiterInner {
285 total_size: AtomicU64::new(0),
286 controller: Mutex::new(VecDeque::default()),
287 has_waiter: AtomicBool::new(false),
288 quota,
289 }),
290 }
291 }
292
293 pub fn try_require_memory(&self, quota: u64) -> Option<MemoryTracker> {
294 if self.inner.try_require_memory(quota) {
295 Some(MemoryTracker::new(self.inner.clone(), quota))
296 } else {
297 None
298 }
299 }
300
301 pub fn get_memory_usage(&self) -> u64 {
302 self.inner.total_size.load(AtomicOrdering::Acquire)
303 }
304
305 pub fn quota(&self) -> u64 {
306 self.inner.quota
307 }
308
309 pub fn must_require_memory(&self, quota: u64) -> MemoryTracker {
310 if !self.inner.try_require_memory(quota) {
311 self.inner.add_memory(quota);
312 }
313
314 MemoryTracker::new(self.inner.clone(), quota)
315 }
316}
317
318impl MemoryLimiter {
319 pub async fn require_memory(&self, quota: u64) -> MemoryTracker {
320 match self.inner.require_memory(quota) {
321 MemoryRequest::Ready(tracker) => tracker,
322 MemoryRequest::Pending(rx) => rx.await.unwrap(),
323 }
324 }
325}
326
327impl MemoryTracker {
328 pub fn try_increase_memory(&mut self, target: u64) -> bool {
329 let quota = self.quota.unwrap();
330 if quota >= target {
331 return true;
332 }
333 if self.limiter.try_require_memory(target - quota) {
334 self.quota = Some(target);
335 true
336 } else {
337 false
338 }
339 }
340}
341
342impl Drop for MemoryTracker {
344 fn drop(&mut self) {
345 if let Some(quota) = self.quota.take() {
346 self.limiter.release_quota(quota);
347 self.limiter.may_notify_waiters();
348 }
349 }
350}
351
352pub fn check_subset_preserve_order<T: Eq>(
355 sub_iter: impl Iterator<Item = T>,
356 mut full_iter: impl Iterator<Item = T>,
357) -> bool {
358 for sub_iter_item in sub_iter {
359 let mut found = false;
360 for full_iter_item in full_iter.by_ref() {
361 if sub_iter_item == full_iter_item {
362 found = true;
363 break;
364 }
365 }
366 if !found {
367 return false;
368 }
369 }
370 true
371}
372
373static SANITY_CHECK_ENABLED: AtomicBool = AtomicBool::new(cfg!(debug_assertions));
374
375pub fn disable_sanity_check() {
379 SANITY_CHECK_ENABLED.store(false, AtomicOrdering::Release);
380}
381
382pub(crate) fn sanity_check_enabled() -> bool {
383 SANITY_CHECK_ENABLED.load(AtomicOrdering::Acquire)
384}
385
386async fn get_from_state_store(
387 state_store: &impl StateStoreGet,
388 key: TableKey<Bytes>,
389 read_options: ReadOptions,
390) -> StorageResult<Option<Bytes>> {
391 state_store
392 .on_key_value(key, read_options, |_, value| {
393 Ok(Bytes::copy_from_slice(value))
394 })
395 .await
396}
397
398pub(crate) async fn do_insert_sanity_check(
400 table_id: TableId,
401 key: &TableKey<Bytes>,
402 value: &Bytes,
403 inner: &impl StateStoreRead,
404 table_option: TableOption,
405 op_consistency_level: &OpConsistencyLevel,
406) -> StorageResult<()> {
407 if let OpConsistencyLevel::Inconsistent = op_consistency_level {
408 return Ok(());
409 }
410 let read_options = ReadOptions {
411 retention_seconds: table_option.retention_seconds,
412 cache_policy: CachePolicy::Fill(Hint::Normal),
413 ..Default::default()
414 };
415 let stored_value = get_from_state_store(inner, key.clone(), read_options).await?;
416
417 if let Some(stored_value) = stored_value {
418 return Err(Box::new(MemTableError::InconsistentOperation {
419 table_id,
420 key: key.clone(),
421 prev: KeyOp::Insert(stored_value),
422 new: KeyOp::Insert(value.clone()),
423 })
424 .into());
425 }
426 Ok(())
427}
428
429pub(crate) async fn do_delete_sanity_check(
431 table_id: TableId,
432 key: &TableKey<Bytes>,
433 old_value: &Bytes,
434 inner: &impl StateStoreRead,
435 table_option: TableOption,
436 op_consistency_level: &OpConsistencyLevel,
437) -> StorageResult<()> {
438 let OpConsistencyLevel::ConsistentOldValue {
439 check_old_value: old_value_checker,
440 ..
441 } = op_consistency_level
442 else {
443 return Ok(());
444 };
445 let read_options = ReadOptions {
446 retention_seconds: table_option.retention_seconds,
447 cache_policy: CachePolicy::Fill(Hint::Normal),
448 ..Default::default()
449 };
450 match get_from_state_store(inner, key.clone(), read_options).await? {
451 None => Err(Box::new(MemTableError::InconsistentOperation {
452 table_id,
453 key: key.clone(),
454 prev: KeyOp::Delete(Bytes::default()),
455 new: KeyOp::Delete(old_value.clone()),
456 })
457 .into()),
458 Some(stored_value) => {
459 if !old_value_checker(&stored_value, old_value) {
460 Err(Box::new(MemTableError::InconsistentOperation {
461 table_id,
462 key: key.clone(),
463 prev: KeyOp::Insert(stored_value),
464 new: KeyOp::Delete(old_value.clone()),
465 })
466 .into())
467 } else {
468 Ok(())
469 }
470 }
471 }
472}
473
474pub(crate) async fn do_update_sanity_check(
476 table_id: TableId,
477 key: &TableKey<Bytes>,
478 old_value: &Bytes,
479 new_value: &Bytes,
480 inner: &impl StateStoreRead,
481 table_option: TableOption,
482 op_consistency_level: &OpConsistencyLevel,
483) -> StorageResult<()> {
484 let OpConsistencyLevel::ConsistentOldValue {
485 check_old_value: old_value_checker,
486 ..
487 } = op_consistency_level
488 else {
489 return Ok(());
490 };
491 let read_options = ReadOptions {
492 retention_seconds: table_option.retention_seconds,
493 cache_policy: CachePolicy::Fill(Hint::Normal),
494 ..Default::default()
495 };
496
497 match get_from_state_store(inner, key.clone(), read_options).await? {
498 None => Err(Box::new(MemTableError::InconsistentOperation {
499 table_id,
500 key: key.clone(),
501 prev: KeyOp::Delete(Bytes::default()),
502 new: KeyOp::Update((old_value.clone(), new_value.clone())),
503 })
504 .into()),
505 Some(stored_value) => {
506 if !old_value_checker(&stored_value, old_value) {
507 Err(Box::new(MemTableError::InconsistentOperation {
508 table_id,
509 key: key.clone(),
510 prev: KeyOp::Insert(stored_value),
511 new: KeyOp::Update((old_value.clone(), new_value.clone())),
512 })
513 .into())
514 } else {
515 Ok(())
516 }
517 }
518 }
519}
520
521pub fn cmp_delete_range_left_bounds(a: Bound<&Bytes>, b: Bound<&Bytes>) -> Ordering {
522 match (a, b) {
523 (Unbounded, _) | (_, Unbounded) => unreachable!(),
525 (Included(x), Included(y)) | (Excluded(x), Excluded(y)) => x.cmp(y),
526 (Included(x), Excluded(y)) => x.cmp(y).then(Ordering::Less),
527 (Excluded(x), Included(y)) => x.cmp(y).then(Ordering::Greater),
528 }
529}
530
531pub(crate) fn validate_delete_range(left: &Bound<Bytes>, right: &Bound<Bytes>) -> bool {
532 match (left, right) {
533 (Unbounded, _) => unreachable!(),
535 (_, Unbounded) => true,
536 (Included(x), Included(y)) => x <= y,
537 (Included(x), Excluded(y)) | (Excluded(x), Included(y)) | (Excluded(x), Excluded(y)) => {
538 x < y
539 }
540 }
541}
542
543#[expect(dead_code)]
544pub(crate) fn filter_with_delete_range<'a>(
545 kv_iter: impl Iterator<Item = (TableKey<Bytes>, KeyOp)> + 'a,
546 mut delete_ranges_iter: impl Iterator<Item = &'a (Bound<Bytes>, Bound<Bytes>)> + 'a,
547) -> impl Iterator<Item = (TableKey<Bytes>, KeyOp)> + 'a {
548 let mut range = delete_ranges_iter.next();
549 if let Some((range_start, range_end)) = range {
550 assert!(
551 validate_delete_range(range_start, range_end),
552 "range_end {:?} smaller than range_start {:?}",
553 range_start,
554 range_end
555 );
556 }
557 kv_iter.filter(move |(key, _)| {
558 if let Some(range_bound) = range {
559 if cmp_delete_range_left_bounds(Included(&key.0), range_bound.0.as_ref())
560 == Ordering::Less
561 {
562 true
563 } else if range_bound.contains(key.as_ref()) {
564 false
565 } else {
566 loop {
568 range = delete_ranges_iter.next();
569 if let Some(range_bound) = range {
570 assert!(
571 validate_delete_range(&range_bound.0, &range_bound.1),
572 "range_end {:?} smaller than range_start {:?}",
573 range_bound.0,
574 range_bound.1
575 );
576 if cmp_delete_range_left_bounds(Included(key), range_bound.0.as_ref())
577 == Ordering::Less
578 {
579 break true;
581 } else if range_bound.contains(key.as_ref()) {
582 break false;
584 } else {
585 continue;
588 }
589 } else {
590 break true;
592 }
593 }
594 }
595 } else {
596 true
597 }
598 })
599}
600
601pub(crate) async fn wait_for_epoch(
606 notifier: &tokio::sync::watch::Sender<PinnedVersion>,
607 wait_epoch: u64,
608 table_id: TableId,
609) -> StorageResult<PinnedVersion> {
610 let mut prev_committed_epoch = None;
611 let prev_committed_epoch = &mut prev_committed_epoch;
612 let version = wait_for_update(
613 notifier,
614 |version| {
615 let committed_epoch = version.table_committed_epoch(table_id);
616 let ret = if let Some(committed_epoch) = committed_epoch {
617 if committed_epoch >= wait_epoch {
618 Ok(true)
619 } else {
620 Ok(false)
621 }
622 } else if prev_committed_epoch.is_none() {
623 Ok(false)
624 } else {
625 Err(HummockError::wait_epoch(format!(
626 "table {} has been dropped",
627 table_id
628 )))
629 };
630 *prev_committed_epoch = committed_epoch;
631 ret
632 },
633 || {
634 format!(
635 "wait_for_epoch: epoch: {}, table_id: {}",
636 wait_epoch, table_id
637 )
638 },
639 )
640 .await?;
641 Ok(version)
642}
643
644pub(crate) async fn wait_for_update(
645 notifier: &tokio::sync::watch::Sender<PinnedVersion>,
646 mut inspect_fn: impl FnMut(&PinnedVersion) -> HummockResult<bool>,
647 mut periodic_debug_info: impl FnMut() -> String,
648) -> HummockResult<PinnedVersion> {
649 let mut receiver = notifier.subscribe();
650 {
651 let version = receiver.borrow_and_update();
652 if inspect_fn(&version)? {
653 return Ok(version.clone());
654 }
655 }
656 let start_time = Instant::now();
657 loop {
658 match tokio::time::timeout(Duration::from_secs(30), receiver.changed()).await {
659 Err(_) => {
660 let backtrace = cfg!(debug_assertions)
662 .then(Backtrace::capture)
663 .map(tracing::field::display);
664
665 tracing::warn!(
675 info = periodic_debug_info(),
676 elapsed = ?start_time.elapsed(),
677 backtrace,
678 "timeout when waiting for version update",
679 );
680 continue;
681 }
682 Ok(Err(_)) => {
683 return Err(HummockError::wait_epoch("tx dropped"));
684 }
685 Ok(Ok(_)) => {
686 let version = receiver.borrow_and_update();
687 if inspect_fn(&version)? {
688 return Ok(version.clone());
689 }
690 }
691 }
692 }
693}
694
695pub struct HummockMemoryCollector {
696 sstable_store: SstableStoreRef,
697 limiter: Arc<MemoryLimiter>,
698 storage_memory_config: StorageMemoryConfig,
699}
700
701impl HummockMemoryCollector {
702 pub fn new(
703 sstable_store: SstableStoreRef,
704 limiter: Arc<MemoryLimiter>,
705 storage_memory_config: StorageMemoryConfig,
706 ) -> Self {
707 Self {
708 sstable_store,
709 limiter,
710 storage_memory_config,
711 }
712 }
713}
714
715impl MemoryCollector for HummockMemoryCollector {
716 fn get_meta_memory_usage(&self) -> u64 {
717 self.sstable_store.meta_cache().memory().usage() as _
718 }
719
720 fn get_data_memory_usage(&self) -> u64 {
721 self.sstable_store.block_cache().memory().usage() as _
722 }
723
724 fn get_vector_meta_memory_usage(&self) -> u64 {
725 self.sstable_store.vector_meta_cache.usage() as _
726 }
727
728 fn get_vector_data_memory_usage(&self) -> u64 {
729 self.sstable_store.vector_block_cache.usage() as _
730 }
731
732 fn get_uploading_memory_usage(&self) -> u64 {
733 self.limiter.get_memory_usage()
734 }
735
736 fn get_prefetch_memory_usage(&self) -> usize {
737 self.sstable_store.get_prefetch_memory_usage()
738 }
739
740 fn get_meta_cache_memory_usage_ratio(&self) -> f64 {
741 self.sstable_store.meta_cache().memory().usage() as f64
742 / self.sstable_store.meta_cache().memory().capacity() as f64
743 }
744
745 fn get_block_cache_memory_usage_ratio(&self) -> f64 {
746 self.sstable_store.block_cache().memory().usage() as f64
747 / self.sstable_store.block_cache().memory().capacity() as f64
748 }
749
750 fn get_vector_meta_cache_memory_usage_ratio(&self) -> f64 {
751 self.sstable_store.vector_meta_cache.usage() as f64
752 / self.sstable_store.vector_meta_cache.capacity() as f64
753 }
754
755 fn get_vector_data_cache_memory_usage_ratio(&self) -> f64 {
756 self.sstable_store.vector_block_cache.usage() as f64
757 / self.sstable_store.vector_block_cache.capacity() as f64
758 }
759
760 fn get_shared_buffer_usage_ratio(&self) -> f64 {
761 self.limiter.get_memory_usage() as f64
762 / (self.storage_memory_config.shared_buffer_capacity_mb * 1024 * 1024) as f64
763 }
764}
765
766#[try_stream(ok = StateStoreKeyedRow, error = StorageError)]
767pub(crate) async fn merge_stream<'a>(
768 mem_table_iter: impl Iterator<Item = (&'a TableKey<Bytes>, &'a KeyOp)> + 'a,
769 inner_stream: impl Stream<Item = StorageResult<StateStoreKeyedRow>> + 'static,
770 table_id: TableId,
771 epoch: u64,
772 rev: bool,
773) {
774 let inner_stream = inner_stream.peekable();
775 pin_mut!(inner_stream);
776
777 let mut mem_table_iter = mem_table_iter.fuse().peekable();
778
779 loop {
780 match (inner_stream.as_mut().peek().await, mem_table_iter.peek()) {
781 (None, None) => break,
782 (Some(_), None) => {
784 let (key, value) = inner_stream.next().await.unwrap()?;
785 yield (key, value)
786 }
787 (None, Some(_)) => {
789 let (key, key_op) = mem_table_iter.next().unwrap();
790 match key_op {
791 KeyOp::Insert(value) | KeyOp::Update((_, value)) => {
792 yield (FullKey::new(table_id, key.clone(), epoch), value.clone())
793 }
794 _ => {}
795 }
796 }
797 (Some(Ok((inner_key, _))), Some((mem_table_key, _))) => {
798 debug_assert_eq!(inner_key.user_key.table_id, table_id);
799 let mut ret = inner_key.user_key.table_key.cmp(mem_table_key);
800 if rev {
801 ret = ret.reverse();
802 }
803 match ret {
804 Ordering::Less => {
805 let (key, value) = inner_stream.next().await.unwrap()?;
807 yield (key, value);
808 }
809 Ordering::Equal => {
810 let (_, key_op) = mem_table_iter.next().unwrap();
814 let (key, old_value_in_inner) = inner_stream.next().await.unwrap()?;
815 match key_op {
816 KeyOp::Insert(value) => {
817 yield (key.clone(), value.clone());
818 }
819 KeyOp::Delete(_) => {}
820 KeyOp::Update((old_value, new_value)) => {
821 debug_assert!(old_value == &old_value_in_inner);
822
823 yield (key, new_value.clone());
824 }
825 }
826 }
827 Ordering::Greater => {
828 let (key, key_op) = mem_table_iter.next().unwrap();
830
831 match key_op {
832 KeyOp::Insert(value) => {
833 yield (FullKey::new(table_id, key.clone(), epoch), value.clone());
834 }
835 KeyOp::Delete(_) => {}
836 KeyOp::Update(_) => unreachable!(
837 "memtable update should always be paired with a storage key"
838 ),
839 }
840 }
841 }
842 }
843 (Some(Err(_)), Some(_)) => {
844 return Err(inner_stream.next().await.unwrap().unwrap_err());
846 }
847 }
848 }
849}
850
851#[cfg(test)]
852mod tests {
853 use std::future::{Future, poll_fn};
854 use std::sync::Arc;
855 use std::task::Poll;
856
857 use futures::FutureExt;
858 use futures::future::join_all;
859 use rand::random_range;
860
861 use crate::hummock::utils::MemoryLimiter;
862
863 async fn assert_pending(future: &mut (impl Future + Unpin)) {
864 for _ in 0..10 {
865 assert!(
866 poll_fn(|cx| Poll::Ready(future.poll_unpin(cx)))
867 .await
868 .is_pending()
869 );
870 }
871 }
872
873 #[tokio::test]
874 async fn test_loose_memory_limiter() {
875 let quota = 5;
876 let memory_limiter = MemoryLimiter::new(quota);
877 drop(memory_limiter.require_memory(6).await);
878 let tracker1 = memory_limiter.require_memory(3).await;
879 assert_eq!(3, memory_limiter.get_memory_usage());
880 let tracker2 = memory_limiter.require_memory(4).await;
881 assert_eq!(7, memory_limiter.get_memory_usage());
882 let mut future = memory_limiter.require_memory(5).boxed();
883 assert_pending(&mut future).await;
884 assert_eq!(7, memory_limiter.get_memory_usage());
885 drop(tracker1);
886 let tracker3 = future.await;
887 assert_eq!(9, memory_limiter.get_memory_usage());
888 drop(tracker2);
889 assert_eq!(5, memory_limiter.get_memory_usage());
890 drop(tracker3);
891 assert_eq!(0, memory_limiter.get_memory_usage());
892 }
893
894 #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
895 async fn test_multi_thread_acquire_memory() {
896 const QUOTA: u64 = 10;
897 let memory_limiter = Arc::new(MemoryLimiter::new(200));
898 let mut handles = vec![];
899 for _ in 0..40 {
900 let limiter = memory_limiter.clone();
901 let h = tokio::spawn(async move {
902 let mut buffers = vec![];
903 let mut current_buffer_usage = random_range(2..=9);
904 for _ in 0..1000 {
905 if buffers.len() < current_buffer_usage
906 && let Some(tracker) = limiter.try_require_memory(QUOTA)
907 {
908 buffers.push(tracker);
909 } else {
910 buffers.clear();
911 current_buffer_usage = random_range(2..=9);
912 let req = limiter.require_memory(QUOTA);
913 match tokio::time::timeout(std::time::Duration::from_millis(1), req).await {
914 Ok(tracker) => {
915 buffers.push(tracker);
916 }
917 Err(_) => {
918 continue;
919 }
920 }
921 }
922 let sleep_time = random_range(1..=3);
923 tokio::time::sleep(std::time::Duration::from_millis(sleep_time)).await;
924 }
925 });
926 handles.push(h);
927 }
928 let h = join_all(handles);
929 let _ = h.await;
930 }
931}