rw_futures_util/
buffered_with_fence.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::future::Future;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18
19use futures::TryFutureExt;
20use futures::future::{FusedFuture, IntoFuture, TryFuture};
21use futures::stream::{
22    Fuse, FuturesOrdered, IntoStream, Stream, StreamExt, TryStream, TryStreamExt,
23};
24use pin_project_lite::pin_project;
25
26pub trait MaybeFence {
27    fn is_fence(&self) -> bool {
28        false
29    }
30}
31
32pin_project! {
33    #[derive(Debug)]
34    #[must_use = "streams do nothing unless polled"]
35    pub struct TryBufferedWithFence<St>
36    where
37        St: TryStream,
38        St::Ok: TryFuture,
39    {
40        #[pin]
41        stream: Fuse<IntoStream<St>>,
42        in_progress_queue: FuturesOrdered<IntoFuture<St::Ok>>,
43        syncing: bool,
44        max: usize,
45    }
46}
47
48impl<St> TryBufferedWithFence<St>
49where
50    St: TryStream,
51    St::Ok: TryFuture<Error = St::Error> + MaybeFence,
52{
53    pub(crate) fn new(stream: St, n: usize) -> Self {
54        Self {
55            stream: stream.into_stream().fuse(),
56            in_progress_queue: FuturesOrdered::new(),
57            syncing: false,
58            max: n,
59        }
60    }
61}
62
63impl<St> Stream for TryBufferedWithFence<St>
64where
65    St: TryStream,
66    St::Ok: TryFuture<Error = St::Error> + MaybeFence,
67{
68    type Item = Result<<St::Ok as TryFuture>::Ok, St::Error>;
69
70    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
71        let mut this = self.project();
72
73        if *this.syncing && this.in_progress_queue.is_empty() {
74            *this.syncing = false;
75        }
76
77        // First up, try to spawn off as many futures as possible by filling up our queue of futures, **if the stream is not in syncing**.
78        // Propagate errors from the stream immediately.
79        while !*this.syncing && this.in_progress_queue.len() < *this.max {
80            match this.stream.as_mut().poll_next(cx)? {
81                Poll::Ready(Some(fut)) => {
82                    let is_fence = fut.is_fence();
83                    this.in_progress_queue
84                        .push_back(TryFutureExt::into_future(fut));
85                    if is_fence {
86                        // While receiving a fence, don't buffer more data.
87                        *this.syncing = true;
88                        break;
89                    }
90                }
91                Poll::Ready(None) | Poll::Pending => break,
92            }
93        }
94
95        // Attempt to pull the next value from the in_progress_queue
96        match this.in_progress_queue.poll_next_unpin(cx) {
97            x @ Poll::Pending | x @ Poll::Ready(Some(_)) => return x,
98            Poll::Ready(None) => {}
99        }
100
101        // If more values are still coming from the stream, we're not done yet
102        if this.stream.is_done() {
103            Poll::Ready(None)
104        } else {
105            Poll::Pending
106        }
107    }
108}
109
110pin_project! {
111    #[must_use = "futures do nothing unless you `.await` or polled them"]
112    pub struct Fenced<Fut: Future> {
113        #[pin]
114        inner: Fut,
115        is_fence: bool,
116    }
117}
118
119impl<Fut> Fenced<Fut>
120where
121    Fut: Future,
122{
123    pub(crate) fn new(inner: Fut, is_fence: bool) -> Self {
124        Self { inner, is_fence }
125    }
126}
127
128impl<Fut> Future for Fenced<Fut>
129where
130    Fut: Future,
131{
132    type Output = <Fut as Future>::Output;
133
134    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
135        let this = self.project();
136
137        this.inner.poll(cx)
138    }
139}
140
141impl<Fut> FusedFuture for Fenced<Fut>
142where
143    Fut: FusedFuture,
144{
145    fn is_terminated(&self) -> bool {
146        self.inner.is_terminated()
147    }
148}
149
150impl<Fut> MaybeFence for Fenced<Fut>
151where
152    Fut: Future,
153{
154    fn is_fence(&self) -> bool {
155        self.is_fence
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use std::sync::{Arc, Mutex};
162    use std::time::Duration;
163
164    use futures::stream::StreamExt;
165
166    use crate::{RwFutureExt, RwTryStreamExt};
167
168    #[tokio::test]
169    async fn test_buffered_with_fence() {
170        let n = 10;
171        let polled_flags: Vec<_> = (0..n).map(|_| Arc::new(Mutex::new(false))).collect();
172        let futs = polled_flags.iter().cloned().enumerate().map(|(i, flag)| {
173            let polled_flags2: Vec<_> = polled_flags.clone();
174            let is_fence = i == 2 || i == 4 || i == 9;
175
176            async move {
177                {
178                    let mut flag = flag.lock().unwrap();
179                    *flag = true;
180                }
181                tokio::time::sleep(Duration::from_millis(10 * (n - i) as u64)).await;
182                if is_fence {
183                    let all_later_unpolled =
184                        polled_flags2[(i + 1)..n].iter().cloned().all(|flag| {
185                            let flag = flag.lock().unwrap();
186                            !*flag
187                        });
188                    assert!(all_later_unpolled);
189                }
190                tokio::time::sleep(Duration::from_millis(10 * (n - i) as u64)).await;
191
192                Ok::<_, ()>(())
193            }
194            .with_fence(is_fence)
195        });
196        let st = futures::stream::iter(futs)
197            .map(Ok)
198            .try_buffered_with_fence(4);
199        let cnt = st.count().await;
200        assert_eq!(cnt, n);
201    }
202}