risingwave_hummock_trace/replay/
worker.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::sync::Arc;
17
18use futures::StreamExt;
19use futures::stream::BoxStream;
20use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
21use tokio::task::JoinHandle;
22
23use super::{GlobalReplay, LocalReplay, ReplayRequest, WorkerId, WorkerResponse};
24use crate::{
25    LocalStorageId, Operation, OperationResult, Record, RecordId, ReplayItem, Result, StorageType,
26    TraceResult, TracedNewLocalOptions,
27};
28
29#[async_trait::async_trait]
30pub trait ReplayWorkerScheduler {
31    // schedule a replaying task for given record
32    fn schedule(&mut self, record: Record);
33    // send result of an operation for a worker
34    fn send_result(&mut self, record: Record);
35    // wait an operation finishes
36    async fn wait_finish(&mut self, record: Record);
37    // gracefully shutdown all workers
38    async fn shutdown(self);
39}
40
41pub(crate) struct WorkerScheduler<G: GlobalReplay> {
42    workers: HashMap<WorkerId, WorkerHandler>,
43    replay: Arc<G>,
44}
45
46impl<G: GlobalReplay> WorkerScheduler<G> {
47    pub(crate) fn new(replay: Arc<G>) -> Self {
48        WorkerScheduler {
49            workers: HashMap::new(),
50            replay,
51        }
52    }
53}
54
55#[async_trait::async_trait]
56impl<G: GlobalReplay + 'static> ReplayWorkerScheduler for WorkerScheduler<G> {
57    fn schedule(&mut self, record: Record) {
58        let worker_id = allocate_worker_id(&record);
59        let handler = self
60            .workers
61            .entry(worker_id)
62            .or_insert_with(|| ReplayWorker::spawn(self.replay.clone()));
63
64        handler.replay(Some(record));
65    }
66
67    fn send_result(&mut self, record: Record) {
68        let worker_id = allocate_worker_id(&record);
69        let Record { operation, .. } = record;
70        // Check if the worker with the given ID exists in the workers map and the record contains a
71        // Result operation.
72        if let (Some(handler), Operation::Result(trace_result)) =
73            (self.workers.get_mut(&worker_id), operation)
74        {
75            // If the worker exists and the record contains a Result operation, send the result to
76            // the worker.
77            handler.send_result(trace_result);
78        }
79    }
80
81    async fn wait_finish(&mut self, record: Record) {
82        let worker_id = allocate_worker_id(&record);
83
84        // Check if the worker with the given ID exists in the workers map.
85        if let Some(handler) = self.workers.get_mut(&worker_id) {
86            // If the worker exists, wait for it to finish executing.
87            let resp = handler.wait(record.record_id).await;
88
89            // If the worker is a one-shot worker or local workers that should be closed, remove it
90            // from the workers map and call its finish method.
91            if matches!(worker_id, WorkerId::OneShot(_))
92                || matches!(resp, Some(WorkerResponse::Shutdown))
93            {
94                let handler = self.workers.remove(&worker_id).unwrap();
95                handler.finish();
96            }
97        }
98    }
99
100    async fn shutdown(self) {
101        // Iterate over the workers map, calling the finish and join methods on each worker.
102        for (_, handler) in self.workers {
103            handler.finish();
104            handler.join().await;
105        }
106    }
107}
108
109struct ReplayWorker {}
110
111impl ReplayWorker {
112    fn spawn(replay: Arc<impl GlobalReplay + 'static>) -> WorkerHandler {
113        let (req_tx, req_rx) = unbounded_channel();
114        let (resp_tx, resp_rx) = unbounded_channel();
115        let (res_tx, res_rx) = unbounded_channel();
116
117        let join = tokio::spawn(Self::run(req_rx, res_rx, resp_tx, replay));
118        WorkerHandler {
119            req_tx,
120            res_tx,
121            record_end_resp_rx: resp_rx,
122            join,
123            stacked_replay_reqs: HashMap::new(),
124        }
125    }
126
127    async fn run(
128        mut req_rx: UnboundedReceiver<ReplayRequest>,
129        mut res_rx: UnboundedReceiver<OperationResult>,
130        resp_tx: UnboundedSender<WorkerResponse>,
131        replay: Arc<impl GlobalReplay>,
132    ) {
133        let mut iters_map: HashMap<RecordId, BoxStream<'static, Result<ReplayItem>>> =
134            HashMap::new();
135        let mut local_storages = LocalStorages::new();
136        let mut should_shutdown = false;
137        let mut local_storage_opts_map = HashMap::new();
138
139        while let Some(Some(record)) = req_rx.recv().await {
140            Self::handle_record(
141                record.clone(),
142                &replay,
143                &mut res_rx,
144                &mut iters_map,
145                &mut local_storages,
146                &mut local_storage_opts_map,
147                &mut should_shutdown,
148            )
149            .await;
150
151            let message = if should_shutdown {
152                WorkerResponse::Shutdown
153            } else {
154                WorkerResponse::Continue
155            };
156
157            resp_tx.send(message).expect("Failed to send message");
158        }
159    }
160
161    async fn handle_record(
162        record: Record,
163        replay: &Arc<impl GlobalReplay>,
164        res_rx: &mut UnboundedReceiver<OperationResult>,
165        iters_map: &mut HashMap<RecordId, BoxStream<'static, Result<ReplayItem>>>,
166        local_storages: &mut LocalStorages,
167        local_storage_opts_map: &mut HashMap<LocalStorageId, TracedNewLocalOptions>,
168        should_shutdown: &mut bool,
169    ) {
170        let Record {
171            storage_type,
172            record_id,
173            operation,
174        } = record;
175
176        match operation {
177            Operation::Get {
178                key,
179                epoch,
180                read_options,
181            } => {
182                let actual = match storage_type {
183                    StorageType::Global => {
184                        // epoch must be Some
185                        let epoch = epoch.unwrap();
186                        replay.get(key, epoch, read_options).await
187                    }
188                    StorageType::Local(_, local_storage_id) => {
189                        let opts = local_storage_opts_map.get(&local_storage_id).unwrap();
190                        assert_eq!(opts.table_id, read_options.table_id);
191                        let s = local_storages.get_mut(&storage_type).unwrap();
192                        s.get(key, read_options).await
193                    }
194                };
195
196                let res = res_rx.recv().await.expect("recv result failed");
197                if let OperationResult::Get(expected) = res {
198                    assert_eq!(TraceResult::from(actual), expected, "get result wrong");
199                } else {
200                    panic!("expect get result, but got {:?}", res);
201                }
202            }
203            Operation::Insert {
204                key,
205                new_val,
206                old_val,
207            } => {
208                let local_storage = local_storages.get_mut(&storage_type).unwrap();
209                let actual = local_storage.insert(key, new_val, old_val);
210
211                let expected = res_rx.recv().await.expect("recv result failed");
212                if let OperationResult::Insert(expected) = expected {
213                    assert_eq!(TraceResult::from(actual), expected, "get result wrong");
214                } else {
215                    panic!("expect insert result, but got {:?}", expected);
216                }
217            }
218            Operation::Delete { key, old_val } => {
219                let local_storage = local_storages.get_mut(&storage_type).unwrap();
220                let actual = local_storage.delete(key, old_val);
221
222                let expected = res_rx.recv().await.expect("recv result failed");
223                if let OperationResult::Delete(expected) = expected {
224                    assert_eq!(TraceResult::from(actual), expected, "get result wrong");
225                } else {
226                    panic!("expect delete result, but got {:?}", expected);
227                }
228            }
229            Operation::Iter {
230                key_range,
231                epoch,
232                read_options,
233            } => {
234                let iter = match storage_type {
235                    StorageType::Global => {
236                        // Global Storage must have a epoch
237                        let epoch = epoch.unwrap();
238                        replay.iter(key_range, epoch, read_options).await
239                    }
240                    StorageType::Local(_, id) => {
241                        let opts = local_storage_opts_map.get(&id).unwrap();
242                        assert_eq!(opts.table_id, read_options.table_id);
243                        let s = local_storages.get_mut(&storage_type).unwrap();
244                        s.iter(key_range, read_options).await
245                    }
246                };
247                let res = res_rx.recv().await.expect("recv result failed");
248                if let OperationResult::Iter(expected) = res {
249                    if expected.is_ok() {
250                        let iter = iter.unwrap().boxed();
251                        let id = record_id;
252                        iters_map.insert(id, iter);
253                    } else {
254                        assert!(iter.is_err());
255                    }
256                } else {
257                    panic!("expect iter result, but got {:?}", res);
258                }
259            }
260            Operation::Sync(sync_table_epochs) => {
261                assert_eq!(storage_type, StorageType::Global);
262                let sync_result = replay.sync(sync_table_epochs).await.unwrap();
263                let res = res_rx.recv().await.expect("recv result failed");
264                if let OperationResult::Sync(expected) = res {
265                    assert_eq!(TraceResult::Ok(sync_result), expected, "sync failed");
266                } else {
267                    panic!("expect sync result, but got {:?}", res);
268                }
269            }
270            Operation::IterNext(id) => {
271                let iter = iters_map.get_mut(&id).expect("iter not in worker");
272                let actual = iter.next().await;
273                let actual = actual.map(|res| res.unwrap());
274                let res = res_rx.recv().await.expect("recv result failed");
275                if let OperationResult::IterNext(expected) = res {
276                    assert_eq!(TraceResult::Ok(actual), expected, "iter_next result wrong");
277                } else {
278                    panic!("expect iter_next result, but got {:?}", res);
279                }
280            }
281            Operation::NewLocalStorage(new_local_opts, id) => {
282                assert_ne!(storage_type, StorageType::Global);
283                local_storage_opts_map.insert(id, new_local_opts.clone());
284                let local_storage = replay.new_local(new_local_opts).await;
285                local_storages.insert(storage_type, local_storage);
286            }
287            Operation::DropLocalStorage => {
288                assert_ne!(storage_type, StorageType::Global);
289                if let StorageType::Local(_, _) = storage_type {
290                    local_storages.remove(&storage_type);
291                }
292                // All local storages have been dropped, we should shutdown this worker
293                // If there are incoming new_local, this ReplayWorker will spawn again
294                if local_storages.is_empty() {
295                    *should_shutdown = true;
296                }
297            }
298            Operation::MetaMessage(resp) => {
299                assert_eq!(storage_type, StorageType::Global);
300                let op = resp.0.operation();
301                if let Some(info) = resp.0.info {
302                    replay
303                        .notify_hummock(info, op, resp.0.version)
304                        .await
305                        .unwrap();
306                }
307            }
308            Operation::LocalStorageInit(options) => {
309                assert_ne!(storage_type, StorageType::Global);
310                let local_storage = local_storages.get_mut(&storage_type).unwrap();
311                local_storage.init(options).await.unwrap();
312            }
313            Operation::TryWaitEpoch(epoch, options) => {
314                assert_eq!(storage_type, StorageType::Global);
315                let res = res_rx.recv().await.expect("recv result failed");
316                if let OperationResult::TryWaitEpoch(expected) = res {
317                    let actual = replay.try_wait_epoch(epoch.into(), options).await;
318                    assert_eq!(TraceResult::from(actual), expected, "try_wait_epoch wrong");
319                } else {
320                    panic!(
321                        "wrong try_wait_epoch result, expect epoch result, but got {:?}",
322                        res
323                    );
324                }
325            }
326            Operation::SealCurrentEpoch { epoch, opts } => {
327                assert_ne!(storage_type, StorageType::Global);
328                let local_storage = local_storages.get_mut(&storage_type).unwrap();
329                local_storage.seal_current_epoch(epoch, opts);
330            }
331            Operation::LocalStorageEpoch => {
332                assert_ne!(storage_type, StorageType::Global);
333                let local_storage = local_storages.get_mut(&storage_type).unwrap();
334                let res = res_rx.recv().await.expect("recv result failed");
335                if let OperationResult::LocalStorageEpoch(expected) = res {
336                    let actual = local_storage.epoch();
337                    assert_eq!(TraceResult::Ok(actual), expected, "epoch wrong");
338                } else {
339                    panic!(
340                        "wrong local storage epoch result, expect epoch result, but got {:?}",
341                        res
342                    );
343                }
344            }
345            Operation::LocalStorageIsDirty => {
346                assert_ne!(storage_type, StorageType::Global);
347                let local_storage = local_storages.get_mut(&storage_type).unwrap();
348                let res = res_rx.recv().await.expect("recv result failed");
349                if let OperationResult::LocalStorageIsDirty(expected) = res {
350                    let actual = local_storage.is_dirty();
351                    assert_eq!(
352                        TraceResult::Ok(actual),
353                        expected,
354                        "is_dirty wrong, epoch: {}",
355                        local_storage.epoch()
356                    );
357                } else {
358                    panic!(
359                        "wrong local storage is_dirty result, expect is_dirty result, but got {:?}",
360                        res
361                    );
362                }
363            }
364            Operation::Flush => {
365                assert_ne!(storage_type, StorageType::Global);
366                let local_storage = local_storages.get_mut(&storage_type).unwrap();
367                let res = res_rx.recv().await.expect("recv result failed");
368                if let OperationResult::Flush(expected) = res {
369                    let actual = local_storage.flush().await;
370                    assert_eq!(TraceResult::from(actual), expected, "flush wrong");
371                } else {
372                    panic!("wrong flush result, expect flush result, but got {:?}", res);
373                }
374            }
375            Operation::TryFlush => {
376                assert_ne!(storage_type, StorageType::Global);
377                let local_storage = local_storages.get_mut(&storage_type).unwrap();
378                let res = res_rx.recv().await.expect("recv result failed");
379                if let OperationResult::TryFlush(_) = res {
380                    let _ = local_storage.try_flush().await;
381                    // todo(wcy-fdu): unify try_flush and flush interface, do not return usize.
382                    // assert_eq!(TraceResult::from(actual), expected, "try flush wrong");
383                } else {
384                    panic!(
385                        "wrong try flush result, expect flush result, but got {:?}",
386                        res
387                    );
388                }
389            }
390
391            Operation::Finish => unreachable!(),
392            Operation::Result(_) => unreachable!(),
393        }
394    }
395}
396
397fn allocate_worker_id(record: &Record) -> WorkerId {
398    match record.storage_type() {
399        StorageType::Local(concurrent_id, id) => WorkerId::Local(*concurrent_id, *id),
400        StorageType::Global => WorkerId::OneShot(record.record_id()),
401    }
402}
403
404struct WorkerHandler {
405    req_tx: UnboundedSender<ReplayRequest>,
406    res_tx: UnboundedSender<OperationResult>,
407    record_end_resp_rx: UnboundedReceiver<WorkerResponse>,
408    join: JoinHandle<()>,
409    // Used for ops like iter, since a iter may have multiple next
410    // Example
411    // Iter begin, Iter next, Iter next...., Iter finish
412    // So replay requests may be stacked
413    stacked_replay_reqs: HashMap<u64, u32>,
414}
415
416impl WorkerHandler {
417    async fn join(self) {
418        self.join.await.expect("failed to stop worker");
419    }
420
421    fn finish(&self) {
422        self.send_replay_req(None);
423    }
424
425    fn replay(&mut self, req: ReplayRequest) {
426        if let Some(r) = &req {
427            let entry = self.stacked_replay_reqs.entry(r.record_id).or_insert(0);
428            *entry += 1;
429        }
430        self.send_replay_req(req);
431    }
432
433    async fn wait(&mut self, record_id: u64) -> Option<WorkerResponse> {
434        let mut stacked_replay_reqs = *self.stacked_replay_reqs.get(&record_id).unwrap();
435        assert!(
436            stacked_replay_reqs > 0,
437            "replay count should be 0, but found {}",
438            stacked_replay_reqs
439        );
440        let mut resp = None;
441
442        while stacked_replay_reqs > 0 {
443            resp = Some(
444                self.record_end_resp_rx
445                    .recv()
446                    .await
447                    .expect("failed to wait worker resp"),
448            );
449            stacked_replay_reqs -= 1;
450        }
451        // cleaned this record from replay worker
452        self.stacked_replay_reqs.remove(&record_id);
453        // impossible to be None
454        resp
455    }
456
457    fn send_replay_req(&self, req: ReplayRequest) {
458        self.req_tx
459            .send(req)
460            .expect("failed to send replay request");
461    }
462
463    fn send_result(&self, result: OperationResult) {
464        self.res_tx.send(result).expect("failed to send result");
465    }
466}
467
468struct LocalStorages {
469    storages: HashMap<StorageType, Box<dyn LocalReplay>>,
470}
471
472impl LocalStorages {
473    fn new() -> Self {
474        Self {
475            storages: HashMap::new(),
476        }
477    }
478
479    fn remove(&mut self, storage_type: &StorageType) {
480        self.storages.remove(storage_type);
481    }
482
483    fn get_mut(&mut self, storage_type: &StorageType) -> Option<&mut Box<dyn LocalReplay>> {
484        self.storages.get_mut(storage_type)
485    }
486
487    fn insert(&mut self, storage_type: StorageType, local_storage: Box<dyn LocalReplay>) {
488        self.storages.insert(storage_type, local_storage);
489    }
490
491    fn is_empty(&self) -> bool {
492        self.storages.is_empty()
493    }
494
495    #[cfg(test)]
496    fn len(&self) -> usize {
497        self.storages.len()
498    }
499}
500
501#[cfg(test)]
502mod tests {
503
504    use std::ops::Bound;
505
506    use bytes::Bytes;
507    use mockall::predicate;
508
509    use super::*;
510    use crate::replay::{MockGlobalReplayInterface, MockLocalReplayInterface};
511    use crate::{MockReplayIterStream, TracedBytes, TracedReadOptions};
512
513    #[tokio::test]
514    async fn test_handle_record() {
515        let mut iters_map = HashMap::new();
516        let mut local_storage_opts_map = HashMap::new();
517        let mut local_storages = LocalStorages::new();
518        let (res_tx, mut res_rx) = unbounded_channel();
519        let get_table_id = 12;
520        let iter_table_id = 14654;
521        let read_options = TracedReadOptions::for_test(get_table_id);
522        let iter_read_options = TracedReadOptions::for_test(iter_table_id);
523        let op = Operation::get(Bytes::from(vec![123]), Some(123), read_options);
524
525        let new_local_opts = TracedNewLocalOptions::for_test(get_table_id);
526
527        let iter_local_opts = TracedNewLocalOptions::for_test(iter_table_id);
528        let mut should_exit = false;
529        let get_storage_type = StorageType::Local(0, 0);
530        let record = Record::new(get_storage_type, 1, op);
531        let mut mock_replay = MockGlobalReplayInterface::new();
532
533        mock_replay.expect_new_local().times(1).returning(move |_| {
534            let mut mock_local = MockLocalReplayInterface::new();
535
536            mock_local
537                .expect_get()
538                .with(
539                    predicate::eq(TracedBytes::from(vec![123])),
540                    predicate::always(),
541                )
542                .returning(|_, _| Ok(Some(TracedBytes::from(vec![120]))));
543
544            Box::new(mock_local)
545        });
546
547        mock_replay.expect_new_local().times(1).returning(move |_| {
548            let mut mock_local = MockLocalReplayInterface::new();
549
550            mock_local
551                .expect_iter()
552                .with(
553                    predicate::eq((Bound::Unbounded, Bound::Unbounded)),
554                    predicate::always(),
555                )
556                .returning(move |_, _| {
557                    let iter = MockReplayIterStream::new(vec![(
558                        TracedBytes::from(vec![1]),
559                        TracedBytes::from(vec![0]),
560                    )]);
561                    Ok(iter.into_stream().boxed())
562                });
563
564            Box::new(mock_local)
565        });
566
567        let replay = Arc::new(mock_replay);
568
569        ReplayWorker::handle_record(
570            Record::new(
571                get_storage_type,
572                0,
573                Operation::NewLocalStorage(new_local_opts, 0),
574            ),
575            &replay,
576            &mut res_rx,
577            &mut iters_map,
578            &mut local_storages,
579            &mut local_storage_opts_map,
580            &mut should_exit,
581        )
582        .await;
583
584        res_tx
585            .send(OperationResult::Get(TraceResult::Ok(Some(
586                TracedBytes::from(vec![120]),
587            ))))
588            .unwrap();
589
590        ReplayWorker::handle_record(
591            record,
592            &replay,
593            &mut res_rx,
594            &mut iters_map,
595            &mut local_storages,
596            &mut local_storage_opts_map,
597            &mut should_exit,
598        )
599        .await;
600
601        assert_eq!(local_storages.len(), 1);
602        assert!(iters_map.is_empty());
603
604        let op = Operation::Iter {
605            key_range: (Bound::Unbounded, Bound::Unbounded),
606            epoch: Some(45),
607            read_options: iter_read_options,
608        };
609
610        let iter_storage_type = StorageType::Local(0, 1);
611
612        ReplayWorker::handle_record(
613            Record::new(
614                iter_storage_type,
615                2,
616                Operation::NewLocalStorage(iter_local_opts, 1),
617            ),
618            &replay,
619            &mut res_rx,
620            &mut iters_map,
621            &mut local_storages,
622            &mut local_storage_opts_map,
623            &mut should_exit,
624        )
625        .await;
626
627        let record = Record::new(iter_storage_type, 1, op);
628        res_tx
629            .send(OperationResult::Iter(TraceResult::Ok(())))
630            .unwrap();
631
632        ReplayWorker::handle_record(
633            record,
634            &replay,
635            &mut res_rx,
636            &mut iters_map,
637            &mut local_storages,
638            &mut local_storage_opts_map,
639            &mut should_exit,
640        )
641        .await;
642
643        assert_eq!(local_storages.len(), 2);
644        assert_eq!(iters_map.len(), 1);
645
646        let op = Operation::IterNext(1);
647        let record = Record::new(iter_storage_type, 3, op);
648        res_tx
649            .send(OperationResult::IterNext(TraceResult::Ok(Some((
650                TracedBytes::from(vec![1]),
651                TracedBytes::from(vec![0]),
652            )))))
653            .unwrap();
654
655        ReplayWorker::handle_record(
656            record,
657            &replay,
658            &mut res_rx,
659            &mut iters_map,
660            &mut local_storages,
661            &mut local_storage_opts_map,
662            &mut should_exit,
663        )
664        .await;
665
666        assert_eq!(local_storages.len(), 2);
667        assert_eq!(iters_map.len(), 1);
668    }
669
670    #[tokio::test]
671    async fn test_worker_scheduler() {
672        // Create a mock GlobalReplay and a ReplayWorkerScheduler that uses the mock GlobalReplay.
673        let mut mock_replay = MockGlobalReplayInterface::default();
674        let record_id = 29053;
675        let key = TracedBytes::from(vec![1]);
676        let epoch = 2596;
677        let read_options = TracedReadOptions::for_test(1);
678
679        let res_bytes = TracedBytes::from(vec![58, 54, 35]);
680
681        mock_replay
682            .expect_get()
683            .with(
684                predicate::eq(key.clone()),
685                predicate::eq(epoch),
686                predicate::eq(read_options.clone()),
687            )
688            .returning(move |_, _, _| Ok(Some(TracedBytes::from(vec![58, 54, 35]))));
689
690        let mut scheduler = WorkerScheduler::new(Arc::new(mock_replay));
691        // Schedule a record for replay.
692        let record = Record::new(
693            StorageType::Global,
694            record_id,
695            Operation::get(key.into(), Some(epoch), read_options),
696        );
697        scheduler.schedule(record);
698
699        let result = Record::new(
700            StorageType::Global,
701            record_id,
702            Operation::Result(OperationResult::Get(TraceResult::Ok(Some(res_bytes)))),
703        );
704
705        scheduler.send_result(result);
706
707        let fin = Record::new(StorageType::Global, record_id, Operation::Finish);
708        scheduler.wait_finish(fin).await;
709
710        scheduler.shutdown().await;
711    }
712}