sync_point/
lib.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::future::Future;
17use std::sync::Arc;
18use std::time::Duration;
19
20use futures_util::future::{BoxFuture, FutureExt};
21
22#[derive(thiserror::Error, Debug)]
23pub enum Error {
24    #[error("Wait for signal {0} timeout")]
25    WaitTimeout(&'static str),
26}
27
28pub type SyncPoint = &'static str;
29type Action = Arc<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync>;
30
31static SYNC_FACILITY: spin::Once<SyncFacility> = spin::Once::new();
32
33struct SyncFacility {
34    /// `Notify` for each sync point.
35    notifies: spin::Mutex<HashMap<SyncPoint, Arc<tokio::sync::Notify>>>,
36    /// Actions for each sync point.
37    actions: spin::Mutex<HashMap<SyncPoint, Action>>,
38}
39
40impl SyncFacility {
41    fn new() -> Self {
42        Self {
43            notifies: Default::default(),
44            actions: Default::default(),
45        }
46    }
47
48    fn get() -> &'static Self {
49        SYNC_FACILITY.get().expect("sync point not enabled")
50    }
51
52    async fn wait(
53        &self,
54        sync_point: SyncPoint,
55        timeout: Duration,
56        relay: bool,
57    ) -> Result<(), Error> {
58        let entry = self.notifies.lock().entry(sync_point).or_default().clone();
59        match tokio::time::timeout(timeout, entry.notified()).await {
60            Ok(_) if relay => entry.notify_one(),
61            Ok(_) => {}
62            Err(_) => return Err(Error::WaitTimeout(sync_point)),
63        }
64        Ok(())
65    }
66
67    fn emit(&self, sync_point: SyncPoint) {
68        self.notifies
69            .lock()
70            .entry(sync_point)
71            .or_default()
72            .notify_one();
73    }
74
75    fn hook(&self, sync_point: SyncPoint, action: Action) {
76        self.actions.lock().insert(sync_point, action);
77    }
78
79    fn reset(&self) {
80        self.actions.lock().clear();
81        self.notifies.lock().clear();
82    }
83
84    fn remove_action(&self, sync_point: SyncPoint) {
85        self.actions.lock().remove(&sync_point);
86    }
87
88    async fn on(&self, sync_point: SyncPoint) {
89        let action = self.actions.lock().get(sync_point).map(|action| action());
90        if let Some(action) = action {
91            action.await;
92        }
93        self.emit(sync_point);
94    }
95}
96
97/// Enable or reset the global sync facility.
98pub fn reset() {
99    SYNC_FACILITY.call_once(SyncFacility::new).reset();
100}
101
102/// Remove a sync point's action.
103pub fn remove_action(sync_point: SyncPoint) {
104    SYNC_FACILITY
105        .call_once(SyncFacility::new)
106        .remove_action(sync_point);
107}
108
109/// Mark a sync point.
110pub async fn on(sync_point: SyncPoint) {
111    if let Some(sync_facility) = SYNC_FACILITY.get() {
112        sync_facility.on(sync_point).await;
113    }
114}
115
116/// Hook a sync point with action.
117///
118/// The action will be executed before reaching the sync point.
119pub fn hook<F, Fut>(sync_point: SyncPoint, action: F)
120where
121    F: Fn() -> Fut + Send + Sync + 'static,
122    Fut: Future<Output = ()> + Send + 'static,
123{
124    let action = Arc::new(move || action().boxed());
125    SyncFacility::get().hook(sync_point, action);
126}
127
128/// Wait for a sync point to be reached with timeout.
129///
130/// If the sync point is reached before this call, it will consume this event and return
131/// immediately.
132pub async fn wait_timeout(sync_point: SyncPoint, dur: Duration) -> Result<(), Error> {
133    SyncFacility::get().wait(sync_point, dur, false).await
134}
135
136#[macro_export]
137#[cfg(feature = "sync_point")]
138macro_rules! sync_point {
139    ($name:expr) => {{
140        sync_point::on($name).await;
141    }};
142}
143
144#[macro_export]
145#[cfg(not(feature = "sync_point"))]
146macro_rules! sync_point {
147    ($name:expr) => {{}};
148}