1use std::collections::{HashMap, HashSet};
16use std::ops::Bound::{Excluded, Included};
17use std::ops::{Deref, DerefMut};
18use std::sync::atomic::Ordering;
19
20use bytes::BytesMut;
21use risingwave_hummock_sdk::compaction_group::hummock_version_ext::object_size_map;
22use risingwave_hummock_sdk::version::HummockVersion;
23use risingwave_hummock_sdk::{HummockObjectId, HummockVersionId, get_stale_object_ids};
24use risingwave_pb::hummock::hummock_version_checkpoint::{PbStaleObjects, StaleObjects};
25use risingwave_pb::hummock::{
26 CheckpointCompressionAlgorithm, PbHummockVersion, PbHummockVersionArchive,
27 PbHummockVersionCheckpoint, PbHummockVersionCheckpointEnvelope, PbVectorIndexObject,
28 PbVectorIndexObjectType,
29};
30use thiserror_ext::AsReport;
31use tracing::warn;
32
33use crate::hummock::HummockManager;
34use crate::hummock::error::Result;
35use crate::hummock::manager::versioning::Versioning;
36use crate::hummock::metrics_utils::{trigger_gc_stat, trigger_split_stat};
37
38pub(crate) fn xxhash64_checksum(data: &[u8]) -> u64 {
41 use std::hash::Hasher;
42 let mut hasher = twox_hash::XxHash64::with_seed(0);
43 hasher.write(data);
44 hasher.finish()
45}
46
47#[derive(Default)]
48pub struct HummockVersionCheckpoint {
49 pub version: HummockVersion,
50
51 pub stale_objects: HashMap<HummockVersionId, PbStaleObjects>,
57}
58
59impl HummockVersionCheckpoint {
60 pub fn from_protobuf(checkpoint: &PbHummockVersionCheckpoint) -> Self {
61 Self {
62 version: HummockVersion::from_persisted_protobuf(checkpoint.version.as_ref().unwrap()),
63 stale_objects: checkpoint
64 .stale_objects
65 .iter()
66 .map(|(version_id, objects)| (*version_id, objects.clone()))
67 .collect(),
68 }
69 }
70
71 pub fn from_protobuf_owned(checkpoint: PbHummockVersionCheckpoint) -> Self {
74 Self {
75 version: HummockVersion::from_persisted_protobuf_owned(checkpoint.version.unwrap()),
76 stale_objects: checkpoint.stale_objects,
77 }
78 }
79
80 pub fn to_protobuf(&self) -> PbHummockVersionCheckpoint {
81 PbHummockVersionCheckpoint {
82 version: Some(PbHummockVersion::from(&self.version)),
83 stale_objects: self
84 .stale_objects
85 .iter()
86 .map(|(version_id, objects)| (*version_id, objects.clone()))
87 .collect(),
88 }
89 }
90}
91
92fn decode_checkpoint_data(data: bytes::Bytes) -> Result<PbHummockVersionCheckpoint> {
108 use anyhow::Context;
109 use prost::Message;
110
111 let data_size = data.len();
112
113 if let Ok(envelope) = PbHummockVersionCheckpointEnvelope::decode(data.clone())
114 && let Some(expected) = envelope.checksum
115 {
116 let actual = xxhash64_checksum(&envelope.payload);
117 if actual != expected {
118 return Err(anyhow::anyhow!(
119 "checkpoint checksum mismatch: expected {:#x}, got {:#x}",
120 expected,
121 actual
122 )
123 .into());
124 }
125
126 let algo = CheckpointCompressionAlgorithm::try_from(envelope.compression_algorithm)
127 .with_context(|| {
128 format!(
129 "unknown checkpoint compression algorithm: {}",
130 envelope.compression_algorithm
131 )
132 })?;
133
134 let decompressed = decompress_payload(algo, &envelope.payload)?;
135 let ckpt = PbHummockVersionCheckpoint::decode(decompressed.as_ref())
136 .context("failed to decode checkpoint envelope payload")?;
137 if ckpt.version.is_none() {
138 return Err(anyhow::anyhow!("checkpoint missing required field `version`").into());
139 }
140
141 tracing::info!(
142 compression = ?algo,
143 compressed_size = envelope.payload.len(),
144 decompressed_size = decompressed.len(),
145 compression_ratio =
146 format!("{:.2}x", decompressed.len() as f64 / envelope.payload.len().max(1) as f64),
147 checksum = format!("{expected:#x}"),
148 "decoded compressed checkpoint"
149 );
150 return Ok(ckpt);
151 }
152
153 tracing::info!(
155 data_size,
156 "decoding checkpoint in legacy uncompressed format"
157 );
158 let ckpt =
159 PbHummockVersionCheckpoint::decode(data).context("failed to decode legacy checkpoint")?;
160 if ckpt.version.is_none() {
161 return Err(anyhow::anyhow!("legacy checkpoint missing required field `version`").into());
162 }
163 Ok(ckpt)
164}
165
166fn decompress_payload(
167 algo: CheckpointCompressionAlgorithm,
168 payload: &[u8],
169) -> Result<std::borrow::Cow<'_, [u8]>> {
170 use anyhow::Context;
171
172 match algo {
173 CheckpointCompressionAlgorithm::CheckpointCompressionUnspecified => Ok(payload.into()),
174 CheckpointCompressionAlgorithm::CheckpointCompressionZstd => {
175 Ok(zstd::stream::decode_all(payload)
176 .map(std::borrow::Cow::Owned)
177 .context("zstd decompression failed")?)
178 }
179 CheckpointCompressionAlgorithm::CheckpointCompressionLz4 => {
180 let mut decoder = lz4::Decoder::new(payload).context("lz4 decoder init failed")?;
181 let mut decompressed = Vec::new();
182 std::io::Read::read_to_end(&mut decoder, &mut decompressed)
183 .context("lz4 decompression failed")?;
184 Ok(decompressed.into())
185 }
186 }
187}
188
189pub(crate) fn compress_payload(
191 algo: risingwave_common::config::CheckpointCompression,
192 data: &[u8],
193) -> Result<Vec<u8>> {
194 use anyhow::Context;
195 use risingwave_common::config::CheckpointCompression;
196
197 match algo {
198 CheckpointCompression::None => Ok(data.to_vec()),
199 CheckpointCompression::Zstd => {
200 Ok(zstd::stream::encode_all(data, 3).context("zstd compression failed")?)
202 }
203 CheckpointCompression::Lz4 => {
204 let mut compressed = Vec::new();
205 let mut encoder = lz4::EncoderBuilder::new()
206 .level(4)
207 .build(&mut compressed)
208 .context("lz4 encoder init failed")?;
209 std::io::Write::write_all(&mut encoder, data)
210 .context("lz4 compression write failed")?;
211 let (_writer, result) = encoder.finish();
212 result.context("lz4 compression finish failed")?;
213 Ok(compressed)
214 }
215 }
216}
217
218async fn read_bytes_in_chunks<F, Fut>(
219 total_size: usize,
220 chunk_size: usize,
221 max_in_flight_chunks: usize,
222 mut read_range: F,
223) -> anyhow::Result<bytes::Bytes>
224where
225 F: FnMut(std::ops::Range<usize>) -> Fut,
226 Fut: std::future::Future<Output = anyhow::Result<bytes::Bytes>>,
227{
228 use anyhow::Context;
229 use futures::StreamExt;
230
231 let num_chunks = total_size.div_ceil(chunk_size);
232 let mut buf = BytesMut::with_capacity(total_size);
233
234 let mut chunk_stream = futures::stream::iter((0..total_size).step_by(chunk_size))
235 .enumerate()
236 .map(|(chunk_idx, offset)| {
237 let end = std::cmp::min(offset + chunk_size, total_size);
238 let range = offset..end;
239 let fut = read_range(range.clone());
240 async move {
241 fut.await.with_context(|| {
242 format!(
243 "read checkpoint chunk {}/{} range {}..{}",
244 chunk_idx + 1,
245 num_chunks,
246 range.start,
247 range.end
248 )
249 })
250 }
251 })
252 .buffered(max_in_flight_chunks);
253
254 while let Some(chunk) = chunk_stream.next().await {
255 let chunk = chunk?;
256 buf.extend_from_slice(&chunk);
257 }
258
259 Ok(buf.freeze())
260}
261
262impl HummockManager {
265 pub async fn try_read_checkpoint(&self) -> Result<Option<HummockVersionCheckpoint>> {
272 let object_metadata = match self
273 .object_store
274 .metadata(&self.version_checkpoint_path)
275 .await
276 {
277 Ok(metadata) => metadata,
278 Err(e) => {
279 if e.is_object_not_found_error() {
280 return Ok(None);
281 }
282 return Err(e.into());
283 }
284 };
285 let total_size = object_metadata.total_size;
286
287 let chunk_size = self.env.opts.checkpoint_read_chunk_size;
288 let max_in_flight_chunks = self.env.opts.checkpoint_read_max_in_flight_chunks;
289
290 let download_start = std::time::Instant::now();
291 let data = if total_size <= chunk_size {
292 self.object_store
293 .read(&self.version_checkpoint_path, 0..total_size)
294 .await?
295 } else {
296 let num_chunks = total_size.div_ceil(chunk_size);
297 let data = read_bytes_in_chunks(
298 total_size,
299 chunk_size,
300 max_in_flight_chunks,
301 |range| async {
302 Ok(self
303 .object_store
304 .read(&self.version_checkpoint_path, range)
305 .await?)
306 },
307 )
308 .await?;
309
310 tracing::info!(
311 total_size,
312 num_chunks,
313 chunk_size,
314 max_in_flight_chunks,
315 "chunked read complete"
316 );
317 data
318 };
319 let download_duration = download_start.elapsed();
320
321 let decode_start = std::time::Instant::now();
322 let ckpt = decode_checkpoint_data(data)?;
323 let decode_duration = decode_start.elapsed();
324
325 tracing::info!(
326 total_size,
327 download_ms = download_duration.as_millis() as u64,
328 decode_ms = decode_duration.as_millis() as u64,
329 "checkpoint read complete"
330 );
331
332 Ok(Some(HummockVersionCheckpoint::from_protobuf_owned(ckpt)))
333 }
334
335 pub(super) async fn write_checkpoint(
336 &self,
337 checkpoint: &HummockVersionCheckpoint,
338 ) -> Result<()> {
339 use prost::Message;
340 let raw_bytes = checkpoint.to_protobuf().encode_to_vec();
341 let raw_size = raw_bytes.len();
342
343 let compression = self.env.opts.checkpoint_compression_algorithm;
344 let compressed = compress_payload(compression, &raw_bytes)?;
345 let checksum = xxhash64_checksum(&compressed);
346
347 tracing::info!(
348 raw_size,
349 compressed_size = compressed.len(),
350 compression_ratio =
351 format!("{:.2}x", raw_size as f64 / compressed.len().max(1) as f64),
352 compression = ?compression,
353 checksum = format!("{:#x}", checksum),
354 "writing compressed checkpoint"
355 );
356
357 let envelope = PbHummockVersionCheckpointEnvelope {
358 compression_algorithm: compression as i32,
359 payload: compressed,
360 checksum: Some(checksum),
361 };
362
363 let buf = envelope.encode_to_vec();
364 self.object_store
365 .upload(&self.version_checkpoint_path, buf.into())
366 .await?;
367 Ok(())
368 }
369
370 pub(super) async fn write_version_archive(
371 &self,
372 archive: &PbHummockVersionArchive,
373 ) -> Result<()> {
374 use prost::Message;
375 let buf = archive.encode_to_vec();
376 let archive_path = format!(
377 "{}/{}",
378 self.version_archive_dir,
379 archive.version.as_ref().unwrap().id
380 );
381 self.object_store.upload(&archive_path, buf.into()).await?;
382 Ok(())
383 }
384
385 pub async fn create_version_checkpoint(&self, min_delta_log_num: u64) -> Result<u64> {
390 let timer = self.metrics.version_checkpoint_latency.start_timer();
391 let versioning_guard = self.versioning.read().await;
393 let versioning: &Versioning = versioning_guard.deref();
394 let current_version: &HummockVersion = &versioning.current_version;
395 let old_checkpoint: &HummockVersionCheckpoint = &versioning.checkpoint;
396 let new_checkpoint_id = current_version.id;
397 let old_checkpoint_id = old_checkpoint.version.id;
398 if new_checkpoint_id < old_checkpoint_id + min_delta_log_num {
399 return Ok(0);
400 }
401 if cfg!(test) && new_checkpoint_id == old_checkpoint_id {
402 drop(versioning_guard);
403 let versioning = self.versioning.read().await;
404 let context_info = self.context_info.read().await;
405 let min_pinned_version_id = context_info.min_pinned_version_id();
406 trigger_gc_stat(&self.metrics, &versioning.checkpoint, min_pinned_version_id);
407 return Ok(0);
408 }
409 assert!(new_checkpoint_id > old_checkpoint_id);
410 let mut archive: Option<PbHummockVersionArchive> = None;
411 let mut stale_objects = old_checkpoint.stale_objects.clone();
412 let mut object_sizes = object_size_map(&old_checkpoint.version);
413 let mut versions_object_ids: HashSet<_> =
415 old_checkpoint.version.get_object_ids(false).collect();
416 for (_, version_delta) in versioning
417 .hummock_version_deltas
418 .range((Excluded(old_checkpoint_id), Included(new_checkpoint_id)))
419 {
420 match HummockObjectId::Sstable(0.into()) {
424 HummockObjectId::Sstable(_) => {}
425 HummockObjectId::VectorFile(_) => {}
426 HummockObjectId::HnswGraphFile(_) => {}
427 };
428 for (object_id, file_size) in version_delta
429 .newly_added_sst_infos(false)
430 .map(|sst| (HummockObjectId::Sstable(sst.object_id), sst.file_size))
431 .chain(
432 version_delta
433 .vector_index_delta
434 .values()
435 .flat_map(|delta| delta.newly_added_objects()),
436 )
437 {
438 object_sizes.insert(object_id, file_size);
439 versions_object_ids.insert(object_id);
440 }
441 }
442
443 let removed_object_ids =
445 &versions_object_ids - ¤t_version.get_object_ids(false).collect();
446 let total_file_size = removed_object_ids
447 .iter()
448 .map(|t| {
449 object_sizes.get(t).copied().unwrap_or_else(|| {
450 warn!(object_id = ?t, "unable to get size of removed object id");
451 0
452 })
453 })
454 .sum::<u64>();
455 stale_objects.insert(current_version.id, {
456 let mut sst_ids = vec![];
457 let mut vector_files = vec![];
458 for object_id in removed_object_ids {
459 match object_id {
460 HummockObjectId::Sstable(sst_id) => sst_ids.push(sst_id),
461 HummockObjectId::VectorFile(vector_file_id) => {
462 vector_files.push(PbVectorIndexObject {
463 id: vector_file_id.as_raw(),
464 object_type: PbVectorIndexObjectType::VectorIndexObjectVector as _,
465 })
466 }
467 HummockObjectId::HnswGraphFile(graph_file_id) => {
468 vector_files.push(PbVectorIndexObject {
469 id: graph_file_id.as_raw(),
470 object_type: PbVectorIndexObjectType::VectorIndexObjectHnswGraph as _,
471 });
472 }
473 }
474 }
475 StaleObjects {
476 id: sst_ids,
477 total_file_size,
478 vector_files,
479 }
480 });
481 if self.env.opts.enable_hummock_data_archive {
482 archive = Some(PbHummockVersionArchive {
483 version: Some(PbHummockVersion::from(&old_checkpoint.version)),
484 version_deltas: versioning
485 .hummock_version_deltas
486 .range((Excluded(old_checkpoint_id), Included(new_checkpoint_id)))
487 .map(|(_, version_delta)| version_delta.into())
488 .collect(),
489 });
490 }
491 let min_pinned_version_id = self.context_info.read().await.min_pinned_version_id();
492 let may_delete_object = stale_objects
493 .iter()
494 .filter_map(|(version_id, object_ids)| {
495 if *version_id >= min_pinned_version_id {
496 return None;
497 }
498 Some(get_stale_object_ids(object_ids))
499 })
500 .flatten();
501 self.gc_manager.add_may_delete_object_ids(may_delete_object);
502 stale_objects.retain(|version_id, _| *version_id >= min_pinned_version_id);
503 let new_checkpoint = HummockVersionCheckpoint {
504 version: current_version.clone(),
505 stale_objects,
506 };
507 drop(versioning_guard);
508 self.write_checkpoint(&new_checkpoint).await?;
509 if let Some(archive) = archive
510 && let Err(e) = self.write_version_archive(&archive).await
511 {
512 tracing::warn!(
513 error = %e.as_report(),
514 "failed to write version archive {}",
515 archive.version.as_ref().unwrap().id
516 );
517 }
518 let mut versioning_guard = self.versioning.write().await;
519 let versioning = versioning_guard.deref_mut();
520 assert!(new_checkpoint.version.id > versioning.checkpoint.version.id);
521 versioning.checkpoint = new_checkpoint;
522 let min_pinned_version_id = self.context_info.read().await.min_pinned_version_id();
523 trigger_gc_stat(&self.metrics, &versioning.checkpoint, min_pinned_version_id);
524 trigger_split_stat(&self.metrics, &versioning.current_version);
525 drop(versioning_guard);
526 timer.observe_duration();
527 self.metrics
528 .checkpoint_version_id
529 .set(new_checkpoint_id.as_i64_id());
530
531 Ok(new_checkpoint_id - old_checkpoint_id)
532 }
533
534 pub fn pause_version_checkpoint(&self) {
535 self.pause_version_checkpoint.store(true, Ordering::Relaxed);
536 tracing::info!("hummock version checkpoint is paused.");
537 }
538
539 pub fn resume_version_checkpoint(&self) {
540 self.pause_version_checkpoint
541 .store(false, Ordering::Relaxed);
542 tracing::info!("hummock version checkpoint is resumed.");
543 }
544
545 pub fn is_version_checkpoint_paused(&self) -> bool {
546 self.pause_version_checkpoint.load(Ordering::Relaxed)
547 }
548
549 pub async fn get_checkpoint_version(&self) -> HummockVersion {
550 let versioning_guard = self.versioning.read().await;
551 versioning_guard.checkpoint.version.clone()
552 }
553}
554
555#[cfg(test)]
556mod tests {
557 use bytes::Bytes;
558 use prost::Message;
559 use risingwave_common::config::CheckpointCompression;
560 use risingwave_pb::hummock::hummock_version_checkpoint::StaleObjects;
561 use risingwave_pb::hummock::{
562 PbHummockVersion, PbHummockVersionCheckpoint, PbHummockVersionCheckpointEnvelope,
563 };
564
565 use super::{
566 compress_payload, decode_checkpoint_data, read_bytes_in_chunks, xxhash64_checksum,
567 };
568
569 #[allow(deprecated)]
570 fn make_version(id: u64) -> PbHummockVersion {
571 PbHummockVersion {
572 id: id.into(),
573 levels: Default::default(),
574 max_committed_epoch: 0,
575 table_watermarks: Default::default(),
576 table_change_logs: Default::default(),
577 state_table_info: Default::default(),
578 vector_indexes: Default::default(),
579 }
580 }
581
582 fn make_checkpoint(version_id: u64) -> PbHummockVersionCheckpoint {
583 let stale = StaleObjects {
584 id: vec![1u64.into(), 2u64.into(), 3u64.into()],
585 total_file_size: 123,
586 vector_files: vec![],
587 };
588
589 PbHummockVersionCheckpoint {
590 version: Some(make_version(version_id)),
591 stale_objects: [(1u64.into(), stale)].into_iter().collect(),
592 }
593 }
594
595 fn make_envelope_bytes(
596 checkpoint: &PbHummockVersionCheckpoint,
597 compression: CheckpointCompression,
598 checksum: Option<u64>,
599 ) -> Bytes {
600 let raw = checkpoint.encode_to_vec();
601 let payload = compress_payload(compression, &raw)
602 .expect("compress checkpoint payload should succeed");
603 let checksum = checksum.unwrap_or_else(|| xxhash64_checksum(&payload));
604 let envelope = PbHummockVersionCheckpointEnvelope {
605 compression_algorithm: compression as i32,
606 payload,
607 checksum: Some(checksum),
608 };
609 Bytes::from(envelope.encode_to_vec())
610 }
611
612 #[test]
613 fn decode_checkpoint_data_falls_back_to_legacy_format() {
614 let checkpoint = make_checkpoint(42);
615 let raw = Bytes::from(checkpoint.encode_to_vec());
616 let decoded = decode_checkpoint_data(raw).expect("legacy checkpoint should decode");
617 assert_eq!(decoded, checkpoint);
618 }
619
620 #[test]
621 fn decode_checkpoint_data_roundtrips_envelope_with_checksum() {
622 let checkpoint = make_checkpoint(42);
623 for compression in [
624 CheckpointCompression::None,
625 CheckpointCompression::Zstd,
626 CheckpointCompression::Lz4,
627 ] {
628 let data = make_envelope_bytes(&checkpoint, compression, None);
629 let decoded = decode_checkpoint_data(data).expect("envelope checkpoint should decode");
630 assert_eq!(decoded, checkpoint);
631 }
632 }
633
634 #[test]
635 fn decode_checkpoint_data_returns_error_on_checksum_mismatch() {
636 let checkpoint = make_checkpoint(42);
637 let raw = checkpoint.encode_to_vec();
638 let mut payload = compress_payload(CheckpointCompression::Zstd, &raw)
639 .expect("compress checkpoint payload should succeed");
640 let expected = xxhash64_checksum(&payload);
641 payload[0] ^= 0x01;
642 let envelope = PbHummockVersionCheckpointEnvelope {
643 compression_algorithm: CheckpointCompression::Zstd as i32,
644 payload,
645 checksum: Some(expected),
646 };
647 let data = Bytes::from(envelope.encode_to_vec());
648 let err = decode_checkpoint_data(data).expect_err("checksum mismatch should error");
649 assert!(err.to_string().contains("checksum mismatch"), "{err:?}");
650 }
651
652 #[test]
653 fn decode_checkpoint_data_returns_error_on_unknown_compression_algorithm() {
654 let checkpoint = make_checkpoint(42);
655 let payload = checkpoint.encode_to_vec();
656 let checksum = xxhash64_checksum(&payload);
657 let envelope = PbHummockVersionCheckpointEnvelope {
658 compression_algorithm: 123,
659 payload,
660 checksum: Some(checksum),
661 };
662 let data = Bytes::from(envelope.encode_to_vec());
663 let err =
664 decode_checkpoint_data(data).expect_err("unknown compression algorithm should error");
665 assert!(
666 err.to_string()
667 .contains("unknown checkpoint compression algorithm"),
668 "{err:?}"
669 );
670 }
671
672 #[test]
673 fn decode_checkpoint_data_returns_error_on_legacy_missing_version() {
674 let checkpoint = PbHummockVersionCheckpoint {
675 version: None,
676 stale_objects: Default::default(),
677 };
678 let data = Bytes::from(checkpoint.encode_to_vec());
679 let err = decode_checkpoint_data(data).expect_err("missing version should error");
680 assert!(
681 err.to_string()
682 .contains("legacy checkpoint missing required field `version`"),
683 "{err:?}"
684 );
685 }
686
687 #[test]
688 fn decode_checkpoint_data_returns_error_on_corrupt_envelope_payload() {
689 let garbage = b"not a valid protobuf";
690 let checksum = xxhash64_checksum(garbage);
691 let envelope = PbHummockVersionCheckpointEnvelope {
692 compression_algorithm: CheckpointCompression::None as i32,
693 payload: garbage.to_vec(),
694 checksum: Some(checksum),
695 };
696 let data = Bytes::from(envelope.encode_to_vec());
697 let err = decode_checkpoint_data(data).expect_err("corrupt envelope payload should error");
698 assert!(
699 err.to_string()
700 .contains("failed to decode checkpoint envelope payload"),
701 "{err:?}"
702 );
703 }
704
705 #[test]
706 fn decode_checkpoint_data_returns_error_on_empty_input() {
707 let err = decode_checkpoint_data(Bytes::new()).expect_err("empty checkpoint should fail");
708 assert!(
709 err.to_string()
710 .contains("legacy checkpoint missing required field `version`"),
711 "{err:?}"
712 );
713 }
714
715 #[test]
716 fn decode_checkpoint_data_returns_error_on_envelope_missing_version() {
717 let checkpoint = PbHummockVersionCheckpoint {
718 version: None,
719 stale_objects: Default::default(),
720 };
721 let raw = checkpoint.encode_to_vec();
722 let checksum = xxhash64_checksum(&raw);
723 let envelope = PbHummockVersionCheckpointEnvelope {
724 compression_algorithm: CheckpointCompression::None as i32,
725 payload: raw,
726 checksum: Some(checksum),
727 };
728 let data = Bytes::from(envelope.encode_to_vec());
729 let err =
730 decode_checkpoint_data(data).expect_err("envelope with missing version should error");
731 assert!(
732 err.to_string()
733 .contains("checkpoint missing required field `version`"),
734 "{err:?}"
735 );
736 }
737
738 #[tokio::test]
739 async fn read_bytes_in_chunks_respects_concurrency_limit_and_reassembles() {
740 use std::sync::Arc;
741 use std::sync::atomic::{AtomicUsize, Ordering};
742
743 use tokio::time::{Duration, sleep};
744
745 let total_size = 100usize;
746 let chunk_size = 10usize;
747 let max_in_flight = 3usize;
748
749 let data: Arc<Vec<u8>> = Arc::new((0..total_size).map(|i| (i % 256) as u8).collect());
750 let in_flight = Arc::new(AtomicUsize::new(0));
751 let max_seen = Arc::new(AtomicUsize::new(0));
752
753 let out = read_bytes_in_chunks(total_size, chunk_size, max_in_flight, {
754 let data = data.clone();
755 let in_flight = in_flight.clone();
756 let max_seen = max_seen.clone();
757 move |range: std::ops::Range<usize>| {
758 let data = data.clone();
759 let in_flight = in_flight.clone();
760 let max_seen = max_seen.clone();
761 async move {
762 let cur = in_flight.fetch_add(1, Ordering::SeqCst) + 1;
763 max_seen.fetch_max(cur, Ordering::SeqCst);
764
765 sleep(Duration::from_millis(30)).await;
769
770 let bytes = Bytes::copy_from_slice(&data[range]);
771 in_flight.fetch_sub(1, Ordering::SeqCst);
772 Ok(bytes)
773 }
774 }
775 })
776 .await
777 .expect("chunked read should succeed");
778
779 assert_eq!(out.as_ref(), data.as_slice());
780 let max_seen = max_seen.load(Ordering::SeqCst);
781 assert!(max_seen <= max_in_flight, "max_seen={max_seen}");
782 assert!(
783 max_seen > 1,
784 "expected some concurrency, max_seen={max_seen}"
785 );
786 }
787
788 #[tokio::test]
789 async fn read_bytes_in_chunks_adds_range_context_on_error() {
790 let total_size = 30usize;
791 let chunk_size = 10usize;
792 let max_in_flight = 2usize;
793
794 let err = read_bytes_in_chunks(total_size, chunk_size, max_in_flight, |range| async move {
795 if range.start == 10 {
796 anyhow::bail!("boom");
797 }
798 Ok(Bytes::copy_from_slice(&vec![0u8; range.len()]))
799 })
800 .await
801 .expect_err("should fail");
802
803 let msg = err.to_string();
804 assert!(
805 msg.contains("read checkpoint chunk 2/3 range 10..20"),
806 "unexpected error message: {msg}"
807 );
808 let msg_with_chain = format!("{err:#}");
809 assert!(
810 msg_with_chain.contains("boom"),
811 "unexpected error message: {msg_with_chain}"
812 );
813 }
814}