risingwave_stream/executor/
stream_reader.rs1use std::pin::Pin;
16use std::task::Poll;
17
18use either::Either;
19use futures::stream::BoxStream;
20use futures::{Stream, StreamExt, TryStreamExt};
21
22use crate::executor::Message;
23use crate::executor::error::StreamExecutorResult;
24
25type ExecutorMessageStream = BoxStream<'static, StreamExecutorResult<Message>>;
26type StreamReaderData<M> = StreamExecutorResult<Either<Message, M>>;
27type ReaderArm<M> = BoxStream<'static, StreamReaderData<M>>;
28
29mod stream_reader_with_pause {
30 use futures::stream::{PollNext, SelectWithStrategy, select_with_strategy};
31
32 use crate::executor::stream_reader::ReaderArm;
33
34 pub(super) type StreamReaderWithPauseInner<M, const BIASED: bool> = SelectWithStrategy<
35 ReaderArm<M>,
36 ReaderArm<M>,
37 impl FnMut(&mut PollNext) -> PollNext,
38 PollNext,
39 >;
40
41 #[define_opaque(StreamReaderWithPauseInner)]
42 pub(super) fn new_inner<M, const BIASED: bool>(
43 message_stream: ReaderArm<M>,
44 data_stream: ReaderArm<M>,
45 ) -> StreamReaderWithPauseInner<M, BIASED> {
46 let strategy = if BIASED {
47 |_: &mut PollNext| PollNext::Left
48 } else {
49 |last: &mut PollNext| last.toggle()
51 };
52 select_with_strategy(message_stream, data_stream, strategy)
53 }
54}
55
56use stream_reader_with_pause::*;
57
58pub(super) struct StreamReaderWithPause<const BIASED: bool, M> {
70 inner: StreamReaderWithPauseInner<M, BIASED>,
71 paused: bool,
73}
74
75impl<const BIASED: bool, M: Send + 'static> StreamReaderWithPause<BIASED, M> {
76 pub fn new(
79 message_stream: ExecutorMessageStream,
80 data_stream: impl Stream<Item = StreamExecutorResult<M>> + Send + 'static,
81 ) -> Self {
82 let message_stream_arm = message_stream.map_ok(Either::Left).boxed();
83 let data_stream_arm = data_stream.map_ok(Either::Right).boxed();
84 let inner = new_inner(message_stream_arm, data_stream_arm);
85 Self {
86 inner,
87 paused: false,
88 }
89 }
90
91 pub fn replace_data_stream(
93 &mut self,
94 data_stream: impl Stream<Item = StreamExecutorResult<M>> + Send + 'static,
95 ) {
96 let barrier_receiver_arm = std::mem::replace(
98 self.inner.get_mut().0,
99 futures::stream::once(async { unreachable!("placeholder") }).boxed(),
100 );
101
102 self.inner = new_inner(
105 barrier_receiver_arm,
106 data_stream.map_ok(Either::Right).boxed(),
107 );
108 }
109
110 pub fn pause_stream(&mut self) {
112 if self.paused {
113 tracing::warn!("already paused");
114 }
115 tracing::info!("data stream paused");
116 self.paused = true;
117 }
118
119 pub fn resume_stream(&mut self) {
121 if !self.paused {
122 tracing::warn!("not paused");
123 }
124 tracing::info!("data stream resumed");
125 self.paused = false;
126 }
127}
128
129impl<const BIASED: bool, M> Stream for StreamReaderWithPause<BIASED, M> {
130 type Item = StreamReaderData<M>;
131
132 fn poll_next(
133 mut self: Pin<&mut Self>,
134 ctx: &mut std::task::Context<'_>,
135 ) -> Poll<Option<Self::Item>> {
136 if self.paused {
137 self.inner.get_mut().0.poll_next_unpin(ctx)
141 } else {
142 self.inner.poll_next_unpin(ctx)
145 }
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use assert_matches::assert_matches;
152 use futures::{FutureExt, pin_mut};
153 use risingwave_common::array::StreamChunk;
154 use risingwave_common::transaction::transaction_id::TxnId;
155 use risingwave_common::util::epoch::test_epoch;
156 use risingwave_dml::TableDmlHandle;
157 use tokio::sync::mpsc;
158
159 use super::*;
160 use crate::executor::source::barrier_to_message_stream;
161 use crate::executor::{Barrier, StreamExecutorError};
162
163 const TEST_TRANSACTION_ID1: TxnId = 0;
164 const TEST_TRANSACTION_ID2: TxnId = 1;
165 const TEST_SESSION_ID: u32 = 0;
166 const TEST_DML_CHANNEL_INIT_PERMITS: usize = 32768;
167
168 #[tokio::test]
169 async fn test_pause_and_resume() {
170 let (barrier_tx, barrier_rx) = mpsc::unbounded_channel();
171
172 let table_dml_handle = TableDmlHandle::new(vec![], TEST_DML_CHANNEL_INIT_PERMITS);
173
174 let source_stream = table_dml_handle
175 .stream_reader()
176 .into_data_stream_for_test()
177 .map_err(StreamExecutorError::from);
178
179 let mut write_handle1 = table_dml_handle
180 .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID1)
181 .unwrap();
182 let mut write_handle2 = table_dml_handle
183 .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID2)
184 .unwrap();
185
186 let barrier_stream = barrier_to_message_stream(barrier_rx).boxed();
187 let stream = StreamReaderWithPause::<true, StreamChunk>::new(barrier_stream, source_stream);
188 pin_mut!(stream);
189
190 macro_rules! next {
191 () => {
192 stream
193 .next()
194 .now_or_never()
195 .flatten()
196 .map(|result| result.unwrap())
197 };
198 }
199
200 write_handle1.begin().unwrap();
203 write_handle1
204 .write_chunk(StreamChunk::default())
205 .await
206 .unwrap();
207 assert_matches!(next!().unwrap(), Either::Right(_));
210 barrier_tx
212 .send(Barrier::new_test_barrier(test_epoch(1)))
213 .unwrap();
214 assert_matches!(next!().unwrap(), Either::Left(_));
215
216 stream.pause_stream();
218
219 barrier_tx
221 .send(Barrier::new_test_barrier(test_epoch(2)))
222 .unwrap();
223 write_handle2.begin().unwrap();
225 write_handle2
226 .write_chunk(StreamChunk::default())
227 .await
228 .unwrap();
229 assert_matches!(next!().unwrap(), Either::Left(_));
233 assert!(next!().is_none());
235
236 stream.resume_stream();
238 assert_matches!(next!().unwrap(), Either::Right(_));
240 }
241}