risingwave_common/vnode_mapping/
vnode_placement.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{HashMap, HashSet, LinkedList, VecDeque};
16use std::ops::BitOrAssign;
17
18use itertools::Itertools;
19use num_integer::Integer;
20use risingwave_common::hash::WorkerSlotId;
21use risingwave_pb::common::WorkerNode;
22
23use crate::bitmap::{Bitmap, BitmapBuilder};
24use crate::hash::{VirtualNode, WorkerSlotMapping};
25
26/// Calculate a new vnode mapping, keeping locality and balance on a best effort basis.
27/// The strategy is similar to `rebalance_actor_vnode` used in meta node, but is modified to
28/// consider `max_parallelism` too.
29pub fn place_vnode(
30    hint_worker_slot_mapping: Option<&WorkerSlotMapping>,
31    workers: &[WorkerNode],
32    max_parallelism: Option<usize>,
33    vnode_count: usize,
34) -> Option<WorkerSlotMapping> {
35    if let Some(mapping) = hint_worker_slot_mapping {
36        assert_eq!(mapping.len(), vnode_count);
37    }
38
39    // Get all serving worker slots from all available workers, grouped by worker id and ordered
40    // by worker slot id in each group.
41    let mut worker_slots: LinkedList<_> = workers
42        .iter()
43        .filter(|w| w.property.as_ref().is_some_and(|p| p.is_serving))
44        .sorted_by_key(|w| w.id)
45        .map(|w| (0..w.compute_node_parallelism()).map(|idx| WorkerSlotId::new(w.id, idx)))
46        .collect();
47
48    // Set serving parallelism to the minimum of total number of worker slots, specified
49    // `max_parallelism` and total number of virtual nodes.
50    let serving_parallelism = std::cmp::min(
51        worker_slots.iter().map(|slots| slots.len()).sum(),
52        std::cmp::min(max_parallelism.unwrap_or(usize::MAX), vnode_count),
53    );
54
55    // Select `serving_parallelism` worker slots in a round-robin fashion, to distribute workload
56    // evenly among workers.
57    let mut selected_slots = Vec::new();
58    while !worker_slots.is_empty() {
59        worker_slots
60            .extract_if(|slots| {
61                if let Some(slot) = slots.next() {
62                    selected_slots.push(slot);
63                    false
64                } else {
65                    true
66                }
67            })
68            .for_each(drop);
69    }
70    selected_slots.drain(serving_parallelism..);
71    let selected_slots_set: HashSet<WorkerSlotId> = selected_slots.iter().cloned().collect();
72    if selected_slots_set.is_empty() {
73        return None;
74    }
75
76    // Calculate balance for each selected worker slot. Initially, each worker slot is assigned
77    // no vnodes. Thus its negative balance means that many vnodes should be assigned to it later.
78    // `is_temp` is a mark for a special temporary worker slot, only to simplify implementation.
79    #[derive(Debug)]
80    struct Balance {
81        slot: WorkerSlotId,
82        balance: i32,
83        builder: BitmapBuilder,
84        is_temp: bool,
85    }
86
87    let (expected, mut remain) = vnode_count.div_rem(&selected_slots.len());
88    let mut balances: HashMap<WorkerSlotId, Balance> = HashMap::default();
89
90    for slot in &selected_slots {
91        let mut balance = Balance {
92            slot: *slot,
93            balance: -(expected as i32),
94            builder: BitmapBuilder::zeroed(vnode_count),
95            is_temp: false,
96        };
97
98        if remain > 0 {
99            balance.balance -= 1;
100            remain -= 1;
101        }
102        balances.insert(*slot, balance);
103    }
104
105    // Now to maintain affinity, if a hint has been provided via `hint_worker_slot_mapping`, follow
106    // that mapping to adjust balances.
107    let mut temp_slot = Balance {
108        slot: WorkerSlotId::new(0u32.into(), usize::MAX), /* This id doesn't matter for `temp_slot`. It's distinguishable via `is_temp`. */
109        balance: 0,
110        builder: BitmapBuilder::zeroed(vnode_count),
111        is_temp: true,
112    };
113    match hint_worker_slot_mapping {
114        Some(hint_worker_slot_mapping) => {
115            for (vnode, worker_slot) in hint_worker_slot_mapping.iter_with_vnode() {
116                let b = if selected_slots_set.contains(&worker_slot) {
117                    // Assign vnode to the same worker slot as hint.
118                    balances.get_mut(&worker_slot).unwrap()
119                } else {
120                    // Assign vnode that doesn't belong to any worker slot to `temp_slot`
121                    // temporarily. They will be reassigned later.
122                    &mut temp_slot
123                };
124
125                b.balance += 1;
126                b.builder.set(vnode.to_index(), true);
127            }
128        }
129        None => {
130            // No hint is provided, assign all vnodes to `temp_pu`.
131            for vnode in VirtualNode::all(vnode_count) {
132                temp_slot.balance += 1;
133                temp_slot.builder.set(vnode.to_index(), true);
134            }
135        }
136    }
137
138    // The final step is to move vnodes from worker slots with positive balance to worker slots
139    // with negative balance, until all worker slots are of 0 balance.
140    // A double-ended queue with worker slots ordered by balance in descending order is consumed:
141    // 1. Peek 2 worker slots from front and back.
142    // 2. It any of them is of 0 balance, pop it and go to step 1.
143    // 3. Otherwise, move vnodes from front to back.
144    let mut balances: VecDeque<_> = balances
145        .into_values()
146        .chain(std::iter::once(temp_slot))
147        .sorted_by_key(|b| b.balance)
148        .rev()
149        .collect();
150
151    let mut results: HashMap<WorkerSlotId, Bitmap> = HashMap::default();
152
153    while !balances.is_empty() {
154        if balances.len() == 1 {
155            let single = balances.pop_front().unwrap();
156            assert_eq!(single.balance, 0);
157            if !single.is_temp {
158                results.insert(single.slot, single.builder.finish());
159            }
160            break;
161        }
162        let mut src = balances.pop_front().unwrap();
163        let mut dst = balances.pop_back().unwrap();
164        let n = std::cmp::min(src.balance.abs(), dst.balance.abs());
165        let mut moved = 0;
166        for idx in 0..vnode_count {
167            if moved >= n {
168                break;
169            }
170            if src.builder.is_set(idx) {
171                src.builder.set(idx, false);
172                assert!(!dst.builder.is_set(idx));
173                dst.builder.set(idx, true);
174                moved += 1;
175            }
176        }
177        src.balance -= n;
178        dst.balance += n;
179        if src.balance != 0 {
180            balances.push_front(src);
181        } else if !src.is_temp {
182            results.insert(src.slot, src.builder.finish());
183        }
184
185        if dst.balance != 0 {
186            balances.push_back(dst);
187        } else if !dst.is_temp {
188            results.insert(dst.slot, dst.builder.finish());
189        }
190    }
191
192    let mut worker_result = HashMap::new();
193
194    for (worker_slot, bitmap) in results {
195        worker_result
196            .entry(worker_slot)
197            .or_insert(Bitmap::zeros(vnode_count))
198            .bitor_assign(&bitmap);
199    }
200
201    Some(WorkerSlotMapping::from_bitmaps(&worker_result))
202}
203
204#[cfg(test)]
205mod tests {
206
207    use risingwave_common::hash::WorkerSlotMapping;
208    use risingwave_pb::common::worker_node::Property;
209    use risingwave_pb::common::{WorkerNode, WorkerType};
210
211    use crate::hash::VirtualNode;
212
213    /// [`super::place_vnode`] with [`VirtualNode::COUNT_FOR_TEST`] as the vnode count.
214    fn place_vnode(
215        hint_worker_slot_mapping: Option<&WorkerSlotMapping>,
216        workers: &[WorkerNode],
217        max_parallelism: Option<usize>,
218    ) -> Option<WorkerSlotMapping> {
219        super::place_vnode(
220            hint_worker_slot_mapping,
221            workers,
222            max_parallelism,
223            VirtualNode::COUNT_FOR_TEST,
224        )
225    }
226
227    #[test]
228    fn test_place_vnode() {
229        assert_eq!(VirtualNode::COUNT_FOR_TEST, 256);
230
231        let serving_property = Property {
232            is_serving: true,
233            is_streaming: false,
234            ..Default::default()
235        };
236
237        let count_same_vnode_mapping = |wm1: &WorkerSlotMapping, wm2: &WorkerSlotMapping| {
238            assert_eq!(wm1.len(), 256);
239            assert_eq!(wm2.len(), 256);
240            let mut count: usize = 0;
241            for idx in 0..VirtualNode::COUNT_FOR_TEST {
242                let vnode = VirtualNode::from_index(idx);
243                if wm1.get(vnode) == wm2.get(vnode) {
244                    count += 1;
245                }
246            }
247            count
248        };
249
250        let mut property = serving_property.clone();
251        property.parallelism = 1;
252        let worker_1 = WorkerNode {
253            id: 1.into(),
254            r#type: WorkerType::ComputeNode.into(),
255            property: Some(property),
256            ..Default::default()
257        };
258
259        assert!(
260            place_vnode(None, std::slice::from_ref(&worker_1), Some(0)).is_none(),
261            "max_parallelism should >= 0"
262        );
263
264        let re_worker_mapping_2 = place_vnode(None, std::slice::from_ref(&worker_1), None).unwrap();
265        assert_eq!(re_worker_mapping_2.iter_unique().count(), 1);
266
267        let mut property = serving_property.clone();
268        property.parallelism = 50;
269        let worker_2 = WorkerNode {
270            id: 2.into(),
271            property: Some(property),
272            r#type: WorkerType::ComputeNode.into(),
273            ..Default::default()
274        };
275
276        let re_worker_mapping = place_vnode(
277            Some(&re_worker_mapping_2),
278            &[worker_1.clone(), worker_2.clone()],
279            None,
280        )
281        .unwrap();
282
283        assert_eq!(re_worker_mapping.iter_unique().count(), 51);
284        // 1 * 256 + 0 -> 51 * 5 + 1
285        let score = count_same_vnode_mapping(&re_worker_mapping_2, &re_worker_mapping);
286        assert!(score >= 5);
287
288        let mut property = serving_property;
289        property.parallelism = 60;
290        let worker_3 = WorkerNode {
291            id: 3.into(),
292            r#type: WorkerType::ComputeNode.into(),
293            property: Some(property),
294            ..Default::default()
295        };
296        let re_pu_mapping_2 = place_vnode(
297            Some(&re_worker_mapping),
298            &[worker_1.clone(), worker_2.clone(), worker_3.clone()],
299            None,
300        )
301        .unwrap();
302
303        // limited by total pu number
304        assert_eq!(re_pu_mapping_2.iter_unique().count(), 111);
305        // 51 * 5 + 1 -> 111 * 2 + 34
306        let score = count_same_vnode_mapping(&re_pu_mapping_2, &re_worker_mapping);
307        assert!(score >= (2 + 50 * 2));
308        let re_pu_mapping = place_vnode(
309            Some(&re_pu_mapping_2),
310            &[worker_1.clone(), worker_2.clone(), worker_3.clone()],
311            Some(50),
312        )
313        .unwrap();
314        // limited by max_parallelism
315        assert_eq!(re_pu_mapping.iter_unique().count(), 50);
316        // 111 * 2 + 34 -> 50 * 5 + 6
317        let score = count_same_vnode_mapping(&re_pu_mapping, &re_pu_mapping_2);
318        assert!(score >= 50 * 2);
319        let re_pu_mapping_2 = place_vnode(
320            Some(&re_pu_mapping),
321            &[worker_1.clone(), worker_2, worker_3.clone()],
322            None,
323        )
324        .unwrap();
325        assert_eq!(re_pu_mapping_2.iter_unique().count(), 111);
326        // 50 * 5 + 6 -> 111 * 2 + 34
327        let score = count_same_vnode_mapping(&re_pu_mapping_2, &re_pu_mapping);
328        assert!(score >= 50 * 2);
329        let re_pu_mapping =
330            place_vnode(Some(&re_pu_mapping_2), &[worker_1, worker_3.clone()], None).unwrap();
331        // limited by total pu number
332        assert_eq!(re_pu_mapping.iter_unique().count(), 61);
333        // 111 * 2 + 34 -> 61 * 4 + 12
334        let score = count_same_vnode_mapping(&re_pu_mapping, &re_pu_mapping_2);
335        assert!(score >= 61 * 2);
336        assert!(place_vnode(Some(&re_pu_mapping), &[], None).is_none());
337        let re_pu_mapping = place_vnode(Some(&re_pu_mapping), &[worker_3], None).unwrap();
338        assert_eq!(re_pu_mapping.iter_unique().count(), 60);
339        assert!(place_vnode(Some(&re_pu_mapping), &[], None).is_none());
340    }
341}