1use std::collections::HashMap;
16use std::iter;
17use std::pin::Pin;
18use std::task::{Context, Poll};
19
20use anyhow::Context as _;
21use futures::future::try_join_all;
22use pin_project::pin_project;
23use risingwave_common::catalog::Field;
24use risingwave_expr::expr::{EvalErrorReport, NonStrictExpression, build_non_strict_from_prost};
25use risingwave_pb::common::PbActorInfo;
26use risingwave_pb::expr::PbExprNode;
27use risingwave_pb::plan_common::PbField;
28use risingwave_pb::stream_service::inject_barrier_request::build_actor_info::UpstreamActors;
29use rw_futures_util::pending_on_none;
30use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
31
32use crate::executor::exchange::input::{Input, assert_equal_dispatcher_barrier, new_input};
33use crate::executor::prelude::*;
34use crate::executor::project::apply_project_exprs;
35use crate::executor::{
36 BarrierMutationType, BoxedMessageInput, DynamicReceivers, MergeExecutor, Message,
37};
38use crate::task::{ActorEvalErrorReport, FragmentId, LocalBarrierManager};
39
40type ProcessedMessageStream = impl Stream<Item = MessageStreamItem>;
41
42#[pin_project]
45pub struct SinkHandlerInput {
46 upstream_fragment_id: FragmentId,
48
49 #[pin]
51 processed_stream: ProcessedMessageStream,
52}
53
54impl SinkHandlerInput {
55 pub fn new(
56 upstream_fragment_id: FragmentId,
57 merge: Box<MergeExecutor>,
58 project_exprs: Vec<NonStrictExpression>,
59 ) -> Self {
60 let processed_stream = Self::apply_project_exprs_stream(merge, project_exprs);
61 Self {
62 upstream_fragment_id,
63 processed_stream,
64 }
65 }
66
67 #[define_opaque(ProcessedMessageStream)]
68 fn apply_project_exprs_stream(
69 merge: Box<MergeExecutor>,
70 project_exprs: Vec<NonStrictExpression>,
71 ) -> ProcessedMessageStream {
72 Self::apply_project_exprs_stream_impl(merge, project_exprs)
74 }
75
76 #[try_stream(ok = Message, error = StreamExecutorError)]
78 async fn apply_project_exprs_stream_impl(
79 merge: Box<MergeExecutor>,
80 project_exprs: Vec<NonStrictExpression>,
81 ) {
82 let merge_stream = merge.execute_inner();
83 pin_mut!(merge_stream);
84 while let Some(msg) = merge_stream.next().await {
85 let msg = msg?;
86 if let Message::Chunk(chunk) = msg {
87 let new_chunk = apply_project_exprs(&project_exprs, chunk).await?;
89 yield Message::Chunk(new_chunk);
90 } else {
91 yield msg;
92 }
93 }
94 }
95}
96
97impl Input for SinkHandlerInput {
98 type InputId = FragmentId;
99
100 fn id(&self) -> Self::InputId {
101 self.upstream_fragment_id
103 }
104}
105
106impl Stream for SinkHandlerInput {
107 type Item = MessageStreamItem;
108
109 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
110 self.project().processed_stream.poll_next(cx)
111 }
112}
113
114#[derive(Debug)]
116pub struct UpstreamFragmentInfo {
117 pub upstream_fragment_id: FragmentId,
118 pub upstream_actors: Vec<PbActorInfo>,
119 pub merge_schema: Schema,
120 pub project_exprs: Vec<NonStrictExpression>,
121}
122
123impl UpstreamFragmentInfo {
124 pub fn new(
125 upstream_fragment_id: FragmentId,
126 initial_upstream_actors: &HashMap<FragmentId, UpstreamActors>,
127 sink_output_schema: &[PbField],
128 project_exprs: &[PbExprNode],
129 error_report: impl EvalErrorReport + 'static,
130 ) -> StreamResult<Self> {
131 let actors = initial_upstream_actors
132 .get(&upstream_fragment_id)
133 .ok_or_else(|| {
134 anyhow::anyhow!(
135 "upstream fragment {} not found in initial upstream actors",
136 upstream_fragment_id
137 )
138 })?;
139 let merge_schema = sink_output_schema.iter().map(Field::from).collect();
140 let project_exprs = project_exprs
141 .iter()
142 .map(|e| build_non_strict_from_prost(e, error_report.clone()))
143 .try_collect()
144 .map_err(|err| anyhow::anyhow!(err))?;
145 Ok(Self {
146 upstream_fragment_id,
147 upstream_actors: actors.actors.clone(),
148 merge_schema,
149 project_exprs,
150 })
151 }
152}
153
154type BoxedSinkInput = BoxedMessageInput<FragmentId, BarrierMutationType>;
155
156pub struct UpstreamSinkUnionExecutor {
168 actor_context: ActorContextRef,
170
171 local_barrier_manager: LocalBarrierManager,
173
174 executor_stats: Arc<StreamingMetrics>,
176
177 chunk_size: usize,
179
180 initial_upstream_infos: Vec<UpstreamFragmentInfo>,
182
183 eval_error_report: ActorEvalErrorReport,
185
186 barrier_rx: UnboundedReceiver<Barrier>,
188
189 barrier_tx_map: HashMap<FragmentId, UnboundedSender<Barrier>>,
191}
192
193impl Debug for UpstreamSinkUnionExecutor {
194 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
195 f.debug_struct("UpstreamSinkUnionExecutor")
196 .field("initial_upstream_infos", &self.initial_upstream_infos)
197 .finish()
198 }
199}
200
201impl Execute for UpstreamSinkUnionExecutor {
202 fn execute(self: Box<Self>) -> BoxedMessageStream {
203 self.execute_inner().boxed()
204 }
205}
206
207impl UpstreamSinkUnionExecutor {
208 pub fn new(
209 ctx: ActorContextRef,
210 local_barrier_manager: LocalBarrierManager,
211 executor_stats: Arc<StreamingMetrics>,
212 chunk_size: usize,
213 initial_upstream_infos: Vec<UpstreamFragmentInfo>,
214 eval_error_report: ActorEvalErrorReport,
215 ) -> Self {
216 let barrier_rx = local_barrier_manager.subscribe_barrier(ctx.id);
217 Self {
218 actor_context: ctx,
219 local_barrier_manager,
220 executor_stats,
221 chunk_size,
222 initial_upstream_infos,
223 eval_error_report,
224 barrier_rx,
225 barrier_tx_map: Default::default(),
226 }
227 }
228
229 #[cfg(test)]
230 pub fn for_test(
231 actor_id: ActorId,
232 local_barrier_manager: LocalBarrierManager,
233 chunk_size: usize,
234 ) -> Self {
235 let metrics = StreamingMetrics::unused();
236 let actor_ctx = ActorContext::for_test(actor_id);
237 let barrier_rx = local_barrier_manager.subscribe_barrier(actor_id);
238 Self {
239 actor_context: actor_ctx.clone(),
240 local_barrier_manager,
241 executor_stats: metrics.into(),
242 chunk_size,
243 initial_upstream_infos: vec![],
244 eval_error_report: ActorEvalErrorReport {
245 actor_context: actor_ctx,
246 identity: format!("UpstreamSinkUnionExecutor-{}", actor_id).into(),
247 },
248 barrier_rx,
249 barrier_tx_map: Default::default(),
250 }
251 }
252
253 fn subscribe_local_barrier(&mut self, fragment_id: FragmentId) -> UnboundedReceiver<Barrier> {
254 let (tx, rx) = unbounded_channel();
255 self.barrier_tx_map
256 .try_insert(fragment_id, tx)
257 .expect("non-duplicate");
258 rx
259 }
260
261 async fn new_sink_input(
262 &mut self,
263 UpstreamFragmentInfo {
264 upstream_fragment_id,
265 upstream_actors,
266 merge_schema,
267 project_exprs,
268 }: UpstreamFragmentInfo,
269 ) -> StreamExecutorResult<BoxedSinkInput> {
270 let merge_executor = self
271 .new_merge_executor(upstream_fragment_id, upstream_actors, merge_schema)
272 .await?;
273
274 Ok(SinkHandlerInput::new(
275 upstream_fragment_id,
276 Box::new(merge_executor),
277 project_exprs,
278 )
279 .boxed_input())
280 }
281
282 async fn new_merge_executor(
283 &mut self,
284 upstream_fragment_id: FragmentId,
285 upstream_actors: Vec<PbActorInfo>,
286 schema: Schema,
287 ) -> StreamExecutorResult<MergeExecutor> {
288 let barrier_rx = self.subscribe_local_barrier(upstream_fragment_id);
289
290 let inputs = try_join_all(upstream_actors.iter().map(|actor| {
291 new_input(
292 &self.local_barrier_manager,
293 self.executor_stats.clone(),
294 self.actor_context.id,
295 self.actor_context.fragment_id,
296 actor,
297 upstream_fragment_id,
298 )
299 }))
300 .await?;
301
302 let upstreams =
303 MergeExecutor::new_select_receiver(inputs, &self.executor_stats, &self.actor_context);
304
305 Ok(MergeExecutor::new(
306 self.actor_context.clone(),
307 self.actor_context.fragment_id,
308 upstream_fragment_id,
309 upstreams,
310 self.local_barrier_manager.clone(),
311 self.executor_stats.clone(),
312 barrier_rx,
313 self.chunk_size,
314 schema,
315 ))
316 }
317
318 #[try_stream(ok = Message, error = StreamExecutorError)]
319 async fn execute_inner(mut self: Box<Self>) {
320 let inputs: Vec<_> = {
321 let initial_upstream_infos = std::mem::take(&mut self.initial_upstream_infos);
322 let mut inputs = Vec::with_capacity(initial_upstream_infos.len());
323 for UpstreamFragmentInfo {
324 upstream_fragment_id,
325 upstream_actors,
326 merge_schema,
327 project_exprs,
328 } in initial_upstream_infos
329 {
330 let merge_executor = self
331 .new_merge_executor(upstream_fragment_id, upstream_actors, merge_schema)
332 .await?;
333
334 let input = SinkHandlerInput::new(
335 upstream_fragment_id,
336 Box::new(merge_executor),
337 project_exprs,
338 )
339 .boxed_input();
340
341 inputs.push(input);
342 }
343 inputs
344 };
345
346 let execution_stream = self.execute_with_inputs(inputs);
347 pin_mut!(execution_stream);
348 while let Some(msg) = execution_stream.next().await {
349 yield msg?;
350 }
351 }
352
353 async fn handle_update(
354 &mut self,
355 upstreams: &mut DynamicReceivers<FragmentId, BarrierMutationType>,
356 barrier: &Barrier,
357 ) -> StreamExecutorResult<()> {
358 let fragment_id = self.actor_context.fragment_id;
359 if let Some(new_upstream_sink) = barrier.as_new_upstream_sink(fragment_id) {
360 let info = new_upstream_sink.get_info().unwrap();
362 let merge_schema = info
363 .get_sink_output_schema()
364 .iter()
365 .map(Field::from)
366 .collect();
367 let project_exprs = info
368 .get_project_exprs()
369 .iter()
370 .map(|e| build_non_strict_from_prost(e, self.eval_error_report.clone()))
371 .try_collect()
372 .map_err(|err| anyhow::anyhow!(err))?;
373 let mut new_input = self
374 .new_sink_input(UpstreamFragmentInfo {
375 upstream_fragment_id: info.get_upstream_fragment_id(),
376 upstream_actors: new_upstream_sink.get_upstream_actors().clone(),
377 merge_schema,
378 project_exprs,
379 })
380 .await?;
381 self.barrier_tx_map
382 .get(&info.get_upstream_fragment_id())
383 .unwrap()
384 .send(barrier.clone())
385 .map_err(|e| StreamExecutorError::from(anyhow::anyhow!(e)))?;
386
387 let new_barrier = expect_first_barrier(&mut new_input).await?;
388 assert_equal_dispatcher_barrier(barrier, &new_barrier);
389
390 upstreams.add_upstreams_from(iter::once(new_input));
391 }
392
393 if let Some(dropped_upstream_sinks) = barrier.as_dropped_upstream_sinks()
394 && !dropped_upstream_sinks.is_empty()
395 {
396 upstreams.remove_upstreams(dropped_upstream_sinks);
398 for upstream_fragment_id in dropped_upstream_sinks {
399 self.barrier_tx_map.remove(upstream_fragment_id);
400 }
401 }
402
403 Ok(())
404 }
405
406 #[try_stream(ok = Message, error = StreamExecutorError)]
407 async fn execute_with_inputs(mut self: Box<Self>, inputs: Vec<BoxedSinkInput>) {
408 let actor_id = self.actor_context.id;
409 let fragment_id = self.actor_context.fragment_id;
410
411 let barrier_align = self
412 .executor_stats
413 .barrier_align_duration
414 .with_guarded_label_values(&[
415 actor_id.to_string().as_str(),
416 fragment_id.to_string().as_str(),
417 "",
418 "UpstreamSinkUnion",
419 ]);
420
421 let upstreams = DynamicReceivers::new(inputs, Some(barrier_align.clone()), None);
422 pin_mut!(upstreams);
423
424 let mut current_barrier = None;
425
426 let mut select_once = async || -> StreamExecutorResult<Message> {
429 loop {
430 tokio::select! {
431 biased;
432
433 msg = pending_on_none(upstreams.next()) => {
436 let msg = msg?;
437 if let Message::Barrier(barrier) = &msg {
438 let current_barrier = current_barrier.take().unwrap();
439 assert_equal_dispatcher_barrier(¤t_barrier, barrier);
440 self.handle_update(&mut upstreams, barrier).await?;
441 }
442 return Ok(msg);
443 }
444
445 barrier = self.barrier_rx.recv(), if current_barrier.is_none() => {
446 let barrier = barrier.context("Failed to receive barrier from barrier_rx")?;
447 if upstreams.is_empty() {
451 self.handle_update(&mut upstreams, &barrier).await?;
452 return Ok(Message::Barrier(barrier.clone()));
453 } else {
454 for tx in self.barrier_tx_map.values() {
455 tx.send(barrier.clone())
456 .map_err(|e| StreamExecutorError::from(anyhow::anyhow!(e)))?;
457 }
458 current_barrier = Some(barrier);
459 continue;
460 }
461 }
462 }
463 }
464 };
465
466 loop {
467 yield select_once().await?;
468 }
469 }
470}
471
472#[cfg(test)]
473mod tests {
474 use std::collections::HashSet;
475
476 use futures::FutureExt;
477 use risingwave_common::array::{Op, StreamChunkTestExt};
478 use risingwave_common::catalog::Field;
479 use risingwave_common::util::epoch::test_epoch;
480 use risingwave_pb::stream_plan::PbUpstreamSinkInfo;
481 use risingwave_pb::stream_plan::add_mutation::PbNewUpstreamSink;
482
483 use super::*;
484 use crate::executor::exchange::permit::{Sender, channel_for_test};
485 use crate::executor::test_utils::expr::build_from_pretty;
486 use crate::executor::{AddMutation, MessageInner, StopMutation};
487 use crate::task::NewOutputRequest;
488 use crate::task::barrier_test_utils::LocalBarrierTestEnv;
489 use crate::task::test_utils::helper_make_local_actor;
490
491 #[tokio::test]
492 async fn test_sink_input() {
493 let test_env = LocalBarrierTestEnv::for_test().await;
494
495 let actor_id = 2;
496
497 let b1 = Barrier::with_prev_epoch_for_test(2, 1);
498
499 test_env.inject_barrier(&b1, [actor_id]);
500 test_env.flush_all_events().await;
501
502 let schema = Schema {
503 fields: vec![
504 Field::unnamed(DataType::Int64),
505 Field::unnamed(DataType::Int64),
506 ],
507 };
508
509 let (tx1, rx1) = channel_for_test();
510 let (tx2, rx2) = channel_for_test();
511
512 let merge = MergeExecutor::for_test(
513 actor_id,
514 vec![rx1, rx2],
515 test_env.local_barrier_manager.clone(),
516 schema.clone(),
517 5,
518 None,
519 );
520
521 let test_expr = build_from_pretty("$1:int8");
522
523 let mut input = SinkHandlerInput::new(
524 1919, Box::new(merge),
526 vec![test_expr],
527 )
528 .boxed_input();
529
530 let chunk1 = StreamChunk::from_pretty(
531 " I I
532 + 1 4
533 + 2 5
534 + 3 6",
535 );
536 let chunk2 = StreamChunk::from_pretty(
537 " I I
538 + 7 8
539 - 3 6",
540 );
541
542 tx1.send(MessageInner::Chunk(chunk1).into()).await.unwrap();
543 tx2.send(MessageInner::Chunk(chunk2).into()).await.unwrap();
544
545 let msg = input.next().await.unwrap().unwrap();
546 assert_eq!(
547 *msg.as_chunk().unwrap(),
548 StreamChunk::from_pretty(
549 " I
550 + 4
551 + 5
552 + 6
553 + 8
554 - 6"
555 )
556 );
557 }
558
559 fn new_input_for_test(
560 actor_id: ActorId,
561 local_barrier_manager: LocalBarrierManager,
562 ) -> (BoxedSinkInput, Sender, UnboundedSender<Barrier>) {
563 let (tx, rx) = channel_for_test();
564 let (barrier_tx, barrier_rx) = unbounded_channel();
565 let merge = MergeExecutor::for_test(
566 actor_id,
567 vec![rx],
568 local_barrier_manager,
569 Schema::new(vec![]),
570 10,
571 Some(barrier_rx),
572 );
573 let input = SinkHandlerInput::new(actor_id, Box::new(merge), vec![]).boxed_input();
574 (input, tx, barrier_tx)
575 }
576
577 fn build_test_chunk(size: u64) -> StreamChunk {
578 let ops = vec![Op::Insert; size as usize];
579 StreamChunk::new(ops, vec![])
580 }
581
582 #[tokio::test]
583 async fn test_fixed_upstreams() {
584 let test_env = LocalBarrierTestEnv::for_test().await;
585
586 let actor_id = 2;
587
588 let b1 = Barrier::with_prev_epoch_for_test(2, 1);
589
590 test_env.inject_barrier(&b1, [actor_id]);
591 test_env.flush_all_events().await;
592
593 let mut inputs = Vec::with_capacity(3);
594 let mut txs = Vec::with_capacity(3);
595 let mut barrier_txs = Vec::with_capacity(3);
596 for _ in 0..3 {
597 let (input, tx, barrier_tx) =
598 new_input_for_test(actor_id, test_env.local_barrier_manager.clone());
599 inputs.push(input);
600 txs.push(tx);
601 barrier_txs.push(barrier_tx);
602 }
603
604 let sink_union = UpstreamSinkUnionExecutor::for_test(
605 actor_id,
606 test_env.local_barrier_manager.clone(),
607 10,
608 );
609 test_env.flush_all_events().await;
611 let mut sink_union = Box::new(sink_union).execute_with_inputs(inputs).boxed();
612
613 for tx in txs {
614 tx.send(MessageInner::Chunk(build_test_chunk(10)).into())
615 .await
616 .unwrap();
617 tx.send(MessageInner::Chunk(build_test_chunk(10)).into())
618 .await
619 .unwrap();
620 tx.send(MessageInner::Barrier(b1.clone().into_dispatcher()).into())
621 .await
622 .unwrap();
623 }
624
625 for _ in 0..6 {
626 let msg = sink_union.next().await.unwrap().unwrap();
627 assert!(msg.is_chunk());
628 assert_eq!(msg.as_chunk().unwrap().ops().len(), 10);
629 }
630
631 assert!(sink_union.next().now_or_never().is_none());
633
634 for barrier_tx in barrier_txs {
635 barrier_tx.send(b1.clone()).unwrap();
636 }
637
638 let msg = sink_union.next().await.unwrap().unwrap();
639 assert!(msg.is_barrier());
640 let barrier = msg.as_barrier().unwrap();
641 assert_eq!(barrier.epoch, b1.epoch);
642 }
643
644 #[tokio::test]
645 async fn test_dynamic_upstreams() {
646 let test_env = LocalBarrierTestEnv::for_test().await;
647
648 let actor_id = 2;
649 let fragment_id = 0; let upstream_fragment_id = 11;
651 let upstream_actor_id = 101;
652
653 let upstream_actor = helper_make_local_actor(upstream_actor_id);
654
655 let add_upstream = PbNewUpstreamSink {
656 info: Some(PbUpstreamSinkInfo {
657 upstream_fragment_id,
658 sink_output_schema: vec![],
659 project_exprs: vec![],
660 }),
661 upstream_actors: vec![upstream_actor],
662 };
663
664 let b1 = Barrier::new_test_barrier(test_epoch(1));
665 let b2 =
666 Barrier::new_test_barrier(test_epoch(2)).with_mutation(Mutation::Add(AddMutation {
667 new_upstream_sinks: HashMap::from([(fragment_id, add_upstream)]),
668 ..Default::default()
669 }));
670 let b3 = Barrier::new_test_barrier(test_epoch(3));
671 let b4 =
672 Barrier::new_test_barrier(test_epoch(4)).with_mutation(Mutation::Stop(StopMutation {
673 dropped_sink_fragments: HashSet::from([upstream_fragment_id]),
674 ..Default::default()
675 }));
676 for barrier in [&b1, &b2, &b3, &b4] {
677 test_env.inject_barrier(barrier, [actor_id]);
678 }
679 test_env.flush_all_events().await;
680
681 let executor = UpstreamSinkUnionExecutor::for_test(
682 actor_id,
683 test_env.local_barrier_manager.clone(),
684 10,
685 );
686 test_env.flush_all_events().await;
688
689 let mut exec_stream = Box::new(executor).execute_inner().boxed();
691 let msg = exec_stream.next().await.unwrap().unwrap();
692 assert_eq!(msg.as_barrier().unwrap().epoch, b1.epoch);
693
694 assert!(exec_stream.next().now_or_never().is_none());
697
698 let mut output_req = test_env
699 .take_pending_new_output_requests(upstream_actor_id)
700 .await;
701 let (_, req) = output_req.pop().unwrap();
702 let tx = match req {
703 NewOutputRequest::Local(tx) => tx,
704 NewOutputRequest::Remote(_) => unreachable!(),
705 };
706
707 tx.send(MessageInner::Barrier(b2.clone().into_dispatcher()).into())
708 .await
709 .unwrap();
710 let msg = exec_stream.next().await.unwrap().unwrap();
712 assert_eq!(msg.as_barrier().unwrap().epoch, b2.epoch);
713
714 tx.send(MessageInner::Chunk(build_test_chunk(10)).into())
715 .await
716 .unwrap();
717 let msg = exec_stream.next().await.unwrap().unwrap();
718 assert!(msg.is_chunk());
719
720 tx.send(MessageInner::Barrier(b3.clone().into_dispatcher()).into())
721 .await
722 .unwrap();
723 let msg = exec_stream.next().await.unwrap().unwrap();
724 assert_eq!(msg.as_barrier().unwrap().epoch, b3.epoch);
725
726 tx.send(MessageInner::Barrier(b4.clone().into_dispatcher()).into())
728 .await
729 .unwrap();
730 let msg = exec_stream.next().await.unwrap().unwrap();
732 assert_eq!(msg.as_barrier().unwrap().epoch, b4.epoch);
733 }
734}