use std::collections::HashMap;
use std::future::Future;
use std::sync::Arc;
use std::time::Duration;
use futures_util::future::{BoxFuture, FutureExt};
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("Wait for signal {0} timeout")]
WaitTimeout(&'static str),
}
pub type SyncPoint = &'static str;
type Action = Arc<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync>;
static SYNC_FACILITY: spin::Once<SyncFacility> = spin::Once::new();
struct SyncFacility {
notifies: spin::Mutex<HashMap<SyncPoint, Arc<tokio::sync::Notify>>>,
actions: spin::Mutex<HashMap<SyncPoint, Action>>,
}
impl SyncFacility {
fn new() -> Self {
Self {
notifies: Default::default(),
actions: Default::default(),
}
}
fn get() -> &'static Self {
SYNC_FACILITY.get().expect("sync point not enabled")
}
async fn wait(
&self,
sync_point: SyncPoint,
timeout: Duration,
relay: bool,
) -> Result<(), Error> {
let entry = self.notifies.lock().entry(sync_point).or_default().clone();
match tokio::time::timeout(timeout, entry.notified()).await {
Ok(_) if relay => entry.notify_one(),
Ok(_) => {}
Err(_) => return Err(Error::WaitTimeout(sync_point)),
}
Ok(())
}
fn emit(&self, sync_point: SyncPoint) {
self.notifies
.lock()
.entry(sync_point)
.or_default()
.notify_one();
}
fn hook(&self, sync_point: SyncPoint, action: Action) {
self.actions.lock().insert(sync_point, action);
}
fn reset(&self) {
self.actions.lock().clear();
self.notifies.lock().clear();
}
fn remove_action(&self, sync_point: SyncPoint) {
self.actions.lock().remove(&sync_point);
}
async fn on(&self, sync_point: SyncPoint) {
let action = self.actions.lock().get(sync_point).map(|action| action());
if let Some(action) = action {
action.await;
}
self.emit(sync_point);
}
}
pub fn reset() {
SYNC_FACILITY.call_once(SyncFacility::new).reset();
}
pub fn remove_action(sync_point: SyncPoint) {
SYNC_FACILITY
.call_once(SyncFacility::new)
.remove_action(sync_point);
}
pub async fn on(sync_point: SyncPoint) {
if let Some(sync_facility) = SYNC_FACILITY.get() {
sync_facility.on(sync_point).await;
}
}
pub fn hook<F, Fut>(sync_point: SyncPoint, action: F)
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let action = Arc::new(move || action().boxed());
SyncFacility::get().hook(sync_point, action);
}
pub async fn wait_timeout(sync_point: SyncPoint, dur: Duration) -> Result<(), Error> {
SyncFacility::get().wait(sync_point, dur, false).await
}
#[macro_export]
#[cfg(feature = "sync_point")]
macro_rules! sync_point {
($name:expr) => {{
sync_point::on($name).await;
}};
}
#[macro_export]
#[cfg(not(feature = "sync_point"))]
macro_rules! sync_point {
($name:expr) => {{}};
}