risingwave_stream/executor/
merge.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::VecDeque;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18
19use risingwave_common::array::StreamChunkBuilder;
20use tokio::sync::mpsc;
21
22use super::exchange::input::BoxedActorInput;
23use super::*;
24use crate::executor::prelude::*;
25use crate::task::LocalBarrierManager;
26
27pub type SelectReceivers = DynamicReceivers<ActorId, ()>;
28
29pub type MergeUpstream = BufferChunks<SelectReceivers>;
30pub type SingletonUpstream = BoxedActorInput;
31
32pub(crate) enum MergeExecutorUpstream {
33    Singleton(SingletonUpstream),
34    Merge(MergeUpstream),
35}
36
37pub(crate) struct MergeExecutorInput {
38    upstream: MergeExecutorUpstream,
39    actor_context: ActorContextRef,
40    upstream_fragment_id: UpstreamFragmentId,
41    local_barrier_manager: LocalBarrierManager,
42    executor_stats: Arc<StreamingMetrics>,
43    pub(crate) info: ExecutorInfo,
44}
45
46impl MergeExecutorInput {
47    pub(crate) fn new(
48        upstream: MergeExecutorUpstream,
49        actor_context: ActorContextRef,
50        upstream_fragment_id: UpstreamFragmentId,
51        local_barrier_manager: LocalBarrierManager,
52        executor_stats: Arc<StreamingMetrics>,
53        info: ExecutorInfo,
54    ) -> Self {
55        Self {
56            upstream,
57            actor_context,
58            upstream_fragment_id,
59            local_barrier_manager,
60            executor_stats,
61            info,
62        }
63    }
64
65    pub(crate) fn into_executor(self, barrier_rx: mpsc::UnboundedReceiver<Barrier>) -> Executor {
66        let fragment_id = self.actor_context.fragment_id;
67        let executor = match self.upstream {
68            MergeExecutorUpstream::Singleton(input) => ReceiverExecutor::new(
69                self.actor_context,
70                fragment_id,
71                self.upstream_fragment_id,
72                input,
73                self.local_barrier_manager,
74                self.executor_stats,
75                barrier_rx,
76            )
77            .boxed(),
78            MergeExecutorUpstream::Merge(inputs) => MergeExecutor::new(
79                self.actor_context,
80                fragment_id,
81                self.upstream_fragment_id,
82                inputs,
83                self.local_barrier_manager,
84                self.executor_stats,
85                barrier_rx,
86            )
87            .boxed(),
88        };
89        (self.info, executor).into()
90    }
91}
92
93impl Stream for MergeExecutorInput {
94    type Item = DispatcherMessageStreamItem;
95
96    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
97        match &mut self.get_mut().upstream {
98            MergeExecutorUpstream::Singleton(input) => input.poll_next_unpin(cx),
99            MergeExecutorUpstream::Merge(inputs) => inputs.poll_next_unpin(cx),
100        }
101    }
102}
103
104mod upstream {
105    use super::*;
106
107    /// Trait unifying operations on [`MergeUpstream`] and [`SingletonUpstream`], so that we can
108    /// reuse code between [`MergeExecutor`] and [`ReceiverExecutor`].
109    pub trait Upstream:
110        Stream<Item = StreamExecutorResult<DispatcherMessage>> + Unpin + Send + 'static
111    {
112        fn upstream_input_ids(&self) -> impl Iterator<Item = ActorId> + '_;
113        fn flush_buffered_watermarks(&mut self);
114        fn update(&mut self, to_add: Vec<BoxedActorInput>, to_remove: &HashSet<ActorId>);
115    }
116
117    impl Upstream for MergeUpstream {
118        fn upstream_input_ids(&self) -> impl Iterator<Item = ActorId> + '_ {
119            self.inner.upstream_input_ids()
120        }
121
122        fn flush_buffered_watermarks(&mut self) {
123            self.inner.flush_buffered_watermarks();
124        }
125
126        fn update(&mut self, to_add: Vec<BoxedActorInput>, to_remove: &HashSet<ActorId>) {
127            if !to_add.is_empty() {
128                self.inner.add_upstreams_from(to_add);
129            }
130            if !to_remove.is_empty() {
131                self.inner.remove_upstreams(to_remove);
132            }
133        }
134    }
135
136    impl Upstream for SingletonUpstream {
137        fn upstream_input_ids(&self) -> impl Iterator<Item = ActorId> + '_ {
138            std::iter::once(self.id())
139        }
140
141        fn flush_buffered_watermarks(&mut self) {
142            // For single input, we won't buffer watermarks so there's nothing to do.
143        }
144
145        fn update(&mut self, to_add: Vec<BoxedActorInput>, to_remove: &HashSet<ActorId>) {
146            assert_eq!(
147                to_remove,
148                &HashSet::from_iter([self.id()]),
149                "the removed upstream actor should be the same as the current input"
150            );
151
152            // Replace the single input.
153            *self = to_add
154                .into_iter()
155                .exactly_one()
156                .expect("receiver should have exactly one new upstream");
157        }
158    }
159}
160use upstream::Upstream;
161
162/// The core of `MergeExecutor` and `ReceiverExecutor`.
163pub struct MergeExecutorInner<U> {
164    /// The context of the actor.
165    actor_context: ActorContextRef,
166
167    /// Upstream channels.
168    upstream: U,
169
170    /// Belonged fragment id.
171    fragment_id: FragmentId,
172
173    /// Upstream fragment id.
174    upstream_fragment_id: FragmentId,
175
176    local_barrier_manager: LocalBarrierManager,
177
178    /// Streaming metrics.
179    metrics: Arc<StreamingMetrics>,
180
181    barrier_rx: mpsc::UnboundedReceiver<Barrier>,
182}
183
184impl<U> MergeExecutorInner<U> {
185    pub fn new(
186        ctx: ActorContextRef,
187        fragment_id: FragmentId,
188        upstream_fragment_id: FragmentId,
189        upstream: U,
190        local_barrier_manager: LocalBarrierManager,
191        metrics: Arc<StreamingMetrics>,
192        barrier_rx: mpsc::UnboundedReceiver<Barrier>,
193    ) -> Self {
194        Self {
195            actor_context: ctx,
196            upstream,
197            fragment_id,
198            upstream_fragment_id,
199            local_barrier_manager,
200            metrics,
201            barrier_rx,
202        }
203    }
204}
205
206/// `MergeExecutor` merges data from multiple upstream actors and aligns them with barriers.
207pub type MergeExecutor = MergeExecutorInner<MergeUpstream>;
208
209impl MergeExecutor {
210    #[cfg(test)]
211    pub fn for_test(
212        actor_id: impl Into<ActorId>,
213        inputs: Vec<super::exchange::permit::Receiver>,
214        local_barrier_manager: crate::task::LocalBarrierManager,
215        schema: Schema,
216        chunk_size: usize,
217        barrier_rx: Option<mpsc::UnboundedReceiver<Barrier>>,
218    ) -> Self {
219        let actor_id = actor_id.into();
220        use super::exchange::input::LocalInput;
221        use crate::executor::exchange::input::ActorInput;
222
223        let barrier_rx =
224            barrier_rx.unwrap_or_else(|| local_barrier_manager.subscribe_barrier(actor_id));
225
226        let metrics = StreamingMetrics::unused();
227        let actor_ctx = ActorContext::for_test(actor_id);
228        let upstream = Self::new_merge_upstream(
229            inputs
230                .into_iter()
231                .enumerate()
232                .map(|(idx, input)| LocalInput::new(input, ActorId::new(idx as u32)).boxed_input())
233                .collect(),
234            &metrics,
235            &actor_ctx,
236            chunk_size,
237            schema,
238        );
239
240        Self::new(
241            actor_ctx,
242            514.into(),
243            1919.into(),
244            upstream,
245            local_barrier_manager,
246            metrics.into(),
247            barrier_rx,
248        )
249    }
250
251    pub(crate) fn new_merge_upstream(
252        upstreams: Vec<BoxedActorInput>,
253        metrics: &StreamingMetrics,
254        actor_context: &ActorContext,
255        chunk_size: usize,
256        schema: Schema,
257    ) -> MergeUpstream {
258        let merge_barrier_align_duration = Some(
259            metrics
260                .merge_barrier_align_duration
261                .with_guarded_label_values(&[
262                    &actor_context.id.to_string(),
263                    &actor_context.fragment_id.to_string(),
264                ]),
265        );
266
267        BufferChunks::new(
268            // Futures of all active upstreams.
269            SelectReceivers::new(upstreams, None, merge_barrier_align_duration),
270            chunk_size,
271            schema,
272        )
273    }
274}
275
276impl<U> MergeExecutorInner<U>
277where
278    U: Upstream,
279{
280    #[try_stream(ok = Message, error = StreamExecutorError)]
281    pub(crate) async fn execute_inner(mut self: Box<Self>) {
282        let mut upstream = self.upstream;
283        let actor_id = self.actor_context.id;
284
285        let mut metrics = self.metrics.new_actor_input_metrics(
286            actor_id,
287            self.fragment_id,
288            self.upstream_fragment_id,
289        );
290
291        let mut barrier_buffer = DispatchBarrierBuffer::new(
292            self.barrier_rx,
293            actor_id,
294            self.upstream_fragment_id,
295            self.local_barrier_manager,
296            self.metrics.clone(),
297            self.fragment_id,
298            self.actor_context.config.clone(),
299        );
300
301        loop {
302            let msg = barrier_buffer
303                .await_next_message(&mut upstream, &metrics)
304                .await?;
305            let msg = match msg {
306                DispatcherMessage::Watermark(watermark) => Message::Watermark(watermark),
307                DispatcherMessage::Chunk(chunk) => {
308                    metrics.actor_in_record_cnt.inc_by(chunk.cardinality() as _);
309                    Message::Chunk(chunk)
310                }
311                DispatcherMessage::Barrier(barrier) => {
312                    let (barrier, new_inputs) =
313                        barrier_buffer.pop_barrier_with_inputs(barrier).await?;
314
315                    if let Some(Mutation::Update(UpdateMutation { dispatchers, .. })) =
316                        barrier.mutation.as_deref()
317                        && upstream
318                            .upstream_input_ids()
319                            .any(|actor_id| dispatchers.contains_key(&actor_id))
320                    {
321                        // `Watermark` of upstream may become stale after downstream scaling.
322                        upstream.flush_buffered_watermarks();
323                    }
324
325                    if let Some(update) =
326                        barrier.as_update_merge(self.actor_context.id, self.upstream_fragment_id)
327                    {
328                        let new_upstream_fragment_id = update
329                            .new_upstream_fragment_id
330                            .unwrap_or(self.upstream_fragment_id);
331                        let removed_upstream_actor_id: HashSet<_> =
332                            if update.new_upstream_fragment_id.is_some() {
333                                upstream.upstream_input_ids().collect()
334                            } else {
335                                update.removed_upstream_actor_id.iter().copied().collect()
336                            };
337
338                        // `Watermark` of upstream may become stale after upstream scaling.
339                        upstream.flush_buffered_watermarks();
340
341                        // Add and remove upstreams.
342                        upstream.update(new_inputs.unwrap_or_default(), &removed_upstream_actor_id);
343
344                        self.upstream_fragment_id = new_upstream_fragment_id;
345                        metrics = self.metrics.new_actor_input_metrics(
346                            actor_id,
347                            self.fragment_id,
348                            self.upstream_fragment_id,
349                        );
350                    }
351
352                    let is_stop = barrier.is_stop(actor_id);
353                    let msg = Message::Barrier(barrier);
354                    if is_stop {
355                        yield msg;
356                        break;
357                    }
358
359                    msg
360                }
361            };
362
363            yield msg;
364        }
365    }
366}
367
368impl<U> Execute for MergeExecutorInner<U>
369where
370    U: Upstream,
371{
372    fn execute(self: Box<Self>) -> BoxedMessageStream {
373        self.execute_inner().boxed()
374    }
375}
376
377/// A wrapper that buffers the `StreamChunk`s from upstream until no more ready items are available.
378/// Besides, any message other than `StreamChunk` will trigger the buffered `StreamChunk`s
379/// to be emitted immediately along with the message itself.
380pub struct BufferChunks<S: Stream> {
381    inner: S,
382    chunk_builder: StreamChunkBuilder,
383
384    /// The items to be emitted. Whenever there's something here, we should return a `Poll::Ready` immediately.
385    pending_items: VecDeque<S::Item>,
386}
387
388impl<S: Stream> BufferChunks<S> {
389    pub(super) fn new(inner: S, chunk_size: usize, schema: Schema) -> Self {
390        assert!(chunk_size > 0);
391        let chunk_builder = StreamChunkBuilder::new(chunk_size, schema.data_types());
392        Self {
393            inner,
394            chunk_builder,
395            pending_items: VecDeque::new(),
396        }
397    }
398}
399
400impl<S: Stream> std::ops::Deref for BufferChunks<S> {
401    type Target = S;
402
403    fn deref(&self) -> &Self::Target {
404        &self.inner
405    }
406}
407
408impl<S: Stream> std::ops::DerefMut for BufferChunks<S> {
409    fn deref_mut(&mut self) -> &mut Self::Target {
410        &mut self.inner
411    }
412}
413
414impl<S: Stream> Stream for BufferChunks<S>
415where
416    S: Stream<Item = DispatcherMessageStreamItem> + Unpin,
417{
418    type Item = S::Item;
419
420    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
421        loop {
422            if let Some(item) = self.pending_items.pop_front() {
423                return Poll::Ready(Some(item));
424            }
425
426            match self.inner.poll_next_unpin(cx) {
427                Poll::Pending => {
428                    return if let Some(chunk_out) = self.chunk_builder.take() {
429                        Poll::Ready(Some(Ok(MessageInner::Chunk(chunk_out))))
430                    } else {
431                        Poll::Pending
432                    };
433                }
434
435                Poll::Ready(Some(result)) => {
436                    if let Ok(MessageInner::Chunk(chunk)) = result {
437                        for row in chunk.records() {
438                            if let Some(chunk_out) = self.chunk_builder.append_record(row) {
439                                self.pending_items
440                                    .push_back(Ok(MessageInner::Chunk(chunk_out)));
441                            }
442                        }
443                    } else {
444                        return if let Some(chunk_out) = self.chunk_builder.take() {
445                            self.pending_items.push_back(result);
446                            Poll::Ready(Some(Ok(MessageInner::Chunk(chunk_out))))
447                        } else {
448                            Poll::Ready(Some(result))
449                        };
450                    }
451                }
452
453                Poll::Ready(None) => {
454                    // See also the comments in `DynamicReceivers::poll_next`.
455                    unreachable!("Merge should always have upstream inputs");
456                }
457            }
458        }
459    }
460}
461
462#[cfg(test)]
463mod tests {
464    use std::sync::atomic::{AtomicBool, Ordering};
465    use std::time::Duration;
466
467    use assert_matches::assert_matches;
468    use futures::FutureExt;
469    use futures::future::try_join_all;
470    use risingwave_common::array::Op;
471    use risingwave_common::util::epoch::test_epoch;
472    use risingwave_pb::task_service::stream_exchange_service_server::{
473        StreamExchangeService, StreamExchangeServiceServer,
474    };
475    use risingwave_pb::task_service::{GetStreamRequest, GetStreamResponse, PbPermits};
476    use tokio::time::sleep;
477    use tokio_stream::wrappers::ReceiverStream;
478    use tonic::{Request, Response, Status, Streaming};
479
480    use super::*;
481    use crate::executor::exchange::input::{ActorInput, LocalInput, RemoteInput};
482    use crate::executor::exchange::permit::channel_for_test;
483    use crate::executor::{BarrierInner as Barrier, MessageInner as Message};
484    use crate::task::NewOutputRequest;
485    use crate::task::barrier_test_utils::LocalBarrierTestEnv;
486    use crate::task::test_utils::helper_make_local_actor;
487
488    fn build_test_chunk(size: u64) -> StreamChunk {
489        let ops = vec![Op::Insert; size as usize];
490        StreamChunk::new(ops, vec![])
491    }
492
493    #[tokio::test]
494    async fn test_buffer_chunks() {
495        let test_env = LocalBarrierTestEnv::for_test().await;
496
497        let (tx, rx) = channel_for_test();
498        let input = LocalInput::new(rx, 1.into()).boxed_input();
499        let mut buffer = BufferChunks::new(input, 100, Schema::new(vec![]));
500
501        // Send a chunk
502        tx.send(Message::Chunk(build_test_chunk(10)).into())
503            .await
504            .unwrap();
505        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
506            assert_eq!(chunk.ops().len() as u64, 10);
507        });
508
509        // Send 2 chunks and expect them to be merged.
510        tx.send(Message::Chunk(build_test_chunk(10)).into())
511            .await
512            .unwrap();
513        tx.send(Message::Chunk(build_test_chunk(10)).into())
514            .await
515            .unwrap();
516        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
517            assert_eq!(chunk.ops().len() as u64, 20);
518        });
519
520        // Send a watermark.
521        tx.send(
522            Message::Watermark(Watermark {
523                col_idx: 0,
524                data_type: DataType::Int64,
525                val: ScalarImpl::Int64(233),
526            })
527            .into(),
528        )
529        .await
530        .unwrap();
531        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Watermark(watermark) => {
532            assert_eq!(watermark.val, ScalarImpl::Int64(233));
533        });
534
535        // Send 2 chunks before a watermark. Expect the 2 chunks to be merged and the watermark to be emitted.
536        tx.send(Message::Chunk(build_test_chunk(10)).into())
537            .await
538            .unwrap();
539        tx.send(Message::Chunk(build_test_chunk(10)).into())
540            .await
541            .unwrap();
542        tx.send(
543            Message::Watermark(Watermark {
544                col_idx: 0,
545                data_type: DataType::Int64,
546                val: ScalarImpl::Int64(233),
547            })
548            .into(),
549        )
550        .await
551        .unwrap();
552        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
553            assert_eq!(chunk.ops().len() as u64, 20);
554        });
555        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Watermark(watermark) => {
556            assert_eq!(watermark.val, ScalarImpl::Int64(233));
557        });
558
559        // Send a barrier.
560        let barrier = Barrier::new_test_barrier(test_epoch(1));
561        test_env.inject_barrier(&barrier, [2.into()]);
562        tx.send(Message::Barrier(barrier.clone().into_dispatcher()).into())
563            .await
564            .unwrap();
565        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Barrier(Barrier { epoch: barrier_epoch, mutation: _, .. }) => {
566            assert_eq!(barrier_epoch.curr, test_epoch(1));
567        });
568
569        // Send 2 chunks before a barrier. Expect the 2 chunks to be merged and the barrier to be emitted.
570        tx.send(Message::Chunk(build_test_chunk(10)).into())
571            .await
572            .unwrap();
573        tx.send(Message::Chunk(build_test_chunk(10)).into())
574            .await
575            .unwrap();
576        let barrier = Barrier::new_test_barrier(test_epoch(2));
577        test_env.inject_barrier(&barrier, [2.into()]);
578        tx.send(Message::Barrier(barrier.clone().into_dispatcher()).into())
579            .await
580            .unwrap();
581        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
582            assert_eq!(chunk.ops().len() as u64, 20);
583        });
584        assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Barrier(Barrier { epoch: barrier_epoch, mutation: _, .. }) => {
585            assert_eq!(barrier_epoch.curr, test_epoch(2));
586        });
587    }
588
589    #[tokio::test]
590    async fn test_merger() {
591        const CHANNEL_NUMBER: usize = 10;
592        let mut txs = Vec::with_capacity(CHANNEL_NUMBER);
593        let mut rxs = Vec::with_capacity(CHANNEL_NUMBER);
594        for _i in 0..CHANNEL_NUMBER {
595            let (tx, rx) = channel_for_test();
596            txs.push(tx);
597            rxs.push(rx);
598        }
599        let barrier_test_env = LocalBarrierTestEnv::for_test().await;
600        let actor_id = 233.into();
601        let mut handles = Vec::with_capacity(CHANNEL_NUMBER);
602
603        let epochs = (10..1000u64)
604            .step_by(10)
605            .map(|idx| (idx, test_epoch(idx)))
606            .collect_vec();
607        let mut prev_epoch = 0;
608        let prev_epoch = &mut prev_epoch;
609        let barriers: HashMap<_, _> = epochs
610            .iter()
611            .map(|(_, epoch)| {
612                let barrier = Barrier::with_prev_epoch_for_test(*epoch, *prev_epoch);
613                *prev_epoch = *epoch;
614                barrier_test_env.inject_barrier(&barrier, [actor_id]);
615                (*epoch, barrier)
616            })
617            .collect();
618        let b2 = Barrier::with_prev_epoch_for_test(test_epoch(1000), *prev_epoch)
619            .with_mutation(Mutation::Stop(StopMutation::default()));
620        barrier_test_env.inject_barrier(&b2, [actor_id]);
621        barrier_test_env.flush_all_events().await;
622
623        for (tx_id, tx) in txs.into_iter().enumerate() {
624            let epochs = epochs.clone();
625            let barriers = barriers.clone();
626            let b2 = b2.clone();
627            let handle = tokio::spawn(async move {
628                for (idx, epoch) in epochs {
629                    if idx % 20 == 0 {
630                        tx.send(Message::Chunk(build_test_chunk(10)).into())
631                            .await
632                            .unwrap();
633                    } else {
634                        tx.send(
635                            Message::Watermark(Watermark {
636                                col_idx: (idx as usize / 20 + tx_id) % CHANNEL_NUMBER,
637                                data_type: DataType::Int64,
638                                val: ScalarImpl::Int64(idx as i64),
639                            })
640                            .into(),
641                        )
642                        .await
643                        .unwrap();
644                    }
645                    tx.send(Message::Barrier(barriers[&epoch].clone().into_dispatcher()).into())
646                        .await
647                        .unwrap();
648                    sleep(Duration::from_millis(1)).await;
649                }
650                tx.send(Message::Barrier(b2.clone().into_dispatcher()).into())
651                    .await
652                    .unwrap();
653            });
654            handles.push(handle);
655        }
656
657        let merger = MergeExecutor::for_test(
658            actor_id,
659            rxs,
660            barrier_test_env.local_barrier_manager.clone(),
661            Schema::new(vec![]),
662            100,
663            None,
664        );
665        let mut merger = merger.boxed().execute();
666        for (idx, epoch) in epochs {
667            if idx % 20 == 0 {
668                // expect 1 or more chunks with 100 rows in total
669                let mut count = 0usize;
670                while count < 100 {
671                    assert_matches!(merger.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
672                        count += chunk.ops().len();
673                    });
674                }
675                assert_eq!(count, 100);
676            } else if idx as usize / 20 >= CHANNEL_NUMBER - 1 {
677                // expect n watermarks
678                for _ in 0..CHANNEL_NUMBER {
679                    assert_matches!(merger.next().await.unwrap().unwrap(), Message::Watermark(watermark) => {
680                        assert_eq!(watermark.val, ScalarImpl::Int64((idx - 20 * (CHANNEL_NUMBER as u64 - 1)) as i64));
681                    });
682                }
683            }
684            // expect a barrier
685            assert_matches!(merger.next().await.unwrap().unwrap(), Message::Barrier(Barrier{epoch:barrier_epoch,mutation:_,..}) => {
686                assert_eq!(barrier_epoch.curr, epoch);
687            });
688        }
689        assert_matches!(
690            merger.next().await.unwrap().unwrap(),
691            Message::Barrier(Barrier {
692                mutation,
693                ..
694            }) if mutation.as_deref().unwrap().is_stop()
695        );
696
697        for handle in handles {
698            handle.await.unwrap();
699        }
700    }
701
702    #[tokio::test]
703    async fn test_configuration_change() {
704        let actor_id = 233.into();
705        let (untouched, old, new) = (234.into(), 235.into(), 238.into()); // upstream actors
706        let barrier_test_env = LocalBarrierTestEnv::for_test().await;
707        let metrics = Arc::new(StreamingMetrics::unused());
708
709        // untouched -> actor_id
710        // old -> actor_id
711        // new -> actor_id
712
713        let (upstream_fragment_id, fragment_id) = (10.into(), 18.into());
714
715        let actor_ctx = ActorContext::for_test(actor_id);
716
717        let inputs: Vec<_> =
718            try_join_all([untouched, old].into_iter().map(async |upstream_actor_id| {
719                new_input(
720                    &barrier_test_env.local_barrier_manager,
721                    metrics.clone(),
722                    actor_id,
723                    fragment_id,
724                    &helper_make_local_actor(upstream_actor_id),
725                    upstream_fragment_id,
726                    actor_ctx.config.clone(),
727                )
728                .await
729            }))
730            .await
731            .unwrap();
732
733        let merge_updates = maplit::hashmap! {
734            (actor_id, upstream_fragment_id) => MergeUpdate {
735                actor_id,
736                upstream_fragment_id,
737                new_upstream_fragment_id: None,
738                added_upstream_actors: vec![helper_make_local_actor(new)],
739                removed_upstream_actor_id: vec![old],
740            }
741        };
742
743        let b1 = Barrier::new_test_barrier(test_epoch(1)).with_mutation(Mutation::Update(
744            UpdateMutation {
745                merges: merge_updates,
746                ..Default::default()
747            },
748        ));
749        barrier_test_env.inject_barrier(&b1, [actor_id]);
750        barrier_test_env.flush_all_events().await;
751
752        let barrier_rx = barrier_test_env
753            .local_barrier_manager
754            .subscribe_barrier(actor_id);
755        let upstream = MergeExecutor::new_merge_upstream(
756            inputs,
757            &metrics,
758            &actor_ctx,
759            100,
760            Schema::empty().clone(),
761        );
762
763        let mut merge = MergeExecutor::new(
764            actor_ctx,
765            fragment_id,
766            upstream_fragment_id,
767            upstream,
768            barrier_test_env.local_barrier_manager.clone(),
769            metrics.clone(),
770            barrier_rx,
771        )
772        .boxed()
773        .execute();
774
775        let mut txs = HashMap::new();
776        macro_rules! send {
777            ($actors:expr, $msg:expr) => {
778                for actor in $actors {
779                    txs.get(&actor).unwrap().send($msg).await.unwrap();
780                }
781            };
782        }
783
784        macro_rules! assert_recv_pending {
785            () => {
786                assert!(
787                    merge
788                        .next()
789                        .now_or_never()
790                        .flatten()
791                        .transpose()
792                        .unwrap()
793                        .is_none()
794                );
795            };
796        }
797        macro_rules! recv {
798            () => {
799                merge.next().await.transpose().unwrap()
800            };
801        }
802
803        macro_rules! collect_upstream_tx {
804            ($actors:expr) => {
805                for upstream_id in $actors {
806                    let mut output_requests = barrier_test_env
807                        .take_pending_new_output_requests(upstream_id)
808                        .await;
809                    assert_eq!(output_requests.len(), 1);
810                    let (downstream_actor_id, request) = output_requests.pop().unwrap();
811                    assert_eq!(downstream_actor_id, actor_id);
812                    let NewOutputRequest::Local(tx) = request else {
813                        unreachable!()
814                    };
815                    txs.insert(upstream_id, tx);
816                }
817            };
818        }
819
820        assert_recv_pending!();
821        barrier_test_env.flush_all_events().await;
822
823        // 2. Take downstream receivers.
824        collect_upstream_tx!([untouched, old]);
825
826        // 3. Send a chunk.
827        send!([untouched, old], Message::Chunk(build_test_chunk(1)).into());
828        assert_eq!(2, recv!().unwrap().as_chunk().unwrap().cardinality()); // We should be able to receive the chunk twice.
829        assert_recv_pending!();
830
831        send!(
832            [untouched, old],
833            Message::Barrier(b1.clone().into_dispatcher()).into()
834        );
835        assert_recv_pending!(); // We should not receive the barrier, since merger is waiting for the new upstream new.
836
837        collect_upstream_tx!([new]);
838
839        send!([new], Message::Barrier(b1.clone().into_dispatcher()).into());
840        recv!().unwrap().as_barrier().unwrap(); // We should now receive the barrier.
841
842        // 5. Send a chunk.
843        send!([untouched, new], Message::Chunk(build_test_chunk(1)).into());
844        assert_eq!(2, recv!().unwrap().as_chunk().unwrap().cardinality()); // We should be able to receive the chunk twice.
845        assert_recv_pending!();
846    }
847
848    struct FakeExchangeService {
849        rpc_called: Arc<AtomicBool>,
850    }
851
852    fn exchange_client_test_barrier() -> crate::executor::Barrier {
853        Barrier::new_test_barrier(test_epoch(1))
854    }
855
856    #[async_trait::async_trait]
857    impl StreamExchangeService for FakeExchangeService {
858        type GetStreamStream = ReceiverStream<std::result::Result<GetStreamResponse, Status>>;
859
860        async fn get_stream(
861            &self,
862            _request: Request<Streaming<GetStreamRequest>>,
863        ) -> std::result::Result<Response<Self::GetStreamStream>, Status> {
864            let (tx, rx) = tokio::sync::mpsc::channel(10);
865            self.rpc_called.store(true, Ordering::SeqCst);
866            // send stream_chunk
867            let stream_chunk = StreamChunk::default().to_protobuf();
868            tx.send(Ok(GetStreamResponse {
869                message: Some(PbStreamMessageBatch {
870                    stream_message_batch: Some(
871                        risingwave_pb::stream_plan::stream_message_batch::StreamMessageBatch::StreamChunk(
872                            stream_chunk,
873                        ),
874                    ),
875                }),
876                permits: Some(PbPermits::default()),
877            }))
878            .await
879            .unwrap();
880            // send barrier
881            let barrier = exchange_client_test_barrier();
882            tx.send(Ok(GetStreamResponse {
883                message: Some(PbStreamMessageBatch {
884                    stream_message_batch: Some(
885                        risingwave_pb::stream_plan::stream_message_batch::StreamMessageBatch::BarrierBatch(
886                            BarrierBatch {
887                                barriers: vec![barrier.to_protobuf()],
888                            },
889                        ),
890                    ),
891                }),
892                permits: Some(PbPermits::default()),
893            }))
894            .await
895            .unwrap();
896            Ok(Response::new(ReceiverStream::new(rx)))
897        }
898    }
899
900    #[tokio::test]
901    async fn test_stream_exchange_client() {
902        let rpc_called = Arc::new(AtomicBool::new(false));
903        let server_run = Arc::new(AtomicBool::new(false));
904        let addr = "127.0.0.1:12348".parse().unwrap();
905
906        // Start a server.
907        let (shutdown_send, shutdown_recv) = tokio::sync::oneshot::channel();
908        let exchange_svc = StreamExchangeServiceServer::new(FakeExchangeService {
909            rpc_called: rpc_called.clone(),
910        });
911        let cp_server_run = server_run.clone();
912        let join_handle = tokio::spawn(async move {
913            cp_server_run.store(true, Ordering::SeqCst);
914            tonic::transport::Server::builder()
915                .add_service(exchange_svc)
916                .serve_with_shutdown(addr, async move {
917                    shutdown_recv.await.unwrap();
918                })
919                .await
920                .unwrap();
921        });
922
923        sleep(Duration::from_secs(1)).await;
924        assert!(server_run.load(Ordering::SeqCst));
925
926        let test_env = LocalBarrierTestEnv::for_test().await;
927
928        let remote_input = {
929            RemoteInput::new(
930                &test_env.local_barrier_manager,
931                addr.into(),
932                (0.into(), 0.into()),
933                (0.into(), 0.into()),
934                Arc::new(StreamingMetrics::unused()),
935                Arc::new(StreamingConfig::default()),
936            )
937            .await
938            .unwrap()
939        };
940
941        test_env.inject_barrier(&exchange_client_test_barrier(), [remote_input.id()]);
942
943        pin_mut!(remote_input);
944
945        assert_matches!(remote_input.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
946            let (ops, columns, visibility) = chunk.into_inner();
947            assert!(ops.is_empty());
948            assert!(columns.is_empty());
949            assert!(visibility.is_empty());
950        });
951        assert_matches!(remote_input.next().await.unwrap().unwrap(), Message::Barrier(Barrier { epoch: barrier_epoch, mutation: _, .. }) => {
952            assert_eq!(barrier_epoch.curr, test_epoch(1));
953        });
954        assert!(rpc_called.load(Ordering::SeqCst));
955
956        shutdown_send.send(()).unwrap();
957        join_handle.await.unwrap();
958    }
959}