risingwave_object_store/object/
mem.rs1use std::collections::{HashMap, VecDeque};
16use std::ops::Range;
17use std::pin::Pin;
18use std::sync::{Arc, LazyLock};
19use std::task::{Context, Poll};
20use std::time::{SystemTime, UNIX_EPOCH};
21
22use bytes::{BufMut, Bytes, BytesMut};
23use fail::fail_point;
24use futures::Stream;
25use itertools::Itertools;
26use risingwave_common::range::RangeBoundsExt;
27use thiserror::Error;
28use tokio::sync::Mutex;
29
30use super::{
31 ObjectError, ObjectMetadata, ObjectRangeBounds, ObjectResult, ObjectStore, StreamingUploader,
32};
33use crate::object::{ObjectDataStream, ObjectMetadataIter};
34
35#[derive(Error, Debug)]
36pub enum Error {
37 #[error("NotFound error: {0}")]
38 NotFound(String),
39 #[error("Other error: {0}")]
40 Other(String),
41}
42
43impl Error {
44 pub fn is_object_not_found_error(&self) -> bool {
45 matches!(self, Error::NotFound(_))
46 }
47}
48
49impl Error {
50 fn not_found(msg: impl ToString) -> Self {
51 Error::NotFound(msg.to_string())
52 }
53
54 fn other(msg: impl ToString) -> Self {
55 Error::Other(msg.to_string())
56 }
57}
58
59pub struct InMemStreamingUploader {
61 path: String,
62 buf: BytesMut,
63 objects: Arc<Mutex<HashMap<String, (ObjectMetadata, Bytes)>>>,
64}
65
66impl StreamingUploader for InMemStreamingUploader {
67 async fn write_bytes(&mut self, data: Bytes) -> ObjectResult<()> {
68 fail_point!("mem_write_bytes_err", |_| Err(ObjectError::internal(
69 "mem write bytes error"
70 )));
71 self.buf.put(data);
72 Ok(())
73 }
74
75 async fn finish(self) -> ObjectResult<()> {
76 fail_point!("mem_finish_streaming_upload_err", |_| Err(
77 ObjectError::internal("mem finish streaming upload error")
78 ));
79 let obj = self.buf.freeze();
80 if obj.is_empty() {
81 Err(Error::other("upload empty object").into())
82 } else {
83 let metadata = get_obj_meta(&self.path, &obj)?;
84 self.objects.lock().await.insert(self.path, (metadata, obj));
85 Ok(())
86 }
87 }
88
89 fn get_memory_usage(&self) -> u64 {
90 self.buf.capacity() as u64
91 }
92}
93
94#[derive(Default, Clone)]
96pub struct InMemObjectStore {
97 objects: Arc<Mutex<HashMap<String, (ObjectMetadata, Bytes)>>>,
98}
99
100#[async_trait::async_trait]
101impl ObjectStore for InMemObjectStore {
102 type StreamingUploader = InMemStreamingUploader;
103
104 fn get_object_prefix(&self, _obj_id: u64, _use_new_object_prefix_strategy: bool) -> String {
105 String::default()
106 }
107
108 async fn upload(&self, path: &str, obj: Bytes) -> ObjectResult<()> {
109 fail_point!("mem_upload_err", |_| Err(ObjectError::internal(
110 "mem upload error"
111 )));
112 if obj.is_empty() {
113 Err(Error::other("upload empty object").into())
114 } else {
115 let metadata = get_obj_meta(path, &obj)?;
116 self.objects
117 .lock()
118 .await
119 .insert(path.into(), (metadata, obj));
120 Ok(())
121 }
122 }
123
124 async fn streaming_upload(&self, path: &str) -> ObjectResult<Self::StreamingUploader> {
125 Ok(InMemStreamingUploader {
126 path: path.to_owned(),
127 buf: BytesMut::new(),
128 objects: self.objects.clone(),
129 })
130 }
131
132 async fn read(&self, path: &str, range: impl ObjectRangeBounds) -> ObjectResult<Bytes> {
133 fail_point!("mem_read_err", |_| Err(ObjectError::internal(
134 "mem read error"
135 )));
136 self.get_object(path, range).await
137 }
138
139 async fn streaming_read(
143 &self,
144 path: &str,
145 read_range: Range<usize>,
146 ) -> ObjectResult<ObjectDataStream> {
147 fail_point!("mem_streaming_read_err", |_| Err(ObjectError::internal(
148 "mem streaming read error"
149 )));
150 let bytes = self.get_object(path, read_range).await?;
151
152 Ok(Box::pin(InMemDataIterator::new(bytes)))
153 }
154
155 async fn metadata(&self, path: &str) -> ObjectResult<ObjectMetadata> {
156 self.objects
157 .lock()
158 .await
159 .get(path)
160 .map(|(metadata, _)| metadata)
161 .cloned()
162 .ok_or_else(|| Error::not_found(format!("no object at path '{}'", path)).into())
163 }
164
165 async fn delete(&self, path: &str) -> ObjectResult<()> {
166 fail_point!("mem_delete_err", |_| Err(ObjectError::internal(
167 "mem delete error"
168 )));
169 self.objects.lock().await.remove(path);
170 Ok(())
171 }
172
173 async fn delete_objects(&self, paths: &[String]) -> ObjectResult<()> {
176 let mut guard = self.objects.lock().await;
177
178 for path in paths {
179 guard.remove(path);
180 }
181
182 Ok(())
183 }
184
185 async fn list(
186 &self,
187 prefix: &str,
188 start_after: Option<String>,
189 limit: Option<usize>,
190 ) -> ObjectResult<ObjectMetadataIter> {
191 let list_result = self
192 .objects
193 .lock()
194 .await
195 .iter()
196 .filter_map(|(path, (metadata, _))| {
197 if let Some(ref start_after) = start_after
198 && metadata.key.le(start_after)
199 {
200 return None;
201 }
202 if path.starts_with(prefix) {
203 return Some(metadata.clone());
204 }
205 None
206 })
207 .sorted_by(|a, b| Ord::cmp(&a.key, &b.key))
208 .take(limit.unwrap_or(usize::MAX))
209 .collect_vec();
210 Ok(Box::pin(InMemObjectIter::new(list_result)))
211 }
212
213 fn store_media_type(&self) -> &'static str {
214 "mem"
215 }
216}
217
218pub struct InMemDataIterator {
219 data: Bytes,
220 offset: usize,
221}
222
223impl InMemDataIterator {
224 pub fn new(data: Bytes) -> Self {
225 Self { data, offset: 0 }
226 }
227}
228
229impl Stream for InMemDataIterator {
230 type Item = ObjectResult<Bytes>;
231
232 fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
233 const MAX_PACKET_SIZE: usize = 128 * 1024;
234 if self.offset >= self.data.len() {
235 return Poll::Ready(None);
236 }
237 let read_len = std::cmp::min(self.data.len() - self.offset, MAX_PACKET_SIZE);
238 let data = self.data.slice(self.offset..(self.offset + read_len));
239 self.offset += read_len;
240 Poll::Ready(Some(Ok(data)))
241 }
242}
243
244static SHARED: LazyLock<spin::Mutex<InMemObjectStore>> =
245 LazyLock::new(|| spin::Mutex::new(InMemObjectStore::new()));
246
247impl InMemObjectStore {
248 fn new() -> Self {
249 Self {
250 objects: Arc::new(Mutex::new(HashMap::new())),
251 }
252 }
253
254 pub fn for_test() -> Self {
256 Self::new()
257 }
258
259 pub(super) fn shared() -> Self {
264 SHARED.lock().clone()
265 }
266
267 pub fn reset_shared() {
269 *SHARED.lock() = InMemObjectStore::new();
270 }
271
272 async fn get_object(&self, path: &str, range: impl ObjectRangeBounds) -> ObjectResult<Bytes> {
273 let objects = self.objects.lock().await;
274
275 let obj = objects
276 .get(path)
277 .map(|(_, obj)| obj)
278 .ok_or_else(|| Error::not_found(format!("no object at path '{}'", path)))?;
279
280 if let Some(end) = range.end()
281 && end > obj.len()
282 {
283 return Err(Error::other("bad block offset and size").into());
284 }
285
286 Ok(obj.slice(range))
287 }
288}
289
290fn get_obj_meta(path: &str, obj: &Bytes) -> ObjectResult<ObjectMetadata> {
291 Ok(ObjectMetadata {
292 key: path.to_owned(),
293 last_modified: SystemTime::now()
294 .duration_since(UNIX_EPOCH)
295 .map_err(ObjectError::internal)?
296 .as_secs_f64(),
297 total_size: obj.len(),
298 })
299}
300
301struct InMemObjectIter {
302 list_result: VecDeque<ObjectMetadata>,
303}
304
305impl InMemObjectIter {
306 fn new(list_result: Vec<ObjectMetadata>) -> Self {
307 Self {
308 list_result: list_result.into(),
309 }
310 }
311}
312
313impl Stream for InMemObjectIter {
314 type Item = ObjectResult<ObjectMetadata>;
315
316 fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
317 if let Some(i) = self.list_result.pop_front() {
318 return Poll::Ready(Some(Ok(i)));
319 }
320 Poll::Ready(None)
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use futures::TryStreamExt;
327 use itertools::enumerate;
328
329 use super::*;
330
331 #[tokio::test]
332 async fn test_upload() {
333 let block = Bytes::from("123456");
334
335 let s3 = InMemObjectStore::for_test();
336 s3.upload("/abc", block).await.unwrap();
337
338 let err = s3.read("/ab", 0..3).await.unwrap_err();
340 assert!(err.is_object_not_found_error());
341
342 let bytes = s3.read("/abc", 4..6).await.unwrap();
343 assert_eq!(String::from_utf8(bytes.to_vec()).unwrap(), "56".to_owned());
344
345 s3.read("/abc", 4..8).await.unwrap_err();
347
348 s3.delete("/abc").await.unwrap();
349
350 s3.read("/abc", 0..3).await.unwrap_err();
352 }
353
354 #[tokio::test]
355 async fn test_streaming_upload() {
356 let blocks = vec![Bytes::from("123"), Bytes::from("456"), Bytes::from("789")];
357 let obj = Bytes::from("123456789");
358
359 let store = InMemObjectStore::for_test();
360 let mut uploader = store.streaming_upload("/abc").await.unwrap();
361
362 for block in blocks {
363 uploader.write_bytes(block).await.unwrap();
364 }
365 uploader.finish().await.unwrap();
366
367 let read_obj = store.read("/abc", ..).await.unwrap();
369 assert!(read_obj.eq(&obj));
370
371 let read_obj = store.read("/abc", 4..6).await.unwrap();
373 assert_eq!(
374 String::from_utf8(read_obj.to_vec()).unwrap(),
375 "56".to_owned()
376 );
377 }
378
379 #[tokio::test]
380 async fn test_metadata() {
381 let block = Bytes::from("123456");
382
383 let obj_store = InMemObjectStore::for_test();
384 obj_store.upload("/abc", block).await.unwrap();
385
386 let err = obj_store.metadata("/not_exist").await.unwrap_err();
387 assert!(err.is_object_not_found_error());
388
389 let metadata = obj_store.metadata("/abc").await.unwrap();
390 assert_eq!(metadata.total_size, 6);
391 }
392
393 async fn list_all(prefix: &str, store: &InMemObjectStore) -> Vec<ObjectMetadata> {
394 store
395 .list(prefix, None, None)
396 .await
397 .unwrap()
398 .try_collect::<Vec<_>>()
399 .await
400 .unwrap()
401 }
402
403 #[tokio::test]
404 async fn test_list() {
405 let payload = Bytes::from("123456");
406 let store = InMemObjectStore::for_test();
407 assert!(list_all("", &store).await.is_empty());
408
409 let paths = vec!["001/002/test.obj", "001/003/test.obj"];
410 for (i, path) in enumerate(paths.clone()) {
411 assert_eq!(list_all("", &store).await.len(), i);
412 store.upload(path, payload.clone()).await.unwrap();
413 assert_eq!(list_all("", &store).await.len(), i + 1);
414 }
415
416 let list_path = list_all("", &store)
417 .await
418 .iter()
419 .map(|p| p.key.clone())
420 .collect_vec();
421 assert_eq!(list_path, paths);
422
423 for i in 0..=5 {
424 assert_eq!(list_all(&paths[0][0..=i], &store).await.len(), 2);
425 }
426 for i in 6..=paths[0].len() - 1 {
427 assert_eq!(list_all(&paths[0][0..=i], &store).await.len(), 1)
428 }
429 assert!(list_all("003", &store).await.is_empty());
430
431 for (i, path) in enumerate(paths.clone()) {
432 assert_eq!(list_all("", &store).await.len(), paths.len() - i);
433 store.delete(path).await.unwrap();
434 assert_eq!(list_all("", &store).await.len(), paths.len() - i - 1);
435 }
436 }
437
438 #[tokio::test]
439 async fn test_delete_objects() {
440 let block1 = Bytes::from("123456");
441 let block2 = Bytes::from("987654");
442
443 let store = InMemObjectStore::for_test();
444 store.upload("/abc", block1).await.unwrap();
445 store.upload("/klm", block2).await.unwrap();
446
447 assert_eq!(list_all("", &store).await.len(), 2);
448
449 let str_list = [
450 String::from("/abc"),
451 String::from("/klm"),
452 String::from("/xyz"),
453 ];
454
455 store.delete_objects(&str_list).await.unwrap();
456
457 assert_eq!(list_all("", &store).await.len(), 0);
458 }
459}