risingwave_stream/executor/
merge.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::VecDeque;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18
19use risingwave_common::array::StreamChunkBuilder;
20use risingwave_common::config::MetricLevel;
21use tokio::sync::mpsc;
22
23use super::exchange::input::BoxedActorInput;
24use super::*;
25use crate::executor::prelude::*;
26use crate::task::LocalBarrierManager;
27
28pub type SelectReceivers = DynamicReceivers<ActorId, ()>;
29
30pub type MergeUpstream = BufferChunks<SelectReceivers>;
31pub type SingletonUpstream = BoxedActorInput;
32
33pub(crate) enum MergeExecutorUpstream {
34    Singleton(SingletonUpstream),
35    Merge(MergeUpstream),
36}
37
38pub(crate) struct MergeExecutorInput {
39    upstream: MergeExecutorUpstream,
40    actor_context: ActorContextRef,
41    upstream_fragment_id: UpstreamFragmentId,
42    local_barrier_manager: LocalBarrierManager,
43    executor_stats: Arc<StreamingMetrics>,
44    pub(crate) info: ExecutorInfo,
45}
46
47impl MergeExecutorInput {
48    pub(crate) fn new(
49        upstream: MergeExecutorUpstream,
50        actor_context: ActorContextRef,
51        upstream_fragment_id: UpstreamFragmentId,
52        local_barrier_manager: LocalBarrierManager,
53        executor_stats: Arc<StreamingMetrics>,
54        info: ExecutorInfo,
55    ) -> Self {
56        Self {
57            upstream,
58            actor_context,
59            upstream_fragment_id,
60            local_barrier_manager,
61            executor_stats,
62            info,
63        }
64    }
65
66    pub(crate) fn into_executor(self, barrier_rx: mpsc::UnboundedReceiver<Barrier>) -> Executor {
67        let fragment_id = self.actor_context.fragment_id;
68        let executor = match self.upstream {
69            MergeExecutorUpstream::Singleton(input) => ReceiverExecutor::new(
70                self.actor_context,
71                fragment_id,
72                self.upstream_fragment_id,
73                input,
74                self.local_barrier_manager,
75                self.executor_stats,
76                barrier_rx,
77            )
78            .boxed(),
79            MergeExecutorUpstream::Merge(inputs) => MergeExecutor::new(
80                self.actor_context,
81                fragment_id,
82                self.upstream_fragment_id,
83                inputs,
84                self.local_barrier_manager,
85                self.executor_stats,
86                barrier_rx,
87            )
88            .boxed(),
89        };
90        (self.info, executor).into()
91    }
92}
93
94impl Stream for MergeExecutorInput {
95    type Item = DispatcherMessageStreamItem;
96
97    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
98        match &mut self.get_mut().upstream {
99            MergeExecutorUpstream::Singleton(input) => input.poll_next_unpin(cx),
100            MergeExecutorUpstream::Merge(inputs) => inputs.poll_next_unpin(cx),
101        }
102    }
103}
104
105mod upstream {
106    use super::*;
107
108    /// Trait unifying operations on [`MergeUpstream`] and [`SingletonUpstream`], so that we can
109    /// reuse code between [`MergeExecutor`] and [`ReceiverExecutor`].
110    pub trait Upstream:
111        Stream<Item = StreamExecutorResult<DispatcherMessage>> + Unpin + Send + 'static
112    {
113        fn upstream_input_ids(&self) -> impl Iterator<Item = ActorId> + '_;
114        fn flush_buffered_watermarks(&mut self);
115        fn update(&mut self, to_add: Vec<BoxedActorInput>, to_remove: &HashSet<ActorId>);
116    }
117
118    impl Upstream for MergeUpstream {
119        fn upstream_input_ids(&self) -> impl Iterator<Item = ActorId> + '_ {
120            self.inner.upstream_input_ids()
121        }
122
123        fn flush_buffered_watermarks(&mut self) {
124            self.inner.flush_buffered_watermarks();
125        }
126
127        fn update(&mut self, to_add: Vec<BoxedActorInput>, to_remove: &HashSet<ActorId>) {
128            if !to_add.is_empty() {
129                self.inner.add_upstreams_from(to_add);
130            }
131            if !to_remove.is_empty() {
132                self.inner.remove_upstreams(to_remove);
133            }
134        }
135    }
136
137    impl Upstream for SingletonUpstream {
138        fn upstream_input_ids(&self) -> impl Iterator<Item = ActorId> + '_ {
139            std::iter::once(self.id())
140        }
141
142        fn flush_buffered_watermarks(&mut self) {
143            // For single input, we won't buffer watermarks so there's nothing to do.
144        }
145
146        fn update(&mut self, to_add: Vec<BoxedActorInput>, to_remove: &HashSet<ActorId>) {
147            assert_eq!(
148                to_remove,
149                &HashSet::from_iter([self.id()]),
150                "the removed upstream actor should be the same as the current input"
151            );
152
153            // Replace the single input.
154            *self = to_add
155                .into_iter()
156                .exactly_one()
157                .expect("receiver should have exactly one new upstream");
158        }
159    }
160}
161use upstream::Upstream;
162
163/// The core of `MergeExecutor` and `ReceiverExecutor`.
164pub struct MergeExecutorInner<U> {
165    /// The context of the actor.
166    actor_context: ActorContextRef,
167
168    /// Upstream channels.
169    upstream: U,
170
171    /// Belonged fragment id.
172    fragment_id: FragmentId,
173
174    /// Upstream fragment id.
175    upstream_fragment_id: FragmentId,
176
177    local_barrier_manager: LocalBarrierManager,
178
179    /// Streaming metrics.
180    metrics: Arc<StreamingMetrics>,
181
182    barrier_rx: mpsc::UnboundedReceiver<Barrier>,
183}
184
185impl<U> MergeExecutorInner<U> {
186    pub fn new(
187        ctx: ActorContextRef,
188        fragment_id: FragmentId,
189        upstream_fragment_id: FragmentId,
190        upstream: U,
191        local_barrier_manager: LocalBarrierManager,
192        metrics: Arc<StreamingMetrics>,
193        barrier_rx: mpsc::UnboundedReceiver<Barrier>,
194    ) -> Self {
195        Self {
196            actor_context: ctx,
197            upstream,
198            fragment_id,
199            upstream_fragment_id,
200            local_barrier_manager,
201            metrics,
202            barrier_rx,
203        }
204    }
205}
206
207/// `MergeExecutor` merges data from multiple upstream actors and aligns them with barriers.
208pub type MergeExecutor = MergeExecutorInner<MergeUpstream>;
209
210impl MergeExecutor {
211    #[cfg(test)]
212    pub fn for_test(
213        actor_id: impl Into<ActorId>,
214        inputs: Vec<super::exchange::permit::Receiver>,
215        local_barrier_manager: crate::task::LocalBarrierManager,
216        schema: Schema,
217        chunk_size: usize,
218        barrier_rx: Option<mpsc::UnboundedReceiver<Barrier>>,
219    ) -> Self {
220        let actor_id = actor_id.into();
221        use super::exchange::input::LocalInput;
222        use crate::executor::exchange::input::ActorInput;
223
224        let barrier_rx =
225            barrier_rx.unwrap_or_else(|| local_barrier_manager.subscribe_barrier(actor_id));
226
227        let metrics = StreamingMetrics::unused();
228        let actor_ctx = ActorContext::for_test(actor_id);
229        let upstream = Self::new_merge_upstream(
230            inputs
231                .into_iter()
232                .enumerate()
233                .map(|(idx, input)| LocalInput::new(input, ActorId::new(idx as u32)).boxed_input())
234                .collect(),
235            &metrics,
236            &actor_ctx,
237            chunk_size,
238            schema,
239        );
240
241        Self::new(
242            actor_ctx,
243            514.into(),
244            1919.into(),
245            upstream,
246            local_barrier_manager,
247            metrics.into(),
248            barrier_rx,
249        )
250    }
251
252    pub(crate) fn new_merge_upstream(
253        upstreams: Vec<BoxedActorInput>,
254        metrics: &StreamingMetrics,
255        actor_context: &ActorContext,
256        chunk_size: usize,
257        schema: Schema,
258    ) -> MergeUpstream {
259        let merge_barrier_align_duration = if metrics.level >= MetricLevel::Debug {
260            Some(
261                metrics
262                    .merge_barrier_align_duration
263                    .with_guarded_label_values(&[
264                        &actor_context.id.to_string(),
265                        &actor_context.fragment_id.to_string(),
266                    ]),
267            )
268        } else {
269            None
270        };
271
272        BufferChunks::new(
273            // Futures of all active upstreams.
274            SelectReceivers::new(upstreams, None, merge_barrier_align_duration),
275            chunk_size,
276            schema,
277        )
278    }
279}
280
281impl<U> MergeExecutorInner<U>
282where
283    U: Upstream,
284{
285    #[try_stream(ok = Message, error = StreamExecutorError)]
286    pub(crate) async fn execute_inner(mut self: Box<Self>) {
287        let mut upstream = self.upstream;
288        let actor_id = self.actor_context.id;
289
290        let mut metrics = self.metrics.new_actor_input_metrics(
291            actor_id,
292            self.fragment_id,
293            self.upstream_fragment_id,
294        );
295
296        let mut barrier_buffer = DispatchBarrierBuffer::new(
297            self.barrier_rx,
298            actor_id,
299            self.upstream_fragment_id,
300            self.local_barrier_manager,
301            self.metrics.clone(),
302            self.fragment_id,
303            self.actor_context.config.clone(),
304        );
305
306        loop {
307            let msg = barrier_buffer
308                .await_next_message(&mut upstream, &metrics)
309                .await?;
310            let msg = match msg {
311                DispatcherMessage::Watermark(watermark) => Message::Watermark(watermark),
312                DispatcherMessage::Chunk(chunk) => {
313                    metrics.actor_in_record_cnt.inc_by(chunk.cardinality() as _);
314                    Message::Chunk(chunk)
315                }
316                DispatcherMessage::Barrier(barrier) => {
317                    let (barrier, new_inputs) =
318                        barrier_buffer.pop_barrier_with_inputs(barrier).await?;
319
320                    if let Some(Mutation::Update(UpdateMutation { dispatchers, .. })) =
321                        barrier.mutation.as_deref()
322                        && upstream
323                            .upstream_input_ids()
324                            .any(|actor_id| dispatchers.contains_key(&actor_id))
325                    {
326                        // `Watermark` of upstream may become stale after downstream scaling.
327                        upstream.flush_buffered_watermarks();
328                    }
329
330                    if let Some(update) =
331                        barrier.as_update_merge(self.actor_context.id, self.upstream_fragment_id)
332                    {
333                        let new_upstream_fragment_id = update
334                            .new_upstream_fragment_id
335                            .unwrap_or(self.upstream_fragment_id);
336                        let removed_upstream_actor_id: HashSet<_> =
337                            if update.new_upstream_fragment_id.is_some() {
338                                upstream.upstream_input_ids().collect()
339                            } else {
340                                update.removed_upstream_actor_id.iter().copied().collect()
341                            };
342
343                        // `Watermark` of upstream may become stale after upstream scaling.
344                        upstream.flush_buffered_watermarks();
345
346                        // Add and remove upstreams.
347                        upstream.update(new_inputs.unwrap_or_default(), &removed_upstream_actor_id);
348
349                        self.upstream_fragment_id = new_upstream_fragment_id;
350                        metrics = self.metrics.new_actor_input_metrics(
351                            actor_id,
352                            self.fragment_id,
353                            self.upstream_fragment_id,
354                        );
355                    }
356
357                    let is_stop = barrier.is_stop(actor_id);
358                    let msg = Message::Barrier(barrier);
359                    if is_stop {
360                        yield msg;
361                        break;
362                    }
363
364                    msg
365                }
366            };
367
368            yield msg;
369        }
370    }
371}
372
373impl<U> Execute for MergeExecutorInner<U>
374where
375    U: Upstream,
376{
377    fn execute(self: Box<Self>) -> BoxedMessageStream {
378        self.execute_inner().boxed()
379    }
380}
381
382/// A wrapper that buffers the `StreamChunk`s from upstream until no more ready items are available.
383/// Besides, any message other than `StreamChunk` will trigger the buffered `StreamChunk`s
384/// to be emitted immediately along with the message itself.
385pub struct BufferChunks<S: Stream> {
386    inner: S,
387    chunk_builder: StreamChunkBuilder,
388
389    /// The items to be emitted. Whenever there's something here, we should return a `Poll::Ready` immediately.
390    pending_items: VecDeque<S::Item>,
391}
392
393impl<S: Stream> BufferChunks<S> {
394    pub(super) fn new(inner: S, chunk_size: usize, schema: Schema) -> Self {
395        assert!(chunk_size > 0);
396        let chunk_builder = StreamChunkBuilder::new(chunk_size, schema.data_types());
397        Self {
398            inner,
399            chunk_builder,
400            pending_items: VecDeque::new(),
401        }
402    }
403}
404
405impl<S: Stream> std::ops::Deref for BufferChunks<S> {
406    type Target = S;
407
408    fn deref(&self) -> &Self::Target {
409        &self.inner
410    }
411}
412
413impl<S: Stream> std::ops::DerefMut for BufferChunks<S> {
414    fn deref_mut(&mut self) -> &mut Self::Target {
415        &mut self.inner
416    }
417}
418
419impl<S: Stream> Stream for BufferChunks<S>
420where
421    S: Stream<Item = DispatcherMessageStreamItem> + Unpin,
422{
423    type Item = S::Item;
424
425    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
426        loop {
427            if let Some(item) = self.pending_items.pop_front() {
428                return Poll::Ready(Some(item));
429            }
430
431            match self.inner.poll_next_unpin(cx) {
432                Poll::Pending => {
433                    return if let Some(chunk_out) = self.chunk_builder.take() {
434                        Poll::Ready(Some(Ok(MessageInner::Chunk(chunk_out))))
435                    } else {
436                        Poll::Pending
437                    };
438                }
439
440                Poll::Ready(Some(result)) => {
441                    if let Ok(MessageInner::Chunk(chunk)) = result {
442                        for row in chunk.records() {
443                            if let Some(chunk_out) = self.chunk_builder.append_record(row) {
444                                self.pending_items
445                                    .push_back(Ok(MessageInner::Chunk(chunk_out)));
446                            }
447                        }
448                    } else {
449                        return if let Some(chunk_out) = self.chunk_builder.take() {
450                            self.pending_items.push_back(result);
451                            Poll::Ready(Some(Ok(MessageInner::Chunk(chunk_out))))
452                        } else {
453                            Poll::Ready(Some(result))
454                        };
455                    }
456                }
457
458                Poll::Ready(None) => {
459                    // See also the comments in `DynamicReceivers::poll_next`.
460                    unreachable!("Merge should always have upstream inputs");
461                }
462            }
463        }
464    }
465}
466
467#[cfg(test)]
468mod tests {
469    use std::sync::atomic::{AtomicBool, Ordering};
470    use std::time::Duration;
471
472    use assert_matches::assert_matches;
473    use futures::FutureExt;
474    use futures::future::try_join_all;
475    use risingwave_common::array::Op;
476    use risingwave_common::util::epoch::test_epoch;
477    use risingwave_pb::task_service::stream_exchange_service_server::{
478        StreamExchangeService, StreamExchangeServiceServer,
479    };
480    use risingwave_pb::task_service::{GetStreamRequest, GetStreamResponse, PbPermits};
481    use tokio::time::sleep;
482    use tokio_stream::wrappers::ReceiverStream;
483    use tonic::{Request, Response, Status, Streaming};
484
485    use super::*;
486    use crate::executor::exchange::input::{ActorInput, LocalInput, RemoteInput};
487    use crate::executor::exchange::permit::channel_for_test;
488    use crate::executor::{BarrierInner as Barrier, MessageInner as Message};
489    use crate::task::NewOutputRequest;
490    use crate::task::barrier_test_utils::LocalBarrierTestEnv;
491    use crate::task::test_utils::helper_make_local_actor;
492
493    fn build_test_chunk(size: u64) -> StreamChunk {
494        let ops = vec![Op::Insert; size as usize];
495        StreamChunk::new(ops, vec![])
496    }
497
498    #[tokio::test]
499    async fn test_buffer_chunks() {
500        let test_env = LocalBarrierTestEnv::for_test().await;
501
502        let (tx, rx) = channel_for_test();
503        let input = LocalInput::new(rx, 1.into()).boxed_input();
504        let mut buffer = BufferChunks::new(input, 100, Schema::new(vec![]));
505
506        // Send a chunk
507        tx.send(Message::Chunk(build_test_chunk(10)).into())
508            .await
509            .unwrap();
510        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
511            assert_eq!(chunk.ops().len() as u64, 10);
512        });
513
514        // Send 2 chunks and expect them to be merged.
515        tx.send(Message::Chunk(build_test_chunk(10)).into())
516            .await
517            .unwrap();
518        tx.send(Message::Chunk(build_test_chunk(10)).into())
519            .await
520            .unwrap();
521        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
522            assert_eq!(chunk.ops().len() as u64, 20);
523        });
524
525        // Send a watermark.
526        tx.send(
527            Message::Watermark(Watermark {
528                col_idx: 0,
529                data_type: DataType::Int64,
530                val: ScalarImpl::Int64(233),
531            })
532            .into(),
533        )
534        .await
535        .unwrap();
536        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Watermark(watermark) => {
537            assert_eq!(watermark.val, ScalarImpl::Int64(233));
538        });
539
540        // Send 2 chunks before a watermark. Expect the 2 chunks to be merged and the watermark to be emitted.
541        tx.send(Message::Chunk(build_test_chunk(10)).into())
542            .await
543            .unwrap();
544        tx.send(Message::Chunk(build_test_chunk(10)).into())
545            .await
546            .unwrap();
547        tx.send(
548            Message::Watermark(Watermark {
549                col_idx: 0,
550                data_type: DataType::Int64,
551                val: ScalarImpl::Int64(233),
552            })
553            .into(),
554        )
555        .await
556        .unwrap();
557        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
558            assert_eq!(chunk.ops().len() as u64, 20);
559        });
560        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Watermark(watermark) => {
561            assert_eq!(watermark.val, ScalarImpl::Int64(233));
562        });
563
564        // Send a barrier.
565        let barrier = Barrier::new_test_barrier(test_epoch(1));
566        test_env.inject_barrier(&barrier, [2.into()]);
567        tx.send(Message::Barrier(barrier.clone().into_dispatcher()).into())
568            .await
569            .unwrap();
570        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Barrier(Barrier { epoch: barrier_epoch, mutation: _, .. }) => {
571            assert_eq!(barrier_epoch.curr, test_epoch(1));
572        });
573
574        // Send 2 chunks before a barrier. Expect the 2 chunks to be merged and the barrier to be emitted.
575        tx.send(Message::Chunk(build_test_chunk(10)).into())
576            .await
577            .unwrap();
578        tx.send(Message::Chunk(build_test_chunk(10)).into())
579            .await
580            .unwrap();
581        let barrier = Barrier::new_test_barrier(test_epoch(2));
582        test_env.inject_barrier(&barrier, [2.into()]);
583        tx.send(Message::Barrier(barrier.clone().into_dispatcher()).into())
584            .await
585            .unwrap();
586        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
587            assert_eq!(chunk.ops().len() as u64, 20);
588        });
589        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Barrier(Barrier { epoch: barrier_epoch, mutation: _, .. }) => {
590            assert_eq!(barrier_epoch.curr, test_epoch(2));
591        });
592    }
593
594    #[tokio::test]
595    async fn test_merger() {
596        const CHANNEL_NUMBER: usize = 10;
597        let mut txs = Vec::with_capacity(CHANNEL_NUMBER);
598        let mut rxs = Vec::with_capacity(CHANNEL_NUMBER);
599        for _i in 0..CHANNEL_NUMBER {
600            let (tx, rx) = channel_for_test();
601            txs.push(tx);
602            rxs.push(rx);
603        }
604        let barrier_test_env = LocalBarrierTestEnv::for_test().await;
605        let actor_id = 233.into();
606        let mut handles = Vec::with_capacity(CHANNEL_NUMBER);
607
608        let epochs = (10..1000u64)
609            .step_by(10)
610            .map(|idx| (idx, test_epoch(idx)))
611            .collect_vec();
612        let mut prev_epoch = 0;
613        let prev_epoch = &mut prev_epoch;
614        let barriers: HashMap<_, _> = epochs
615            .iter()
616            .map(|(_, epoch)| {
617                let barrier = Barrier::with_prev_epoch_for_test(*epoch, *prev_epoch);
618                *prev_epoch = *epoch;
619                barrier_test_env.inject_barrier(&barrier, [actor_id]);
620                (*epoch, barrier)
621            })
622            .collect();
623        let b2 = Barrier::with_prev_epoch_for_test(test_epoch(1000), *prev_epoch)
624            .with_mutation(Mutation::Stop(StopMutation::default()));
625        barrier_test_env.inject_barrier(&b2, [actor_id]);
626        barrier_test_env.flush_all_events().await;
627
628        for (tx_id, tx) in txs.into_iter().enumerate() {
629            let epochs = epochs.clone();
630            let barriers = barriers.clone();
631            let b2 = b2.clone();
632            let handle = tokio::spawn(async move {
633                for (idx, epoch) in epochs {
634                    if idx % 20 == 0 {
635                        tx.send(Message::Chunk(build_test_chunk(10)).into())
636                            .await
637                            .unwrap();
638                    } else {
639                        tx.send(
640                            Message::Watermark(Watermark {
641                                col_idx: (idx as usize / 20 + tx_id) % CHANNEL_NUMBER,
642                                data_type: DataType::Int64,
643                                val: ScalarImpl::Int64(idx as i64),
644                            })
645                            .into(),
646                        )
647                        .await
648                        .unwrap();
649                    }
650                    tx.send(Message::Barrier(barriers[&epoch].clone().into_dispatcher()).into())
651                        .await
652                        .unwrap();
653                    sleep(Duration::from_millis(1)).await;
654                }
655                tx.send(Message::Barrier(b2.clone().into_dispatcher()).into())
656                    .await
657                    .unwrap();
658            });
659            handles.push(handle);
660        }
661
662        let merger = MergeExecutor::for_test(
663            actor_id,
664            rxs,
665            barrier_test_env.local_barrier_manager.clone(),
666            Schema::new(vec![]),
667            100,
668            None,
669        );
670        let mut merger = merger.boxed().execute();
671        for (idx, epoch) in epochs {
672            if idx % 20 == 0 {
673                // expect 1 or more chunks with 100 rows in total
674                let mut count = 0usize;
675                while count < 100 {
676                    assert_matches!(merger.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
677                        count += chunk.ops().len();
678                    });
679                }
680                assert_eq!(count, 100);
681            } else if idx as usize / 20 >= CHANNEL_NUMBER - 1 {
682                // expect n watermarks
683                for _ in 0..CHANNEL_NUMBER {
684                    assert_matches!(merger.next().await.unwrap().unwrap(), Message::Watermark(watermark) => {
685                        assert_eq!(watermark.val, ScalarImpl::Int64((idx - 20 * (CHANNEL_NUMBER as u64 - 1)) as i64));
686                    });
687                }
688            }
689            // expect a barrier
690            assert_matches!(merger.next().await.unwrap().unwrap(), Message::Barrier(Barrier{epoch:barrier_epoch,mutation:_,..}) => {
691                assert_eq!(barrier_epoch.curr, epoch);
692            });
693        }
694        assert_matches!(
695            merger.next().await.unwrap().unwrap(),
696            Message::Barrier(Barrier {
697                mutation,
698                ..
699            }) if mutation.as_deref().unwrap().is_stop()
700        );
701
702        for handle in handles {
703            handle.await.unwrap();
704        }
705    }
706
707    #[tokio::test]
708    async fn test_configuration_change() {
709        let actor_id = 233.into();
710        let (untouched, old, new) = (234.into(), 235.into(), 238.into()); // upstream actors
711        let barrier_test_env = LocalBarrierTestEnv::for_test().await;
712        let metrics = Arc::new(StreamingMetrics::unused());
713
714        // untouched -> actor_id
715        // old -> actor_id
716        // new -> actor_id
717
718        let (upstream_fragment_id, fragment_id) = (10.into(), 18.into());
719
720        let actor_ctx = ActorContext::for_test(actor_id);
721
722        let inputs: Vec<_> =
723            try_join_all([untouched, old].into_iter().map(async |upstream_actor_id| {
724                new_input(
725                    &barrier_test_env.local_barrier_manager,
726                    metrics.clone(),
727                    actor_id,
728                    fragment_id,
729                    &helper_make_local_actor(upstream_actor_id),
730                    upstream_fragment_id,
731                    actor_ctx.config.clone(),
732                )
733                .await
734            }))
735            .await
736            .unwrap();
737
738        let merge_updates = maplit::hashmap! {
739            (actor_id, upstream_fragment_id) => MergeUpdate {
740                actor_id,
741                upstream_fragment_id,
742                new_upstream_fragment_id: None,
743                added_upstream_actors: vec![helper_make_local_actor(new)],
744                removed_upstream_actor_id: vec![old],
745            }
746        };
747
748        let b1 = Barrier::new_test_barrier(test_epoch(1)).with_mutation(Mutation::Update(
749            UpdateMutation {
750                merges: merge_updates,
751                ..Default::default()
752            },
753        ));
754        barrier_test_env.inject_barrier(&b1, [actor_id]);
755        barrier_test_env.flush_all_events().await;
756
757        let barrier_rx = barrier_test_env
758            .local_barrier_manager
759            .subscribe_barrier(actor_id);
760        let upstream = MergeExecutor::new_merge_upstream(
761            inputs,
762            &metrics,
763            &actor_ctx,
764            100,
765            Schema::empty().clone(),
766        );
767
768        let mut merge = MergeExecutor::new(
769            actor_ctx,
770            fragment_id,
771            upstream_fragment_id,
772            upstream,
773            barrier_test_env.local_barrier_manager.clone(),
774            metrics.clone(),
775            barrier_rx,
776        )
777        .boxed()
778        .execute();
779
780        let mut txs = HashMap::new();
781        macro_rules! send {
782            ($actors:expr, $msg:expr) => {
783                for actor in $actors {
784                    txs.get(&actor).unwrap().send($msg).await.unwrap();
785                }
786            };
787        }
788
789        macro_rules! assert_recv_pending {
790            () => {
791                assert!(
792                    merge
793                        .next()
794                        .now_or_never()
795                        .flatten()
796                        .transpose()
797                        .unwrap()
798                        .is_none()
799                );
800            };
801        }
802        macro_rules! recv {
803            () => {
804                merge.next().await.transpose().unwrap()
805            };
806        }
807
808        macro_rules! collect_upstream_tx {
809            ($actors:expr) => {
810                for upstream_id in $actors {
811                    let mut output_requests = barrier_test_env
812                        .take_pending_new_output_requests(upstream_id)
813                        .await;
814                    assert_eq!(output_requests.len(), 1);
815                    let (downstream_actor_id, request) = output_requests.pop().unwrap();
816                    assert_eq!(downstream_actor_id, actor_id);
817                    let NewOutputRequest::Local(tx) = request else {
818                        unreachable!()
819                    };
820                    txs.insert(upstream_id, tx);
821                }
822            };
823        }
824
825        assert_recv_pending!();
826        barrier_test_env.flush_all_events().await;
827
828        // 2. Take downstream receivers.
829        collect_upstream_tx!([untouched, old]);
830
831        // 3. Send a chunk.
832        send!([untouched, old], Message::Chunk(build_test_chunk(1)).into());
833        assert_eq!(2, recv!().unwrap().as_chunk().unwrap().cardinality()); // We should be able to receive the chunk twice.
834        assert_recv_pending!();
835
836        send!(
837            [untouched, old],
838            Message::Barrier(b1.clone().into_dispatcher()).into()
839        );
840        assert_recv_pending!(); // We should not receive the barrier, since merger is waiting for the new upstream new.
841
842        collect_upstream_tx!([new]);
843
844        send!([new], Message::Barrier(b1.clone().into_dispatcher()).into());
845        recv!().unwrap().as_barrier().unwrap(); // We should now receive the barrier.
846
847        // 5. Send a chunk.
848        send!([untouched, new], Message::Chunk(build_test_chunk(1)).into());
849        assert_eq!(2, recv!().unwrap().as_chunk().unwrap().cardinality()); // We should be able to receive the chunk twice.
850        assert_recv_pending!();
851    }
852
853    struct FakeExchangeService {
854        rpc_called: Arc<AtomicBool>,
855    }
856
857    fn exchange_client_test_barrier() -> crate::executor::Barrier {
858        Barrier::new_test_barrier(test_epoch(1))
859    }
860
861    #[async_trait::async_trait]
862    impl StreamExchangeService for FakeExchangeService {
863        type GetStreamStream = ReceiverStream<std::result::Result<GetStreamResponse, Status>>;
864
865        async fn get_stream(
866            &self,
867            _request: Request<Streaming<GetStreamRequest>>,
868        ) -> std::result::Result<Response<Self::GetStreamStream>, Status> {
869            let (tx, rx) = tokio::sync::mpsc::channel(10);
870            self.rpc_called.store(true, Ordering::SeqCst);
871            // send stream_chunk
872            let stream_chunk = StreamChunk::default().to_protobuf();
873            tx.send(Ok(GetStreamResponse {
874                message: Some(PbStreamMessageBatch {
875                    stream_message_batch: Some(
876                        risingwave_pb::stream_plan::stream_message_batch::StreamMessageBatch::StreamChunk(
877                            stream_chunk,
878                        ),
879                    ),
880                }),
881                permits: Some(PbPermits::default()),
882            }))
883            .await
884            .unwrap();
885            // send barrier
886            let barrier = exchange_client_test_barrier();
887            tx.send(Ok(GetStreamResponse {
888                message: Some(PbStreamMessageBatch {
889                    stream_message_batch: Some(
890                        risingwave_pb::stream_plan::stream_message_batch::StreamMessageBatch::BarrierBatch(
891                            BarrierBatch {
892                                barriers: vec![barrier.to_protobuf()],
893                            },
894                        ),
895                    ),
896                }),
897                permits: Some(PbPermits::default()),
898            }))
899            .await
900            .unwrap();
901            Ok(Response::new(ReceiverStream::new(rx)))
902        }
903    }
904
905    #[tokio::test]
906    async fn test_stream_exchange_client() {
907        let rpc_called = Arc::new(AtomicBool::new(false));
908        let server_run = Arc::new(AtomicBool::new(false));
909        let addr = "127.0.0.1:12348".parse().unwrap();
910
911        // Start a server.
912        let (shutdown_send, shutdown_recv) = tokio::sync::oneshot::channel();
913        let exchange_svc = StreamExchangeServiceServer::new(FakeExchangeService {
914            rpc_called: rpc_called.clone(),
915        });
916        let cp_server_run = server_run.clone();
917        let join_handle = tokio::spawn(async move {
918            cp_server_run.store(true, Ordering::SeqCst);
919            tonic::transport::Server::builder()
920                .add_service(exchange_svc)
921                .serve_with_shutdown(addr, async move {
922                    shutdown_recv.await.unwrap();
923                })
924                .await
925                .unwrap();
926        });
927
928        sleep(Duration::from_secs(1)).await;
929        assert!(server_run.load(Ordering::SeqCst));
930
931        let test_env = LocalBarrierTestEnv::for_test().await;
932
933        let remote_input = {
934            RemoteInput::new(
935                &test_env.local_barrier_manager,
936                addr.into(),
937                (0.into(), 0.into()),
938                (0.into(), 0.into()),
939                Arc::new(StreamingMetrics::unused()),
940                Arc::new(StreamingConfig::default()),
941            )
942            .await
943            .unwrap()
944        };
945
946        test_env.inject_barrier(&exchange_client_test_barrier(), [remote_input.id()]);
947
948        pin_mut!(remote_input);
949
950        assert_matches!(remote_input.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
951            let (ops, columns, visibility) = chunk.into_inner();
952            assert!(ops.is_empty());
953            assert!(columns.is_empty());
954            assert!(visibility.is_empty());
955        });
956        assert_matches!(remote_input.next().await.unwrap().unwrap(), Message::Barrier(Barrier { epoch: barrier_epoch, mutation: _, .. }) => {
957            assert_eq!(barrier_epoch.curr, test_epoch(1));
958        });
959        assert!(rpc_called.load(Ordering::SeqCst));
960
961        shutdown_send.send(()).unwrap();
962        join_handle.await.unwrap();
963    }
964}