risingwave_hummock_trace/replay/
worker.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::sync::Arc;
17
18use futures::StreamExt;
19use futures::stream::BoxStream;
20use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
21use tokio::task::JoinHandle;
22
23use super::{GlobalReplay, LocalReplay, ReplayRequest, WorkerId, WorkerResponse};
24use crate::{
25    LocalStorageId, Operation, OperationResult, Record, RecordId, ReplayItem, Result, StorageType,
26    TraceResult, TracedNewLocalOptions,
27};
28
29#[async_trait::async_trait]
30pub trait ReplayWorkerScheduler {
31    // schedule a replaying task for given record
32    fn schedule(&mut self, record: Record);
33    // send result of an operation for a worker
34    fn send_result(&mut self, record: Record);
35    // wait an operation finishes
36    async fn wait_finish(&mut self, record: Record);
37    // gracefully shutdown all workers
38    async fn shutdown(self);
39}
40
41pub(crate) struct WorkerScheduler<G: GlobalReplay> {
42    workers: HashMap<WorkerId, WorkerHandler>,
43    replay: Arc<G>,
44}
45
46impl<G: GlobalReplay> WorkerScheduler<G> {
47    pub(crate) fn new(replay: Arc<G>) -> Self {
48        WorkerScheduler {
49            workers: HashMap::new(),
50            replay,
51        }
52    }
53}
54
55#[async_trait::async_trait]
56impl<G: GlobalReplay + 'static> ReplayWorkerScheduler for WorkerScheduler<G> {
57    fn schedule(&mut self, record: Record) {
58        let worker_id = allocate_worker_id(&record);
59        let handler = self
60            .workers
61            .entry(worker_id)
62            .or_insert_with(|| ReplayWorker::spawn(self.replay.clone()));
63
64        handler.replay(Some(record));
65    }
66
67    fn send_result(&mut self, record: Record) {
68        let worker_id = allocate_worker_id(&record);
69        let Record { operation, .. } = record;
70        // Check if the worker with the given ID exists in the workers map and the record contains a
71        // Result operation.
72        if let (Some(handler), Operation::Result(trace_result)) =
73            (self.workers.get_mut(&worker_id), operation)
74        {
75            // If the worker exists and the record contains a Result operation, send the result to
76            // the worker.
77            handler.send_result(trace_result);
78        }
79    }
80
81    async fn wait_finish(&mut self, record: Record) {
82        let worker_id = allocate_worker_id(&record);
83
84        // Check if the worker with the given ID exists in the workers map.
85        if let Some(handler) = self.workers.get_mut(&worker_id) {
86            // If the worker exists, wait for it to finish executing.
87            let resp = handler.wait(record.record_id).await;
88
89            // If the worker is a one-shot worker or local workers that should be closed, remove it
90            // from the workers map and call its finish method.
91            if matches!(worker_id, WorkerId::OneShot(_))
92                || matches!(resp, Some(WorkerResponse::Shutdown))
93            {
94                let handler = self.workers.remove(&worker_id).unwrap();
95                handler.finish();
96            }
97        }
98    }
99
100    async fn shutdown(self) {
101        // Iterate over the workers map, calling the finish and join methods on each worker.
102        for (_, handler) in self.workers {
103            handler.finish();
104            handler.join().await;
105        }
106    }
107}
108
109struct ReplayWorker {}
110
111impl ReplayWorker {
112    fn spawn(replay: Arc<impl GlobalReplay + 'static>) -> WorkerHandler {
113        let (req_tx, req_rx) = unbounded_channel();
114        let (resp_tx, resp_rx) = unbounded_channel();
115        let (res_tx, res_rx) = unbounded_channel();
116
117        let join = tokio::spawn(Self::run(req_rx, res_rx, resp_tx, replay));
118        WorkerHandler {
119            req_tx,
120            res_tx,
121            record_end_resp_rx: resp_rx,
122            join,
123            stacked_replay_reqs: HashMap::new(),
124        }
125    }
126
127    async fn run(
128        mut req_rx: UnboundedReceiver<ReplayRequest>,
129        mut res_rx: UnboundedReceiver<OperationResult>,
130        resp_tx: UnboundedSender<WorkerResponse>,
131        replay: Arc<impl GlobalReplay>,
132    ) {
133        let mut iters_map: HashMap<RecordId, BoxStream<'static, Result<ReplayItem>>> =
134            HashMap::new();
135        let mut local_storages = LocalStorages::new();
136        let mut should_shutdown = false;
137        let mut local_storage_opts_map = HashMap::new();
138
139        while let Some(Some(record)) = req_rx.recv().await {
140            Self::handle_record(
141                record.clone(),
142                &replay,
143                &mut res_rx,
144                &mut iters_map,
145                &mut local_storages,
146                &mut local_storage_opts_map,
147                &mut should_shutdown,
148            )
149            .await;
150
151            let message = if should_shutdown {
152                WorkerResponse::Shutdown
153            } else {
154                WorkerResponse::Continue
155            };
156
157            resp_tx.send(message).expect("Failed to send message");
158        }
159    }
160
161    async fn handle_record(
162        record: Record,
163        replay: &Arc<impl GlobalReplay>,
164        res_rx: &mut UnboundedReceiver<OperationResult>,
165        iters_map: &mut HashMap<RecordId, BoxStream<'static, Result<ReplayItem>>>,
166        local_storages: &mut LocalStorages,
167        local_storage_opts_map: &mut HashMap<LocalStorageId, TracedNewLocalOptions>,
168        should_shutdown: &mut bool,
169    ) {
170        let Record {
171            storage_type,
172            record_id,
173            operation,
174        } = record;
175
176        match operation {
177            Operation::Get {
178                key,
179                epoch,
180                read_options,
181            } => {
182                let actual = match storage_type {
183                    StorageType::Global => {
184                        // epoch must be Some
185                        let epoch = epoch.unwrap();
186                        replay.get(key, epoch, read_options).await
187                    }
188                    StorageType::Local(_, local_storage_id) => {
189                        let opts = local_storage_opts_map.get(&local_storage_id).unwrap();
190                        assert_eq!(opts.table_id, read_options.table_id);
191                        let s = local_storages.get_mut(&storage_type).unwrap();
192                        s.get(key, read_options).await
193                    }
194                };
195
196                let res = res_rx.recv().await.expect("recv result failed");
197                if let OperationResult::Get(expected) = res {
198                    assert_eq!(TraceResult::from(actual), expected, "get result wrong");
199                } else {
200                    panic!("expect get result, but got {:?}", res);
201                }
202            }
203            Operation::Insert {
204                key,
205                new_val,
206                old_val,
207            } => {
208                let local_storage = local_storages.get_mut(&storage_type).unwrap();
209                let actual = local_storage.insert(key, new_val, old_val);
210
211                let expected = res_rx.recv().await.expect("recv result failed");
212                if let OperationResult::Insert(expected) = expected {
213                    assert_eq!(TraceResult::from(actual), expected, "get result wrong");
214                } else {
215                    panic!("expect insert result, but got {:?}", expected);
216                }
217            }
218            Operation::Delete { key, old_val } => {
219                let local_storage = local_storages.get_mut(&storage_type).unwrap();
220                let actual = local_storage.delete(key, old_val);
221
222                let expected = res_rx.recv().await.expect("recv result failed");
223                if let OperationResult::Delete(expected) = expected {
224                    assert_eq!(TraceResult::from(actual), expected, "get result wrong");
225                } else {
226                    panic!("expect delete result, but got {:?}", expected);
227                }
228            }
229            Operation::Iter {
230                key_range,
231                epoch,
232                read_options,
233            } => {
234                let iter = match storage_type {
235                    StorageType::Global => {
236                        // Global Storage must have a epoch
237                        let epoch = epoch.unwrap();
238                        replay.iter(key_range, epoch, read_options).await
239                    }
240                    StorageType::Local(_, id) => {
241                        let opts = local_storage_opts_map.get(&id).unwrap();
242                        assert_eq!(opts.table_id, read_options.table_id);
243                        let s = local_storages.get_mut(&storage_type).unwrap();
244                        s.iter(key_range, read_options).await
245                    }
246                };
247                let res = res_rx.recv().await.expect("recv result failed");
248                if let OperationResult::Iter(expected) = res {
249                    if expected.is_ok() {
250                        let iter = iter.unwrap().boxed();
251                        let id = record_id;
252                        iters_map.insert(id, iter);
253                    } else {
254                        assert!(iter.is_err());
255                    }
256                } else {
257                    panic!("expect iter result, but got {:?}", res);
258                }
259            }
260            Operation::Sync(sync_table_epochs) => {
261                assert_eq!(storage_type, StorageType::Global);
262                let sync_result = replay.sync(sync_table_epochs).await.unwrap();
263                let res = res_rx.recv().await.expect("recv result failed");
264                if let OperationResult::Sync(expected) = res {
265                    assert_eq!(TraceResult::Ok(sync_result), expected, "sync failed");
266                } else {
267                    panic!("expect sync result, but got {:?}", res);
268                }
269            }
270            Operation::IterNext(id) => {
271                let iter = iters_map.get_mut(&id).expect("iter not in worker");
272                let actual = iter.next().await;
273                let actual = actual.map(|res| res.unwrap());
274                let res = res_rx.recv().await.expect("recv result failed");
275                if let OperationResult::IterNext(expected) = res {
276                    assert_eq!(TraceResult::Ok(actual), expected, "iter_next result wrong");
277                } else {
278                    panic!("expect iter_next result, but got {:?}", res);
279                }
280            }
281            Operation::NewLocalStorage(new_local_opts, id) => {
282                assert_ne!(storage_type, StorageType::Global);
283                local_storage_opts_map.insert(id, new_local_opts.clone());
284                let local_storage = replay.new_local(new_local_opts).await;
285                local_storages.insert(storage_type, local_storage);
286            }
287            Operation::DropLocalStorage => {
288                assert_ne!(storage_type, StorageType::Global);
289                if let StorageType::Local(_, _) = storage_type {
290                    local_storages.remove(&storage_type);
291                }
292                // All local storages have been dropped, we should shutdown this worker
293                // If there are incoming new_local, this ReplayWorker will spawn again
294                if local_storages.is_empty() {
295                    *should_shutdown = true;
296                }
297            }
298            Operation::MetaMessage(resp) => {
299                assert_eq!(storage_type, StorageType::Global);
300                let op = resp.0.operation();
301                if let Some(info) = resp.0.info {
302                    replay
303                        .notify_hummock(info, op, resp.0.version)
304                        .await
305                        .unwrap();
306                }
307            }
308            Operation::LocalStorageInit(options) => {
309                assert_ne!(storage_type, StorageType::Global);
310                let local_storage = local_storages.get_mut(&storage_type).unwrap();
311                local_storage.init(options).await.unwrap();
312            }
313            Operation::TryWaitEpoch(epoch, options) => {
314                assert_eq!(storage_type, StorageType::Global);
315                let res = res_rx.recv().await.expect("recv result failed");
316                if let OperationResult::TryWaitEpoch(expected) = res {
317                    let actual = replay.try_wait_epoch(epoch.into(), options).await;
318                    assert_eq!(TraceResult::from(actual), expected, "try_wait_epoch wrong");
319                } else {
320                    panic!(
321                        "wrong try_wait_epoch result, expect epoch result, but got {:?}",
322                        res
323                    );
324                }
325            }
326            Operation::SealCurrentEpoch { epoch, opts } => {
327                assert_ne!(storage_type, StorageType::Global);
328                let local_storage = local_storages.get_mut(&storage_type).unwrap();
329                local_storage.seal_current_epoch(epoch, opts);
330            }
331            Operation::Flush => {
332                assert_ne!(storage_type, StorageType::Global);
333                let local_storage = local_storages.get_mut(&storage_type).unwrap();
334                let res = res_rx.recv().await.expect("recv result failed");
335                if let OperationResult::Flush(expected) = res {
336                    let actual = local_storage.flush().await;
337                    assert_eq!(TraceResult::from(actual), expected, "flush wrong");
338                } else {
339                    panic!("wrong flush result, expect flush result, but got {:?}", res);
340                }
341            }
342            Operation::TryFlush => {
343                assert_ne!(storage_type, StorageType::Global);
344                let local_storage = local_storages.get_mut(&storage_type).unwrap();
345                let res = res_rx.recv().await.expect("recv result failed");
346                if let OperationResult::TryFlush(_) = res {
347                    let _ = local_storage.try_flush().await;
348                    // todo(wcy-fdu): unify try_flush and flush interface, do not return usize.
349                    // assert_eq!(TraceResult::from(actual), expected, "try flush wrong");
350                } else {
351                    panic!(
352                        "wrong try flush result, expect flush result, but got {:?}",
353                        res
354                    );
355                }
356            }
357
358            Operation::Finish => unreachable!(),
359            Operation::Result(_) => unreachable!(),
360        }
361    }
362}
363
364fn allocate_worker_id(record: &Record) -> WorkerId {
365    match record.storage_type() {
366        StorageType::Local(concurrent_id, id) => WorkerId::Local(*concurrent_id, *id),
367        StorageType::Global => WorkerId::OneShot(record.record_id()),
368    }
369}
370
371struct WorkerHandler {
372    req_tx: UnboundedSender<ReplayRequest>,
373    res_tx: UnboundedSender<OperationResult>,
374    record_end_resp_rx: UnboundedReceiver<WorkerResponse>,
375    join: JoinHandle<()>,
376    // Used for ops like iter, since a iter may have multiple next
377    // Example
378    // Iter begin, Iter next, Iter next...., Iter finish
379    // So replay requests may be stacked
380    stacked_replay_reqs: HashMap<u64, u32>,
381}
382
383impl WorkerHandler {
384    async fn join(self) {
385        self.join.await.expect("failed to stop worker");
386    }
387
388    fn finish(&self) {
389        self.send_replay_req(None);
390    }
391
392    fn replay(&mut self, req: ReplayRequest) {
393        if let Some(r) = &req {
394            let entry = self.stacked_replay_reqs.entry(r.record_id).or_insert(0);
395            *entry += 1;
396        }
397        self.send_replay_req(req);
398    }
399
400    async fn wait(&mut self, record_id: u64) -> Option<WorkerResponse> {
401        let mut stacked_replay_reqs = *self.stacked_replay_reqs.get(&record_id).unwrap();
402        assert!(
403            stacked_replay_reqs > 0,
404            "replay count should be 0, but found {}",
405            stacked_replay_reqs
406        );
407        let mut resp = None;
408
409        while stacked_replay_reqs > 0 {
410            resp = Some(
411                self.record_end_resp_rx
412                    .recv()
413                    .await
414                    .expect("failed to wait worker resp"),
415            );
416            stacked_replay_reqs -= 1;
417        }
418        // cleaned this record from replay worker
419        self.stacked_replay_reqs.remove(&record_id);
420        // impossible to be None
421        resp
422    }
423
424    fn send_replay_req(&self, req: ReplayRequest) {
425        self.req_tx
426            .send(req)
427            .expect("failed to send replay request");
428    }
429
430    fn send_result(&self, result: OperationResult) {
431        self.res_tx.send(result).expect("failed to send result");
432    }
433}
434
435struct LocalStorages {
436    storages: HashMap<StorageType, Box<dyn LocalReplay>>,
437}
438
439impl LocalStorages {
440    fn new() -> Self {
441        Self {
442            storages: HashMap::new(),
443        }
444    }
445
446    fn remove(&mut self, storage_type: &StorageType) {
447        self.storages.remove(storage_type);
448    }
449
450    fn get_mut(&mut self, storage_type: &StorageType) -> Option<&mut Box<dyn LocalReplay>> {
451        self.storages.get_mut(storage_type)
452    }
453
454    fn insert(&mut self, storage_type: StorageType, local_storage: Box<dyn LocalReplay>) {
455        self.storages.insert(storage_type, local_storage);
456    }
457
458    fn is_empty(&self) -> bool {
459        self.storages.is_empty()
460    }
461
462    #[cfg(test)]
463    fn len(&self) -> usize {
464        self.storages.len()
465    }
466}
467
468#[cfg(test)]
469mod tests {
470
471    use std::ops::Bound;
472
473    use bytes::Bytes;
474    use mockall::predicate;
475
476    use super::*;
477    use crate::replay::{MockGlobalReplayInterface, MockLocalReplayInterface};
478    use crate::{MockReplayIterStream, TracedBytes, TracedReadOptions};
479
480    #[tokio::test]
481    async fn test_handle_record() {
482        let mut iters_map = HashMap::new();
483        let mut local_storage_opts_map = HashMap::new();
484        let mut local_storages = LocalStorages::new();
485        let (res_tx, mut res_rx) = unbounded_channel();
486        let get_table_id = 12;
487        let iter_table_id = 14654;
488        let read_options = TracedReadOptions::for_test(get_table_id);
489        let iter_read_options = TracedReadOptions::for_test(iter_table_id);
490        let op = Operation::get(Bytes::from(vec![123]), Some(123), read_options);
491
492        let new_local_opts = TracedNewLocalOptions::for_test(get_table_id);
493
494        let iter_local_opts = TracedNewLocalOptions::for_test(iter_table_id);
495        let mut should_exit = false;
496        let get_storage_type = StorageType::Local(0, 0);
497        let record = Record::new(get_storage_type, 1, op);
498        let mut mock_replay = MockGlobalReplayInterface::new();
499
500        mock_replay.expect_new_local().times(1).returning(move |_| {
501            let mut mock_local = MockLocalReplayInterface::new();
502
503            mock_local
504                .expect_get()
505                .with(
506                    predicate::eq(TracedBytes::from(vec![123])),
507                    predicate::always(),
508                )
509                .returning(|_, _| Ok(Some(TracedBytes::from(vec![120]))));
510
511            Box::new(mock_local)
512        });
513
514        mock_replay.expect_new_local().times(1).returning(move |_| {
515            let mut mock_local = MockLocalReplayInterface::new();
516
517            mock_local
518                .expect_iter()
519                .with(
520                    predicate::eq((Bound::Unbounded, Bound::Unbounded)),
521                    predicate::always(),
522                )
523                .returning(move |_, _| {
524                    let iter = MockReplayIterStream::new(vec![(
525                        TracedBytes::from(vec![1]),
526                        TracedBytes::from(vec![0]),
527                    )]);
528                    Ok(iter.into_stream().boxed())
529                });
530
531            Box::new(mock_local)
532        });
533
534        let replay = Arc::new(mock_replay);
535
536        ReplayWorker::handle_record(
537            Record::new(
538                get_storage_type,
539                0,
540                Operation::NewLocalStorage(new_local_opts, 0),
541            ),
542            &replay,
543            &mut res_rx,
544            &mut iters_map,
545            &mut local_storages,
546            &mut local_storage_opts_map,
547            &mut should_exit,
548        )
549        .await;
550
551        res_tx
552            .send(OperationResult::Get(TraceResult::Ok(Some(
553                TracedBytes::from(vec![120]),
554            ))))
555            .unwrap();
556
557        ReplayWorker::handle_record(
558            record,
559            &replay,
560            &mut res_rx,
561            &mut iters_map,
562            &mut local_storages,
563            &mut local_storage_opts_map,
564            &mut should_exit,
565        )
566        .await;
567
568        assert_eq!(local_storages.len(), 1);
569        assert!(iters_map.is_empty());
570
571        let op = Operation::Iter {
572            key_range: (Bound::Unbounded, Bound::Unbounded),
573            epoch: Some(45),
574            read_options: iter_read_options,
575        };
576
577        let iter_storage_type = StorageType::Local(0, 1);
578
579        ReplayWorker::handle_record(
580            Record::new(
581                iter_storage_type,
582                2,
583                Operation::NewLocalStorage(iter_local_opts, 1),
584            ),
585            &replay,
586            &mut res_rx,
587            &mut iters_map,
588            &mut local_storages,
589            &mut local_storage_opts_map,
590            &mut should_exit,
591        )
592        .await;
593
594        let record = Record::new(iter_storage_type, 1, op);
595        res_tx
596            .send(OperationResult::Iter(TraceResult::Ok(())))
597            .unwrap();
598
599        ReplayWorker::handle_record(
600            record,
601            &replay,
602            &mut res_rx,
603            &mut iters_map,
604            &mut local_storages,
605            &mut local_storage_opts_map,
606            &mut should_exit,
607        )
608        .await;
609
610        assert_eq!(local_storages.len(), 2);
611        assert_eq!(iters_map.len(), 1);
612
613        let op = Operation::IterNext(1);
614        let record = Record::new(iter_storage_type, 3, op);
615        res_tx
616            .send(OperationResult::IterNext(TraceResult::Ok(Some((
617                TracedBytes::from(vec![1]),
618                TracedBytes::from(vec![0]),
619            )))))
620            .unwrap();
621
622        ReplayWorker::handle_record(
623            record,
624            &replay,
625            &mut res_rx,
626            &mut iters_map,
627            &mut local_storages,
628            &mut local_storage_opts_map,
629            &mut should_exit,
630        )
631        .await;
632
633        assert_eq!(local_storages.len(), 2);
634        assert_eq!(iters_map.len(), 1);
635    }
636
637    #[tokio::test]
638    async fn test_worker_scheduler() {
639        // Create a mock GlobalReplay and a ReplayWorkerScheduler that uses the mock GlobalReplay.
640        let mut mock_replay = MockGlobalReplayInterface::default();
641        let record_id = 29053;
642        let key = TracedBytes::from(vec![1]);
643        let epoch = 2596;
644        let read_options = TracedReadOptions::for_test(1);
645
646        let res_bytes = TracedBytes::from(vec![58, 54, 35]);
647
648        mock_replay
649            .expect_get()
650            .with(
651                predicate::eq(key.clone()),
652                predicate::eq(epoch),
653                predicate::eq(read_options.clone()),
654            )
655            .returning(move |_, _, _| Ok(Some(TracedBytes::from(vec![58, 54, 35]))));
656
657        let mut scheduler = WorkerScheduler::new(Arc::new(mock_replay));
658        // Schedule a record for replay.
659        let record = Record::new(
660            StorageType::Global,
661            record_id,
662            Operation::get(key.into(), Some(epoch), read_options),
663        );
664        scheduler.schedule(record);
665
666        let result = Record::new(
667            StorageType::Global,
668            record_id,
669            Operation::Result(OperationResult::Get(TraceResult::Ok(Some(res_bytes)))),
670        );
671
672        scheduler.send_result(result);
673
674        let fin = Record::new(StorageType::Global, record_id, Operation::Finish);
675        scheduler.wait_finish(fin).await;
676
677        scheduler.shutdown().await;
678    }
679}