risingwave_stream/executor/
union.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BTreeMap;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18use std::time::Instant;
19
20use futures::stream::{FusedStream, FuturesUnordered};
21use pin_project::pin_project;
22
23use super::watermark::BufferedWatermarks;
24use crate::executor::prelude::*;
25use crate::task::FragmentId;
26
27/// `UnionExecutor` merges data from multiple inputs.
28pub struct UnionExecutor {
29    inputs: Vec<Executor>,
30    metrics: Arc<StreamingMetrics>,
31    actor_context: ActorContextRef,
32}
33
34impl std::fmt::Debug for UnionExecutor {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        f.debug_struct("UnionExecutor").finish()
37    }
38}
39
40impl UnionExecutor {
41    pub fn new(
42        inputs: Vec<Executor>,
43        metrics: Arc<StreamingMetrics>,
44        actor_context: ActorContextRef,
45    ) -> Self {
46        Self {
47            inputs,
48            metrics,
49            actor_context,
50        }
51    }
52}
53
54impl Execute for UnionExecutor {
55    fn execute(self: Box<Self>) -> BoxedMessageStream {
56        let streams = self.inputs.into_iter().map(|e| e.execute()).collect();
57        merge(
58            streams,
59            self.metrics,
60            self.actor_context.fragment_id,
61            self.actor_context.id,
62        )
63        .boxed()
64    }
65}
66
67#[pin_project]
68struct Input {
69    #[pin]
70    inner: BoxedMessageStream,
71    id: usize,
72}
73
74impl Stream for Input {
75    type Item = MessageStreamItem;
76
77    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
78        self.project().inner.poll_next(cx)
79    }
80}
81
82/// Merges input streams and aligns with barriers.
83#[try_stream(ok = Message, error = StreamExecutorError)]
84async fn merge(
85    inputs: Vec<BoxedMessageStream>,
86    metrics: Arc<StreamingMetrics>,
87    fragment_id: FragmentId,
88    actor_id: ActorId,
89) {
90    let input_num = inputs.len();
91    let mut active: FuturesUnordered<_> = inputs
92        .into_iter()
93        .enumerate()
94        .map(|(idx, input)| {
95            (Input {
96                id: idx,
97                inner: input,
98            })
99            .into_future()
100        })
101        .collect();
102    let mut blocked = vec![];
103    let mut current_barrier: Option<Barrier> = None;
104
105    // watermark column index -> `BufferedWatermarks`
106    let mut watermark_buffers = BTreeMap::<usize, BufferedWatermarks<usize>>::new();
107
108    let mut start_time = Instant::now();
109    let barrier_align = metrics.barrier_align_duration.with_guarded_label_values(&[
110        &actor_id.to_string(),
111        &fragment_id.to_string(),
112        "",
113        "Union",
114    ]);
115    loop {
116        match active.next().await {
117            Some((Some(Ok(message)), remaining)) => {
118                match message {
119                    Message::Chunk(chunk) => {
120                        // Continue polling this upstream by pushing it back to `active`.
121                        active.push(remaining.into_future());
122                        yield Message::Chunk(chunk);
123                    }
124                    Message::Watermark(watermark) => {
125                        let id = remaining.id;
126                        // Continue polling this upstream by pushing it back to `active`.
127                        active.push(remaining.into_future());
128                        let buffers = watermark_buffers
129                            .entry(watermark.col_idx)
130                            .or_insert_with(|| BufferedWatermarks::with_ids(0..input_num));
131                        if let Some(selected_watermark) =
132                            buffers.handle_watermark(id, watermark.clone())
133                        {
134                            yield Message::Watermark(selected_watermark)
135                        }
136                    }
137                    Message::Barrier(barrier) => {
138                        // Block this upstream by pushing it to `blocked`.
139                        if blocked.is_empty() {
140                            start_time = Instant::now();
141                        }
142                        blocked.push(remaining);
143                        if let Some(cur_barrier) = current_barrier.as_ref() {
144                            if barrier.epoch != cur_barrier.epoch {
145                                return Err(StreamExecutorError::align_barrier(
146                                    cur_barrier.clone(),
147                                    barrier,
148                                ));
149                            }
150                        } else {
151                            current_barrier = Some(barrier);
152                        }
153                    }
154                }
155            }
156            Some((Some(Err(e)), _)) => return Err(e),
157            Some((None, remaining)) => {
158                // tracing::error!("Union from upstream {} closed unexpectedly", remaining.id);
159                return Err(StreamExecutorError::channel_closed(format!(
160                    "Union from upstream {} closed unexpectedly",
161                    remaining.id,
162                )));
163            }
164            None => {
165                assert!(active.is_terminated());
166                let barrier = current_barrier.take().unwrap();
167                barrier_align.inc_by(start_time.elapsed().as_nanos() as u64);
168
169                let upstreams = std::mem::take(&mut blocked);
170                active.extend(upstreams.into_iter().map(|upstream| upstream.into_future()));
171
172                yield Message::Barrier(barrier)
173            }
174        }
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use async_stream::try_stream;
181    use risingwave_common::array::stream_chunk::StreamChunkTestExt;
182    use risingwave_common::util::epoch::test_epoch;
183
184    use super::*;
185
186    #[tokio::test]
187    async fn union() {
188        let streams = vec![
189            try_stream! {
190                yield Message::Chunk(StreamChunk::from_pretty("I\n + 1"));
191                yield Message::Barrier(Barrier::new_test_barrier(test_epoch(1)));
192                yield Message::Chunk(StreamChunk::from_pretty("I\n + 2"));
193                yield Message::Barrier(Barrier::new_test_barrier(test_epoch(2)));
194                yield Message::Barrier(Barrier::new_test_barrier(test_epoch(3)));
195                yield Message::Watermark(Watermark::new(0, DataType::Int64, ScalarImpl::Int64(4)));
196                yield Message::Barrier(Barrier::new_test_barrier(test_epoch(4)));
197            }
198            .boxed(),
199            try_stream! {
200                yield Message::Chunk(StreamChunk::from_pretty("I\n + 1"));
201                yield Message::Barrier(Barrier::new_test_barrier(test_epoch(1)));
202                yield Message::Barrier(Barrier::new_test_barrier(test_epoch(2)));
203                yield Message::Chunk(StreamChunk::from_pretty("I\n + 3"));
204                yield Message::Barrier(Barrier::new_test_barrier(test_epoch(3)));
205                yield Message::Watermark(Watermark::new(0, DataType::Int64, ScalarImpl::Int64(5)));
206                yield Message::Barrier(Barrier::new_test_barrier(test_epoch(4)));
207            }
208            .boxed(),
209        ];
210        let mut output = vec![];
211        let mut merged = merge(streams, Arc::new(StreamingMetrics::unused()), 0, 0).boxed();
212
213        let result = vec![
214            Message::Chunk(StreamChunk::from_pretty("I\n + 1")),
215            Message::Chunk(StreamChunk::from_pretty("I\n + 1")),
216            Message::Barrier(Barrier::new_test_barrier(test_epoch(1))),
217            Message::Chunk(StreamChunk::from_pretty("I\n + 2")),
218            Message::Barrier(Barrier::new_test_barrier(test_epoch(2))),
219            Message::Chunk(StreamChunk::from_pretty("I\n + 3")),
220            Message::Barrier(Barrier::new_test_barrier(test_epoch(3))),
221            Message::Watermark(Watermark::new(0, DataType::Int64, ScalarImpl::Int64(4))),
222            Message::Barrier(Barrier::new_test_barrier(test_epoch(4))),
223        ];
224        for _ in 0..result.len() {
225            output.push(merged.next().await.unwrap().unwrap());
226        }
227        assert_eq!(output, result);
228    }
229}