risingwave_batch/task/
broadcast_channel.rs1use std::fmt::{Debug, Formatter};
16use std::sync::Arc;
17
18use anyhow::anyhow;
19use risingwave_common::array::DataChunk;
20use risingwave_pb::batch_plan::exchange_info::BroadcastInfo;
21use risingwave_pb::batch_plan::*;
22use tokio::sync::mpsc;
23
24use crate::error::BatchError::{Internal, SenderError};
25use crate::error::{BatchError, Result as BatchResult, SharedResult};
26use crate::task::channel::{ChanReceiver, ChanReceiverImpl, ChanSender, ChanSenderImpl};
27use crate::task::data_chunk_in_channel::DataChunkInChannel;
28
29#[derive(Clone)]
31pub struct BroadcastSender {
32 senders: Vec<mpsc::Sender<SharedResult<Option<DataChunkInChannel>>>>,
33 broadcast_info: BroadcastInfo,
34}
35
36impl Debug for BroadcastSender {
37 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
38 f.debug_struct("BroadcastSender")
39 .field("broadcast_info", &self.broadcast_info)
40 .finish()
41 }
42}
43
44impl ChanSender for BroadcastSender {
45 async fn send(&mut self, chunk: DataChunk) -> BatchResult<()> {
46 let broadcast_data_chunk = DataChunkInChannel::new(chunk);
47 for sender in &self.senders {
48 sender
49 .send(Ok(Some(broadcast_data_chunk.clone())))
50 .await
51 .map_err(|_| SenderError)?
52 }
53
54 Ok(())
55 }
56
57 async fn close(self, error: Option<Arc<BatchError>>) -> BatchResult<()> {
58 for sender in self.senders {
59 sender
60 .send(error.clone().map(Err).unwrap_or(Ok(None)))
61 .await
62 .map_err(|_| SenderError)?
63 }
64
65 Ok(())
66 }
67}
68
69pub struct BroadcastReceiver {
71 receiver: mpsc::Receiver<SharedResult<Option<DataChunkInChannel>>>,
72}
73
74impl ChanReceiver for BroadcastReceiver {
75 async fn recv(&mut self) -> SharedResult<Option<DataChunkInChannel>> {
76 match self.receiver.recv().await {
77 Some(data_chunk) => data_chunk,
78 None => Err(Arc::new(Internal(anyhow!("broken broadcast_channel")))),
80 }
81 }
82}
83
84pub fn new_broadcast_channel(
85 shuffle: &ExchangeInfo,
86 output_channel_size: usize,
87) -> (ChanSenderImpl, Vec<ChanReceiverImpl>) {
88 let broadcast_info = match shuffle.distribution {
89 Some(exchange_info::Distribution::BroadcastInfo(ref v)) => *v,
90 _ => BroadcastInfo::default(),
91 };
92
93 let output_count = broadcast_info.count as usize;
94 let mut senders = Vec::with_capacity(output_count);
95 let mut receivers = Vec::with_capacity(output_count);
96 for _ in 0..output_count {
97 let (s, r) = mpsc::channel(output_channel_size);
98 senders.push(s);
99 receivers.push(r);
100 }
101 let channel_sender = ChanSenderImpl::Broadcast(BroadcastSender {
102 senders,
103 broadcast_info,
104 });
105 let channel_receivers = receivers
106 .into_iter()
107 .map(|receiver| ChanReceiverImpl::Broadcast(BroadcastReceiver { receiver }))
108 .collect::<Vec<_>>();
109 (channel_sender, channel_receivers)
110}