risingwave_stream/executor/
stream_reader.rs1use std::pin::Pin;
16use std::task::Poll;
17
18use either::Either;
19use futures::stream::BoxStream;
20use futures::{Stream, StreamExt, TryStreamExt};
21
22use crate::executor::Message;
23use crate::executor::error::StreamExecutorResult;
24
25type ExecutorMessageStream = BoxStream<'static, StreamExecutorResult<Message>>;
26type StreamReaderData<M> = StreamExecutorResult<Either<Message, M>>;
27type ReaderArm<M> = BoxStream<'static, StreamReaderData<M>>;
28
29mod stream_reader_with_pause {
30 use futures::stream::{PollNext, SelectWithStrategy, select_with_strategy};
31
32 use crate::executor::stream_reader::ReaderArm;
33
34 pub(super) type StreamReaderWithPauseInner<M, const BIASED: bool> = SelectWithStrategy<
35 ReaderArm<M>,
36 ReaderArm<M>,
37 impl FnMut(&mut PollNext) -> PollNext,
38 PollNext,
39 >;
40
41 #[define_opaque(StreamReaderWithPauseInner)]
42 pub(super) fn new_inner<M, const BIASED: bool>(
43 message_stream: ReaderArm<M>,
44 data_stream: ReaderArm<M>,
45 ) -> StreamReaderWithPauseInner<M, BIASED> {
46 let strategy = if BIASED {
47 |_: &mut PollNext| PollNext::Left
48 } else {
49 |last: &mut PollNext| last.toggle()
51 };
52 select_with_strategy(message_stream, data_stream, strategy)
53 }
54}
55
56use stream_reader_with_pause::*;
57
58pub(super) struct StreamReaderWithPause<const BIASED: bool, M> {
70 inner: StreamReaderWithPauseInner<M, BIASED>,
71 paused: bool,
73}
74
75impl<const BIASED: bool, M: Send + 'static> StreamReaderWithPause<BIASED, M> {
76 pub fn new(
79 message_stream: ExecutorMessageStream,
80 data_stream: impl Stream<Item = StreamExecutorResult<M>> + Send + 'static,
81 ) -> Self {
82 let message_stream_arm = message_stream.map_ok(Either::Left).boxed();
83 let data_stream_arm = data_stream.map_ok(Either::Right).boxed();
84 let inner = new_inner(message_stream_arm, data_stream_arm);
85 Self {
86 inner,
87 paused: false,
88 }
89 }
90
91 #[allow(dead_code)]
92 pub fn only_left(message_stream: ExecutorMessageStream) -> Self {
93 Self::new(message_stream, futures::stream::empty().boxed())
94 }
95
96 pub fn replace_data_stream(
98 &mut self,
99 data_stream: impl Stream<Item = StreamExecutorResult<M>> + Send + 'static,
100 ) {
101 let barrier_receiver_arm = std::mem::replace(
103 self.inner.get_mut().0,
104 futures::stream::once(async { unreachable!("placeholder") }).boxed(),
105 );
106
107 self.inner = new_inner(
110 barrier_receiver_arm,
111 data_stream.map_ok(Either::Right).boxed(),
112 );
113 }
114
115 pub fn pause_stream(&mut self) {
117 if self.paused {
118 tracing::warn!("already paused");
119 }
120 tracing::info!("data stream paused");
121 self.paused = true;
122 }
123
124 pub fn resume_stream(&mut self) {
126 if !self.paused {
127 tracing::warn!("not paused");
128 }
129 tracing::info!("data stream resumed");
130 self.paused = false;
131 }
132}
133
134impl<const BIASED: bool, M> Stream for StreamReaderWithPause<BIASED, M> {
135 type Item = StreamReaderData<M>;
136
137 fn poll_next(
138 mut self: Pin<&mut Self>,
139 ctx: &mut std::task::Context<'_>,
140 ) -> Poll<Option<Self::Item>> {
141 if self.paused {
142 self.inner.get_mut().0.poll_next_unpin(ctx)
146 } else {
147 self.inner.poll_next_unpin(ctx)
150 }
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use assert_matches::assert_matches;
157 use futures::{FutureExt, pin_mut};
158 use risingwave_common::array::StreamChunk;
159 use risingwave_common::transaction::transaction_id::TxnId;
160 use risingwave_common::util::epoch::test_epoch;
161 use risingwave_dml::TableDmlHandle;
162 use tokio::sync::mpsc;
163
164 use super::*;
165 use crate::executor::source::barrier_to_message_stream;
166 use crate::executor::{Barrier, StreamExecutorError};
167
168 const TEST_TRANSACTION_ID1: TxnId = 0;
169 const TEST_TRANSACTION_ID2: TxnId = 1;
170 const TEST_SESSION_ID: u32 = 0;
171 const TEST_DML_CHANNEL_INIT_PERMITS: usize = 32768;
172
173 #[tokio::test]
174 async fn test_pause_and_resume() {
175 let (barrier_tx, barrier_rx) = mpsc::unbounded_channel();
176
177 let table_dml_handle = TableDmlHandle::new(vec![], TEST_DML_CHANNEL_INIT_PERMITS);
178
179 let source_stream = table_dml_handle
180 .stream_reader()
181 .into_data_stream_for_test()
182 .map_err(StreamExecutorError::from);
183
184 let mut write_handle1 = table_dml_handle
185 .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID1)
186 .unwrap();
187 let mut write_handle2 = table_dml_handle
188 .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID2)
189 .unwrap();
190
191 let barrier_stream = barrier_to_message_stream(barrier_rx).boxed();
192 let stream = StreamReaderWithPause::<true, StreamChunk>::new(barrier_stream, source_stream);
193 pin_mut!(stream);
194
195 macro_rules! next {
196 () => {
197 stream
198 .next()
199 .now_or_never()
200 .flatten()
201 .map(|result| result.unwrap())
202 };
203 }
204
205 write_handle1.begin().unwrap();
208 write_handle1
209 .write_chunk(StreamChunk::default())
210 .await
211 .unwrap();
212 assert_matches!(next!().unwrap(), Either::Right(_));
215 barrier_tx
217 .send(Barrier::new_test_barrier(test_epoch(1)))
218 .unwrap();
219 assert_matches!(next!().unwrap(), Either::Left(_));
220
221 stream.pause_stream();
223
224 barrier_tx
226 .send(Barrier::new_test_barrier(test_epoch(2)))
227 .unwrap();
228 write_handle2.begin().unwrap();
230 write_handle2
231 .write_chunk(StreamChunk::default())
232 .await
233 .unwrap();
234 assert_matches!(next!().unwrap(), Either::Left(_));
238 assert!(next!().is_none());
240
241 stream.resume_stream();
243 assert_matches!(next!().unwrap(), Either::Right(_));
245 }
246}