risingwave_stream/executor/
stream_reader.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::pin::Pin;
16use std::task::Poll;
17
18use either::Either;
19use futures::stream::BoxStream;
20use futures::{Stream, StreamExt, TryStreamExt};
21
22use crate::executor::Message;
23use crate::executor::error::StreamExecutorResult;
24
25type ExecutorMessageStream = BoxStream<'static, StreamExecutorResult<Message>>;
26type StreamReaderData<M> = StreamExecutorResult<Either<Message, M>>;
27type ReaderArm<M> = BoxStream<'static, StreamReaderData<M>>;
28
29mod stream_reader_with_pause {
30    use futures::stream::{PollNext, SelectWithStrategy, select_with_strategy};
31
32    use crate::executor::stream_reader::ReaderArm;
33
34    pub(super) type StreamReaderWithPauseInner<M, const BIASED: bool> = SelectWithStrategy<
35        ReaderArm<M>,
36        ReaderArm<M>,
37        impl FnMut(&mut PollNext) -> PollNext,
38        PollNext,
39    >;
40
41    #[define_opaque(StreamReaderWithPauseInner)]
42    pub(super) fn new_inner<M, const BIASED: bool>(
43        message_stream: ReaderArm<M>,
44        data_stream: ReaderArm<M>,
45    ) -> StreamReaderWithPauseInner<M, BIASED> {
46        let strategy = if BIASED {
47            |_: &mut PollNext| PollNext::Left
48        } else {
49            // The poll strategy is not biased: we poll the two streams in a round robin way.
50            |last: &mut PollNext| last.toggle()
51        };
52        select_with_strategy(message_stream, data_stream, strategy)
53    }
54}
55
56use stream_reader_with_pause::*;
57
58/// [`StreamReaderWithPause`] merges two streams, with one receiving barriers (and maybe other types
59/// of messages) and the other receiving data only (no barrier). The merged stream can be paused
60/// (`StreamReaderWithPause::pause_stream`) and resumed (`StreamReaderWithPause::resume_stream`).
61/// A paused stream will not receive any data from either original stream until a barrier arrives
62/// and the stream is resumed.
63///
64/// ## Priority
65///
66/// If `BIASED` is `true`, the left-hand stream (the one receiving barriers) will get a higher
67/// priority over the right-hand one. Otherwise, the two streams will be polled in a round robin
68/// fashion.
69pub(super) struct StreamReaderWithPause<const BIASED: bool, M> {
70    inner: StreamReaderWithPauseInner<M, BIASED>,
71    /// Whether the source stream is paused.
72    paused: bool,
73}
74
75impl<const BIASED: bool, M: Send + 'static> StreamReaderWithPause<BIASED, M> {
76    /// Construct a `StreamReaderWithPause` with one stream receiving barrier messages (and maybe
77    /// other types of messages) and the other receiving data only (no barrier).
78    pub fn new(
79        message_stream: ExecutorMessageStream,
80        data_stream: impl Stream<Item = StreamExecutorResult<M>> + Send + 'static,
81    ) -> Self {
82        let message_stream_arm = message_stream.map_ok(Either::Left).boxed();
83        let data_stream_arm = data_stream.map_ok(Either::Right).boxed();
84        let inner = new_inner(message_stream_arm, data_stream_arm);
85        Self {
86            inner,
87            paused: false,
88        }
89    }
90
91    /// Replace the data stream with a new one for given `stream`. Used for split change.
92    pub fn replace_data_stream(
93        &mut self,
94        data_stream: impl Stream<Item = StreamExecutorResult<M>> + Send + 'static,
95    ) {
96        // Take the barrier receiver arm.
97        let barrier_receiver_arm = std::mem::replace(
98            self.inner.get_mut().0,
99            futures::stream::once(async { unreachable!("placeholder") }).boxed(),
100        );
101
102        // Note: create a new `SelectWithStrategy` instead of replacing the source stream arm here,
103        // to ensure the internal state of the `SelectWithStrategy` is reset. (#6300)
104        self.inner = new_inner(
105            barrier_receiver_arm,
106            data_stream.map_ok(Either::Right).boxed(),
107        );
108    }
109
110    /// Pause the data stream.
111    pub fn pause_stream(&mut self) {
112        if self.paused {
113            tracing::warn!("already paused");
114        }
115        tracing::info!("data stream paused");
116        self.paused = true;
117    }
118
119    /// Resume the data stream.
120    pub fn resume_stream(&mut self) {
121        if !self.paused {
122            tracing::warn!("not paused");
123        }
124        tracing::info!("data stream resumed");
125        self.paused = false;
126    }
127}
128
129impl<const BIASED: bool, M> Stream for StreamReaderWithPause<BIASED, M> {
130    type Item = StreamReaderData<M>;
131
132    fn poll_next(
133        mut self: Pin<&mut Self>,
134        ctx: &mut std::task::Context<'_>,
135    ) -> Poll<Option<Self::Item>> {
136        if self.paused {
137            // Note: It is safe here to poll the left arm even if it contains streaming messages
138            // other than barriers: after the upstream executor sends a `Mutation::Pause`, there
139            // should be no more message until a `Mutation::Update` and a 'Mutation::Resume`.
140            self.inner.get_mut().0.poll_next_unpin(ctx)
141        } else {
142            // TODO: We may need to prioritize the data stream (right-hand stream) after resuming
143            // from the paused state.
144            self.inner.poll_next_unpin(ctx)
145        }
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use assert_matches::assert_matches;
152    use futures::{FutureExt, pin_mut};
153    use risingwave_common::array::StreamChunk;
154    use risingwave_common::transaction::transaction_id::TxnId;
155    use risingwave_common::util::epoch::test_epoch;
156    use risingwave_dml::TableDmlHandle;
157    use tokio::sync::mpsc;
158
159    use super::*;
160    use crate::executor::source::barrier_to_message_stream;
161    use crate::executor::{Barrier, StreamExecutorError};
162
163    const TEST_TRANSACTION_ID1: TxnId = 0;
164    const TEST_TRANSACTION_ID2: TxnId = 1;
165    const TEST_SESSION_ID: u32 = 0;
166    const TEST_DML_CHANNEL_INIT_PERMITS: usize = 32768;
167
168    #[tokio::test]
169    async fn test_pause_and_resume() {
170        let (barrier_tx, barrier_rx) = mpsc::unbounded_channel();
171
172        let table_dml_handle = TableDmlHandle::new(vec![], TEST_DML_CHANNEL_INIT_PERMITS);
173
174        let source_stream = table_dml_handle
175            .stream_reader()
176            .into_data_stream_for_test()
177            .map_err(StreamExecutorError::from);
178
179        let mut write_handle1 = table_dml_handle
180            .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID1)
181            .unwrap();
182        let mut write_handle2 = table_dml_handle
183            .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID2)
184            .unwrap();
185
186        let barrier_stream = barrier_to_message_stream(barrier_rx).boxed();
187        let stream = StreamReaderWithPause::<true, StreamChunk>::new(barrier_stream, source_stream);
188        pin_mut!(stream);
189
190        macro_rules! next {
191            () => {
192                stream
193                    .next()
194                    .now_or_never()
195                    .flatten()
196                    .map(|result| result.unwrap())
197            };
198        }
199
200        // Write a chunk, and we should receive it.
201
202        write_handle1.begin().unwrap();
203        write_handle1
204            .write_chunk(StreamChunk::default())
205            .await
206            .unwrap();
207        // We don't call end() here, since we test `StreamChunk` instead of `TxnMsg`.
208
209        assert_matches!(next!().unwrap(), Either::Right(_));
210        // Write a barrier, and we should receive it.
211        barrier_tx
212            .send(Barrier::new_test_barrier(test_epoch(1)))
213            .unwrap();
214        assert_matches!(next!().unwrap(), Either::Left(_));
215
216        // Pause the stream.
217        stream.pause_stream();
218
219        // Write a barrier.
220        barrier_tx
221            .send(Barrier::new_test_barrier(test_epoch(2)))
222            .unwrap();
223        // Then write a chunk.
224        write_handle2.begin().unwrap();
225        write_handle2
226            .write_chunk(StreamChunk::default())
227            .await
228            .unwrap();
229        // We don't call end() here, since we test `StreamChunk` instead of `TxnMsg`.
230
231        // We should receive the barrier.
232        assert_matches!(next!().unwrap(), Either::Left(_));
233        // We shouldn't receive the chunk.
234        assert!(next!().is_none());
235
236        // Resume the stream.
237        stream.resume_stream();
238        // Then we can receive the chunk sent when the stream is paused.
239        assert_matches!(next!().unwrap(), Either::Right(_));
240    }
241}