risingwave_stream/executor/
merge.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::VecDeque;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18
19use risingwave_common::array::StreamChunkBuilder;
20use risingwave_common::config::MetricLevel;
21use tokio::sync::mpsc;
22use tokio::time::Instant;
23
24use super::exchange::input::BoxedActorInput;
25use super::*;
26use crate::executor::prelude::*;
27use crate::task::LocalBarrierManager;
28
29pub type SelectReceivers = DynamicReceivers<ActorId, ()>;
30
31pub(crate) enum MergeExecutorUpstream {
32    Singleton(BoxedActorInput),
33    Merge(SelectReceivers),
34}
35
36pub(crate) struct MergeExecutorInput {
37    upstream: MergeExecutorUpstream,
38    actor_context: ActorContextRef,
39    upstream_fragment_id: UpstreamFragmentId,
40    local_barrier_manager: LocalBarrierManager,
41    executor_stats: Arc<StreamingMetrics>,
42    pub(crate) info: ExecutorInfo,
43    chunk_size: usize,
44}
45
46impl MergeExecutorInput {
47    pub(crate) fn new(
48        upstream: MergeExecutorUpstream,
49        actor_context: ActorContextRef,
50        upstream_fragment_id: UpstreamFragmentId,
51        local_barrier_manager: LocalBarrierManager,
52        executor_stats: Arc<StreamingMetrics>,
53        info: ExecutorInfo,
54        chunk_size: usize,
55    ) -> Self {
56        Self {
57            upstream,
58            actor_context,
59            upstream_fragment_id,
60            local_barrier_manager,
61            executor_stats,
62            info,
63            chunk_size,
64        }
65    }
66
67    pub(crate) fn into_executor(self, barrier_rx: mpsc::UnboundedReceiver<Barrier>) -> Executor {
68        let fragment_id = self.actor_context.fragment_id;
69        let executor = match self.upstream {
70            MergeExecutorUpstream::Singleton(input) => ReceiverExecutor::new(
71                self.actor_context,
72                fragment_id,
73                self.upstream_fragment_id,
74                input,
75                self.local_barrier_manager,
76                self.executor_stats,
77                barrier_rx,
78            )
79            .boxed(),
80            MergeExecutorUpstream::Merge(inputs) => MergeExecutor::new(
81                self.actor_context,
82                fragment_id,
83                self.upstream_fragment_id,
84                inputs,
85                self.local_barrier_manager,
86                self.executor_stats,
87                barrier_rx,
88                self.chunk_size,
89                self.info.schema.clone(),
90            )
91            .boxed(),
92        };
93        (self.info, executor).into()
94    }
95}
96
97impl Stream for MergeExecutorInput {
98    type Item = DispatcherMessageStreamItem;
99
100    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
101        match &mut self.get_mut().upstream {
102            MergeExecutorUpstream::Singleton(input) => input.poll_next_unpin(cx),
103            MergeExecutorUpstream::Merge(inputs) => inputs.poll_next_unpin(cx),
104        }
105    }
106}
107
108/// `MergeExecutor` merges data from multiple channels. Dataflow from one channel
109/// will be stopped on barrier.
110pub struct MergeExecutor {
111    /// The context of the actor.
112    actor_context: ActorContextRef,
113
114    /// Upstream channels.
115    upstreams: SelectReceivers,
116
117    /// Belonged fragment id.
118    fragment_id: FragmentId,
119
120    /// Upstream fragment id.
121    upstream_fragment_id: FragmentId,
122
123    local_barrier_manager: LocalBarrierManager,
124
125    /// Streaming metrics.
126    metrics: Arc<StreamingMetrics>,
127
128    barrier_rx: mpsc::UnboundedReceiver<Barrier>,
129
130    /// Chunk size for the `StreamChunkBuilder`
131    chunk_size: usize,
132
133    /// Data types for the `StreamChunkBuilder`
134    schema: Schema,
135}
136
137impl MergeExecutor {
138    #[allow(clippy::too_many_arguments)]
139    pub fn new(
140        ctx: ActorContextRef,
141        fragment_id: FragmentId,
142        upstream_fragment_id: FragmentId,
143        upstreams: SelectReceivers,
144        local_barrier_manager: LocalBarrierManager,
145        metrics: Arc<StreamingMetrics>,
146        barrier_rx: mpsc::UnboundedReceiver<Barrier>,
147        chunk_size: usize,
148        schema: Schema,
149    ) -> Self {
150        Self {
151            actor_context: ctx,
152            upstreams,
153            fragment_id,
154            upstream_fragment_id,
155            local_barrier_manager,
156            metrics,
157            barrier_rx,
158            chunk_size,
159            schema,
160        }
161    }
162
163    #[cfg(test)]
164    pub fn for_test(
165        actor_id: impl Into<ActorId>,
166        inputs: Vec<super::exchange::permit::Receiver>,
167        local_barrier_manager: crate::task::LocalBarrierManager,
168        schema: Schema,
169        chunk_size: usize,
170        barrier_rx: Option<mpsc::UnboundedReceiver<Barrier>>,
171    ) -> Self {
172        let actor_id = actor_id.into();
173        use super::exchange::input::LocalInput;
174        use crate::executor::exchange::input::ActorInput;
175
176        let barrier_rx =
177            barrier_rx.unwrap_or_else(|| local_barrier_manager.subscribe_barrier(actor_id));
178
179        let metrics = StreamingMetrics::unused();
180        let actor_ctx = ActorContext::for_test(actor_id);
181        let upstream = Self::new_select_receiver(
182            inputs
183                .into_iter()
184                .enumerate()
185                .map(|(idx, input)| LocalInput::new(input, ActorId::new(idx as u32)).boxed_input())
186                .collect(),
187            &metrics,
188            &actor_ctx,
189        );
190
191        Self::new(
192            actor_ctx,
193            514.into(),
194            1919.into(),
195            upstream,
196            local_barrier_manager,
197            metrics.into(),
198            barrier_rx,
199            chunk_size,
200            schema,
201        )
202    }
203
204    pub(crate) fn new_select_receiver(
205        upstreams: Vec<BoxedActorInput>,
206        metrics: &StreamingMetrics,
207        actor_context: &ActorContext,
208    ) -> SelectReceivers {
209        let merge_barrier_align_duration = if metrics.level >= MetricLevel::Debug {
210            Some(
211                metrics
212                    .merge_barrier_align_duration
213                    .with_guarded_label_values(&[
214                        &actor_context.id.to_string(),
215                        &actor_context.fragment_id.to_string(),
216                    ]),
217            )
218        } else {
219            None
220        };
221
222        // Futures of all active upstreams.
223        SelectReceivers::new(upstreams, None, merge_barrier_align_duration)
224    }
225
226    #[try_stream(ok = Message, error = StreamExecutorError)]
227    pub(crate) async fn execute_inner(mut self: Box<Self>) {
228        let select_all = self.upstreams;
229        let select_all = BufferChunks::new(select_all, self.chunk_size, self.schema);
230        let actor_id = self.actor_context.id;
231
232        let mut metrics = self.metrics.new_actor_input_metrics(
233            actor_id,
234            self.fragment_id,
235            self.upstream_fragment_id,
236        );
237
238        // Channels that're blocked by the barrier to align.
239        let mut start_time = Instant::now();
240        pin_mut!(select_all);
241
242        let mut barrier_buffer = DispatchBarrierBuffer::new(
243            self.barrier_rx,
244            actor_id,
245            self.upstream_fragment_id,
246            self.local_barrier_manager,
247            self.metrics.clone(),
248            self.fragment_id,
249        );
250
251        loop {
252            let msg = barrier_buffer.await_next_message(&mut select_all).await?;
253            metrics
254                .actor_input_buffer_blocking_duration_ns
255                .inc_by(start_time.elapsed().as_nanos() as u64);
256
257            let msg = match msg {
258                DispatcherMessage::Watermark(watermark) => Message::Watermark(watermark),
259                DispatcherMessage::Chunk(chunk) => {
260                    metrics.actor_in_record_cnt.inc_by(chunk.cardinality() as _);
261                    Message::Chunk(chunk)
262                }
263                DispatcherMessage::Barrier(barrier) => {
264                    tracing::debug!(
265                        target: "events::stream::barrier::path",
266                        actor_id = %actor_id,
267                        "receiver receives barrier from path: {:?}",
268                        barrier.passed_actors
269                    );
270                    let (mut barrier, new_inputs) =
271                        barrier_buffer.pop_barrier_with_inputs(barrier).await?;
272                    barrier.passed_actors.push(actor_id);
273
274                    if let Some(Mutation::Update(UpdateMutation { dispatchers, .. })) =
275                        barrier.mutation.as_deref()
276                        && select_all
277                            .upstream_input_ids()
278                            .any(|actor_id| dispatchers.contains_key(&actor_id))
279                    {
280                        // `Watermark` of upstream may become stale after downstream scaling.
281                        select_all.flush_buffered_watermarks();
282                    }
283
284                    if let Some(update) =
285                        barrier.as_update_merge(self.actor_context.id, self.upstream_fragment_id)
286                    {
287                        let new_upstream_fragment_id = update
288                            .new_upstream_fragment_id
289                            .unwrap_or(self.upstream_fragment_id);
290                        let removed_upstream_actor_id: HashSet<_> =
291                            if update.new_upstream_fragment_id.is_some() {
292                                select_all.upstream_input_ids().collect()
293                            } else {
294                                update.removed_upstream_actor_id.iter().copied().collect()
295                            };
296
297                        // `Watermark` of upstream may become stale after upstream scaling.
298                        select_all.flush_buffered_watermarks();
299
300                        if let Some(new_upstreams) = new_inputs {
301                            // Add the new upstreams to select.
302                            select_all.add_upstreams_from(new_upstreams);
303                        }
304
305                        if !removed_upstream_actor_id.is_empty() {
306                            // Remove upstreams.
307                            select_all.remove_upstreams(&removed_upstream_actor_id);
308                        }
309
310                        self.upstream_fragment_id = new_upstream_fragment_id;
311                        metrics = self.metrics.new_actor_input_metrics(
312                            actor_id,
313                            self.fragment_id,
314                            self.upstream_fragment_id,
315                        );
316                    }
317
318                    let is_stop = barrier.is_stop(actor_id);
319                    let msg = Message::Barrier(barrier);
320                    if is_stop {
321                        yield msg;
322                        break;
323                    }
324
325                    msg
326                }
327            };
328
329            yield msg;
330            start_time = Instant::now();
331        }
332    }
333}
334
335impl Execute for MergeExecutor {
336    fn execute(self: Box<Self>) -> BoxedMessageStream {
337        self.execute_inner().boxed()
338    }
339}
340
341/// A wrapper that buffers the `StreamChunk`s from upstream until no more ready items are available.
342/// Besides, any message other than `StreamChunk` will trigger the buffered `StreamChunk`s
343/// to be emitted immediately along with the message itself.
344struct BufferChunks<S: Stream> {
345    inner: S,
346    chunk_builder: StreamChunkBuilder,
347
348    /// The items to be emitted. Whenever there's something here, we should return a `Poll::Ready` immediately.
349    pending_items: VecDeque<S::Item>,
350}
351
352impl<S: Stream> BufferChunks<S> {
353    pub(super) fn new(inner: S, chunk_size: usize, schema: Schema) -> Self {
354        assert!(chunk_size > 0);
355        let chunk_builder = StreamChunkBuilder::new(chunk_size, schema.data_types());
356        Self {
357            inner,
358            chunk_builder,
359            pending_items: VecDeque::new(),
360        }
361    }
362}
363
364impl<S: Stream> std::ops::Deref for BufferChunks<S> {
365    type Target = S;
366
367    fn deref(&self) -> &Self::Target {
368        &self.inner
369    }
370}
371
372impl<S: Stream> std::ops::DerefMut for BufferChunks<S> {
373    fn deref_mut(&mut self) -> &mut Self::Target {
374        &mut self.inner
375    }
376}
377
378impl<S: Stream> Stream for BufferChunks<S>
379where
380    S: Stream<Item = DispatcherMessageStreamItem> + Unpin,
381{
382    type Item = S::Item;
383
384    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
385        loop {
386            if let Some(item) = self.pending_items.pop_front() {
387                return Poll::Ready(Some(item));
388            }
389
390            match self.inner.poll_next_unpin(cx) {
391                Poll::Pending => {
392                    return if let Some(chunk_out) = self.chunk_builder.take() {
393                        Poll::Ready(Some(Ok(MessageInner::Chunk(chunk_out))))
394                    } else {
395                        Poll::Pending
396                    };
397                }
398
399                Poll::Ready(Some(result)) => {
400                    if let Ok(MessageInner::Chunk(chunk)) = result {
401                        for row in chunk.records() {
402                            if let Some(chunk_out) = self.chunk_builder.append_record(row) {
403                                self.pending_items
404                                    .push_back(Ok(MessageInner::Chunk(chunk_out)));
405                            }
406                        }
407                    } else {
408                        return if let Some(chunk_out) = self.chunk_builder.take() {
409                            self.pending_items.push_back(result);
410                            Poll::Ready(Some(Ok(MessageInner::Chunk(chunk_out))))
411                        } else {
412                            Poll::Ready(Some(result))
413                        };
414                    }
415                }
416
417                Poll::Ready(None) => {
418                    // See also the comments in `DynamicReceivers::poll_next`.
419                    unreachable!("Merge should always have upstream inputs");
420                }
421            }
422        }
423    }
424}
425
426#[cfg(test)]
427mod tests {
428    use std::sync::atomic::{AtomicBool, Ordering};
429    use std::time::Duration;
430
431    use assert_matches::assert_matches;
432    use futures::FutureExt;
433    use futures::future::try_join_all;
434    use risingwave_common::array::Op;
435    use risingwave_common::util::epoch::test_epoch;
436    use risingwave_pb::task_service::exchange_service_server::{
437        ExchangeService, ExchangeServiceServer,
438    };
439    use risingwave_pb::task_service::{
440        GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse, PbPermits,
441    };
442    use tokio::time::sleep;
443    use tokio_stream::wrappers::ReceiverStream;
444    use tonic::{Request, Response, Status, Streaming};
445
446    use super::*;
447    use crate::executor::exchange::input::{ActorInput, LocalInput, RemoteInput};
448    use crate::executor::exchange::permit::channel_for_test;
449    use crate::executor::{BarrierInner as Barrier, MessageInner as Message};
450    use crate::task::NewOutputRequest;
451    use crate::task::barrier_test_utils::LocalBarrierTestEnv;
452    use crate::task::test_utils::helper_make_local_actor;
453
454    fn build_test_chunk(size: u64) -> StreamChunk {
455        let ops = vec![Op::Insert; size as usize];
456        StreamChunk::new(ops, vec![])
457    }
458
459    #[tokio::test]
460    async fn test_buffer_chunks() {
461        let test_env = LocalBarrierTestEnv::for_test().await;
462
463        let (tx, rx) = channel_for_test();
464        let input = LocalInput::new(rx, 1.into()).boxed_input();
465        let mut buffer = BufferChunks::new(input, 100, Schema::new(vec![]));
466
467        // Send a chunk
468        tx.send(Message::Chunk(build_test_chunk(10)).into())
469            .await
470            .unwrap();
471        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
472            assert_eq!(chunk.ops().len() as u64, 10);
473        });
474
475        // Send 2 chunks and expect them to be merged.
476        tx.send(Message::Chunk(build_test_chunk(10)).into())
477            .await
478            .unwrap();
479        tx.send(Message::Chunk(build_test_chunk(10)).into())
480            .await
481            .unwrap();
482        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
483            assert_eq!(chunk.ops().len() as u64, 20);
484        });
485
486        // Send a watermark.
487        tx.send(
488            Message::Watermark(Watermark {
489                col_idx: 0,
490                data_type: DataType::Int64,
491                val: ScalarImpl::Int64(233),
492            })
493            .into(),
494        )
495        .await
496        .unwrap();
497        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Watermark(watermark) => {
498            assert_eq!(watermark.val, ScalarImpl::Int64(233));
499        });
500
501        // Send 2 chunks before a watermark. Expect the 2 chunks to be merged and the watermark to be emitted.
502        tx.send(Message::Chunk(build_test_chunk(10)).into())
503            .await
504            .unwrap();
505        tx.send(Message::Chunk(build_test_chunk(10)).into())
506            .await
507            .unwrap();
508        tx.send(
509            Message::Watermark(Watermark {
510                col_idx: 0,
511                data_type: DataType::Int64,
512                val: ScalarImpl::Int64(233),
513            })
514            .into(),
515        )
516        .await
517        .unwrap();
518        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
519            assert_eq!(chunk.ops().len() as u64, 20);
520        });
521        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Watermark(watermark) => {
522            assert_eq!(watermark.val, ScalarImpl::Int64(233));
523        });
524
525        // Send a barrier.
526        let barrier = Barrier::new_test_barrier(test_epoch(1));
527        test_env.inject_barrier(&barrier, [2.into()]);
528        tx.send(Message::Barrier(barrier.clone().into_dispatcher()).into())
529            .await
530            .unwrap();
531        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Barrier(Barrier { epoch: barrier_epoch, mutation: _, .. }) => {
532            assert_eq!(barrier_epoch.curr, test_epoch(1));
533        });
534
535        // Send 2 chunks before a barrier. Expect the 2 chunks to be merged and the barrier to be emitted.
536        tx.send(Message::Chunk(build_test_chunk(10)).into())
537            .await
538            .unwrap();
539        tx.send(Message::Chunk(build_test_chunk(10)).into())
540            .await
541            .unwrap();
542        let barrier = Barrier::new_test_barrier(test_epoch(2));
543        test_env.inject_barrier(&barrier, [2.into()]);
544        tx.send(Message::Barrier(barrier.clone().into_dispatcher()).into())
545            .await
546            .unwrap();
547        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
548            assert_eq!(chunk.ops().len() as u64, 20);
549        });
550        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Barrier(Barrier { epoch: barrier_epoch, mutation: _, .. }) => {
551            assert_eq!(barrier_epoch.curr, test_epoch(2));
552        });
553    }
554
555    #[tokio::test]
556    async fn test_merger() {
557        const CHANNEL_NUMBER: usize = 10;
558        let mut txs = Vec::with_capacity(CHANNEL_NUMBER);
559        let mut rxs = Vec::with_capacity(CHANNEL_NUMBER);
560        for _i in 0..CHANNEL_NUMBER {
561            let (tx, rx) = channel_for_test();
562            txs.push(tx);
563            rxs.push(rx);
564        }
565        let barrier_test_env = LocalBarrierTestEnv::for_test().await;
566        let actor_id = 233.into();
567        let mut handles = Vec::with_capacity(CHANNEL_NUMBER);
568
569        let epochs = (10..1000u64)
570            .step_by(10)
571            .map(|idx| (idx, test_epoch(idx)))
572            .collect_vec();
573        let mut prev_epoch = 0;
574        let prev_epoch = &mut prev_epoch;
575        let barriers: HashMap<_, _> = epochs
576            .iter()
577            .map(|(_, epoch)| {
578                let barrier = Barrier::with_prev_epoch_for_test(*epoch, *prev_epoch);
579                *prev_epoch = *epoch;
580                barrier_test_env.inject_barrier(&barrier, [actor_id]);
581                (*epoch, barrier)
582            })
583            .collect();
584        let b2 = Barrier::with_prev_epoch_for_test(test_epoch(1000), *prev_epoch)
585            .with_mutation(Mutation::Stop(StopMutation::default()));
586        barrier_test_env.inject_barrier(&b2, [actor_id]);
587        barrier_test_env.flush_all_events().await;
588
589        for (tx_id, tx) in txs.into_iter().enumerate() {
590            let epochs = epochs.clone();
591            let barriers = barriers.clone();
592            let b2 = b2.clone();
593            let handle = tokio::spawn(async move {
594                for (idx, epoch) in epochs {
595                    if idx % 20 == 0 {
596                        tx.send(Message::Chunk(build_test_chunk(10)).into())
597                            .await
598                            .unwrap();
599                    } else {
600                        tx.send(
601                            Message::Watermark(Watermark {
602                                col_idx: (idx as usize / 20 + tx_id) % CHANNEL_NUMBER,
603                                data_type: DataType::Int64,
604                                val: ScalarImpl::Int64(idx as i64),
605                            })
606                            .into(),
607                        )
608                        .await
609                        .unwrap();
610                    }
611                    tx.send(Message::Barrier(barriers[&epoch].clone().into_dispatcher()).into())
612                        .await
613                        .unwrap();
614                    sleep(Duration::from_millis(1)).await;
615                }
616                tx.send(Message::Barrier(b2.clone().into_dispatcher()).into())
617                    .await
618                    .unwrap();
619            });
620            handles.push(handle);
621        }
622
623        let merger = MergeExecutor::for_test(
624            actor_id,
625            rxs,
626            barrier_test_env.local_barrier_manager.clone(),
627            Schema::new(vec![]),
628            100,
629            None,
630        );
631        let mut merger = merger.boxed().execute();
632        for (idx, epoch) in epochs {
633            if idx % 20 == 0 {
634                // expect 1 or more chunks with 100 rows in total
635                let mut count = 0usize;
636                while count < 100 {
637                    assert_matches!(merger.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
638                        count += chunk.ops().len();
639                    });
640                }
641                assert_eq!(count, 100);
642            } else if idx as usize / 20 >= CHANNEL_NUMBER - 1 {
643                // expect n watermarks
644                for _ in 0..CHANNEL_NUMBER {
645                    assert_matches!(merger.next().await.unwrap().unwrap(), Message::Watermark(watermark) => {
646                        assert_eq!(watermark.val, ScalarImpl::Int64((idx - 20 * (CHANNEL_NUMBER as u64 - 1)) as i64));
647                    });
648                }
649            }
650            // expect a barrier
651            assert_matches!(merger.next().await.unwrap().unwrap(), Message::Barrier(Barrier{epoch:barrier_epoch,mutation:_,..}) => {
652                assert_eq!(barrier_epoch.curr, epoch);
653            });
654        }
655        assert_matches!(
656            merger.next().await.unwrap().unwrap(),
657            Message::Barrier(Barrier {
658                mutation,
659                ..
660            }) if mutation.as_deref().unwrap().is_stop()
661        );
662
663        for handle in handles {
664            handle.await.unwrap();
665        }
666    }
667
668    #[tokio::test]
669    async fn test_configuration_change() {
670        let actor_id = 233.into();
671        let (untouched, old, new) = (234.into(), 235.into(), 238.into()); // upstream actors
672        let barrier_test_env = LocalBarrierTestEnv::for_test().await;
673        let metrics = Arc::new(StreamingMetrics::unused());
674
675        // untouched -> actor_id
676        // old -> actor_id
677        // new -> actor_id
678
679        let (upstream_fragment_id, fragment_id) = (10.into(), 18.into());
680
681        let inputs: Vec<_> =
682            try_join_all([untouched, old].into_iter().map(async |upstream_actor_id| {
683                new_input(
684                    &barrier_test_env.local_barrier_manager,
685                    metrics.clone(),
686                    actor_id,
687                    fragment_id,
688                    &helper_make_local_actor(upstream_actor_id),
689                    upstream_fragment_id,
690                )
691                .await
692            }))
693            .await
694            .unwrap();
695
696        let merge_updates = maplit::hashmap! {
697            (actor_id, upstream_fragment_id) => MergeUpdate {
698                actor_id,
699                upstream_fragment_id,
700                new_upstream_fragment_id: None,
701                added_upstream_actors: vec![helper_make_local_actor(new)],
702                removed_upstream_actor_id: vec![old],
703            }
704        };
705
706        let b1 = Barrier::new_test_barrier(test_epoch(1)).with_mutation(Mutation::Update(
707            UpdateMutation {
708                merges: merge_updates,
709                ..Default::default()
710            },
711        ));
712        barrier_test_env.inject_barrier(&b1, [actor_id]);
713        barrier_test_env.flush_all_events().await;
714
715        let barrier_rx = barrier_test_env
716            .local_barrier_manager
717            .subscribe_barrier(actor_id);
718        let actor_ctx = ActorContext::for_test(actor_id);
719        let upstream = MergeExecutor::new_select_receiver(inputs, &metrics, &actor_ctx);
720
721        let mut merge = MergeExecutor::new(
722            actor_ctx,
723            fragment_id,
724            upstream_fragment_id,
725            upstream,
726            barrier_test_env.local_barrier_manager.clone(),
727            metrics.clone(),
728            barrier_rx,
729            100,
730            Schema::new(vec![]),
731        )
732        .boxed()
733        .execute();
734
735        let mut txs = HashMap::new();
736        macro_rules! send {
737            ($actors:expr, $msg:expr) => {
738                for actor in $actors {
739                    txs.get(&actor).unwrap().send($msg).await.unwrap();
740                }
741            };
742        }
743
744        macro_rules! assert_recv_pending {
745            () => {
746                assert!(
747                    merge
748                        .next()
749                        .now_or_never()
750                        .flatten()
751                        .transpose()
752                        .unwrap()
753                        .is_none()
754                );
755            };
756        }
757        macro_rules! recv {
758            () => {
759                merge.next().await.transpose().unwrap()
760            };
761        }
762
763        macro_rules! collect_upstream_tx {
764            ($actors:expr) => {
765                for upstream_id in $actors {
766                    let mut output_requests = barrier_test_env
767                        .take_pending_new_output_requests(upstream_id)
768                        .await;
769                    assert_eq!(output_requests.len(), 1);
770                    let (downstream_actor_id, request) = output_requests.pop().unwrap();
771                    assert_eq!(downstream_actor_id, actor_id);
772                    let NewOutputRequest::Local(tx) = request else {
773                        unreachable!()
774                    };
775                    txs.insert(upstream_id, tx);
776                }
777            };
778        }
779
780        assert_recv_pending!();
781        barrier_test_env.flush_all_events().await;
782
783        // 2. Take downstream receivers.
784        collect_upstream_tx!([untouched, old]);
785
786        // 3. Send a chunk.
787        send!([untouched, old], Message::Chunk(build_test_chunk(1)).into());
788        assert_eq!(2, recv!().unwrap().as_chunk().unwrap().cardinality()); // We should be able to receive the chunk twice.
789        assert_recv_pending!();
790
791        send!(
792            [untouched, old],
793            Message::Barrier(b1.clone().into_dispatcher()).into()
794        );
795        assert_recv_pending!(); // We should not receive the barrier, since merger is waiting for the new upstream new.
796
797        collect_upstream_tx!([new]);
798
799        send!([new], Message::Barrier(b1.clone().into_dispatcher()).into());
800        recv!().unwrap().as_barrier().unwrap(); // We should now receive the barrier.
801
802        // 5. Send a chunk.
803        send!([untouched, new], Message::Chunk(build_test_chunk(1)).into());
804        assert_eq!(2, recv!().unwrap().as_chunk().unwrap().cardinality()); // We should be able to receive the chunk twice.
805        assert_recv_pending!();
806    }
807
808    struct FakeExchangeService {
809        rpc_called: Arc<AtomicBool>,
810    }
811
812    fn exchange_client_test_barrier() -> crate::executor::Barrier {
813        Barrier::new_test_barrier(test_epoch(1))
814    }
815
816    #[async_trait::async_trait]
817    impl ExchangeService for FakeExchangeService {
818        type GetDataStream = ReceiverStream<std::result::Result<GetDataResponse, Status>>;
819        type GetStreamStream = ReceiverStream<std::result::Result<GetStreamResponse, Status>>;
820
821        async fn get_data(
822            &self,
823            _: Request<GetDataRequest>,
824        ) -> std::result::Result<Response<Self::GetDataStream>, Status> {
825            unimplemented!()
826        }
827
828        async fn get_stream(
829            &self,
830            _request: Request<Streaming<GetStreamRequest>>,
831        ) -> std::result::Result<Response<Self::GetStreamStream>, Status> {
832            let (tx, rx) = tokio::sync::mpsc::channel(10);
833            self.rpc_called.store(true, Ordering::SeqCst);
834            // send stream_chunk
835            let stream_chunk = StreamChunk::default().to_protobuf();
836            tx.send(Ok(GetStreamResponse {
837                message: Some(PbStreamMessageBatch {
838                    stream_message_batch: Some(
839                        risingwave_pb::stream_plan::stream_message_batch::StreamMessageBatch::StreamChunk(
840                            stream_chunk,
841                        ),
842                    ),
843                }),
844                permits: Some(PbPermits::default()),
845            }))
846            .await
847            .unwrap();
848            // send barrier
849            let barrier = exchange_client_test_barrier();
850            tx.send(Ok(GetStreamResponse {
851                message: Some(PbStreamMessageBatch {
852                    stream_message_batch: Some(
853                        risingwave_pb::stream_plan::stream_message_batch::StreamMessageBatch::BarrierBatch(
854                            BarrierBatch {
855                                barriers: vec![barrier.to_protobuf()],
856                            },
857                        ),
858                    ),
859                }),
860                permits: Some(PbPermits::default()),
861            }))
862            .await
863            .unwrap();
864            Ok(Response::new(ReceiverStream::new(rx)))
865        }
866    }
867
868    #[tokio::test]
869    async fn test_stream_exchange_client() {
870        let rpc_called = Arc::new(AtomicBool::new(false));
871        let server_run = Arc::new(AtomicBool::new(false));
872        let addr = "127.0.0.1:12348".parse().unwrap();
873
874        // Start a server.
875        let (shutdown_send, shutdown_recv) = tokio::sync::oneshot::channel();
876        let exchange_svc = ExchangeServiceServer::new(FakeExchangeService {
877            rpc_called: rpc_called.clone(),
878        });
879        let cp_server_run = server_run.clone();
880        let join_handle = tokio::spawn(async move {
881            cp_server_run.store(true, Ordering::SeqCst);
882            tonic::transport::Server::builder()
883                .add_service(exchange_svc)
884                .serve_with_shutdown(addr, async move {
885                    shutdown_recv.await.unwrap();
886                })
887                .await
888                .unwrap();
889        });
890
891        sleep(Duration::from_secs(1)).await;
892        assert!(server_run.load(Ordering::SeqCst));
893
894        let test_env = LocalBarrierTestEnv::for_test().await;
895
896        let remote_input = {
897            RemoteInput::new(
898                &test_env.local_barrier_manager,
899                addr.into(),
900                (0.into(), 0.into()),
901                (0.into(), 0.into()),
902                Arc::new(StreamingMetrics::unused()),
903            )
904            .await
905            .unwrap()
906        };
907
908        test_env.inject_barrier(&exchange_client_test_barrier(), [remote_input.id()]);
909
910        pin_mut!(remote_input);
911
912        assert_matches!(remote_input.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
913            let (ops, columns, visibility) = chunk.into_inner();
914            assert!(ops.is_empty());
915            assert!(columns.is_empty());
916            assert!(visibility.is_empty());
917        });
918        assert_matches!(remote_input.next().await.unwrap().unwrap(), Message::Barrier(Barrier { epoch: barrier_epoch, mutation: _, .. }) => {
919            assert_eq!(barrier_epoch.curr, test_epoch(1));
920        });
921        assert!(rpc_called.load(Ordering::SeqCst));
922
923        shutdown_send.send(()).unwrap();
924        join_handle.await.unwrap();
925    }
926}