1use std::collections::HashMap;
16use std::sync::Arc;
17
18use futures::StreamExt;
19use futures::stream::BoxStream;
20use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
21use tokio::task::JoinHandle;
22
23use super::{GlobalReplay, LocalReplay, ReplayRequest, WorkerId, WorkerResponse};
24use crate::{
25 LocalStorageId, Operation, OperationResult, Record, RecordId, ReplayItem, Result, StorageType,
26 TraceResult, TracedNewLocalOptions,
27};
28
29#[async_trait::async_trait]
30pub trait ReplayWorkerScheduler {
31 fn schedule(&mut self, record: Record);
33 fn send_result(&mut self, record: Record);
35 async fn wait_finish(&mut self, record: Record);
37 async fn shutdown(self);
39}
40
41pub(crate) struct WorkerScheduler<G: GlobalReplay> {
42 workers: HashMap<WorkerId, WorkerHandler>,
43 replay: Arc<G>,
44}
45
46impl<G: GlobalReplay> WorkerScheduler<G> {
47 pub(crate) fn new(replay: Arc<G>) -> Self {
48 WorkerScheduler {
49 workers: HashMap::new(),
50 replay,
51 }
52 }
53}
54
55#[async_trait::async_trait]
56impl<G: GlobalReplay + 'static> ReplayWorkerScheduler for WorkerScheduler<G> {
57 fn schedule(&mut self, record: Record) {
58 let worker_id = allocate_worker_id(&record);
59 let handler = self
60 .workers
61 .entry(worker_id)
62 .or_insert_with(|| ReplayWorker::spawn(self.replay.clone()));
63
64 handler.replay(Some(record));
65 }
66
67 fn send_result(&mut self, record: Record) {
68 let worker_id = allocate_worker_id(&record);
69 let Record { operation, .. } = record;
70 if let (Some(handler), Operation::Result(trace_result)) =
73 (self.workers.get_mut(&worker_id), operation)
74 {
75 handler.send_result(trace_result);
78 }
79 }
80
81 async fn wait_finish(&mut self, record: Record) {
82 let worker_id = allocate_worker_id(&record);
83
84 if let Some(handler) = self.workers.get_mut(&worker_id) {
86 let resp = handler.wait(record.record_id).await;
88
89 if matches!(worker_id, WorkerId::OneShot(_))
92 || matches!(resp, Some(WorkerResponse::Shutdown))
93 {
94 let handler = self.workers.remove(&worker_id).unwrap();
95 handler.finish();
96 }
97 }
98 }
99
100 async fn shutdown(self) {
101 for (_, handler) in self.workers {
103 handler.finish();
104 handler.join().await;
105 }
106 }
107}
108
109struct ReplayWorker {}
110
111impl ReplayWorker {
112 fn spawn(replay: Arc<impl GlobalReplay + 'static>) -> WorkerHandler {
113 let (req_tx, req_rx) = unbounded_channel();
114 let (resp_tx, resp_rx) = unbounded_channel();
115 let (res_tx, res_rx) = unbounded_channel();
116
117 let join = tokio::spawn(Self::run(req_rx, res_rx, resp_tx, replay));
118 WorkerHandler {
119 req_tx,
120 res_tx,
121 record_end_resp_rx: resp_rx,
122 join,
123 stacked_replay_reqs: HashMap::new(),
124 }
125 }
126
127 async fn run(
128 mut req_rx: UnboundedReceiver<ReplayRequest>,
129 mut res_rx: UnboundedReceiver<OperationResult>,
130 resp_tx: UnboundedSender<WorkerResponse>,
131 replay: Arc<impl GlobalReplay>,
132 ) {
133 let mut iters_map: HashMap<RecordId, BoxStream<'static, Result<ReplayItem>>> =
134 HashMap::new();
135 let mut local_storages = LocalStorages::new();
136 let mut should_shutdown = false;
137 let mut local_storage_opts_map = HashMap::new();
138
139 while let Some(Some(record)) = req_rx.recv().await {
140 Self::handle_record(
141 record.clone(),
142 &replay,
143 &mut res_rx,
144 &mut iters_map,
145 &mut local_storages,
146 &mut local_storage_opts_map,
147 &mut should_shutdown,
148 )
149 .await;
150
151 let message = if should_shutdown {
152 WorkerResponse::Shutdown
153 } else {
154 WorkerResponse::Continue
155 };
156
157 resp_tx.send(message).expect("Failed to send message");
158 }
159 }
160
161 async fn handle_record(
162 record: Record,
163 replay: &Arc<impl GlobalReplay>,
164 res_rx: &mut UnboundedReceiver<OperationResult>,
165 iters_map: &mut HashMap<RecordId, BoxStream<'static, Result<ReplayItem>>>,
166 local_storages: &mut LocalStorages,
167 local_storage_opts_map: &mut HashMap<LocalStorageId, TracedNewLocalOptions>,
168 should_shutdown: &mut bool,
169 ) {
170 let Record {
171 storage_type,
172 record_id,
173 operation,
174 } = record;
175
176 match operation {
177 Operation::Get {
178 key,
179 epoch,
180 read_options,
181 } => {
182 let actual = match storage_type {
183 StorageType::Global => {
184 let epoch = epoch.unwrap();
186 replay.get(key, epoch, read_options).await
187 }
188 StorageType::Local(_, local_storage_id) => {
189 let opts = local_storage_opts_map.get(&local_storage_id).unwrap();
190 assert_eq!(opts.table_id, read_options.table_id);
191 let s = local_storages.get_mut(&storage_type).unwrap();
192 s.get(key, read_options).await
193 }
194 };
195
196 let res = res_rx.recv().await.expect("recv result failed");
197 if let OperationResult::Get(expected) = res {
198 assert_eq!(TraceResult::from(actual), expected, "get result wrong");
199 } else {
200 panic!("expect get result, but got {:?}", res);
201 }
202 }
203 Operation::Insert {
204 key,
205 new_val,
206 old_val,
207 } => {
208 let local_storage = local_storages.get_mut(&storage_type).unwrap();
209 let actual = local_storage.insert(key, new_val, old_val);
210
211 let expected = res_rx.recv().await.expect("recv result failed");
212 if let OperationResult::Insert(expected) = expected {
213 assert_eq!(TraceResult::from(actual), expected, "get result wrong");
214 } else {
215 panic!("expect insert result, but got {:?}", expected);
216 }
217 }
218 Operation::Delete { key, old_val } => {
219 let local_storage = local_storages.get_mut(&storage_type).unwrap();
220 let actual = local_storage.delete(key, old_val);
221
222 let expected = res_rx.recv().await.expect("recv result failed");
223 if let OperationResult::Delete(expected) = expected {
224 assert_eq!(TraceResult::from(actual), expected, "get result wrong");
225 } else {
226 panic!("expect delete result, but got {:?}", expected);
227 }
228 }
229 Operation::Iter {
230 key_range,
231 epoch,
232 read_options,
233 } => {
234 let iter = match storage_type {
235 StorageType::Global => {
236 let epoch = epoch.unwrap();
238 replay.iter(key_range, epoch, read_options).await
239 }
240 StorageType::Local(_, id) => {
241 let opts = local_storage_opts_map.get(&id).unwrap();
242 assert_eq!(opts.table_id, read_options.table_id);
243 let s = local_storages.get_mut(&storage_type).unwrap();
244 s.iter(key_range, read_options).await
245 }
246 };
247 let res = res_rx.recv().await.expect("recv result failed");
248 if let OperationResult::Iter(expected) = res {
249 if expected.is_ok() {
250 let iter = iter.unwrap().boxed();
251 let id = record_id;
252 iters_map.insert(id, iter);
253 } else {
254 assert!(iter.is_err());
255 }
256 } else {
257 panic!("expect iter result, but got {:?}", res);
258 }
259 }
260 Operation::Sync(sync_table_epochs) => {
261 assert_eq!(storage_type, StorageType::Global);
262 let sync_result = replay.sync(sync_table_epochs).await.unwrap();
263 let res = res_rx.recv().await.expect("recv result failed");
264 if let OperationResult::Sync(expected) = res {
265 assert_eq!(TraceResult::Ok(sync_result), expected, "sync failed");
266 } else {
267 panic!("expect sync result, but got {:?}", res);
268 }
269 }
270 Operation::IterNext(id) => {
271 let iter = iters_map.get_mut(&id).expect("iter not in worker");
272 let actual = iter.next().await;
273 let actual = actual.map(|res| res.unwrap());
274 let res = res_rx.recv().await.expect("recv result failed");
275 if let OperationResult::IterNext(expected) = res {
276 assert_eq!(TraceResult::Ok(actual), expected, "iter_next result wrong");
277 } else {
278 panic!("expect iter_next result, but got {:?}", res);
279 }
280 }
281 Operation::NewLocalStorage(new_local_opts, id) => {
282 assert_ne!(storage_type, StorageType::Global);
283 local_storage_opts_map.insert(id, new_local_opts.clone());
284 let local_storage = replay.new_local(new_local_opts).await;
285 local_storages.insert(storage_type, local_storage);
286 }
287 Operation::DropLocalStorage => {
288 assert_ne!(storage_type, StorageType::Global);
289 if let StorageType::Local(_, _) = storage_type {
290 local_storages.remove(&storage_type);
291 }
292 if local_storages.is_empty() {
295 *should_shutdown = true;
296 }
297 }
298 Operation::MetaMessage(resp) => {
299 assert_eq!(storage_type, StorageType::Global);
300 let op = resp.0.operation();
301 if let Some(info) = resp.0.info {
302 replay
303 .notify_hummock(info, op, resp.0.version)
304 .await
305 .unwrap();
306 }
307 }
308 Operation::LocalStorageInit(options) => {
309 assert_ne!(storage_type, StorageType::Global);
310 let local_storage = local_storages.get_mut(&storage_type).unwrap();
311 local_storage.init(options).await.unwrap();
312 }
313 Operation::TryWaitEpoch(epoch, options) => {
314 assert_eq!(storage_type, StorageType::Global);
315 let res = res_rx.recv().await.expect("recv result failed");
316 if let OperationResult::TryWaitEpoch(expected) = res {
317 let actual = replay.try_wait_epoch(epoch.into(), options).await;
318 assert_eq!(TraceResult::from(actual), expected, "try_wait_epoch wrong");
319 } else {
320 panic!(
321 "wrong try_wait_epoch result, expect epoch result, but got {:?}",
322 res
323 );
324 }
325 }
326 Operation::SealCurrentEpoch { epoch, opts } => {
327 assert_ne!(storage_type, StorageType::Global);
328 let local_storage = local_storages.get_mut(&storage_type).unwrap();
329 local_storage.seal_current_epoch(epoch, opts);
330 }
331 Operation::Flush => {
332 assert_ne!(storage_type, StorageType::Global);
333 let local_storage = local_storages.get_mut(&storage_type).unwrap();
334 let res = res_rx.recv().await.expect("recv result failed");
335 if let OperationResult::Flush(expected) = res {
336 let actual = local_storage.flush().await;
337 assert_eq!(TraceResult::from(actual), expected, "flush wrong");
338 } else {
339 panic!("wrong flush result, expect flush result, but got {:?}", res);
340 }
341 }
342 Operation::TryFlush => {
343 assert_ne!(storage_type, StorageType::Global);
344 let local_storage = local_storages.get_mut(&storage_type).unwrap();
345 let res = res_rx.recv().await.expect("recv result failed");
346 if let OperationResult::TryFlush(_) = res {
347 let _ = local_storage.try_flush().await;
348 } else {
351 panic!(
352 "wrong try flush result, expect flush result, but got {:?}",
353 res
354 );
355 }
356 }
357
358 Operation::Finish => unreachable!(),
359 Operation::Result(_) => unreachable!(),
360 }
361 }
362}
363
364fn allocate_worker_id(record: &Record) -> WorkerId {
365 match record.storage_type() {
366 StorageType::Local(concurrent_id, id) => WorkerId::Local(*concurrent_id, *id),
367 StorageType::Global => WorkerId::OneShot(record.record_id()),
368 }
369}
370
371struct WorkerHandler {
372 req_tx: UnboundedSender<ReplayRequest>,
373 res_tx: UnboundedSender<OperationResult>,
374 record_end_resp_rx: UnboundedReceiver<WorkerResponse>,
375 join: JoinHandle<()>,
376 stacked_replay_reqs: HashMap<u64, u32>,
381}
382
383impl WorkerHandler {
384 async fn join(self) {
385 self.join.await.expect("failed to stop worker");
386 }
387
388 fn finish(&self) {
389 self.send_replay_req(None);
390 }
391
392 fn replay(&mut self, req: ReplayRequest) {
393 if let Some(r) = &req {
394 let entry = self.stacked_replay_reqs.entry(r.record_id).or_insert(0);
395 *entry += 1;
396 }
397 self.send_replay_req(req);
398 }
399
400 async fn wait(&mut self, record_id: u64) -> Option<WorkerResponse> {
401 let mut stacked_replay_reqs = *self.stacked_replay_reqs.get(&record_id).unwrap();
402 assert!(
403 stacked_replay_reqs > 0,
404 "replay count should be 0, but found {}",
405 stacked_replay_reqs
406 );
407 let mut resp = None;
408
409 while stacked_replay_reqs > 0 {
410 resp = Some(
411 self.record_end_resp_rx
412 .recv()
413 .await
414 .expect("failed to wait worker resp"),
415 );
416 stacked_replay_reqs -= 1;
417 }
418 self.stacked_replay_reqs.remove(&record_id);
420 resp
422 }
423
424 fn send_replay_req(&self, req: ReplayRequest) {
425 self.req_tx
426 .send(req)
427 .expect("failed to send replay request");
428 }
429
430 fn send_result(&self, result: OperationResult) {
431 self.res_tx.send(result).expect("failed to send result");
432 }
433}
434
435struct LocalStorages {
436 storages: HashMap<StorageType, Box<dyn LocalReplay>>,
437}
438
439impl LocalStorages {
440 fn new() -> Self {
441 Self {
442 storages: HashMap::new(),
443 }
444 }
445
446 fn remove(&mut self, storage_type: &StorageType) {
447 self.storages.remove(storage_type);
448 }
449
450 fn get_mut(&mut self, storage_type: &StorageType) -> Option<&mut Box<dyn LocalReplay>> {
451 self.storages.get_mut(storage_type)
452 }
453
454 fn insert(&mut self, storage_type: StorageType, local_storage: Box<dyn LocalReplay>) {
455 self.storages.insert(storage_type, local_storage);
456 }
457
458 fn is_empty(&self) -> bool {
459 self.storages.is_empty()
460 }
461
462 #[cfg(test)]
463 fn len(&self) -> usize {
464 self.storages.len()
465 }
466}
467
468#[cfg(test)]
469mod tests {
470
471 use std::ops::Bound;
472
473 use bytes::Bytes;
474 use mockall::predicate;
475
476 use super::*;
477 use crate::replay::{MockGlobalReplayInterface, MockLocalReplayInterface};
478 use crate::{MockReplayIterStream, TracedBytes, TracedReadOptions};
479
480 #[tokio::test]
481 async fn test_handle_record() {
482 let mut iters_map = HashMap::new();
483 let mut local_storage_opts_map = HashMap::new();
484 let mut local_storages = LocalStorages::new();
485 let (res_tx, mut res_rx) = unbounded_channel();
486 let get_table_id = 12;
487 let iter_table_id = 14654;
488 let read_options = TracedReadOptions::for_test(get_table_id);
489 let iter_read_options = TracedReadOptions::for_test(iter_table_id);
490 let op = Operation::get(Bytes::from(vec![123]), Some(123), read_options);
491
492 let new_local_opts = TracedNewLocalOptions::for_test(get_table_id);
493
494 let iter_local_opts = TracedNewLocalOptions::for_test(iter_table_id);
495 let mut should_exit = false;
496 let get_storage_type = StorageType::Local(0, 0);
497 let record = Record::new(get_storage_type, 1, op);
498 let mut mock_replay = MockGlobalReplayInterface::new();
499
500 mock_replay.expect_new_local().times(1).returning(move |_| {
501 let mut mock_local = MockLocalReplayInterface::new();
502
503 mock_local
504 .expect_get()
505 .with(
506 predicate::eq(TracedBytes::from(vec![123])),
507 predicate::always(),
508 )
509 .returning(|_, _| Ok(Some(TracedBytes::from(vec![120]))));
510
511 Box::new(mock_local)
512 });
513
514 mock_replay.expect_new_local().times(1).returning(move |_| {
515 let mut mock_local = MockLocalReplayInterface::new();
516
517 mock_local
518 .expect_iter()
519 .with(
520 predicate::eq((Bound::Unbounded, Bound::Unbounded)),
521 predicate::always(),
522 )
523 .returning(move |_, _| {
524 let iter = MockReplayIterStream::new(vec![(
525 TracedBytes::from(vec![1]),
526 TracedBytes::from(vec![0]),
527 )]);
528 Ok(iter.into_stream().boxed())
529 });
530
531 Box::new(mock_local)
532 });
533
534 let replay = Arc::new(mock_replay);
535
536 ReplayWorker::handle_record(
537 Record::new(
538 get_storage_type,
539 0,
540 Operation::NewLocalStorage(new_local_opts, 0),
541 ),
542 &replay,
543 &mut res_rx,
544 &mut iters_map,
545 &mut local_storages,
546 &mut local_storage_opts_map,
547 &mut should_exit,
548 )
549 .await;
550
551 res_tx
552 .send(OperationResult::Get(TraceResult::Ok(Some(
553 TracedBytes::from(vec![120]),
554 ))))
555 .unwrap();
556
557 ReplayWorker::handle_record(
558 record,
559 &replay,
560 &mut res_rx,
561 &mut iters_map,
562 &mut local_storages,
563 &mut local_storage_opts_map,
564 &mut should_exit,
565 )
566 .await;
567
568 assert_eq!(local_storages.len(), 1);
569 assert!(iters_map.is_empty());
570
571 let op = Operation::Iter {
572 key_range: (Bound::Unbounded, Bound::Unbounded),
573 epoch: Some(45),
574 read_options: iter_read_options,
575 };
576
577 let iter_storage_type = StorageType::Local(0, 1);
578
579 ReplayWorker::handle_record(
580 Record::new(
581 iter_storage_type,
582 2,
583 Operation::NewLocalStorage(iter_local_opts, 1),
584 ),
585 &replay,
586 &mut res_rx,
587 &mut iters_map,
588 &mut local_storages,
589 &mut local_storage_opts_map,
590 &mut should_exit,
591 )
592 .await;
593
594 let record = Record::new(iter_storage_type, 1, op);
595 res_tx
596 .send(OperationResult::Iter(TraceResult::Ok(())))
597 .unwrap();
598
599 ReplayWorker::handle_record(
600 record,
601 &replay,
602 &mut res_rx,
603 &mut iters_map,
604 &mut local_storages,
605 &mut local_storage_opts_map,
606 &mut should_exit,
607 )
608 .await;
609
610 assert_eq!(local_storages.len(), 2);
611 assert_eq!(iters_map.len(), 1);
612
613 let op = Operation::IterNext(1);
614 let record = Record::new(iter_storage_type, 3, op);
615 res_tx
616 .send(OperationResult::IterNext(TraceResult::Ok(Some((
617 TracedBytes::from(vec![1]),
618 TracedBytes::from(vec![0]),
619 )))))
620 .unwrap();
621
622 ReplayWorker::handle_record(
623 record,
624 &replay,
625 &mut res_rx,
626 &mut iters_map,
627 &mut local_storages,
628 &mut local_storage_opts_map,
629 &mut should_exit,
630 )
631 .await;
632
633 assert_eq!(local_storages.len(), 2);
634 assert_eq!(iters_map.len(), 1);
635 }
636
637 #[tokio::test]
638 async fn test_worker_scheduler() {
639 let mut mock_replay = MockGlobalReplayInterface::default();
641 let record_id = 29053;
642 let key = TracedBytes::from(vec![1]);
643 let epoch = 2596;
644 let read_options = TracedReadOptions::for_test(1);
645
646 let res_bytes = TracedBytes::from(vec![58, 54, 35]);
647
648 mock_replay
649 .expect_get()
650 .with(
651 predicate::eq(key.clone()),
652 predicate::eq(epoch),
653 predicate::eq(read_options.clone()),
654 )
655 .returning(move |_, _, _| Ok(Some(TracedBytes::from(vec![58, 54, 35]))));
656
657 let mut scheduler = WorkerScheduler::new(Arc::new(mock_replay));
658 let record = Record::new(
660 StorageType::Global,
661 record_id,
662 Operation::get(key.into(), Some(epoch), read_options),
663 );
664 scheduler.schedule(record);
665
666 let result = Record::new(
667 StorageType::Global,
668 record_id,
669 Operation::Result(OperationResult::Get(TraceResult::Ok(Some(res_bytes)))),
670 );
671
672 scheduler.send_result(result);
673
674 let fin = Record::new(StorageType::Global, record_id, Operation::Finish);
675 scheduler.wait_finish(fin).await;
676
677 scheduler.shutdown().await;
678 }
679}