1use std::collections::VecDeque;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18
19use risingwave_common::array::StreamChunkBuilder;
20use risingwave_common::config::MetricLevel;
21use tokio::sync::mpsc;
22use tokio::time::Instant;
23
24use super::exchange::input::BoxedActorInput;
25use super::*;
26use crate::executor::prelude::*;
27use crate::task::LocalBarrierManager;
28
29pub type SelectReceivers = DynamicReceivers<ActorId, ()>;
30
31pub(crate) enum MergeExecutorUpstream {
32 Singleton(BoxedActorInput),
33 Merge(SelectReceivers),
34}
35
36pub(crate) struct MergeExecutorInput {
37 upstream: MergeExecutorUpstream,
38 actor_context: ActorContextRef,
39 upstream_fragment_id: UpstreamFragmentId,
40 local_barrier_manager: LocalBarrierManager,
41 executor_stats: Arc<StreamingMetrics>,
42 pub(crate) info: ExecutorInfo,
43 chunk_size: usize,
44}
45
46impl MergeExecutorInput {
47 pub(crate) fn new(
48 upstream: MergeExecutorUpstream,
49 actor_context: ActorContextRef,
50 upstream_fragment_id: UpstreamFragmentId,
51 local_barrier_manager: LocalBarrierManager,
52 executor_stats: Arc<StreamingMetrics>,
53 info: ExecutorInfo,
54 chunk_size: usize,
55 ) -> Self {
56 Self {
57 upstream,
58 actor_context,
59 upstream_fragment_id,
60 local_barrier_manager,
61 executor_stats,
62 info,
63 chunk_size,
64 }
65 }
66
67 pub(crate) fn into_executor(self, barrier_rx: mpsc::UnboundedReceiver<Barrier>) -> Executor {
68 let fragment_id = self.actor_context.fragment_id;
69 let executor = match self.upstream {
70 MergeExecutorUpstream::Singleton(input) => ReceiverExecutor::new(
71 self.actor_context,
72 fragment_id,
73 self.upstream_fragment_id,
74 input,
75 self.local_barrier_manager,
76 self.executor_stats,
77 barrier_rx,
78 )
79 .boxed(),
80 MergeExecutorUpstream::Merge(inputs) => MergeExecutor::new(
81 self.actor_context,
82 fragment_id,
83 self.upstream_fragment_id,
84 inputs,
85 self.local_barrier_manager,
86 self.executor_stats,
87 barrier_rx,
88 self.chunk_size,
89 self.info.schema.clone(),
90 )
91 .boxed(),
92 };
93 (self.info, executor).into()
94 }
95}
96
97impl Stream for MergeExecutorInput {
98 type Item = DispatcherMessageStreamItem;
99
100 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
101 match &mut self.get_mut().upstream {
102 MergeExecutorUpstream::Singleton(input) => input.poll_next_unpin(cx),
103 MergeExecutorUpstream::Merge(inputs) => inputs.poll_next_unpin(cx),
104 }
105 }
106}
107
108pub struct MergeExecutor {
111 actor_context: ActorContextRef,
113
114 upstreams: SelectReceivers,
116
117 fragment_id: FragmentId,
119
120 upstream_fragment_id: FragmentId,
122
123 local_barrier_manager: LocalBarrierManager,
124
125 metrics: Arc<StreamingMetrics>,
127
128 barrier_rx: mpsc::UnboundedReceiver<Barrier>,
129
130 chunk_size: usize,
132
133 schema: Schema,
135}
136
137impl MergeExecutor {
138 #[allow(clippy::too_many_arguments)]
139 pub fn new(
140 ctx: ActorContextRef,
141 fragment_id: FragmentId,
142 upstream_fragment_id: FragmentId,
143 upstreams: SelectReceivers,
144 local_barrier_manager: LocalBarrierManager,
145 metrics: Arc<StreamingMetrics>,
146 barrier_rx: mpsc::UnboundedReceiver<Barrier>,
147 chunk_size: usize,
148 schema: Schema,
149 ) -> Self {
150 Self {
151 actor_context: ctx,
152 upstreams,
153 fragment_id,
154 upstream_fragment_id,
155 local_barrier_manager,
156 metrics,
157 barrier_rx,
158 chunk_size,
159 schema,
160 }
161 }
162
163 #[cfg(test)]
164 pub fn for_test(
165 actor_id: impl Into<ActorId>,
166 inputs: Vec<super::exchange::permit::Receiver>,
167 local_barrier_manager: crate::task::LocalBarrierManager,
168 schema: Schema,
169 chunk_size: usize,
170 barrier_rx: Option<mpsc::UnboundedReceiver<Barrier>>,
171 ) -> Self {
172 let actor_id = actor_id.into();
173 use super::exchange::input::LocalInput;
174 use crate::executor::exchange::input::ActorInput;
175
176 let barrier_rx =
177 barrier_rx.unwrap_or_else(|| local_barrier_manager.subscribe_barrier(actor_id));
178
179 let metrics = StreamingMetrics::unused();
180 let actor_ctx = ActorContext::for_test(actor_id);
181 let upstream = Self::new_select_receiver(
182 inputs
183 .into_iter()
184 .enumerate()
185 .map(|(idx, input)| LocalInput::new(input, ActorId::new(idx as u32)).boxed_input())
186 .collect(),
187 &metrics,
188 &actor_ctx,
189 );
190
191 Self::new(
192 actor_ctx,
193 514.into(),
194 1919.into(),
195 upstream,
196 local_barrier_manager,
197 metrics.into(),
198 barrier_rx,
199 chunk_size,
200 schema,
201 )
202 }
203
204 pub(crate) fn new_select_receiver(
205 upstreams: Vec<BoxedActorInput>,
206 metrics: &StreamingMetrics,
207 actor_context: &ActorContext,
208 ) -> SelectReceivers {
209 let merge_barrier_align_duration = if metrics.level >= MetricLevel::Debug {
210 Some(
211 metrics
212 .merge_barrier_align_duration
213 .with_guarded_label_values(&[
214 &actor_context.id.to_string(),
215 &actor_context.fragment_id.to_string(),
216 ]),
217 )
218 } else {
219 None
220 };
221
222 SelectReceivers::new(upstreams, None, merge_barrier_align_duration)
224 }
225
226 #[try_stream(ok = Message, error = StreamExecutorError)]
227 pub(crate) async fn execute_inner(mut self: Box<Self>) {
228 let select_all = self.upstreams;
229 let select_all = BufferChunks::new(select_all, self.chunk_size, self.schema);
230 let actor_id = self.actor_context.id;
231
232 let mut metrics = self.metrics.new_actor_input_metrics(
233 actor_id,
234 self.fragment_id,
235 self.upstream_fragment_id,
236 );
237
238 let mut start_time = Instant::now();
240 pin_mut!(select_all);
241
242 let mut barrier_buffer = DispatchBarrierBuffer::new(
243 self.barrier_rx,
244 actor_id,
245 self.upstream_fragment_id,
246 self.local_barrier_manager,
247 self.metrics.clone(),
248 self.fragment_id,
249 );
250
251 loop {
252 let msg = barrier_buffer.await_next_message(&mut select_all).await?;
253 metrics
254 .actor_input_buffer_blocking_duration_ns
255 .inc_by(start_time.elapsed().as_nanos() as u64);
256
257 let msg = match msg {
258 DispatcherMessage::Watermark(watermark) => Message::Watermark(watermark),
259 DispatcherMessage::Chunk(chunk) => {
260 metrics.actor_in_record_cnt.inc_by(chunk.cardinality() as _);
261 Message::Chunk(chunk)
262 }
263 DispatcherMessage::Barrier(barrier) => {
264 tracing::debug!(
265 target: "events::stream::barrier::path",
266 actor_id = %actor_id,
267 "receiver receives barrier from path: {:?}",
268 barrier.passed_actors
269 );
270 let (mut barrier, new_inputs) =
271 barrier_buffer.pop_barrier_with_inputs(barrier).await?;
272 barrier.passed_actors.push(actor_id);
273
274 if let Some(Mutation::Update(UpdateMutation { dispatchers, .. })) =
275 barrier.mutation.as_deref()
276 && select_all
277 .upstream_input_ids()
278 .any(|actor_id| dispatchers.contains_key(&actor_id))
279 {
280 select_all.flush_buffered_watermarks();
282 }
283
284 if let Some(update) =
285 barrier.as_update_merge(self.actor_context.id, self.upstream_fragment_id)
286 {
287 let new_upstream_fragment_id = update
288 .new_upstream_fragment_id
289 .unwrap_or(self.upstream_fragment_id);
290 let removed_upstream_actor_id: HashSet<_> =
291 if update.new_upstream_fragment_id.is_some() {
292 select_all.upstream_input_ids().collect()
293 } else {
294 update.removed_upstream_actor_id.iter().copied().collect()
295 };
296
297 select_all.flush_buffered_watermarks();
299
300 if let Some(new_upstreams) = new_inputs {
301 select_all.add_upstreams_from(new_upstreams);
303 }
304
305 if !removed_upstream_actor_id.is_empty() {
306 select_all.remove_upstreams(&removed_upstream_actor_id);
308 }
309
310 self.upstream_fragment_id = new_upstream_fragment_id;
311 metrics = self.metrics.new_actor_input_metrics(
312 actor_id,
313 self.fragment_id,
314 self.upstream_fragment_id,
315 );
316 }
317
318 let is_stop = barrier.is_stop(actor_id);
319 let msg = Message::Barrier(barrier);
320 if is_stop {
321 yield msg;
322 break;
323 }
324
325 msg
326 }
327 };
328
329 yield msg;
330 start_time = Instant::now();
331 }
332 }
333}
334
335impl Execute for MergeExecutor {
336 fn execute(self: Box<Self>) -> BoxedMessageStream {
337 self.execute_inner().boxed()
338 }
339}
340
341struct BufferChunks<S: Stream> {
345 inner: S,
346 chunk_builder: StreamChunkBuilder,
347
348 pending_items: VecDeque<S::Item>,
350}
351
352impl<S: Stream> BufferChunks<S> {
353 pub(super) fn new(inner: S, chunk_size: usize, schema: Schema) -> Self {
354 assert!(chunk_size > 0);
355 let chunk_builder = StreamChunkBuilder::new(chunk_size, schema.data_types());
356 Self {
357 inner,
358 chunk_builder,
359 pending_items: VecDeque::new(),
360 }
361 }
362}
363
364impl<S: Stream> std::ops::Deref for BufferChunks<S> {
365 type Target = S;
366
367 fn deref(&self) -> &Self::Target {
368 &self.inner
369 }
370}
371
372impl<S: Stream> std::ops::DerefMut for BufferChunks<S> {
373 fn deref_mut(&mut self) -> &mut Self::Target {
374 &mut self.inner
375 }
376}
377
378impl<S: Stream> Stream for BufferChunks<S>
379where
380 S: Stream<Item = DispatcherMessageStreamItem> + Unpin,
381{
382 type Item = S::Item;
383
384 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
385 loop {
386 if let Some(item) = self.pending_items.pop_front() {
387 return Poll::Ready(Some(item));
388 }
389
390 match self.inner.poll_next_unpin(cx) {
391 Poll::Pending => {
392 return if let Some(chunk_out) = self.chunk_builder.take() {
393 Poll::Ready(Some(Ok(MessageInner::Chunk(chunk_out))))
394 } else {
395 Poll::Pending
396 };
397 }
398
399 Poll::Ready(Some(result)) => {
400 if let Ok(MessageInner::Chunk(chunk)) = result {
401 for row in chunk.records() {
402 if let Some(chunk_out) = self.chunk_builder.append_record(row) {
403 self.pending_items
404 .push_back(Ok(MessageInner::Chunk(chunk_out)));
405 }
406 }
407 } else {
408 return if let Some(chunk_out) = self.chunk_builder.take() {
409 self.pending_items.push_back(result);
410 Poll::Ready(Some(Ok(MessageInner::Chunk(chunk_out))))
411 } else {
412 Poll::Ready(Some(result))
413 };
414 }
415 }
416
417 Poll::Ready(None) => {
418 unreachable!("Merge should always have upstream inputs");
420 }
421 }
422 }
423 }
424}
425
426#[cfg(test)]
427mod tests {
428 use std::sync::atomic::{AtomicBool, Ordering};
429 use std::time::Duration;
430
431 use assert_matches::assert_matches;
432 use futures::FutureExt;
433 use futures::future::try_join_all;
434 use risingwave_common::array::Op;
435 use risingwave_common::util::epoch::test_epoch;
436 use risingwave_pb::task_service::exchange_service_server::{
437 ExchangeService, ExchangeServiceServer,
438 };
439 use risingwave_pb::task_service::{
440 GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse, PbPermits,
441 };
442 use tokio::time::sleep;
443 use tokio_stream::wrappers::ReceiverStream;
444 use tonic::{Request, Response, Status, Streaming};
445
446 use super::*;
447 use crate::executor::exchange::input::{ActorInput, LocalInput, RemoteInput};
448 use crate::executor::exchange::permit::channel_for_test;
449 use crate::executor::{BarrierInner as Barrier, MessageInner as Message};
450 use crate::task::NewOutputRequest;
451 use crate::task::barrier_test_utils::LocalBarrierTestEnv;
452 use crate::task::test_utils::helper_make_local_actor;
453
454 fn build_test_chunk(size: u64) -> StreamChunk {
455 let ops = vec![Op::Insert; size as usize];
456 StreamChunk::new(ops, vec![])
457 }
458
459 #[tokio::test]
460 async fn test_buffer_chunks() {
461 let test_env = LocalBarrierTestEnv::for_test().await;
462
463 let (tx, rx) = channel_for_test();
464 let input = LocalInput::new(rx, 1.into()).boxed_input();
465 let mut buffer = BufferChunks::new(input, 100, Schema::new(vec![]));
466
467 tx.send(Message::Chunk(build_test_chunk(10)).into())
469 .await
470 .unwrap();
471 assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
472 assert_eq!(chunk.ops().len() as u64, 10);
473 });
474
475 tx.send(Message::Chunk(build_test_chunk(10)).into())
477 .await
478 .unwrap();
479 tx.send(Message::Chunk(build_test_chunk(10)).into())
480 .await
481 .unwrap();
482 assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
483 assert_eq!(chunk.ops().len() as u64, 20);
484 });
485
486 tx.send(
488 Message::Watermark(Watermark {
489 col_idx: 0,
490 data_type: DataType::Int64,
491 val: ScalarImpl::Int64(233),
492 })
493 .into(),
494 )
495 .await
496 .unwrap();
497 assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Watermark(watermark) => {
498 assert_eq!(watermark.val, ScalarImpl::Int64(233));
499 });
500
501 tx.send(Message::Chunk(build_test_chunk(10)).into())
503 .await
504 .unwrap();
505 tx.send(Message::Chunk(build_test_chunk(10)).into())
506 .await
507 .unwrap();
508 tx.send(
509 Message::Watermark(Watermark {
510 col_idx: 0,
511 data_type: DataType::Int64,
512 val: ScalarImpl::Int64(233),
513 })
514 .into(),
515 )
516 .await
517 .unwrap();
518 assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
519 assert_eq!(chunk.ops().len() as u64, 20);
520 });
521 assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Watermark(watermark) => {
522 assert_eq!(watermark.val, ScalarImpl::Int64(233));
523 });
524
525 let barrier = Barrier::new_test_barrier(test_epoch(1));
527 test_env.inject_barrier(&barrier, [2.into()]);
528 tx.send(Message::Barrier(barrier.clone().into_dispatcher()).into())
529 .await
530 .unwrap();
531 assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Barrier(Barrier { epoch: barrier_epoch, mutation: _, .. }) => {
532 assert_eq!(barrier_epoch.curr, test_epoch(1));
533 });
534
535 tx.send(Message::Chunk(build_test_chunk(10)).into())
537 .await
538 .unwrap();
539 tx.send(Message::Chunk(build_test_chunk(10)).into())
540 .await
541 .unwrap();
542 let barrier = Barrier::new_test_barrier(test_epoch(2));
543 test_env.inject_barrier(&barrier, [2.into()]);
544 tx.send(Message::Barrier(barrier.clone().into_dispatcher()).into())
545 .await
546 .unwrap();
547 assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
548 assert_eq!(chunk.ops().len() as u64, 20);
549 });
550 assert_matches!(buffer.next().await.unwrap().unwrap(), Message::Barrier(Barrier { epoch: barrier_epoch, mutation: _, .. }) => {
551 assert_eq!(barrier_epoch.curr, test_epoch(2));
552 });
553 }
554
555 #[tokio::test]
556 async fn test_merger() {
557 const CHANNEL_NUMBER: usize = 10;
558 let mut txs = Vec::with_capacity(CHANNEL_NUMBER);
559 let mut rxs = Vec::with_capacity(CHANNEL_NUMBER);
560 for _i in 0..CHANNEL_NUMBER {
561 let (tx, rx) = channel_for_test();
562 txs.push(tx);
563 rxs.push(rx);
564 }
565 let barrier_test_env = LocalBarrierTestEnv::for_test().await;
566 let actor_id = 233.into();
567 let mut handles = Vec::with_capacity(CHANNEL_NUMBER);
568
569 let epochs = (10..1000u64)
570 .step_by(10)
571 .map(|idx| (idx, test_epoch(idx)))
572 .collect_vec();
573 let mut prev_epoch = 0;
574 let prev_epoch = &mut prev_epoch;
575 let barriers: HashMap<_, _> = epochs
576 .iter()
577 .map(|(_, epoch)| {
578 let barrier = Barrier::with_prev_epoch_for_test(*epoch, *prev_epoch);
579 *prev_epoch = *epoch;
580 barrier_test_env.inject_barrier(&barrier, [actor_id]);
581 (*epoch, barrier)
582 })
583 .collect();
584 let b2 = Barrier::with_prev_epoch_for_test(test_epoch(1000), *prev_epoch)
585 .with_mutation(Mutation::Stop(StopMutation::default()));
586 barrier_test_env.inject_barrier(&b2, [actor_id]);
587 barrier_test_env.flush_all_events().await;
588
589 for (tx_id, tx) in txs.into_iter().enumerate() {
590 let epochs = epochs.clone();
591 let barriers = barriers.clone();
592 let b2 = b2.clone();
593 let handle = tokio::spawn(async move {
594 for (idx, epoch) in epochs {
595 if idx % 20 == 0 {
596 tx.send(Message::Chunk(build_test_chunk(10)).into())
597 .await
598 .unwrap();
599 } else {
600 tx.send(
601 Message::Watermark(Watermark {
602 col_idx: (idx as usize / 20 + tx_id) % CHANNEL_NUMBER,
603 data_type: DataType::Int64,
604 val: ScalarImpl::Int64(idx as i64),
605 })
606 .into(),
607 )
608 .await
609 .unwrap();
610 }
611 tx.send(Message::Barrier(barriers[&epoch].clone().into_dispatcher()).into())
612 .await
613 .unwrap();
614 sleep(Duration::from_millis(1)).await;
615 }
616 tx.send(Message::Barrier(b2.clone().into_dispatcher()).into())
617 .await
618 .unwrap();
619 });
620 handles.push(handle);
621 }
622
623 let merger = MergeExecutor::for_test(
624 actor_id,
625 rxs,
626 barrier_test_env.local_barrier_manager.clone(),
627 Schema::new(vec![]),
628 100,
629 None,
630 );
631 let mut merger = merger.boxed().execute();
632 for (idx, epoch) in epochs {
633 if idx % 20 == 0 {
634 let mut count = 0usize;
636 while count < 100 {
637 assert_matches!(merger.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
638 count += chunk.ops().len();
639 });
640 }
641 assert_eq!(count, 100);
642 } else if idx as usize / 20 >= CHANNEL_NUMBER - 1 {
643 for _ in 0..CHANNEL_NUMBER {
645 assert_matches!(merger.next().await.unwrap().unwrap(), Message::Watermark(watermark) => {
646 assert_eq!(watermark.val, ScalarImpl::Int64((idx - 20 * (CHANNEL_NUMBER as u64 - 1)) as i64));
647 });
648 }
649 }
650 assert_matches!(merger.next().await.unwrap().unwrap(), Message::Barrier(Barrier{epoch:barrier_epoch,mutation:_,..}) => {
652 assert_eq!(barrier_epoch.curr, epoch);
653 });
654 }
655 assert_matches!(
656 merger.next().await.unwrap().unwrap(),
657 Message::Barrier(Barrier {
658 mutation,
659 ..
660 }) if mutation.as_deref().unwrap().is_stop()
661 );
662
663 for handle in handles {
664 handle.await.unwrap();
665 }
666 }
667
668 #[tokio::test]
669 async fn test_configuration_change() {
670 let actor_id = 233.into();
671 let (untouched, old, new) = (234.into(), 235.into(), 238.into()); let barrier_test_env = LocalBarrierTestEnv::for_test().await;
673 let metrics = Arc::new(StreamingMetrics::unused());
674
675 let (upstream_fragment_id, fragment_id) = (10.into(), 18.into());
680
681 let inputs: Vec<_> =
682 try_join_all([untouched, old].into_iter().map(async |upstream_actor_id| {
683 new_input(
684 &barrier_test_env.local_barrier_manager,
685 metrics.clone(),
686 actor_id,
687 fragment_id,
688 &helper_make_local_actor(upstream_actor_id),
689 upstream_fragment_id,
690 )
691 .await
692 }))
693 .await
694 .unwrap();
695
696 let merge_updates = maplit::hashmap! {
697 (actor_id, upstream_fragment_id) => MergeUpdate {
698 actor_id,
699 upstream_fragment_id,
700 new_upstream_fragment_id: None,
701 added_upstream_actors: vec![helper_make_local_actor(new)],
702 removed_upstream_actor_id: vec![old],
703 }
704 };
705
706 let b1 = Barrier::new_test_barrier(test_epoch(1)).with_mutation(Mutation::Update(
707 UpdateMutation {
708 merges: merge_updates,
709 ..Default::default()
710 },
711 ));
712 barrier_test_env.inject_barrier(&b1, [actor_id]);
713 barrier_test_env.flush_all_events().await;
714
715 let barrier_rx = barrier_test_env
716 .local_barrier_manager
717 .subscribe_barrier(actor_id);
718 let actor_ctx = ActorContext::for_test(actor_id);
719 let upstream = MergeExecutor::new_select_receiver(inputs, &metrics, &actor_ctx);
720
721 let mut merge = MergeExecutor::new(
722 actor_ctx,
723 fragment_id,
724 upstream_fragment_id,
725 upstream,
726 barrier_test_env.local_barrier_manager.clone(),
727 metrics.clone(),
728 barrier_rx,
729 100,
730 Schema::new(vec![]),
731 )
732 .boxed()
733 .execute();
734
735 let mut txs = HashMap::new();
736 macro_rules! send {
737 ($actors:expr, $msg:expr) => {
738 for actor in $actors {
739 txs.get(&actor).unwrap().send($msg).await.unwrap();
740 }
741 };
742 }
743
744 macro_rules! assert_recv_pending {
745 () => {
746 assert!(
747 merge
748 .next()
749 .now_or_never()
750 .flatten()
751 .transpose()
752 .unwrap()
753 .is_none()
754 );
755 };
756 }
757 macro_rules! recv {
758 () => {
759 merge.next().await.transpose().unwrap()
760 };
761 }
762
763 macro_rules! collect_upstream_tx {
764 ($actors:expr) => {
765 for upstream_id in $actors {
766 let mut output_requests = barrier_test_env
767 .take_pending_new_output_requests(upstream_id)
768 .await;
769 assert_eq!(output_requests.len(), 1);
770 let (downstream_actor_id, request) = output_requests.pop().unwrap();
771 assert_eq!(downstream_actor_id, actor_id);
772 let NewOutputRequest::Local(tx) = request else {
773 unreachable!()
774 };
775 txs.insert(upstream_id, tx);
776 }
777 };
778 }
779
780 assert_recv_pending!();
781 barrier_test_env.flush_all_events().await;
782
783 collect_upstream_tx!([untouched, old]);
785
786 send!([untouched, old], Message::Chunk(build_test_chunk(1)).into());
788 assert_eq!(2, recv!().unwrap().as_chunk().unwrap().cardinality()); assert_recv_pending!();
790
791 send!(
792 [untouched, old],
793 Message::Barrier(b1.clone().into_dispatcher()).into()
794 );
795 assert_recv_pending!(); collect_upstream_tx!([new]);
798
799 send!([new], Message::Barrier(b1.clone().into_dispatcher()).into());
800 recv!().unwrap().as_barrier().unwrap(); send!([untouched, new], Message::Chunk(build_test_chunk(1)).into());
804 assert_eq!(2, recv!().unwrap().as_chunk().unwrap().cardinality()); assert_recv_pending!();
806 }
807
808 struct FakeExchangeService {
809 rpc_called: Arc<AtomicBool>,
810 }
811
812 fn exchange_client_test_barrier() -> crate::executor::Barrier {
813 Barrier::new_test_barrier(test_epoch(1))
814 }
815
816 #[async_trait::async_trait]
817 impl ExchangeService for FakeExchangeService {
818 type GetDataStream = ReceiverStream<std::result::Result<GetDataResponse, Status>>;
819 type GetStreamStream = ReceiverStream<std::result::Result<GetStreamResponse, Status>>;
820
821 async fn get_data(
822 &self,
823 _: Request<GetDataRequest>,
824 ) -> std::result::Result<Response<Self::GetDataStream>, Status> {
825 unimplemented!()
826 }
827
828 async fn get_stream(
829 &self,
830 _request: Request<Streaming<GetStreamRequest>>,
831 ) -> std::result::Result<Response<Self::GetStreamStream>, Status> {
832 let (tx, rx) = tokio::sync::mpsc::channel(10);
833 self.rpc_called.store(true, Ordering::SeqCst);
834 let stream_chunk = StreamChunk::default().to_protobuf();
836 tx.send(Ok(GetStreamResponse {
837 message: Some(PbStreamMessageBatch {
838 stream_message_batch: Some(
839 risingwave_pb::stream_plan::stream_message_batch::StreamMessageBatch::StreamChunk(
840 stream_chunk,
841 ),
842 ),
843 }),
844 permits: Some(PbPermits::default()),
845 }))
846 .await
847 .unwrap();
848 let barrier = exchange_client_test_barrier();
850 tx.send(Ok(GetStreamResponse {
851 message: Some(PbStreamMessageBatch {
852 stream_message_batch: Some(
853 risingwave_pb::stream_plan::stream_message_batch::StreamMessageBatch::BarrierBatch(
854 BarrierBatch {
855 barriers: vec![barrier.to_protobuf()],
856 },
857 ),
858 ),
859 }),
860 permits: Some(PbPermits::default()),
861 }))
862 .await
863 .unwrap();
864 Ok(Response::new(ReceiverStream::new(rx)))
865 }
866 }
867
868 #[tokio::test]
869 async fn test_stream_exchange_client() {
870 let rpc_called = Arc::new(AtomicBool::new(false));
871 let server_run = Arc::new(AtomicBool::new(false));
872 let addr = "127.0.0.1:12348".parse().unwrap();
873
874 let (shutdown_send, shutdown_recv) = tokio::sync::oneshot::channel();
876 let exchange_svc = ExchangeServiceServer::new(FakeExchangeService {
877 rpc_called: rpc_called.clone(),
878 });
879 let cp_server_run = server_run.clone();
880 let join_handle = tokio::spawn(async move {
881 cp_server_run.store(true, Ordering::SeqCst);
882 tonic::transport::Server::builder()
883 .add_service(exchange_svc)
884 .serve_with_shutdown(addr, async move {
885 shutdown_recv.await.unwrap();
886 })
887 .await
888 .unwrap();
889 });
890
891 sleep(Duration::from_secs(1)).await;
892 assert!(server_run.load(Ordering::SeqCst));
893
894 let test_env = LocalBarrierTestEnv::for_test().await;
895
896 let remote_input = {
897 RemoteInput::new(
898 &test_env.local_barrier_manager,
899 addr.into(),
900 (0.into(), 0.into()),
901 (0.into(), 0.into()),
902 Arc::new(StreamingMetrics::unused()),
903 )
904 .await
905 .unwrap()
906 };
907
908 test_env.inject_barrier(&exchange_client_test_barrier(), [remote_input.id()]);
909
910 pin_mut!(remote_input);
911
912 assert_matches!(remote_input.next().await.unwrap().unwrap(), Message::Chunk(chunk) => {
913 let (ops, columns, visibility) = chunk.into_inner();
914 assert!(ops.is_empty());
915 assert!(columns.is_empty());
916 assert!(visibility.is_empty());
917 });
918 assert_matches!(remote_input.next().await.unwrap().unwrap(), Message::Barrier(Barrier { epoch: barrier_epoch, mutation: _, .. }) => {
919 assert_eq!(barrier_epoch.curr, test_epoch(1));
920 });
921 assert!(rpc_called.load(Ordering::SeqCst));
922
923 shutdown_send.send(()).unwrap();
924 join_handle.await.unwrap();
925 }
926}