risingwave_common/hash/consistent_hash/
mapping.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeSet, HashMap};
16use std::fmt::{Debug, Display, Formatter};
17use std::hash::Hash;
18use std::ops::Index;
19
20use educe::Educe;
21use itertools::Itertools;
22use risingwave_pb::common::PbWorkerSlotMapping;
23use risingwave_pb::stream_plan::ActorMapping as ActorMappingProto;
24
25use super::bitmap::VnodeBitmapExt;
26use crate::bitmap::{Bitmap, BitmapBuilder};
27use crate::hash::VirtualNode;
28use crate::util::compress::compress_data;
29use crate::util::iter_util::ZipEqDebug;
30
31// TODO: find a better place for this.
32pub type ActorId = u32;
33
34#[derive(Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)]
35pub struct WorkerSlotId(u64);
36
37impl WorkerSlotId {
38    pub fn worker_id(&self) -> u32 {
39        (self.0 >> 32) as u32
40    }
41
42    pub fn slot_idx(&self) -> u32 {
43        self.0 as u32
44    }
45
46    pub fn new(worker_id: u32, slot_idx: usize) -> Self {
47        Self((worker_id as u64) << 32 | slot_idx as u64)
48    }
49}
50
51impl From<WorkerSlotId> for u64 {
52    fn from(id: WorkerSlotId) -> Self {
53        id.0
54    }
55}
56
57impl From<u64> for WorkerSlotId {
58    fn from(id: u64) -> Self {
59        Self(id)
60    }
61}
62
63impl Display for WorkerSlotId {
64    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
65        f.write_fmt(format_args!("[{}:{}]", self.worker_id(), self.slot_idx()))
66    }
67}
68
69impl Debug for WorkerSlotId {
70    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
71        f.write_fmt(format_args!("[{}:{}]", self.worker_id(), self.slot_idx()))
72    }
73}
74
75/// Trait for items that can be used as keys in [`VnodeMapping`].
76pub trait VnodeMappingItem {
77    /// The type of the item.
78    ///
79    /// Currently, there are two types of items: [`WorkerSlotId`] and [`ActorId`]. We don't use
80    /// them directly as the generic parameter because they're the same type aliases.
81    type Item: Copy + Ord + Hash + Debug;
82}
83
84/// Exapnded mapping from virtual nodes to items, essentially a vector of items and can be indexed
85/// by virtual nodes.
86pub type ExpandedMapping<T> = Vec<<T as VnodeMappingItem>::Item>;
87
88/// Generic mapping from virtual nodes to items.
89///
90/// The representation is compressed as described in [`compress_data`], which is optimized for the
91/// mapping with a small number of items and good locality.
92#[derive(Educe)]
93#[educe(Debug, Clone, PartialEq, Eq, Hash)]
94pub struct VnodeMapping<T: VnodeMappingItem> {
95    original_indices: Vec<u32>,
96    data: Vec<T::Item>,
97}
98
99#[expect(
100    clippy::len_without_is_empty,
101    reason = "empty vnode mapping makes no sense"
102)]
103impl<T: VnodeMappingItem> VnodeMapping<T> {
104    /// Create a uniform vnode mapping with a **set** of items.
105    ///
106    /// For example, if `items` is `[0, 1, 2]`, and the total vnode count is 10, we'll generate
107    /// mapping like `[0, 0, 0, 0, 1, 1, 1, 2, 2, 2]`.
108    pub fn new_uniform(items: impl ExactSizeIterator<Item = T::Item>, vnode_count: usize) -> Self {
109        // If the number of items is greater than the total vnode count, no vnode will be mapped to
110        // some items and the mapping will be invalid.
111        assert!(items.len() <= vnode_count);
112
113        let mut original_indices = Vec::with_capacity(items.len());
114        let mut data = Vec::with_capacity(items.len());
115
116        let hash_shard_size = vnode_count / items.len();
117        let mut one_more_count = vnode_count % items.len();
118        let mut init_bound = 0;
119
120        for item in items {
121            let count = if one_more_count > 0 {
122                one_more_count -= 1;
123                hash_shard_size + 1
124            } else {
125                hash_shard_size
126            };
127            init_bound += count;
128
129            original_indices.push(init_bound as u32 - 1);
130            data.push(item);
131        }
132
133        // Assert that there's no duplicated items.
134        debug_assert_eq!(data.iter().duplicates().count(), 0);
135
136        Self {
137            original_indices,
138            data,
139        }
140    }
141
142    /// Create a vnode mapping with the single item and length of 1.
143    ///
144    /// Should only be used for singletons. If you want a different vnode count, call
145    /// [`VnodeMapping::new_uniform`] with `std::iter::once(item)` and desired length.
146    pub fn new_single(item: T::Item) -> Self {
147        Self::new_uniform(std::iter::once(item), 1)
148    }
149
150    /// The length (or count) of the vnode in this mapping.
151    pub fn len(&self) -> usize {
152        self.original_indices
153            .last()
154            .map(|&i| i as usize + 1)
155            .unwrap_or(0)
156    }
157
158    /// Get the item mapped to the given `vnode` by binary search.
159    ///
160    /// Note: to achieve better mapping performance, one should convert the mapping to the
161    /// [`ExpandedMapping`] first and directly access the item by index.
162    pub fn get(&self, vnode: VirtualNode) -> T::Item {
163        self[vnode]
164    }
165
166    /// Get the item matched by the virtual nodes indicated by high bits in the given `bitmap`.
167    /// Returns `None` if the no virtual node is set in the bitmap.
168    pub fn get_matched(&self, bitmap: &Bitmap) -> Option<T::Item> {
169        bitmap
170            .iter_vnodes()
171            .next() // only need to check the first one
172            .map(|v| self.get(v))
173    }
174
175    /// Iterate over all items in this mapping, in the order of vnodes.
176    pub fn iter(&self) -> impl Iterator<Item = T::Item> + '_ {
177        self.data
178            .iter()
179            .copied()
180            .zip_eq_debug(
181                std::iter::once(0)
182                    .chain(self.original_indices.iter().copied().map(|i| i + 1))
183                    .tuple_windows()
184                    .map(|(a, b)| (b - a) as usize),
185            )
186            .flat_map(|(item, c)| std::iter::repeat_n(item, c))
187    }
188
189    /// Iterate over all vnode-item pairs in this mapping.
190    pub fn iter_with_vnode(&self) -> impl Iterator<Item = (VirtualNode, T::Item)> + '_ {
191        self.iter()
192            .enumerate()
193            .map(|(v, item)| (VirtualNode::from_index(v), item))
194    }
195
196    /// Iterate over all unique items in this mapping. The order is deterministic.
197    pub fn iter_unique(&self) -> impl Iterator<Item = T::Item> + '_ {
198        // Note: we can't ensure there's no duplicated items in the `data` after some scaling.
199        self.data.iter().copied().sorted().dedup()
200    }
201
202    /// Returns the item if it's the only item in this mapping, otherwise returns `None`.
203    pub fn to_single(&self) -> Option<T::Item> {
204        self.data.iter().copied().dedup().exactly_one().ok()
205    }
206
207    /// Convert this vnode mapping to a mapping from items to bitmaps, where each bitmap represents
208    /// the vnodes mapped to the item.
209    pub fn to_bitmaps(&self) -> HashMap<T::Item, Bitmap> {
210        let vnode_count = self.len();
211        let mut vnode_bitmaps = HashMap::new();
212
213        for (vnode, item) in self.iter_with_vnode() {
214            vnode_bitmaps
215                .entry(item)
216                .or_insert_with(|| BitmapBuilder::zeroed(vnode_count))
217                .set(vnode.to_index(), true);
218        }
219
220        vnode_bitmaps
221            .into_iter()
222            .map(|(item, b)| (item, b.finish()))
223            .collect()
224    }
225
226    /// Create a vnode mapping from the given mapping from items to bitmaps, where each bitmap
227    /// represents the vnodes mapped to the item.
228    pub fn from_bitmaps(bitmaps: &HashMap<T::Item, Bitmap>) -> Self {
229        let vnode_count = bitmaps.values().next().expect("empty bitmaps").len();
230        let mut items = vec![None; vnode_count];
231
232        for (&item, bitmap) in bitmaps {
233            assert_eq!(bitmap.len(), vnode_count);
234            for idx in bitmap.iter_ones() {
235                if let Some(prev) = items[idx].replace(item) {
236                    panic!("mapping at index `{idx}` is set to both `{prev:?}` and `{item:?}`");
237                }
238            }
239        }
240
241        let items = items
242            .into_iter()
243            .enumerate()
244            .map(|(i, o)| o.unwrap_or_else(|| panic!("mapping at index `{i}` is not set")))
245            .collect_vec();
246        Self::from_expanded(&items)
247    }
248
249    /// Create a vnode mapping from the expanded slice of items.
250    pub fn from_expanded(items: &[T::Item]) -> Self {
251        let (original_indices, data) = compress_data(items);
252        Self {
253            original_indices,
254            data,
255        }
256    }
257
258    /// Convert this vnode mapping to a expanded vector of items.
259    pub fn to_expanded(&self) -> ExpandedMapping<T> {
260        self.iter().collect()
261    }
262
263    /// Transform this vnode mapping to another type of vnode mapping, with the given mapping from
264    /// items of this mapping to items of the other mapping.
265    pub fn transform<T2, M>(&self, to_map: &M) -> VnodeMapping<T2>
266    where
267        T2: VnodeMappingItem,
268        M: for<'a> Index<&'a T::Item, Output = T2::Item>,
269    {
270        VnodeMapping {
271            original_indices: self.original_indices.clone(),
272            data: self.data.iter().map(|item| to_map[item]).collect(),
273        }
274    }
275}
276
277impl<T: VnodeMappingItem> Index<VirtualNode> for VnodeMapping<T> {
278    type Output = T::Item;
279
280    fn index(&self, vnode: VirtualNode) -> &Self::Output {
281        let index = self
282            .original_indices
283            .partition_point(|&i| i < vnode.to_index() as u32);
284        &self.data[index]
285    }
286}
287
288pub mod marker {
289    use super::*;
290
291    /// A marker type for items of [`ActorId`].
292    pub struct Actor;
293    impl VnodeMappingItem for Actor {
294        type Item = ActorId;
295    }
296
297    /// A marker type for items of [`WorkerSlotId`].
298    pub struct WorkerSlot;
299    impl VnodeMappingItem for WorkerSlot {
300        type Item = WorkerSlotId;
301    }
302}
303
304/// A mapping from [`VirtualNode`] to [`ActorId`].
305pub type ActorMapping = VnodeMapping<marker::Actor>;
306/// An expanded mapping from [`VirtualNode`] to [`ActorId`].
307pub type ExpandedActorMapping = ExpandedMapping<marker::Actor>;
308
309/// A mapping from [`VirtualNode`] to [`WorkerSlotId`].
310pub type WorkerSlotMapping = VnodeMapping<marker::WorkerSlot>;
311/// An expanded mapping from [`VirtualNode`] to [`WorkerSlotId`].
312pub type ExpandedWorkerSlotMapping = ExpandedMapping<marker::WorkerSlot>;
313
314impl ActorMapping {
315    /// Transform the actor mapping to the worker slot mapping. Note that the parameter is a mapping from actor to worker.
316    pub fn to_worker_slot(&self, actor_to_worker: &HashMap<ActorId, u32>) -> WorkerSlotMapping {
317        let mut worker_actors = HashMap::new();
318        for actor_id in self.iter_unique() {
319            let worker_id = actor_to_worker
320                .get(&actor_id)
321                .cloned()
322                .unwrap_or_else(|| panic!("location for actor {} not found", actor_id));
323
324            worker_actors
325                .entry(worker_id)
326                .or_insert(BTreeSet::new())
327                .insert(actor_id);
328        }
329
330        let mut actor_location = HashMap::new();
331        for (worker, actors) in worker_actors {
332            for (idx, &actor) in actors.iter().enumerate() {
333                actor_location.insert(actor, WorkerSlotId::new(worker, idx));
334            }
335        }
336
337        self.transform(&actor_location)
338    }
339
340    /// Create an actor mapping from the protobuf representation.
341    pub fn from_protobuf(proto: &ActorMappingProto) -> Self {
342        assert_eq!(proto.original_indices.len(), proto.data.len());
343        Self {
344            original_indices: proto.original_indices.clone(),
345            data: proto.data.clone(),
346        }
347    }
348
349    /// Convert this actor mapping to the protobuf representation.
350    pub fn to_protobuf(&self) -> ActorMappingProto {
351        ActorMappingProto {
352            original_indices: self.original_indices.clone(),
353            data: self.data.clone(),
354        }
355    }
356}
357
358impl WorkerSlotMapping {
359    /// Create a uniform worker mapping from the given worker ids
360    pub fn build_from_ids(worker_slot_ids: &[WorkerSlotId], vnode_count: usize) -> Self {
361        Self::new_uniform(worker_slot_ids.iter().cloned(), vnode_count)
362    }
363
364    /// Create a worker mapping from the protobuf representation.
365    pub fn from_protobuf(proto: &PbWorkerSlotMapping) -> Self {
366        assert_eq!(proto.original_indices.len(), proto.data.len());
367        Self {
368            original_indices: proto.original_indices.clone(),
369            data: proto.data.iter().map(|&id| WorkerSlotId(id)).collect(),
370        }
371    }
372
373    /// Convert this worker mapping to the protobuf representation.
374    pub fn to_protobuf(&self) -> PbWorkerSlotMapping {
375        PbWorkerSlotMapping {
376            original_indices: self.original_indices.clone(),
377            data: self.data.iter().map(|id| id.0).collect(),
378        }
379    }
380}
381
382impl WorkerSlotMapping {
383    /// Transform this worker slot mapping to an actor mapping, essentially `transform`.
384    pub fn to_actor(&self, to_map: &HashMap<WorkerSlotId, ActorId>) -> ActorMapping {
385        self.transform(to_map)
386    }
387}
388
389#[cfg(test)]
390mod tests {
391    use std::iter::repeat_with;
392
393    use rand::Rng;
394
395    use super::*;
396
397    struct Test;
398    impl VnodeMappingItem for Test {
399        type Item = u32;
400    }
401
402    struct Test2;
403    impl VnodeMappingItem for Test2 {
404        type Item = u32;
405    }
406
407    type TestMapping = VnodeMapping<Test>;
408    type Test2Mapping = VnodeMapping<Test2>;
409
410    const COUNTS: &[usize] = &[1, 3, 12, 42, VirtualNode::COUNT_FOR_TEST];
411
412    fn uniforms() -> impl Iterator<Item = TestMapping> {
413        COUNTS
414            .iter()
415            .map(|&count| TestMapping::new_uniform(0..count as u32, VirtualNode::COUNT_FOR_TEST))
416    }
417
418    fn randoms() -> impl Iterator<Item = TestMapping> {
419        COUNTS.iter().map(|&count| {
420            let raw = repeat_with(|| rand::rng().random_range(0..count as u32))
421                .take(VirtualNode::COUNT_FOR_TEST)
422                .collect_vec();
423            TestMapping::from_expanded(&raw)
424        })
425    }
426
427    fn mappings() -> impl Iterator<Item = TestMapping> {
428        uniforms().chain(randoms())
429    }
430
431    #[test]
432    fn test_uniform() {
433        for vnode_mapping in uniforms() {
434            assert_eq!(vnode_mapping.len(), VirtualNode::COUNT_FOR_TEST);
435            let item_count = vnode_mapping.iter_unique().count();
436
437            let mut check: HashMap<u32, Vec<_>> = HashMap::new();
438            for (vnode, item) in vnode_mapping.iter_with_vnode() {
439                check.entry(item).or_default().push(vnode);
440            }
441
442            assert_eq!(check.len(), item_count);
443
444            let (min, max) = check
445                .values()
446                .map(|indexes| indexes.len())
447                .minmax()
448                .into_option()
449                .unwrap();
450
451            assert!(max - min <= 1);
452        }
453    }
454
455    #[test]
456    fn test_iter_with_get() {
457        for vnode_mapping in mappings() {
458            for (vnode, item) in vnode_mapping.iter_with_vnode() {
459                assert_eq!(vnode_mapping.get(vnode), item);
460            }
461        }
462    }
463
464    #[test]
465    fn test_from_to_bitmaps() {
466        for vnode_mapping in mappings() {
467            let bitmaps = vnode_mapping.to_bitmaps();
468            let new_vnode_mapping = TestMapping::from_bitmaps(&bitmaps);
469
470            assert_eq!(vnode_mapping, new_vnode_mapping);
471        }
472    }
473
474    #[test]
475    fn test_transform() {
476        for vnode_mapping in mappings() {
477            let transform_map: HashMap<_, _> = vnode_mapping
478                .iter_unique()
479                .map(|item| (item, item + 1))
480                .collect();
481            let vnode_mapping_2: Test2Mapping = vnode_mapping.transform(&transform_map);
482
483            for (item, item_2) in vnode_mapping.iter().zip_eq_debug(vnode_mapping_2.iter()) {
484                assert_eq!(item + 1, item_2);
485            }
486
487            let transform_back_map: HashMap<_, _> =
488                transform_map.into_iter().map(|(k, v)| (v, k)).collect();
489            let new_vnode_mapping: TestMapping = vnode_mapping_2.transform(&transform_back_map);
490
491            assert_eq!(vnode_mapping, new_vnode_mapping);
492        }
493    }
494}