risingwave_stream/executor/
stream_reader.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::pin::Pin;
16use std::task::Poll;
17
18use either::Either;
19use futures::stream::BoxStream;
20use futures::{Stream, StreamExt, TryStreamExt};
21
22use crate::executor::Message;
23use crate::executor::error::StreamExecutorResult;
24
25type ExecutorMessageStream = BoxStream<'static, StreamExecutorResult<Message>>;
26type StreamReaderData<M> = StreamExecutorResult<Either<Message, M>>;
27type ReaderArm<M> = BoxStream<'static, StreamReaderData<M>>;
28
29mod stream_reader_with_pause {
30    use futures::stream::{PollNext, SelectWithStrategy, select_with_strategy};
31
32    use crate::executor::stream_reader::ReaderArm;
33
34    pub(super) type StreamReaderWithPauseInner<M, const BIASED: bool> = SelectWithStrategy<
35        ReaderArm<M>,
36        ReaderArm<M>,
37        impl FnMut(&mut PollNext) -> PollNext,
38        PollNext,
39    >;
40
41    pub(super) fn new_inner<M, const BIASED: bool>(
42        message_stream: ReaderArm<M>,
43        data_stream: ReaderArm<M>,
44    ) -> StreamReaderWithPauseInner<M, BIASED> {
45        let strategy = if BIASED {
46            |_: &mut PollNext| PollNext::Left
47        } else {
48            // The poll strategy is not biased: we poll the two streams in a round robin way.
49            |last: &mut PollNext| last.toggle()
50        };
51        select_with_strategy(message_stream, data_stream, strategy)
52    }
53}
54
55use stream_reader_with_pause::*;
56
57/// [`StreamReaderWithPause`] merges two streams, with one receiving barriers (and maybe other types
58/// of messages) and the other receiving data only (no barrier). The merged stream can be paused
59/// (`StreamReaderWithPause::pause_stream`) and resumed (`StreamReaderWithPause::resume_stream`).
60/// A paused stream will not receive any data from either original stream until a barrier arrives
61/// and the stream is resumed.
62///
63/// ## Priority
64///
65/// If `BIASED` is `true`, the left-hand stream (the one receiving barriers) will get a higher
66/// priority over the right-hand one. Otherwise, the two streams will be polled in a round robin
67/// fashion.
68pub(super) struct StreamReaderWithPause<const BIASED: bool, M> {
69    inner: StreamReaderWithPauseInner<M, BIASED>,
70    /// Whether the source stream is paused.
71    paused: bool,
72}
73
74impl<const BIASED: bool, M: Send + 'static> StreamReaderWithPause<BIASED, M> {
75    /// Construct a `StreamReaderWithPause` with one stream receiving barrier messages (and maybe
76    /// other types of messages) and the other receiving data only (no barrier).
77    pub fn new(
78        message_stream: ExecutorMessageStream,
79        data_stream: impl Stream<Item = StreamExecutorResult<M>> + Send + 'static,
80    ) -> Self {
81        let message_stream_arm = message_stream.map_ok(Either::Left).boxed();
82        let data_stream_arm = data_stream.map_ok(Either::Right).boxed();
83        let inner = new_inner(message_stream_arm, data_stream_arm);
84        Self {
85            inner,
86            paused: false,
87        }
88    }
89
90    /// Replace the data stream with a new one for given `stream`. Used for split change.
91    pub fn replace_data_stream(
92        &mut self,
93        data_stream: impl Stream<Item = StreamExecutorResult<M>> + Send + 'static,
94    ) {
95        // Take the barrier receiver arm.
96        let barrier_receiver_arm = std::mem::replace(
97            self.inner.get_mut().0,
98            futures::stream::once(async { unreachable!("placeholder") }).boxed(),
99        );
100
101        // Note: create a new `SelectWithStrategy` instead of replacing the source stream arm here,
102        // to ensure the internal state of the `SelectWithStrategy` is reset. (#6300)
103        self.inner = new_inner(
104            barrier_receiver_arm,
105            data_stream.map_ok(Either::Right).boxed(),
106        );
107    }
108
109    /// Pause the data stream.
110    pub fn pause_stream(&mut self) {
111        if self.paused {
112            tracing::warn!("already paused");
113        }
114        tracing::info!("data stream paused");
115        self.paused = true;
116    }
117
118    /// Resume the data stream.
119    pub fn resume_stream(&mut self) {
120        if !self.paused {
121            tracing::warn!("not paused");
122        }
123        tracing::info!("data stream resumed");
124        self.paused = false;
125    }
126}
127
128impl<const BIASED: bool, M> Stream for StreamReaderWithPause<BIASED, M> {
129    type Item = StreamReaderData<M>;
130
131    fn poll_next(
132        mut self: Pin<&mut Self>,
133        ctx: &mut std::task::Context<'_>,
134    ) -> Poll<Option<Self::Item>> {
135        if self.paused {
136            // Note: It is safe here to poll the left arm even if it contains streaming messages
137            // other than barriers: after the upstream executor sends a `Mutation::Pause`, there
138            // should be no more message until a `Mutation::Update` and a 'Mutation::Resume`.
139            self.inner.get_mut().0.poll_next_unpin(ctx)
140        } else {
141            // TODO: We may need to prioritize the data stream (right-hand stream) after resuming
142            // from the paused state.
143            self.inner.poll_next_unpin(ctx)
144        }
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use assert_matches::assert_matches;
151    use futures::{FutureExt, pin_mut};
152    use risingwave_common::array::StreamChunk;
153    use risingwave_common::transaction::transaction_id::TxnId;
154    use risingwave_common::util::epoch::test_epoch;
155    use risingwave_dml::TableDmlHandle;
156    use tokio::sync::mpsc;
157
158    use super::*;
159    use crate::executor::source::barrier_to_message_stream;
160    use crate::executor::{Barrier, StreamExecutorError};
161
162    const TEST_TRANSACTION_ID1: TxnId = 0;
163    const TEST_TRANSACTION_ID2: TxnId = 1;
164    const TEST_SESSION_ID: u32 = 0;
165    const TEST_DML_CHANNEL_INIT_PERMITS: usize = 32768;
166
167    #[tokio::test]
168    async fn test_pause_and_resume() {
169        let (barrier_tx, barrier_rx) = mpsc::unbounded_channel();
170
171        let table_dml_handle = TableDmlHandle::new(vec![], TEST_DML_CHANNEL_INIT_PERMITS);
172
173        let source_stream = table_dml_handle
174            .stream_reader()
175            .into_data_stream_for_test()
176            .map_err(StreamExecutorError::from);
177
178        let mut write_handle1 = table_dml_handle
179            .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID1)
180            .unwrap();
181        let mut write_handle2 = table_dml_handle
182            .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID2)
183            .unwrap();
184
185        let barrier_stream = barrier_to_message_stream(barrier_rx).boxed();
186        let stream = StreamReaderWithPause::<true, StreamChunk>::new(barrier_stream, source_stream);
187        pin_mut!(stream);
188
189        macro_rules! next {
190            () => {
191                stream
192                    .next()
193                    .now_or_never()
194                    .flatten()
195                    .map(|result| result.unwrap())
196            };
197        }
198
199        // Write a chunk, and we should receive it.
200
201        write_handle1.begin().unwrap();
202        write_handle1
203            .write_chunk(StreamChunk::default())
204            .await
205            .unwrap();
206        // We don't call end() here, since we test `StreamChunk` instead of `TxnMsg`.
207
208        assert_matches!(next!().unwrap(), Either::Right(_));
209        // Write a barrier, and we should receive it.
210        barrier_tx
211            .send(Barrier::new_test_barrier(test_epoch(1)))
212            .unwrap();
213        assert_matches!(next!().unwrap(), Either::Left(_));
214
215        // Pause the stream.
216        stream.pause_stream();
217
218        // Write a barrier.
219        barrier_tx
220            .send(Barrier::new_test_barrier(test_epoch(2)))
221            .unwrap();
222        // Then write a chunk.
223        write_handle2.begin().unwrap();
224        write_handle2
225            .write_chunk(StreamChunk::default())
226            .await
227            .unwrap();
228        // We don't call end() here, since we test `StreamChunk` instead of `TxnMsg`.
229
230        // We should receive the barrier.
231        assert_matches!(next!().unwrap(), Either::Left(_));
232        // We shouldn't receive the chunk.
233        assert!(next!().is_none());
234
235        // Resume the stream.
236        stream.resume_stream();
237        // Then we can receive the chunk sent when the stream is paused.
238        assert_matches!(next!().unwrap(), Either::Right(_));
239    }
240}