risingwave_stream/executor/
stream_reader.rs1use std::pin::Pin;
16use std::task::Poll;
17
18use either::Either;
19use futures::stream::BoxStream;
20use futures::{Stream, StreamExt, TryStreamExt};
21
22use crate::executor::Message;
23use crate::executor::error::StreamExecutorResult;
24
25type ExecutorMessageStream = BoxStream<'static, StreamExecutorResult<Message>>;
26type StreamReaderData<M> = StreamExecutorResult<Either<Message, M>>;
27type ReaderArm<M> = BoxStream<'static, StreamReaderData<M>>;
28
29mod stream_reader_with_pause {
30 use futures::stream::{PollNext, SelectWithStrategy, select_with_strategy};
31
32 use crate::executor::stream_reader::ReaderArm;
33
34 pub(super) type StreamReaderWithPauseInner<M, const BIASED: bool> = SelectWithStrategy<
35 ReaderArm<M>,
36 ReaderArm<M>,
37 impl FnMut(&mut PollNext) -> PollNext,
38 PollNext,
39 >;
40
41 pub(super) fn new_inner<M, const BIASED: bool>(
42 message_stream: ReaderArm<M>,
43 data_stream: ReaderArm<M>,
44 ) -> StreamReaderWithPauseInner<M, BIASED> {
45 let strategy = if BIASED {
46 |_: &mut PollNext| PollNext::Left
47 } else {
48 |last: &mut PollNext| last.toggle()
50 };
51 select_with_strategy(message_stream, data_stream, strategy)
52 }
53}
54
55use stream_reader_with_pause::*;
56
57pub(super) struct StreamReaderWithPause<const BIASED: bool, M> {
69 inner: StreamReaderWithPauseInner<M, BIASED>,
70 paused: bool,
72}
73
74impl<const BIASED: bool, M: Send + 'static> StreamReaderWithPause<BIASED, M> {
75 pub fn new(
78 message_stream: ExecutorMessageStream,
79 data_stream: impl Stream<Item = StreamExecutorResult<M>> + Send + 'static,
80 ) -> Self {
81 let message_stream_arm = message_stream.map_ok(Either::Left).boxed();
82 let data_stream_arm = data_stream.map_ok(Either::Right).boxed();
83 let inner = new_inner(message_stream_arm, data_stream_arm);
84 Self {
85 inner,
86 paused: false,
87 }
88 }
89
90 pub fn replace_data_stream(
92 &mut self,
93 data_stream: impl Stream<Item = StreamExecutorResult<M>> + Send + 'static,
94 ) {
95 let barrier_receiver_arm = std::mem::replace(
97 self.inner.get_mut().0,
98 futures::stream::once(async { unreachable!("placeholder") }).boxed(),
99 );
100
101 self.inner = new_inner(
104 barrier_receiver_arm,
105 data_stream.map_ok(Either::Right).boxed(),
106 );
107 }
108
109 pub fn pause_stream(&mut self) {
111 if self.paused {
112 tracing::warn!("already paused");
113 }
114 tracing::info!("data stream paused");
115 self.paused = true;
116 }
117
118 pub fn resume_stream(&mut self) {
120 if !self.paused {
121 tracing::warn!("not paused");
122 }
123 tracing::info!("data stream resumed");
124 self.paused = false;
125 }
126}
127
128impl<const BIASED: bool, M> Stream for StreamReaderWithPause<BIASED, M> {
129 type Item = StreamReaderData<M>;
130
131 fn poll_next(
132 mut self: Pin<&mut Self>,
133 ctx: &mut std::task::Context<'_>,
134 ) -> Poll<Option<Self::Item>> {
135 if self.paused {
136 self.inner.get_mut().0.poll_next_unpin(ctx)
140 } else {
141 self.inner.poll_next_unpin(ctx)
144 }
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use assert_matches::assert_matches;
151 use futures::{FutureExt, pin_mut};
152 use risingwave_common::array::StreamChunk;
153 use risingwave_common::transaction::transaction_id::TxnId;
154 use risingwave_common::util::epoch::test_epoch;
155 use risingwave_dml::TableDmlHandle;
156 use tokio::sync::mpsc;
157
158 use super::*;
159 use crate::executor::source::barrier_to_message_stream;
160 use crate::executor::{Barrier, StreamExecutorError};
161
162 const TEST_TRANSACTION_ID1: TxnId = 0;
163 const TEST_TRANSACTION_ID2: TxnId = 1;
164 const TEST_SESSION_ID: u32 = 0;
165 const TEST_DML_CHANNEL_INIT_PERMITS: usize = 32768;
166
167 #[tokio::test]
168 async fn test_pause_and_resume() {
169 let (barrier_tx, barrier_rx) = mpsc::unbounded_channel();
170
171 let table_dml_handle = TableDmlHandle::new(vec![], TEST_DML_CHANNEL_INIT_PERMITS);
172
173 let source_stream = table_dml_handle
174 .stream_reader()
175 .into_data_stream_for_test()
176 .map_err(StreamExecutorError::from);
177
178 let mut write_handle1 = table_dml_handle
179 .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID1)
180 .unwrap();
181 let mut write_handle2 = table_dml_handle
182 .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID2)
183 .unwrap();
184
185 let barrier_stream = barrier_to_message_stream(barrier_rx).boxed();
186 let stream = StreamReaderWithPause::<true, StreamChunk>::new(barrier_stream, source_stream);
187 pin_mut!(stream);
188
189 macro_rules! next {
190 () => {
191 stream
192 .next()
193 .now_or_never()
194 .flatten()
195 .map(|result| result.unwrap())
196 };
197 }
198
199 write_handle1.begin().unwrap();
202 write_handle1
203 .write_chunk(StreamChunk::default())
204 .await
205 .unwrap();
206 assert_matches!(next!().unwrap(), Either::Right(_));
209 barrier_tx
211 .send(Barrier::new_test_barrier(test_epoch(1)))
212 .unwrap();
213 assert_matches!(next!().unwrap(), Either::Left(_));
214
215 stream.pause_stream();
217
218 barrier_tx
220 .send(Barrier::new_test_barrier(test_epoch(2)))
221 .unwrap();
222 write_handle2.begin().unwrap();
224 write_handle2
225 .write_chunk(StreamChunk::default())
226 .await
227 .unwrap();
228 assert_matches!(next!().unwrap(), Either::Left(_));
232 assert!(next!().is_none());
234
235 stream.resume_stream();
237 assert_matches!(next!().unwrap(), Either::Right(_));
239 }
240}