risingwave_stream/executor/
merge.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::VecDeque;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18
19use anyhow::Context as _;
20use futures::future::try_join_all;
21use risingwave_common::array::StreamChunkBuilder;
22use risingwave_common::config::MetricLevel;
23use tokio::sync::mpsc;
24use tokio::time::Instant;
25
26use super::exchange::input::BoxedActorInput;
27use super::*;
28use crate::executor::exchange::input::{
29    assert_equal_dispatcher_barrier, new_input, process_dispatcher_msg,
30};
31use crate::executor::prelude::*;
32use crate::task::LocalBarrierManager;
33
34pub type SelectReceivers = DynamicReceivers<ActorId, ()>;
35
36pub(crate) enum MergeExecutorUpstream {
37    Singleton(BoxedActorInput),
38    Merge(SelectReceivers),
39}
40
41pub(crate) struct MergeExecutorInput {
42    upstream: MergeExecutorUpstream,
43    actor_context: ActorContextRef,
44    upstream_fragment_id: UpstreamFragmentId,
45    local_barrier_manager: LocalBarrierManager,
46    executor_stats: Arc<StreamingMetrics>,
47    pub(crate) info: ExecutorInfo,
48    chunk_size: usize,
49}
50
51impl MergeExecutorInput {
52    pub(crate) fn new(
53        upstream: MergeExecutorUpstream,
54        actor_context: ActorContextRef,
55        upstream_fragment_id: UpstreamFragmentId,
56        local_barrier_manager: LocalBarrierManager,
57        executor_stats: Arc<StreamingMetrics>,
58        info: ExecutorInfo,
59        chunk_size: usize,
60    ) -> Self {
61        Self {
62            upstream,
63            actor_context,
64            upstream_fragment_id,
65            local_barrier_manager,
66            executor_stats,
67            info,
68            chunk_size,
69        }
70    }
71
72    pub(crate) fn into_executor(self, barrier_rx: mpsc::UnboundedReceiver<Barrier>) -> Executor {
73        let fragment_id = self.actor_context.fragment_id;
74        let executor = match self.upstream {
75            MergeExecutorUpstream::Singleton(input) => ReceiverExecutor::new(
76                self.actor_context,
77                fragment_id,
78                self.upstream_fragment_id,
79                input,
80                self.local_barrier_manager,
81                self.executor_stats,
82                barrier_rx,
83            )
84            .boxed(),
85            MergeExecutorUpstream::Merge(inputs) => MergeExecutor::new(
86                self.actor_context,
87                fragment_id,
88                self.upstream_fragment_id,
89                inputs,
90                self.local_barrier_manager,
91                self.executor_stats,
92                barrier_rx,
93                self.chunk_size,
94                self.info.schema.clone(),
95            )
96            .boxed(),
97        };
98        (self.info, executor).into()
99    }
100}
101
102impl Stream for MergeExecutorInput {
103    type Item = DispatcherMessageStreamItem;
104
105    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
106        match &mut self.get_mut().upstream {
107            MergeExecutorUpstream::Singleton(input) => input.poll_next_unpin(cx),
108            MergeExecutorUpstream::Merge(inputs) => inputs.poll_next_unpin(cx),
109        }
110    }
111}
112
113/// `MergeExecutor` merges data from multiple channels. Dataflow from one channel
114/// will be stopped on barrier.
115pub struct MergeExecutor {
116    /// The context of the actor.
117    actor_context: ActorContextRef,
118
119    /// Upstream channels.
120    upstreams: SelectReceivers,
121
122    /// Belonged fragment id.
123    fragment_id: FragmentId,
124
125    /// Upstream fragment id.
126    upstream_fragment_id: FragmentId,
127
128    local_barrier_manager: LocalBarrierManager,
129
130    /// Streaming metrics.
131    metrics: Arc<StreamingMetrics>,
132
133    barrier_rx: mpsc::UnboundedReceiver<Barrier>,
134
135    /// Chunk size for the `StreamChunkBuilder`
136    chunk_size: usize,
137
138    /// Data types for the `StreamChunkBuilder`
139    schema: Schema,
140}
141
142impl MergeExecutor {
143    #[allow(clippy::too_many_arguments)]
144    pub fn new(
145        ctx: ActorContextRef,
146        fragment_id: FragmentId,
147        upstream_fragment_id: FragmentId,
148        upstreams: SelectReceivers,
149        local_barrier_manager: LocalBarrierManager,
150        metrics: Arc<StreamingMetrics>,
151        barrier_rx: mpsc::UnboundedReceiver<Barrier>,
152        chunk_size: usize,
153        schema: Schema,
154    ) -> Self {
155        Self {
156            actor_context: ctx,
157            upstreams,
158            fragment_id,
159            upstream_fragment_id,
160            local_barrier_manager,
161            metrics,
162            barrier_rx,
163            chunk_size,
164            schema,
165        }
166    }
167
168    #[cfg(test)]
169    pub fn for_test(
170        actor_id: ActorId,
171        inputs: Vec<super::exchange::permit::Receiver>,
172        local_barrier_manager: crate::task::LocalBarrierManager,
173        schema: Schema,
174        chunk_size: usize,
175        barrier_rx: Option<mpsc::UnboundedReceiver<Barrier>>,
176    ) -> Self {
177        use super::exchange::input::LocalInput;
178        use crate::executor::exchange::input::ActorInput;
179
180        let barrier_rx =
181            barrier_rx.unwrap_or_else(|| local_barrier_manager.subscribe_barrier(actor_id));
182
183        let metrics = StreamingMetrics::unused();
184        let actor_ctx = ActorContext::for_test(actor_id);
185        let upstream = Self::new_select_receiver(
186            inputs
187                .into_iter()
188                .enumerate()
189                .map(|(idx, input)| LocalInput::new(input, idx as ActorId).boxed_input())
190                .collect(),
191            &metrics,
192            &actor_ctx,
193        );
194
195        Self::new(
196            actor_ctx,
197            514,
198            1919,
199            upstream,
200            local_barrier_manager,
201            metrics.into(),
202            barrier_rx,
203            chunk_size,
204            schema,
205        )
206    }
207
208    pub(crate) fn new_select_receiver(
209        upstreams: Vec<BoxedActorInput>,
210        metrics: &StreamingMetrics,
211        actor_context: &ActorContext,
212    ) -> SelectReceivers {
213        let merge_barrier_align_duration = if metrics.level >= MetricLevel::Debug {
214            Some(
215                metrics
216                    .merge_barrier_align_duration
217                    .with_guarded_label_values(&[
218                        &actor_context.id.to_string(),
219                        &actor_context.fragment_id.to_string(),
220                    ]),
221            )
222        } else {
223            None
224        };
225
226        // Futures of all active upstreams.
227        SelectReceivers::new(upstreams, None, merge_barrier_align_duration.clone())
228    }
229
230    #[try_stream(ok = Message, error = StreamExecutorError)]
231    pub(crate) async fn execute_inner(mut self: Box<Self>) {
232        let select_all = self.upstreams;
233        let select_all = BufferChunks::new(select_all, self.chunk_size, self.schema);
234        let actor_id = self.actor_context.id;
235
236        let mut metrics = self.metrics.new_actor_input_metrics(
237            actor_id,
238            self.fragment_id,
239            self.upstream_fragment_id,
240        );
241
242        // Channels that're blocked by the barrier to align.
243        let mut start_time = Instant::now();
244        pin_mut!(select_all);
245        while let Some(msg) = select_all.next().await {
246            metrics
247                .actor_input_buffer_blocking_duration_ns
248                .inc_by(start_time.elapsed().as_nanos() as u64);
249            let msg: DispatcherMessage = msg?;
250            let mut msg: Message = process_dispatcher_msg(msg, &mut self.barrier_rx).await?;
251
252            match &mut msg {
253                Message::Watermark(_) => {
254                    // Do nothing.
255                }
256                Message::Chunk(chunk) => {
257                    metrics.actor_in_record_cnt.inc_by(chunk.cardinality() as _);
258                }
259                Message::Barrier(barrier) => {
260                    tracing::debug!(
261                        target: "events::stream::barrier::path",
262                        actor_id = actor_id,
263                        "receiver receives barrier from path: {:?}",
264                        barrier.passed_actors
265                    );
266                    barrier.passed_actors.push(actor_id);
267
268                    if let Some(Mutation::Update(UpdateMutation { dispatchers, .. })) =
269                        barrier.mutation.as_deref()
270                        && select_all
271                            .upstream_input_ids()
272                            .any(|actor_id| dispatchers.contains_key(&actor_id))
273                    {
274                        // `Watermark` of upstream may become stale after downstream scaling.
275                        select_all.flush_buffered_watermarks();
276                    }
277
278                    if let Some(update) =
279                        barrier.as_update_merge(self.actor_context.id, self.upstream_fragment_id)
280                    {
281                        let new_upstream_fragment_id = update
282                            .new_upstream_fragment_id
283                            .unwrap_or(self.upstream_fragment_id);
284                        let removed_upstream_actor_id: HashSet<_> =
285                            if update.new_upstream_fragment_id.is_some() {
286                                select_all.upstream_input_ids().collect()
287                            } else {
288                                update.removed_upstream_actor_id.iter().copied().collect()
289                            };
290
291                        // `Watermark` of upstream may become stale after upstream scaling.
292                        select_all.flush_buffered_watermarks();
293
294                        if !update.added_upstream_actors.is_empty() {
295                            // Create new upstreams receivers.
296                            let mut new_upstreams: Vec<_> = try_join_all(
297                                update.added_upstream_actors.iter().map(|upstream_actor| {
298                                    new_input(
299                                        &self.local_barrier_manager,
300                                        self.metrics.clone(),
301                                        self.actor_context.id,
302                                        self.fragment_id,
303                                        upstream_actor,
304                                        new_upstream_fragment_id,
305                                    )
306                                }),
307                            )
308                            .await
309                            .context("failed to create upstream receivers")?;
310
311                            // Poll the first barrier from the new upstreams. It must be the same as
312                            // the one we polled from original upstreams.
313                            for upstream in &mut new_upstreams {
314                                let new_barrier = expect_first_barrier(upstream).await?;
315                                assert_equal_dispatcher_barrier(barrier, &new_barrier);
316                            }
317
318                            // Add the new upstreams to select.
319                            select_all.add_upstreams_from(new_upstreams);
320                        }
321
322                        if !removed_upstream_actor_id.is_empty() {
323                            // Remove upstreams.
324                            select_all.remove_upstreams(&removed_upstream_actor_id);
325                        }
326
327                        self.upstream_fragment_id = new_upstream_fragment_id;
328                        metrics = self.metrics.new_actor_input_metrics(
329                            actor_id,
330                            self.fragment_id,
331                            self.upstream_fragment_id,
332                        );
333                    }
334
335                    if barrier.is_stop(actor_id) {
336                        yield msg;
337                        break;
338                    }
339                }
340            }
341
342            yield msg;
343            start_time = Instant::now();
344        }
345    }
346}
347
348impl Execute for MergeExecutor {
349    fn execute(self: Box<Self>) -> BoxedMessageStream {
350        self.execute_inner().boxed()
351    }
352}
353
354/// A wrapper that buffers the `StreamChunk`s from upstream until no more ready items are available.
355/// Besides, any message other than `StreamChunk` will trigger the buffered `StreamChunk`s
356/// to be emitted immediately along with the message itself.
357struct BufferChunks<S: Stream> {
358    inner: S,
359    chunk_builder: StreamChunkBuilder,
360
361    /// The items to be emitted. Whenever there's something here, we should return a `Poll::Ready` immediately.
362    pending_items: VecDeque<S::Item>,
363}
364
365impl<S: Stream> BufferChunks<S> {
366    pub(super) fn new(inner: S, chunk_size: usize, schema: Schema) -> Self {
367        assert!(chunk_size > 0);
368        let chunk_builder = StreamChunkBuilder::new(chunk_size, schema.data_types());
369        Self {
370            inner,
371            chunk_builder,
372            pending_items: VecDeque::new(),
373        }
374    }
375}
376
377impl<S: Stream> std::ops::Deref for BufferChunks<S> {
378    type Target = S;
379
380    fn deref(&self) -> &Self::Target {
381        &self.inner
382    }
383}
384
385impl<S: Stream> std::ops::DerefMut for BufferChunks<S> {
386    fn deref_mut(&mut self) -> &mut Self::Target {
387        &mut self.inner
388    }
389}
390
391impl<S: Stream> Stream for BufferChunks<S>
392where
393    S: Stream<Item = DispatcherMessageStreamItem> + Unpin,
394{
395    type Item = S::Item;
396
397    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
398        loop {
399            if let Some(item) = self.pending_items.pop_front() {
400                return Poll::Ready(Some(item));
401            }
402
403            match self.inner.poll_next_unpin(cx) {
404                Poll::Pending => {
405                    return if let Some(chunk_out) = self.chunk_builder.take() {
406                        Poll::Ready(Some(Ok(MessageInner::Chunk(chunk_out))))
407                    } else {
408                        Poll::Pending
409                    };
410                }
411
412                Poll::Ready(Some(result)) => {
413                    if let Ok(MessageInner::Chunk(chunk)) = result {
414                        for row in chunk.records() {
415                            if let Some(chunk_out) = self.chunk_builder.append_record(row) {
416                                self.pending_items
417                                    .push_back(Ok(MessageInner::Chunk(chunk_out)));
418                            }
419                        }
420                    } else {
421                        return if let Some(chunk_out) = self.chunk_builder.take() {
422                            self.pending_items.push_back(result);
423                            Poll::Ready(Some(Ok(MessageInner::Chunk(chunk_out))))
424                        } else {
425                            Poll::Ready(Some(result))
426                        };
427                    }
428                }
429
430                Poll::Ready(None) => {
431                    // See also the comments in `DynamicReceivers::poll_next`.
432                    unreachable!("Merge should always have upstream inputs");
433                }
434            }
435        }
436    }
437}
438
439#[cfg(test)]
440mod tests {
441    use std::sync::atomic::{AtomicBool, Ordering};
442    use std::time::Duration;
443
444    use assert_matches::assert_matches;
445    use futures::FutureExt;
446    use futures::future::try_join_all;
447    use risingwave_common::array::Op;
448    use risingwave_common::util::epoch::test_epoch;
449    use risingwave_pb::task_service::exchange_service_server::{
450        ExchangeService, ExchangeServiceServer,
451    };
452    use risingwave_pb::task_service::{
453        GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse, PbPermits,
454    };
455    use tokio::time::sleep;
456    use tokio_stream::wrappers::ReceiverStream;
457    use tonic::{Request, Response, Status, Streaming};
458
459    use super::*;
460    use crate::executor::exchange::input::{ActorInput, LocalInput, RemoteInput};
461    use crate::executor::exchange::permit::channel_for_test;
462    use crate::executor::{BarrierInner as Barrier, MessageInner as Message};
463    use crate::task::NewOutputRequest;
464    use crate::task::barrier_test_utils::LocalBarrierTestEnv;
465    use crate::task::test_utils::helper_make_local_actor;
466
467    fn build_test_chunk(size: u64) -> StreamChunk {
468        let ops = vec![Op::Insert; size as usize];
469        StreamChunk::new(ops, vec![])
470    }
471
472    #[tokio::test]
473    async fn test_buffer_chunks() {
474        let test_env = LocalBarrierTestEnv::for_test().await;
475
476        let (tx, rx) = channel_for_test();
477        let input = LocalInput::new(rx, 1).boxed_input();
478        let mut buffer = BufferChunks::new(input, 100, Schema::new(vec![]));
479
480        // Send a chunk
481        tx.send(Message::Chunk(build_test_chunk(10)).into())
482            .await
483            .unwrap();
484        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
485            assert_eq!(chunk.ops().len() as u64, 10);
486        });
487
488        // Send 2 chunks and expect them to be merged.
489        tx.send(Message::Chunk(build_test_chunk(10)).into())
490            .await
491            .unwrap();
492        tx.send(Message::Chunk(build_test_chunk(10)).into())
493            .await
494            .unwrap();
495        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
496            assert_eq!(chunk.ops().len() as u64, 20);
497        });
498
499        // Send a watermark.
500        tx.send(
501            Message::Watermark(Watermark {
502                col_idx: 0,
503                data_type: DataType::Int64,
504                val: ScalarImpl::Int64(233),
505            })
506            .into(),
507        )
508        .await
509        .unwrap();
510        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Watermark(watermark) => {
511            assert_eq!(watermark.val, ScalarImpl::Int64(233));
512        });
513
514        // Send 2 chunks before a watermark. Expect the 2 chunks to be merged and the watermark to be emitted.
515        tx.send(Message::Chunk(build_test_chunk(10)).into())
516            .await
517            .unwrap();
518        tx.send(Message::Chunk(build_test_chunk(10)).into())
519            .await
520            .unwrap();
521        tx.send(
522            Message::Watermark(Watermark {
523                col_idx: 0,
524                data_type: DataType::Int64,
525                val: ScalarImpl::Int64(233),
526            })
527            .into(),
528        )
529        .await
530        .unwrap();
531        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
532            assert_eq!(chunk.ops().len() as u64, 20);
533        });
534        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Watermark(watermark) => {
535            assert_eq!(watermark.val, ScalarImpl::Int64(233));
536        });
537
538        // Send a barrier.
539        let barrier = Barrier::new_test_barrier(test_epoch(1));
540        test_env.inject_barrier(&barrier, [2]);
541        tx.send(Message::Barrier(barrier.clone().into_dispatcher()).into())
542            .await
543            .unwrap();
544        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Barrier(Barrier { epoch: barrier_epoch, mutation: _, .. }) => {
545            assert_eq!(barrier_epoch.curr, test_epoch(1));
546        });
547
548        // Send 2 chunks before a barrier. Expect the 2 chunks to be merged and the barrier to be emitted.
549        tx.send(Message::Chunk(build_test_chunk(10)).into())
550            .await
551            .unwrap();
552        tx.send(Message::Chunk(build_test_chunk(10)).into())
553            .await
554            .unwrap();
555        let barrier = Barrier::new_test_barrier(test_epoch(2));
556        test_env.inject_barrier(&barrier, [2]);
557        tx.send(Message::Barrier(barrier.clone().into_dispatcher()).into())
558            .await
559            .unwrap();
560        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
561            assert_eq!(chunk.ops().len() as u64, 20);
562        });
563        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Barrier(Barrier { epoch: barrier_epoch, mutation: _, .. }) => {
564            assert_eq!(barrier_epoch.curr, test_epoch(2));
565        });
566    }
567
568    #[tokio::test]
569    async fn test_merger() {
570        const CHANNEL_NUMBER: usize = 10;
571        let mut txs = Vec::with_capacity(CHANNEL_NUMBER);
572        let mut rxs = Vec::with_capacity(CHANNEL_NUMBER);
573        for _i in 0..CHANNEL_NUMBER {
574            let (tx, rx) = channel_for_test();
575            txs.push(tx);
576            rxs.push(rx);
577        }
578        let barrier_test_env = LocalBarrierTestEnv::for_test().await;
579        let actor_id = 233;
580        let mut handles = Vec::with_capacity(CHANNEL_NUMBER);
581
582        let epochs = (10..1000u64)
583            .step_by(10)
584            .map(|idx| (idx, test_epoch(idx)))
585            .collect_vec();
586        let mut prev_epoch = 0;
587        let prev_epoch = &mut prev_epoch;
588        let barriers: HashMap<_, _> = epochs
589            .iter()
590            .map(|(_, epoch)| {
591                let barrier = Barrier::with_prev_epoch_for_test(*epoch, *prev_epoch);
592                *prev_epoch = *epoch;
593                barrier_test_env.inject_barrier(&barrier, [actor_id]);
594                (*epoch, barrier)
595            })
596            .collect();
597        let b2 = Barrier::with_prev_epoch_for_test(test_epoch(1000), *prev_epoch)
598            .with_mutation(Mutation::Stop(StopMutation::default()));
599        barrier_test_env.inject_barrier(&b2, [actor_id]);
600        barrier_test_env.flush_all_events().await;
601
602        for (tx_id, tx) in txs.into_iter().enumerate() {
603            let epochs = epochs.clone();
604            let barriers = barriers.clone();
605            let b2 = b2.clone();
606            let handle = tokio::spawn(async move {
607                for (idx, epoch) in epochs {
608                    if idx % 20 == 0 {
609                        tx.send(Message::Chunk(build_test_chunk(10)).into())
610                            .await
611                            .unwrap();
612                    } else {
613                        tx.send(
614                            Message::Watermark(Watermark {
615                                col_idx: (idx as usize / 20 + tx_id) % CHANNEL_NUMBER,
616                                data_type: DataType::Int64,
617                                val: ScalarImpl::Int64(idx as i64),
618                            })
619                            .into(),
620                        )
621                        .await
622                        .unwrap();
623                    }
624                    tx.send(Message::Barrier(barriers[&epoch].clone().into_dispatcher()).into())
625                        .await
626                        .unwrap();
627                    sleep(Duration::from_millis(1)).await;
628                }
629                tx.send(Message::Barrier(b2.clone().into_dispatcher()).into())
630                    .await
631                    .unwrap();
632            });
633            handles.push(handle);
634        }
635
636        let merger = MergeExecutor::for_test(
637            actor_id,
638            rxs,
639            barrier_test_env.local_barrier_manager.clone(),
640            Schema::new(vec![]),
641            100,
642            None,
643        );
644        let mut merger = merger.boxed().execute();
645        for (idx, epoch) in epochs {
646            if idx % 20 == 0 {
647                // expect 1 or more chunks with 100 rows in total
648                let mut count = 0usize;
649                while count < 100 {
650                    assert_matches!(merger.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
651                        count += chunk.ops().len();
652                    });
653                }
654                assert_eq!(count, 100);
655            } else if idx as usize / 20 >= CHANNEL_NUMBER - 1 {
656                // expect n watermarks
657                for _ in 0..CHANNEL_NUMBER {
658                    assert_matches!(merger.next().await.unwrap().unwrap(), Message::Watermark(watermark) => {
659                        assert_eq!(watermark.val, ScalarImpl::Int64((idx - 20 * (CHANNEL_NUMBER as u64 - 1)) as i64));
660                    });
661                }
662            }
663            // expect a barrier
664            assert_matches!(merger.next().await.unwrap().unwrap(), Message::Barrier(Barrier{epoch:barrier_epoch,mutation:_,..}) => {
665                assert_eq!(barrier_epoch.curr, epoch);
666            });
667        }
668        assert_matches!(
669            merger.next().await.unwrap().unwrap(),
670            Message::Barrier(Barrier {
671                mutation,
672                ..
673            }) if mutation.as_deref().unwrap().is_stop()
674        );
675
676        for handle in handles {
677            handle.await.unwrap();
678        }
679    }
680
681    #[tokio::test]
682    async fn test_configuration_change() {
683        let actor_id = 233;
684        let (untouched, old, new) = (234, 235, 238); // upstream actors
685        let barrier_test_env = LocalBarrierTestEnv::for_test().await;
686        let metrics = Arc::new(StreamingMetrics::unused());
687
688        // untouched -> actor_id
689        // old -> actor_id
690        // new -> actor_id
691
692        let (upstream_fragment_id, fragment_id) = (10, 18);
693
694        let inputs: Vec<_> =
695            try_join_all([untouched, old].into_iter().map(async |upstream_actor_id| {
696                new_input(
697                    &barrier_test_env.local_barrier_manager,
698                    metrics.clone(),
699                    actor_id,
700                    fragment_id,
701                    &helper_make_local_actor(upstream_actor_id),
702                    upstream_fragment_id,
703                )
704                .await
705            }))
706            .await
707            .unwrap();
708
709        let merge_updates = maplit::hashmap! {
710            (actor_id, upstream_fragment_id) => MergeUpdate {
711                actor_id,
712                upstream_fragment_id,
713                new_upstream_fragment_id: None,
714                added_upstream_actors: vec![helper_make_local_actor(new)],
715                removed_upstream_actor_id: vec![old],
716            }
717        };
718
719        let b1 = Barrier::new_test_barrier(test_epoch(1)).with_mutation(Mutation::Update(
720            UpdateMutation {
721                merges: merge_updates,
722                ..Default::default()
723            },
724        ));
725        barrier_test_env.inject_barrier(&b1, [actor_id]);
726        barrier_test_env.flush_all_events().await;
727
728        let barrier_rx = barrier_test_env
729            .local_barrier_manager
730            .subscribe_barrier(actor_id);
731        let actor_ctx = ActorContext::for_test(actor_id);
732        let upstream = MergeExecutor::new_select_receiver(inputs, &metrics, &actor_ctx);
733
734        let mut merge = MergeExecutor::new(
735            actor_ctx,
736            fragment_id,
737            upstream_fragment_id,
738            upstream,
739            barrier_test_env.local_barrier_manager.clone(),
740            metrics.clone(),
741            barrier_rx,
742            100,
743            Schema::new(vec![]),
744        )
745        .boxed()
746        .execute();
747
748        let mut txs = HashMap::new();
749        macro_rules! send {
750            ($actors:expr, $msg:expr) => {
751                for actor in $actors {
752                    txs.get(&actor).unwrap().send($msg).await.unwrap();
753                }
754            };
755        }
756
757        macro_rules! assert_recv_pending {
758            () => {
759                assert!(
760                    merge
761                        .next()
762                        .now_or_never()
763                        .flatten()
764                        .transpose()
765                        .unwrap()
766                        .is_none()
767                );
768            };
769        }
770        macro_rules! recv {
771            () => {
772                merge.next().await.transpose().unwrap()
773            };
774        }
775
776        macro_rules! collect_upstream_tx {
777            ($actors:expr) => {
778                for upstream_id in $actors {
779                    let mut output_requests = barrier_test_env
780                        .take_pending_new_output_requests(upstream_id)
781                        .await;
782                    assert_eq!(output_requests.len(), 1);
783                    let (downstream_actor_id, request) = output_requests.pop().unwrap();
784                    assert_eq!(actor_id, downstream_actor_id);
785                    let NewOutputRequest::Local(tx) = request else {
786                        unreachable!()
787                    };
788                    txs.insert(upstream_id, tx);
789                }
790            };
791        }
792
793        assert_recv_pending!();
794        barrier_test_env.flush_all_events().await;
795
796        // 2. Take downstream receivers.
797        collect_upstream_tx!([untouched, old]);
798
799        // 3. Send a chunk.
800        send!([untouched, old], Message::Chunk(build_test_chunk(1)).into());
801        assert_eq!(2, recv!().unwrap().as_chunk().unwrap().cardinality()); // We should be able to receive the chunk twice.
802        assert_recv_pending!();
803
804        send!(
805            [untouched, old],
806            Message::Barrier(b1.clone().into_dispatcher()).into()
807        );
808        assert_recv_pending!(); // We should not receive the barrier, since merger is waiting for the new upstream new.
809
810        collect_upstream_tx!([new]);
811
812        send!([new], Message::Barrier(b1.clone().into_dispatcher()).into());
813        recv!().unwrap().as_barrier().unwrap(); // We should now receive the barrier.
814
815        // 5. Send a chunk.
816        send!([untouched, new], Message::Chunk(build_test_chunk(1)).into());
817        assert_eq!(2, recv!().unwrap().as_chunk().unwrap().cardinality()); // We should be able to receive the chunk twice.
818        assert_recv_pending!();
819    }
820
821    struct FakeExchangeService {
822        rpc_called: Arc<AtomicBool>,
823    }
824
825    fn exchange_client_test_barrier() -> crate::executor::Barrier {
826        Barrier::new_test_barrier(test_epoch(1))
827    }
828
829    #[async_trait::async_trait]
830    impl ExchangeService for FakeExchangeService {
831        type GetDataStream = ReceiverStream<std::result::Result<GetDataResponse, Status>>;
832        type GetStreamStream = ReceiverStream<std::result::Result<GetStreamResponse, Status>>;
833
834        async fn get_data(
835            &self,
836            _: Request<GetDataRequest>,
837        ) -> std::result::Result<Response<Self::GetDataStream>, Status> {
838            unimplemented!()
839        }
840
841        async fn get_stream(
842            &self,
843            _request: Request<Streaming<GetStreamRequest>>,
844        ) -> std::result::Result<Response<Self::GetStreamStream>, Status> {
845            let (tx, rx) = tokio::sync::mpsc::channel(10);
846            self.rpc_called.store(true, Ordering::SeqCst);
847            // send stream_chunk
848            let stream_chunk = StreamChunk::default().to_protobuf();
849            tx.send(Ok(GetStreamResponse {
850                message: Some(PbStreamMessageBatch {
851                    stream_message_batch: Some(
852                        risingwave_pb::stream_plan::stream_message_batch::StreamMessageBatch::StreamChunk(
853                            stream_chunk,
854                        ),
855                    ),
856                }),
857                permits: Some(PbPermits::default()),
858            }))
859            .await
860            .unwrap();
861            // send barrier
862            let barrier = exchange_client_test_barrier();
863            tx.send(Ok(GetStreamResponse {
864                message: Some(PbStreamMessageBatch {
865                    stream_message_batch: Some(
866                        risingwave_pb::stream_plan::stream_message_batch::StreamMessageBatch::BarrierBatch(
867                            BarrierBatch {
868                                barriers: vec![barrier.to_protobuf()],
869                            },
870                        ),
871                    ),
872                }),
873                permits: Some(PbPermits::default()),
874            }))
875            .await
876            .unwrap();
877            Ok(Response::new(ReceiverStream::new(rx)))
878        }
879    }
880
881    #[tokio::test]
882    async fn test_stream_exchange_client() {
883        let rpc_called = Arc::new(AtomicBool::new(false));
884        let server_run = Arc::new(AtomicBool::new(false));
885        let addr = "127.0.0.1:12348".parse().unwrap();
886
887        // Start a server.
888        let (shutdown_send, shutdown_recv) = tokio::sync::oneshot::channel();
889        let exchange_svc = ExchangeServiceServer::new(FakeExchangeService {
890            rpc_called: rpc_called.clone(),
891        });
892        let cp_server_run = server_run.clone();
893        let join_handle = tokio::spawn(async move {
894            cp_server_run.store(true, Ordering::SeqCst);
895            tonic::transport::Server::builder()
896                .add_service(exchange_svc)
897                .serve_with_shutdown(addr, async move {
898                    shutdown_recv.await.unwrap();
899                })
900                .await
901                .unwrap();
902        });
903
904        sleep(Duration::from_secs(1)).await;
905        assert!(server_run.load(Ordering::SeqCst));
906
907        let test_env = LocalBarrierTestEnv::for_test().await;
908
909        let remote_input = {
910            RemoteInput::new(
911                &test_env.local_barrier_manager,
912                addr.into(),
913                (0, 0),
914                (0, 0),
915                Arc::new(StreamingMetrics::unused()),
916            )
917            .await
918            .unwrap()
919        };
920
921        test_env.inject_barrier(&exchange_client_test_barrier(), [remote_input.id()]);
922
923        pin_mut!(remote_input);
924
925        assert_matches!(remote_input.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
926            let (ops, columns, visibility) = chunk.into_inner();
927            assert!(ops.is_empty());
928            assert!(columns.is_empty());
929            assert!(visibility.is_empty());
930        });
931        assert_matches!(remote_input.next().await.unwrap().unwrap(), Message::Barrier(Barrier { epoch: barrier_epoch, mutation: _, .. }) => {
932            assert_eq!(barrier_epoch.curr, test_epoch(1));
933        });
934        assert!(rpc_called.load(Ordering::SeqCst));
935
936        shutdown_send.send(()).unwrap();
937        join_handle.await.unwrap();
938    }
939}