risingwave_batch/task/
broadcast_channel.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::{Debug, Formatter};
16use std::sync::Arc;
17
18use anyhow::anyhow;
19use risingwave_common::array::DataChunk;
20use risingwave_pb::batch_plan::exchange_info::BroadcastInfo;
21use risingwave_pb::batch_plan::*;
22use tokio::sync::mpsc;
23
24use crate::error::BatchError::{Internal, SenderError};
25use crate::error::{BatchError, Result as BatchResult, SharedResult};
26use crate::task::channel::{ChanReceiver, ChanReceiverImpl, ChanSender, ChanSenderImpl};
27use crate::task::data_chunk_in_channel::DataChunkInChannel;
28
29/// `BroadcastSender` sends the same chunk to a number of `BroadcastReceiver`s.
30#[derive(Clone)]
31pub struct BroadcastSender {
32    senders: Vec<mpsc::Sender<SharedResult<Option<DataChunkInChannel>>>>,
33    broadcast_info: BroadcastInfo,
34}
35
36impl Debug for BroadcastSender {
37    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
38        f.debug_struct("BroadcastSender")
39            .field("broadcast_info", &self.broadcast_info)
40            .finish()
41    }
42}
43
44impl ChanSender for BroadcastSender {
45    async fn send(&mut self, chunk: DataChunk) -> BatchResult<()> {
46        let broadcast_data_chunk = DataChunkInChannel::new(chunk);
47        for sender in &self.senders {
48            sender
49                .send(Ok(Some(broadcast_data_chunk.clone())))
50                .await
51                .map_err(|_| SenderError)?
52        }
53
54        Ok(())
55    }
56
57    async fn close(self, error: Option<Arc<BatchError>>) -> BatchResult<()> {
58        for sender in self.senders {
59            sender
60                .send(error.clone().map(Err).unwrap_or(Ok(None)))
61                .await
62                .map_err(|_| SenderError)?
63        }
64
65        Ok(())
66    }
67}
68
69/// One or more `BroadcastReceiver`s corresponds to a single `BroadcastReceiver`
70pub struct BroadcastReceiver {
71    receiver: mpsc::Receiver<SharedResult<Option<DataChunkInChannel>>>,
72}
73
74impl ChanReceiver for BroadcastReceiver {
75    async fn recv(&mut self) -> SharedResult<Option<DataChunkInChannel>> {
76        match self.receiver.recv().await {
77            Some(data_chunk) => data_chunk,
78            // Early close should be treated as an error.
79            None => Err(Arc::new(Internal(anyhow!("broken broadcast_channel")))),
80        }
81    }
82}
83
84pub fn new_broadcast_channel(
85    shuffle: &ExchangeInfo,
86    output_channel_size: usize,
87) -> (ChanSenderImpl, Vec<ChanReceiverImpl>) {
88    let broadcast_info = match shuffle.distribution {
89        Some(exchange_info::Distribution::BroadcastInfo(ref v)) => *v,
90        _ => BroadcastInfo::default(),
91    };
92
93    let output_count = broadcast_info.count as usize;
94    let mut senders = Vec::with_capacity(output_count);
95    let mut receivers = Vec::with_capacity(output_count);
96    for _ in 0..output_count {
97        let (s, r) = mpsc::channel(output_channel_size);
98        senders.push(s);
99        receivers.push(r);
100    }
101    let channel_sender = ChanSenderImpl::Broadcast(BroadcastSender {
102        senders,
103        broadcast_info,
104    });
105    let channel_receivers = receivers
106        .into_iter()
107        .map(|receiver| ChanReceiverImpl::Broadcast(BroadcastReceiver { receiver }))
108        .collect::<Vec<_>>();
109    (channel_sender, channel_receivers)
110}