1use std::collections::HashMap;
16use std::sync::Arc;
17
18use futures::StreamExt;
19use futures::stream::BoxStream;
20use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
21use tokio::task::JoinHandle;
22
23use super::{GlobalReplay, LocalReplay, ReplayRequest, WorkerId, WorkerResponse};
24use crate::{
25 LocalStorageId, Operation, OperationResult, Record, RecordId, ReplayItem, Result, StorageType,
26 TraceResult, TracedNewLocalOptions,
27};
28
29#[async_trait::async_trait]
30pub trait ReplayWorkerScheduler {
31 fn schedule(&mut self, record: Record);
33 fn send_result(&mut self, record: Record);
35 async fn wait_finish(&mut self, record: Record);
37 async fn shutdown(self);
39}
40
41pub(crate) struct WorkerScheduler<G: GlobalReplay> {
42 workers: HashMap<WorkerId, WorkerHandler>,
43 replay: Arc<G>,
44}
45
46impl<G: GlobalReplay> WorkerScheduler<G> {
47 pub(crate) fn new(replay: Arc<G>) -> Self {
48 WorkerScheduler {
49 workers: HashMap::new(),
50 replay,
51 }
52 }
53}
54
55#[async_trait::async_trait]
56impl<G: GlobalReplay + 'static> ReplayWorkerScheduler for WorkerScheduler<G> {
57 fn schedule(&mut self, record: Record) {
58 let worker_id = allocate_worker_id(&record);
59 let handler = self
60 .workers
61 .entry(worker_id)
62 .or_insert_with(|| ReplayWorker::spawn(self.replay.clone()));
63
64 handler.replay(Some(record));
65 }
66
67 fn send_result(&mut self, record: Record) {
68 let worker_id = allocate_worker_id(&record);
69 let Record { operation, .. } = record;
70 if let (Some(handler), Operation::Result(trace_result)) =
73 (self.workers.get_mut(&worker_id), operation)
74 {
75 handler.send_result(trace_result);
78 }
79 }
80
81 async fn wait_finish(&mut self, record: Record) {
82 let worker_id = allocate_worker_id(&record);
83
84 if let Some(handler) = self.workers.get_mut(&worker_id) {
86 let resp = handler.wait(record.record_id).await;
88
89 if matches!(worker_id, WorkerId::OneShot(_))
92 || matches!(resp, Some(WorkerResponse::Shutdown))
93 {
94 let handler = self.workers.remove(&worker_id).unwrap();
95 handler.finish();
96 }
97 }
98 }
99
100 async fn shutdown(self) {
101 for (_, handler) in self.workers {
103 handler.finish();
104 handler.join().await;
105 }
106 }
107}
108
109struct ReplayWorker {}
110
111impl ReplayWorker {
112 fn spawn(replay: Arc<impl GlobalReplay + 'static>) -> WorkerHandler {
113 let (req_tx, req_rx) = unbounded_channel();
114 let (resp_tx, resp_rx) = unbounded_channel();
115 let (res_tx, res_rx) = unbounded_channel();
116
117 let join = tokio::spawn(Self::run(req_rx, res_rx, resp_tx, replay));
118 WorkerHandler {
119 req_tx,
120 res_tx,
121 record_end_resp_rx: resp_rx,
122 join,
123 stacked_replay_reqs: HashMap::new(),
124 }
125 }
126
127 async fn run(
128 mut req_rx: UnboundedReceiver<ReplayRequest>,
129 mut res_rx: UnboundedReceiver<OperationResult>,
130 resp_tx: UnboundedSender<WorkerResponse>,
131 replay: Arc<impl GlobalReplay>,
132 ) {
133 let mut iters_map: HashMap<RecordId, BoxStream<'static, Result<ReplayItem>>> =
134 HashMap::new();
135 let mut local_storages = LocalStorages::new();
136 let mut should_shutdown = false;
137 let mut local_storage_opts_map = HashMap::new();
138
139 while let Some(Some(record)) = req_rx.recv().await {
140 Self::handle_record(
141 record.clone(),
142 &replay,
143 &mut res_rx,
144 &mut iters_map,
145 &mut local_storages,
146 &mut local_storage_opts_map,
147 &mut should_shutdown,
148 )
149 .await;
150
151 let message = if should_shutdown {
152 WorkerResponse::Shutdown
153 } else {
154 WorkerResponse::Continue
155 };
156
157 resp_tx.send(message).expect("Failed to send message");
158 }
159 }
160
161 async fn handle_record(
162 record: Record,
163 replay: &Arc<impl GlobalReplay>,
164 res_rx: &mut UnboundedReceiver<OperationResult>,
165 iters_map: &mut HashMap<RecordId, BoxStream<'static, Result<ReplayItem>>>,
166 local_storages: &mut LocalStorages,
167 local_storage_opts_map: &mut HashMap<LocalStorageId, TracedNewLocalOptions>,
168 should_shutdown: &mut bool,
169 ) {
170 let Record {
171 storage_type,
172 record_id,
173 operation,
174 } = record;
175
176 match operation {
177 Operation::Get {
178 key,
179 epoch,
180 read_options,
181 } => {
182 let actual = match storage_type {
183 StorageType::Global => {
184 let epoch = epoch.unwrap();
186 replay.get(key, epoch, read_options).await
187 }
188 StorageType::Local(_, local_storage_id) => {
189 let opts = local_storage_opts_map.get(&local_storage_id).unwrap();
190 assert_eq!(opts.table_id, read_options.table_id);
191 let s = local_storages.get_mut(&storage_type).unwrap();
192 s.get(key, read_options).await
193 }
194 };
195
196 let res = res_rx.recv().await.expect("recv result failed");
197 if let OperationResult::Get(expected) = res {
198 assert_eq!(TraceResult::from(actual), expected, "get result wrong");
199 } else {
200 panic!("expect get result, but got {:?}", res);
201 }
202 }
203 Operation::Insert {
204 key,
205 new_val,
206 old_val,
207 } => {
208 let local_storage = local_storages.get_mut(&storage_type).unwrap();
209 let actual = local_storage.insert(key, new_val, old_val);
210
211 let expected = res_rx.recv().await.expect("recv result failed");
212 if let OperationResult::Insert(expected) = expected {
213 assert_eq!(TraceResult::from(actual), expected, "get result wrong");
214 } else {
215 panic!("expect insert result, but got {:?}", expected);
216 }
217 }
218 Operation::Delete { key, old_val } => {
219 let local_storage = local_storages.get_mut(&storage_type).unwrap();
220 let actual = local_storage.delete(key, old_val);
221
222 let expected = res_rx.recv().await.expect("recv result failed");
223 if let OperationResult::Delete(expected) = expected {
224 assert_eq!(TraceResult::from(actual), expected, "get result wrong");
225 } else {
226 panic!("expect delete result, but got {:?}", expected);
227 }
228 }
229 Operation::Iter {
230 key_range,
231 epoch,
232 read_options,
233 } => {
234 let iter = match storage_type {
235 StorageType::Global => {
236 let epoch = epoch.unwrap();
238 replay.iter(key_range, epoch, read_options).await
239 }
240 StorageType::Local(_, id) => {
241 let opts = local_storage_opts_map.get(&id).unwrap();
242 assert_eq!(opts.table_id, read_options.table_id);
243 let s = local_storages.get_mut(&storage_type).unwrap();
244 s.iter(key_range, read_options).await
245 }
246 };
247 let res = res_rx.recv().await.expect("recv result failed");
248 if let OperationResult::Iter(expected) = res {
249 if expected.is_ok() {
250 let iter = iter.unwrap().boxed();
251 let id = record_id;
252 iters_map.insert(id, iter);
253 } else {
254 assert!(iter.is_err());
255 }
256 } else {
257 panic!("expect iter result, but got {:?}", res);
258 }
259 }
260 Operation::Sync(sync_table_epochs) => {
261 assert_eq!(storage_type, StorageType::Global);
262 let sync_result = replay.sync(sync_table_epochs).await.unwrap();
263 let res = res_rx.recv().await.expect("recv result failed");
264 if let OperationResult::Sync(expected) = res {
265 assert_eq!(TraceResult::Ok(sync_result), expected, "sync failed");
266 } else {
267 panic!("expect sync result, but got {:?}", res);
268 }
269 }
270 Operation::IterNext(id) => {
271 let iter = iters_map.get_mut(&id).expect("iter not in worker");
272 let actual = iter.next().await;
273 let actual = actual.map(|res| res.unwrap());
274 let res = res_rx.recv().await.expect("recv result failed");
275 if let OperationResult::IterNext(expected) = res {
276 assert_eq!(TraceResult::Ok(actual), expected, "iter_next result wrong");
277 } else {
278 panic!("expect iter_next result, but got {:?}", res);
279 }
280 }
281 Operation::NewLocalStorage(new_local_opts, id) => {
282 assert_ne!(storage_type, StorageType::Global);
283 local_storage_opts_map.insert(id, new_local_opts.clone());
284 let local_storage = replay.new_local(new_local_opts).await;
285 local_storages.insert(storage_type, local_storage);
286 }
287 Operation::DropLocalStorage => {
288 assert_ne!(storage_type, StorageType::Global);
289 if let StorageType::Local(_, _) = storage_type {
290 local_storages.remove(&storage_type);
291 }
292 if local_storages.is_empty() {
295 *should_shutdown = true;
296 }
297 }
298 Operation::MetaMessage(resp) => {
299 assert_eq!(storage_type, StorageType::Global);
300 let op = resp.0.operation();
301 if let Some(info) = resp.0.info {
302 replay
303 .notify_hummock(info, op, resp.0.version)
304 .await
305 .unwrap();
306 }
307 }
308 Operation::LocalStorageInit(options) => {
309 assert_ne!(storage_type, StorageType::Global);
310 let local_storage = local_storages.get_mut(&storage_type).unwrap();
311 local_storage.init(options).await.unwrap();
312 }
313 Operation::TryWaitEpoch(epoch, options) => {
314 assert_eq!(storage_type, StorageType::Global);
315 let res = res_rx.recv().await.expect("recv result failed");
316 if let OperationResult::TryWaitEpoch(expected) = res {
317 let actual = replay.try_wait_epoch(epoch.into(), options).await;
318 assert_eq!(TraceResult::from(actual), expected, "try_wait_epoch wrong");
319 } else {
320 panic!(
321 "wrong try_wait_epoch result, expect epoch result, but got {:?}",
322 res
323 );
324 }
325 }
326 Operation::SealCurrentEpoch { epoch, opts } => {
327 assert_ne!(storage_type, StorageType::Global);
328 let local_storage = local_storages.get_mut(&storage_type).unwrap();
329 local_storage.seal_current_epoch(epoch, opts);
330 }
331 Operation::LocalStorageEpoch => {
332 assert_ne!(storage_type, StorageType::Global);
333 let local_storage = local_storages.get_mut(&storage_type).unwrap();
334 let res = res_rx.recv().await.expect("recv result failed");
335 if let OperationResult::LocalStorageEpoch(expected) = res {
336 let actual = local_storage.epoch();
337 assert_eq!(TraceResult::Ok(actual), expected, "epoch wrong");
338 } else {
339 panic!(
340 "wrong local storage epoch result, expect epoch result, but got {:?}",
341 res
342 );
343 }
344 }
345 Operation::LocalStorageIsDirty => {
346 assert_ne!(storage_type, StorageType::Global);
347 let local_storage = local_storages.get_mut(&storage_type).unwrap();
348 let res = res_rx.recv().await.expect("recv result failed");
349 if let OperationResult::LocalStorageIsDirty(expected) = res {
350 let actual = local_storage.is_dirty();
351 assert_eq!(
352 TraceResult::Ok(actual),
353 expected,
354 "is_dirty wrong, epoch: {}",
355 local_storage.epoch()
356 );
357 } else {
358 panic!(
359 "wrong local storage is_dirty result, expect is_dirty result, but got {:?}",
360 res
361 );
362 }
363 }
364 Operation::Flush => {
365 assert_ne!(storage_type, StorageType::Global);
366 let local_storage = local_storages.get_mut(&storage_type).unwrap();
367 let res = res_rx.recv().await.expect("recv result failed");
368 if let OperationResult::Flush(expected) = res {
369 let actual = local_storage.flush().await;
370 assert_eq!(TraceResult::from(actual), expected, "flush wrong");
371 } else {
372 panic!("wrong flush result, expect flush result, but got {:?}", res);
373 }
374 }
375 Operation::TryFlush => {
376 assert_ne!(storage_type, StorageType::Global);
377 let local_storage = local_storages.get_mut(&storage_type).unwrap();
378 let res = res_rx.recv().await.expect("recv result failed");
379 if let OperationResult::TryFlush(_) = res {
380 let _ = local_storage.try_flush().await;
381 } else {
384 panic!(
385 "wrong try flush result, expect flush result, but got {:?}",
386 res
387 );
388 }
389 }
390
391 Operation::Finish => unreachable!(),
392 Operation::Result(_) => unreachable!(),
393 }
394 }
395}
396
397fn allocate_worker_id(record: &Record) -> WorkerId {
398 match record.storage_type() {
399 StorageType::Local(concurrent_id, id) => WorkerId::Local(*concurrent_id, *id),
400 StorageType::Global => WorkerId::OneShot(record.record_id()),
401 }
402}
403
404struct WorkerHandler {
405 req_tx: UnboundedSender<ReplayRequest>,
406 res_tx: UnboundedSender<OperationResult>,
407 record_end_resp_rx: UnboundedReceiver<WorkerResponse>,
408 join: JoinHandle<()>,
409 stacked_replay_reqs: HashMap<u64, u32>,
414}
415
416impl WorkerHandler {
417 async fn join(self) {
418 self.join.await.expect("failed to stop worker");
419 }
420
421 fn finish(&self) {
422 self.send_replay_req(None);
423 }
424
425 fn replay(&mut self, req: ReplayRequest) {
426 if let Some(r) = &req {
427 let entry = self.stacked_replay_reqs.entry(r.record_id).or_insert(0);
428 *entry += 1;
429 }
430 self.send_replay_req(req);
431 }
432
433 async fn wait(&mut self, record_id: u64) -> Option<WorkerResponse> {
434 let mut stacked_replay_reqs = *self.stacked_replay_reqs.get(&record_id).unwrap();
435 assert!(
436 stacked_replay_reqs > 0,
437 "replay count should be 0, but found {}",
438 stacked_replay_reqs
439 );
440 let mut resp = None;
441
442 while stacked_replay_reqs > 0 {
443 resp = Some(
444 self.record_end_resp_rx
445 .recv()
446 .await
447 .expect("failed to wait worker resp"),
448 );
449 stacked_replay_reqs -= 1;
450 }
451 self.stacked_replay_reqs.remove(&record_id);
453 resp
455 }
456
457 fn send_replay_req(&self, req: ReplayRequest) {
458 self.req_tx
459 .send(req)
460 .expect("failed to send replay request");
461 }
462
463 fn send_result(&self, result: OperationResult) {
464 self.res_tx.send(result).expect("failed to send result");
465 }
466}
467
468struct LocalStorages {
469 storages: HashMap<StorageType, Box<dyn LocalReplay>>,
470}
471
472impl LocalStorages {
473 fn new() -> Self {
474 Self {
475 storages: HashMap::new(),
476 }
477 }
478
479 fn remove(&mut self, storage_type: &StorageType) {
480 self.storages.remove(storage_type);
481 }
482
483 fn get_mut(&mut self, storage_type: &StorageType) -> Option<&mut Box<dyn LocalReplay>> {
484 self.storages.get_mut(storage_type)
485 }
486
487 fn insert(&mut self, storage_type: StorageType, local_storage: Box<dyn LocalReplay>) {
488 self.storages.insert(storage_type, local_storage);
489 }
490
491 fn is_empty(&self) -> bool {
492 self.storages.is_empty()
493 }
494
495 #[cfg(test)]
496 fn len(&self) -> usize {
497 self.storages.len()
498 }
499}
500
501#[cfg(test)]
502mod tests {
503
504 use std::ops::Bound;
505
506 use bytes::Bytes;
507 use mockall::predicate;
508
509 use super::*;
510 use crate::replay::{MockGlobalReplayInterface, MockLocalReplayInterface};
511 use crate::{MockReplayIterStream, TracedBytes, TracedReadOptions};
512
513 #[tokio::test]
514 async fn test_handle_record() {
515 let mut iters_map = HashMap::new();
516 let mut local_storage_opts_map = HashMap::new();
517 let mut local_storages = LocalStorages::new();
518 let (res_tx, mut res_rx) = unbounded_channel();
519 let get_table_id = 12;
520 let iter_table_id = 14654;
521 let read_options = TracedReadOptions::for_test(get_table_id);
522 let iter_read_options = TracedReadOptions::for_test(iter_table_id);
523 let op = Operation::get(Bytes::from(vec![123]), Some(123), read_options);
524
525 let new_local_opts = TracedNewLocalOptions::for_test(get_table_id);
526
527 let iter_local_opts = TracedNewLocalOptions::for_test(iter_table_id);
528 let mut should_exit = false;
529 let get_storage_type = StorageType::Local(0, 0);
530 let record = Record::new(get_storage_type, 1, op);
531 let mut mock_replay = MockGlobalReplayInterface::new();
532
533 mock_replay.expect_new_local().times(1).returning(move |_| {
534 let mut mock_local = MockLocalReplayInterface::new();
535
536 mock_local
537 .expect_get()
538 .with(
539 predicate::eq(TracedBytes::from(vec![123])),
540 predicate::always(),
541 )
542 .returning(|_, _| Ok(Some(TracedBytes::from(vec![120]))));
543
544 Box::new(mock_local)
545 });
546
547 mock_replay.expect_new_local().times(1).returning(move |_| {
548 let mut mock_local = MockLocalReplayInterface::new();
549
550 mock_local
551 .expect_iter()
552 .with(
553 predicate::eq((Bound::Unbounded, Bound::Unbounded)),
554 predicate::always(),
555 )
556 .returning(move |_, _| {
557 let iter = MockReplayIterStream::new(vec![(
558 TracedBytes::from(vec![1]),
559 TracedBytes::from(vec![0]),
560 )]);
561 Ok(iter.into_stream().boxed())
562 });
563
564 Box::new(mock_local)
565 });
566
567 let replay = Arc::new(mock_replay);
568
569 ReplayWorker::handle_record(
570 Record::new(
571 get_storage_type,
572 0,
573 Operation::NewLocalStorage(new_local_opts, 0),
574 ),
575 &replay,
576 &mut res_rx,
577 &mut iters_map,
578 &mut local_storages,
579 &mut local_storage_opts_map,
580 &mut should_exit,
581 )
582 .await;
583
584 res_tx
585 .send(OperationResult::Get(TraceResult::Ok(Some(
586 TracedBytes::from(vec![120]),
587 ))))
588 .unwrap();
589
590 ReplayWorker::handle_record(
591 record,
592 &replay,
593 &mut res_rx,
594 &mut iters_map,
595 &mut local_storages,
596 &mut local_storage_opts_map,
597 &mut should_exit,
598 )
599 .await;
600
601 assert_eq!(local_storages.len(), 1);
602 assert!(iters_map.is_empty());
603
604 let op = Operation::Iter {
605 key_range: (Bound::Unbounded, Bound::Unbounded),
606 epoch: Some(45),
607 read_options: iter_read_options,
608 };
609
610 let iter_storage_type = StorageType::Local(0, 1);
611
612 ReplayWorker::handle_record(
613 Record::new(
614 iter_storage_type,
615 2,
616 Operation::NewLocalStorage(iter_local_opts, 1),
617 ),
618 &replay,
619 &mut res_rx,
620 &mut iters_map,
621 &mut local_storages,
622 &mut local_storage_opts_map,
623 &mut should_exit,
624 )
625 .await;
626
627 let record = Record::new(iter_storage_type, 1, op);
628 res_tx
629 .send(OperationResult::Iter(TraceResult::Ok(())))
630 .unwrap();
631
632 ReplayWorker::handle_record(
633 record,
634 &replay,
635 &mut res_rx,
636 &mut iters_map,
637 &mut local_storages,
638 &mut local_storage_opts_map,
639 &mut should_exit,
640 )
641 .await;
642
643 assert_eq!(local_storages.len(), 2);
644 assert_eq!(iters_map.len(), 1);
645
646 let op = Operation::IterNext(1);
647 let record = Record::new(iter_storage_type, 3, op);
648 res_tx
649 .send(OperationResult::IterNext(TraceResult::Ok(Some((
650 TracedBytes::from(vec![1]),
651 TracedBytes::from(vec![0]),
652 )))))
653 .unwrap();
654
655 ReplayWorker::handle_record(
656 record,
657 &replay,
658 &mut res_rx,
659 &mut iters_map,
660 &mut local_storages,
661 &mut local_storage_opts_map,
662 &mut should_exit,
663 )
664 .await;
665
666 assert_eq!(local_storages.len(), 2);
667 assert_eq!(iters_map.len(), 1);
668 }
669
670 #[tokio::test]
671 async fn test_worker_scheduler() {
672 let mut mock_replay = MockGlobalReplayInterface::default();
674 let record_id = 29053;
675 let key = TracedBytes::from(vec![1]);
676 let epoch = 2596;
677 let read_options = TracedReadOptions::for_test(1);
678
679 let res_bytes = TracedBytes::from(vec![58, 54, 35]);
680
681 mock_replay
682 .expect_get()
683 .with(
684 predicate::eq(key.clone()),
685 predicate::eq(epoch),
686 predicate::eq(read_options.clone()),
687 )
688 .returning(move |_, _, _| Ok(Some(TracedBytes::from(vec![58, 54, 35]))));
689
690 let mut scheduler = WorkerScheduler::new(Arc::new(mock_replay));
691 let record = Record::new(
693 StorageType::Global,
694 record_id,
695 Operation::get(key.into(), Some(epoch), read_options),
696 );
697 scheduler.schedule(record);
698
699 let result = Record::new(
700 StorageType::Global,
701 record_id,
702 Operation::Result(OperationResult::Get(TraceResult::Ok(Some(res_bytes)))),
703 );
704
705 scheduler.send_result(result);
706
707 let fin = Record::new(StorageType::Global, record_id, Operation::Finish);
708 scheduler.wait_finish(fin).await;
709
710 scheduler.shutdown().await;
711 }
712}