rw_futures_util/
pausable.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::pin::Pin;
16use std::sync::atomic::{AtomicBool, Ordering};
17use std::sync::{Arc, Mutex};
18use std::task::{Context, Poll, Waker};
19
20use futures::Stream;
21use pin_project_lite::pin_project;
22
23pin_project! {
24    #[derive(Debug)]
25    #[must_use = "streams do nothing unless polled"]
26    pub struct Pausable<St>
27        where St: Stream
28    {
29        #[pin]
30        stream: St,
31        paused: Arc<AtomicBool>,
32        waker: Arc<Mutex<Option<Waker>>>,
33    }
34}
35
36/// A valve is a handle that can control the [`Pausable`] stream.
37#[derive(Clone)]
38pub struct Valve {
39    paused: Arc<AtomicBool>,
40    waker: Arc<Mutex<Option<Waker>>>,
41}
42
43impl Valve {
44    /// Pause the stream controlled by the valve.
45    pub fn pause(&self) {
46        self.paused.store(true, Ordering::Relaxed);
47    }
48
49    /// Resume the stream controlled by the valve.
50    pub fn resume(&self) {
51        self.paused.store(false, Ordering::Relaxed);
52        if let Some(waker) = self.waker.lock().unwrap().as_ref() {
53            waker.wake_by_ref()
54        }
55    }
56}
57
58impl<St> Pausable<St>
59where
60    St: Stream,
61{
62    pub(crate) fn new(stream: St) -> (Self, Valve) {
63        let paused = Arc::new(AtomicBool::new(false));
64        let waker = Arc::new(Mutex::new(None));
65        (
66            Pausable {
67                stream,
68                paused: paused.clone(),
69                waker: waker.clone(),
70            },
71            Valve { paused, waker },
72        )
73    }
74}
75
76impl<St> Stream for Pausable<St>
77where
78    St: Stream,
79{
80    type Item = St::Item;
81
82    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
83        let this = self.project();
84        if this.paused.load(Ordering::Relaxed) {
85            let mut waker = this.waker.lock().unwrap();
86            *waker = Some(cx.waker().clone());
87            Poll::Pending
88        } else {
89            this.stream.poll_next(cx)
90        }
91    }
92
93    fn size_hint(&self) -> (usize, Option<usize>) {
94        self.stream.size_hint()
95    }
96}