1use std::collections::HashMap;
16use std::future::Future;
17use std::sync::Arc;
18use std::time::Duration;
19
20use futures_util::future::{BoxFuture, FutureExt};
21
22#[derive(thiserror::Error, Debug)]
23pub enum Error {
24 #[error("Wait for signal {0} timeout")]
25 WaitTimeout(&'static str),
26}
27
28pub type SyncPoint = &'static str;
29type Action = Arc<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync>;
30
31static SYNC_FACILITY: spin::Once<SyncFacility> = spin::Once::new();
32
33struct SyncFacility {
34 notifies: spin::Mutex<HashMap<SyncPoint, Arc<tokio::sync::Notify>>>,
36 actions: spin::Mutex<HashMap<SyncPoint, Action>>,
38}
39
40impl SyncFacility {
41 fn new() -> Self {
42 Self {
43 notifies: Default::default(),
44 actions: Default::default(),
45 }
46 }
47
48 fn get() -> &'static Self {
49 SYNC_FACILITY.get().expect("sync point not enabled")
50 }
51
52 async fn wait(
53 &self,
54 sync_point: SyncPoint,
55 timeout: Duration,
56 relay: bool,
57 ) -> Result<(), Error> {
58 let entry = self.notifies.lock().entry(sync_point).or_default().clone();
59 match tokio::time::timeout(timeout, entry.notified()).await {
60 Ok(_) if relay => entry.notify_one(),
61 Ok(_) => {}
62 Err(_) => return Err(Error::WaitTimeout(sync_point)),
63 }
64 Ok(())
65 }
66
67 fn emit(&self, sync_point: SyncPoint) {
68 self.notifies
69 .lock()
70 .entry(sync_point)
71 .or_default()
72 .notify_one();
73 }
74
75 fn hook(&self, sync_point: SyncPoint, action: Action) {
76 self.actions.lock().insert(sync_point, action);
77 }
78
79 fn reset(&self) {
80 self.actions.lock().clear();
81 self.notifies.lock().clear();
82 }
83
84 fn remove_action(&self, sync_point: SyncPoint) {
85 self.actions.lock().remove(&sync_point);
86 }
87
88 async fn on(&self, sync_point: SyncPoint) {
89 let action = self.actions.lock().get(sync_point).map(|action| action());
90 if let Some(action) = action {
91 action.await;
92 }
93 self.emit(sync_point);
94 }
95}
96
97pub fn reset() {
99 SYNC_FACILITY.call_once(SyncFacility::new).reset();
100}
101
102pub fn remove_action(sync_point: SyncPoint) {
104 SYNC_FACILITY
105 .call_once(SyncFacility::new)
106 .remove_action(sync_point);
107}
108
109pub async fn on(sync_point: SyncPoint) {
111 if let Some(sync_facility) = SYNC_FACILITY.get() {
112 sync_facility.on(sync_point).await;
113 }
114}
115
116pub fn hook<F, Fut>(sync_point: SyncPoint, action: F)
120where
121 F: Fn() -> Fut + Send + Sync + 'static,
122 Fut: Future<Output = ()> + Send + 'static,
123{
124 let action = Arc::new(move || action().boxed());
125 SyncFacility::get().hook(sync_point, action);
126}
127
128pub async fn wait_timeout(sync_point: SyncPoint, dur: Duration) -> Result<(), Error> {
133 SyncFacility::get().wait(sync_point, dur, false).await
134}
135
136#[macro_export]
137#[cfg(feature = "sync_point")]
138macro_rules! sync_point {
139 ($name:expr) => {{
140 sync_point::on($name).await;
141 }};
142}
143
144#[macro_export]
145#[cfg(not(feature = "sync_point"))]
146macro_rules! sync_point {
147 ($name:expr) => {{}};
148}