risingwave_stream/executor/wrapper/
epoch_check.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use futures::{StreamExt, pin_mut};
18use futures_async_stream::try_stream;
19
20use crate::executor::error::StreamExecutorError;
21use crate::executor::{ExecutorInfo, Message, MessageStream};
22
23/// Streams wrapped by `epoch_check` will check whether the first message received is a barrier, and
24/// the epoch in the barriers are monotonically increasing.
25#[try_stream(ok = Message, error = StreamExecutorError)]
26pub async fn epoch_check(info: Arc<ExecutorInfo>, input: impl MessageStream) {
27    // Epoch number recorded from last barrier message.
28    let mut last_epoch = None;
29
30    pin_mut!(input);
31    while let Some(message) = input.next().await {
32        let message = message?;
33
34        if let Message::Barrier(b) = &message {
35            let new_epoch = b.epoch.curr;
36            let stale = last_epoch
37                .map(|last_epoch| last_epoch > new_epoch)
38                .unwrap_or(false);
39
40            if stale {
41                panic!(
42                    "epoch check failed on {}: last epoch is {:?}, while the epoch of incoming barrier is {}.\nstale barrier: {:?}",
43                    info.identity, last_epoch, new_epoch, b
44                );
45            }
46
47            if let Some(last_epoch) = last_epoch
48                && !b.is_with_stop_mutation()
49            {
50                assert_eq!(
51                    b.epoch.prev, last_epoch,
52                    "missing barrier: last barrier's epoch = {}, while current barrier prev={} curr={}",
53                    last_epoch, b.epoch.prev, b.epoch.curr
54                );
55            }
56
57            last_epoch = Some(new_epoch);
58        } else if last_epoch.is_none() && !info.identity.contains("BatchQuery") {
59            panic!(
60                "epoch check failed on {}: the first message must be a barrier",
61                info.identity
62            )
63        }
64
65        yield message;
66    }
67}
68
69#[cfg(test)]
70mod tests {
71    use assert_matches::assert_matches;
72    use futures::pin_mut;
73    use risingwave_common::array::StreamChunk;
74    use risingwave_common::util::epoch::test_epoch;
75
76    use super::*;
77    use crate::executor::test_utils::MockSource;
78
79    #[tokio::test]
80    async fn test_epoch_ok() {
81        let (mut tx, source) = MockSource::channel();
82        let source = source.into_executor(Default::default(), vec![]);
83        tx.push_barrier(test_epoch(1), false);
84        tx.push_chunk(StreamChunk::default());
85        tx.push_barrier(test_epoch(2), false);
86        tx.push_barrier(test_epoch(3), false);
87        tx.push_barrier(test_epoch(4), false);
88
89        let checked = epoch_check(source.info().clone().into(), source.execute());
90        pin_mut!(checked);
91
92        assert_matches!(checked.next().await.unwrap().unwrap(), Message::Barrier(b) if b.epoch.curr == test_epoch(1));
93        assert_matches!(checked.next().await.unwrap().unwrap(), Message::Chunk(_));
94        assert_matches!(checked.next().await.unwrap().unwrap(), Message::Barrier(b) if b.epoch.curr == test_epoch(2));
95        assert_matches!(checked.next().await.unwrap().unwrap(), Message::Barrier(b) if b.epoch.curr == test_epoch(3));
96        assert_matches!(checked.next().await.unwrap().unwrap(), Message::Barrier(b) if b.epoch.curr == test_epoch(4));
97    }
98
99    #[should_panic]
100    #[tokio::test]
101    async fn test_epoch_bad() {
102        let (mut tx, source) = MockSource::channel();
103        let source = source.into_executor(Default::default(), vec![]);
104        tx.push_barrier(test_epoch(100), false);
105        tx.push_chunk(StreamChunk::default());
106        tx.push_barrier(test_epoch(514), false);
107        tx.push_barrier(test_epoch(514), false);
108        tx.push_barrier(test_epoch(114), false);
109
110        let checked = epoch_check(source.info().clone().into(), source.execute());
111        pin_mut!(checked);
112
113        assert_matches!(checked.next().await.unwrap().unwrap(), Message::Barrier(b) if b.epoch.curr == test_epoch(100));
114        assert_matches!(checked.next().await.unwrap().unwrap(), Message::Chunk(_));
115        assert_matches!(checked.next().await.unwrap().unwrap(), Message::Barrier(b) if b.epoch.curr == test_epoch(514));
116        assert_matches!(checked.next().await.unwrap().unwrap(), Message::Barrier(b) if b.epoch.curr == test_epoch(514));
117
118        checked.next().await.unwrap().unwrap(); // should panic
119    }
120
121    #[should_panic]
122    #[tokio::test]
123    async fn test_epoch_first_not_barrier() {
124        let (mut tx, source) = MockSource::channel();
125        let source = source.into_executor(Default::default(), vec![]);
126        tx.push_chunk(StreamChunk::default());
127        tx.push_barrier(test_epoch(114), false);
128
129        let checked = epoch_check(source.info().clone().into(), source.execute());
130        pin_mut!(checked);
131
132        checked.next().await.unwrap().unwrap(); // should panic
133    }
134
135    #[tokio::test]
136    async fn test_empty() {
137        let (_, source) = MockSource::channel();
138        let source = source
139            .stop_on_finish(false)
140            .into_executor(Default::default(), vec![]);
141        let checked = epoch_check(source.info().clone().into(), source.execute());
142        pin_mut!(checked);
143
144        assert!(checked.next().await.transpose().unwrap().is_none());
145    }
146}