risingwave_stream/executor/
union.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::pin::Pin;
16use std::task::{Context, Poll};
17
18use crate::executor::DynamicReceivers;
19use crate::executor::exchange::input::{BoxedInput, Input};
20use crate::executor::prelude::*;
21
22/// `UnionExecutor` merges data from multiple inputs.
23pub struct UnionExecutor {
24    inputs: Vec<Executor>,
25    metrics: Arc<StreamingMetrics>,
26    actor_context: ActorContextRef,
27}
28
29impl std::fmt::Debug for UnionExecutor {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        f.debug_struct("UnionExecutor").finish()
32    }
33}
34
35impl UnionExecutor {
36    pub fn new(
37        inputs: Vec<Executor>,
38        metrics: Arc<StreamingMetrics>,
39        actor_context: ActorContextRef,
40    ) -> Self {
41        Self {
42            inputs,
43            metrics,
44            actor_context,
45        }
46    }
47}
48
49impl Execute for UnionExecutor {
50    fn execute(self: Box<Self>) -> BoxedMessageStream {
51        let upstreams = self
52            .inputs
53            .into_iter()
54            .map(|e| e.execute())
55            .enumerate()
56            .map(|(id, input)| {
57                Box::pin(UnionExecutorInput { id, inner: input })
58                    as BoxedInput<usize, MessageStreamItem>
59            })
60            .collect();
61
62        let barrier_align = self
63            .metrics
64            .barrier_align_duration
65            .with_guarded_label_values(&[
66                self.actor_context.id.to_string().as_str(),
67                self.actor_context.fragment_id.to_string().as_str(),
68                "",
69                "Union",
70            ]);
71
72        let union_receivers = DynamicReceivers::new(upstreams, Some(barrier_align), None);
73
74        union_receivers.boxed()
75    }
76}
77
78struct UnionExecutorInput {
79    id: usize,
80    inner: BoxedMessageStream,
81}
82
83impl Input for UnionExecutorInput {
84    type InputId = usize;
85
86    fn id(&self) -> Self::InputId {
87        self.id
88    }
89}
90
91impl Stream for UnionExecutorInput {
92    type Item = MessageStreamItem;
93
94    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
95        self.inner.as_mut().poll_next(cx)
96    }
97}
98
99#[cfg(test)]
100mod tests {
101    use async_stream::try_stream;
102    use risingwave_common::array::stream_chunk::StreamChunkTestExt;
103    use risingwave_common::util::epoch::test_epoch;
104
105    use super::*;
106
107    #[tokio::test]
108    async fn union() {
109        let streams = vec![
110            try_stream! {
111                yield Message::Chunk(StreamChunk::from_pretty("I\n + 1"));
112                yield Message::Barrier(Barrier::new_test_barrier(test_epoch(1)));
113                yield Message::Chunk(StreamChunk::from_pretty("I\n + 2"));
114                yield Message::Barrier(Barrier::new_test_barrier(test_epoch(2)));
115                yield Message::Barrier(Barrier::new_test_barrier(test_epoch(3)));
116                yield Message::Watermark(Watermark::new(0, DataType::Int64, ScalarImpl::Int64(4)));
117                yield Message::Barrier(Barrier::new_test_barrier(test_epoch(4)));
118            }
119            .boxed(),
120            try_stream! {
121                yield Message::Chunk(StreamChunk::from_pretty("I\n + 1"));
122                yield Message::Barrier(Barrier::new_test_barrier(test_epoch(1)));
123                yield Message::Barrier(Barrier::new_test_barrier(test_epoch(2)));
124                yield Message::Chunk(StreamChunk::from_pretty("I\n + 3"));
125                yield Message::Barrier(Barrier::new_test_barrier(test_epoch(3)));
126                yield Message::Watermark(Watermark::new(0, DataType::Int64, ScalarImpl::Int64(5)));
127                yield Message::Barrier(Barrier::new_test_barrier(test_epoch(4)));
128            }
129            .boxed(),
130        ];
131        let upstreams = streams
132            .into_iter()
133            .enumerate()
134            .map(|(id, input)| {
135                Box::pin(UnionExecutorInput { id, inner: input })
136                    as BoxedInput<usize, MessageStreamItem>
137            })
138            .collect();
139        let mut output = vec![];
140        let mut union = DynamicReceivers::new(upstreams, None, None).boxed();
141
142        let result = vec![
143            Message::Chunk(StreamChunk::from_pretty("I\n + 1")),
144            Message::Chunk(StreamChunk::from_pretty("I\n + 1")),
145            Message::Barrier(Barrier::new_test_barrier(test_epoch(1))),
146            Message::Chunk(StreamChunk::from_pretty("I\n + 2")),
147            Message::Barrier(Barrier::new_test_barrier(test_epoch(2))),
148            Message::Chunk(StreamChunk::from_pretty("I\n + 3")),
149            Message::Barrier(Barrier::new_test_barrier(test_epoch(3))),
150            Message::Watermark(Watermark::new(0, DataType::Int64, ScalarImpl::Int64(4))),
151            Message::Barrier(Barrier::new_test_barrier(test_epoch(4))),
152        ];
153        for _ in 0..result.len() {
154            output.push(union.next().await.unwrap().unwrap());
155        }
156        assert_eq!(output, result);
157    }
158}