risingwave_stream/executor/
union.rs1use std::collections::BTreeMap;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18use std::time::Instant;
19
20use futures::stream::{FusedStream, FuturesUnordered};
21use pin_project::pin_project;
22
23use super::watermark::BufferedWatermarks;
24use crate::executor::prelude::*;
25use crate::task::FragmentId;
26
27pub struct UnionExecutor {
29 inputs: Vec<Executor>,
30 metrics: Arc<StreamingMetrics>,
31 actor_context: ActorContextRef,
32}
33
34impl std::fmt::Debug for UnionExecutor {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 f.debug_struct("UnionExecutor").finish()
37 }
38}
39
40impl UnionExecutor {
41 pub fn new(
42 inputs: Vec<Executor>,
43 metrics: Arc<StreamingMetrics>,
44 actor_context: ActorContextRef,
45 ) -> Self {
46 Self {
47 inputs,
48 metrics,
49 actor_context,
50 }
51 }
52}
53
54impl Execute for UnionExecutor {
55 fn execute(self: Box<Self>) -> BoxedMessageStream {
56 let streams = self.inputs.into_iter().map(|e| e.execute()).collect();
57 merge(
58 streams,
59 self.metrics,
60 self.actor_context.fragment_id,
61 self.actor_context.id,
62 )
63 .boxed()
64 }
65}
66
67#[pin_project]
68struct Input {
69 #[pin]
70 inner: BoxedMessageStream,
71 id: usize,
72}
73
74impl Stream for Input {
75 type Item = MessageStreamItem;
76
77 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
78 self.project().inner.poll_next(cx)
79 }
80}
81
82#[try_stream(ok = Message, error = StreamExecutorError)]
84async fn merge(
85 inputs: Vec<BoxedMessageStream>,
86 metrics: Arc<StreamingMetrics>,
87 fragment_id: FragmentId,
88 actor_id: ActorId,
89) {
90 let input_num = inputs.len();
91 let mut active: FuturesUnordered<_> = inputs
92 .into_iter()
93 .enumerate()
94 .map(|(idx, input)| {
95 (Input {
96 id: idx,
97 inner: input,
98 })
99 .into_future()
100 })
101 .collect();
102 let mut blocked = vec![];
103 let mut current_barrier: Option<Barrier> = None;
104
105 let mut watermark_buffers = BTreeMap::<usize, BufferedWatermarks<usize>>::new();
107
108 let mut start_time = Instant::now();
109 let barrier_align = metrics.barrier_align_duration.with_guarded_label_values(&[
110 &actor_id.to_string(),
111 &fragment_id.to_string(),
112 "",
113 "Union",
114 ]);
115 loop {
116 match active.next().await {
117 Some((Some(Ok(message)), remaining)) => {
118 match message {
119 Message::Chunk(chunk) => {
120 active.push(remaining.into_future());
122 yield Message::Chunk(chunk);
123 }
124 Message::Watermark(watermark) => {
125 let id = remaining.id;
126 active.push(remaining.into_future());
128 let buffers = watermark_buffers
129 .entry(watermark.col_idx)
130 .or_insert_with(|| BufferedWatermarks::with_ids(0..input_num));
131 if let Some(selected_watermark) =
132 buffers.handle_watermark(id, watermark.clone())
133 {
134 yield Message::Watermark(selected_watermark)
135 }
136 }
137 Message::Barrier(barrier) => {
138 if blocked.is_empty() {
140 start_time = Instant::now();
141 }
142 blocked.push(remaining);
143 if let Some(cur_barrier) = current_barrier.as_ref() {
144 if barrier.epoch != cur_barrier.epoch {
145 return Err(StreamExecutorError::align_barrier(
146 cur_barrier.clone(),
147 barrier,
148 ));
149 }
150 } else {
151 current_barrier = Some(barrier);
152 }
153 }
154 }
155 }
156 Some((Some(Err(e)), _)) => return Err(e),
157 Some((None, remaining)) => {
158 return Err(StreamExecutorError::channel_closed(format!(
160 "Union from upstream {} closed unexpectedly",
161 remaining.id,
162 )));
163 }
164 None => {
165 assert!(active.is_terminated());
166 let barrier = current_barrier.take().unwrap();
167 barrier_align.inc_by(start_time.elapsed().as_nanos() as u64);
168
169 let upstreams = std::mem::take(&mut blocked);
170 active.extend(upstreams.into_iter().map(|upstream| upstream.into_future()));
171
172 yield Message::Barrier(barrier)
173 }
174 }
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use async_stream::try_stream;
181 use risingwave_common::array::stream_chunk::StreamChunkTestExt;
182 use risingwave_common::util::epoch::test_epoch;
183
184 use super::*;
185
186 #[tokio::test]
187 async fn union() {
188 let streams = vec![
189 try_stream! {
190 yield Message::Chunk(StreamChunk::from_pretty("I\n + 1"));
191 yield Message::Barrier(Barrier::new_test_barrier(test_epoch(1)));
192 yield Message::Chunk(StreamChunk::from_pretty("I\n + 2"));
193 yield Message::Barrier(Barrier::new_test_barrier(test_epoch(2)));
194 yield Message::Barrier(Barrier::new_test_barrier(test_epoch(3)));
195 yield Message::Watermark(Watermark::new(0, DataType::Int64, ScalarImpl::Int64(4)));
196 yield Message::Barrier(Barrier::new_test_barrier(test_epoch(4)));
197 }
198 .boxed(),
199 try_stream! {
200 yield Message::Chunk(StreamChunk::from_pretty("I\n + 1"));
201 yield Message::Barrier(Barrier::new_test_barrier(test_epoch(1)));
202 yield Message::Barrier(Barrier::new_test_barrier(test_epoch(2)));
203 yield Message::Chunk(StreamChunk::from_pretty("I\n + 3"));
204 yield Message::Barrier(Barrier::new_test_barrier(test_epoch(3)));
205 yield Message::Watermark(Watermark::new(0, DataType::Int64, ScalarImpl::Int64(5)));
206 yield Message::Barrier(Barrier::new_test_barrier(test_epoch(4)));
207 }
208 .boxed(),
209 ];
210 let mut output = vec![];
211 let mut merged = merge(streams, Arc::new(StreamingMetrics::unused()), 0, 0).boxed();
212
213 let result = vec![
214 Message::Chunk(StreamChunk::from_pretty("I\n + 1")),
215 Message::Chunk(StreamChunk::from_pretty("I\n + 1")),
216 Message::Barrier(Barrier::new_test_barrier(test_epoch(1))),
217 Message::Chunk(StreamChunk::from_pretty("I\n + 2")),
218 Message::Barrier(Barrier::new_test_barrier(test_epoch(2))),
219 Message::Chunk(StreamChunk::from_pretty("I\n + 3")),
220 Message::Barrier(Barrier::new_test_barrier(test_epoch(3))),
221 Message::Watermark(Watermark::new(0, DataType::Int64, ScalarImpl::Int64(4))),
222 Message::Barrier(Barrier::new_test_barrier(test_epoch(4))),
223 ];
224 for _ in 0..result.len() {
225 output.push(merged.next().await.unwrap().unwrap());
226 }
227 assert_eq!(output, result);
228 }
229}