risingwave_stream/executor/
merge.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::VecDeque;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18
19use risingwave_common::array::StreamChunkBuilder;
20use risingwave_common::config::MetricLevel;
21use tokio::sync::mpsc;
22use tokio::time::Instant;
23
24use super::exchange::input::BoxedActorInput;
25use super::*;
26use crate::executor::prelude::*;
27use crate::task::LocalBarrierManager;
28
29pub type SelectReceivers = DynamicReceivers<ActorId, ()>;
30
31pub(crate) enum MergeExecutorUpstream {
32    Singleton(BoxedActorInput),
33    Merge(SelectReceivers),
34}
35
36pub(crate) struct MergeExecutorInput {
37    upstream: MergeExecutorUpstream,
38    actor_context: ActorContextRef,
39    upstream_fragment_id: UpstreamFragmentId,
40    local_barrier_manager: LocalBarrierManager,
41    executor_stats: Arc<StreamingMetrics>,
42    pub(crate) info: ExecutorInfo,
43    chunk_size: usize,
44}
45
46impl MergeExecutorInput {
47    pub(crate) fn new(
48        upstream: MergeExecutorUpstream,
49        actor_context: ActorContextRef,
50        upstream_fragment_id: UpstreamFragmentId,
51        local_barrier_manager: LocalBarrierManager,
52        executor_stats: Arc<StreamingMetrics>,
53        info: ExecutorInfo,
54        chunk_size: usize,
55    ) -> Self {
56        Self {
57            upstream,
58            actor_context,
59            upstream_fragment_id,
60            local_barrier_manager,
61            executor_stats,
62            info,
63            chunk_size,
64        }
65    }
66
67    pub(crate) fn into_executor(self, barrier_rx: mpsc::UnboundedReceiver<Barrier>) -> Executor {
68        let fragment_id = self.actor_context.fragment_id;
69        let executor = match self.upstream {
70            MergeExecutorUpstream::Singleton(input) => ReceiverExecutor::new(
71                self.actor_context,
72                fragment_id,
73                self.upstream_fragment_id,
74                input,
75                self.local_barrier_manager,
76                self.executor_stats,
77                barrier_rx,
78            )
79            .boxed(),
80            MergeExecutorUpstream::Merge(inputs) => MergeExecutor::new(
81                self.actor_context,
82                fragment_id,
83                self.upstream_fragment_id,
84                inputs,
85                self.local_barrier_manager,
86                self.executor_stats,
87                barrier_rx,
88                self.chunk_size,
89                self.info.schema.clone(),
90            )
91            .boxed(),
92        };
93        (self.info, executor).into()
94    }
95}
96
97impl Stream for MergeExecutorInput {
98    type Item = DispatcherMessageStreamItem;
99
100    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
101        match &mut self.get_mut().upstream {
102            MergeExecutorUpstream::Singleton(input) => input.poll_next_unpin(cx),
103            MergeExecutorUpstream::Merge(inputs) => inputs.poll_next_unpin(cx),
104        }
105    }
106}
107
108/// `MergeExecutor` merges data from multiple channels. Dataflow from one channel
109/// will be stopped on barrier.
110pub struct MergeExecutor {
111    /// The context of the actor.
112    actor_context: ActorContextRef,
113
114    /// Upstream channels.
115    upstreams: SelectReceivers,
116
117    /// Belonged fragment id.
118    fragment_id: FragmentId,
119
120    /// Upstream fragment id.
121    upstream_fragment_id: FragmentId,
122
123    local_barrier_manager: LocalBarrierManager,
124
125    /// Streaming metrics.
126    metrics: Arc<StreamingMetrics>,
127
128    barrier_rx: mpsc::UnboundedReceiver<Barrier>,
129
130    /// Chunk size for the `StreamChunkBuilder`
131    chunk_size: usize,
132
133    /// Data types for the `StreamChunkBuilder`
134    schema: Schema,
135}
136
137impl MergeExecutor {
138    #[allow(clippy::too_many_arguments)]
139    pub fn new(
140        ctx: ActorContextRef,
141        fragment_id: FragmentId,
142        upstream_fragment_id: FragmentId,
143        upstreams: SelectReceivers,
144        local_barrier_manager: LocalBarrierManager,
145        metrics: Arc<StreamingMetrics>,
146        barrier_rx: mpsc::UnboundedReceiver<Barrier>,
147        chunk_size: usize,
148        schema: Schema,
149    ) -> Self {
150        Self {
151            actor_context: ctx,
152            upstreams,
153            fragment_id,
154            upstream_fragment_id,
155            local_barrier_manager,
156            metrics,
157            barrier_rx,
158            chunk_size,
159            schema,
160        }
161    }
162
163    #[cfg(test)]
164    pub fn for_test(
165        actor_id: ActorId,
166        inputs: Vec<super::exchange::permit::Receiver>,
167        local_barrier_manager: crate::task::LocalBarrierManager,
168        schema: Schema,
169        chunk_size: usize,
170        barrier_rx: Option<mpsc::UnboundedReceiver<Barrier>>,
171    ) -> Self {
172        use super::exchange::input::LocalInput;
173        use crate::executor::exchange::input::ActorInput;
174
175        let barrier_rx =
176            barrier_rx.unwrap_or_else(|| local_barrier_manager.subscribe_barrier(actor_id));
177
178        let metrics = StreamingMetrics::unused();
179        let actor_ctx = ActorContext::for_test(actor_id);
180        let upstream = Self::new_select_receiver(
181            inputs
182                .into_iter()
183                .enumerate()
184                .map(|(idx, input)| LocalInput::new(input, idx as ActorId).boxed_input())
185                .collect(),
186            &metrics,
187            &actor_ctx,
188        );
189
190        Self::new(
191            actor_ctx,
192            514,
193            1919,
194            upstream,
195            local_barrier_manager,
196            metrics.into(),
197            barrier_rx,
198            chunk_size,
199            schema,
200        )
201    }
202
203    pub(crate) fn new_select_receiver(
204        upstreams: Vec<BoxedActorInput>,
205        metrics: &StreamingMetrics,
206        actor_context: &ActorContext,
207    ) -> SelectReceivers {
208        let merge_barrier_align_duration = if metrics.level >= MetricLevel::Debug {
209            Some(
210                metrics
211                    .merge_barrier_align_duration
212                    .with_guarded_label_values(&[
213                        &actor_context.id.to_string(),
214                        &actor_context.fragment_id.to_string(),
215                    ]),
216            )
217        } else {
218            None
219        };
220
221        // Futures of all active upstreams.
222        SelectReceivers::new(upstreams, None, merge_barrier_align_duration.clone())
223    }
224
225    #[try_stream(ok = Message, error = StreamExecutorError)]
226    pub(crate) async fn execute_inner(mut self: Box<Self>) {
227        let select_all = self.upstreams;
228        let select_all = BufferChunks::new(select_all, self.chunk_size, self.schema);
229        let actor_id = self.actor_context.id;
230
231        let mut metrics = self.metrics.new_actor_input_metrics(
232            actor_id,
233            self.fragment_id,
234            self.upstream_fragment_id,
235        );
236
237        // Channels that're blocked by the barrier to align.
238        let mut start_time = Instant::now();
239        pin_mut!(select_all);
240
241        let mut barrier_buffer = DispatchBarrierBuffer::new(
242            self.barrier_rx,
243            actor_id,
244            self.upstream_fragment_id,
245            self.local_barrier_manager,
246            self.metrics.clone(),
247            self.fragment_id,
248        );
249
250        loop {
251            let msg = barrier_buffer.await_next_message(&mut select_all).await?;
252            metrics
253                .actor_input_buffer_blocking_duration_ns
254                .inc_by(start_time.elapsed().as_nanos() as u64);
255
256            let msg = match msg {
257                DispatcherMessage::Watermark(watermark) => Message::Watermark(watermark),
258                DispatcherMessage::Chunk(chunk) => {
259                    metrics.actor_in_record_cnt.inc_by(chunk.cardinality() as _);
260                    Message::Chunk(chunk)
261                }
262                DispatcherMessage::Barrier(barrier) => {
263                    tracing::debug!(
264                        target: "events::stream::barrier::path",
265                        actor_id = actor_id,
266                        "receiver receives barrier from path: {:?}",
267                        barrier.passed_actors
268                    );
269                    let (mut barrier, new_inputs) =
270                        barrier_buffer.pop_barrier_with_inputs(barrier).await?;
271                    barrier.passed_actors.push(actor_id);
272
273                    if let Some(Mutation::Update(UpdateMutation { dispatchers, .. })) =
274                        barrier.mutation.as_deref()
275                        && select_all
276                            .upstream_input_ids()
277                            .any(|actor_id| dispatchers.contains_key(&actor_id))
278                    {
279                        // `Watermark` of upstream may become stale after downstream scaling.
280                        select_all.flush_buffered_watermarks();
281                    }
282
283                    if let Some(update) =
284                        barrier.as_update_merge(self.actor_context.id, self.upstream_fragment_id)
285                    {
286                        let new_upstream_fragment_id = update
287                            .new_upstream_fragment_id
288                            .unwrap_or(self.upstream_fragment_id);
289                        let removed_upstream_actor_id: HashSet<_> =
290                            if update.new_upstream_fragment_id.is_some() {
291                                select_all.upstream_input_ids().collect()
292                            } else {
293                                update.removed_upstream_actor_id.iter().copied().collect()
294                            };
295
296                        // `Watermark` of upstream may become stale after upstream scaling.
297                        select_all.flush_buffered_watermarks();
298
299                        if let Some(new_upstreams) = new_inputs {
300                            // Add the new upstreams to select.
301                            select_all.add_upstreams_from(new_upstreams);
302                        }
303
304                        if !removed_upstream_actor_id.is_empty() {
305                            // Remove upstreams.
306                            select_all.remove_upstreams(&removed_upstream_actor_id);
307                        }
308
309                        self.upstream_fragment_id = new_upstream_fragment_id;
310                        metrics = self.metrics.new_actor_input_metrics(
311                            actor_id,
312                            self.fragment_id,
313                            self.upstream_fragment_id,
314                        );
315                    }
316
317                    let is_stop = barrier.is_stop(actor_id);
318                    let msg = Message::Barrier(barrier);
319                    if is_stop {
320                        yield msg;
321                        break;
322                    }
323
324                    msg
325                }
326            };
327
328            yield msg;
329            start_time = Instant::now();
330        }
331    }
332}
333
334impl Execute for MergeExecutor {
335    fn execute(self: Box<Self>) -> BoxedMessageStream {
336        self.execute_inner().boxed()
337    }
338}
339
340/// A wrapper that buffers the `StreamChunk`s from upstream until no more ready items are available.
341/// Besides, any message other than `StreamChunk` will trigger the buffered `StreamChunk`s
342/// to be emitted immediately along with the message itself.
343struct BufferChunks<S: Stream> {
344    inner: S,
345    chunk_builder: StreamChunkBuilder,
346
347    /// The items to be emitted. Whenever there's something here, we should return a `Poll::Ready` immediately.
348    pending_items: VecDeque<S::Item>,
349}
350
351impl<S: Stream> BufferChunks<S> {
352    pub(super) fn new(inner: S, chunk_size: usize, schema: Schema) -> Self {
353        assert!(chunk_size > 0);
354        let chunk_builder = StreamChunkBuilder::new(chunk_size, schema.data_types());
355        Self {
356            inner,
357            chunk_builder,
358            pending_items: VecDeque::new(),
359        }
360    }
361}
362
363impl<S: Stream> std::ops::Deref for BufferChunks<S> {
364    type Target = S;
365
366    fn deref(&self) -> &Self::Target {
367        &self.inner
368    }
369}
370
371impl<S: Stream> std::ops::DerefMut for BufferChunks<S> {
372    fn deref_mut(&mut self) -> &mut Self::Target {
373        &mut self.inner
374    }
375}
376
377impl<S: Stream> Stream for BufferChunks<S>
378where
379    S: Stream<Item = DispatcherMessageStreamItem> + Unpin,
380{
381    type Item = S::Item;
382
383    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
384        loop {
385            if let Some(item) = self.pending_items.pop_front() {
386                return Poll::Ready(Some(item));
387            }
388
389            match self.inner.poll_next_unpin(cx) {
390                Poll::Pending => {
391                    return if let Some(chunk_out) = self.chunk_builder.take() {
392                        Poll::Ready(Some(Ok(MessageInner::Chunk(chunk_out))))
393                    } else {
394                        Poll::Pending
395                    };
396                }
397
398                Poll::Ready(Some(result)) => {
399                    if let Ok(MessageInner::Chunk(chunk)) = result {
400                        for row in chunk.records() {
401                            if let Some(chunk_out) = self.chunk_builder.append_record(row) {
402                                self.pending_items
403                                    .push_back(Ok(MessageInner::Chunk(chunk_out)));
404                            }
405                        }
406                    } else {
407                        return if let Some(chunk_out) = self.chunk_builder.take() {
408                            self.pending_items.push_back(result);
409                            Poll::Ready(Some(Ok(MessageInner::Chunk(chunk_out))))
410                        } else {
411                            Poll::Ready(Some(result))
412                        };
413                    }
414                }
415
416                Poll::Ready(None) => {
417                    // See also the comments in `DynamicReceivers::poll_next`.
418                    unreachable!("Merge should always have upstream inputs");
419                }
420            }
421        }
422    }
423}
424
425#[cfg(test)]
426mod tests {
427    use std::sync::atomic::{AtomicBool, Ordering};
428    use std::time::Duration;
429
430    use assert_matches::assert_matches;
431    use futures::FutureExt;
432    use futures::future::try_join_all;
433    use risingwave_common::array::Op;
434    use risingwave_common::util::epoch::test_epoch;
435    use risingwave_pb::task_service::exchange_service_server::{
436        ExchangeService, ExchangeServiceServer,
437    };
438    use risingwave_pb::task_service::{
439        GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse, PbPermits,
440    };
441    use tokio::time::sleep;
442    use tokio_stream::wrappers::ReceiverStream;
443    use tonic::{Request, Response, Status, Streaming};
444
445    use super::*;
446    use crate::executor::exchange::input::{ActorInput, LocalInput, RemoteInput};
447    use crate::executor::exchange::permit::channel_for_test;
448    use crate::executor::{BarrierInner as Barrier, MessageInner as Message};
449    use crate::task::NewOutputRequest;
450    use crate::task::barrier_test_utils::LocalBarrierTestEnv;
451    use crate::task::test_utils::helper_make_local_actor;
452
453    fn build_test_chunk(size: u64) -> StreamChunk {
454        let ops = vec![Op::Insert; size as usize];
455        StreamChunk::new(ops, vec![])
456    }
457
458    #[tokio::test]
459    async fn test_buffer_chunks() {
460        let test_env = LocalBarrierTestEnv::for_test().await;
461
462        let (tx, rx) = channel_for_test();
463        let input = LocalInput::new(rx, 1).boxed_input();
464        let mut buffer = BufferChunks::new(input, 100, Schema::new(vec![]));
465
466        // Send a chunk
467        tx.send(Message::Chunk(build_test_chunk(10)).into())
468            .await
469            .unwrap();
470        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
471            assert_eq!(chunk.ops().len() as u64, 10);
472        });
473
474        // Send 2 chunks and expect them to be merged.
475        tx.send(Message::Chunk(build_test_chunk(10)).into())
476            .await
477            .unwrap();
478        tx.send(Message::Chunk(build_test_chunk(10)).into())
479            .await
480            .unwrap();
481        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
482            assert_eq!(chunk.ops().len() as u64, 20);
483        });
484
485        // Send a watermark.
486        tx.send(
487            Message::Watermark(Watermark {
488                col_idx: 0,
489                data_type: DataType::Int64,
490                val: ScalarImpl::Int64(233),
491            })
492            .into(),
493        )
494        .await
495        .unwrap();
496        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Watermark(watermark) => {
497            assert_eq!(watermark.val, ScalarImpl::Int64(233));
498        });
499
500        // Send 2 chunks before a watermark. Expect the 2 chunks to be merged and the watermark to be emitted.
501        tx.send(Message::Chunk(build_test_chunk(10)).into())
502            .await
503            .unwrap();
504        tx.send(Message::Chunk(build_test_chunk(10)).into())
505            .await
506            .unwrap();
507        tx.send(
508            Message::Watermark(Watermark {
509                col_idx: 0,
510                data_type: DataType::Int64,
511                val: ScalarImpl::Int64(233),
512            })
513            .into(),
514        )
515        .await
516        .unwrap();
517        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
518            assert_eq!(chunk.ops().len() as u64, 20);
519        });
520        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Watermark(watermark) => {
521            assert_eq!(watermark.val, ScalarImpl::Int64(233));
522        });
523
524        // Send a barrier.
525        let barrier = Barrier::new_test_barrier(test_epoch(1));
526        test_env.inject_barrier(&barrier, [2]);
527        tx.send(Message::Barrier(barrier.clone().into_dispatcher()).into())
528            .await
529            .unwrap();
530        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Barrier(Barrier { epoch: barrier_epoch, mutation: _, .. }) => {
531            assert_eq!(barrier_epoch.curr, test_epoch(1));
532        });
533
534        // Send 2 chunks before a barrier. Expect the 2 chunks to be merged and the barrier to be emitted.
535        tx.send(Message::Chunk(build_test_chunk(10)).into())
536            .await
537            .unwrap();
538        tx.send(Message::Chunk(build_test_chunk(10)).into())
539            .await
540            .unwrap();
541        let barrier = Barrier::new_test_barrier(test_epoch(2));
542        test_env.inject_barrier(&barrier, [2]);
543        tx.send(Message::Barrier(barrier.clone().into_dispatcher()).into())
544            .await
545            .unwrap();
546        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
547            assert_eq!(chunk.ops().len() as u64, 20);
548        });
549        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Barrier(Barrier { epoch: barrier_epoch, mutation: _, .. }) => {
550            assert_eq!(barrier_epoch.curr, test_epoch(2));
551        });
552    }
553
554    #[tokio::test]
555    async fn test_merger() {
556        const CHANNEL_NUMBER: usize = 10;
557        let mut txs = Vec::with_capacity(CHANNEL_NUMBER);
558        let mut rxs = Vec::with_capacity(CHANNEL_NUMBER);
559        for _i in 0..CHANNEL_NUMBER {
560            let (tx, rx) = channel_for_test();
561            txs.push(tx);
562            rxs.push(rx);
563        }
564        let barrier_test_env = LocalBarrierTestEnv::for_test().await;
565        let actor_id = 233;
566        let mut handles = Vec::with_capacity(CHANNEL_NUMBER);
567
568        let epochs = (10..1000u64)
569            .step_by(10)
570            .map(|idx| (idx, test_epoch(idx)))
571            .collect_vec();
572        let mut prev_epoch = 0;
573        let prev_epoch = &mut prev_epoch;
574        let barriers: HashMap<_, _> = epochs
575            .iter()
576            .map(|(_, epoch)| {
577                let barrier = Barrier::with_prev_epoch_for_test(*epoch, *prev_epoch);
578                *prev_epoch = *epoch;
579                barrier_test_env.inject_barrier(&barrier, [actor_id]);
580                (*epoch, barrier)
581            })
582            .collect();
583        let b2 = Barrier::with_prev_epoch_for_test(test_epoch(1000), *prev_epoch)
584            .with_mutation(Mutation::Stop(StopMutation::default()));
585        barrier_test_env.inject_barrier(&b2, [actor_id]);
586        barrier_test_env.flush_all_events().await;
587
588        for (tx_id, tx) in txs.into_iter().enumerate() {
589            let epochs = epochs.clone();
590            let barriers = barriers.clone();
591            let b2 = b2.clone();
592            let handle = tokio::spawn(async move {
593                for (idx, epoch) in epochs {
594                    if idx % 20 == 0 {
595                        tx.send(Message::Chunk(build_test_chunk(10)).into())
596                            .await
597                            .unwrap();
598                    } else {
599                        tx.send(
600                            Message::Watermark(Watermark {
601                                col_idx: (idx as usize / 20 + tx_id) % CHANNEL_NUMBER,
602                                data_type: DataType::Int64,
603                                val: ScalarImpl::Int64(idx as i64),
604                            })
605                            .into(),
606                        )
607                        .await
608                        .unwrap();
609                    }
610                    tx.send(Message::Barrier(barriers[&epoch].clone().into_dispatcher()).into())
611                        .await
612                        .unwrap();
613                    sleep(Duration::from_millis(1)).await;
614                }
615                tx.send(Message::Barrier(b2.clone().into_dispatcher()).into())
616                    .await
617                    .unwrap();
618            });
619            handles.push(handle);
620        }
621
622        let merger = MergeExecutor::for_test(
623            actor_id,
624            rxs,
625            barrier_test_env.local_barrier_manager.clone(),
626            Schema::new(vec![]),
627            100,
628            None,
629        );
630        let mut merger = merger.boxed().execute();
631        for (idx, epoch) in epochs {
632            if idx % 20 == 0 {
633                // expect 1 or more chunks with 100 rows in total
634                let mut count = 0usize;
635                while count < 100 {
636                    assert_matches!(merger.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
637                        count += chunk.ops().len();
638                    });
639                }
640                assert_eq!(count, 100);
641            } else if idx as usize / 20 >= CHANNEL_NUMBER - 1 {
642                // expect n watermarks
643                for _ in 0..CHANNEL_NUMBER {
644                    assert_matches!(merger.next().await.unwrap().unwrap(), Message::Watermark(watermark) => {
645                        assert_eq!(watermark.val, ScalarImpl::Int64((idx - 20 * (CHANNEL_NUMBER as u64 - 1)) as i64));
646                    });
647                }
648            }
649            // expect a barrier
650            assert_matches!(merger.next().await.unwrap().unwrap(), Message::Barrier(Barrier{epoch:barrier_epoch,mutation:_,..}) => {
651                assert_eq!(barrier_epoch.curr, epoch);
652            });
653        }
654        assert_matches!(
655            merger.next().await.unwrap().unwrap(),
656            Message::Barrier(Barrier {
657                mutation,
658                ..
659            }) if mutation.as_deref().unwrap().is_stop()
660        );
661
662        for handle in handles {
663            handle.await.unwrap();
664        }
665    }
666
667    #[tokio::test]
668    async fn test_configuration_change() {
669        let actor_id = 233;
670        let (untouched, old, new) = (234, 235, 238); // upstream actors
671        let barrier_test_env = LocalBarrierTestEnv::for_test().await;
672        let metrics = Arc::new(StreamingMetrics::unused());
673
674        // untouched -> actor_id
675        // old -> actor_id
676        // new -> actor_id
677
678        let (upstream_fragment_id, fragment_id) = (10, 18);
679
680        let inputs: Vec<_> =
681            try_join_all([untouched, old].into_iter().map(async |upstream_actor_id| {
682                new_input(
683                    &barrier_test_env.local_barrier_manager,
684                    metrics.clone(),
685                    actor_id,
686                    fragment_id,
687                    &helper_make_local_actor(upstream_actor_id),
688                    upstream_fragment_id,
689                )
690                .await
691            }))
692            .await
693            .unwrap();
694
695        let merge_updates = maplit::hashmap! {
696            (actor_id, upstream_fragment_id) => MergeUpdate {
697                actor_id,
698                upstream_fragment_id,
699                new_upstream_fragment_id: None,
700                added_upstream_actors: vec![helper_make_local_actor(new)],
701                removed_upstream_actor_id: vec![old],
702            }
703        };
704
705        let b1 = Barrier::new_test_barrier(test_epoch(1)).with_mutation(Mutation::Update(
706            UpdateMutation {
707                merges: merge_updates,
708                ..Default::default()
709            },
710        ));
711        barrier_test_env.inject_barrier(&b1, [actor_id]);
712        barrier_test_env.flush_all_events().await;
713
714        let barrier_rx = barrier_test_env
715            .local_barrier_manager
716            .subscribe_barrier(actor_id);
717        let actor_ctx = ActorContext::for_test(actor_id);
718        let upstream = MergeExecutor::new_select_receiver(inputs, &metrics, &actor_ctx);
719
720        let mut merge = MergeExecutor::new(
721            actor_ctx,
722            fragment_id,
723            upstream_fragment_id,
724            upstream,
725            barrier_test_env.local_barrier_manager.clone(),
726            metrics.clone(),
727            barrier_rx,
728            100,
729            Schema::new(vec![]),
730        )
731        .boxed()
732        .execute();
733
734        let mut txs = HashMap::new();
735        macro_rules! send {
736            ($actors:expr, $msg:expr) => {
737                for actor in $actors {
738                    txs.get(&actor).unwrap().send($msg).await.unwrap();
739                }
740            };
741        }
742
743        macro_rules! assert_recv_pending {
744            () => {
745                assert!(
746                    merge
747                        .next()
748                        .now_or_never()
749                        .flatten()
750                        .transpose()
751                        .unwrap()
752                        .is_none()
753                );
754            };
755        }
756        macro_rules! recv {
757            () => {
758                merge.next().await.transpose().unwrap()
759            };
760        }
761
762        macro_rules! collect_upstream_tx {
763            ($actors:expr) => {
764                for upstream_id in $actors {
765                    let mut output_requests = barrier_test_env
766                        .take_pending_new_output_requests(upstream_id)
767                        .await;
768                    assert_eq!(output_requests.len(), 1);
769                    let (downstream_actor_id, request) = output_requests.pop().unwrap();
770                    assert_eq!(actor_id, downstream_actor_id);
771                    let NewOutputRequest::Local(tx) = request else {
772                        unreachable!()
773                    };
774                    txs.insert(upstream_id, tx);
775                }
776            };
777        }
778
779        assert_recv_pending!();
780        barrier_test_env.flush_all_events().await;
781
782        // 2. Take downstream receivers.
783        collect_upstream_tx!([untouched, old]);
784
785        // 3. Send a chunk.
786        send!([untouched, old], Message::Chunk(build_test_chunk(1)).into());
787        assert_eq!(2, recv!().unwrap().as_chunk().unwrap().cardinality()); // We should be able to receive the chunk twice.
788        assert_recv_pending!();
789
790        send!(
791            [untouched, old],
792            Message::Barrier(b1.clone().into_dispatcher()).into()
793        );
794        assert_recv_pending!(); // We should not receive the barrier, since merger is waiting for the new upstream new.
795
796        collect_upstream_tx!([new]);
797
798        send!([new], Message::Barrier(b1.clone().into_dispatcher()).into());
799        recv!().unwrap().as_barrier().unwrap(); // We should now receive the barrier.
800
801        // 5. Send a chunk.
802        send!([untouched, new], Message::Chunk(build_test_chunk(1)).into());
803        assert_eq!(2, recv!().unwrap().as_chunk().unwrap().cardinality()); // We should be able to receive the chunk twice.
804        assert_recv_pending!();
805    }
806
807    struct FakeExchangeService {
808        rpc_called: Arc<AtomicBool>,
809    }
810
811    fn exchange_client_test_barrier() -> crate::executor::Barrier {
812        Barrier::new_test_barrier(test_epoch(1))
813    }
814
815    #[async_trait::async_trait]
816    impl ExchangeService for FakeExchangeService {
817        type GetDataStream = ReceiverStream<std::result::Result<GetDataResponse, Status>>;
818        type GetStreamStream = ReceiverStream<std::result::Result<GetStreamResponse, Status>>;
819
820        async fn get_data(
821            &self,
822            _: Request<GetDataRequest>,
823        ) -> std::result::Result<Response<Self::GetDataStream>, Status> {
824            unimplemented!()
825        }
826
827        async fn get_stream(
828            &self,
829            _request: Request<Streaming<GetStreamRequest>>,
830        ) -> std::result::Result<Response<Self::GetStreamStream>, Status> {
831            let (tx, rx) = tokio::sync::mpsc::channel(10);
832            self.rpc_called.store(true, Ordering::SeqCst);
833            // send stream_chunk
834            let stream_chunk = StreamChunk::default().to_protobuf();
835            tx.send(Ok(GetStreamResponse {
836                message: Some(PbStreamMessageBatch {
837                    stream_message_batch: Some(
838                        risingwave_pb::stream_plan::stream_message_batch::StreamMessageBatch::StreamChunk(
839                            stream_chunk,
840                        ),
841                    ),
842                }),
843                permits: Some(PbPermits::default()),
844            }))
845            .await
846            .unwrap();
847            // send barrier
848            let barrier = exchange_client_test_barrier();
849            tx.send(Ok(GetStreamResponse {
850                message: Some(PbStreamMessageBatch {
851                    stream_message_batch: Some(
852                        risingwave_pb::stream_plan::stream_message_batch::StreamMessageBatch::BarrierBatch(
853                            BarrierBatch {
854                                barriers: vec![barrier.to_protobuf()],
855                            },
856                        ),
857                    ),
858                }),
859                permits: Some(PbPermits::default()),
860            }))
861            .await
862            .unwrap();
863            Ok(Response::new(ReceiverStream::new(rx)))
864        }
865    }
866
867    #[tokio::test]
868    async fn test_stream_exchange_client() {
869        let rpc_called = Arc::new(AtomicBool::new(false));
870        let server_run = Arc::new(AtomicBool::new(false));
871        let addr = "127.0.0.1:12348".parse().unwrap();
872
873        // Start a server.
874        let (shutdown_send, shutdown_recv) = tokio::sync::oneshot::channel();
875        let exchange_svc = ExchangeServiceServer::new(FakeExchangeService {
876            rpc_called: rpc_called.clone(),
877        });
878        let cp_server_run = server_run.clone();
879        let join_handle = tokio::spawn(async move {
880            cp_server_run.store(true, Ordering::SeqCst);
881            tonic::transport::Server::builder()
882                .add_service(exchange_svc)
883                .serve_with_shutdown(addr, async move {
884                    shutdown_recv.await.unwrap();
885                })
886                .await
887                .unwrap();
888        });
889
890        sleep(Duration::from_secs(1)).await;
891        assert!(server_run.load(Ordering::SeqCst));
892
893        let test_env = LocalBarrierTestEnv::for_test().await;
894
895        let remote_input = {
896            RemoteInput::new(
897                &test_env.local_barrier_manager,
898                addr.into(),
899                (0, 0),
900                (0, 0),
901                Arc::new(StreamingMetrics::unused()),
902            )
903            .await
904            .unwrap()
905        };
906
907        test_env.inject_barrier(&exchange_client_test_barrier(), [remote_input.id()]);
908
909        pin_mut!(remote_input);
910
911        assert_matches!(remote_input.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
912            let (ops, columns, visibility) = chunk.into_inner();
913            assert!(ops.is_empty());
914            assert!(columns.is_empty());
915            assert!(visibility.is_empty());
916        });
917        assert_matches!(remote_input.next().await.unwrap().unwrap(), Message::Barrier(Barrier { epoch: barrier_epoch, mutation: _, .. }) => {
918            assert_eq!(barrier_epoch.curr, test_epoch(1));
919        });
920        assert!(rpc_called.load(Ordering::SeqCst));
921
922        shutdown_send.send(()).unwrap();
923        join_handle.await.unwrap();
924    }
925}