risingwave_batch/task/
consistent_hash_shuffle_channel.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::{Debug, Formatter};
16use std::sync::Arc;
17
18use anyhow::anyhow;
19use itertools::Itertools;
20use risingwave_common::array::DataChunk;
21use risingwave_common::bitmap::Bitmap;
22use risingwave_common::hash::VirtualNode;
23use risingwave_pb::batch_plan::exchange_info::ConsistentHashInfo;
24use risingwave_pb::batch_plan::*;
25use tokio::sync::mpsc;
26
27use crate::error::BatchError::{Internal, SenderError};
28use crate::error::{BatchError, Result as BatchResult, SharedResult};
29use crate::task::channel::{ChanReceiver, ChanReceiverImpl, ChanSender, ChanSenderImpl};
30use crate::task::data_chunk_in_channel::DataChunkInChannel;
31
32#[derive(Clone)]
33pub struct ConsistentHashShuffleSender {
34    senders: Vec<mpsc::Sender<SharedResult<Option<DataChunkInChannel>>>>,
35    consistent_hash_info: ConsistentHashInfo,
36    output_count: usize,
37}
38
39impl Debug for ConsistentHashShuffleSender {
40    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
41        f.debug_struct("ConsistentHashShuffleSender")
42            .field("consistent_hash_info", &self.consistent_hash_info)
43            .finish()
44    }
45}
46
47pub struct ConsistentHashShuffleReceiver {
48    receiver: mpsc::Receiver<SharedResult<Option<DataChunkInChannel>>>,
49}
50
51fn generate_hash_values(
52    chunk: &DataChunk,
53    consistent_hash_info: &ConsistentHashInfo,
54) -> BatchResult<Vec<usize>> {
55    let vnodes = VirtualNode::compute_chunk(
56        chunk,
57        &consistent_hash_info
58            .key
59            .iter()
60            .map(|idx| *idx as usize)
61            .collect::<Vec<_>>(),
62        consistent_hash_info.vmap.len(),
63    );
64
65    let hash_values = vnodes
66        .iter()
67        .map(|vnode| consistent_hash_info.vmap[vnode.to_index()] as usize)
68        .collect::<Vec<_>>();
69
70    Ok(hash_values)
71}
72
73/// The returned chunks must have cardinality > 0.
74fn generate_new_data_chunks(
75    chunk: &DataChunk,
76    output_count: usize,
77    hash_values: &[usize],
78) -> Vec<DataChunk> {
79    let mut vis_maps = vec![vec![]; output_count];
80    hash_values.iter().for_each(|hash| {
81        for (sink_id, vis_map) in vis_maps.iter_mut().enumerate() {
82            if *hash == sink_id {
83                vis_map.push(true);
84            } else {
85                vis_map.push(false);
86            }
87        }
88    });
89    let mut res = Vec::with_capacity(output_count);
90    for (sink_id, vis_map_vec) in vis_maps.into_iter().enumerate() {
91        let vis_map = Bitmap::from_bool_slice(&vis_map_vec) & chunk.visibility();
92        let new_data_chunk = chunk.with_visibility(vis_map);
93        trace!(
94            "send to sink:{}, cardinality:{}",
95            sink_id,
96            new_data_chunk.cardinality()
97        );
98        res.push(new_data_chunk);
99    }
100    res
101}
102
103impl ChanSender for ConsistentHashShuffleSender {
104    async fn send(&mut self, chunk: DataChunk) -> BatchResult<()> {
105        self.send_chunk(chunk).await
106    }
107
108    async fn close(self, error: Option<Arc<BatchError>>) -> BatchResult<()> {
109        self.send_done(error).await
110    }
111}
112
113impl ConsistentHashShuffleSender {
114    async fn send_chunk(&mut self, chunk: DataChunk) -> BatchResult<()> {
115        let hash_values = generate_hash_values(&chunk, &self.consistent_hash_info)?;
116        let new_data_chunks = generate_new_data_chunks(&chunk, self.output_count, &hash_values);
117
118        for (sink_id, new_data_chunk) in new_data_chunks.into_iter().enumerate() {
119            trace!(
120                "send to sink:{}, cardinality:{}",
121                sink_id,
122                new_data_chunk.cardinality()
123            );
124            // The reason we need to add this filter only in HashShuffleSender is that
125            // `generate_new_data_chunks` may generate an empty chunk.
126            if new_data_chunk.cardinality() > 0 {
127                self.senders[sink_id]
128                    .send(Ok(Some(DataChunkInChannel::new(new_data_chunk))))
129                    .await
130                    .map_err(|_| SenderError)?
131            }
132        }
133        Ok(())
134    }
135
136    async fn send_done(self, error: Option<Arc<BatchError>>) -> BatchResult<()> {
137        for sender in self.senders {
138            sender
139                .send(error.clone().map(Err).unwrap_or(Ok(None)))
140                .await
141                .map_err(|_| SenderError)?
142        }
143
144        Ok(())
145    }
146}
147
148impl ChanReceiver for ConsistentHashShuffleReceiver {
149    async fn recv(&mut self) -> SharedResult<Option<DataChunkInChannel>> {
150        match self.receiver.recv().await {
151            Some(data_chunk) => data_chunk,
152            // Early close should be treated as error.
153            None => Err(Arc::new(Internal(anyhow!("broken hash_shuffle_channel")))),
154        }
155    }
156}
157
158pub fn new_consistent_shuffle_channel(
159    shuffle: &ExchangeInfo,
160    output_channel_size: usize,
161) -> (ChanSenderImpl, Vec<ChanReceiverImpl>) {
162    let consistent_hash_info = match shuffle.distribution {
163        Some(exchange_info::Distribution::ConsistentHashInfo(ref v)) => v.clone(),
164        _ => exchange_info::ConsistentHashInfo::default(),
165    };
166
167    let output_count = consistent_hash_info
168        .vmap
169        .iter()
170        .copied()
171        .sorted()
172        .dedup()
173        .count();
174
175    let mut senders = Vec::with_capacity(output_count);
176    let mut receivers = Vec::with_capacity(output_count);
177    for _ in 0..output_count {
178        let (s, r) = mpsc::channel(output_channel_size);
179        senders.push(s);
180        receivers.push(r);
181    }
182    let channel_sender = ChanSenderImpl::ConsistentHashShuffle(ConsistentHashShuffleSender {
183        senders,
184        consistent_hash_info,
185        output_count,
186    });
187    let channel_receivers = receivers
188        .into_iter()
189        .map(|receiver| {
190            ChanReceiverImpl::ConsistentHashShuffle(ConsistentHashShuffleReceiver { receiver })
191        })
192        .collect::<Vec<_>>();
193    (channel_sender, channel_receivers)
194}