risingwave_stream/executor/
union.rs1use std::pin::Pin;
16use std::task::{Context, Poll};
17
18use crate::executor::DynamicReceivers;
19use crate::executor::exchange::input::{BoxedInput, Input};
20use crate::executor::prelude::*;
21
22pub struct UnionExecutor {
24 inputs: Vec<Executor>,
25 metrics: Arc<StreamingMetrics>,
26 actor_context: ActorContextRef,
27}
28
29impl std::fmt::Debug for UnionExecutor {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 f.debug_struct("UnionExecutor").finish()
32 }
33}
34
35impl UnionExecutor {
36 pub fn new(
37 inputs: Vec<Executor>,
38 metrics: Arc<StreamingMetrics>,
39 actor_context: ActorContextRef,
40 ) -> Self {
41 Self {
42 inputs,
43 metrics,
44 actor_context,
45 }
46 }
47}
48
49impl Execute for UnionExecutor {
50 fn execute(self: Box<Self>) -> BoxedMessageStream {
51 let upstreams = self
52 .inputs
53 .into_iter()
54 .map(|e| e.execute())
55 .enumerate()
56 .map(|(id, input)| {
57 Box::pin(UnionExecutorInput { id, inner: input })
58 as BoxedInput<usize, MessageStreamItem>
59 })
60 .collect();
61
62 let barrier_align = self
63 .metrics
64 .barrier_align_duration
65 .with_guarded_label_values(&[
66 self.actor_context.id.to_string().as_str(),
67 self.actor_context.fragment_id.to_string().as_str(),
68 "",
69 "Union",
70 ]);
71
72 let union_receivers = DynamicReceivers::new(upstreams, Some(barrier_align), None);
73
74 union_receivers.boxed()
75 }
76}
77
78struct UnionExecutorInput {
79 id: usize,
80 inner: BoxedMessageStream,
81}
82
83impl Input for UnionExecutorInput {
84 type InputId = usize;
85
86 fn id(&self) -> Self::InputId {
87 self.id
88 }
89}
90
91impl Stream for UnionExecutorInput {
92 type Item = MessageStreamItem;
93
94 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
95 self.inner.as_mut().poll_next(cx)
96 }
97}
98
99#[cfg(test)]
100mod tests {
101 use async_stream::try_stream;
102 use risingwave_common::array::stream_chunk::StreamChunkTestExt;
103 use risingwave_common::util::epoch::test_epoch;
104
105 use super::*;
106
107 #[tokio::test]
108 async fn union() {
109 let streams = vec![
110 try_stream! {
111 yield Message::Chunk(StreamChunk::from_pretty("I\n + 1"));
112 yield Message::Barrier(Barrier::new_test_barrier(test_epoch(1)));
113 yield Message::Chunk(StreamChunk::from_pretty("I\n + 2"));
114 yield Message::Barrier(Barrier::new_test_barrier(test_epoch(2)));
115 yield Message::Barrier(Barrier::new_test_barrier(test_epoch(3)));
116 yield Message::Watermark(Watermark::new(0, DataType::Int64, ScalarImpl::Int64(4)));
117 yield Message::Barrier(Barrier::new_test_barrier(test_epoch(4)));
118 }
119 .boxed(),
120 try_stream! {
121 yield Message::Chunk(StreamChunk::from_pretty("I\n + 1"));
122 yield Message::Barrier(Barrier::new_test_barrier(test_epoch(1)));
123 yield Message::Barrier(Barrier::new_test_barrier(test_epoch(2)));
124 yield Message::Chunk(StreamChunk::from_pretty("I\n + 3"));
125 yield Message::Barrier(Barrier::new_test_barrier(test_epoch(3)));
126 yield Message::Watermark(Watermark::new(0, DataType::Int64, ScalarImpl::Int64(5)));
127 yield Message::Barrier(Barrier::new_test_barrier(test_epoch(4)));
128 }
129 .boxed(),
130 ];
131 let upstreams = streams
132 .into_iter()
133 .enumerate()
134 .map(|(id, input)| {
135 Box::pin(UnionExecutorInput { id, inner: input })
136 as BoxedInput<usize, MessageStreamItem>
137 })
138 .collect();
139 let mut output = vec![];
140 let mut union = DynamicReceivers::new(upstreams, None, None).boxed();
141
142 let result = vec![
143 Message::Chunk(StreamChunk::from_pretty("I\n + 1")),
144 Message::Chunk(StreamChunk::from_pretty("I\n + 1")),
145 Message::Barrier(Barrier::new_test_barrier(test_epoch(1))),
146 Message::Chunk(StreamChunk::from_pretty("I\n + 2")),
147 Message::Barrier(Barrier::new_test_barrier(test_epoch(2))),
148 Message::Chunk(StreamChunk::from_pretty("I\n + 3")),
149 Message::Barrier(Barrier::new_test_barrier(test_epoch(3))),
150 Message::Watermark(Watermark::new(0, DataType::Int64, ScalarImpl::Int64(4))),
151 Message::Barrier(Barrier::new_test_barrier(test_epoch(4))),
152 ];
153 for _ in 0..result.len() {
154 output.push(union.next().await.unwrap().unwrap());
155 }
156 assert_eq!(output, result);
157 }
158}