risingwave_stream/executor/
merge.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::VecDeque;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18
19use anyhow::Context as _;
20use futures::future::try_join_all;
21use risingwave_common::array::StreamChunkBuilder;
22use risingwave_common::config::MetricLevel;
23use tokio::sync::mpsc;
24use tokio::time::Instant;
25
26use super::exchange::input::BoxedActorInput;
27use super::*;
28use crate::executor::exchange::input::{
29    assert_equal_dispatcher_barrier, new_input, process_dispatcher_msg,
30};
31use crate::executor::prelude::*;
32use crate::task::LocalBarrierManager;
33
34pub type SelectReceivers = DynamicReceivers<ActorId, ()>;
35
36pub(crate) enum MergeExecutorUpstream {
37    Singleton(BoxedActorInput),
38    Merge(SelectReceivers),
39}
40
41pub(crate) struct MergeExecutorInput {
42    upstream: MergeExecutorUpstream,
43    actor_context: ActorContextRef,
44    upstream_fragment_id: UpstreamFragmentId,
45    local_barrier_manager: LocalBarrierManager,
46    executor_stats: Arc<StreamingMetrics>,
47    pub(crate) info: ExecutorInfo,
48    chunk_size: usize,
49}
50
51impl MergeExecutorInput {
52    pub(crate) fn new(
53        upstream: MergeExecutorUpstream,
54        actor_context: ActorContextRef,
55        upstream_fragment_id: UpstreamFragmentId,
56        local_barrier_manager: LocalBarrierManager,
57        executor_stats: Arc<StreamingMetrics>,
58        info: ExecutorInfo,
59        chunk_size: usize,
60    ) -> Self {
61        Self {
62            upstream,
63            actor_context,
64            upstream_fragment_id,
65            local_barrier_manager,
66            executor_stats,
67            info,
68            chunk_size,
69        }
70    }
71
72    pub(crate) fn into_executor(self, barrier_rx: mpsc::UnboundedReceiver<Barrier>) -> Executor {
73        let fragment_id = self.actor_context.fragment_id;
74        let executor = match self.upstream {
75            MergeExecutorUpstream::Singleton(input) => ReceiverExecutor::new(
76                self.actor_context,
77                fragment_id,
78                self.upstream_fragment_id,
79                input,
80                self.local_barrier_manager,
81                self.executor_stats,
82                barrier_rx,
83            )
84            .boxed(),
85            MergeExecutorUpstream::Merge(inputs) => MergeExecutor::new(
86                self.actor_context,
87                fragment_id,
88                self.upstream_fragment_id,
89                inputs,
90                self.local_barrier_manager,
91                self.executor_stats,
92                barrier_rx,
93                self.chunk_size,
94                self.info.schema.clone(),
95            )
96            .boxed(),
97        };
98        (self.info, executor).into()
99    }
100}
101
102impl Stream for MergeExecutorInput {
103    type Item = DispatcherMessageStreamItem;
104
105    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
106        match &mut self.get_mut().upstream {
107            MergeExecutorUpstream::Singleton(input) => input.poll_next_unpin(cx),
108            MergeExecutorUpstream::Merge(inputs) => inputs.poll_next_unpin(cx),
109        }
110    }
111}
112
113/// `MergeExecutor` merges data from multiple channels. Dataflow from one channel
114/// will be stopped on barrier.
115pub struct MergeExecutor {
116    /// The context of the actor.
117    actor_context: ActorContextRef,
118
119    /// Upstream channels.
120    upstreams: SelectReceivers,
121
122    /// Belonged fragment id.
123    fragment_id: FragmentId,
124
125    /// Upstream fragment id.
126    upstream_fragment_id: FragmentId,
127
128    local_barrier_manager: LocalBarrierManager,
129
130    /// Streaming metrics.
131    metrics: Arc<StreamingMetrics>,
132
133    barrier_rx: mpsc::UnboundedReceiver<Barrier>,
134
135    /// Chunk size for the `StreamChunkBuilder`
136    chunk_size: usize,
137
138    /// Data types for the `StreamChunkBuilder`
139    schema: Schema,
140}
141
142impl MergeExecutor {
143    #[allow(clippy::too_many_arguments)]
144    pub fn new(
145        ctx: ActorContextRef,
146        fragment_id: FragmentId,
147        upstream_fragment_id: FragmentId,
148        upstreams: SelectReceivers,
149        local_barrier_manager: LocalBarrierManager,
150        metrics: Arc<StreamingMetrics>,
151        barrier_rx: mpsc::UnboundedReceiver<Barrier>,
152        chunk_size: usize,
153        schema: Schema,
154    ) -> Self {
155        Self {
156            actor_context: ctx,
157            upstreams,
158            fragment_id,
159            upstream_fragment_id,
160            local_barrier_manager,
161            metrics,
162            barrier_rx,
163            chunk_size,
164            schema,
165        }
166    }
167
168    #[cfg(test)]
169    pub fn for_test(
170        actor_id: ActorId,
171        inputs: Vec<super::exchange::permit::Receiver>,
172        local_barrier_manager: crate::task::LocalBarrierManager,
173        schema: Schema,
174        chunk_size: usize,
175    ) -> Self {
176        use super::exchange::input::LocalInput;
177        use crate::executor::exchange::input::ActorInput;
178
179        let barrier_rx = local_barrier_manager.subscribe_barrier(actor_id);
180
181        let metrics = StreamingMetrics::unused();
182        let actor_ctx = ActorContext::for_test(actor_id);
183        let upstream = Self::new_select_receiver(
184            inputs
185                .into_iter()
186                .enumerate()
187                .map(|(idx, input)| LocalInput::new(input, idx as ActorId).boxed_input())
188                .collect(),
189            &metrics,
190            &actor_ctx,
191        );
192
193        Self::new(
194            actor_ctx,
195            514,
196            1919,
197            upstream,
198            local_barrier_manager,
199            metrics.into(),
200            barrier_rx,
201            chunk_size,
202            schema,
203        )
204    }
205
206    pub(crate) fn new_select_receiver(
207        upstreams: Vec<BoxedActorInput>,
208        metrics: &StreamingMetrics,
209        actor_context: &ActorContext,
210    ) -> SelectReceivers {
211        let merge_barrier_align_duration = if metrics.level >= MetricLevel::Debug {
212            Some(
213                metrics
214                    .merge_barrier_align_duration
215                    .with_guarded_label_values(&[
216                        &actor_context.id.to_string(),
217                        &actor_context.fragment_id.to_string(),
218                    ]),
219            )
220        } else {
221            None
222        };
223
224        // Futures of all active upstreams.
225        SelectReceivers::new(upstreams, None, merge_barrier_align_duration.clone())
226    }
227
228    #[try_stream(ok = Message, error = StreamExecutorError)]
229    pub(crate) async fn execute_inner(mut self: Box<Self>) {
230        let select_all = self.upstreams;
231        let select_all = BufferChunks::new(select_all, self.chunk_size, self.schema);
232        let actor_id = self.actor_context.id;
233
234        let mut metrics = self.metrics.new_actor_input_metrics(
235            actor_id,
236            self.fragment_id,
237            self.upstream_fragment_id,
238        );
239
240        // Channels that're blocked by the barrier to align.
241        let mut start_time = Instant::now();
242        pin_mut!(select_all);
243        while let Some(msg) = select_all.next().await {
244            metrics
245                .actor_input_buffer_blocking_duration_ns
246                .inc_by(start_time.elapsed().as_nanos() as u64);
247            let msg: DispatcherMessage = msg?;
248            let mut msg: Message = process_dispatcher_msg(msg, &mut self.barrier_rx).await?;
249
250            match &mut msg {
251                Message::Watermark(_) => {
252                    // Do nothing.
253                }
254                Message::Chunk(chunk) => {
255                    metrics.actor_in_record_cnt.inc_by(chunk.cardinality() as _);
256                }
257                Message::Barrier(barrier) => {
258                    tracing::debug!(
259                        target: "events::stream::barrier::path",
260                        actor_id = actor_id,
261                        "receiver receives barrier from path: {:?}",
262                        barrier.passed_actors
263                    );
264                    barrier.passed_actors.push(actor_id);
265
266                    if let Some(Mutation::Update(UpdateMutation { dispatchers, .. })) =
267                        barrier.mutation.as_deref()
268                        && select_all
269                            .upstream_input_ids()
270                            .any(|actor_id| dispatchers.contains_key(&actor_id))
271                    {
272                        // `Watermark` of upstream may become stale after downstream scaling.
273                        select_all.flush_buffered_watermarks();
274                    }
275
276                    if let Some(update) =
277                        barrier.as_update_merge(self.actor_context.id, self.upstream_fragment_id)
278                    {
279                        let new_upstream_fragment_id = update
280                            .new_upstream_fragment_id
281                            .unwrap_or(self.upstream_fragment_id);
282                        let removed_upstream_actor_id: HashSet<_> =
283                            if update.new_upstream_fragment_id.is_some() {
284                                select_all.upstream_input_ids().collect()
285                            } else {
286                                update.removed_upstream_actor_id.iter().copied().collect()
287                            };
288
289                        // `Watermark` of upstream may become stale after upstream scaling.
290                        select_all.flush_buffered_watermarks();
291
292                        if !update.added_upstream_actors.is_empty() {
293                            // Create new upstreams receivers.
294                            let new_upstreams: Vec<_> = try_join_all(
295                                update.added_upstream_actors.iter().map(|upstream_actor| {
296                                    new_input(
297                                        &self.local_barrier_manager,
298                                        self.metrics.clone(),
299                                        self.actor_context.id,
300                                        self.fragment_id,
301                                        upstream_actor,
302                                        new_upstream_fragment_id,
303                                    )
304                                }),
305                            )
306                            .await
307                            .context("failed to create upstream receivers")?;
308
309                            // Poll the first barrier from the new upstreams. It must be the same as
310                            // the one we polled from original upstreams.
311                            let mut select_new = SelectReceivers::new(
312                                new_upstreams,
313                                None,
314                                select_all.merge_barrier_align_duration(),
315                            );
316                            let new_barrier = expect_first_barrier(&mut select_new).await?;
317                            assert_equal_dispatcher_barrier(barrier, &new_barrier);
318
319                            // Add the new upstreams to select.
320                            select_all.add_upstreams_from(select_new);
321                        }
322
323                        if !removed_upstream_actor_id.is_empty() {
324                            // Remove upstreams.
325                            select_all.remove_upstreams(&removed_upstream_actor_id);
326                        }
327
328                        self.upstream_fragment_id = new_upstream_fragment_id;
329                        metrics = self.metrics.new_actor_input_metrics(
330                            actor_id,
331                            self.fragment_id,
332                            self.upstream_fragment_id,
333                        );
334                    }
335
336                    if barrier.is_stop(actor_id) {
337                        yield msg;
338                        break;
339                    }
340                }
341            }
342
343            yield msg;
344            start_time = Instant::now();
345        }
346    }
347}
348
349impl Execute for MergeExecutor {
350    fn execute(self: Box<Self>) -> BoxedMessageStream {
351        self.execute_inner().boxed()
352    }
353}
354
355/// A wrapper that buffers the `StreamChunk`s from upstream until no more ready items are available.
356/// Besides, any message other than `StreamChunk` will trigger the buffered `StreamChunk`s
357/// to be emitted immediately along with the message itself.
358struct BufferChunks<S: Stream> {
359    inner: S,
360    chunk_builder: StreamChunkBuilder,
361
362    /// The items to be emitted. Whenever there's something here, we should return a `Poll::Ready` immediately.
363    pending_items: VecDeque<S::Item>,
364}
365
366impl<S: Stream> BufferChunks<S> {
367    pub(super) fn new(inner: S, chunk_size: usize, schema: Schema) -> Self {
368        assert!(chunk_size > 0);
369        let chunk_builder = StreamChunkBuilder::new(chunk_size, schema.data_types());
370        Self {
371            inner,
372            chunk_builder,
373            pending_items: VecDeque::new(),
374        }
375    }
376}
377
378impl<S: Stream> std::ops::Deref for BufferChunks<S> {
379    type Target = S;
380
381    fn deref(&self) -> &Self::Target {
382        &self.inner
383    }
384}
385
386impl<S: Stream> std::ops::DerefMut for BufferChunks<S> {
387    fn deref_mut(&mut self) -> &mut Self::Target {
388        &mut self.inner
389    }
390}
391
392impl<S: Stream> Stream for BufferChunks<S>
393where
394    S: Stream<Item = DispatcherMessageStreamItem> + Unpin,
395{
396    type Item = S::Item;
397
398    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
399        loop {
400            if let Some(item) = self.pending_items.pop_front() {
401                return Poll::Ready(Some(item));
402            }
403
404            match self.inner.poll_next_unpin(cx) {
405                Poll::Pending => {
406                    return if let Some(chunk_out) = self.chunk_builder.take() {
407                        Poll::Ready(Some(Ok(MessageInner::Chunk(chunk_out))))
408                    } else {
409                        Poll::Pending
410                    };
411                }
412
413                Poll::Ready(Some(result)) => {
414                    if let Ok(MessageInner::Chunk(chunk)) = result {
415                        for row in chunk.records() {
416                            if let Some(chunk_out) = self.chunk_builder.append_record(row) {
417                                self.pending_items
418                                    .push_back(Ok(MessageInner::Chunk(chunk_out)));
419                            }
420                        }
421                    } else {
422                        return if let Some(chunk_out) = self.chunk_builder.take() {
423                            self.pending_items.push_back(result);
424                            Poll::Ready(Some(Ok(MessageInner::Chunk(chunk_out))))
425                        } else {
426                            Poll::Ready(Some(result))
427                        };
428                    }
429                }
430
431                Poll::Ready(None) => {
432                    // See also the comments in `SelectReceivers::poll_next`.
433                    unreachable!("SelectReceivers should never return None");
434                }
435            }
436        }
437    }
438}
439
440#[cfg(test)]
441mod tests {
442    use std::sync::atomic::{AtomicBool, Ordering};
443    use std::time::Duration;
444
445    use assert_matches::assert_matches;
446    use futures::FutureExt;
447    use futures::future::try_join_all;
448    use risingwave_common::array::Op;
449    use risingwave_common::util::epoch::test_epoch;
450    use risingwave_pb::task_service::exchange_service_server::{
451        ExchangeService, ExchangeServiceServer,
452    };
453    use risingwave_pb::task_service::{
454        GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse, PbPermits,
455    };
456    use tokio::time::sleep;
457    use tokio_stream::wrappers::ReceiverStream;
458    use tonic::{Request, Response, Status, Streaming};
459
460    use super::*;
461    use crate::executor::exchange::input::{ActorInput, LocalInput, RemoteInput};
462    use crate::executor::exchange::permit::channel_for_test;
463    use crate::executor::{BarrierInner as Barrier, MessageInner as Message};
464    use crate::task::NewOutputRequest;
465    use crate::task::barrier_test_utils::LocalBarrierTestEnv;
466    use crate::task::test_utils::helper_make_local_actor;
467
468    fn build_test_chunk(size: u64) -> StreamChunk {
469        let ops = vec![Op::Insert; size as usize];
470        StreamChunk::new(ops, vec![])
471    }
472
473    #[tokio::test]
474    async fn test_buffer_chunks() {
475        let test_env = LocalBarrierTestEnv::for_test().await;
476
477        let (tx, rx) = channel_for_test();
478        let input = LocalInput::new(rx, 1).boxed_input();
479        let mut buffer = BufferChunks::new(input, 100, Schema::new(vec![]));
480
481        // Send a chunk
482        tx.send(Message::Chunk(build_test_chunk(10)).into())
483            .await
484            .unwrap();
485        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
486            assert_eq!(chunk.ops().len() as u64, 10);
487        });
488
489        // Send 2 chunks and expect them to be merged.
490        tx.send(Message::Chunk(build_test_chunk(10)).into())
491            .await
492            .unwrap();
493        tx.send(Message::Chunk(build_test_chunk(10)).into())
494            .await
495            .unwrap();
496        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
497            assert_eq!(chunk.ops().len() as u64, 20);
498        });
499
500        // Send a watermark.
501        tx.send(
502            Message::Watermark(Watermark {
503                col_idx: 0,
504                data_type: DataType::Int64,
505                val: ScalarImpl::Int64(233),
506            })
507            .into(),
508        )
509        .await
510        .unwrap();
511        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Watermark(watermark) => {
512            assert_eq!(watermark.val, ScalarImpl::Int64(233));
513        });
514
515        // Send 2 chunks before a watermark. Expect the 2 chunks to be merged and the watermark to be emitted.
516        tx.send(Message::Chunk(build_test_chunk(10)).into())
517            .await
518            .unwrap();
519        tx.send(Message::Chunk(build_test_chunk(10)).into())
520            .await
521            .unwrap();
522        tx.send(
523            Message::Watermark(Watermark {
524                col_idx: 0,
525                data_type: DataType::Int64,
526                val: ScalarImpl::Int64(233),
527            })
528            .into(),
529        )
530        .await
531        .unwrap();
532        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
533            assert_eq!(chunk.ops().len() as u64, 20);
534        });
535        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Watermark(watermark) => {
536            assert_eq!(watermark.val, ScalarImpl::Int64(233));
537        });
538
539        // Send a barrier.
540        let barrier = Barrier::new_test_barrier(test_epoch(1));
541        test_env.inject_barrier(&barrier, [2]);
542        tx.send(Message::Barrier(barrier.clone().into_dispatcher()).into())
543            .await
544            .unwrap();
545        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Barrier(Barrier { epoch: barrier_epoch, mutation: _, .. }) => {
546            assert_eq!(barrier_epoch.curr, test_epoch(1));
547        });
548
549        // Send 2 chunks before a barrier. Expect the 2 chunks to be merged and the barrier to be emitted.
550        tx.send(Message::Chunk(build_test_chunk(10)).into())
551            .await
552            .unwrap();
553        tx.send(Message::Chunk(build_test_chunk(10)).into())
554            .await
555            .unwrap();
556        let barrier = Barrier::new_test_barrier(test_epoch(2));
557        test_env.inject_barrier(&barrier, [2]);
558        tx.send(Message::Barrier(barrier.clone().into_dispatcher()).into())
559            .await
560            .unwrap();
561        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
562            assert_eq!(chunk.ops().len() as u64, 20);
563        });
564        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Barrier(Barrier { epoch: barrier_epoch, mutation: _, .. }) => {
565            assert_eq!(barrier_epoch.curr, test_epoch(2));
566        });
567    }
568
569    #[tokio::test]
570    async fn test_merger() {
571        const CHANNEL_NUMBER: usize = 10;
572        let mut txs = Vec::with_capacity(CHANNEL_NUMBER);
573        let mut rxs = Vec::with_capacity(CHANNEL_NUMBER);
574        for _i in 0..CHANNEL_NUMBER {
575            let (tx, rx) = channel_for_test();
576            txs.push(tx);
577            rxs.push(rx);
578        }
579        let barrier_test_env = LocalBarrierTestEnv::for_test().await;
580        let actor_id = 233;
581        let mut handles = Vec::with_capacity(CHANNEL_NUMBER);
582
583        let epochs = (10..1000u64)
584            .step_by(10)
585            .map(|idx| (idx, test_epoch(idx)))
586            .collect_vec();
587        let mut prev_epoch = 0;
588        let prev_epoch = &mut prev_epoch;
589        let barriers: HashMap<_, _> = epochs
590            .iter()
591            .map(|(_, epoch)| {
592                let barrier = Barrier::with_prev_epoch_for_test(*epoch, *prev_epoch);
593                *prev_epoch = *epoch;
594                barrier_test_env.inject_barrier(&barrier, [actor_id]);
595                (*epoch, barrier)
596            })
597            .collect();
598        let b2 = Barrier::with_prev_epoch_for_test(test_epoch(1000), *prev_epoch)
599            .with_mutation(Mutation::Stop(HashSet::default()));
600        barrier_test_env.inject_barrier(&b2, [actor_id]);
601        barrier_test_env.flush_all_events().await;
602
603        for (tx_id, tx) in txs.into_iter().enumerate() {
604            let epochs = epochs.clone();
605            let barriers = barriers.clone();
606            let b2 = b2.clone();
607            let handle = tokio::spawn(async move {
608                for (idx, epoch) in epochs {
609                    if idx % 20 == 0 {
610                        tx.send(Message::Chunk(build_test_chunk(10)).into())
611                            .await
612                            .unwrap();
613                    } else {
614                        tx.send(
615                            Message::Watermark(Watermark {
616                                col_idx: (idx as usize / 20 + tx_id) % CHANNEL_NUMBER,
617                                data_type: DataType::Int64,
618                                val: ScalarImpl::Int64(idx as i64),
619                            })
620                            .into(),
621                        )
622                        .await
623                        .unwrap();
624                    }
625                    tx.send(Message::Barrier(barriers[&epoch].clone().into_dispatcher()).into())
626                        .await
627                        .unwrap();
628                    sleep(Duration::from_millis(1)).await;
629                }
630                tx.send(Message::Barrier(b2.clone().into_dispatcher()).into())
631                    .await
632                    .unwrap();
633            });
634            handles.push(handle);
635        }
636
637        let merger = MergeExecutor::for_test(
638            actor_id,
639            rxs,
640            barrier_test_env.local_barrier_manager.clone(),
641            Schema::new(vec![]),
642            100,
643        );
644        let mut merger = merger.boxed().execute();
645        for (idx, epoch) in epochs {
646            if idx % 20 == 0 {
647                // expect 1 or more chunks with 100 rows in total
648                let mut count = 0usize;
649                while count < 100 {
650                    assert_matches!(merger.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
651                        count += chunk.ops().len();
652                    });
653                }
654                assert_eq!(count, 100);
655            } else if idx as usize / 20 >= CHANNEL_NUMBER - 1 {
656                // expect n watermarks
657                for _ in 0..CHANNEL_NUMBER {
658                    assert_matches!(merger.next().await.unwrap().unwrap(), Message::Watermark(watermark) => {
659                        assert_eq!(watermark.val, ScalarImpl::Int64((idx - 20 * (CHANNEL_NUMBER as u64 - 1)) as i64));
660                    });
661                }
662            }
663            // expect a barrier
664            assert_matches!(merger.next().await.unwrap().unwrap(), Message::Barrier(Barrier{epoch:barrier_epoch,mutation:_,..}) => {
665                assert_eq!(barrier_epoch.curr, epoch);
666            });
667        }
668        assert_matches!(
669            merger.next().await.unwrap().unwrap(),
670            Message::Barrier(Barrier {
671                mutation,
672                ..
673            }) if mutation.as_deref().unwrap().is_stop()
674        );
675
676        for handle in handles {
677            handle.await.unwrap();
678        }
679    }
680
681    #[tokio::test]
682    async fn test_configuration_change() {
683        let actor_id = 233;
684        let (untouched, old, new) = (234, 235, 238); // upstream actors
685        let barrier_test_env = LocalBarrierTestEnv::for_test().await;
686        let metrics = Arc::new(StreamingMetrics::unused());
687
688        // untouched -> actor_id
689        // old -> actor_id
690        // new -> actor_id
691
692        let (upstream_fragment_id, fragment_id) = (10, 18);
693
694        let inputs: Vec<_> =
695            try_join_all([untouched, old].into_iter().map(async |upstream_actor_id| {
696                new_input(
697                    &barrier_test_env.local_barrier_manager,
698                    metrics.clone(),
699                    actor_id,
700                    fragment_id,
701                    &helper_make_local_actor(upstream_actor_id),
702                    upstream_fragment_id,
703                )
704                .await
705            }))
706            .await
707            .unwrap();
708
709        let merge_updates = maplit::hashmap! {
710            (actor_id, upstream_fragment_id) => MergeUpdate {
711                actor_id,
712                upstream_fragment_id,
713                new_upstream_fragment_id: None,
714                added_upstream_actors: vec![helper_make_local_actor(new)],
715                removed_upstream_actor_id: vec![old],
716            }
717        };
718
719        let b1 = Barrier::new_test_barrier(test_epoch(1)).with_mutation(Mutation::Update(
720            UpdateMutation {
721                dispatchers: Default::default(),
722                merges: merge_updates,
723                vnode_bitmaps: Default::default(),
724                dropped_actors: Default::default(),
725                actor_splits: Default::default(),
726                actor_new_dispatchers: Default::default(),
727                actor_cdc_table_snapshot_splits: Default::default(),
728            },
729        ));
730        barrier_test_env.inject_barrier(&b1, [actor_id]);
731        barrier_test_env.flush_all_events().await;
732
733        let barrier_rx = barrier_test_env
734            .local_barrier_manager
735            .subscribe_barrier(actor_id);
736        let actor_ctx = ActorContext::for_test(actor_id);
737        let upstream = MergeExecutor::new_select_receiver(inputs, &metrics, &actor_ctx);
738
739        let mut merge = MergeExecutor::new(
740            actor_ctx,
741            fragment_id,
742            upstream_fragment_id,
743            upstream,
744            barrier_test_env.local_barrier_manager.clone(),
745            metrics.clone(),
746            barrier_rx,
747            100,
748            Schema::new(vec![]),
749        )
750        .boxed()
751        .execute();
752
753        let mut txs = HashMap::new();
754        macro_rules! send {
755            ($actors:expr, $msg:expr) => {
756                for actor in $actors {
757                    txs.get(&actor).unwrap().send($msg).await.unwrap();
758                }
759            };
760        }
761
762        macro_rules! assert_recv_pending {
763            () => {
764                assert!(
765                    merge
766                        .next()
767                        .now_or_never()
768                        .flatten()
769                        .transpose()
770                        .unwrap()
771                        .is_none()
772                );
773            };
774        }
775        macro_rules! recv {
776            () => {
777                merge.next().await.transpose().unwrap()
778            };
779        }
780
781        macro_rules! collect_upstream_tx {
782            ($actors:expr) => {
783                for upstream_id in $actors {
784                    let mut output_requests = barrier_test_env
785                        .take_pending_new_output_requests(upstream_id)
786                        .await;
787                    assert_eq!(output_requests.len(), 1);
788                    let (downstream_actor_id, request) = output_requests.pop().unwrap();
789                    assert_eq!(actor_id, downstream_actor_id);
790                    let NewOutputRequest::Local(tx) = request else {
791                        unreachable!()
792                    };
793                    txs.insert(upstream_id, tx);
794                }
795            };
796        }
797
798        assert_recv_pending!();
799        barrier_test_env.flush_all_events().await;
800
801        // 2. Take downstream receivers.
802        collect_upstream_tx!([untouched, old]);
803
804        // 3. Send a chunk.
805        send!([untouched, old], Message::Chunk(build_test_chunk(1)).into());
806        assert_eq!(2, recv!().unwrap().as_chunk().unwrap().cardinality()); // We should be able to receive the chunk twice.
807        assert_recv_pending!();
808
809        send!(
810            [untouched, old],
811            Message::Barrier(b1.clone().into_dispatcher()).into()
812        );
813        assert_recv_pending!(); // We should not receive the barrier, since merger is waiting for the new upstream new.
814
815        collect_upstream_tx!([new]);
816
817        send!([new], Message::Barrier(b1.clone().into_dispatcher()).into());
818        recv!().unwrap().as_barrier().unwrap(); // We should now receive the barrier.
819
820        // 5. Send a chunk.
821        send!([untouched, new], Message::Chunk(build_test_chunk(1)).into());
822        assert_eq!(2, recv!().unwrap().as_chunk().unwrap().cardinality()); // We should be able to receive the chunk twice.
823        assert_recv_pending!();
824    }
825
826    struct FakeExchangeService {
827        rpc_called: Arc<AtomicBool>,
828    }
829
830    fn exchange_client_test_barrier() -> crate::executor::Barrier {
831        Barrier::new_test_barrier(test_epoch(1))
832    }
833
834    #[async_trait::async_trait]
835    impl ExchangeService for FakeExchangeService {
836        type GetDataStream = ReceiverStream<std::result::Result<GetDataResponse, Status>>;
837        type GetStreamStream = ReceiverStream<std::result::Result<GetStreamResponse, Status>>;
838
839        async fn get_data(
840            &self,
841            _: Request<GetDataRequest>,
842        ) -> std::result::Result<Response<Self::GetDataStream>, Status> {
843            unimplemented!()
844        }
845
846        async fn get_stream(
847            &self,
848            _request: Request<Streaming<GetStreamRequest>>,
849        ) -> std::result::Result<Response<Self::GetStreamStream>, Status> {
850            let (tx, rx) = tokio::sync::mpsc::channel(10);
851            self.rpc_called.store(true, Ordering::SeqCst);
852            // send stream_chunk
853            let stream_chunk = StreamChunk::default().to_protobuf();
854            tx.send(Ok(GetStreamResponse {
855                message: Some(PbStreamMessageBatch {
856                    stream_message_batch: Some(
857                        risingwave_pb::stream_plan::stream_message_batch::StreamMessageBatch::StreamChunk(
858                            stream_chunk,
859                        ),
860                    ),
861                }),
862                permits: Some(PbPermits::default()),
863            }))
864            .await
865            .unwrap();
866            // send barrier
867            let barrier = exchange_client_test_barrier();
868            tx.send(Ok(GetStreamResponse {
869                message: Some(PbStreamMessageBatch {
870                    stream_message_batch: Some(
871                        risingwave_pb::stream_plan::stream_message_batch::StreamMessageBatch::BarrierBatch(
872                            BarrierBatch {
873                                barriers: vec![barrier.to_protobuf()],
874                            },
875                        ),
876                    ),
877                }),
878                permits: Some(PbPermits::default()),
879            }))
880            .await
881            .unwrap();
882            Ok(Response::new(ReceiverStream::new(rx)))
883        }
884    }
885
886    #[tokio::test]
887    async fn test_stream_exchange_client() {
888        let rpc_called = Arc::new(AtomicBool::new(false));
889        let server_run = Arc::new(AtomicBool::new(false));
890        let addr = "127.0.0.1:12348".parse().unwrap();
891
892        // Start a server.
893        let (shutdown_send, shutdown_recv) = tokio::sync::oneshot::channel();
894        let exchange_svc = ExchangeServiceServer::new(FakeExchangeService {
895            rpc_called: rpc_called.clone(),
896        });
897        let cp_server_run = server_run.clone();
898        let join_handle = tokio::spawn(async move {
899            cp_server_run.store(true, Ordering::SeqCst);
900            tonic::transport::Server::builder()
901                .add_service(exchange_svc)
902                .serve_with_shutdown(addr, async move {
903                    shutdown_recv.await.unwrap();
904                })
905                .await
906                .unwrap();
907        });
908
909        sleep(Duration::from_secs(1)).await;
910        assert!(server_run.load(Ordering::SeqCst));
911
912        let test_env = LocalBarrierTestEnv::for_test().await;
913
914        let remote_input = {
915            RemoteInput::new(
916                &test_env.local_barrier_manager,
917                addr.into(),
918                (0, 0),
919                (0, 0),
920                Arc::new(StreamingMetrics::unused()),
921            )
922            .await
923            .unwrap()
924        };
925
926        test_env.inject_barrier(&exchange_client_test_barrier(), [remote_input.id()]);
927
928        pin_mut!(remote_input);
929
930        assert_matches!(remote_input.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
931            let (ops, columns, visibility) = chunk.into_inner();
932            assert!(ops.is_empty());
933            assert!(columns.is_empty());
934            assert!(visibility.is_empty());
935        });
936        assert_matches!(remote_input.next().await.unwrap().unwrap(), Message::Barrier(Barrier { epoch: barrier_epoch, mutation: _, .. }) => {
937            assert_eq!(barrier_epoch.curr, test_epoch(1));
938        });
939        assert!(rpc_called.load(Ordering::SeqCst));
940
941        shutdown_send.send(()).unwrap();
942        join_handle.await.unwrap();
943    }
944}