risingwave_stream/executor/
merge.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::VecDeque;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18
19use risingwave_common::array::StreamChunkBuilder;
20use risingwave_common::config::MetricLevel;
21use tokio::sync::mpsc;
22
23use super::exchange::input::BoxedActorInput;
24use super::*;
25use crate::executor::prelude::*;
26use crate::task::LocalBarrierManager;
27
28pub type SelectReceivers = DynamicReceivers<ActorId, ()>;
29
30pub(crate) enum MergeExecutorUpstream {
31    Singleton(BoxedActorInput),
32    Merge(SelectReceivers),
33}
34
35pub(crate) struct MergeExecutorInput {
36    upstream: MergeExecutorUpstream,
37    actor_context: ActorContextRef,
38    upstream_fragment_id: UpstreamFragmentId,
39    local_barrier_manager: LocalBarrierManager,
40    executor_stats: Arc<StreamingMetrics>,
41    pub(crate) info: ExecutorInfo,
42    chunk_size: usize,
43}
44
45impl MergeExecutorInput {
46    pub(crate) fn new(
47        upstream: MergeExecutorUpstream,
48        actor_context: ActorContextRef,
49        upstream_fragment_id: UpstreamFragmentId,
50        local_barrier_manager: LocalBarrierManager,
51        executor_stats: Arc<StreamingMetrics>,
52        info: ExecutorInfo,
53        chunk_size: usize,
54    ) -> Self {
55        Self {
56            upstream,
57            actor_context,
58            upstream_fragment_id,
59            local_barrier_manager,
60            executor_stats,
61            info,
62            chunk_size,
63        }
64    }
65
66    pub(crate) fn into_executor(self, barrier_rx: mpsc::UnboundedReceiver<Barrier>) -> Executor {
67        let fragment_id = self.actor_context.fragment_id;
68        let executor = match self.upstream {
69            MergeExecutorUpstream::Singleton(input) => ReceiverExecutor::new(
70                self.actor_context,
71                fragment_id,
72                self.upstream_fragment_id,
73                input,
74                self.local_barrier_manager,
75                self.executor_stats,
76                barrier_rx,
77            )
78            .boxed(),
79            MergeExecutorUpstream::Merge(inputs) => MergeExecutor::new(
80                self.actor_context,
81                fragment_id,
82                self.upstream_fragment_id,
83                inputs,
84                self.local_barrier_manager,
85                self.executor_stats,
86                barrier_rx,
87                self.chunk_size,
88                self.info.schema.clone(),
89            )
90            .boxed(),
91        };
92        (self.info, executor).into()
93    }
94}
95
96impl Stream for MergeExecutorInput {
97    type Item = DispatcherMessageStreamItem;
98
99    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
100        match &mut self.get_mut().upstream {
101            MergeExecutorUpstream::Singleton(input) => input.poll_next_unpin(cx),
102            MergeExecutorUpstream::Merge(inputs) => inputs.poll_next_unpin(cx),
103        }
104    }
105}
106
107/// `MergeExecutor` merges data from multiple channels. Dataflow from one channel
108/// will be stopped on barrier.
109pub struct MergeExecutor {
110    /// The context of the actor.
111    actor_context: ActorContextRef,
112
113    /// Upstream channels.
114    upstreams: SelectReceivers,
115
116    /// Belonged fragment id.
117    fragment_id: FragmentId,
118
119    /// Upstream fragment id.
120    upstream_fragment_id: FragmentId,
121
122    local_barrier_manager: LocalBarrierManager,
123
124    /// Streaming metrics.
125    metrics: Arc<StreamingMetrics>,
126
127    barrier_rx: mpsc::UnboundedReceiver<Barrier>,
128
129    /// Chunk size for the `StreamChunkBuilder`
130    chunk_size: usize,
131
132    /// Data types for the `StreamChunkBuilder`
133    schema: Schema,
134}
135
136impl MergeExecutor {
137    #[allow(clippy::too_many_arguments)]
138    pub fn new(
139        ctx: ActorContextRef,
140        fragment_id: FragmentId,
141        upstream_fragment_id: FragmentId,
142        upstreams: SelectReceivers,
143        local_barrier_manager: LocalBarrierManager,
144        metrics: Arc<StreamingMetrics>,
145        barrier_rx: mpsc::UnboundedReceiver<Barrier>,
146        chunk_size: usize,
147        schema: Schema,
148    ) -> Self {
149        Self {
150            actor_context: ctx,
151            upstreams,
152            fragment_id,
153            upstream_fragment_id,
154            local_barrier_manager,
155            metrics,
156            barrier_rx,
157            chunk_size,
158            schema,
159        }
160    }
161
162    #[cfg(test)]
163    pub fn for_test(
164        actor_id: impl Into<ActorId>,
165        inputs: Vec<super::exchange::permit::Receiver>,
166        local_barrier_manager: crate::task::LocalBarrierManager,
167        schema: Schema,
168        chunk_size: usize,
169        barrier_rx: Option<mpsc::UnboundedReceiver<Barrier>>,
170    ) -> Self {
171        let actor_id = actor_id.into();
172        use super::exchange::input::LocalInput;
173        use crate::executor::exchange::input::ActorInput;
174
175        let barrier_rx =
176            barrier_rx.unwrap_or_else(|| local_barrier_manager.subscribe_barrier(actor_id));
177
178        let metrics = StreamingMetrics::unused();
179        let actor_ctx = ActorContext::for_test(actor_id);
180        let upstream = Self::new_select_receiver(
181            inputs
182                .into_iter()
183                .enumerate()
184                .map(|(idx, input)| LocalInput::new(input, ActorId::new(idx as u32)).boxed_input())
185                .collect(),
186            &metrics,
187            &actor_ctx,
188        );
189
190        Self::new(
191            actor_ctx,
192            514.into(),
193            1919.into(),
194            upstream,
195            local_barrier_manager,
196            metrics.into(),
197            barrier_rx,
198            chunk_size,
199            schema,
200        )
201    }
202
203    pub(crate) fn new_select_receiver(
204        upstreams: Vec<BoxedActorInput>,
205        metrics: &StreamingMetrics,
206        actor_context: &ActorContext,
207    ) -> SelectReceivers {
208        let merge_barrier_align_duration = if metrics.level >= MetricLevel::Debug {
209            Some(
210                metrics
211                    .merge_barrier_align_duration
212                    .with_guarded_label_values(&[
213                        &actor_context.id.to_string(),
214                        &actor_context.fragment_id.to_string(),
215                    ]),
216            )
217        } else {
218            None
219        };
220
221        // Futures of all active upstreams.
222        SelectReceivers::new(upstreams, None, merge_barrier_align_duration)
223    }
224
225    #[try_stream(ok = Message, error = StreamExecutorError)]
226    pub(crate) async fn execute_inner(mut self: Box<Self>) {
227        let select_all = self.upstreams;
228        let select_all = BufferChunks::new(select_all, self.chunk_size, self.schema);
229        let actor_id = self.actor_context.id;
230
231        let mut metrics = self.metrics.new_actor_input_metrics(
232            actor_id,
233            self.fragment_id,
234            self.upstream_fragment_id,
235        );
236
237        // Channels that're blocked by the barrier to align.
238        pin_mut!(select_all);
239
240        let mut barrier_buffer = DispatchBarrierBuffer::new(
241            self.barrier_rx,
242            actor_id,
243            self.upstream_fragment_id,
244            self.local_barrier_manager,
245            self.metrics.clone(),
246            self.fragment_id,
247        );
248
249        loop {
250            let msg = barrier_buffer
251                .await_next_message(&mut select_all, &metrics)
252                .await?;
253            let msg = match msg {
254                DispatcherMessage::Watermark(watermark) => Message::Watermark(watermark),
255                DispatcherMessage::Chunk(chunk) => {
256                    metrics.actor_in_record_cnt.inc_by(chunk.cardinality() as _);
257                    Message::Chunk(chunk)
258                }
259                DispatcherMessage::Barrier(barrier) => {
260                    tracing::debug!(
261                        target: "events::stream::barrier::path",
262                        actor_id = %actor_id,
263                        "receiver receives barrier from path: {:?}",
264                        barrier.passed_actors
265                    );
266                    let (mut barrier, new_inputs) =
267                        barrier_buffer.pop_barrier_with_inputs(barrier).await?;
268                    barrier.passed_actors.push(actor_id);
269
270                    if let Some(Mutation::Update(UpdateMutation { dispatchers, .. })) =
271                        barrier.mutation.as_deref()
272                        && select_all
273                            .upstream_input_ids()
274                            .any(|actor_id| dispatchers.contains_key(&actor_id))
275                    {
276                        // `Watermark` of upstream may become stale after downstream scaling.
277                        select_all.flush_buffered_watermarks();
278                    }
279
280                    if let Some(update) =
281                        barrier.as_update_merge(self.actor_context.id, self.upstream_fragment_id)
282                    {
283                        let new_upstream_fragment_id = update
284                            .new_upstream_fragment_id
285                            .unwrap_or(self.upstream_fragment_id);
286                        let removed_upstream_actor_id: HashSet<_> =
287                            if update.new_upstream_fragment_id.is_some() {
288                                select_all.upstream_input_ids().collect()
289                            } else {
290                                update.removed_upstream_actor_id.iter().copied().collect()
291                            };
292
293                        // `Watermark` of upstream may become stale after upstream scaling.
294                        select_all.flush_buffered_watermarks();
295
296                        if let Some(new_upstreams) = new_inputs {
297                            // Add the new upstreams to select.
298                            select_all.add_upstreams_from(new_upstreams);
299                        }
300
301                        if !removed_upstream_actor_id.is_empty() {
302                            // Remove upstreams.
303                            select_all.remove_upstreams(&removed_upstream_actor_id);
304                        }
305
306                        self.upstream_fragment_id = new_upstream_fragment_id;
307                        metrics = self.metrics.new_actor_input_metrics(
308                            actor_id,
309                            self.fragment_id,
310                            self.upstream_fragment_id,
311                        );
312                    }
313
314                    let is_stop = barrier.is_stop(actor_id);
315                    let msg = Message::Barrier(barrier);
316                    if is_stop {
317                        yield msg;
318                        break;
319                    }
320
321                    msg
322                }
323            };
324
325            yield msg;
326        }
327    }
328}
329
330impl Execute for MergeExecutor {
331    fn execute(self: Box<Self>) -> BoxedMessageStream {
332        self.execute_inner().boxed()
333    }
334}
335
336/// A wrapper that buffers the `StreamChunk`s from upstream until no more ready items are available.
337/// Besides, any message other than `StreamChunk` will trigger the buffered `StreamChunk`s
338/// to be emitted immediately along with the message itself.
339struct BufferChunks<S: Stream> {
340    inner: S,
341    chunk_builder: StreamChunkBuilder,
342
343    /// The items to be emitted. Whenever there's something here, we should return a `Poll::Ready` immediately.
344    pending_items: VecDeque<S::Item>,
345}
346
347impl<S: Stream> BufferChunks<S> {
348    pub(super) fn new(inner: S, chunk_size: usize, schema: Schema) -> Self {
349        assert!(chunk_size > 0);
350        let chunk_builder = StreamChunkBuilder::new(chunk_size, schema.data_types());
351        Self {
352            inner,
353            chunk_builder,
354            pending_items: VecDeque::new(),
355        }
356    }
357}
358
359impl<S: Stream> std::ops::Deref for BufferChunks<S> {
360    type Target = S;
361
362    fn deref(&self) -> &Self::Target {
363        &self.inner
364    }
365}
366
367impl<S: Stream> std::ops::DerefMut for BufferChunks<S> {
368    fn deref_mut(&mut self) -> &mut Self::Target {
369        &mut self.inner
370    }
371}
372
373impl<S: Stream> Stream for BufferChunks<S>
374where
375    S: Stream<Item = DispatcherMessageStreamItem> + Unpin,
376{
377    type Item = S::Item;
378
379    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
380        loop {
381            if let Some(item) = self.pending_items.pop_front() {
382                return Poll::Ready(Some(item));
383            }
384
385            match self.inner.poll_next_unpin(cx) {
386                Poll::Pending => {
387                    return if let Some(chunk_out) = self.chunk_builder.take() {
388                        Poll::Ready(Some(Ok(MessageInner::Chunk(chunk_out))))
389                    } else {
390                        Poll::Pending
391                    };
392                }
393
394                Poll::Ready(Some(result)) => {
395                    if let Ok(MessageInner::Chunk(chunk)) = result {
396                        for row in chunk.records() {
397                            if let Some(chunk_out) = self.chunk_builder.append_record(row) {
398                                self.pending_items
399                                    .push_back(Ok(MessageInner::Chunk(chunk_out)));
400                            }
401                        }
402                    } else {
403                        return if let Some(chunk_out) = self.chunk_builder.take() {
404                            self.pending_items.push_back(result);
405                            Poll::Ready(Some(Ok(MessageInner::Chunk(chunk_out))))
406                        } else {
407                            Poll::Ready(Some(result))
408                        };
409                    }
410                }
411
412                Poll::Ready(None) => {
413                    // See also the comments in `DynamicReceivers::poll_next`.
414                    unreachable!("Merge should always have upstream inputs");
415                }
416            }
417        }
418    }
419}
420
421#[cfg(test)]
422mod tests {
423    use std::sync::atomic::{AtomicBool, Ordering};
424    use std::time::Duration;
425
426    use assert_matches::assert_matches;
427    use futures::FutureExt;
428    use futures::future::try_join_all;
429    use risingwave_common::array::Op;
430    use risingwave_common::util::epoch::test_epoch;
431    use risingwave_pb::task_service::exchange_service_server::{
432        ExchangeService, ExchangeServiceServer,
433    };
434    use risingwave_pb::task_service::{
435        GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse, PbPermits,
436    };
437    use tokio::time::sleep;
438    use tokio_stream::wrappers::ReceiverStream;
439    use tonic::{Request, Response, Status, Streaming};
440
441    use super::*;
442    use crate::executor::exchange::input::{ActorInput, LocalInput, RemoteInput};
443    use crate::executor::exchange::permit::channel_for_test;
444    use crate::executor::{BarrierInner as Barrier, MessageInner as Message};
445    use crate::task::NewOutputRequest;
446    use crate::task::barrier_test_utils::LocalBarrierTestEnv;
447    use crate::task::test_utils::helper_make_local_actor;
448
449    fn build_test_chunk(size: u64) -> StreamChunk {
450        let ops = vec![Op::Insert; size as usize];
451        StreamChunk::new(ops, vec![])
452    }
453
454    #[tokio::test]
455    async fn test_buffer_chunks() {
456        let test_env = LocalBarrierTestEnv::for_test().await;
457
458        let (tx, rx) = channel_for_test();
459        let input = LocalInput::new(rx, 1.into()).boxed_input();
460        let mut buffer = BufferChunks::new(input, 100, Schema::new(vec![]));
461
462        // Send a chunk
463        tx.send(Message::Chunk(build_test_chunk(10)).into())
464            .await
465            .unwrap();
466        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
467            assert_eq!(chunk.ops().len() as u64, 10);
468        });
469
470        // Send 2 chunks and expect them to be merged.
471        tx.send(Message::Chunk(build_test_chunk(10)).into())
472            .await
473            .unwrap();
474        tx.send(Message::Chunk(build_test_chunk(10)).into())
475            .await
476            .unwrap();
477        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
478            assert_eq!(chunk.ops().len() as u64, 20);
479        });
480
481        // Send a watermark.
482        tx.send(
483            Message::Watermark(Watermark {
484                col_idx: 0,
485                data_type: DataType::Int64,
486                val: ScalarImpl::Int64(233),
487            })
488            .into(),
489        )
490        .await
491        .unwrap();
492        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Watermark(watermark) => {
493            assert_eq!(watermark.val, ScalarImpl::Int64(233));
494        });
495
496        // Send 2 chunks before a watermark. Expect the 2 chunks to be merged and the watermark to be emitted.
497        tx.send(Message::Chunk(build_test_chunk(10)).into())
498            .await
499            .unwrap();
500        tx.send(Message::Chunk(build_test_chunk(10)).into())
501            .await
502            .unwrap();
503        tx.send(
504            Message::Watermark(Watermark {
505                col_idx: 0,
506                data_type: DataType::Int64,
507                val: ScalarImpl::Int64(233),
508            })
509            .into(),
510        )
511        .await
512        .unwrap();
513        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
514            assert_eq!(chunk.ops().len() as u64, 20);
515        });
516        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Watermark(watermark) => {
517            assert_eq!(watermark.val, ScalarImpl::Int64(233));
518        });
519
520        // Send a barrier.
521        let barrier = Barrier::new_test_barrier(test_epoch(1));
522        test_env.inject_barrier(&barrier, [2.into()]);
523        tx.send(Message::Barrier(barrier.clone().into_dispatcher()).into())
524            .await
525            .unwrap();
526        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Barrier(Barrier { epoch: barrier_epoch, mutation: _, .. }) => {
527            assert_eq!(barrier_epoch.curr, test_epoch(1));
528        });
529
530        // Send 2 chunks before a barrier. Expect the 2 chunks to be merged and the barrier to be emitted.
531        tx.send(Message::Chunk(build_test_chunk(10)).into())
532            .await
533            .unwrap();
534        tx.send(Message::Chunk(build_test_chunk(10)).into())
535            .await
536            .unwrap();
537        let barrier = Barrier::new_test_barrier(test_epoch(2));
538        test_env.inject_barrier(&barrier, [2.into()]);
539        tx.send(Message::Barrier(barrier.clone().into_dispatcher()).into())
540            .await
541            .unwrap();
542        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
543            assert_eq!(chunk.ops().len() as u64, 20);
544        });
545        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Barrier(Barrier { epoch: barrier_epoch, mutation: _, .. }) => {
546            assert_eq!(barrier_epoch.curr, test_epoch(2));
547        });
548    }
549
550    #[tokio::test]
551    async fn test_merger() {
552        const CHANNEL_NUMBER: usize = 10;
553        let mut txs = Vec::with_capacity(CHANNEL_NUMBER);
554        let mut rxs = Vec::with_capacity(CHANNEL_NUMBER);
555        for _i in 0..CHANNEL_NUMBER {
556            let (tx, rx) = channel_for_test();
557            txs.push(tx);
558            rxs.push(rx);
559        }
560        let barrier_test_env = LocalBarrierTestEnv::for_test().await;
561        let actor_id = 233.into();
562        let mut handles = Vec::with_capacity(CHANNEL_NUMBER);
563
564        let epochs = (10..1000u64)
565            .step_by(10)
566            .map(|idx| (idx, test_epoch(idx)))
567            .collect_vec();
568        let mut prev_epoch = 0;
569        let prev_epoch = &mut prev_epoch;
570        let barriers: HashMap<_, _> = epochs
571            .iter()
572            .map(|(_, epoch)| {
573                let barrier = Barrier::with_prev_epoch_for_test(*epoch, *prev_epoch);
574                *prev_epoch = *epoch;
575                barrier_test_env.inject_barrier(&barrier, [actor_id]);
576                (*epoch, barrier)
577            })
578            .collect();
579        let b2 = Barrier::with_prev_epoch_for_test(test_epoch(1000), *prev_epoch)
580            .with_mutation(Mutation::Stop(StopMutation::default()));
581        barrier_test_env.inject_barrier(&b2, [actor_id]);
582        barrier_test_env.flush_all_events().await;
583
584        for (tx_id, tx) in txs.into_iter().enumerate() {
585            let epochs = epochs.clone();
586            let barriers = barriers.clone();
587            let b2 = b2.clone();
588            let handle = tokio::spawn(async move {
589                for (idx, epoch) in epochs {
590                    if idx % 20 == 0 {
591                        tx.send(Message::Chunk(build_test_chunk(10)).into())
592                            .await
593                            .unwrap();
594                    } else {
595                        tx.send(
596                            Message::Watermark(Watermark {
597                                col_idx: (idx as usize / 20 + tx_id) % CHANNEL_NUMBER,
598                                data_type: DataType::Int64,
599                                val: ScalarImpl::Int64(idx as i64),
600                            })
601                            .into(),
602                        )
603                        .await
604                        .unwrap();
605                    }
606                    tx.send(Message::Barrier(barriers[&epoch].clone().into_dispatcher()).into())
607                        .await
608                        .unwrap();
609                    sleep(Duration::from_millis(1)).await;
610                }
611                tx.send(Message::Barrier(b2.clone().into_dispatcher()).into())
612                    .await
613                    .unwrap();
614            });
615            handles.push(handle);
616        }
617
618        let merger = MergeExecutor::for_test(
619            actor_id,
620            rxs,
621            barrier_test_env.local_barrier_manager.clone(),
622            Schema::new(vec![]),
623            100,
624            None,
625        );
626        let mut merger = merger.boxed().execute();
627        for (idx, epoch) in epochs {
628            if idx % 20 == 0 {
629                // expect 1 or more chunks with 100 rows in total
630                let mut count = 0usize;
631                while count < 100 {
632                    assert_matches!(merger.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
633                        count += chunk.ops().len();
634                    });
635                }
636                assert_eq!(count, 100);
637            } else if idx as usize / 20 >= CHANNEL_NUMBER - 1 {
638                // expect n watermarks
639                for _ in 0..CHANNEL_NUMBER {
640                    assert_matches!(merger.next().await.unwrap().unwrap(), Message::Watermark(watermark) => {
641                        assert_eq!(watermark.val, ScalarImpl::Int64((idx - 20 * (CHANNEL_NUMBER as u64 - 1)) as i64));
642                    });
643                }
644            }
645            // expect a barrier
646            assert_matches!(merger.next().await.unwrap().unwrap(), Message::Barrier(Barrier{epoch:barrier_epoch,mutation:_,..}) => {
647                assert_eq!(barrier_epoch.curr, epoch);
648            });
649        }
650        assert_matches!(
651            merger.next().await.unwrap().unwrap(),
652            Message::Barrier(Barrier {
653                mutation,
654                ..
655            }) if mutation.as_deref().unwrap().is_stop()
656        );
657
658        for handle in handles {
659            handle.await.unwrap();
660        }
661    }
662
663    #[tokio::test]
664    async fn test_configuration_change() {
665        let actor_id = 233.into();
666        let (untouched, old, new) = (234.into(), 235.into(), 238.into()); // upstream actors
667        let barrier_test_env = LocalBarrierTestEnv::for_test().await;
668        let metrics = Arc::new(StreamingMetrics::unused());
669
670        // untouched -> actor_id
671        // old -> actor_id
672        // new -> actor_id
673
674        let (upstream_fragment_id, fragment_id) = (10.into(), 18.into());
675
676        let inputs: Vec<_> =
677            try_join_all([untouched, old].into_iter().map(async |upstream_actor_id| {
678                new_input(
679                    &barrier_test_env.local_barrier_manager,
680                    metrics.clone(),
681                    actor_id,
682                    fragment_id,
683                    &helper_make_local_actor(upstream_actor_id),
684                    upstream_fragment_id,
685                )
686                .await
687            }))
688            .await
689            .unwrap();
690
691        let merge_updates = maplit::hashmap! {
692            (actor_id, upstream_fragment_id) => MergeUpdate {
693                actor_id,
694                upstream_fragment_id,
695                new_upstream_fragment_id: None,
696                added_upstream_actors: vec![helper_make_local_actor(new)],
697                removed_upstream_actor_id: vec![old],
698            }
699        };
700
701        let b1 = Barrier::new_test_barrier(test_epoch(1)).with_mutation(Mutation::Update(
702            UpdateMutation {
703                merges: merge_updates,
704                ..Default::default()
705            },
706        ));
707        barrier_test_env.inject_barrier(&b1, [actor_id]);
708        barrier_test_env.flush_all_events().await;
709
710        let barrier_rx = barrier_test_env
711            .local_barrier_manager
712            .subscribe_barrier(actor_id);
713        let actor_ctx = ActorContext::for_test(actor_id);
714        let upstream = MergeExecutor::new_select_receiver(inputs, &metrics, &actor_ctx);
715
716        let mut merge = MergeExecutor::new(
717            actor_ctx,
718            fragment_id,
719            upstream_fragment_id,
720            upstream,
721            barrier_test_env.local_barrier_manager.clone(),
722            metrics.clone(),
723            barrier_rx,
724            100,
725            Schema::new(vec![]),
726        )
727        .boxed()
728        .execute();
729
730        let mut txs = HashMap::new();
731        macro_rules! send {
732            ($actors:expr, $msg:expr) => {
733                for actor in $actors {
734                    txs.get(&actor).unwrap().send($msg).await.unwrap();
735                }
736            };
737        }
738
739        macro_rules! assert_recv_pending {
740            () => {
741                assert!(
742                    merge
743                        .next()
744                        .now_or_never()
745                        .flatten()
746                        .transpose()
747                        .unwrap()
748                        .is_none()
749                );
750            };
751        }
752        macro_rules! recv {
753            () => {
754                merge.next().await.transpose().unwrap()
755            };
756        }
757
758        macro_rules! collect_upstream_tx {
759            ($actors:expr) => {
760                for upstream_id in $actors {
761                    let mut output_requests = barrier_test_env
762                        .take_pending_new_output_requests(upstream_id)
763                        .await;
764                    assert_eq!(output_requests.len(), 1);
765                    let (downstream_actor_id, request) = output_requests.pop().unwrap();
766                    assert_eq!(downstream_actor_id, actor_id);
767                    let NewOutputRequest::Local(tx) = request else {
768                        unreachable!()
769                    };
770                    txs.insert(upstream_id, tx);
771                }
772            };
773        }
774
775        assert_recv_pending!();
776        barrier_test_env.flush_all_events().await;
777
778        // 2. Take downstream receivers.
779        collect_upstream_tx!([untouched, old]);
780
781        // 3. Send a chunk.
782        send!([untouched, old], Message::Chunk(build_test_chunk(1)).into());
783        assert_eq!(2, recv!().unwrap().as_chunk().unwrap().cardinality()); // We should be able to receive the chunk twice.
784        assert_recv_pending!();
785
786        send!(
787            [untouched, old],
788            Message::Barrier(b1.clone().into_dispatcher()).into()
789        );
790        assert_recv_pending!(); // We should not receive the barrier, since merger is waiting for the new upstream new.
791
792        collect_upstream_tx!([new]);
793
794        send!([new], Message::Barrier(b1.clone().into_dispatcher()).into());
795        recv!().unwrap().as_barrier().unwrap(); // We should now receive the barrier.
796
797        // 5. Send a chunk.
798        send!([untouched, new], Message::Chunk(build_test_chunk(1)).into());
799        assert_eq!(2, recv!().unwrap().as_chunk().unwrap().cardinality()); // We should be able to receive the chunk twice.
800        assert_recv_pending!();
801    }
802
803    struct FakeExchangeService {
804        rpc_called: Arc<AtomicBool>,
805    }
806
807    fn exchange_client_test_barrier() -> crate::executor::Barrier {
808        Barrier::new_test_barrier(test_epoch(1))
809    }
810
811    #[async_trait::async_trait]
812    impl ExchangeService for FakeExchangeService {
813        type GetDataStream = ReceiverStream<std::result::Result<GetDataResponse, Status>>;
814        type GetStreamStream = ReceiverStream<std::result::Result<GetStreamResponse, Status>>;
815
816        async fn get_data(
817            &self,
818            _: Request<GetDataRequest>,
819        ) -> std::result::Result<Response<Self::GetDataStream>, Status> {
820            unimplemented!()
821        }
822
823        async fn get_stream(
824            &self,
825            _request: Request<Streaming<GetStreamRequest>>,
826        ) -> std::result::Result<Response<Self::GetStreamStream>, Status> {
827            let (tx, rx) = tokio::sync::mpsc::channel(10);
828            self.rpc_called.store(true, Ordering::SeqCst);
829            // send stream_chunk
830            let stream_chunk = StreamChunk::default().to_protobuf();
831            tx.send(Ok(GetStreamResponse {
832                message: Some(PbStreamMessageBatch {
833                    stream_message_batch: Some(
834                        risingwave_pb::stream_plan::stream_message_batch::StreamMessageBatch::StreamChunk(
835                            stream_chunk,
836                        ),
837                    ),
838                }),
839                permits: Some(PbPermits::default()),
840            }))
841            .await
842            .unwrap();
843            // send barrier
844            let barrier = exchange_client_test_barrier();
845            tx.send(Ok(GetStreamResponse {
846                message: Some(PbStreamMessageBatch {
847                    stream_message_batch: Some(
848                        risingwave_pb::stream_plan::stream_message_batch::StreamMessageBatch::BarrierBatch(
849                            BarrierBatch {
850                                barriers: vec![barrier.to_protobuf()],
851                            },
852                        ),
853                    ),
854                }),
855                permits: Some(PbPermits::default()),
856            }))
857            .await
858            .unwrap();
859            Ok(Response::new(ReceiverStream::new(rx)))
860        }
861    }
862
863    #[tokio::test]
864    async fn test_stream_exchange_client() {
865        let rpc_called = Arc::new(AtomicBool::new(false));
866        let server_run = Arc::new(AtomicBool::new(false));
867        let addr = "127.0.0.1:12348".parse().unwrap();
868
869        // Start a server.
870        let (shutdown_send, shutdown_recv) = tokio::sync::oneshot::channel();
871        let exchange_svc = ExchangeServiceServer::new(FakeExchangeService {
872            rpc_called: rpc_called.clone(),
873        });
874        let cp_server_run = server_run.clone();
875        let join_handle = tokio::spawn(async move {
876            cp_server_run.store(true, Ordering::SeqCst);
877            tonic::transport::Server::builder()
878                .add_service(exchange_svc)
879                .serve_with_shutdown(addr, async move {
880                    shutdown_recv.await.unwrap();
881                })
882                .await
883                .unwrap();
884        });
885
886        sleep(Duration::from_secs(1)).await;
887        assert!(server_run.load(Ordering::SeqCst));
888
889        let test_env = LocalBarrierTestEnv::for_test().await;
890
891        let remote_input = {
892            RemoteInput::new(
893                &test_env.local_barrier_manager,
894                addr.into(),
895                (0.into(), 0.into()),
896                (0.into(), 0.into()),
897                Arc::new(StreamingMetrics::unused()),
898            )
899            .await
900            .unwrap()
901        };
902
903        test_env.inject_barrier(&exchange_client_test_barrier(), [remote_input.id()]);
904
905        pin_mut!(remote_input);
906
907        assert_matches!(remote_input.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
908            let (ops, columns, visibility) = chunk.into_inner();
909            assert!(ops.is_empty());
910            assert!(columns.is_empty());
911            assert!(visibility.is_empty());
912        });
913        assert_matches!(remote_input.next().await.unwrap().unwrap(), Message::Barrier(Barrier { epoch: barrier_epoch, mutation: _, .. }) => {
914            assert_eq!(barrier_epoch.curr, test_epoch(1));
915        });
916        assert!(rpc_called.load(Ordering::SeqCst));
917
918        shutdown_send.send(()).unwrap();
919        join_handle.await.unwrap();
920    }
921}