risingwave_object_store/object/
mem.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{HashMap, VecDeque};
16use std::ops::Range;
17use std::pin::Pin;
18use std::sync::{Arc, LazyLock};
19use std::task::{Context, Poll};
20use std::time::{SystemTime, UNIX_EPOCH};
21
22use bytes::{BufMut, Bytes, BytesMut};
23use fail::fail_point;
24use futures::Stream;
25use itertools::Itertools;
26use risingwave_common::range::RangeBoundsExt;
27use thiserror::Error;
28use tokio::sync::Mutex;
29
30use super::{
31    ObjectError, ObjectMetadata, ObjectRangeBounds, ObjectResult, ObjectStore, StreamingUploader,
32};
33use crate::object::{ObjectDataStream, ObjectMetadataIter};
34
35#[derive(Error, Debug)]
36pub enum Error {
37    #[error("NotFound error: {0}")]
38    NotFound(String),
39    #[error("Other error: {0}")]
40    Other(String),
41}
42
43impl Error {
44    pub fn is_object_not_found_error(&self) -> bool {
45        matches!(self, Error::NotFound(_))
46    }
47}
48
49impl Error {
50    fn not_found(msg: impl ToString) -> Self {
51        Error::NotFound(msg.to_string())
52    }
53
54    fn other(msg: impl ToString) -> Self {
55        Error::Other(msg.to_string())
56    }
57}
58
59/// Store multiple parts in a map, and concatenate them on finish.
60pub struct InMemStreamingUploader {
61    path: String,
62    buf: BytesMut,
63    objects: Arc<Mutex<HashMap<String, (ObjectMetadata, Bytes)>>>,
64}
65
66impl StreamingUploader for InMemStreamingUploader {
67    async fn write_bytes(&mut self, data: Bytes) -> ObjectResult<()> {
68        fail_point!("mem_write_bytes_err", |_| Err(ObjectError::internal(
69            "mem write bytes error"
70        )));
71        self.buf.put(data);
72        Ok(())
73    }
74
75    async fn finish(self) -> ObjectResult<()> {
76        fail_point!("mem_finish_streaming_upload_err", |_| Err(
77            ObjectError::internal("mem finish streaming upload error")
78        ));
79        let obj = self.buf.freeze();
80        if obj.is_empty() {
81            Err(Error::other("upload empty object").into())
82        } else {
83            let metadata = get_obj_meta(&self.path, &obj)?;
84            self.objects.lock().await.insert(self.path, (metadata, obj));
85            Ok(())
86        }
87    }
88
89    fn get_memory_usage(&self) -> u64 {
90        self.buf.capacity() as u64
91    }
92}
93
94/// In-memory object storage, useful for testing.
95#[derive(Default, Clone)]
96pub struct InMemObjectStore {
97    objects: Arc<Mutex<HashMap<String, (ObjectMetadata, Bytes)>>>,
98}
99
100#[async_trait::async_trait]
101impl ObjectStore for InMemObjectStore {
102    type StreamingUploader = InMemStreamingUploader;
103
104    fn get_object_prefix(&self, _obj_id: u64, _use_new_object_prefix_strategy: bool) -> String {
105        String::default()
106    }
107
108    async fn upload(&self, path: &str, obj: Bytes) -> ObjectResult<()> {
109        fail_point!("mem_upload_err", |_| Err(ObjectError::internal(
110            "mem upload error"
111        )));
112        if obj.is_empty() {
113            Err(Error::other("upload empty object").into())
114        } else {
115            let metadata = get_obj_meta(path, &obj)?;
116            self.objects
117                .lock()
118                .await
119                .insert(path.into(), (metadata, obj));
120            Ok(())
121        }
122    }
123
124    async fn streaming_upload(&self, path: &str) -> ObjectResult<Self::StreamingUploader> {
125        Ok(InMemStreamingUploader {
126            path: path.to_owned(),
127            buf: BytesMut::new(),
128            objects: self.objects.clone(),
129        })
130    }
131
132    async fn read(&self, path: &str, range: impl ObjectRangeBounds) -> ObjectResult<Bytes> {
133        fail_point!("mem_read_err", |_| Err(ObjectError::internal(
134            "mem read error"
135        )));
136        self.get_object(path, range).await
137    }
138
139    /// Returns a stream reading the object specified in `path`. If given, the stream starts at the
140    /// byte with index `start_pos` (0-based). As far as possible, the stream only loads the amount
141    /// of data into memory that is read from the stream.
142    async fn streaming_read(
143        &self,
144        path: &str,
145        read_range: Range<usize>,
146    ) -> ObjectResult<ObjectDataStream> {
147        fail_point!("mem_streaming_read_err", |_| Err(ObjectError::internal(
148            "mem streaming read error"
149        )));
150        let bytes = self.get_object(path, read_range).await?;
151
152        Ok(Box::pin(InMemDataIterator::new(bytes)))
153    }
154
155    async fn metadata(&self, path: &str) -> ObjectResult<ObjectMetadata> {
156        self.objects
157            .lock()
158            .await
159            .get(path)
160            .map(|(metadata, _)| metadata)
161            .cloned()
162            .ok_or_else(|| Error::not_found(format!("no object at path '{}'", path)).into())
163    }
164
165    async fn delete(&self, path: &str) -> ObjectResult<()> {
166        fail_point!("mem_delete_err", |_| Err(ObjectError::internal(
167            "mem delete error"
168        )));
169        self.objects.lock().await.remove(path);
170        Ok(())
171    }
172
173    /// Deletes the objects with the given paths permanently from the storage. If an object
174    /// specified in the request is not found, it will be considered as successfully deleted.
175    async fn delete_objects(&self, paths: &[String]) -> ObjectResult<()> {
176        let mut guard = self.objects.lock().await;
177
178        for path in paths {
179            guard.remove(path);
180        }
181
182        Ok(())
183    }
184
185    async fn list(
186        &self,
187        prefix: &str,
188        start_after: Option<String>,
189        limit: Option<usize>,
190    ) -> ObjectResult<ObjectMetadataIter> {
191        let list_result = self
192            .objects
193            .lock()
194            .await
195            .iter()
196            .filter_map(|(path, (metadata, _))| {
197                if let Some(ref start_after) = start_after
198                    && metadata.key.le(start_after)
199                {
200                    return None;
201                }
202                if path.starts_with(prefix) {
203                    return Some(metadata.clone());
204                }
205                None
206            })
207            .sorted_by(|a, b| Ord::cmp(&a.key, &b.key))
208            .take(limit.unwrap_or(usize::MAX))
209            .collect_vec();
210        Ok(Box::pin(InMemObjectIter::new(list_result)))
211    }
212
213    fn store_media_type(&self) -> &'static str {
214        "mem"
215    }
216}
217
218pub struct InMemDataIterator {
219    data: Bytes,
220    offset: usize,
221}
222
223impl InMemDataIterator {
224    pub fn new(data: Bytes) -> Self {
225        Self { data, offset: 0 }
226    }
227}
228
229impl Stream for InMemDataIterator {
230    type Item = ObjectResult<Bytes>;
231
232    fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
233        const MAX_PACKET_SIZE: usize = 128 * 1024;
234        if self.offset >= self.data.len() {
235            return Poll::Ready(None);
236        }
237        let read_len = std::cmp::min(self.data.len() - self.offset, MAX_PACKET_SIZE);
238        let data = self.data.slice(self.offset..(self.offset + read_len));
239        self.offset += read_len;
240        Poll::Ready(Some(Ok(data)))
241    }
242}
243
244static SHARED: LazyLock<spin::Mutex<InMemObjectStore>> =
245    LazyLock::new(|| spin::Mutex::new(InMemObjectStore::new()));
246
247impl InMemObjectStore {
248    fn new() -> Self {
249        Self {
250            objects: Arc::new(Mutex::new(HashMap::new())),
251        }
252    }
253
254    /// Create a new in-memory object store for testing, isolated with others.
255    pub fn for_test() -> Self {
256        Self::new()
257    }
258
259    /// Get a reference to the in-memory object store shared in this process.
260    ///
261    /// Note: Should only be used for `risedev playground`, when there're multiple compute-nodes or
262    /// compactors in the same process.
263    pub(super) fn shared() -> Self {
264        SHARED.lock().clone()
265    }
266
267    /// Reset the shared in-memory object store.
268    pub fn reset_shared() {
269        *SHARED.lock() = InMemObjectStore::new();
270    }
271
272    async fn get_object(&self, path: &str, range: impl ObjectRangeBounds) -> ObjectResult<Bytes> {
273        let objects = self.objects.lock().await;
274
275        let obj = objects
276            .get(path)
277            .map(|(_, obj)| obj)
278            .ok_or_else(|| Error::not_found(format!("no object at path '{}'", path)))?;
279
280        if let Some(end) = range.end()
281            && end > obj.len()
282        {
283            return Err(Error::other("bad block offset and size").into());
284        }
285
286        Ok(obj.slice(range))
287    }
288}
289
290fn get_obj_meta(path: &str, obj: &Bytes) -> ObjectResult<ObjectMetadata> {
291    Ok(ObjectMetadata {
292        key: path.to_owned(),
293        last_modified: SystemTime::now()
294            .duration_since(UNIX_EPOCH)
295            .map_err(ObjectError::internal)?
296            .as_secs_f64(),
297        total_size: obj.len(),
298    })
299}
300
301struct InMemObjectIter {
302    list_result: VecDeque<ObjectMetadata>,
303}
304
305impl InMemObjectIter {
306    fn new(list_result: Vec<ObjectMetadata>) -> Self {
307        Self {
308            list_result: list_result.into(),
309        }
310    }
311}
312
313impl Stream for InMemObjectIter {
314    type Item = ObjectResult<ObjectMetadata>;
315
316    fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
317        if let Some(i) = self.list_result.pop_front() {
318            return Poll::Ready(Some(Ok(i)));
319        }
320        Poll::Ready(None)
321    }
322}
323
324#[cfg(test)]
325mod tests {
326    use futures::TryStreamExt;
327    use itertools::enumerate;
328
329    use super::*;
330
331    #[tokio::test]
332    async fn test_upload() {
333        let block = Bytes::from("123456");
334
335        let s3 = InMemObjectStore::for_test();
336        s3.upload("/abc", block).await.unwrap();
337
338        // No such object.
339        let err = s3.read("/ab", 0..3).await.unwrap_err();
340        assert!(err.is_object_not_found_error());
341
342        let bytes = s3.read("/abc", 4..6).await.unwrap();
343        assert_eq!(String::from_utf8(bytes.to_vec()).unwrap(), "56".to_owned());
344
345        // Overflow.
346        s3.read("/abc", 4..8).await.unwrap_err();
347
348        s3.delete("/abc").await.unwrap();
349
350        // No such object.
351        s3.read("/abc", 0..3).await.unwrap_err();
352    }
353
354    #[tokio::test]
355    async fn test_streaming_upload() {
356        let blocks = vec![Bytes::from("123"), Bytes::from("456"), Bytes::from("789")];
357        let obj = Bytes::from("123456789");
358
359        let store = InMemObjectStore::for_test();
360        let mut uploader = store.streaming_upload("/abc").await.unwrap();
361
362        for block in blocks {
363            uploader.write_bytes(block).await.unwrap();
364        }
365        uploader.finish().await.unwrap();
366
367        // Read whole object.
368        let read_obj = store.read("/abc", ..).await.unwrap();
369        assert!(read_obj.eq(&obj));
370
371        // Read part of the object.
372        let read_obj = store.read("/abc", 4..6).await.unwrap();
373        assert_eq!(
374            String::from_utf8(read_obj.to_vec()).unwrap(),
375            "56".to_owned()
376        );
377    }
378
379    #[tokio::test]
380    async fn test_metadata() {
381        let block = Bytes::from("123456");
382
383        let obj_store = InMemObjectStore::for_test();
384        obj_store.upload("/abc", block).await.unwrap();
385
386        let err = obj_store.metadata("/not_exist").await.unwrap_err();
387        assert!(err.is_object_not_found_error());
388
389        let metadata = obj_store.metadata("/abc").await.unwrap();
390        assert_eq!(metadata.total_size, 6);
391    }
392
393    async fn list_all(prefix: &str, store: &InMemObjectStore) -> Vec<ObjectMetadata> {
394        store
395            .list(prefix, None, None)
396            .await
397            .unwrap()
398            .try_collect::<Vec<_>>()
399            .await
400            .unwrap()
401    }
402
403    #[tokio::test]
404    async fn test_list() {
405        let payload = Bytes::from("123456");
406        let store = InMemObjectStore::for_test();
407        assert!(list_all("", &store).await.is_empty());
408
409        let paths = vec!["001/002/test.obj", "001/003/test.obj"];
410        for (i, path) in enumerate(paths.clone()) {
411            assert_eq!(list_all("", &store).await.len(), i);
412            store.upload(path, payload.clone()).await.unwrap();
413            assert_eq!(list_all("", &store).await.len(), i + 1);
414        }
415
416        let list_path = list_all("", &store)
417            .await
418            .iter()
419            .map(|p| p.key.clone())
420            .collect_vec();
421        assert_eq!(list_path, paths);
422
423        for i in 0..=5 {
424            assert_eq!(list_all(&paths[0][0..=i], &store).await.len(), 2);
425        }
426        for i in 6..=paths[0].len() - 1 {
427            assert_eq!(list_all(&paths[0][0..=i], &store).await.len(), 1)
428        }
429        assert!(list_all("003", &store).await.is_empty());
430
431        for (i, path) in enumerate(paths.clone()) {
432            assert_eq!(list_all("", &store).await.len(), paths.len() - i);
433            store.delete(path).await.unwrap();
434            assert_eq!(list_all("", &store).await.len(), paths.len() - i - 1);
435        }
436    }
437
438    #[tokio::test]
439    async fn test_delete_objects() {
440        let block1 = Bytes::from("123456");
441        let block2 = Bytes::from("987654");
442
443        let store = InMemObjectStore::for_test();
444        store.upload("/abc", block1).await.unwrap();
445        store.upload("/klm", block2).await.unwrap();
446
447        assert_eq!(list_all("", &store).await.len(), 2);
448
449        let str_list = [
450            String::from("/abc"),
451            String::from("/klm"),
452            String::from("/xyz"),
453        ];
454
455        store.delete_objects(&str_list).await.unwrap();
456
457        assert_eq!(list_all("", &store).await.len(), 0);
458    }
459}