risingwave_stream/executor/
stream_reader.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::pin::Pin;
16use std::task::Poll;
17
18use either::Either;
19use futures::stream::BoxStream;
20use futures::{Stream, StreamExt, TryStreamExt};
21
22use crate::executor::Message;
23use crate::executor::error::StreamExecutorResult;
24
25type ExecutorMessageStream = BoxStream<'static, StreamExecutorResult<Message>>;
26type StreamReaderData<M> = StreamExecutorResult<Either<Message, M>>;
27type ReaderArm<M> = BoxStream<'static, StreamReaderData<M>>;
28
29mod stream_reader_with_pause {
30    use futures::stream::{PollNext, SelectWithStrategy, select_with_strategy};
31
32    use crate::executor::stream_reader::ReaderArm;
33
34    pub(super) type StreamReaderWithPauseInner<M, const BIASED: bool> = SelectWithStrategy<
35        ReaderArm<M>,
36        ReaderArm<M>,
37        impl FnMut(&mut PollNext) -> PollNext,
38        PollNext,
39    >;
40
41    #[define_opaque(StreamReaderWithPauseInner)]
42    pub(super) fn new_inner<M, const BIASED: bool>(
43        message_stream: ReaderArm<M>,
44        data_stream: ReaderArm<M>,
45    ) -> StreamReaderWithPauseInner<M, BIASED> {
46        let strategy = if BIASED {
47            |_: &mut PollNext| PollNext::Left
48        } else {
49            // The poll strategy is not biased: we poll the two streams in a round robin way.
50            |last: &mut PollNext| last.toggle()
51        };
52        select_with_strategy(message_stream, data_stream, strategy)
53    }
54}
55
56use stream_reader_with_pause::*;
57
58/// [`StreamReaderWithPause`] merges two streams, with one receiving barriers (and maybe other types
59/// of messages) and the other receiving data only (no barrier). The merged stream can be paused
60/// (`StreamReaderWithPause::pause_stream`) and resumed (`StreamReaderWithPause::resume_stream`).
61/// A paused stream will not receive any data from either original stream until a barrier arrives
62/// and the stream is resumed.
63///
64/// ## Priority
65///
66/// If `BIASED` is `true`, the left-hand stream (the one receiving barriers) will get a higher
67/// priority over the right-hand one. Otherwise, the two streams will be polled in a round robin
68/// fashion.
69pub(super) struct StreamReaderWithPause<const BIASED: bool, M> {
70    inner: StreamReaderWithPauseInner<M, BIASED>,
71    /// Whether the source stream is paused.
72    paused: bool,
73}
74
75impl<const BIASED: bool, M: Send + 'static> StreamReaderWithPause<BIASED, M> {
76    /// Construct a `StreamReaderWithPause` with one stream receiving barrier messages (and maybe
77    /// other types of messages) and the other receiving data only (no barrier).
78    pub fn new(
79        message_stream: ExecutorMessageStream,
80        data_stream: impl Stream<Item = StreamExecutorResult<M>> + Send + 'static,
81    ) -> Self {
82        let message_stream_arm = message_stream.map_ok(Either::Left).boxed();
83        let data_stream_arm = data_stream.map_ok(Either::Right).boxed();
84        let inner = new_inner(message_stream_arm, data_stream_arm);
85        Self {
86            inner,
87            paused: false,
88        }
89    }
90
91    #[allow(dead_code)]
92    pub fn only_left(message_stream: ExecutorMessageStream) -> Self {
93        Self::new(message_stream, futures::stream::empty().boxed())
94    }
95
96    /// Replace the data stream with a new one for given `stream`. Used for split change.
97    pub fn replace_data_stream(
98        &mut self,
99        data_stream: impl Stream<Item = StreamExecutorResult<M>> + Send + 'static,
100    ) {
101        // Take the barrier receiver arm.
102        let barrier_receiver_arm = std::mem::replace(
103            self.inner.get_mut().0,
104            futures::stream::once(async { unreachable!("placeholder") }).boxed(),
105        );
106
107        // Note: create a new `SelectWithStrategy` instead of replacing the source stream arm here,
108        // to ensure the internal state of the `SelectWithStrategy` is reset. (#6300)
109        self.inner = new_inner(
110            barrier_receiver_arm,
111            data_stream.map_ok(Either::Right).boxed(),
112        );
113    }
114
115    /// Pause the data stream.
116    pub fn pause_stream(&mut self) {
117        if self.paused {
118            tracing::warn!("already paused");
119        }
120        tracing::info!("data stream paused");
121        self.paused = true;
122    }
123
124    /// Resume the data stream.
125    pub fn resume_stream(&mut self) {
126        if !self.paused {
127            tracing::warn!("not paused");
128        }
129        tracing::info!("data stream resumed");
130        self.paused = false;
131    }
132}
133
134impl<const BIASED: bool, M> Stream for StreamReaderWithPause<BIASED, M> {
135    type Item = StreamReaderData<M>;
136
137    fn poll_next(
138        mut self: Pin<&mut Self>,
139        ctx: &mut std::task::Context<'_>,
140    ) -> Poll<Option<Self::Item>> {
141        if self.paused {
142            // Note: It is safe here to poll the left arm even if it contains streaming messages
143            // other than barriers: after the upstream executor sends a `Mutation::Pause`, there
144            // should be no more message until a `Mutation::Update` and a 'Mutation::Resume`.
145            self.inner.get_mut().0.poll_next_unpin(ctx)
146        } else {
147            // TODO: We may need to prioritize the data stream (right-hand stream) after resuming
148            // from the paused state.
149            self.inner.poll_next_unpin(ctx)
150        }
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use assert_matches::assert_matches;
157    use futures::{FutureExt, pin_mut};
158    use risingwave_common::array::StreamChunk;
159    use risingwave_common::transaction::transaction_id::TxnId;
160    use risingwave_common::util::epoch::test_epoch;
161    use risingwave_dml::TableDmlHandle;
162    use tokio::sync::mpsc;
163
164    use super::*;
165    use crate::executor::source::barrier_to_message_stream;
166    use crate::executor::{Barrier, StreamExecutorError};
167
168    const TEST_TRANSACTION_ID1: TxnId = 0;
169    const TEST_TRANSACTION_ID2: TxnId = 1;
170    const TEST_SESSION_ID: u32 = 0;
171    const TEST_DML_CHANNEL_INIT_PERMITS: usize = 32768;
172
173    #[tokio::test]
174    async fn test_pause_and_resume() {
175        let (barrier_tx, barrier_rx) = mpsc::unbounded_channel();
176
177        let table_dml_handle = TableDmlHandle::new(vec![], TEST_DML_CHANNEL_INIT_PERMITS);
178
179        let source_stream = table_dml_handle
180            .stream_reader()
181            .into_data_stream_for_test()
182            .map_err(StreamExecutorError::from);
183
184        let mut write_handle1 = table_dml_handle
185            .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID1)
186            .unwrap();
187        let mut write_handle2 = table_dml_handle
188            .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID2)
189            .unwrap();
190
191        let barrier_stream = barrier_to_message_stream(barrier_rx).boxed();
192        let stream = StreamReaderWithPause::<true, StreamChunk>::new(barrier_stream, source_stream);
193        pin_mut!(stream);
194
195        macro_rules! next {
196            () => {
197                stream
198                    .next()
199                    .now_or_never()
200                    .flatten()
201                    .map(|result| result.unwrap())
202            };
203        }
204
205        // Write a chunk, and we should receive it.
206
207        write_handle1.begin().unwrap();
208        write_handle1
209            .write_chunk(StreamChunk::default())
210            .await
211            .unwrap();
212        // We don't call end() here, since we test `StreamChunk` instead of `TxnMsg`.
213
214        assert_matches!(next!().unwrap(), Either::Right(_));
215        // Write a barrier, and we should receive it.
216        barrier_tx
217            .send(Barrier::new_test_barrier(test_epoch(1)))
218            .unwrap();
219        assert_matches!(next!().unwrap(), Either::Left(_));
220
221        // Pause the stream.
222        stream.pause_stream();
223
224        // Write a barrier.
225        barrier_tx
226            .send(Barrier::new_test_barrier(test_epoch(2)))
227            .unwrap();
228        // Then write a chunk.
229        write_handle2.begin().unwrap();
230        write_handle2
231            .write_chunk(StreamChunk::default())
232            .await
233            .unwrap();
234        // We don't call end() here, since we test `StreamChunk` instead of `TxnMsg`.
235
236        // We should receive the barrier.
237        assert_matches!(next!().unwrap(), Either::Left(_));
238        // We shouldn't receive the chunk.
239        assert!(next!().is_none());
240
241        // Resume the stream.
242        stream.resume_stream();
243        // Then we can receive the chunk sent when the stream is paused.
244        assert_matches!(next!().unwrap(), Either::Right(_));
245    }
246}