rw_futures_util/
buffered_with_fence.rs1use std::future::Future;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18
19use futures::TryFutureExt;
20use futures::future::{FusedFuture, IntoFuture, TryFuture};
21use futures::stream::{
22 Fuse, FuturesOrdered, IntoStream, Stream, StreamExt, TryStream, TryStreamExt,
23};
24use pin_project_lite::pin_project;
25
26pub trait MaybeFence {
27 fn is_fence(&self) -> bool {
28 false
29 }
30}
31
32pin_project! {
33 #[derive(Debug)]
34 #[must_use = "streams do nothing unless polled"]
35 pub struct TryBufferedWithFence<St>
36 where
37 St: TryStream,
38 St::Ok: TryFuture,
39 {
40 #[pin]
41 stream: Fuse<IntoStream<St>>,
42 in_progress_queue: FuturesOrdered<IntoFuture<St::Ok>>,
43 syncing: bool,
44 max: usize,
45 }
46}
47
48impl<St> TryBufferedWithFence<St>
49where
50 St: TryStream,
51 St::Ok: TryFuture<Error = St::Error> + MaybeFence,
52{
53 pub(crate) fn new(stream: St, n: usize) -> Self {
54 Self {
55 stream: stream.into_stream().fuse(),
56 in_progress_queue: FuturesOrdered::new(),
57 syncing: false,
58 max: n,
59 }
60 }
61}
62
63impl<St> Stream for TryBufferedWithFence<St>
64where
65 St: TryStream,
66 St::Ok: TryFuture<Error = St::Error> + MaybeFence,
67{
68 type Item = Result<<St::Ok as TryFuture>::Ok, St::Error>;
69
70 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
71 let mut this = self.project();
72
73 if *this.syncing && this.in_progress_queue.is_empty() {
74 *this.syncing = false;
75 }
76
77 while !*this.syncing && this.in_progress_queue.len() < *this.max {
80 match this.stream.as_mut().poll_next(cx)? {
81 Poll::Ready(Some(fut)) => {
82 let is_fence = fut.is_fence();
83 this.in_progress_queue
84 .push_back(TryFutureExt::into_future(fut));
85 if is_fence {
86 *this.syncing = true;
88 break;
89 }
90 }
91 Poll::Ready(None) | Poll::Pending => break,
92 }
93 }
94
95 match this.in_progress_queue.poll_next_unpin(cx) {
97 x @ Poll::Pending | x @ Poll::Ready(Some(_)) => return x,
98 Poll::Ready(None) => {}
99 }
100
101 if this.stream.is_done() {
103 Poll::Ready(None)
104 } else {
105 Poll::Pending
106 }
107 }
108}
109
110pin_project! {
111 #[must_use = "futures do nothing unless you `.await` or polled them"]
112 pub struct Fenced<Fut: Future> {
113 #[pin]
114 inner: Fut,
115 is_fence: bool,
116 }
117}
118
119impl<Fut> Fenced<Fut>
120where
121 Fut: Future,
122{
123 pub(crate) fn new(inner: Fut, is_fence: bool) -> Self {
124 Self { inner, is_fence }
125 }
126}
127
128impl<Fut> Future for Fenced<Fut>
129where
130 Fut: Future,
131{
132 type Output = <Fut as Future>::Output;
133
134 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
135 let this = self.project();
136
137 this.inner.poll(cx)
138 }
139}
140
141impl<Fut> FusedFuture for Fenced<Fut>
142where
143 Fut: FusedFuture,
144{
145 fn is_terminated(&self) -> bool {
146 self.inner.is_terminated()
147 }
148}
149
150impl<Fut> MaybeFence for Fenced<Fut>
151where
152 Fut: Future,
153{
154 fn is_fence(&self) -> bool {
155 self.is_fence
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use std::sync::{Arc, Mutex};
162 use std::time::Duration;
163
164 use futures::stream::StreamExt;
165
166 use crate::{RwFutureExt, RwTryStreamExt};
167
168 #[tokio::test]
169 async fn test_buffered_with_fence() {
170 let n = 10;
171 let polled_flags: Vec<_> = (0..n).map(|_| Arc::new(Mutex::new(false))).collect();
172 let futs = polled_flags.iter().cloned().enumerate().map(|(i, flag)| {
173 let polled_flags2: Vec<_> = polled_flags.clone();
174 let is_fence = i == 2 || i == 4 || i == 9;
175
176 async move {
177 {
178 let mut flag = flag.lock().unwrap();
179 *flag = true;
180 }
181 tokio::time::sleep(Duration::from_millis(10 * (n - i) as u64)).await;
182 if is_fence {
183 let all_later_unpolled =
184 polled_flags2[(i + 1)..n].iter().cloned().all(|flag| {
185 let flag = flag.lock().unwrap();
186 !*flag
187 });
188 assert!(all_later_unpolled);
189 }
190 tokio::time::sleep(Duration::from_millis(10 * (n - i) as u64)).await;
191
192 Ok::<_, ()>(())
193 }
194 .with_fence(is_fence)
195 });
196 let st = futures::stream::iter(futs)
197 .map(Ok)
198 .try_buffered_with_fence(4);
199 let cnt = st.count().await;
200 assert_eq!(cnt, n);
201 }
202}