1use std::collections::VecDeque;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18
19use risingwave_common::array::StreamChunkBuilder;
20use tokio::sync::mpsc;
21
22use super::exchange::input::BoxedActorInput;
23use super::*;
24use crate::executor::prelude::*;
25use crate::task::LocalBarrierManager;
26
27pub type SelectReceivers = DynamicReceivers<ActorId, ()>;
28
29pub type MergeUpstream = BufferChunks<SelectReceivers>;
30pub type SingletonUpstream = BoxedActorInput;
31
32pub(crate) enum MergeExecutorUpstream {
33 Singleton(SingletonUpstream),
34 Merge(MergeUpstream),
35}
36
37pub(crate) struct MergeExecutorInput {
38 upstream: MergeExecutorUpstream,
39 actor_context: ActorContextRef,
40 upstream_fragment_id: UpstreamFragmentId,
41 local_barrier_manager: LocalBarrierManager,
42 executor_stats: Arc<StreamingMetrics>,
43 pub(crate) info: ExecutorInfo,
44}
45
46impl MergeExecutorInput {
47 pub(crate) fn new(
48 upstream: MergeExecutorUpstream,
49 actor_context: ActorContextRef,
50 upstream_fragment_id: UpstreamFragmentId,
51 local_barrier_manager: LocalBarrierManager,
52 executor_stats: Arc<StreamingMetrics>,
53 info: ExecutorInfo,
54 ) -> Self {
55 Self {
56 upstream,
57 actor_context,
58 upstream_fragment_id,
59 local_barrier_manager,
60 executor_stats,
61 info,
62 }
63 }
64
65 pub(crate) fn into_executor(self, barrier_rx: mpsc::UnboundedReceiver<Barrier>) -> Executor {
66 let fragment_id = self.actor_context.fragment_id;
67 let executor = match self.upstream {
68 MergeExecutorUpstream::Singleton(input) => ReceiverExecutor::new(
69 self.actor_context,
70 fragment_id,
71 self.upstream_fragment_id,
72 input,
73 self.local_barrier_manager,
74 self.executor_stats,
75 barrier_rx,
76 )
77 .boxed(),
78 MergeExecutorUpstream::Merge(inputs) => MergeExecutor::new(
79 self.actor_context,
80 fragment_id,
81 self.upstream_fragment_id,
82 inputs,
83 self.local_barrier_manager,
84 self.executor_stats,
85 barrier_rx,
86 )
87 .boxed(),
88 };
89 (self.info, executor).into()
90 }
91}
92
93impl Stream for MergeExecutorInput {
94 type Item = DispatcherMessageStreamItem;
95
96 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
97 match &mut self.get_mut().upstream {
98 MergeExecutorUpstream::Singleton(input) => input.poll_next_unpin(cx),
99 MergeExecutorUpstream::Merge(inputs) => inputs.poll_next_unpin(cx),
100 }
101 }
102}
103
104mod upstream {
105 use super::*;
106
107 pub trait Upstream:
110 Stream<Item = StreamExecutorResult<DispatcherMessage>> + Unpin + Send + 'static
111 {
112 fn upstream_input_ids(&self) -> impl Iterator<Item = ActorId> + '_;
113 fn flush_buffered_watermarks(&mut self);
114 fn update(&mut self, to_add: Vec<BoxedActorInput>, to_remove: &HashSet<ActorId>);
115 }
116
117 impl Upstream for MergeUpstream {
118 fn upstream_input_ids(&self) -> impl Iterator<Item = ActorId> + '_ {
119 self.inner.upstream_input_ids()
120 }
121
122 fn flush_buffered_watermarks(&mut self) {
123 self.inner.flush_buffered_watermarks();
124 }
125
126 fn update(&mut self, to_add: Vec<BoxedActorInput>, to_remove: &HashSet<ActorId>) {
127 if !to_add.is_empty() {
128 self.inner.add_upstreams_from(to_add);
129 }
130 if !to_remove.is_empty() {
131 self.inner.remove_upstreams(to_remove);
132 }
133 }
134 }
135
136 impl Upstream for SingletonUpstream {
137 fn upstream_input_ids(&self) -> impl Iterator<Item = ActorId> + '_ {
138 std::iter::once(self.id())
139 }
140
141 fn flush_buffered_watermarks(&mut self) {
142 }
144
145 fn update(&mut self, to_add: Vec<BoxedActorInput>, to_remove: &HashSet<ActorId>) {
146 assert_eq!(
147 to_remove,
148 &HashSet::from_iter([self.id()]),
149 "the removed upstream actor should be the same as the current input"
150 );
151
152 *self = to_add
154 .into_iter()
155 .exactly_one()
156 .expect("receiver should have exactly one new upstream");
157 }
158 }
159}
160use upstream::Upstream;
161
162pub struct MergeExecutorInner<U> {
164 actor_context: ActorContextRef,
166
167 upstream: U,
169
170 fragment_id: FragmentId,
172
173 upstream_fragment_id: FragmentId,
175
176 local_barrier_manager: LocalBarrierManager,
177
178 metrics: Arc<StreamingMetrics>,
180
181 barrier_rx: mpsc::UnboundedReceiver<Barrier>,
182}
183
184impl<U> MergeExecutorInner<U> {
185 pub fn new(
186 ctx: ActorContextRef,
187 fragment_id: FragmentId,
188 upstream_fragment_id: FragmentId,
189 upstream: U,
190 local_barrier_manager: LocalBarrierManager,
191 metrics: Arc<StreamingMetrics>,
192 barrier_rx: mpsc::UnboundedReceiver<Barrier>,
193 ) -> Self {
194 Self {
195 actor_context: ctx,
196 upstream,
197 fragment_id,
198 upstream_fragment_id,
199 local_barrier_manager,
200 metrics,
201 barrier_rx,
202 }
203 }
204}
205
206pub type MergeExecutor = MergeExecutorInner<MergeUpstream>;
208
209impl MergeExecutor {
210 #[cfg(test)]
211 pub fn for_test(
212 actor_id: impl Into<ActorId>,
213 inputs: Vec<super::exchange::permit::Receiver>,
214 local_barrier_manager: crate::task::LocalBarrierManager,
215 schema: Schema,
216 chunk_size: usize,
217 barrier_rx: Option<mpsc::UnboundedReceiver<Barrier>>,
218 ) -> Self {
219 let actor_id = actor_id.into();
220 use super::exchange::input::LocalInput;
221 use crate::executor::exchange::input::ActorInput;
222
223 let barrier_rx =
224 barrier_rx.unwrap_or_else(|| local_barrier_manager.subscribe_barrier(actor_id));
225
226 let metrics = StreamingMetrics::unused();
227 let actor_ctx = ActorContext::for_test(actor_id);
228 let upstream = Self::new_merge_upstream(
229 inputs
230 .into_iter()
231 .enumerate()
232 .map(|(idx, input)| LocalInput::new(input, ActorId::new(idx as u32)).boxed_input())
233 .collect(),
234 &metrics,
235 &actor_ctx,
236 chunk_size,
237 schema,
238 );
239
240 Self::new(
241 actor_ctx,
242 514.into(),
243 1919.into(),
244 upstream,
245 local_barrier_manager,
246 metrics.into(),
247 barrier_rx,
248 )
249 }
250
251 pub(crate) fn new_merge_upstream(
252 upstreams: Vec<BoxedActorInput>,
253 metrics: &StreamingMetrics,
254 actor_context: &ActorContext,
255 chunk_size: usize,
256 schema: Schema,
257 ) -> MergeUpstream {
258 let merge_barrier_align_duration = Some(
259 metrics
260 .merge_barrier_align_duration
261 .with_guarded_label_values(&[
262 &actor_context.id.to_string(),
263 &actor_context.fragment_id.to_string(),
264 ]),
265 );
266
267 BufferChunks::new(
268 SelectReceivers::new(upstreams, None, merge_barrier_align_duration),
270 chunk_size,
271 schema,
272 )
273 }
274}
275
276impl<U> MergeExecutorInner<U>
277where
278 U: Upstream,
279{
280 #[try_stream(ok = Message, error = StreamExecutorError)]
281 pub(crate) async fn execute_inner(mut self: Box<Self>) {
282 let mut upstream = self.upstream;
283 let actor_id = self.actor_context.id;
284
285 let mut metrics = self.metrics.new_actor_input_metrics(
286 actor_id,
287 self.fragment_id,
288 self.upstream_fragment_id,
289 );
290
291 let mut barrier_buffer = DispatchBarrierBuffer::new(
292 self.barrier_rx,
293 actor_id,
294 self.upstream_fragment_id,
295 self.local_barrier_manager,
296 self.metrics.clone(),
297 self.fragment_id,
298 self.actor_context.config.clone(),
299 );
300
301 loop {
302 let msg = barrier_buffer
303 .await_next_message(&mut upstream, &metrics)
304 .await?;
305 let msg = match msg {
306 DispatcherMessage::Watermark(watermark) => Message::Watermark(watermark),
307 DispatcherMessage::Chunk(chunk) => {
308 metrics.actor_in_record_cnt.inc_by(chunk.cardinality() as _);
309 Message::Chunk(chunk)
310 }
311 DispatcherMessage::Barrier(barrier) => {
312 let (barrier, new_inputs) =
313 barrier_buffer.pop_barrier_with_inputs(barrier).await?;
314
315 if let Some(Mutation::Update(UpdateMutation { dispatchers, .. })) =
316 barrier.mutation.as_deref()
317 && upstream
318 .upstream_input_ids()
319 .any(|actor_id| dispatchers.contains_key(&actor_id))
320 {
321 upstream.flush_buffered_watermarks();
323 }
324
325 if let Some(update) =
326 barrier.as_update_merge(self.actor_context.id, self.upstream_fragment_id)
327 {
328 let new_upstream_fragment_id = update
329 .new_upstream_fragment_id
330 .unwrap_or(self.upstream_fragment_id);
331 let removed_upstream_actor_id: HashSet<_> =
332 if update.new_upstream_fragment_id.is_some() {
333 upstream.upstream_input_ids().collect()
334 } else {
335 update.removed_upstream_actor_id.iter().copied().collect()
336 };
337
338 upstream.flush_buffered_watermarks();
340
341 upstream.update(new_inputs.unwrap_or_default(), &removed_upstream_actor_id);
343
344 self.upstream_fragment_id = new_upstream_fragment_id;
345 metrics = self.metrics.new_actor_input_metrics(
346 actor_id,
347 self.fragment_id,
348 self.upstream_fragment_id,
349 );
350 }
351
352 let is_stop = barrier.is_stop(actor_id);
353 let msg = Message::Barrier(barrier);
354 if is_stop {
355 yield msg;
356 break;
357 }
358
359 msg
360 }
361 };
362
363 yield msg;
364 }
365 }
366}
367
368impl<U> Execute for MergeExecutorInner<U>
369where
370 U: Upstream,
371{
372 fn execute(self: Box<Self>) -> BoxedMessageStream {
373 self.execute_inner().boxed()
374 }
375}
376
377pub struct BufferChunks<S: Stream> {
381 inner: S,
382 chunk_builder: StreamChunkBuilder,
383
384 pending_items: VecDeque<S::Item>,
386}
387
388impl<S: Stream> BufferChunks<S> {
389 pub(super) fn new(inner: S, chunk_size: usize, schema: Schema) -> Self {
390 assert!(chunk_size > 0);
391 let chunk_builder = StreamChunkBuilder::new(chunk_size, schema.data_types());
392 Self {
393 inner,
394 chunk_builder,
395 pending_items: VecDeque::new(),
396 }
397 }
398}
399
400impl<S: Stream> std::ops::Deref for BufferChunks<S> {
401 type Target = S;
402
403 fn deref(&self) -> &Self::Target {
404 &self.inner
405 }
406}
407
408impl<S: Stream> std::ops::DerefMut for BufferChunks<S> {
409 fn deref_mut(&mut self) -> &mut Self::Target {
410 &mut self.inner
411 }
412}
413
414impl<S: Stream> Stream for BufferChunks<S>
415where
416 S: Stream<Item = DispatcherMessageStreamItem> + Unpin,
417{
418 type Item = S::Item;
419
420 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
421 loop {
422 if let Some(item) = self.pending_items.pop_front() {
423 return Poll::Ready(Some(item));
424 }
425
426 match self.inner.poll_next_unpin(cx) {
427 Poll::Pending => {
428 return if let Some(chunk_out) = self.chunk_builder.take() {
429 Poll::Ready(Some(Ok(MessageInner::Chunk(chunk_out))))
430 } else {
431 Poll::Pending
432 };
433 }
434
435 Poll::Ready(Some(result)) => {
436 if let Ok(MessageInner::Chunk(chunk)) = result {
437 for row in chunk.records() {
438 if let Some(chunk_out) = self.chunk_builder.append_record(row) {
439 self.pending_items
440 .push_back(Ok(MessageInner::Chunk(chunk_out)));
441 }
442 }
443 } else {
444 return if let Some(chunk_out) = self.chunk_builder.take() {
445 self.pending_items.push_back(result);
446 Poll::Ready(Some(Ok(MessageInner::Chunk(chunk_out))))
447 } else {
448 Poll::Ready(Some(result))
449 };
450 }
451 }
452
453 Poll::Ready(None) => {
454 unreachable!("Merge should always have upstream inputs");
456 }
457 }
458 }
459 }
460}
461
462#[cfg(test)]
463mod tests {
464 use std::sync::atomic::{AtomicBool, Ordering};
465 use std::time::Duration;
466
467 use assert_matches::assert_matches;
468 use futures::FutureExt;
469 use futures::future::try_join_all;
470 use risingwave_common::array::Op;
471 use risingwave_common::util::epoch::test_epoch;
472 use risingwave_pb::task_service::stream_exchange_service_server::{
473 StreamExchangeService, StreamExchangeServiceServer,
474 };
475 use risingwave_pb::task_service::{GetStreamRequest, GetStreamResponse, PbPermits};
476 use tokio::time::sleep;
477 use tokio_stream::wrappers::ReceiverStream;
478 use tonic::{Request, Response, Status, Streaming};
479
480 use super::*;
481 use crate::executor::exchange::input::{ActorInput, LocalInput, RemoteInput};
482 use crate::executor::exchange::permit::channel_for_test;
483 use crate::executor::{BarrierInner as Barrier, MessageInner as Message};
484 use crate::task::NewOutputRequest;
485 use crate::task::barrier_test_utils::LocalBarrierTestEnv;
486 use crate::task::test_utils::helper_make_local_actor;
487
488 fn build_test_chunk(size: u64) -> StreamChunk {
489 let ops = vec![Op::Insert; size as usize];
490 StreamChunk::new(ops, vec![])
491 }
492
493 #[tokio::test]
494 async fn test_buffer_chunks() {
495 let test_env = LocalBarrierTestEnv::for_test().await;
496
497 let (tx, rx) = channel_for_test();
498 let input = LocalInput::new(rx, 1.into()).boxed_input();
499 let mut buffer = BufferChunks::new(input, 100, Schema::new(vec![]));
500
501 tx.send(Message::Chunk(build_test_chunk(10)).into())
503 .await
504 .unwrap();
505 assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
506 assert_eq!(chunk.ops().len() as u64, 10);
507 });
508
509 tx.send(Message::Chunk(build_test_chunk(10)).into())
511 .await
512 .unwrap();
513 tx.send(Message::Chunk(build_test_chunk(10)).into())
514 .await
515 .unwrap();
516 assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
517 assert_eq!(chunk.ops().len() as u64, 20);
518 });
519
520 tx.send(
522 Message::Watermark(Watermark {
523 col_idx: 0,
524 data_type: DataType::Int64,
525 val: ScalarImpl::Int64(233),
526 })
527 .into(),
528 )
529 .await
530 .unwrap();
531 assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Watermark(watermark) => {
532 assert_eq!(watermark.val, ScalarImpl::Int64(233));
533 });
534
535 tx.send(Message::Chunk(build_test_chunk(10)).into())
537 .await
538 .unwrap();
539 tx.send(Message::Chunk(build_test_chunk(10)).into())
540 .await
541 .unwrap();
542 tx.send(
543 Message::Watermark(Watermark {
544 col_idx: 0,
545 data_type: DataType::Int64,
546 val: ScalarImpl::Int64(233),
547 })
548 .into(),
549 )
550 .await
551 .unwrap();
552 assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
553 assert_eq!(chunk.ops().len() as u64, 20);
554 });
555 assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Watermark(watermark) => {
556 assert_eq!(watermark.val, ScalarImpl::Int64(233));
557 });
558
559 let barrier = Barrier::new_test_barrier(test_epoch(1));
561 test_env.inject_barrier(&barrier, [2.into()]);
562 tx.send(Message::Barrier(barrier.clone().into_dispatcher()).into())
563 .await
564 .unwrap();
565 assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Barrier(Barrier { epoch: barrier_epoch, mutation: _, .. }) => {
566 assert_eq!(barrier_epoch.curr, test_epoch(1));
567 });
568
569 tx.send(Message::Chunk(build_test_chunk(10)).into())
571 .await
572 .unwrap();
573 tx.send(Message::Chunk(build_test_chunk(10)).into())
574 .await
575 .unwrap();
576 let barrier = Barrier::new_test_barrier(test_epoch(2));
577 test_env.inject_barrier(&barrier, [2.into()]);
578 tx.send(Message::Barrier(barrier.clone().into_dispatcher()).into())
579 .await
580 .unwrap();
581 assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
582 assert_eq!(chunk.ops().len() as u64, 20);
583 });
584 assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Barrier(Barrier { epoch: barrier_epoch, mutation: _, .. }) => {
585 assert_eq!(barrier_epoch.curr, test_epoch(2));
586 });
587 }
588
589 #[tokio::test]
590 async fn test_merger() {
591 const CHANNEL_NUMBER: usize = 10;
592 let mut txs = Vec::with_capacity(CHANNEL_NUMBER);
593 let mut rxs = Vec::with_capacity(CHANNEL_NUMBER);
594 for _i in 0..CHANNEL_NUMBER {
595 let (tx, rx) = channel_for_test();
596 txs.push(tx);
597 rxs.push(rx);
598 }
599 let barrier_test_env = LocalBarrierTestEnv::for_test().await;
600 let actor_id = 233.into();
601 let mut handles = Vec::with_capacity(CHANNEL_NUMBER);
602
603 let epochs = (10..1000u64)
604 .step_by(10)
605 .map(|idx| (idx, test_epoch(idx)))
606 .collect_vec();
607 let mut prev_epoch = 0;
608 let prev_epoch = &mut prev_epoch;
609 let barriers: HashMap<_, _> = epochs
610 .iter()
611 .map(|(_, epoch)| {
612 let barrier = Barrier::with_prev_epoch_for_test(*epoch, *prev_epoch);
613 *prev_epoch = *epoch;
614 barrier_test_env.inject_barrier(&barrier, [actor_id]);
615 (*epoch, barrier)
616 })
617 .collect();
618 let b2 = Barrier::with_prev_epoch_for_test(test_epoch(1000), *prev_epoch)
619 .with_mutation(Mutation::Stop(StopMutation::default()));
620 barrier_test_env.inject_barrier(&b2, [actor_id]);
621 barrier_test_env.flush_all_events().await;
622
623 for (tx_id, tx) in txs.into_iter().enumerate() {
624 let epochs = epochs.clone();
625 let barriers = barriers.clone();
626 let b2 = b2.clone();
627 let handle = tokio::spawn(async move {
628 for (idx, epoch) in epochs {
629 if idx % 20 == 0 {
630 tx.send(Message::Chunk(build_test_chunk(10)).into())
631 .await
632 .unwrap();
633 } else {
634 tx.send(
635 Message::Watermark(Watermark {
636 col_idx: (idx as usize / 20 + tx_id) % CHANNEL_NUMBER,
637 data_type: DataType::Int64,
638 val: ScalarImpl::Int64(idx as i64),
639 })
640 .into(),
641 )
642 .await
643 .unwrap();
644 }
645 tx.send(Message::Barrier(barriers[&epoch].clone().into_dispatcher()).into())
646 .await
647 .unwrap();
648 sleep(Duration::from_millis(1)).await;
649 }
650 tx.send(Message::Barrier(b2.clone().into_dispatcher()).into())
651 .await
652 .unwrap();
653 });
654 handles.push(handle);
655 }
656
657 let merger = MergeExecutor::for_test(
658 actor_id,
659 rxs,
660 barrier_test_env.local_barrier_manager.clone(),
661 Schema::new(vec![]),
662 100,
663 None,
664 );
665 let mut merger = merger.boxed().execute();
666 for (idx, epoch) in epochs {
667 if idx % 20 == 0 {
668 let mut count = 0usize;
670 while count < 100 {
671 assert_matches!(merger.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
672 count += chunk.ops().len();
673 });
674 }
675 assert_eq!(count, 100);
676 } else if idx as usize / 20 >= CHANNEL_NUMBER - 1 {
677 for _ in 0..CHANNEL_NUMBER {
679 assert_matches!(merger.next().await.unwrap().unwrap(), Message::Watermark(watermark) => {
680 assert_eq!(watermark.val, ScalarImpl::Int64((idx - 20 * (CHANNEL_NUMBER as u64 - 1)) as i64));
681 });
682 }
683 }
684 assert_matches!(merger.next().await.unwrap().unwrap(), Message::Barrier(Barrier{epoch:barrier_epoch,mutation:_,..}) => {
686 assert_eq!(barrier_epoch.curr, epoch);
687 });
688 }
689 assert_matches!(
690 merger.next().await.unwrap().unwrap(),
691 Message::Barrier(Barrier {
692 mutation,
693 ..
694 }) if mutation.as_deref().unwrap().is_stop()
695 );
696
697 for handle in handles {
698 handle.await.unwrap();
699 }
700 }
701
702 #[tokio::test]
703 async fn test_configuration_change() {
704 let actor_id = 233.into();
705 let (untouched, old, new) = (234.into(), 235.into(), 238.into()); let barrier_test_env = LocalBarrierTestEnv::for_test().await;
707 let metrics = Arc::new(StreamingMetrics::unused());
708
709 let (upstream_fragment_id, fragment_id) = (10.into(), 18.into());
714
715 let actor_ctx = ActorContext::for_test(actor_id);
716
717 let inputs: Vec<_> =
718 try_join_all([untouched, old].into_iter().map(async |upstream_actor_id| {
719 new_input(
720 &barrier_test_env.local_barrier_manager,
721 metrics.clone(),
722 actor_id,
723 fragment_id,
724 &helper_make_local_actor(upstream_actor_id),
725 upstream_fragment_id,
726 actor_ctx.config.clone(),
727 )
728 .await
729 }))
730 .await
731 .unwrap();
732
733 let merge_updates = maplit::hashmap! {
734 (actor_id, upstream_fragment_id) => MergeUpdate {
735 actor_id,
736 upstream_fragment_id,
737 new_upstream_fragment_id: None,
738 added_upstream_actors: vec![helper_make_local_actor(new)],
739 removed_upstream_actor_id: vec![old],
740 }
741 };
742
743 let b1 = Barrier::new_test_barrier(test_epoch(1)).with_mutation(Mutation::Update(
744 UpdateMutation {
745 merges: merge_updates,
746 ..Default::default()
747 },
748 ));
749 barrier_test_env.inject_barrier(&b1, [actor_id]);
750 barrier_test_env.flush_all_events().await;
751
752 let barrier_rx = barrier_test_env
753 .local_barrier_manager
754 .subscribe_barrier(actor_id);
755 let upstream = MergeExecutor::new_merge_upstream(
756 inputs,
757 &metrics,
758 &actor_ctx,
759 100,
760 Schema::empty().clone(),
761 );
762
763 let mut merge = MergeExecutor::new(
764 actor_ctx,
765 fragment_id,
766 upstream_fragment_id,
767 upstream,
768 barrier_test_env.local_barrier_manager.clone(),
769 metrics.clone(),
770 barrier_rx,
771 )
772 .boxed()
773 .execute();
774
775 let mut txs = HashMap::new();
776 macro_rules! send {
777 ($actors:expr, $msg:expr) => {
778 for actor in $actors {
779 txs.get(&actor).unwrap().send($msg).await.unwrap();
780 }
781 };
782 }
783
784 macro_rules! assert_recv_pending {
785 () => {
786 assert!(
787 merge
788 .next()
789 .now_or_never()
790 .flatten()
791 .transpose()
792 .unwrap()
793 .is_none()
794 );
795 };
796 }
797 macro_rules! recv {
798 () => {
799 merge.next().await.transpose().unwrap()
800 };
801 }
802
803 macro_rules! collect_upstream_tx {
804 ($actors:expr) => {
805 for upstream_id in $actors {
806 let mut output_requests = barrier_test_env
807 .take_pending_new_output_requests(upstream_id)
808 .await;
809 assert_eq!(output_requests.len(), 1);
810 let (downstream_actor_id, request) = output_requests.pop().unwrap();
811 assert_eq!(downstream_actor_id, actor_id);
812 let NewOutputRequest::Local(tx) = request else {
813 unreachable!()
814 };
815 txs.insert(upstream_id, tx);
816 }
817 };
818 }
819
820 assert_recv_pending!();
821 barrier_test_env.flush_all_events().await;
822
823 collect_upstream_tx!([untouched, old]);
825
826 send!([untouched, old], Message::Chunk(build_test_chunk(1)).into());
828 assert_eq!(2, recv!().unwrap().as_chunk().unwrap().cardinality()); assert_recv_pending!();
830
831 send!(
832 [untouched, old],
833 Message::Barrier(b1.clone().into_dispatcher()).into()
834 );
835 assert_recv_pending!(); collect_upstream_tx!([new]);
838
839 send!([new], Message::Barrier(b1.clone().into_dispatcher()).into());
840 recv!().unwrap().as_barrier().unwrap(); send!([untouched, new], Message::Chunk(build_test_chunk(1)).into());
844 assert_eq!(2, recv!().unwrap().as_chunk().unwrap().cardinality()); assert_recv_pending!();
846 }
847
848 struct FakeExchangeService {
849 rpc_called: Arc<AtomicBool>,
850 }
851
852 fn exchange_client_test_barrier() -> crate::executor::Barrier {
853 Barrier::new_test_barrier(test_epoch(1))
854 }
855
856 #[async_trait::async_trait]
857 impl StreamExchangeService for FakeExchangeService {
858 type GetStreamStream = ReceiverStream<std::result::Result<GetStreamResponse, Status>>;
859
860 async fn get_stream(
861 &self,
862 _request: Request<Streaming<GetStreamRequest>>,
863 ) -> std::result::Result<Response<Self::GetStreamStream>, Status> {
864 let (tx, rx) = tokio::sync::mpsc::channel(10);
865 self.rpc_called.store(true, Ordering::SeqCst);
866 let stream_chunk = StreamChunk::default().to_protobuf();
868 tx.send(Ok(GetStreamResponse {
869 message: Some(PbStreamMessageBatch {
870 stream_message_batch: Some(
871 risingwave_pb::stream_plan::stream_message_batch::StreamMessageBatch::StreamChunk(
872 stream_chunk,
873 ),
874 ),
875 }),
876 permits: Some(PbPermits::default()),
877 }))
878 .await
879 .unwrap();
880 let barrier = exchange_client_test_barrier();
882 tx.send(Ok(GetStreamResponse {
883 message: Some(PbStreamMessageBatch {
884 stream_message_batch: Some(
885 risingwave_pb::stream_plan::stream_message_batch::StreamMessageBatch::BarrierBatch(
886 BarrierBatch {
887 barriers: vec![barrier.to_protobuf()],
888 },
889 ),
890 ),
891 }),
892 permits: Some(PbPermits::default()),
893 }))
894 .await
895 .unwrap();
896 Ok(Response::new(ReceiverStream::new(rx)))
897 }
898 }
899
900 #[tokio::test]
901 async fn test_stream_exchange_client() {
902 let rpc_called = Arc::new(AtomicBool::new(false));
903 let server_run = Arc::new(AtomicBool::new(false));
904 let addr = "127.0.0.1:12348".parse().unwrap();
905
906 let (shutdown_send, shutdown_recv) = tokio::sync::oneshot::channel();
908 let exchange_svc = StreamExchangeServiceServer::new(FakeExchangeService {
909 rpc_called: rpc_called.clone(),
910 });
911 let cp_server_run = server_run.clone();
912 let join_handle = tokio::spawn(async move {
913 cp_server_run.store(true, Ordering::SeqCst);
914 tonic::transport::Server::builder()
915 .add_service(exchange_svc)
916 .serve_with_shutdown(addr, async move {
917 shutdown_recv.await.unwrap();
918 })
919 .await
920 .unwrap();
921 });
922
923 sleep(Duration::from_secs(1)).await;
924 assert!(server_run.load(Ordering::SeqCst));
925
926 let test_env = LocalBarrierTestEnv::for_test().await;
927
928 let remote_input = {
929 RemoteInput::new(
930 &test_env.local_barrier_manager,
931 addr.into(),
932 (0.into(), 0.into()),
933 (0.into(), 0.into()),
934 Arc::new(StreamingMetrics::unused()),
935 Arc::new(StreamingConfig::default()),
936 )
937 .await
938 .unwrap()
939 };
940
941 test_env.inject_barrier(&exchange_client_test_barrier(), [remote_input.id()]);
942
943 pin_mut!(remote_input);
944
945 assert_matches!(remote_input.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
946 let (ops, columns, visibility) = chunk.into_inner();
947 assert!(ops.is_empty());
948 assert!(columns.is_empty());
949 assert!(visibility.is_empty());
950 });
951 assert_matches!(remote_input.next().await.unwrap().unwrap(), Message::Barrier(Barrier { epoch: barrier_epoch, mutation: _, .. }) => {
952 assert_eq!(barrier_epoch.curr, test_epoch(1));
953 });
954 assert!(rpc_called.load(Ordering::SeqCst));
955
956 shutdown_send.send(()).unwrap();
957 join_handle.await.unwrap();
958 }
959}