risingwave_common/vnode_mapping/
vnode_placement.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{HashMap, HashSet, LinkedList, VecDeque};
16use std::ops::BitOrAssign;
17
18use itertools::Itertools;
19use num_integer::Integer;
20use risingwave_common::hash::WorkerSlotId;
21use risingwave_pb::common::WorkerNode;
22
23use crate::bitmap::{Bitmap, BitmapBuilder};
24use crate::hash::{VirtualNode, WorkerSlotMapping};
25
26/// Calculate a new vnode mapping, keeping locality and balance on a best effort basis.
27/// The strategy is similar to `rebalance_actor_vnode` used in meta node, but is modified to
28/// consider `max_parallelism` too.
29pub fn place_vnode(
30    hint_worker_slot_mapping: Option<&WorkerSlotMapping>,
31    workers: &[WorkerNode],
32    max_parallelism: Option<usize>,
33    vnode_count: usize,
34) -> Option<WorkerSlotMapping> {
35    if let Some(mapping) = hint_worker_slot_mapping {
36        assert_eq!(mapping.len(), vnode_count);
37    }
38
39    // Get all serving worker slots from all available workers, grouped by worker id and ordered
40    // by worker slot id in each group.
41    let mut worker_slots: LinkedList<_> = workers
42        .iter()
43        .filter(|w| w.property.as_ref().is_some_and(|p| p.is_serving))
44        .sorted_by_key(|w| w.id)
45        .map(|w| (0..w.compute_node_parallelism()).map(|idx| WorkerSlotId::new(w.id, idx)))
46        .collect();
47
48    // Set serving parallelism to the minimum of total number of worker slots, specified
49    // `max_parallelism` and total number of virtual nodes.
50    let serving_parallelism = std::cmp::min(
51        worker_slots.iter().map(|slots| slots.len()).sum(),
52        std::cmp::min(max_parallelism.unwrap_or(usize::MAX), vnode_count),
53    );
54
55    // Select `serving_parallelism` worker slots in a round-robin fashion, to distribute workload
56    // evenly among workers.
57    let mut selected_slots = Vec::new();
58    while !worker_slots.is_empty() {
59        worker_slots
60            .extract_if(|slots| {
61                if let Some(slot) = slots.next() {
62                    selected_slots.push(slot);
63                    false
64                } else {
65                    true
66                }
67            })
68            .for_each(drop);
69    }
70    selected_slots.drain(serving_parallelism..);
71    let selected_slots_set: HashSet<WorkerSlotId> = selected_slots.iter().cloned().collect();
72    if selected_slots_set.is_empty() {
73        return None;
74    }
75
76    // Calculate balance for each selected worker slot. Initially, each worker slot is assigned
77    // no vnodes. Thus its negative balance means that many vnodes should be assigned to it later.
78    // `is_temp` is a mark for a special temporary worker slot, only to simplify implementation.
79    #[derive(Debug)]
80    struct Balance {
81        slot: WorkerSlotId,
82        balance: i32,
83        builder: BitmapBuilder,
84        is_temp: bool,
85    }
86
87    let (expected, mut remain) = vnode_count.div_rem(&selected_slots.len());
88    let mut balances: HashMap<WorkerSlotId, Balance> = HashMap::default();
89
90    for slot in &selected_slots {
91        let mut balance = Balance {
92            slot: *slot,
93            balance: -(expected as i32),
94            builder: BitmapBuilder::zeroed(vnode_count),
95            is_temp: false,
96        };
97
98        if remain > 0 {
99            balance.balance -= 1;
100            remain -= 1;
101        }
102        balances.insert(*slot, balance);
103    }
104
105    // Now to maintain affinity, if a hint has been provided via `hint_worker_slot_mapping`, follow
106    // that mapping to adjust balances.
107    let mut temp_slot = Balance {
108        slot: WorkerSlotId::new(0u32, usize::MAX), /* This id doesn't matter for `temp_slot`. It's distinguishable via `is_temp`. */
109        balance: 0,
110        builder: BitmapBuilder::zeroed(vnode_count),
111        is_temp: true,
112    };
113    match hint_worker_slot_mapping {
114        Some(hint_worker_slot_mapping) => {
115            for (vnode, worker_slot) in hint_worker_slot_mapping.iter_with_vnode() {
116                let b = if selected_slots_set.contains(&worker_slot) {
117                    // Assign vnode to the same worker slot as hint.
118                    balances.get_mut(&worker_slot).unwrap()
119                } else {
120                    // Assign vnode that doesn't belong to any worker slot to `temp_slot`
121                    // temporarily. They will be reassigned later.
122                    &mut temp_slot
123                };
124
125                b.balance += 1;
126                b.builder.set(vnode.to_index(), true);
127            }
128        }
129        None => {
130            // No hint is provided, assign all vnodes to `temp_pu`.
131            for vnode in VirtualNode::all(vnode_count) {
132                temp_slot.balance += 1;
133                temp_slot.builder.set(vnode.to_index(), true);
134            }
135        }
136    }
137
138    // The final step is to move vnodes from worker slots with positive balance to worker slots
139    // with negative balance, until all worker slots are of 0 balance.
140    // A double-ended queue with worker slots ordered by balance in descending order is consumed:
141    // 1. Peek 2 worker slots from front and back.
142    // 2. It any of them is of 0 balance, pop it and go to step 1.
143    // 3. Otherwise, move vnodes from front to back.
144    let mut balances: VecDeque<_> = balances
145        .into_values()
146        .chain(std::iter::once(temp_slot))
147        .sorted_by_key(|b| b.balance)
148        .rev()
149        .collect();
150
151    let mut results: HashMap<WorkerSlotId, Bitmap> = HashMap::default();
152
153    while !balances.is_empty() {
154        if balances.len() == 1 {
155            let single = balances.pop_front().unwrap();
156            assert_eq!(single.balance, 0);
157            if !single.is_temp {
158                results.insert(single.slot, single.builder.finish());
159            }
160            break;
161        }
162        let mut src = balances.pop_front().unwrap();
163        let mut dst = balances.pop_back().unwrap();
164        let n = std::cmp::min(src.balance.abs(), dst.balance.abs());
165        let mut moved = 0;
166        for idx in 0..vnode_count {
167            if moved >= n {
168                break;
169            }
170            if src.builder.is_set(idx) {
171                src.builder.set(idx, false);
172                assert!(!dst.builder.is_set(idx));
173                dst.builder.set(idx, true);
174                moved += 1;
175            }
176        }
177        src.balance -= n;
178        dst.balance += n;
179        if src.balance != 0 {
180            balances.push_front(src);
181        } else if !src.is_temp {
182            results.insert(src.slot, src.builder.finish());
183        }
184
185        if dst.balance != 0 {
186            balances.push_back(dst);
187        } else if !dst.is_temp {
188            results.insert(dst.slot, dst.builder.finish());
189        }
190    }
191
192    let mut worker_result = HashMap::new();
193
194    for (worker_slot, bitmap) in results {
195        worker_result
196            .entry(worker_slot)
197            .or_insert(Bitmap::zeros(vnode_count))
198            .bitor_assign(&bitmap);
199    }
200
201    Some(WorkerSlotMapping::from_bitmaps(&worker_result))
202}
203
204#[cfg(test)]
205mod tests {
206
207    use risingwave_common::hash::WorkerSlotMapping;
208    use risingwave_pb::common::worker_node::Property;
209    use risingwave_pb::common::{WorkerNode, WorkerType};
210
211    use crate::hash::VirtualNode;
212
213    /// [`super::place_vnode`] with [`VirtualNode::COUNT_FOR_TEST`] as the vnode count.
214    fn place_vnode(
215        hint_worker_slot_mapping: Option<&WorkerSlotMapping>,
216        workers: &[WorkerNode],
217        max_parallelism: Option<usize>,
218    ) -> Option<WorkerSlotMapping> {
219        super::place_vnode(
220            hint_worker_slot_mapping,
221            workers,
222            max_parallelism,
223            VirtualNode::COUNT_FOR_TEST,
224        )
225    }
226
227    #[test]
228    fn test_place_vnode() {
229        assert_eq!(VirtualNode::COUNT_FOR_TEST, 256);
230
231        let serving_property = Property {
232            is_unschedulable: false,
233            is_serving: true,
234            is_streaming: false,
235            ..Default::default()
236        };
237
238        let count_same_vnode_mapping = |wm1: &WorkerSlotMapping, wm2: &WorkerSlotMapping| {
239            assert_eq!(wm1.len(), 256);
240            assert_eq!(wm2.len(), 256);
241            let mut count: usize = 0;
242            for idx in 0..VirtualNode::COUNT_FOR_TEST {
243                let vnode = VirtualNode::from_index(idx);
244                if wm1.get(vnode) == wm2.get(vnode) {
245                    count += 1;
246                }
247            }
248            count
249        };
250
251        let mut property = serving_property.clone();
252        property.parallelism = 1;
253        let worker_1 = WorkerNode {
254            id: 1,
255            r#type: WorkerType::ComputeNode.into(),
256            property: Some(property),
257            ..Default::default()
258        };
259
260        assert!(
261            place_vnode(None, &[worker_1.clone()], Some(0)).is_none(),
262            "max_parallelism should >= 0"
263        );
264
265        let re_worker_mapping_2 = place_vnode(None, &[worker_1.clone()], None).unwrap();
266        assert_eq!(re_worker_mapping_2.iter_unique().count(), 1);
267
268        let mut property = serving_property.clone();
269        property.parallelism = 50;
270        let worker_2 = WorkerNode {
271            id: 2,
272            property: Some(property),
273            r#type: WorkerType::ComputeNode.into(),
274            ..Default::default()
275        };
276
277        let re_worker_mapping = place_vnode(
278            Some(&re_worker_mapping_2),
279            &[worker_1.clone(), worker_2.clone()],
280            None,
281        )
282        .unwrap();
283
284        assert_eq!(re_worker_mapping.iter_unique().count(), 51);
285        // 1 * 256 + 0 -> 51 * 5 + 1
286        let score = count_same_vnode_mapping(&re_worker_mapping_2, &re_worker_mapping);
287        assert!(score >= 5);
288
289        let mut property = serving_property.clone();
290        property.parallelism = 60;
291        let worker_3 = WorkerNode {
292            id: 3,
293            r#type: WorkerType::ComputeNode.into(),
294            property: Some(property),
295            ..Default::default()
296        };
297        let re_pu_mapping_2 = place_vnode(
298            Some(&re_worker_mapping),
299            &[worker_1.clone(), worker_2.clone(), worker_3.clone()],
300            None,
301        )
302        .unwrap();
303
304        // limited by total pu number
305        assert_eq!(re_pu_mapping_2.iter_unique().count(), 111);
306        // 51 * 5 + 1 -> 111 * 2 + 34
307        let score = count_same_vnode_mapping(&re_pu_mapping_2, &re_worker_mapping);
308        assert!(score >= (2 + 50 * 2));
309        let re_pu_mapping = place_vnode(
310            Some(&re_pu_mapping_2),
311            &[worker_1.clone(), worker_2.clone(), worker_3.clone()],
312            Some(50),
313        )
314        .unwrap();
315        // limited by max_parallelism
316        assert_eq!(re_pu_mapping.iter_unique().count(), 50);
317        // 111 * 2 + 34 -> 50 * 5 + 6
318        let score = count_same_vnode_mapping(&re_pu_mapping, &re_pu_mapping_2);
319        assert!(score >= 50 * 2);
320        let re_pu_mapping_2 = place_vnode(
321            Some(&re_pu_mapping),
322            &[worker_1.clone(), worker_2, worker_3.clone()],
323            None,
324        )
325        .unwrap();
326        assert_eq!(re_pu_mapping_2.iter_unique().count(), 111);
327        // 50 * 5 + 6 -> 111 * 2 + 34
328        let score = count_same_vnode_mapping(&re_pu_mapping_2, &re_pu_mapping);
329        assert!(score >= 50 * 2);
330        let re_pu_mapping =
331            place_vnode(Some(&re_pu_mapping_2), &[worker_1, worker_3.clone()], None).unwrap();
332        // limited by total pu number
333        assert_eq!(re_pu_mapping.iter_unique().count(), 61);
334        // 111 * 2 + 34 -> 61 * 4 + 12
335        let score = count_same_vnode_mapping(&re_pu_mapping, &re_pu_mapping_2);
336        assert!(score >= 61 * 2);
337        assert!(place_vnode(Some(&re_pu_mapping), &[], None).is_none());
338        let re_pu_mapping = place_vnode(Some(&re_pu_mapping), &[worker_3], None).unwrap();
339        assert_eq!(re_pu_mapping.iter_unique().count(), 60);
340        assert!(place_vnode(Some(&re_pu_mapping), &[], None).is_none());
341    }
342}