risingwave_stream/executor/test_utils/
mock_source.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use super::*;
16pub struct MockSource {
17    rx: mpsc::UnboundedReceiver<Message>,
18
19    /// Whether to send a `Stop` barrier on stream finish.
20    stop_on_finish: bool,
21}
22
23/// A wrapper around `Sender<Message>`.
24pub struct MessageSender(mpsc::UnboundedSender<Message>);
25
26impl MessageSender {
27    #[allow(dead_code)]
28    pub fn push_chunk(&mut self, chunk: StreamChunk) {
29        self.0.send(Message::Chunk(chunk)).unwrap();
30    }
31
32    #[allow(dead_code)]
33    pub fn push_barrier(&mut self, epoch: u64, stop: bool) {
34        let mut barrier = Barrier::new_test_barrier(epoch);
35        if stop {
36            barrier = barrier.with_stop();
37        }
38        self.0.send(Message::Barrier(barrier)).unwrap();
39    }
40
41    pub fn send_barrier(&self, barrier: Barrier) {
42        self.0.send(Message::Barrier(barrier)).unwrap();
43    }
44
45    #[allow(dead_code)]
46    pub fn push_barrier_with_prev_epoch_for_test(
47        &mut self,
48        cur_epoch: u64,
49        prev_epoch: u64,
50        stop: bool,
51    ) {
52        let mut barrier = Barrier::with_prev_epoch_for_test(cur_epoch, prev_epoch);
53        if stop {
54            barrier = barrier.with_stop();
55        }
56        self.0.send(Message::Barrier(barrier)).unwrap();
57    }
58
59    #[allow(dead_code)]
60    pub fn push_watermark(&mut self, col_idx: usize, data_type: DataType, val: ScalarImpl) {
61        self.0
62            .send(Message::Watermark(Watermark {
63                col_idx,
64                data_type,
65                val,
66            }))
67            .unwrap();
68    }
69
70    #[allow(dead_code)]
71    pub fn push_int64_watermark(&mut self, col_idx: usize, val: i64) {
72        self.push_watermark(col_idx, DataType::Int64, ScalarImpl::Int64(val));
73    }
74}
75
76impl std::fmt::Debug for MockSource {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        f.debug_struct("MockSource").finish()
79    }
80}
81
82impl MockSource {
83    #[allow(dead_code)]
84    pub fn channel() -> (MessageSender, Self) {
85        let (tx, rx) = mpsc::unbounded_channel();
86        let source = Self {
87            rx,
88            stop_on_finish: true,
89        };
90        (MessageSender(tx), source)
91    }
92
93    #[allow(dead_code)]
94    pub fn with_messages(msgs: Vec<Message>) -> Self {
95        let (tx, source) = Self::channel();
96        for msg in msgs {
97            tx.0.send(msg).unwrap();
98        }
99        source
100    }
101
102    pub fn with_chunks(chunks: Vec<StreamChunk>) -> Self {
103        let (tx, source) = Self::channel();
104        for chunk in chunks {
105            tx.0.send(Message::Chunk(chunk)).unwrap();
106        }
107        source
108    }
109
110    #[allow(dead_code)]
111    #[must_use]
112    pub fn stop_on_finish(self, stop_on_finish: bool) -> Self {
113        Self {
114            stop_on_finish,
115            ..self
116        }
117    }
118
119    pub fn into_executor(self, schema: Schema, pk_indices: Vec<usize>) -> Executor {
120        Executor::new(
121            ExecutorInfo::new(schema, pk_indices, "MockSource".to_owned(), 0),
122            self.boxed(),
123        )
124    }
125
126    #[try_stream(ok = Message, error = StreamExecutorError)]
127    async fn execute_inner(mut self: Box<Self>) {
128        let mut epoch = test_epoch(1);
129
130        while let Some(msg) = self.rx.recv().await {
131            epoch.inc_epoch();
132            yield msg;
133        }
134
135        if self.stop_on_finish {
136            yield Message::Barrier(Barrier::new_test_barrier(epoch).with_stop());
137        }
138    }
139}
140
141impl Execute for MockSource {
142    fn execute(self: Box<Self>) -> super::BoxedMessageStream {
143        self.execute_inner().boxed()
144    }
145}