1use std::collections::VecDeque;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18
19use risingwave_common::array::StreamChunkBuilder;
20use risingwave_common::config::MetricLevel;
21use tokio::sync::mpsc;
22
23use super::exchange::input::BoxedActorInput;
24use super::*;
25use crate::executor::prelude::*;
26use crate::task::LocalBarrierManager;
27
28pub type SelectReceivers = DynamicReceivers<ActorId, ()>;
29
30pub type MergeUpstream = BufferChunks<SelectReceivers>;
31pub type SingletonUpstream = BoxedActorInput;
32
33pub(crate) enum MergeExecutorUpstream {
34 Singleton(SingletonUpstream),
35 Merge(MergeUpstream),
36}
37
38pub(crate) struct MergeExecutorInput {
39 upstream: MergeExecutorUpstream,
40 actor_context: ActorContextRef,
41 upstream_fragment_id: UpstreamFragmentId,
42 local_barrier_manager: LocalBarrierManager,
43 executor_stats: Arc<StreamingMetrics>,
44 pub(crate) info: ExecutorInfo,
45}
46
47impl MergeExecutorInput {
48 pub(crate) fn new(
49 upstream: MergeExecutorUpstream,
50 actor_context: ActorContextRef,
51 upstream_fragment_id: UpstreamFragmentId,
52 local_barrier_manager: LocalBarrierManager,
53 executor_stats: Arc<StreamingMetrics>,
54 info: ExecutorInfo,
55 ) -> Self {
56 Self {
57 upstream,
58 actor_context,
59 upstream_fragment_id,
60 local_barrier_manager,
61 executor_stats,
62 info,
63 }
64 }
65
66 pub(crate) fn into_executor(self, barrier_rx: mpsc::UnboundedReceiver<Barrier>) -> Executor {
67 let fragment_id = self.actor_context.fragment_id;
68 let executor = match self.upstream {
69 MergeExecutorUpstream::Singleton(input) => ReceiverExecutor::new(
70 self.actor_context,
71 fragment_id,
72 self.upstream_fragment_id,
73 input,
74 self.local_barrier_manager,
75 self.executor_stats,
76 barrier_rx,
77 )
78 .boxed(),
79 MergeExecutorUpstream::Merge(inputs) => MergeExecutor::new(
80 self.actor_context,
81 fragment_id,
82 self.upstream_fragment_id,
83 inputs,
84 self.local_barrier_manager,
85 self.executor_stats,
86 barrier_rx,
87 )
88 .boxed(),
89 };
90 (self.info, executor).into()
91 }
92}
93
94impl Stream for MergeExecutorInput {
95 type Item = DispatcherMessageStreamItem;
96
97 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
98 match &mut self.get_mut().upstream {
99 MergeExecutorUpstream::Singleton(input) => input.poll_next_unpin(cx),
100 MergeExecutorUpstream::Merge(inputs) => inputs.poll_next_unpin(cx),
101 }
102 }
103}
104
105mod upstream {
106 use super::*;
107
108 pub trait Upstream:
111 Stream<Item = StreamExecutorResult<DispatcherMessage>> + Unpin + Send + 'static
112 {
113 fn upstream_input_ids(&self) -> impl Iterator<Item = ActorId> + '_;
114 fn flush_buffered_watermarks(&mut self);
115 fn update(&mut self, to_add: Vec<BoxedActorInput>, to_remove: &HashSet<ActorId>);
116 }
117
118 impl Upstream for MergeUpstream {
119 fn upstream_input_ids(&self) -> impl Iterator<Item = ActorId> + '_ {
120 self.inner.upstream_input_ids()
121 }
122
123 fn flush_buffered_watermarks(&mut self) {
124 self.inner.flush_buffered_watermarks();
125 }
126
127 fn update(&mut self, to_add: Vec<BoxedActorInput>, to_remove: &HashSet<ActorId>) {
128 if !to_add.is_empty() {
129 self.inner.add_upstreams_from(to_add);
130 }
131 if !to_remove.is_empty() {
132 self.inner.remove_upstreams(to_remove);
133 }
134 }
135 }
136
137 impl Upstream for SingletonUpstream {
138 fn upstream_input_ids(&self) -> impl Iterator<Item = ActorId> + '_ {
139 std::iter::once(self.id())
140 }
141
142 fn flush_buffered_watermarks(&mut self) {
143 }
145
146 fn update(&mut self, to_add: Vec<BoxedActorInput>, to_remove: &HashSet<ActorId>) {
147 assert_eq!(
148 to_remove,
149 &HashSet::from_iter([self.id()]),
150 "the removed upstream actor should be the same as the current input"
151 );
152
153 *self = to_add
155 .into_iter()
156 .exactly_one()
157 .expect("receiver should have exactly one new upstream");
158 }
159 }
160}
161use upstream::Upstream;
162
163pub struct MergeExecutorInner<U> {
165 actor_context: ActorContextRef,
167
168 upstream: U,
170
171 fragment_id: FragmentId,
173
174 upstream_fragment_id: FragmentId,
176
177 local_barrier_manager: LocalBarrierManager,
178
179 metrics: Arc<StreamingMetrics>,
181
182 barrier_rx: mpsc::UnboundedReceiver<Barrier>,
183}
184
185impl<U> MergeExecutorInner<U> {
186 pub fn new(
187 ctx: ActorContextRef,
188 fragment_id: FragmentId,
189 upstream_fragment_id: FragmentId,
190 upstream: U,
191 local_barrier_manager: LocalBarrierManager,
192 metrics: Arc<StreamingMetrics>,
193 barrier_rx: mpsc::UnboundedReceiver<Barrier>,
194 ) -> Self {
195 Self {
196 actor_context: ctx,
197 upstream,
198 fragment_id,
199 upstream_fragment_id,
200 local_barrier_manager,
201 metrics,
202 barrier_rx,
203 }
204 }
205}
206
207pub type MergeExecutor = MergeExecutorInner<MergeUpstream>;
209
210impl MergeExecutor {
211 #[cfg(test)]
212 pub fn for_test(
213 actor_id: impl Into<ActorId>,
214 inputs: Vec<super::exchange::permit::Receiver>,
215 local_barrier_manager: crate::task::LocalBarrierManager,
216 schema: Schema,
217 chunk_size: usize,
218 barrier_rx: Option<mpsc::UnboundedReceiver<Barrier>>,
219 ) -> Self {
220 let actor_id = actor_id.into();
221 use super::exchange::input::LocalInput;
222 use crate::executor::exchange::input::ActorInput;
223
224 let barrier_rx =
225 barrier_rx.unwrap_or_else(|| local_barrier_manager.subscribe_barrier(actor_id));
226
227 let metrics = StreamingMetrics::unused();
228 let actor_ctx = ActorContext::for_test(actor_id);
229 let upstream = Self::new_merge_upstream(
230 inputs
231 .into_iter()
232 .enumerate()
233 .map(|(idx, input)| LocalInput::new(input, ActorId::new(idx as u32)).boxed_input())
234 .collect(),
235 &metrics,
236 &actor_ctx,
237 chunk_size,
238 schema,
239 );
240
241 Self::new(
242 actor_ctx,
243 514.into(),
244 1919.into(),
245 upstream,
246 local_barrier_manager,
247 metrics.into(),
248 barrier_rx,
249 )
250 }
251
252 pub(crate) fn new_merge_upstream(
253 upstreams: Vec<BoxedActorInput>,
254 metrics: &StreamingMetrics,
255 actor_context: &ActorContext,
256 chunk_size: usize,
257 schema: Schema,
258 ) -> MergeUpstream {
259 let merge_barrier_align_duration = if metrics.level >= MetricLevel::Debug {
260 Some(
261 metrics
262 .merge_barrier_align_duration
263 .with_guarded_label_values(&[
264 &actor_context.id.to_string(),
265 &actor_context.fragment_id.to_string(),
266 ]),
267 )
268 } else {
269 None
270 };
271
272 BufferChunks::new(
273 SelectReceivers::new(upstreams, None, merge_barrier_align_duration),
275 chunk_size,
276 schema,
277 )
278 }
279}
280
281impl<U> MergeExecutorInner<U>
282where
283 U: Upstream,
284{
285 #[try_stream(ok = Message, error = StreamExecutorError)]
286 pub(crate) async fn execute_inner(mut self: Box<Self>) {
287 let mut upstream = self.upstream;
288 let actor_id = self.actor_context.id;
289
290 let mut metrics = self.metrics.new_actor_input_metrics(
291 actor_id,
292 self.fragment_id,
293 self.upstream_fragment_id,
294 );
295
296 let mut barrier_buffer = DispatchBarrierBuffer::new(
297 self.barrier_rx,
298 actor_id,
299 self.upstream_fragment_id,
300 self.local_barrier_manager,
301 self.metrics.clone(),
302 self.fragment_id,
303 self.actor_context.config.clone(),
304 );
305
306 loop {
307 let msg = barrier_buffer
308 .await_next_message(&mut upstream, &metrics)
309 .await?;
310 let msg = match msg {
311 DispatcherMessage::Watermark(watermark) => Message::Watermark(watermark),
312 DispatcherMessage::Chunk(chunk) => {
313 metrics.actor_in_record_cnt.inc_by(chunk.cardinality() as _);
314 Message::Chunk(chunk)
315 }
316 DispatcherMessage::Barrier(barrier) => {
317 let (barrier, new_inputs) =
318 barrier_buffer.pop_barrier_with_inputs(barrier).await?;
319
320 if let Some(Mutation::Update(UpdateMutation { dispatchers, .. })) =
321 barrier.mutation.as_deref()
322 && upstream
323 .upstream_input_ids()
324 .any(|actor_id| dispatchers.contains_key(&actor_id))
325 {
326 upstream.flush_buffered_watermarks();
328 }
329
330 if let Some(update) =
331 barrier.as_update_merge(self.actor_context.id, self.upstream_fragment_id)
332 {
333 let new_upstream_fragment_id = update
334 .new_upstream_fragment_id
335 .unwrap_or(self.upstream_fragment_id);
336 let removed_upstream_actor_id: HashSet<_> =
337 if update.new_upstream_fragment_id.is_some() {
338 upstream.upstream_input_ids().collect()
339 } else {
340 update.removed_upstream_actor_id.iter().copied().collect()
341 };
342
343 upstream.flush_buffered_watermarks();
345
346 upstream.update(new_inputs.unwrap_or_default(), &removed_upstream_actor_id);
348
349 self.upstream_fragment_id = new_upstream_fragment_id;
350 metrics = self.metrics.new_actor_input_metrics(
351 actor_id,
352 self.fragment_id,
353 self.upstream_fragment_id,
354 );
355 }
356
357 let is_stop = barrier.is_stop(actor_id);
358 let msg = Message::Barrier(barrier);
359 if is_stop {
360 yield msg;
361 break;
362 }
363
364 msg
365 }
366 };
367
368 yield msg;
369 }
370 }
371}
372
373impl<U> Execute for MergeExecutorInner<U>
374where
375 U: Upstream,
376{
377 fn execute(self: Box<Self>) -> BoxedMessageStream {
378 self.execute_inner().boxed()
379 }
380}
381
382pub struct BufferChunks<S: Stream> {
386 inner: S,
387 chunk_builder: StreamChunkBuilder,
388
389 pending_items: VecDeque<S::Item>,
391}
392
393impl<S: Stream> BufferChunks<S> {
394 pub(super) fn new(inner: S, chunk_size: usize, schema: Schema) -> Self {
395 assert!(chunk_size > 0);
396 let chunk_builder = StreamChunkBuilder::new(chunk_size, schema.data_types());
397 Self {
398 inner,
399 chunk_builder,
400 pending_items: VecDeque::new(),
401 }
402 }
403}
404
405impl<S: Stream> std::ops::Deref for BufferChunks<S> {
406 type Target = S;
407
408 fn deref(&self) -> &Self::Target {
409 &self.inner
410 }
411}
412
413impl<S: Stream> std::ops::DerefMut for BufferChunks<S> {
414 fn deref_mut(&mut self) -> &mut Self::Target {
415 &mut self.inner
416 }
417}
418
419impl<S: Stream> Stream for BufferChunks<S>
420where
421 S: Stream<Item = DispatcherMessageStreamItem> + Unpin,
422{
423 type Item = S::Item;
424
425 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
426 loop {
427 if let Some(item) = self.pending_items.pop_front() {
428 return Poll::Ready(Some(item));
429 }
430
431 match self.inner.poll_next_unpin(cx) {
432 Poll::Pending => {
433 return if let Some(chunk_out) = self.chunk_builder.take() {
434 Poll::Ready(Some(Ok(MessageInner::Chunk(chunk_out))))
435 } else {
436 Poll::Pending
437 };
438 }
439
440 Poll::Ready(Some(result)) => {
441 if let Ok(MessageInner::Chunk(chunk)) = result {
442 for row in chunk.records() {
443 if let Some(chunk_out) = self.chunk_builder.append_record(row) {
444 self.pending_items
445 .push_back(Ok(MessageInner::Chunk(chunk_out)));
446 }
447 }
448 } else {
449 return if let Some(chunk_out) = self.chunk_builder.take() {
450 self.pending_items.push_back(result);
451 Poll::Ready(Some(Ok(MessageInner::Chunk(chunk_out))))
452 } else {
453 Poll::Ready(Some(result))
454 };
455 }
456 }
457
458 Poll::Ready(None) => {
459 unreachable!("Merge should always have upstream inputs");
461 }
462 }
463 }
464 }
465}
466
467#[cfg(test)]
468mod tests {
469 use std::sync::atomic::{AtomicBool, Ordering};
470 use std::time::Duration;
471
472 use assert_matches::assert_matches;
473 use futures::FutureExt;
474 use futures::future::try_join_all;
475 use risingwave_common::array::Op;
476 use risingwave_common::util::epoch::test_epoch;
477 use risingwave_pb::task_service::stream_exchange_service_server::{
478 StreamExchangeService, StreamExchangeServiceServer,
479 };
480 use risingwave_pb::task_service::{GetStreamRequest, GetStreamResponse, PbPermits};
481 use tokio::time::sleep;
482 use tokio_stream::wrappers::ReceiverStream;
483 use tonic::{Request, Response, Status, Streaming};
484
485 use super::*;
486 use crate::executor::exchange::input::{ActorInput, LocalInput, RemoteInput};
487 use crate::executor::exchange::permit::channel_for_test;
488 use crate::executor::{BarrierInner as Barrier, MessageInner as Message};
489 use crate::task::NewOutputRequest;
490 use crate::task::barrier_test_utils::LocalBarrierTestEnv;
491 use crate::task::test_utils::helper_make_local_actor;
492
493 fn build_test_chunk(size: u64) -> StreamChunk {
494 let ops = vec![Op::Insert; size as usize];
495 StreamChunk::new(ops, vec![])
496 }
497
498 #[tokio::test]
499 async fn test_buffer_chunks() {
500 let test_env = LocalBarrierTestEnv::for_test().await;
501
502 let (tx, rx) = channel_for_test();
503 let input = LocalInput::new(rx, 1.into()).boxed_input();
504 let mut buffer = BufferChunks::new(input, 100, Schema::new(vec![]));
505
506 tx.send(Message::Chunk(build_test_chunk(10)).into())
508 .await
509 .unwrap();
510 assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
511 assert_eq!(chunk.ops().len() as u64, 10);
512 });
513
514 tx.send(Message::Chunk(build_test_chunk(10)).into())
516 .await
517 .unwrap();
518 tx.send(Message::Chunk(build_test_chunk(10)).into())
519 .await
520 .unwrap();
521 assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
522 assert_eq!(chunk.ops().len() as u64, 20);
523 });
524
525 tx.send(
527 Message::Watermark(Watermark {
528 col_idx: 0,
529 data_type: DataType::Int64,
530 val: ScalarImpl::Int64(233),
531 })
532 .into(),
533 )
534 .await
535 .unwrap();
536 assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Watermark(watermark) => {
537 assert_eq!(watermark.val, ScalarImpl::Int64(233));
538 });
539
540 tx.send(Message::Chunk(build_test_chunk(10)).into())
542 .await
543 .unwrap();
544 tx.send(Message::Chunk(build_test_chunk(10)).into())
545 .await
546 .unwrap();
547 tx.send(
548 Message::Watermark(Watermark {
549 col_idx: 0,
550 data_type: DataType::Int64,
551 val: ScalarImpl::Int64(233),
552 })
553 .into(),
554 )
555 .await
556 .unwrap();
557 assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
558 assert_eq!(chunk.ops().len() as u64, 20);
559 });
560 assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Watermark(watermark) => {
561 assert_eq!(watermark.val, ScalarImpl::Int64(233));
562 });
563
564 let barrier = Barrier::new_test_barrier(test_epoch(1));
566 test_env.inject_barrier(&barrier, [2.into()]);
567 tx.send(Message::Barrier(barrier.clone().into_dispatcher()).into())
568 .await
569 .unwrap();
570 assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Barrier(Barrier { epoch: barrier_epoch, mutation: _, .. }) => {
571 assert_eq!(barrier_epoch.curr, test_epoch(1));
572 });
573
574 tx.send(Message::Chunk(build_test_chunk(10)).into())
576 .await
577 .unwrap();
578 tx.send(Message::Chunk(build_test_chunk(10)).into())
579 .await
580 .unwrap();
581 let barrier = Barrier::new_test_barrier(test_epoch(2));
582 test_env.inject_barrier(&barrier, [2.into()]);
583 tx.send(Message::Barrier(barrier.clone().into_dispatcher()).into())
584 .await
585 .unwrap();
586 assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
587 assert_eq!(chunk.ops().len() as u64, 20);
588 });
589 assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Barrier(Barrier { epoch: barrier_epoch, mutation: _, .. }) => {
590 assert_eq!(barrier_epoch.curr, test_epoch(2));
591 });
592 }
593
594 #[tokio::test]
595 async fn test_merger() {
596 const CHANNEL_NUMBER: usize = 10;
597 let mut txs = Vec::with_capacity(CHANNEL_NUMBER);
598 let mut rxs = Vec::with_capacity(CHANNEL_NUMBER);
599 for _i in 0..CHANNEL_NUMBER {
600 let (tx, rx) = channel_for_test();
601 txs.push(tx);
602 rxs.push(rx);
603 }
604 let barrier_test_env = LocalBarrierTestEnv::for_test().await;
605 let actor_id = 233.into();
606 let mut handles = Vec::with_capacity(CHANNEL_NUMBER);
607
608 let epochs = (10..1000u64)
609 .step_by(10)
610 .map(|idx| (idx, test_epoch(idx)))
611 .collect_vec();
612 let mut prev_epoch = 0;
613 let prev_epoch = &mut prev_epoch;
614 let barriers: HashMap<_, _> = epochs
615 .iter()
616 .map(|(_, epoch)| {
617 let barrier = Barrier::with_prev_epoch_for_test(*epoch, *prev_epoch);
618 *prev_epoch = *epoch;
619 barrier_test_env.inject_barrier(&barrier, [actor_id]);
620 (*epoch, barrier)
621 })
622 .collect();
623 let b2 = Barrier::with_prev_epoch_for_test(test_epoch(1000), *prev_epoch)
624 .with_mutation(Mutation::Stop(StopMutation::default()));
625 barrier_test_env.inject_barrier(&b2, [actor_id]);
626 barrier_test_env.flush_all_events().await;
627
628 for (tx_id, tx) in txs.into_iter().enumerate() {
629 let epochs = epochs.clone();
630 let barriers = barriers.clone();
631 let b2 = b2.clone();
632 let handle = tokio::spawn(async move {
633 for (idx, epoch) in epochs {
634 if idx % 20 == 0 {
635 tx.send(Message::Chunk(build_test_chunk(10)).into())
636 .await
637 .unwrap();
638 } else {
639 tx.send(
640 Message::Watermark(Watermark {
641 col_idx: (idx as usize / 20 + tx_id) % CHANNEL_NUMBER,
642 data_type: DataType::Int64,
643 val: ScalarImpl::Int64(idx as i64),
644 })
645 .into(),
646 )
647 .await
648 .unwrap();
649 }
650 tx.send(Message::Barrier(barriers[&epoch].clone().into_dispatcher()).into())
651 .await
652 .unwrap();
653 sleep(Duration::from_millis(1)).await;
654 }
655 tx.send(Message::Barrier(b2.clone().into_dispatcher()).into())
656 .await
657 .unwrap();
658 });
659 handles.push(handle);
660 }
661
662 let merger = MergeExecutor::for_test(
663 actor_id,
664 rxs,
665 barrier_test_env.local_barrier_manager.clone(),
666 Schema::new(vec![]),
667 100,
668 None,
669 );
670 let mut merger = merger.boxed().execute();
671 for (idx, epoch) in epochs {
672 if idx % 20 == 0 {
673 let mut count = 0usize;
675 while count < 100 {
676 assert_matches!(merger.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
677 count += chunk.ops().len();
678 });
679 }
680 assert_eq!(count, 100);
681 } else if idx as usize / 20 >= CHANNEL_NUMBER - 1 {
682 for _ in 0..CHANNEL_NUMBER {
684 assert_matches!(merger.next().await.unwrap().unwrap(), Message::Watermark(watermark) => {
685 assert_eq!(watermark.val, ScalarImpl::Int64((idx - 20 * (CHANNEL_NUMBER as u64 - 1)) as i64));
686 });
687 }
688 }
689 assert_matches!(merger.next().await.unwrap().unwrap(), Message::Barrier(Barrier{epoch:barrier_epoch,mutation:_,..}) => {
691 assert_eq!(barrier_epoch.curr, epoch);
692 });
693 }
694 assert_matches!(
695 merger.next().await.unwrap().unwrap(),
696 Message::Barrier(Barrier {
697 mutation,
698 ..
699 }) if mutation.as_deref().unwrap().is_stop()
700 );
701
702 for handle in handles {
703 handle.await.unwrap();
704 }
705 }
706
707 #[tokio::test]
708 async fn test_configuration_change() {
709 let actor_id = 233.into();
710 let (untouched, old, new) = (234.into(), 235.into(), 238.into()); let barrier_test_env = LocalBarrierTestEnv::for_test().await;
712 let metrics = Arc::new(StreamingMetrics::unused());
713
714 let (upstream_fragment_id, fragment_id) = (10.into(), 18.into());
719
720 let actor_ctx = ActorContext::for_test(actor_id);
721
722 let inputs: Vec<_> =
723 try_join_all([untouched, old].into_iter().map(async |upstream_actor_id| {
724 new_input(
725 &barrier_test_env.local_barrier_manager,
726 metrics.clone(),
727 actor_id,
728 fragment_id,
729 &helper_make_local_actor(upstream_actor_id),
730 upstream_fragment_id,
731 actor_ctx.config.clone(),
732 )
733 .await
734 }))
735 .await
736 .unwrap();
737
738 let merge_updates = maplit::hashmap! {
739 (actor_id, upstream_fragment_id) => MergeUpdate {
740 actor_id,
741 upstream_fragment_id,
742 new_upstream_fragment_id: None,
743 added_upstream_actors: vec![helper_make_local_actor(new)],
744 removed_upstream_actor_id: vec![old],
745 }
746 };
747
748 let b1 = Barrier::new_test_barrier(test_epoch(1)).with_mutation(Mutation::Update(
749 UpdateMutation {
750 merges: merge_updates,
751 ..Default::default()
752 },
753 ));
754 barrier_test_env.inject_barrier(&b1, [actor_id]);
755 barrier_test_env.flush_all_events().await;
756
757 let barrier_rx = barrier_test_env
758 .local_barrier_manager
759 .subscribe_barrier(actor_id);
760 let upstream = MergeExecutor::new_merge_upstream(
761 inputs,
762 &metrics,
763 &actor_ctx,
764 100,
765 Schema::empty().clone(),
766 );
767
768 let mut merge = MergeExecutor::new(
769 actor_ctx,
770 fragment_id,
771 upstream_fragment_id,
772 upstream,
773 barrier_test_env.local_barrier_manager.clone(),
774 metrics.clone(),
775 barrier_rx,
776 )
777 .boxed()
778 .execute();
779
780 let mut txs = HashMap::new();
781 macro_rules! send {
782 ($actors:expr, $msg:expr) => {
783 for actor in $actors {
784 txs.get(&actor).unwrap().send($msg).await.unwrap();
785 }
786 };
787 }
788
789 macro_rules! assert_recv_pending {
790 () => {
791 assert!(
792 merge
793 .next()
794 .now_or_never()
795 .flatten()
796 .transpose()
797 .unwrap()
798 .is_none()
799 );
800 };
801 }
802 macro_rules! recv {
803 () => {
804 merge.next().await.transpose().unwrap()
805 };
806 }
807
808 macro_rules! collect_upstream_tx {
809 ($actors:expr) => {
810 for upstream_id in $actors {
811 let mut output_requests = barrier_test_env
812 .take_pending_new_output_requests(upstream_id)
813 .await;
814 assert_eq!(output_requests.len(), 1);
815 let (downstream_actor_id, request) = output_requests.pop().unwrap();
816 assert_eq!(downstream_actor_id, actor_id);
817 let NewOutputRequest::Local(tx) = request else {
818 unreachable!()
819 };
820 txs.insert(upstream_id, tx);
821 }
822 };
823 }
824
825 assert_recv_pending!();
826 barrier_test_env.flush_all_events().await;
827
828 collect_upstream_tx!([untouched, old]);
830
831 send!([untouched, old], Message::Chunk(build_test_chunk(1)).into());
833 assert_eq!(2, recv!().unwrap().as_chunk().unwrap().cardinality()); assert_recv_pending!();
835
836 send!(
837 [untouched, old],
838 Message::Barrier(b1.clone().into_dispatcher()).into()
839 );
840 assert_recv_pending!(); collect_upstream_tx!([new]);
843
844 send!([new], Message::Barrier(b1.clone().into_dispatcher()).into());
845 recv!().unwrap().as_barrier().unwrap(); send!([untouched, new], Message::Chunk(build_test_chunk(1)).into());
849 assert_eq!(2, recv!().unwrap().as_chunk().unwrap().cardinality()); assert_recv_pending!();
851 }
852
853 struct FakeExchangeService {
854 rpc_called: Arc<AtomicBool>,
855 }
856
857 fn exchange_client_test_barrier() -> crate::executor::Barrier {
858 Barrier::new_test_barrier(test_epoch(1))
859 }
860
861 #[async_trait::async_trait]
862 impl StreamExchangeService for FakeExchangeService {
863 type GetStreamStream = ReceiverStream<std::result::Result<GetStreamResponse, Status>>;
864
865 async fn get_stream(
866 &self,
867 _request: Request<Streaming<GetStreamRequest>>,
868 ) -> std::result::Result<Response<Self::GetStreamStream>, Status> {
869 let (tx, rx) = tokio::sync::mpsc::channel(10);
870 self.rpc_called.store(true, Ordering::SeqCst);
871 let stream_chunk = StreamChunk::default().to_protobuf();
873 tx.send(Ok(GetStreamResponse {
874 message: Some(PbStreamMessageBatch {
875 stream_message_batch: Some(
876 risingwave_pb::stream_plan::stream_message_batch::StreamMessageBatch::StreamChunk(
877 stream_chunk,
878 ),
879 ),
880 }),
881 permits: Some(PbPermits::default()),
882 }))
883 .await
884 .unwrap();
885 let barrier = exchange_client_test_barrier();
887 tx.send(Ok(GetStreamResponse {
888 message: Some(PbStreamMessageBatch {
889 stream_message_batch: Some(
890 risingwave_pb::stream_plan::stream_message_batch::StreamMessageBatch::BarrierBatch(
891 BarrierBatch {
892 barriers: vec![barrier.to_protobuf()],
893 },
894 ),
895 ),
896 }),
897 permits: Some(PbPermits::default()),
898 }))
899 .await
900 .unwrap();
901 Ok(Response::new(ReceiverStream::new(rx)))
902 }
903 }
904
905 #[tokio::test]
906 async fn test_stream_exchange_client() {
907 let rpc_called = Arc::new(AtomicBool::new(false));
908 let server_run = Arc::new(AtomicBool::new(false));
909 let addr = "127.0.0.1:12348".parse().unwrap();
910
911 let (shutdown_send, shutdown_recv) = tokio::sync::oneshot::channel();
913 let exchange_svc = StreamExchangeServiceServer::new(FakeExchangeService {
914 rpc_called: rpc_called.clone(),
915 });
916 let cp_server_run = server_run.clone();
917 let join_handle = tokio::spawn(async move {
918 cp_server_run.store(true, Ordering::SeqCst);
919 tonic::transport::Server::builder()
920 .add_service(exchange_svc)
921 .serve_with_shutdown(addr, async move {
922 shutdown_recv.await.unwrap();
923 })
924 .await
925 .unwrap();
926 });
927
928 sleep(Duration::from_secs(1)).await;
929 assert!(server_run.load(Ordering::SeqCst));
930
931 let test_env = LocalBarrierTestEnv::for_test().await;
932
933 let remote_input = {
934 RemoteInput::new(
935 &test_env.local_barrier_manager,
936 addr.into(),
937 (0.into(), 0.into()),
938 (0.into(), 0.into()),
939 Arc::new(StreamingMetrics::unused()),
940 Arc::new(StreamingConfig::default()),
941 )
942 .await
943 .unwrap()
944 };
945
946 test_env.inject_barrier(&exchange_client_test_barrier(), [remote_input.id()]);
947
948 pin_mut!(remote_input);
949
950 assert_matches!(remote_input.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
951 let (ops, columns, visibility) = chunk.into_inner();
952 assert!(ops.is_empty());
953 assert!(columns.is_empty());
954 assert!(visibility.is_empty());
955 });
956 assert_matches!(remote_input.next().await.unwrap().unwrap(), Message::Barrier(Barrier { epoch: barrier_epoch, mutation: _, .. }) => {
957 assert_eq!(barrier_epoch.curr, test_epoch(1));
958 });
959 assert!(rpc_called.load(Ordering::SeqCst));
960
961 shutdown_send.send(()).unwrap();
962 join_handle.await.unwrap();
963 }
964}