1use std::collections::VecDeque;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18
19use anyhow::Context as _;
20use futures::future::try_join_all;
21use risingwave_common::array::StreamChunkBuilder;
22use risingwave_common::config::MetricLevel;
23use tokio::sync::mpsc;
24use tokio::time::Instant;
25
26use super::exchange::input::BoxedActorInput;
27use super::*;
28use crate::executor::exchange::input::{
29 assert_equal_dispatcher_barrier, new_input, process_dispatcher_msg,
30};
31use crate::executor::prelude::*;
32use crate::task::LocalBarrierManager;
33
34pub type SelectReceivers = DynamicReceivers<ActorId, ()>;
35
36pub(crate) enum MergeExecutorUpstream {
37 Singleton(BoxedActorInput),
38 Merge(SelectReceivers),
39}
40
41pub(crate) struct MergeExecutorInput {
42 upstream: MergeExecutorUpstream,
43 actor_context: ActorContextRef,
44 upstream_fragment_id: UpstreamFragmentId,
45 local_barrier_manager: LocalBarrierManager,
46 executor_stats: Arc<StreamingMetrics>,
47 pub(crate) info: ExecutorInfo,
48 chunk_size: usize,
49}
50
51impl MergeExecutorInput {
52 pub(crate) fn new(
53 upstream: MergeExecutorUpstream,
54 actor_context: ActorContextRef,
55 upstream_fragment_id: UpstreamFragmentId,
56 local_barrier_manager: LocalBarrierManager,
57 executor_stats: Arc<StreamingMetrics>,
58 info: ExecutorInfo,
59 chunk_size: usize,
60 ) -> Self {
61 Self {
62 upstream,
63 actor_context,
64 upstream_fragment_id,
65 local_barrier_manager,
66 executor_stats,
67 info,
68 chunk_size,
69 }
70 }
71
72 pub(crate) fn into_executor(self, barrier_rx: mpsc::UnboundedReceiver<Barrier>) -> Executor {
73 let fragment_id = self.actor_context.fragment_id;
74 let executor = match self.upstream {
75 MergeExecutorUpstream::Singleton(input) => ReceiverExecutor::new(
76 self.actor_context,
77 fragment_id,
78 self.upstream_fragment_id,
79 input,
80 self.local_barrier_manager,
81 self.executor_stats,
82 barrier_rx,
83 )
84 .boxed(),
85 MergeExecutorUpstream::Merge(inputs) => MergeExecutor::new(
86 self.actor_context,
87 fragment_id,
88 self.upstream_fragment_id,
89 inputs,
90 self.local_barrier_manager,
91 self.executor_stats,
92 barrier_rx,
93 self.chunk_size,
94 self.info.schema.clone(),
95 )
96 .boxed(),
97 };
98 (self.info, executor).into()
99 }
100}
101
102impl Stream for MergeExecutorInput {
103 type Item = DispatcherMessageStreamItem;
104
105 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
106 match &mut self.get_mut().upstream {
107 MergeExecutorUpstream::Singleton(input) => input.poll_next_unpin(cx),
108 MergeExecutorUpstream::Merge(inputs) => inputs.poll_next_unpin(cx),
109 }
110 }
111}
112
113pub struct MergeExecutor {
116 actor_context: ActorContextRef,
118
119 upstreams: SelectReceivers,
121
122 fragment_id: FragmentId,
124
125 upstream_fragment_id: FragmentId,
127
128 local_barrier_manager: LocalBarrierManager,
129
130 metrics: Arc<StreamingMetrics>,
132
133 barrier_rx: mpsc::UnboundedReceiver<Barrier>,
134
135 chunk_size: usize,
137
138 schema: Schema,
140}
141
142impl MergeExecutor {
143 #[allow(clippy::too_many_arguments)]
144 pub fn new(
145 ctx: ActorContextRef,
146 fragment_id: FragmentId,
147 upstream_fragment_id: FragmentId,
148 upstreams: SelectReceivers,
149 local_barrier_manager: LocalBarrierManager,
150 metrics: Arc<StreamingMetrics>,
151 barrier_rx: mpsc::UnboundedReceiver<Barrier>,
152 chunk_size: usize,
153 schema: Schema,
154 ) -> Self {
155 Self {
156 actor_context: ctx,
157 upstreams,
158 fragment_id,
159 upstream_fragment_id,
160 local_barrier_manager,
161 metrics,
162 barrier_rx,
163 chunk_size,
164 schema,
165 }
166 }
167
168 #[cfg(test)]
169 pub fn for_test(
170 actor_id: ActorId,
171 inputs: Vec<super::exchange::permit::Receiver>,
172 local_barrier_manager: crate::task::LocalBarrierManager,
173 schema: Schema,
174 chunk_size: usize,
175 barrier_rx: Option<mpsc::UnboundedReceiver<Barrier>>,
176 ) -> Self {
177 use super::exchange::input::LocalInput;
178 use crate::executor::exchange::input::ActorInput;
179
180 let barrier_rx =
181 barrier_rx.unwrap_or_else(|| local_barrier_manager.subscribe_barrier(actor_id));
182
183 let metrics = StreamingMetrics::unused();
184 let actor_ctx = ActorContext::for_test(actor_id);
185 let upstream = Self::new_select_receiver(
186 inputs
187 .into_iter()
188 .enumerate()
189 .map(|(idx, input)| LocalInput::new(input, idx as ActorId).boxed_input())
190 .collect(),
191 &metrics,
192 &actor_ctx,
193 );
194
195 Self::new(
196 actor_ctx,
197 514,
198 1919,
199 upstream,
200 local_barrier_manager,
201 metrics.into(),
202 barrier_rx,
203 chunk_size,
204 schema,
205 )
206 }
207
208 pub(crate) fn new_select_receiver(
209 upstreams: Vec<BoxedActorInput>,
210 metrics: &StreamingMetrics,
211 actor_context: &ActorContext,
212 ) -> SelectReceivers {
213 let merge_barrier_align_duration = if metrics.level >= MetricLevel::Debug {
214 Some(
215 metrics
216 .merge_barrier_align_duration
217 .with_guarded_label_values(&[
218 &actor_context.id.to_string(),
219 &actor_context.fragment_id.to_string(),
220 ]),
221 )
222 } else {
223 None
224 };
225
226 SelectReceivers::new(upstreams, None, merge_barrier_align_duration.clone())
228 }
229
230 #[try_stream(ok = Message, error = StreamExecutorError)]
231 pub(crate) async fn execute_inner(mut self: Box<Self>) {
232 let select_all = self.upstreams;
233 let select_all = BufferChunks::new(select_all, self.chunk_size, self.schema);
234 let actor_id = self.actor_context.id;
235
236 let mut metrics = self.metrics.new_actor_input_metrics(
237 actor_id,
238 self.fragment_id,
239 self.upstream_fragment_id,
240 );
241
242 let mut start_time = Instant::now();
244 pin_mut!(select_all);
245 while let Some(msg) = select_all.next().await {
246 metrics
247 .actor_input_buffer_blocking_duration_ns
248 .inc_by(start_time.elapsed().as_nanos() as u64);
249 let msg: DispatcherMessage = msg?;
250 let mut msg: Message = process_dispatcher_msg(msg, &mut self.barrier_rx).await?;
251
252 match &mut msg {
253 Message::Watermark(_) => {
254 }
256 Message::Chunk(chunk) => {
257 metrics.actor_in_record_cnt.inc_by(chunk.cardinality() as _);
258 }
259 Message::Barrier(barrier) => {
260 tracing::debug!(
261 target: "events::stream::barrier::path",
262 actor_id = actor_id,
263 "receiver receives barrier from path: {:?}",
264 barrier.passed_actors
265 );
266 barrier.passed_actors.push(actor_id);
267
268 if let Some(Mutation::Update(UpdateMutation { dispatchers, .. })) =
269 barrier.mutation.as_deref()
270 && select_all
271 .upstream_input_ids()
272 .any(|actor_id| dispatchers.contains_key(&actor_id))
273 {
274 select_all.flush_buffered_watermarks();
276 }
277
278 if let Some(update) =
279 barrier.as_update_merge(self.actor_context.id, self.upstream_fragment_id)
280 {
281 let new_upstream_fragment_id = update
282 .new_upstream_fragment_id
283 .unwrap_or(self.upstream_fragment_id);
284 let removed_upstream_actor_id: HashSet<_> =
285 if update.new_upstream_fragment_id.is_some() {
286 select_all.upstream_input_ids().collect()
287 } else {
288 update.removed_upstream_actor_id.iter().copied().collect()
289 };
290
291 select_all.flush_buffered_watermarks();
293
294 if !update.added_upstream_actors.is_empty() {
295 let mut new_upstreams: Vec<_> = try_join_all(
297 update.added_upstream_actors.iter().map(|upstream_actor| {
298 new_input(
299 &self.local_barrier_manager,
300 self.metrics.clone(),
301 self.actor_context.id,
302 self.fragment_id,
303 upstream_actor,
304 new_upstream_fragment_id,
305 )
306 }),
307 )
308 .await
309 .context("failed to create upstream receivers")?;
310
311 for upstream in &mut new_upstreams {
314 let new_barrier = expect_first_barrier(upstream).await?;
315 assert_equal_dispatcher_barrier(barrier, &new_barrier);
316 }
317
318 select_all.add_upstreams_from(new_upstreams);
320 }
321
322 if !removed_upstream_actor_id.is_empty() {
323 select_all.remove_upstreams(&removed_upstream_actor_id);
325 }
326
327 self.upstream_fragment_id = new_upstream_fragment_id;
328 metrics = self.metrics.new_actor_input_metrics(
329 actor_id,
330 self.fragment_id,
331 self.upstream_fragment_id,
332 );
333 }
334
335 if barrier.is_stop(actor_id) {
336 yield msg;
337 break;
338 }
339 }
340 }
341
342 yield msg;
343 start_time = Instant::now();
344 }
345 }
346}
347
348impl Execute for MergeExecutor {
349 fn execute(self: Box<Self>) -> BoxedMessageStream {
350 self.execute_inner().boxed()
351 }
352}
353
354struct BufferChunks<S: Stream> {
358 inner: S,
359 chunk_builder: StreamChunkBuilder,
360
361 pending_items: VecDeque<S::Item>,
363}
364
365impl<S: Stream> BufferChunks<S> {
366 pub(super) fn new(inner: S, chunk_size: usize, schema: Schema) -> Self {
367 assert!(chunk_size > 0);
368 let chunk_builder = StreamChunkBuilder::new(chunk_size, schema.data_types());
369 Self {
370 inner,
371 chunk_builder,
372 pending_items: VecDeque::new(),
373 }
374 }
375}
376
377impl<S: Stream> std::ops::Deref for BufferChunks<S> {
378 type Target = S;
379
380 fn deref(&self) -> &Self::Target {
381 &self.inner
382 }
383}
384
385impl<S: Stream> std::ops::DerefMut for BufferChunks<S> {
386 fn deref_mut(&mut self) -> &mut Self::Target {
387 &mut self.inner
388 }
389}
390
391impl<S: Stream> Stream for BufferChunks<S>
392where
393 S: Stream<Item = DispatcherMessageStreamItem> + Unpin,
394{
395 type Item = S::Item;
396
397 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
398 loop {
399 if let Some(item) = self.pending_items.pop_front() {
400 return Poll::Ready(Some(item));
401 }
402
403 match self.inner.poll_next_unpin(cx) {
404 Poll::Pending => {
405 return if let Some(chunk_out) = self.chunk_builder.take() {
406 Poll::Ready(Some(Ok(MessageInner::Chunk(chunk_out))))
407 } else {
408 Poll::Pending
409 };
410 }
411
412 Poll::Ready(Some(result)) => {
413 if let Ok(MessageInner::Chunk(chunk)) = result {
414 for row in chunk.records() {
415 if let Some(chunk_out) = self.chunk_builder.append_record(row) {
416 self.pending_items
417 .push_back(Ok(MessageInner::Chunk(chunk_out)));
418 }
419 }
420 } else {
421 return if let Some(chunk_out) = self.chunk_builder.take() {
422 self.pending_items.push_back(result);
423 Poll::Ready(Some(Ok(MessageInner::Chunk(chunk_out))))
424 } else {
425 Poll::Ready(Some(result))
426 };
427 }
428 }
429
430 Poll::Ready(None) => {
431 unreachable!("Merge should always have upstream inputs");
433 }
434 }
435 }
436 }
437}
438
439#[cfg(test)]
440mod tests {
441 use std::sync::atomic::{AtomicBool, Ordering};
442 use std::time::Duration;
443
444 use assert_matches::assert_matches;
445 use futures::FutureExt;
446 use futures::future::try_join_all;
447 use risingwave_common::array::Op;
448 use risingwave_common::util::epoch::test_epoch;
449 use risingwave_pb::task_service::exchange_service_server::{
450 ExchangeService, ExchangeServiceServer,
451 };
452 use risingwave_pb::task_service::{
453 GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse, PbPermits,
454 };
455 use tokio::time::sleep;
456 use tokio_stream::wrappers::ReceiverStream;
457 use tonic::{Request, Response, Status, Streaming};
458
459 use super::*;
460 use crate::executor::exchange::input::{ActorInput, LocalInput, RemoteInput};
461 use crate::executor::exchange::permit::channel_for_test;
462 use crate::executor::{BarrierInner as Barrier, MessageInner as Message};
463 use crate::task::NewOutputRequest;
464 use crate::task::barrier_test_utils::LocalBarrierTestEnv;
465 use crate::task::test_utils::helper_make_local_actor;
466
467 fn build_test_chunk(size: u64) -> StreamChunk {
468 let ops = vec![Op::Insert; size as usize];
469 StreamChunk::new(ops, vec![])
470 }
471
472 #[tokio::test]
473 async fn test_buffer_chunks() {
474 let test_env = LocalBarrierTestEnv::for_test().await;
475
476 let (tx, rx) = channel_for_test();
477 let input = LocalInput::new(rx, 1).boxed_input();
478 let mut buffer = BufferChunks::new(input, 100, Schema::new(vec![]));
479
480 tx.send(Message::Chunk(build_test_chunk(10)).into())
482 .await
483 .unwrap();
484 assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
485 assert_eq!(chunk.ops().len() as u64, 10);
486 });
487
488 tx.send(Message::Chunk(build_test_chunk(10)).into())
490 .await
491 .unwrap();
492 tx.send(Message::Chunk(build_test_chunk(10)).into())
493 .await
494 .unwrap();
495 assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
496 assert_eq!(chunk.ops().len() as u64, 20);
497 });
498
499 tx.send(
501 Message::Watermark(Watermark {
502 col_idx: 0,
503 data_type: DataType::Int64,
504 val: ScalarImpl::Int64(233),
505 })
506 .into(),
507 )
508 .await
509 .unwrap();
510 assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Watermark(watermark) => {
511 assert_eq!(watermark.val, ScalarImpl::Int64(233));
512 });
513
514 tx.send(Message::Chunk(build_test_chunk(10)).into())
516 .await
517 .unwrap();
518 tx.send(Message::Chunk(build_test_chunk(10)).into())
519 .await
520 .unwrap();
521 tx.send(
522 Message::Watermark(Watermark {
523 col_idx: 0,
524 data_type: DataType::Int64,
525 val: ScalarImpl::Int64(233),
526 })
527 .into(),
528 )
529 .await
530 .unwrap();
531 assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
532 assert_eq!(chunk.ops().len() as u64, 20);
533 });
534 assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Watermark(watermark) => {
535 assert_eq!(watermark.val, ScalarImpl::Int64(233));
536 });
537
538 let barrier = Barrier::new_test_barrier(test_epoch(1));
540 test_env.inject_barrier(&barrier, [2]);
541 tx.send(Message::Barrier(barrier.clone().into_dispatcher()).into())
542 .await
543 .unwrap();
544 assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Barrier(Barrier { epoch: barrier_epoch, mutation: _, .. }) => {
545 assert_eq!(barrier_epoch.curr, test_epoch(1));
546 });
547
548 tx.send(Message::Chunk(build_test_chunk(10)).into())
550 .await
551 .unwrap();
552 tx.send(Message::Chunk(build_test_chunk(10)).into())
553 .await
554 .unwrap();
555 let barrier = Barrier::new_test_barrier(test_epoch(2));
556 test_env.inject_barrier(&barrier, [2]);
557 tx.send(Message::Barrier(barrier.clone().into_dispatcher()).into())
558 .await
559 .unwrap();
560 assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
561 assert_eq!(chunk.ops().len() as u64, 20);
562 });
563 assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Barrier(Barrier { epoch: barrier_epoch, mutation: _, .. }) => {
564 assert_eq!(barrier_epoch.curr, test_epoch(2));
565 });
566 }
567
568 #[tokio::test]
569 async fn test_merger() {
570 const CHANNEL_NUMBER: usize = 10;
571 let mut txs = Vec::with_capacity(CHANNEL_NUMBER);
572 let mut rxs = Vec::with_capacity(CHANNEL_NUMBER);
573 for _i in 0..CHANNEL_NUMBER {
574 let (tx, rx) = channel_for_test();
575 txs.push(tx);
576 rxs.push(rx);
577 }
578 let barrier_test_env = LocalBarrierTestEnv::for_test().await;
579 let actor_id = 233;
580 let mut handles = Vec::with_capacity(CHANNEL_NUMBER);
581
582 let epochs = (10..1000u64)
583 .step_by(10)
584 .map(|idx| (idx, test_epoch(idx)))
585 .collect_vec();
586 let mut prev_epoch = 0;
587 let prev_epoch = &mut prev_epoch;
588 let barriers: HashMap<_, _> = epochs
589 .iter()
590 .map(|(_, epoch)| {
591 let barrier = Barrier::with_prev_epoch_for_test(*epoch, *prev_epoch);
592 *prev_epoch = *epoch;
593 barrier_test_env.inject_barrier(&barrier, [actor_id]);
594 (*epoch, barrier)
595 })
596 .collect();
597 let b2 = Barrier::with_prev_epoch_for_test(test_epoch(1000), *prev_epoch)
598 .with_mutation(Mutation::Stop(StopMutation::default()));
599 barrier_test_env.inject_barrier(&b2, [actor_id]);
600 barrier_test_env.flush_all_events().await;
601
602 for (tx_id, tx) in txs.into_iter().enumerate() {
603 let epochs = epochs.clone();
604 let barriers = barriers.clone();
605 let b2 = b2.clone();
606 let handle = tokio::spawn(async move {
607 for (idx, epoch) in epochs {
608 if idx % 20 == 0 {
609 tx.send(Message::Chunk(build_test_chunk(10)).into())
610 .await
611 .unwrap();
612 } else {
613 tx.send(
614 Message::Watermark(Watermark {
615 col_idx: (idx as usize / 20 + tx_id) % CHANNEL_NUMBER,
616 data_type: DataType::Int64,
617 val: ScalarImpl::Int64(idx as i64),
618 })
619 .into(),
620 )
621 .await
622 .unwrap();
623 }
624 tx.send(Message::Barrier(barriers[&epoch].clone().into_dispatcher()).into())
625 .await
626 .unwrap();
627 sleep(Duration::from_millis(1)).await;
628 }
629 tx.send(Message::Barrier(b2.clone().into_dispatcher()).into())
630 .await
631 .unwrap();
632 });
633 handles.push(handle);
634 }
635
636 let merger = MergeExecutor::for_test(
637 actor_id,
638 rxs,
639 barrier_test_env.local_barrier_manager.clone(),
640 Schema::new(vec![]),
641 100,
642 None,
643 );
644 let mut merger = merger.boxed().execute();
645 for (idx, epoch) in epochs {
646 if idx % 20 == 0 {
647 let mut count = 0usize;
649 while count < 100 {
650 assert_matches!(merger.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
651 count += chunk.ops().len();
652 });
653 }
654 assert_eq!(count, 100);
655 } else if idx as usize / 20 >= CHANNEL_NUMBER - 1 {
656 for _ in 0..CHANNEL_NUMBER {
658 assert_matches!(merger.next().await.unwrap().unwrap(), Message::Watermark(watermark) => {
659 assert_eq!(watermark.val, ScalarImpl::Int64((idx - 20 * (CHANNEL_NUMBER as u64 - 1)) as i64));
660 });
661 }
662 }
663 assert_matches!(merger.next().await.unwrap().unwrap(), Message::Barrier(Barrier{epoch:barrier_epoch,mutation:_,..}) => {
665 assert_eq!(barrier_epoch.curr, epoch);
666 });
667 }
668 assert_matches!(
669 merger.next().await.unwrap().unwrap(),
670 Message::Barrier(Barrier {
671 mutation,
672 ..
673 }) if mutation.as_deref().unwrap().is_stop()
674 );
675
676 for handle in handles {
677 handle.await.unwrap();
678 }
679 }
680
681 #[tokio::test]
682 async fn test_configuration_change() {
683 let actor_id = 233;
684 let (untouched, old, new) = (234, 235, 238); let barrier_test_env = LocalBarrierTestEnv::for_test().await;
686 let metrics = Arc::new(StreamingMetrics::unused());
687
688 let (upstream_fragment_id, fragment_id) = (10, 18);
693
694 let inputs: Vec<_> =
695 try_join_all([untouched, old].into_iter().map(async |upstream_actor_id| {
696 new_input(
697 &barrier_test_env.local_barrier_manager,
698 metrics.clone(),
699 actor_id,
700 fragment_id,
701 &helper_make_local_actor(upstream_actor_id),
702 upstream_fragment_id,
703 )
704 .await
705 }))
706 .await
707 .unwrap();
708
709 let merge_updates = maplit::hashmap! {
710 (actor_id, upstream_fragment_id) => MergeUpdate {
711 actor_id,
712 upstream_fragment_id,
713 new_upstream_fragment_id: None,
714 added_upstream_actors: vec![helper_make_local_actor(new)],
715 removed_upstream_actor_id: vec![old],
716 }
717 };
718
719 let b1 = Barrier::new_test_barrier(test_epoch(1)).with_mutation(Mutation::Update(
720 UpdateMutation {
721 merges: merge_updates,
722 ..Default::default()
723 },
724 ));
725 barrier_test_env.inject_barrier(&b1, [actor_id]);
726 barrier_test_env.flush_all_events().await;
727
728 let barrier_rx = barrier_test_env
729 .local_barrier_manager
730 .subscribe_barrier(actor_id);
731 let actor_ctx = ActorContext::for_test(actor_id);
732 let upstream = MergeExecutor::new_select_receiver(inputs, &metrics, &actor_ctx);
733
734 let mut merge = MergeExecutor::new(
735 actor_ctx,
736 fragment_id,
737 upstream_fragment_id,
738 upstream,
739 barrier_test_env.local_barrier_manager.clone(),
740 metrics.clone(),
741 barrier_rx,
742 100,
743 Schema::new(vec![]),
744 )
745 .boxed()
746 .execute();
747
748 let mut txs = HashMap::new();
749 macro_rules! send {
750 ($actors:expr, $msg:expr) => {
751 for actor in $actors {
752 txs.get(&actor).unwrap().send($msg).await.unwrap();
753 }
754 };
755 }
756
757 macro_rules! assert_recv_pending {
758 () => {
759 assert!(
760 merge
761 .next()
762 .now_or_never()
763 .flatten()
764 .transpose()
765 .unwrap()
766 .is_none()
767 );
768 };
769 }
770 macro_rules! recv {
771 () => {
772 merge.next().await.transpose().unwrap()
773 };
774 }
775
776 macro_rules! collect_upstream_tx {
777 ($actors:expr) => {
778 for upstream_id in $actors {
779 let mut output_requests = barrier_test_env
780 .take_pending_new_output_requests(upstream_id)
781 .await;
782 assert_eq!(output_requests.len(), 1);
783 let (downstream_actor_id, request) = output_requests.pop().unwrap();
784 assert_eq!(actor_id, downstream_actor_id);
785 let NewOutputRequest::Local(tx) = request else {
786 unreachable!()
787 };
788 txs.insert(upstream_id, tx);
789 }
790 };
791 }
792
793 assert_recv_pending!();
794 barrier_test_env.flush_all_events().await;
795
796 collect_upstream_tx!([untouched, old]);
798
799 send!([untouched, old], Message::Chunk(build_test_chunk(1)).into());
801 assert_eq!(2, recv!().unwrap().as_chunk().unwrap().cardinality()); assert_recv_pending!();
803
804 send!(
805 [untouched, old],
806 Message::Barrier(b1.clone().into_dispatcher()).into()
807 );
808 assert_recv_pending!(); collect_upstream_tx!([new]);
811
812 send!([new], Message::Barrier(b1.clone().into_dispatcher()).into());
813 recv!().unwrap().as_barrier().unwrap(); send!([untouched, new], Message::Chunk(build_test_chunk(1)).into());
817 assert_eq!(2, recv!().unwrap().as_chunk().unwrap().cardinality()); assert_recv_pending!();
819 }
820
821 struct FakeExchangeService {
822 rpc_called: Arc<AtomicBool>,
823 }
824
825 fn exchange_client_test_barrier() -> crate::executor::Barrier {
826 Barrier::new_test_barrier(test_epoch(1))
827 }
828
829 #[async_trait::async_trait]
830 impl ExchangeService for FakeExchangeService {
831 type GetDataStream = ReceiverStream<std::result::Result<GetDataResponse, Status>>;
832 type GetStreamStream = ReceiverStream<std::result::Result<GetStreamResponse, Status>>;
833
834 async fn get_data(
835 &self,
836 _: Request<GetDataRequest>,
837 ) -> std::result::Result<Response<Self::GetDataStream>, Status> {
838 unimplemented!()
839 }
840
841 async fn get_stream(
842 &self,
843 _request: Request<Streaming<GetStreamRequest>>,
844 ) -> std::result::Result<Response<Self::GetStreamStream>, Status> {
845 let (tx, rx) = tokio::sync::mpsc::channel(10);
846 self.rpc_called.store(true, Ordering::SeqCst);
847 let stream_chunk = StreamChunk::default().to_protobuf();
849 tx.send(Ok(GetStreamResponse {
850 message: Some(PbStreamMessageBatch {
851 stream_message_batch: Some(
852 risingwave_pb::stream_plan::stream_message_batch::StreamMessageBatch::StreamChunk(
853 stream_chunk,
854 ),
855 ),
856 }),
857 permits: Some(PbPermits::default()),
858 }))
859 .await
860 .unwrap();
861 let barrier = exchange_client_test_barrier();
863 tx.send(Ok(GetStreamResponse {
864 message: Some(PbStreamMessageBatch {
865 stream_message_batch: Some(
866 risingwave_pb::stream_plan::stream_message_batch::StreamMessageBatch::BarrierBatch(
867 BarrierBatch {
868 barriers: vec![barrier.to_protobuf()],
869 },
870 ),
871 ),
872 }),
873 permits: Some(PbPermits::default()),
874 }))
875 .await
876 .unwrap();
877 Ok(Response::new(ReceiverStream::new(rx)))
878 }
879 }
880
881 #[tokio::test]
882 async fn test_stream_exchange_client() {
883 let rpc_called = Arc::new(AtomicBool::new(false));
884 let server_run = Arc::new(AtomicBool::new(false));
885 let addr = "127.0.0.1:12348".parse().unwrap();
886
887 let (shutdown_send, shutdown_recv) = tokio::sync::oneshot::channel();
889 let exchange_svc = ExchangeServiceServer::new(FakeExchangeService {
890 rpc_called: rpc_called.clone(),
891 });
892 let cp_server_run = server_run.clone();
893 let join_handle = tokio::spawn(async move {
894 cp_server_run.store(true, Ordering::SeqCst);
895 tonic::transport::Server::builder()
896 .add_service(exchange_svc)
897 .serve_with_shutdown(addr, async move {
898 shutdown_recv.await.unwrap();
899 })
900 .await
901 .unwrap();
902 });
903
904 sleep(Duration::from_secs(1)).await;
905 assert!(server_run.load(Ordering::SeqCst));
906
907 let test_env = LocalBarrierTestEnv::for_test().await;
908
909 let remote_input = {
910 RemoteInput::new(
911 &test_env.local_barrier_manager,
912 addr.into(),
913 (0, 0),
914 (0, 0),
915 Arc::new(StreamingMetrics::unused()),
916 )
917 .await
918 .unwrap()
919 };
920
921 test_env.inject_barrier(&exchange_client_test_barrier(), [remote_input.id()]);
922
923 pin_mut!(remote_input);
924
925 assert_matches!(remote_input.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
926 let (ops, columns, visibility) = chunk.into_inner();
927 assert!(ops.is_empty());
928 assert!(columns.is_empty());
929 assert!(visibility.is_empty());
930 });
931 assert_matches!(remote_input.next().await.unwrap().unwrap(), Message::Barrier(Barrier { epoch: barrier_epoch, mutation: _, .. }) => {
932 assert_eq!(barrier_epoch.curr, test_epoch(1));
933 });
934 assert!(rpc_called.load(Ordering::SeqCst));
935
936 shutdown_send.send(()).unwrap();
937 join_handle.await.unwrap();
938 }
939}