risingwave_meta/stream/stream_graph/
assignment.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap};
16use std::fmt::Debug;
17use std::hash::{Hash, Hasher};
18use std::num::NonZeroUsize;
19
20use anyhow::{Context, Result, anyhow, ensure};
21
22/// Assign items to weighted containers with optional capacity scaling and deterministic tie-breaking.
23///
24/// Distributes a slice of items (`&[I]`) across a set of containers (`BTreeMap<C, NonZeroUsize>`)
25/// using a three-phase algorithm:
26///
27/// # Type Parameters
28/// - `C`: Container identifier. Must implement `Ord + Hash + Eq + Clone + Debug`.
29/// - `I`: Item type. Must implement `Hash + Eq + Copy + Debug`.
30/// - `S`: Salt type for tie-breaking. Must implement `Hash + Copy`.
31///
32/// # Parameters
33/// - `containers`: Map of containers to their non-zero weights (`BTreeMap<C, NonZeroUsize>`).
34/// - `items`: Slice of items (`&[I]`) to distribute.
35/// - `salt`: A salt value to vary deterministic tie-breaks between equal remainders.
36/// - `capacity_scale_factor_fn`: Callback `(containers, items) -> Option<ScaleFactor>`:
37///     - `Some(f)`: Scale each container’s base quota by `f` (ceiled, but never below base).
38///     - `None`: Remove upper bound (capacity = `usize::MAX`).
39///
40/// # Returns
41/// A `BTreeMap<C, Vec<I>>` mapping each container to the list of assigned items.
42/// - If either `containers` or `items` is empty, return an empty map.
43///
44/// # Panics
45/// - If the sum of all container weights is zero.
46/// - If, during weighted rendezvous, no eligible container remains (invariant violation).
47///
48/// # Complexity
49/// Runs in **O(N · M)** time, where N = `containers.len()` and M = `items.len()`.
50/// Each item is compared against all containers via a weighted rendezvous hash.
51///
52/// # Example
53/// ```rust
54/// # use std::collections::BTreeMap;
55/// # use std::num::NonZeroUsize;
56/// # use risingwave_meta::stream::{assign_items_weighted_with_scale_fn, weighted_scale};
57///
58/// let mut caps = BTreeMap::new();
59/// caps.insert("fast", NonZeroUsize::new(3).unwrap());
60/// caps.insert("slow", NonZeroUsize::new(1).unwrap());
61///
62/// let tasks = vec!["task1", "task2", "task3", "task4"];
63/// let result = assign_items_weighted_with_scale_fn(&caps, &tasks, 0u8, weighted_scale);
64///
65/// // `fast` should receive roughly 3 tasks, `slow` roughly 1
66/// assert_eq!(result.values().map(Vec::len).sum::<usize>(), tasks.len());
67/// ```
68///
69/// # Algorithm
70///
71/// 1. **Quota Calculation**
72///    - Compute `total_weight = sum(w_i)` as `u128`.
73///    - For each container `i` with weight `w_i`:
74///      ```text
75///      ideal_i     = M * w_i
76///      base_quota_i = floor(ideal_i / total_weight)
77///      rem_i        = ideal_i % total_weight
78///      ```
79///    - Let `rem_count = M - sum(base_quota_i)` and sort containers by `rem_i` (desc),
80///      breaking ties by `stable_hash((container, salt))`.
81///    - Give `+1` slot to the first `rem_count` containers.
82///
83/// 2. **Capacity Scaling**
84///    - If `Some(f)`: For each container,
85///      `quota_i = max(base_quota_i, ceil(base_quota_i as f64 * f))`.
86///    - If `None`: Set `quota_i = usize::MAX`.
87///
88/// 3. **Weighted Rendezvous Assignment**
89///    - For each item `x`, compute for each container `i`:
90///      ```text
91///      h = stable_hash((x, i, salt))
92///      r = (h + 1) / (MAX_HASH + 2)       // 0 < r ≤ 1
93///      key_i = -ln(r) / weight_i
94///      ```
95///    - Assign `x` to the container with the smallest `key_i`.
96pub fn assign_items_weighted_with_scale_fn<C, I, S>(
97    containers: &BTreeMap<C, NonZeroUsize>,
98    items: &[I],
99    salt: S,
100    capacity_scale_factor_fn: impl Fn(&BTreeMap<C, NonZeroUsize>, &[I]) -> Option<ScaleFactor>,
101) -> BTreeMap<C, Vec<I>>
102where
103    C: Ord + Hash + Eq + Clone + Debug,
104    I: Hash + Eq + Copy + Clone + Debug,
105    S: Hash + Copy,
106{
107    // Early exit if there is nothing to assign
108    if containers.is_empty() || items.is_empty() {
109        return BTreeMap::default();
110    }
111
112    // Integer-based quota calculation
113    let total_weight: u128 = containers.values().map(|w| w.get() as u128).sum();
114    assert!(
115        total_weight > 0,
116        "Sum of container weights must be non-zero"
117    );
118
119    struct QuotaInfo<'a, C> {
120        container: &'a C,
121        quota: usize,
122        rem_part: u128,
123    }
124
125    let mut infos: Vec<QuotaInfo<'_, C>> = containers
126        .iter()
127        .map(|(container, &weight)| {
128            // Use saturating multiplication to prevent overflow, even though saturation is highly unlikely in practice.
129            let ideal_num = (items.len() as u128).saturating_mul(weight.get() as u128);
130            QuotaInfo {
131                container,
132                quota: (ideal_num / total_weight) as usize,
133                rem_part: ideal_num % total_weight,
134            }
135        })
136        .collect();
137
138    let used: usize = infos.iter().map(|info| info.quota).sum();
139    let remainder = items.len().saturating_sub(used);
140
141    // Distribute remainder slots
142    infos.sort_by(|a, b| {
143        b.rem_part
144            .cmp(&a.rem_part)
145            .then_with(|| stable_hash(&(b.container, salt)).cmp(&stable_hash(&(a.container, salt))))
146    });
147    for info in infos.iter_mut().take(remainder) {
148        info.quota += 1;
149    }
150
151    // Apply capacity scaling
152    let scale_factor = capacity_scale_factor_fn(containers, items);
153    let quotas: HashMap<&C, usize> = infos
154        .into_iter()
155        .map(|info| match scale_factor {
156            Some(f) => {
157                let scaled_f64 = (info.quota as f64 * f.get()).ceil();
158                let scaled = if scaled_f64 >= usize::MAX as f64 {
159                    usize::MAX
160                } else {
161                    scaled_f64 as usize
162                };
163                (info.container, scaled.max(info.quota))
164            }
165            None => (info.container, usize::MAX),
166        })
167        .collect();
168
169    // Prepare assignment map
170    let mut assignment: BTreeMap<C, Vec<I>> = BTreeMap::new();
171
172    // Assign each item using Weighted Rendezvous
173    for &item in items {
174        let mut best: Option<(&C, f64)> = None;
175        for (container, &weight) in containers {
176            let assigned = assignment.get(container).map(Vec::len).unwrap_or(0);
177
178            debug_assert!(quotas.contains_key(container));
179            let quota = quotas.get(container).copied().unwrap_or(0);
180            if assigned >= quota {
181                continue;
182            }
183
184            // Generate a pseudorandom float `r` in the range (0, 1]:
185            // 1. Compute a stable 64-bit hash for the tuple (item, container).
186            // 2. Normalize: `(raw_hash + 1) / (MAX_HASH + 2)` ensures `0 < r <= 1`.
187            let raw_hash = stable_hash(&(item, container, salt));
188            let r = (raw_hash as f64 + 1.0) / (u64::MAX as f64 + 2.0);
189
190            // Compute weighted rendezvous key:
191            // 1. `-ln(r)` maps the interval (0,1] to [0, ∞).
192            // 2. Dividing by `w` biases selection towards containers with higher weight,
193            //    as smaller keys win in the rendezvous algorithm.
194            let key = -r.ln() / (weight.get() as f64);
195
196            match best {
197                None => best = Some((container, key)),
198                Some((_, best_key)) if key < best_key => best = Some((container, key)),
199                _ => {}
200            }
201        }
202
203        // quotas sum (possibly scaled) always >= items.len(), so best is always Some
204        let (winner, _) = best.expect("Invariant violation: no eligible container");
205        assignment
206            .entry(winner.clone())
207            .and_modify(|v| v.push(item))
208            .or_insert_with(|| vec![item]);
209    }
210
211    assignment
212}
213
214/// Stable hash utility
215fn stable_hash<T: Hash>(t: &T) -> u64 {
216    let mut hasher = twox_hash::XxHash64::with_seed(0);
217    t.hash(&mut hasher);
218    hasher.finish()
219}
220
221/// A validated, non-negative, finite scale factor.
222#[derive(Debug, Copy, Clone)]
223pub struct ScaleFactor(f64);
224
225impl ScaleFactor {
226    /// Creates a new `ScaleFactor` if the value is valid.
227    ///
228    /// A valid scale factor must be finite and non-negative.
229    pub fn new(value: f64) -> Option<Self> {
230        if value.is_finite() && value >= 0.0 {
231            Some(ScaleFactor(value))
232        } else {
233            None
234        }
235    }
236
237    /// Gets the inner f64 value.
238    pub fn get(&self) -> f64 {
239        self.0
240    }
241}
242
243/// A no-op capacity scaling function: always returns `None`.
244pub fn unbounded_scale<C, I>(
245    _containers: &BTreeMap<C, NonZeroUsize>,
246    _items: &[I],
247) -> Option<ScaleFactor> {
248    None
249}
250
251/// A unit capacity scaling function: always returns `Some(1.0)`.
252pub fn weighted_scale<C, I>(
253    _containers: &BTreeMap<C, NonZeroUsize>,
254    _items: &[I],
255) -> Option<ScaleFactor> {
256    ScaleFactor::new(1.0)
257}
258
259/// Defines the capacity assignment strategy for containers.
260///
261/// - `Weighted`: Distribute items proportionally to container weights, applying any configured scale factor.
262/// - `Unbounded`: No capacity limit; containers can receive any number of items.
263#[derive(Clone, Copy, Debug, Eq, PartialEq)]
264#[non_exhaustive]
265pub enum CapacityMode {
266    /// Use each container’s weight to bound how many items it can receive.
267    /// When used for actor-to-worker assignment, this typically means actors are distributed
268    /// strictly proportionally to worker weights (i.e., using a scale factor of 1.0).
269    Weighted,
270
271    /// Ignore per-container quotas entirely—every container can take an unlimited number of items.
272    Unbounded,
273}
274
275/// Defines the vnode distribution strategy for hierarchical assignment.
276///
277/// - `RawWorkerWeights`: Distribute vnodes across workers using the original worker weight values.
278/// - `ActorCounts`: Distribute vnodes based on the number of actors assigned to each worker.
279#[derive(Debug, Copy, Clone, Eq, PartialEq)]
280#[non_exhaustive]
281pub enum BalancedBy {
282    /// Use each worker's raw weight when allocating vnodes.
283    RawWorkerWeights,
284
285    /// Use the count of actors per worker as the weight for vnode distribution.
286    /// This strategy aims to balance the number of vnodes per actor across workers.
287    ActorCounts,
288}
289
290/// Hierarchically distributes virtual nodes to actors in two weighted stages with deterministic tie-breaking.
291///
292/// This function first assigns each actor to a worker, then distributes all virtual nodes among
293/// those active workers, and finally partitions each worker’s vnodes among its actors in a simple
294/// round-robin fashion.
295///
296/// # Type Parameters
297/// - `W`: Worker identifier. Must implement `Ord + Hash + Eq + Clone + Debug`.
298/// - `A`: Actor identifier. Must implement `Ord + Hash + Eq + Copy + Clone + Debug`.
299/// - `V`: Virtual node type. Must implement `Hash + Eq + Copy + Clone + Debug`.
300/// - `S`: Salt type for deterministic tie-breaking. Must implement `Hash + Copy`.
301///
302/// # Parameters
303/// - `workers`: A `BTreeMap<W, NonZeroUsize>` mapping each worker to its positive weight.
304/// - `actors`: A slice of actors (`&[A]`) to place on workers.
305/// - `virtual_nodes`: A slice of vnodes (`&[V]`) to distribute across actors.
306/// - `salt`: A salt value to break ties in hashing, kept constant per invocation for reproducibility.
307/// - `actor_capacity_mode`: A `CapacityMode` deciding how actors are packed onto workers:
308///     - `Weighted`: respect `workers` weights when placing actors.
309///     - `Unbounded`: ignore capacity limits when placing actors.
310/// - `balanced_by`: A `BalancedBy` enum determining vnode distribution strategy:
311///     - `RawWorkerWeights`: prioritize original worker weights (with actor count as lower bound).
312///     - `ActorCounts`: prioritize equal vnode counts per actor (actor-oriented).
313///
314/// # Returns
315/// A `BTreeMap<W, BTreeMap<A, Vec<V>>>` mapping each worker to its map of actors and their assigned vnodes.
316/// - Only workers with at least one actor appear in the result.
317/// - Each actor receives at least one vnode (invariant).
318///
319/// # Errors
320/// - Returns an error if `actors` is empty or `virtual_nodes` is empty.
321/// - Returns an error if `actors.len() > virtual_nodes.len()`, since each actor must receive at least one vnode.
322///
323/// # Complexity
324/// Runs in **O((W + A + V) · log W + V · W)** time:
325/// - Actor → Worker assignment is O(A · W) via weighted rendezvous + O(W + A) map operations.
326/// - VNode → Worker assignment is O(V · W) plus quota computation O(W log W).
327/// - VNode → Actor partition is O(V).
328///
329/// # Example
330/// ```rust
331/// # use std::collections::BTreeMap;
332/// # use std::num::NonZeroUsize;
333/// # use risingwave_meta::stream::{assign_hierarchical, BalancedBy, CapacityMode};
334///
335/// // Define two workers with numeric IDs and weights
336/// let mut workers: BTreeMap<u8, NonZeroUsize> = BTreeMap::new();
337/// workers.insert(1, NonZeroUsize::new(2).unwrap());
338/// workers.insert(2, NonZeroUsize::new(3).unwrap());
339///
340/// // Actors also identified by numbers
341/// let actors: Vec<u16> = vec![10, 20, 30];
342///
343/// // Virtual nodes are simple 0–8
344/// let vnodes: Vec<u16> = (0..9).collect();
345///
346/// let assignment = assign_hierarchical(
347///     &workers,
348///     &actors,
349///     &vnodes,
350///     0u8,                          // salt
351///     CapacityMode::Weighted,       // actor -> worker mode
352///     BalancedBy::RawWorkerWeights, // vnode -> worker mode
353/// )
354/// .unwrap();
355///
356/// for (worker_id, actor_map) in assignment {
357///     println!("Worker {}:", worker_id);
358///     for (actor_id, vn_list) in actor_map {
359///         println!("  Actor {} -> {:?}", actor_id, vn_list);
360///     }
361/// }
362/// ```
363///
364/// # Algorithm
365///
366/// 1. **Actors → Workers**
367///    - Use weighted or unbounded rendezvous hashing to assign each actor to exactly one worker,
368///      based on `actor_capacity_mode` and `workers` weights.
369///    - Build `actor_to_worker: BTreeMap<W, Vec<A>>`.
370///
371/// 2. **`VNodes` → Workers**
372///    - If `RawWorkerWeights`: compute per-worker quotas with `compute_worker_quotas`, ensuring
373///      each active worker’s quota ≥ its actor count and quotas sum = total vnodes.
374///    - If `ActorCounts`: set each worker’s weight = its actor count.
375///    - Run `assign_items_weighted_with_scale_fn` on vnodes vs. the computed weights,
376///      yielding `vnode_to_worker: BTreeMap<W, Vec<V>>`.
377///
378/// 3. **`VNodes` → Actors**
379///    - For each worker, take its vnode list and assign them to actors in simple round-robin:
380///      iterate vnodes in order, dispatching index `% actor_list.len()`.
381///    - Collect into final `BTreeMap<W, BTreeMap<A, Vec<V>>>`.
382pub fn assign_hierarchical<W, A, V, S>(
383    workers: &BTreeMap<W, NonZeroUsize>,
384    actors: &[A],
385    virtual_nodes: &[V],
386    salt: S,
387    actor_capacity_mode: CapacityMode,
388    balanced_by: BalancedBy,
389) -> anyhow::Result<BTreeMap<W, BTreeMap<A, Vec<V>>>>
390where
391    W: Ord + Hash + Eq + Clone + Debug,
392    A: Ord + Hash + Eq + Copy + Clone + Debug,
393    V: Hash + Eq + Copy + Clone + Debug,
394    S: Hash + Copy,
395{
396    if actors.is_empty() {
397        return Err(anyhow!("no actors to assign"));
398    }
399
400    if virtual_nodes.is_empty() {
401        return Err(anyhow!("no vnodes to assign"));
402    }
403
404    // Validate input: ensure vnode count can cover all actors
405    if actors.len() > virtual_nodes.len() {
406        return Err(anyhow!(
407            "actor count ({}) exceeds vnode count ({})",
408            actors.len(),
409            virtual_nodes.len()
410        ));
411    }
412
413    let actor_capacity_fn = match actor_capacity_mode {
414        CapacityMode::Weighted => weighted_scale,
415        CapacityMode::Unbounded => unbounded_scale,
416    };
417
418    // Distribute actors across workers based on their weight
419    let actor_to_worker: BTreeMap<W, Vec<A>> =
420        assign_items_weighted_with_scale_fn(workers, actors, salt, actor_capacity_fn);
421
422    // Build unit-weight map for active workers (those with assigned actors)
423    let mut active_worker_weights: BTreeMap<W, NonZeroUsize> = BTreeMap::new();
424
425    match balanced_by {
426        BalancedBy::RawWorkerWeights => {
427            // Worker oriented: balanced by raw worker weights
428            let mut actor_counts: HashMap<&W, usize> = HashMap::new();
429            for (worker, actor_list) in &actor_to_worker {
430                if !actor_list.is_empty() {
431                    let worker_weight = workers.get(worker).expect("Worker should exist");
432                    active_worker_weights.insert(worker.clone(), *worker_weight);
433                    actor_counts.insert(worker, actor_list.len());
434                }
435            }
436
437            // Recalculate the worker weight to prevent actors from being assigned to vnode.
438            active_worker_weights = compute_worker_quotas(
439                &active_worker_weights,
440                &actor_counts,
441                virtual_nodes.len(),
442                salt,
443            );
444        }
445        BalancedBy::ActorCounts => {
446            // Actor oriented: balanced by actor counts
447            for (worker, actor_list) in &actor_to_worker {
448                debug_assert!(!actor_list.is_empty());
449                if let Some(actor_count) = NonZeroUsize::new(actor_list.len()) {
450                    active_worker_weights.insert(worker.clone(), actor_count);
451                }
452            }
453        }
454    }
455
456    // Distribute vnodes evenly among the active workers
457    let vnode_to_worker: BTreeMap<W, Vec<V>> = assign_items_weighted_with_scale_fn(
458        &active_worker_weights,
459        virtual_nodes,
460        salt,
461        weighted_scale,
462    );
463
464    // Assign each worker's vnodes to its actors in a round-robin fashion
465    let mut assignment = BTreeMap::new();
466    for (worker, actor_list) in actor_to_worker {
467        let assigned_vnodes = vnode_to_worker.get(&worker).cloned().unwrap_or_default();
468
469        // Actors and vnodes can only both be empty at the same time or both be non-empty at the same time.
470        assert_eq!(
471            assigned_vnodes.is_empty(),
472            actor_list.is_empty(),
473            "Invariant violation: empty actor list should have empty vnodes"
474        );
475
476        debug_assert!(
477            assigned_vnodes.len() >= actor_list.len(),
478            "Invariant violation: assigned vnodes should be at least as many as actors"
479        );
480
481        // Within the same worker, use a simple round-robin approach to distribute vnodes relatively evenly among actors.
482        let mut actor_map = BTreeMap::new();
483        for (index, vnode) in assigned_vnodes.into_iter().enumerate() {
484            let actor = actor_list[index % actor_list.len()];
485            actor_map.entry(actor).or_insert(Vec::new()).push(vnode);
486        }
487        assignment.insert(worker, actor_map);
488    }
489
490    Ok(assignment)
491}
492
493/// Computes per-worker VNode quotas based on actor counts and worker weights.
494///
495/// This function allocates virtual nodes to workers such that:
496/// - Each active worker receives at least as many virtual nodes as it has actors (`base_quota`).
497/// - The remaining virtual nodes (`extra_vnodes`) are distributed proportionally to the original worker weights.
498/// - Deterministic tie-breaking on equal remainders uses a hash of (`salt`, `worker_id`).
499///
500/// # Type Parameters
501/// - `W`: Worker identifier type. Must implement `Ord`, `Clone`, `Hash`, `Eq`, and `Debug`.
502/// - `S`: Salt type. Used for deterministic hashing. Must implement `Hash` and `Copy`.
503///
504/// # Parameters
505/// - `workers`: A `BTreeMap` mapping each worker ID to its non-zero weight (`NonZeroUsize`).
506/// - `actor_counts`: A `HashMap` mapping each worker ID to the number of actors assigned.
507/// - `total_vnodes`: The total number of virtual nodes to distribute across all active workers.
508/// - `salt`: A salt value for deterministic tie-breaking in remainder sorting.
509///
510/// # Returns
511/// A `BTreeMap` from worker ID to its allocated quota (`NonZeroUsize`), such that the sum of all quotas equals `total_vnodes`.
512///
513/// # Panics
514/// Panics if any computed quota is zero, which should not occur when `total_vnodes >= sum(actor_counts)`.
515///
516/// # Algorithm
517/// 1. Compute `base_total` as the sum of all actor counts.
518/// 2. Compute `extra_vnodes = total_vnodes - base_total`.
519/// 3. For each active worker:
520///    a. Set `base_quota` equal to its actor count.
521///    b. Compute `ideal_extra = extra_vnodes * weight / total_weight`.
522///    c. Record `extra_floor = floor(ideal_extra)` and `extra_remainder = ideal_extra % total_weight`.
523/// 4. Sort workers by descending `extra_remainder`; tie-break by `stable_hash((salt, worker_id))` ascending.
524/// 5. Distribute the remaining slots (`extra_vnodes - sum(extra_floor)`) by incrementing `extra_floor` for the top workers.
525/// 6. Final quota for each worker is `base_quota + extra_floor`.
526fn compute_worker_quotas<W, S>(
527    workers: &BTreeMap<W, NonZeroUsize>,
528    actor_counts: &HashMap<&W, usize>,
529    total_vnodes: usize,
530    salt: S,
531) -> BTreeMap<W, NonZeroUsize>
532where
533    W: Ord + Clone + Hash + Eq + Debug,
534    S: Hash + Copy,
535{
536    let base_total: usize = actor_counts.values().sum();
537
538    assert!(
539        base_total <= total_vnodes,
540        "Total vnodes ({}) must be greater than or equal to the sum of actor counts ({})",
541        total_vnodes,
542        base_total
543    );
544
545    let extra_vnodes = total_vnodes - base_total;
546
547    // Quota calculation is only performed for Workers with actors.
548    let active_workers: Vec<&W> = actor_counts.keys().cloned().collect();
549    let total_weight: u128 = active_workers
550        .iter()
551        .map(|worker_id| workers[worker_id].get() as u128)
552        .sum();
553
554    assert!(total_weight > 0, "Sum of worker weights must be non-zero");
555    assert!(total_vnodes > 0, "Sum of vnodes must be non-zero");
556
557    // Temporary structure: stores calculation information
558    struct QuotaInfo<W> {
559        worker_id: W,
560        base_quota: usize,
561        extra_floor: usize,
562        extra_remainder: u128,
563    }
564
565    // Preliminary calculation of floor and remainder
566    let mut quota_list: Vec<QuotaInfo<&W>> = active_workers
567        .into_iter()
568        .map(|worker_id| {
569            let base_quota = actor_counts[worker_id];
570            let weight = workers[worker_id].get() as u128;
571            let ideal_extra = extra_vnodes as u128 * weight;
572            let extra_floor = (ideal_extra / total_weight) as usize;
573            let extra_remainder = ideal_extra % total_weight;
574            QuotaInfo {
575                worker_id,
576                base_quota,
577                extra_floor,
578                extra_remainder,
579            }
580        })
581        .collect();
582
583    // Distribute the remaining slots (sorted by remainder, the first N get +1)
584    let used_extra: usize = quota_list.iter().map(|quota| quota.extra_floor).sum();
585    let remaining_slots = extra_vnodes - used_extra;
586    quota_list.sort_by(|a, b| {
587        // First, sort by remainder in descending order.
588        b.extra_remainder
589            .cmp(&a.extra_remainder)
590            // If remainders are the same, then sort by the hash value of (salt, worker_id) in ascending order.
591            .then_with(|| stable_hash(&(salt, a.worker_id)).cmp(&stable_hash(&(salt, b.worker_id))))
592    });
593    for info in quota_list.iter_mut().take(remaining_slots) {
594        info.extra_floor += 1;
595    }
596
597    // Construct the final quotas
598    let mut quotas = BTreeMap::new();
599    for info in quota_list {
600        let total = info.base_quota + info.extra_floor;
601        quotas.insert(info.worker_id.clone(), NonZeroUsize::new(total).unwrap());
602    }
603    quotas
604}
605
606/// A lightweight struct to represent a chunk of `VNodes` during assignment.
607/// This is an internal implementation detail.
608#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]
609struct VnodeChunk(u32);
610
611impl From<usize> for VnodeChunk {
612    fn from(id: usize) -> Self {
613        // Assuming VNode IDs do not exceed u32::MAX for simplicity.
614        Self(id as u32)
615    }
616}
617
618impl VnodeChunk {
619    fn id(&self) -> usize {
620        // Convert back to usize for external use.
621        self.0 as usize
622    }
623}
624
625/// Defines the VNode chunking strategy for assignment.
626#[derive(Debug, Copy, Clone, PartialEq, Eq)]
627pub enum VnodeChunkingStrategy {
628    /// Each VNode is assigned individually. This is the default.
629    NoChunking,
630
631    /// The chunk size is automatically determined to maximize VNode contiguity,
632    /// ensuring that the number of chunks is at least the number of actors.
633    MaximizeContiguity,
634}
635
636/// Core assigner with configurable strategies.
637#[derive(Debug, Clone)]
638pub struct Assigner<S> {
639    salt: S,
640    actor_capacity: CapacityMode,
641    balance_strategy: BalancedBy,
642    vnode_chunking_strategy: VnodeChunkingStrategy,
643}
644
645/// Builder for [`Assigner`].
646#[derive(Debug)]
647pub struct AssignerBuilder<S> {
648    salt: S,
649    actor_capacity: CapacityMode,
650    balance_strategy: BalancedBy,
651    vnode_chunking_strategy: VnodeChunkingStrategy,
652}
653
654impl<S: Hash + Copy> AssignerBuilder<S> {
655    /// Create a new builder with the given salt.
656    pub fn new(salt: S) -> Self {
657        Self {
658            salt,
659            actor_capacity: CapacityMode::Weighted,
660            balance_strategy: BalancedBy::RawWorkerWeights,
661            vnode_chunking_strategy: VnodeChunkingStrategy::NoChunking,
662        }
663    }
664
665    /// Use weighted capacity when assigning actors.
666    pub fn with_capacity_weighted(&mut self) -> &mut Self {
667        self.actor_capacity = CapacityMode::Weighted;
668        self
669    }
670
671    /// Use unbounded capacity when assigning actors.
672    pub fn with_capacity_unbounded(&mut self) -> &mut Self {
673        self.actor_capacity = CapacityMode::Unbounded;
674        self
675    }
676
677    /// Balance vnodes by actor counts (actor‐oriented).
678    pub fn with_actor_oriented_balancing(&mut self) -> &mut Self {
679        self.balance_strategy = BalancedBy::ActorCounts;
680        self
681    }
682
683    /// Balance vnodes by raw worker weights (worker‐oriented).
684    pub fn with_worker_oriented_balancing(&mut self) -> &mut Self {
685        self.balance_strategy = BalancedBy::RawWorkerWeights;
686        self
687    }
688
689    /// Sets the vnode chunking strategy.
690    pub fn with_vnode_chunking_strategy(&mut self, strategy: VnodeChunkingStrategy) -> &mut Self {
691        self.vnode_chunking_strategy = strategy;
692        self
693    }
694
695    /// Finalize and build the [`Assigner`].
696    pub fn build(&self) -> Assigner<S> {
697        Assigner {
698            salt: self.salt,
699            actor_capacity: self.actor_capacity,
700            balance_strategy: self.balance_strategy,
701            vnode_chunking_strategy: self.vnode_chunking_strategy,
702        }
703    }
704}
705
706impl<S: Hash + Copy> Assigner<S> {
707    /// Assigns each actor to a worker according to `CapacityMode`.
708    pub fn assign_actors<C, I>(
709        &self,
710        workers: &BTreeMap<C, NonZeroUsize>,
711        actors: &[I],
712    ) -> BTreeMap<C, Vec<I>>
713    where
714        C: Ord + Hash + Eq + Clone + Debug,
715        I: Hash + Eq + Copy + Debug,
716    {
717        let scale_fn = match self.actor_capacity {
718            CapacityMode::Weighted => weighted_scale,
719            CapacityMode::Unbounded => unbounded_scale,
720        };
721        assign_items_weighted_with_scale_fn(workers, actors, self.salt, scale_fn)
722    }
723
724    /// Returns how many actors each worker would receive for `actor_count` actors.
725    pub fn count_actors_per_worker<C>(
726        &self,
727        workers: &BTreeMap<C, NonZeroUsize>,
728        actor_count: usize,
729    ) -> BTreeMap<C, usize>
730    where
731        C: Ord + Hash + Eq + Clone + Debug,
732    {
733        let synthetic = (0..actor_count).collect::<Vec<_>>();
734        vec_len_map(self.assign_actors(workers, &synthetic))
735    }
736
737    /// Hierarchical assignment: Actors → Workers → `VNodes` → Actors.
738    pub fn assign_hierarchical<W, A, V>(
739        &self,
740        workers: &BTreeMap<W, NonZeroUsize>,
741        actors: &[A],
742        vnodes: &[V],
743    ) -> Result<BTreeMap<W, BTreeMap<A, Vec<V>>>>
744    where
745        W: Ord + Hash + Eq + Clone + Debug,
746        A: Ord + Hash + Eq + Copy + Debug,
747        V: Hash + Eq + Copy + Debug,
748    {
749        ensure!(
750            !workers.is_empty(),
751            "no workers to assign; assignment is meaningless"
752        );
753        ensure!(
754            !actors.is_empty(),
755            "no actors to assign; assignment is meaningless"
756        );
757        ensure!(
758            !vnodes.is_empty(),
759            "no vnodes to assign; assignment is meaningless"
760        );
761        ensure!(
762            vnodes.len() >= actors.len(),
763            "not enough vnodes ({}) for actors ({}); each actor needs at least one vnode",
764            vnodes.len(),
765            actors.len()
766        );
767
768        let chunk_size = match self.vnode_chunking_strategy {
769            VnodeChunkingStrategy::NoChunking => {
770                return assign_hierarchical(
771                    workers,
772                    actors,
773                    vnodes,
774                    self.salt,
775                    self.actor_capacity,
776                    self.balance_strategy,
777                )
778                .context("hierarchical assignment failed");
779            }
780
781            VnodeChunkingStrategy::MaximizeContiguity => {
782                // Automatically calculate chunk size to be as large as possible
783                // while ensuring every actor can receive at least one chunk.
784                // The `.max(1)` ensures the chunk size is at least 1.
785                (vnodes.len() / actors.len()).max(1)
786            }
787        };
788
789        // Calculate the number of chunks using ceiling division.
790        let num_chunks = vnodes.len().div_ceil(chunk_size);
791
792        // The `MaximizeContiguity` strategy inherently ensures `num_chunks >= actors.len()`.
793        // This assertion serves as a sanity check for our logic.
794        debug_assert!(
795            num_chunks >= actors.len(),
796            "Invariant violation: MaximizeContiguity should always produce enough chunks."
797        );
798
799        // Create VNode chunks to be used as the items for assignment.
800        let chunks: Vec<VnodeChunk> = (0..num_chunks).map(VnodeChunk::from).collect();
801
802        // Call the underlying hierarchical assignment function with chunks as items.
803        let chunk_assignment = assign_hierarchical(
804            workers,
805            actors,
806            &chunks,
807            self.salt,
808            self.actor_capacity,
809            self.balance_strategy,
810        )
811        .context("hierarchical assignment of chunks failed")?;
812
813        // Convert the assignment of `VnodeChunk` back to an assignment of the original `V`.
814        let mut final_assignment = BTreeMap::new();
815        for (worker, actor_map) in chunk_assignment {
816            let mut new_actor_map = BTreeMap::new();
817            for (actor, assigned_chunks) in actor_map {
818                // Expand the list of chunks into a flat list of VNodes.
819                let assigned_vnodes: Vec<V> = assigned_chunks
820                    .into_iter()
821                    .flat_map(|chunk| {
822                        let start_idx = chunk.id() * chunk_size;
823                        // Ensure the end index does not go out of bounds.
824                        let end_idx = (start_idx + chunk_size).min(vnodes.len());
825                        // Get the corresponding VNodes from the original slice.
826                        vnodes[start_idx..end_idx].iter().copied()
827                    })
828                    .collect();
829
830                if !assigned_vnodes.is_empty() {
831                    new_actor_map.insert(actor, assigned_vnodes);
832                }
833            }
834            if !new_actor_map.is_empty() {
835                final_assignment.insert(worker, new_actor_map);
836            }
837        }
838
839        Ok(final_assignment)
840    }
841
842    /// Hierarchical counts: how many vnodes each actor gets.
843    pub fn assign_hierarchical_counts<W, A>(
844        &self,
845        workers: &BTreeMap<W, NonZeroUsize>,
846        actor_count: usize,
847        vnode_count: usize,
848    ) -> Result<BTreeMap<W, BTreeMap<A, usize>>>
849    where
850        W: Ord + Hash + Eq + Clone + Debug,
851        A: Ord + Hash + Eq + Copy + Debug + From<usize>,
852    {
853        let actors = (0..actor_count).map(A::from).collect::<Vec<_>>();
854        let vnodes = (0..vnode_count).collect::<Vec<_>>();
855        let full = self.assign_hierarchical(workers, &actors, &vnodes)?;
856        Ok(full
857            .into_iter()
858            .map(|(w, actor_map)| (w, vec_len_map(actor_map)))
859            .collect())
860    }
861}
862
863/// Helper: maps each `Vec<V>` to its length.
864fn vec_len_map<K, V>(map: BTreeMap<K, Vec<V>>) -> BTreeMap<K, usize>
865where
866    K: Ord,
867{
868    map.into_iter().map(|(k, v)| (k, v.len())).collect()
869}
870
871#[cfg(test)]
872mod tests {
873    use std::collections::{BTreeMap, HashMap};
874    use std::num::NonZeroUsize;
875
876    use super::*;
877
878    // --- Tests for `assign_items_weighted_with_scale_fn` ---
879
880    #[test]
881    fn empty_containers_or_items_yields_empty_map() {
882        let empty_containers: BTreeMap<&str, NonZeroUsize> = BTreeMap::new();
883        let items = vec![1, 2, 3];
884        let result =
885            assign_items_weighted_with_scale_fn(&empty_containers, &items, 0u8, weighted_scale);
886        assert!(
887            result.is_empty(),
888            "Expected empty map when containers empty"
889        );
890
891        let mut containers = BTreeMap::new();
892        containers.insert("c1", NonZeroUsize::new(1).unwrap());
893        let empty_items: Vec<i32> = Vec::new();
894        let result2 =
895            assign_items_weighted_with_scale_fn(&containers, &empty_items, 0u8, weighted_scale);
896        assert!(result2.is_empty(), "Expected empty map when items empty");
897    }
898
899    #[test]
900    fn single_container_receives_all_items() {
901        let mut containers = BTreeMap::new();
902        containers.insert("only", NonZeroUsize::new(5).unwrap());
903        let items = vec![10, 20, 30];
904
905        let assignment =
906            assign_items_weighted_with_scale_fn(&containers, &items, 1u8, weighted_scale);
907
908        assert_eq!(assignment.len(), 1, "Only one container should be present");
909        let assigned = &assignment[&"only"];
910        assert_eq!(assigned, &items, "Single container should get all items");
911    }
912
913    #[test]
914    fn equal_weights_divisible_split_evenly() {
915        let mut containers = BTreeMap::new();
916        containers.insert("A", NonZeroUsize::new(1).unwrap());
917        containers.insert("B", NonZeroUsize::new(1).unwrap());
918        let items = vec![1, 2, 3, 4];
919
920        let result = assign_items_weighted_with_scale_fn(&containers, &items, 2u8, weighted_scale);
921        let a_count = result[&"A"].len();
922        let b_count = result[&"B"].len();
923        assert_eq!(a_count, 2, "Container A should receive 2 items");
924        assert_eq!(b_count, 2, "Container B should receive 2 items");
925        assert_eq!(a_count + b_count, items.len(), "All items must be assigned");
926    }
927
928    #[test]
929    fn equal_weights_non_divisible_split_remainder_assigned() {
930        let mut containers = BTreeMap::new();
931        containers.insert("X", NonZeroUsize::new(1).unwrap());
932        containers.insert("Y", NonZeroUsize::new(1).unwrap());
933        let items = vec![1, 2, 3];
934
935        let result = assign_items_weighted_with_scale_fn(&containers, &items, 5u8, weighted_scale);
936        let x_count = result.get(&"X").map(Vec::len).unwrap_or(0);
937        let y_count = result.get(&"Y").map(Vec::len).unwrap_or(0);
938        assert_eq!(x_count + y_count, items.len(), "All items must be assigned");
939        assert!(
940            x_count == 1 && y_count == 2,
941            "Container X should get 1 items, the other 2, but got {} and {}",
942            x_count,
943            y_count
944        );
945    }
946
947    #[test]
948    fn unequal_weights_respect_base_quota() {
949        let mut containers = BTreeMap::new();
950        containers.insert("low", NonZeroUsize::new(1).unwrap());
951        containers.insert("high", NonZeroUsize::new(3).unwrap());
952        let items = vec![100, 200, 300, 400];
953
954        let result = assign_items_weighted_with_scale_fn(&containers, &items, 7u8, weighted_scale);
955        let low_count = result[&"low"].len();
956        let high_count = result[&"high"].len();
957        // low weight should get 1, high weight 3
958        assert_eq!(low_count, 1, "Low-weight container should get 1 item");
959        assert_eq!(high_count, 3, "High-weight container should get 3 items");
960    }
961
962    #[test]
963    fn deterministic_given_same_salt() {
964        let mut containers = BTreeMap::new();
965        containers.insert("A", NonZeroUsize::new(2).unwrap());
966        containers.insert("B", NonZeroUsize::new(1).unwrap());
967        let items = vec![5, 6, 7, 8];
968
969        let out1 = assign_items_weighted_with_scale_fn(&containers, &items, 42u8, weighted_scale);
970        let out2 = assign_items_weighted_with_scale_fn(&containers, &items, 42u8, weighted_scale);
971        assert_eq!(out1, out2, "Same salt should produce identical assignments");
972    }
973
974    #[test]
975    fn assign_items_unbounded_scale_ignores_proportional_quota() {
976        let mut containers = BTreeMap::new();
977        let container_a_id = "A";
978        let container_b_id = "B";
979        containers.insert(container_a_id, NonZeroUsize::new(1).unwrap()); // Low weight
980        containers.insert(container_b_id, NonZeroUsize::new(100).unwrap()); // High weight
981        let items: Vec<i32> = (0..100).collect(); // 100 items
982        let salt = 123u8;
983
984        // 1. Assignment with unit_scale (strictly proportional quotas)
985        let assignment_unit =
986            assign_items_weighted_with_scale_fn(&containers, &items, salt, weighted_scale);
987        let a_count_unit = assignment_unit.get(container_a_id).map_or(0, Vec::len);
988        let b_count_unit = assignment_unit.get(container_b_id).map_or(0, Vec::len);
989
990        assert!(
991            a_count_unit < 10,
992            "With weighted_scale, A ({}) should have few items, got {}",
993            container_a_id,
994            a_count_unit
995        );
996        assert!(
997            b_count_unit > 90,
998            "With weighted_scale, B ({}) should have many items, got {}",
999            container_b_id,
1000            b_count_unit
1001        );
1002        assert_eq!(
1003            a_count_unit + b_count_unit,
1004            items.len(),
1005            "All items must be assigned in weighted_scale"
1006        );
1007
1008        // 2. Assignment with unbounded_scale_fn (quotas are usize::MAX)
1009        let assignment_unbounded =
1010            assign_items_weighted_with_scale_fn(&containers, &items, salt, unbounded_scale);
1011        let a_count_unbounded = assignment_unbounded.get(container_a_id).map_or(0, Vec::len);
1012        let b_count_unbounded = assignment_unbounded.get(container_b_id).map_or(0, Vec::len);
1013
1014        assert_eq!(
1015            a_count_unbounded + b_count_unbounded,
1016            items.len(),
1017            "All items must be assigned in unbounded_scale"
1018        );
1019    }
1020
1021    // --- Tests for `compute_worker_quotas` ---
1022
1023    #[test]
1024    fn test_compute_worker_quotas_equal_weights() {
1025        // Three workers, each with weight 1, one actor each, total_vnodes = 6
1026        let workers: BTreeMap<u8, NonZeroUsize> = vec![
1027            (1, NonZeroUsize::new(1).unwrap()),
1028            (2, NonZeroUsize::new(1).unwrap()),
1029            (3, NonZeroUsize::new(1).unwrap()),
1030        ]
1031        .into_iter()
1032        .collect();
1033        let mut actor_counts = HashMap::new();
1034        actor_counts.insert(&1, 1);
1035        actor_counts.insert(&2, 1);
1036        actor_counts.insert(&3, 1);
1037        let total_vnodes = 6;
1038        let salt = 42u64;
1039
1040        let quotas = compute_worker_quotas(&workers, &actor_counts, total_vnodes, salt);
1041        // Each worker should have quota = 2
1042        for (&worker_id, &quota) in &quotas {
1043            assert_eq!(quota.get(), 2, "Worker {} expected quota 2", worker_id);
1044        }
1045        // Sum of quotas equals total_vnodes
1046        let sum: usize = quotas.values().map(|q| q.get()).sum();
1047        assert_eq!(sum, total_vnodes);
1048    }
1049
1050    #[test]
1051    fn test_compute_worker_quotas_unequal_weights() {
1052        // Two workers: id 1 weight 2, id 2 weight 1, one actor each, total_vnodes = 6
1053        let workers: BTreeMap<u8, NonZeroUsize> = vec![
1054            (1, NonZeroUsize::new(2).unwrap()),
1055            (2, NonZeroUsize::new(1).unwrap()),
1056        ]
1057        .into_iter()
1058        .collect();
1059        let mut actor_counts = HashMap::new();
1060        actor_counts.insert(&1, 1);
1061        actor_counts.insert(&2, 1);
1062        let total_vnodes = 6;
1063        let salt = 100u64;
1064
1065        let quotas = compute_worker_quotas(&workers, &actor_counts, total_vnodes, salt);
1066        // Worker 1 should get 4, worker 2 should get 2
1067        assert_eq!(quotas[&1].get(), 4);
1068        assert_eq!(quotas[&2].get(), 2);
1069        // Sum of quotas equals total_vnodes
1070        let sum: usize = quotas.values().map(|q| q.get()).sum();
1071        assert_eq!(sum, total_vnodes);
1072    }
1073
1074    #[test]
1075    fn test_compute_worker_quotas_minimum_base() {
1076        // Worker with no actors should not appear in quotas
1077        let workers: BTreeMap<u8, NonZeroUsize> = vec![
1078            (1, NonZeroUsize::new(1).unwrap()),
1079            (2, NonZeroUsize::new(1).unwrap()),
1080        ]
1081        .into_iter()
1082        .collect();
1083        let mut actor_counts = HashMap::new();
1084        actor_counts.insert(&1, 2);
1085        // worker 2 has zero actors
1086        let total_vnodes = 5;
1087        let salt = 7u8;
1088
1089        let quotas = compute_worker_quotas(&workers, &actor_counts, total_vnodes, salt);
1090        // Only worker 1 should be present
1091        assert_eq!(quotas.len(), 1);
1092        // Its quota should equal total_vnodes
1093        assert_eq!(quotas[&1].get(), total_vnodes);
1094    }
1095
1096    #[test]
1097    #[should_panic]
1098    fn test_compute_worker_quotas_invalid_total() {
1099        // total_vnodes less than sum of actor_counts should panic
1100        let workers: BTreeMap<u8, NonZeroUsize> = vec![(1, NonZeroUsize::new(1).unwrap())]
1101            .into_iter()
1102            .collect();
1103        let mut actor_counts = HashMap::new();
1104        actor_counts.insert(&1, 3);
1105        let total_vnodes = 2; // less than base_total = 3
1106        let salt = 0u16;
1107
1108        // This should panic due to underflow of extra_vnodes
1109        let _ = compute_worker_quotas(&workers, &actor_counts, total_vnodes, salt);
1110    }
1111
1112    /// This test verifies the core invariant of `compute_worker_quotas`:
1113    /// the sum of all calculated quotas must exactly equal the total number of vnodes provided.
1114    /// It runs through multiple scenarios to ensure this property holds under various conditions.
1115    #[test]
1116    fn test_compute_worker_quotas_sum_is_preserved() {
1117        // Helper function to run a single test case and assert the invariant.
1118        fn run_quota_sum_test<W, S>(
1119            scenario_name: &str,
1120            workers: &BTreeMap<W, NonZeroUsize>,
1121            actor_counts: &HashMap<&W, usize>,
1122            total_vnodes: usize,
1123            salt: S,
1124        ) where
1125            W: Ord + Clone + Hash + Eq + Debug,
1126            S: Hash + Copy,
1127        {
1128            let quotas = compute_worker_quotas(workers, actor_counts, total_vnodes, salt);
1129            let sum_of_quotas: usize = quotas.values().map(|q| q.get()).sum();
1130
1131            assert_eq!(
1132                sum_of_quotas, total_vnodes,
1133                "Scenario '{}' failed: Sum of quotas ({}) does not equal total_vnodes ({})",
1134                scenario_name, sum_of_quotas, total_vnodes
1135            );
1136        }
1137
1138        // --- Scenario 1: Even split with no remainder ---
1139        let workers1: BTreeMap<_, _> = [
1140            (1, NonZeroUsize::new(1).unwrap()),
1141            (2, NonZeroUsize::new(1).unwrap()),
1142            (3, NonZeroUsize::new(1).unwrap()),
1143        ]
1144        .into();
1145        let actor_counts1: HashMap<_, _> = [(&1, 2), (&2, 2), (&3, 2)].into();
1146        run_quota_sum_test(
1147            "Even split, no remainder",
1148            &workers1,
1149            &actor_counts1,
1150            12,
1151            0u8,
1152        );
1153
1154        // --- Scenario 2: Uneven split with remainder ---
1155        // This is the most common case, testing the remainder distribution logic.
1156        let workers2: BTreeMap<_, _> = [
1157            (1, NonZeroUsize::new(1).unwrap()),
1158            (2, NonZeroUsize::new(2).unwrap()),
1159            (3, NonZeroUsize::new(3).unwrap()),
1160        ]
1161        .into();
1162        let actor_counts2: HashMap<_, _> = [(&1, 1), (&2, 5), (&3, 2)].into();
1163        run_quota_sum_test(
1164            "Uneven split with remainder",
1165            &workers2,
1166            &actor_counts2,
1167            101,
1168            42u64,
1169        );
1170
1171        // --- Scenario 3: No extra vnodes to distribute ---
1172        // The total vnodes exactly match the sum of base actor counts.
1173        let workers3: BTreeMap<_, _> = [
1174            (1, NonZeroUsize::new(10).unwrap()),
1175            (2, NonZeroUsize::new(20).unwrap()),
1176        ]
1177        .into();
1178        let actor_counts3: HashMap<_, _> = [(&1, 5), (&2, 10)].into();
1179        run_quota_sum_test("No extra vnodes", &workers3, &actor_counts3, 15, 0u8);
1180
1181        // --- Scenario 4: Only one active worker ---
1182        // All vnodes should be assigned to the single worker with actors.
1183        let workers4: BTreeMap<_, _> = [
1184            (1, NonZeroUsize::new(1).unwrap()),
1185            (2, NonZeroUsize::new(1).unwrap()),
1186        ]
1187        .into();
1188        let actor_counts4: HashMap<_, _> = [(&1, 10)].into();
1189        run_quota_sum_test("Single active worker", &workers4, &actor_counts4, 100, 0u8);
1190
1191        // --- Scenario 5: Large and complex numbers ---
1192        // Stress test with larger, non-trivial numbers.
1193        let workers5: BTreeMap<_, _> = [
1194            (1, NonZeroUsize::new(7).unwrap()),
1195            (2, NonZeroUsize::new(13).unwrap()),
1196            (3, NonZeroUsize::new(19).unwrap()),
1197            (4, NonZeroUsize::new(23).unwrap()),
1198        ]
1199        .into();
1200        let actor_counts5: HashMap<_, _> = [(&1, 111), (&2, 222), (&3, 33), (&4, 4)].into();
1201        run_quota_sum_test(
1202            "Large and complex numbers",
1203            &workers5,
1204            &actor_counts5,
1205            99991,
1206            12345u64,
1207        );
1208    }
1209
1210    #[test]
1211    fn test_compute_worker_quotas_no_extra_vnodes() {
1212        let workers: BTreeMap<u8, NonZeroUsize> = vec![
1213            (1, NonZeroUsize::new(1).unwrap()),
1214            (2, NonZeroUsize::new(3).unwrap()),
1215        ]
1216        .into_iter()
1217        .collect();
1218        let mut actor_counts = HashMap::new();
1219        actor_counts.insert(&1, 2); // Worker 1, 2 actors
1220        actor_counts.insert(&2, 1); // Worker 2, 1 actor
1221        let total_vnodes = 3; // base_total = 2 + 1 = 3. So extra_vnodes = 0.
1222        let salt = 0u8;
1223
1224        let quotas = compute_worker_quotas(&workers, &actor_counts, total_vnodes, salt);
1225        assert_eq!(quotas.len(), 2);
1226        assert_eq!(
1227            quotas[&1].get(),
1228            2,
1229            "Worker 1 quota should be its base_quota"
1230        );
1231        assert_eq!(
1232            quotas[&2].get(),
1233            1,
1234            "Worker 2 quota should be its base_quota"
1235        );
1236        let sum: usize = quotas.values().map(|q| q.get()).sum();
1237        assert_eq!(sum, total_vnodes);
1238    }
1239
1240    #[test]
1241    #[should_panic] // Or expect specific error if compute_worker_quotas returns Result
1242    fn test_compute_worker_quotas_empty_actors_with_vnodes() {
1243        let workers: BTreeMap<u8, NonZeroUsize> = vec![(1, NonZeroUsize::new(1).unwrap())]
1244            .into_iter()
1245            .collect();
1246        let actor_counts: HashMap<&u8, usize> = HashMap::new(); // No active workers
1247        let total_vnodes = 5; // But vnodes exist
1248        let salt = 0u8;
1249
1250        // This scenario: extra_vnodes = 5, active_workers is empty, total_weight = 0.
1251        // Division by zero in `ideal_extra / total_weight`.
1252        let _ = compute_worker_quotas(&workers, &actor_counts, total_vnodes, salt);
1253    }
1254
1255    // --- Tests for `assign_hierarchical` ---
1256
1257    #[test]
1258    fn error_on_empty_actors() {
1259        let workers: BTreeMap<u8, NonZeroUsize> = vec![(1, NonZeroUsize::new(1).unwrap())]
1260            .into_iter()
1261            .collect();
1262        let actors: Vec<u16> = vec![];
1263        let vnodes: Vec<u16> = vec![1, 2];
1264
1265        let err = assign_hierarchical(
1266            &workers,
1267            &actors,
1268            &vnodes,
1269            0u8,
1270            CapacityMode::Weighted,
1271            BalancedBy::ActorCounts,
1272        )
1273        .unwrap_err();
1274
1275        assert!(err.to_string().contains("no actors to assign"));
1276    }
1277
1278    #[test]
1279    fn error_on_empty_vnodes() {
1280        let workers: BTreeMap<u8, NonZeroUsize> = vec![(1, NonZeroUsize::new(1).unwrap())]
1281            .into_iter()
1282            .collect();
1283        let actors: Vec<u16> = vec![10, 20];
1284        let vnodes: Vec<u16> = vec![];
1285
1286        let err = assign_hierarchical(
1287            &workers,
1288            &actors,
1289            &vnodes,
1290            1u8,
1291            CapacityMode::Unbounded,
1292            BalancedBy::RawWorkerWeights,
1293        )
1294        .unwrap_err();
1295
1296        assert!(err.to_string().contains("no vnodes to assign"));
1297    }
1298
1299    #[test]
1300    fn error_when_more_actors_than_vnodes() {
1301        let workers: BTreeMap<u8, NonZeroUsize> = vec![(1, NonZeroUsize::new(1).unwrap())]
1302            .into_iter()
1303            .collect();
1304        let actors: Vec<u16> = vec![1, 2, 3];
1305        let vnodes: Vec<u16> = vec![100];
1306
1307        let err = assign_hierarchical(
1308            &workers,
1309            &actors,
1310            &vnodes,
1311            7u8,
1312            CapacityMode::Weighted,
1313            BalancedBy::ActorCounts,
1314        )
1315        .unwrap_err();
1316
1317        assert!(err.to_string().contains("exceeds vnode count"));
1318    }
1319
1320    #[test]
1321    fn single_worker_all_actors_and_vnodes() {
1322        let workers: BTreeMap<u8, NonZeroUsize> = vec![(1, NonZeroUsize::new(5).unwrap())]
1323            .into_iter()
1324            .collect();
1325        let actors: Vec<u16> = vec![10, 20, 30];
1326        let vnodes: Vec<u16> = vec![100, 200, 300];
1327
1328        let assignment = assign_hierarchical(
1329            &workers,
1330            &actors,
1331            &vnodes,
1332            42u8,
1333            CapacityMode::Weighted,
1334            BalancedBy::RawWorkerWeights,
1335        )
1336        .unwrap();
1337
1338        // Only one worker should appear
1339        assert_eq!(assignment.len(), 1);
1340        let inner = &assignment[&1u8];
1341        // Each actor must get exactly one vnode
1342        for &actor in &actors {
1343            let assigned = inner.get(&actor).unwrap();
1344            assert_eq!(assigned.len(), 1, "Actor {} should have one vnode", actor);
1345        }
1346        // All vnodes assigned
1347        let total: usize = inner.values().map(Vec::len).sum();
1348        assert_eq!(total, vnodes.len());
1349    }
1350
1351    #[test]
1352    fn two_workers_balanced_by_actor_counts() {
1353        let workers: BTreeMap<u8, NonZeroUsize> = vec![
1354            (1, NonZeroUsize::new(1).unwrap()),
1355            (2, NonZeroUsize::new(1).unwrap()),
1356        ]
1357        .into_iter()
1358        .collect();
1359        let actors: Vec<u16> = vec![10, 20];
1360        let vnodes: Vec<u16> = vec![0, 1];
1361
1362        let assignment = assign_hierarchical(
1363            &workers,
1364            &actors,
1365            &vnodes,
1366            5u8,
1367            CapacityMode::Weighted,
1368            BalancedBy::ActorCounts,
1369        )
1370        .unwrap();
1371
1372        // Both workers should appear
1373        assert_eq!(assignment.len(), 2);
1374        for (&w, inner) in &assignment {
1375            // Each worker has exactly one actor
1376            assert_eq!(inner.len(), 1, "Worker {} should have one actor", w);
1377            // That actor has exactly one vnode
1378            let (_, vlist) = inner.iter().next().unwrap();
1379            assert_eq!(vlist.len(), 1, "Worker {} actor should have one vnode", w);
1380        }
1381    }
1382
1383    #[test]
1384    fn raw_worker_weights_respects_worker_weight() {
1385        let workers: BTreeMap<u8, NonZeroUsize> = vec![
1386            (1, NonZeroUsize::new(1).unwrap()),
1387            (2, NonZeroUsize::new(3).unwrap()),
1388        ]
1389        .into_iter()
1390        .collect();
1391        let actors: Vec<u16> = vec![10, 20, 30, 40];
1392        let vnodes: Vec<u16> = vec![0, 1, 2, 3, 4, 5, 6];
1393
1394        let assignment = assign_hierarchical(
1395            &workers,
1396            &actors,
1397            &vnodes,
1398            9u8,
1399            CapacityMode::Weighted,
1400            BalancedBy::RawWorkerWeights,
1401        )
1402        .unwrap();
1403
1404        let w1_total: usize = assignment.get(&1).unwrap().values().map(Vec::len).sum();
1405        let w2_total: usize = assignment.get(&2).unwrap().values().map(Vec::len).sum();
1406        // Worker 2 has triple weight, so should get roughly 3/4 of vnodes
1407        assert!(
1408            w2_total > w1_total,
1409            "Worker 2 should receive more vnodes than Worker 1"
1410        );
1411        assert_eq!(
1412            w1_total + w2_total,
1413            vnodes.len(),
1414            "All vnodes must be assigned"
1415        );
1416    }
1417
1418    #[test]
1419    fn assign_hierarchical_capacity_unbounded() {
1420        let mut workers: BTreeMap<u8, NonZeroUsize> = BTreeMap::new();
1421        workers.insert(1, NonZeroUsize::new(1).unwrap()); // Low weight worker
1422        workers.insert(2, NonZeroUsize::new(100).unwrap()); // High weight worker
1423        let actors: Vec<u16> = (0..10).collect(); // 10 actors
1424        let vnodes: Vec<u16> = (0..10).collect(); // 10 vnodes
1425        let salt = 33u8;
1426
1427        // With CapacityMode::Weighted, worker 1 would get very few (or 0) actors.
1428        // With CapacityMode::Unbounded, actor assignment is by rendezvous hash only.
1429        let assignment = assign_hierarchical(
1430            &workers,
1431            &actors,
1432            &vnodes,
1433            salt,
1434            CapacityMode::Unbounded, // Key change here
1435            BalancedBy::ActorCounts, // VNode balancing doesn't matter as much if actors are skewed
1436        )
1437        .unwrap();
1438
1439        let actors_on_w1 = assignment.get(&1).map_or(0, |amap| amap.len());
1440        let actors_on_w2 = assignment.get(&2).map_or(0, |amap| amap.len());
1441
1442        assert_eq!(
1443            actors_on_w1 + actors_on_w2,
1444            actors.len(),
1445            "All actors must be assigned"
1446        );
1447
1448        let total_assigned_vnodes: usize = assignment
1449            .values()
1450            .flat_map(|amap| amap.values().map(Vec::len))
1451            .sum();
1452        assert_eq!(
1453            total_assigned_vnodes,
1454            vnodes.len(),
1455            "All vnodes must be assigned"
1456        );
1457    }
1458
1459    #[test]
1460    fn assign_hierarchical_compare_balanced_by_modes() {
1461        let mut workers: BTreeMap<u8, NonZeroUsize> = BTreeMap::new();
1462        workers.insert(1, NonZeroUsize::new(1).unwrap()); // Worker 1, low raw weight
1463        workers.insert(2, NonZeroUsize::new(9).unwrap()); // Worker 2, high raw weight
1464        // Actors will be distributed somewhat according to worker weights (1:9)
1465        // Let's say 2 actors on W1, 18 on W2 for a total of 20 actors.
1466        // For simplicity, let's make actor distribution more even for the test.
1467        // We'll use Unbounded to try and get a mix of actors on both.
1468        // Or, use enough actors so both get some with Weighted mode.
1469        let actors: Vec<u16> = (0..10).collect(); // 10 actors
1470        let vnodes: Vec<u16> = (0..100).collect(); // 100 vnodes, plenty to show distribution
1471        let salt = 77u8;
1472
1473        // Assign actors first (CapacityMode::Weighted to see effect of worker weights)
1474        // W1 gets ~1 actor, W2 gets ~9 actors.
1475        let actor_assignment_for_setup =
1476            assign_items_weighted_with_scale_fn(&workers, &actors, salt, weighted_scale);
1477        let actors_on_w1_count = actor_assignment_for_setup.get(&1).map_or(0, Vec::len);
1478        let actors_on_w2_count = actor_assignment_for_setup.get(&2).map_or(0, Vec::len);
1479
1480        // Scenario 1: BalancedBy::RawWorkerWeights
1481        let assignment_raw = assign_hierarchical(
1482            &workers,
1483            &actors,
1484            &vnodes,
1485            salt,
1486            CapacityMode::Weighted, // Actors distributed by worker weight (1:9)
1487            BalancedBy::RawWorkerWeights,
1488        )
1489        .unwrap();
1490
1491        let vnodes_on_w1_raw: usize = assignment_raw
1492            .get(&1)
1493            .map_or(0, |amap| amap.values().map(Vec::len).sum());
1494        let vnodes_on_w2_raw: usize = assignment_raw
1495            .get(&2)
1496            .map_or(0, |amap| amap.values().map(Vec::len).sum());
1497        assert_eq!(vnodes_on_w1_raw + vnodes_on_w2_raw, vnodes.len());
1498        // With RawWorkerWeights, vnode distribution should also be skewed towards W2 (original 1:9 weights)
1499        // after base actor counts are met.
1500        // W1 has actors_on_w1_count, W2 has actors_on_w2_count.
1501        // Base vnodes: actors_on_w1_count for W1, actors_on_w2_count for W2.
1502        // Extra vnodes = 100 - (actors_on_w1_count + actors_on_w2_count) = 100 - 10 = 90.
1503        // These 90 extra vnodes are split 1:9. W1 gets 90*1/10=9. W2 gets 90*9/10=81.
1504        // Total for W1_raw = actors_on_w1_count + 9.
1505        // Total for W2_raw = actors_on_w2_count + 81.
1506        // Since actors_on_w1_count is small (e.g. 1) and actors_on_w2_count is large (e.g. 9),
1507        // W1_raw ~ 1+9=10. W2_raw ~ 9+81=90. Ratio is 1:9.
1508        if vnodes_on_w1_raw > 0 && vnodes_on_w2_raw > 0 {
1509            // Avoid division by zero if one worker gets no vnodes
1510            let ratio_raw = vnodes_on_w2_raw as f64 / vnodes_on_w1_raw as f64;
1511            assert!(
1512                ratio_raw > 5.0 && ratio_raw < 15.0,
1513                "Expected RawWorkerWeights ratio around 9, got {}",
1514                ratio_raw
1515            ); // Roughly 9x
1516        } else if vnodes_on_w2_raw > 0 {
1517            assert!(
1518                actors_on_w1_count == 0 || vnodes_on_w1_raw >= actors_on_w1_count,
1519                "W1 raw vnodes check"
1520            );
1521        }
1522
1523        // Scenario 2: BalancedBy::ActorCounts
1524        let assignment_actors = assign_hierarchical(
1525            &workers,
1526            &actors,
1527            &vnodes,
1528            salt,
1529            CapacityMode::Weighted, // Actors distributed by worker weight (1:9)
1530            BalancedBy::ActorCounts,
1531        )
1532        .unwrap();
1533
1534        let vnodes_on_w1_actors: usize = assignment_actors
1535            .get(&1)
1536            .map_or(0, |amap| amap.values().map(Vec::len).sum());
1537        let vnodes_on_w2_actors: usize = assignment_actors
1538            .get(&2)
1539            .map_or(0, |amap| amap.values().map(Vec::len).sum());
1540        assert_eq!(vnodes_on_w1_actors + vnodes_on_w2_actors, vnodes.len());
1541        // With ActorCounts, vnode distribution is weighted by number of actors on each worker.
1542        // Weights for vnode dist: actors_on_w1_count vs actors_on_w2_count.
1543        // Ratio of vnodes should be actors_on_w2_count / actors_on_w1_count.
1544        // E.g. if actors are 1 on W1, 9 on W2, then vnodes should also be ~1:9.
1545        if actors_on_w1_count > 0
1546            && actors_on_w2_count > 0
1547            && vnodes_on_w1_actors > 0
1548            && vnodes_on_w2_actors > 0
1549        {
1550            let expected_actor_ratio = actors_on_w2_count as f64 / actors_on_w1_count as f64;
1551            let actual_vnode_ratio_actors = vnodes_on_w2_actors as f64 / vnodes_on_w1_actors as f64;
1552            // Check if actual ratio is close to expected actor ratio
1553            assert!(
1554                (actual_vnode_ratio_actors - expected_actor_ratio).abs() < 2.0, /* Allow some leeway due to integer division */
1555                "Expected ActorCounts vnode ratio around {}, got {}",
1556                expected_actor_ratio,
1557                actual_vnode_ratio_actors
1558            );
1559        } else if vnodes_on_w2_actors > 0 {
1560            assert!(
1561                actors_on_w1_count == 0 || vnodes_on_w1_actors >= actors_on_w1_count,
1562                "W1 actorcount vnodes check"
1563            );
1564        }
1565    }
1566}
1567
1568#[cfg(test)]
1569mod extra_tests_for_scale_factor {
1570    use std::collections::BTreeMap;
1571    use std::num::NonZeroUsize;
1572
1573    use super::*;
1574
1575    #[test]
1576    fn test_scale_factor_constructor_rejects_invalid_values() {
1577        assert!(ScaleFactor::new(f64::NAN).is_none(), "Should reject NaN");
1578        assert!(
1579            ScaleFactor::new(f64::INFINITY).is_none(),
1580            "Should reject Infinity"
1581        );
1582        assert!(
1583            ScaleFactor::new(f64::NEG_INFINITY).is_none(),
1584            "Should reject Negative Infinity"
1585        );
1586        assert!(
1587            ScaleFactor::new(-1.0).is_none(),
1588            "Should reject negative numbers"
1589        );
1590        assert!(ScaleFactor::new(1.0).is_some(), "Should accept valid value");
1591    }
1592
1593    #[test]
1594    fn test_assign_items_scale_factor_less_than_one() {
1595        // A scale factor < 1.0 should not reduce quotas below their base value.
1596        // Therefore, the distribution should be identical to a scale factor of 1.0.
1597        let mut containers = BTreeMap::new();
1598        containers.insert("A", NonZeroUsize::new(3).unwrap());
1599        containers.insert("B", NonZeroUsize::new(1).unwrap());
1600        let items: Vec<i32> = (0..4).collect(); // 4 items
1601
1602        // Base quotas: A gets 3, B gets 1.
1603        let result_scale_one =
1604            assign_items_weighted_with_scale_fn(&containers, &items, 0u8, weighted_scale);
1605
1606        fn custom_scale_fn(_: &BTreeMap<&str, NonZeroUsize>, _: &[i32]) -> Option<ScaleFactor> {
1607            ScaleFactor::new(0.5)
1608        }
1609        let result_scale_half =
1610            assign_items_weighted_with_scale_fn(&containers, &items, 0u8, custom_scale_fn);
1611
1612        assert_eq!(
1613            result_scale_one[&"A"].len(),
1614            3,
1615            "With scale=1.0, A should get 3"
1616        );
1617        assert_eq!(
1618            result_scale_half[&"A"].len(),
1619            3,
1620            "With scale=0.5, A's quota should not be reduced, still gets 3"
1621        );
1622        assert_eq!(
1623            result_scale_one, result_scale_half,
1624            "Distributions should be identical"
1625        );
1626    }
1627
1628    #[test]
1629    fn test_assign_items_large_scale_factor_does_not_panic() {
1630        // This test ensures that a very large scale factor, which would cause
1631        // `quota * factor` to exceed `usize::MAX`, is clamped correctly and does not panic.
1632        let mut containers = BTreeMap::new();
1633        containers.insert("A", NonZeroUsize::new(1).unwrap());
1634        let items: Vec<i32> = vec![100]; // A single item, quota is 100.
1635
1636        // A huge scale factor that would definitely overflow if not clamped.
1637        fn huge_scale_fn(_: &BTreeMap<&str, NonZeroUsize>, _: &[i32]) -> Option<ScaleFactor> {
1638            ScaleFactor::new(f64::MAX)
1639        }
1640
1641        // The test passes if this function call does not panic.
1642        let assignment =
1643            assign_items_weighted_with_scale_fn(&containers, &items, 0u8, huge_scale_fn);
1644
1645        // All items should still be assigned correctly.
1646        assert_eq!(assignment[&"A"].len(), items.len());
1647    }
1648}
1649
1650#[cfg(test)]
1651mod affinity_tests {
1652    use std::collections::{BTreeMap, HashMap};
1653    use std::num::NonZeroUsize;
1654
1655    use super::*;
1656
1657    // --- Helper function to analyze affinity ---
1658
1659    /// Flattens the hierarchical assignment into a simple VNode -> Worker map.
1660    pub(crate) fn get_vnode_to_worker_map<W, A, V>(
1661        assignment: &BTreeMap<W, BTreeMap<A, Vec<V>>>,
1662    ) -> HashMap<V, W>
1663    where
1664        W: Copy + Eq + Hash,
1665        A: Copy + Eq + Hash,
1666        V: Copy + Eq + Hash,
1667    {
1668        let mut map = HashMap::new();
1669        for (&worker, actor_map) in assignment {
1670            for vnode_list in actor_map.values() {
1671                for &vnode in vnode_list {
1672                    map.insert(vnode, worker);
1673                }
1674            }
1675        }
1676        map
1677    }
1678
1679    /// Helper to run the affinity test for worker weight changes.
1680    fn run_weight_change_affinity_test(capacity_mode: CapacityMode, balanced_by: BalancedBy) {
1681        let initial_workers: BTreeMap<u8, _> = [
1682            (1, NonZeroUsize::new(5).unwrap()),
1683            (2, NonZeroUsize::new(5).unwrap()),
1684        ]
1685        .into();
1686        let actors: Vec<u16> = (0..100).collect();
1687        let vnodes: Vec<u32> = (0..1000).collect();
1688        let salt = 123u8;
1689
1690        let initial_assignment = assign_hierarchical(
1691            &initial_workers,
1692            &actors,
1693            &vnodes,
1694            salt,
1695            capacity_mode,
1696            balanced_by,
1697        )
1698        .unwrap();
1699        let initial_map = get_vnode_to_worker_map(&initial_assignment);
1700
1701        let mut changed_workers: BTreeMap<u8, _> = BTreeMap::new();
1702        changed_workers.insert(1, NonZeroUsize::new(2).unwrap());
1703        changed_workers.insert(2, NonZeroUsize::new(8).unwrap());
1704
1705        let new_assignment = assign_hierarchical(
1706            &changed_workers,
1707            &actors,
1708            &vnodes,
1709            salt,
1710            capacity_mode,
1711            balanced_by,
1712        )
1713        .unwrap();
1714        let new_map = get_vnode_to_worker_map(&new_assignment);
1715
1716        let stable_vnodes = initial_map
1717            .iter()
1718            .filter(|(v, w)| new_map.get(v) == Some(w))
1719            .count();
1720        let stability_percentage = (stable_vnodes as f64 / vnodes.len() as f64) * 100.0;
1721
1722        println!(
1723            "Affinity for {:?}/{:?}: {:.2}% of vnodes remained stable when weights changed from {:?} to {:?}.",
1724            capacity_mode, balanced_by, stability_percentage, initial_workers, changed_workers
1725        );
1726
1727        assert!(
1728            stability_percentage < 100.0,
1729            "Expected some vnodes to move for {:?}/{:?}",
1730            capacity_mode,
1731            balanced_by
1732        );
1733        assert!(
1734            stability_percentage > 50.0,
1735            "Expected a majority of vnodes to have affinity for {:?}/{:?}",
1736            capacity_mode,
1737            balanced_by
1738        );
1739    }
1740
1741    /// Helper to run the affinity test for actor count changes.
1742    fn run_actor_count_change_affinity_test(capacity_mode: CapacityMode, balanced_by: BalancedBy) {
1743        let workers: BTreeMap<u8, _> = [
1744            (1, NonZeroUsize::new(5).unwrap()),
1745            (2, NonZeroUsize::new(5).unwrap()),
1746        ]
1747        .into();
1748        let initial_actors: Vec<u16> = (0..100).collect();
1749        let vnodes: Vec<u32> = (0..1000).collect();
1750        let salt = 123u8;
1751
1752        let initial_assignment = assign_hierarchical(
1753            &workers,
1754            &initial_actors,
1755            &vnodes,
1756            salt,
1757            capacity_mode,
1758            balanced_by,
1759        )
1760        .unwrap();
1761        let initial_map = get_vnode_to_worker_map(&initial_assignment);
1762
1763        let changed_actors: Vec<u16> = (0..120).collect();
1764        let new_assignment = assign_hierarchical(
1765            &workers,
1766            &changed_actors,
1767            &vnodes,
1768            salt,
1769            capacity_mode,
1770            balanced_by,
1771        )
1772        .unwrap();
1773        let new_map = get_vnode_to_worker_map(&new_assignment);
1774
1775        let stable_vnodes = initial_map
1776            .iter()
1777            .filter(|(v, w)| new_map.get(v) == Some(w))
1778            .count();
1779        let stability_percentage = (stable_vnodes as f64 / vnodes.len() as f64) * 100.0;
1780
1781        println!(
1782            "Affinity for {:?}/{:?}: {:.2}% of vnodes remained stable when actor count changed from {} to {}.",
1783            capacity_mode,
1784            balanced_by,
1785            stability_percentage,
1786            initial_actors.len(),
1787            changed_actors.len(),
1788        );
1789
1790        // The expected affinity depends heavily on the balancing strategy.
1791        match balanced_by {
1792            BalancedBy::RawWorkerWeights => {
1793                // Actor count change has minimal impact on vnode distribution.
1794                assert!(
1795                    stability_percentage > 90.0,
1796                    "Expected very high affinity for RawWorkerWeights"
1797                );
1798            }
1799            BalancedBy::ActorCounts => {
1800                // Actor count change directly impacts vnode distribution weights, causing more churn.
1801                assert!(
1802                    stability_percentage > 50.0,
1803                    "Expected moderate affinity for ActorCounts"
1804                );
1805            }
1806        }
1807    }
1808
1809    #[test]
1810    fn test_affinity_when_worker_weights_change_all_modes() {
1811        let modes = [
1812            (CapacityMode::Weighted, BalancedBy::RawWorkerWeights),
1813            (CapacityMode::Weighted, BalancedBy::ActorCounts),
1814            (CapacityMode::Unbounded, BalancedBy::RawWorkerWeights),
1815            (CapacityMode::Unbounded, BalancedBy::ActorCounts),
1816        ];
1817        for (capacity_mode, balanced_by) in modes {
1818            run_weight_change_affinity_test(capacity_mode, balanced_by);
1819        }
1820    }
1821
1822    #[test]
1823    fn test_affinity_when_actor_count_changes_all_modes() {
1824        let modes = [
1825            (CapacityMode::Weighted, BalancedBy::RawWorkerWeights),
1826            (CapacityMode::Weighted, BalancedBy::ActorCounts),
1827            (CapacityMode::Unbounded, BalancedBy::RawWorkerWeights),
1828            (CapacityMode::Unbounded, BalancedBy::ActorCounts),
1829        ];
1830        for (capacity_mode, balanced_by) in modes {
1831            run_actor_count_change_affinity_test(capacity_mode, balanced_by);
1832        }
1833    }
1834}
1835
1836#[cfg(test)]
1837mod affinity_tests_horizon_scaling {
1838    use std::cmp::Ordering;
1839    use std::collections::{BTreeMap, HashMap, HashSet};
1840    use std::num::NonZeroUsize;
1841
1842    use affinity_tests::get_vnode_to_worker_map;
1843
1844    use super::*;
1845
1846    /// A struct to hold the results of a generic affinity analysis.
1847    #[derive(Debug)]
1848    struct AffinityAnalysis {
1849        /// Percentage of vnodes on surviving workers that remained stable.
1850        stability_on_survivors_pct: f64,
1851        /// Percentage of total vnodes that moved to newly added workers.
1852        moved_to_new_workers_pct: f64,
1853        /// Percentage of total vnodes that did not change their worker assignment.
1854        overall_stability_pct: f64,
1855    }
1856
1857    /// A generic function to analyze affinity between two states.
1858    /// It can handle scale-out, scale-in, and no-change scenarios.
1859    fn analyze_cluster_change(
1860        initial_map: &HashMap<u32, u8>,
1861        new_map: &HashMap<u32, u8>,
1862        initial_workers: &BTreeMap<u8, NonZeroUsize>,
1863        new_workers: &BTreeMap<u8, NonZeroUsize>,
1864    ) -> AffinityAnalysis {
1865        let initial_keys: HashSet<_> = initial_workers.keys().copied().collect();
1866        let new_keys: HashSet<_> = new_workers.keys().copied().collect();
1867
1868        let surviving_workers: HashSet<_> = initial_keys.intersection(&new_keys).copied().collect();
1869        let added_workers: HashSet<_> = new_keys.difference(&initial_keys).copied().collect();
1870
1871        let total_vnodes = initial_map.len();
1872        if total_vnodes == 0 {
1873            return AffinityAnalysis {
1874                stability_on_survivors_pct: 100.0,
1875                moved_to_new_workers_pct: 0.0,
1876                overall_stability_pct: 100.0,
1877            };
1878        }
1879
1880        let mut stable_vnodes_overall = 0;
1881        let mut moved_to_new_worker_count = 0;
1882
1883        for (vnode, &initial_worker) in initial_map {
1884            if let Some(&new_worker) = new_map.get(vnode) {
1885                if initial_worker == new_worker {
1886                    stable_vnodes_overall += 1;
1887                } else if added_workers.contains(&new_worker) {
1888                    moved_to_new_worker_count += 1;
1889                }
1890            }
1891        }
1892
1893        let vnodes_on_survivors_initially = initial_map
1894            .values()
1895            .filter(|w| surviving_workers.contains(w))
1896            .count();
1897        let stable_on_survivors = initial_map
1898            .iter()
1899            .filter(|(_, w)| surviving_workers.contains(w))
1900            .filter(|(v, w)| new_map.get(v) == Some(w))
1901            .count();
1902
1903        AffinityAnalysis {
1904            stability_on_survivors_pct: if vnodes_on_survivors_initially > 0 {
1905                (stable_on_survivors as f64 / vnodes_on_survivors_initially as f64) * 100.0
1906            } else {
1907                100.0 // No survivors, so stability is vacuously 100%
1908            },
1909            moved_to_new_workers_pct: (moved_to_new_worker_count as f64 / total_vnodes as f64)
1910                * 100.0,
1911            overall_stability_pct: (stable_vnodes_overall as f64 / total_vnodes as f64) * 100.0,
1912        }
1913    }
1914
1915    #[test]
1916    fn test_generic_cluster_resize_affinity_all_modes() {
1917        struct TestCase {
1918            name: &'static str,
1919            initial_size: usize,
1920            final_size: usize,
1921        }
1922
1923        let test_cases = [
1924            TestCase {
1925                name: "Scale In (3 -> 2)",
1926                initial_size: 3,
1927                final_size: 2,
1928            },
1929            TestCase {
1930                name: "Scale Out (2 -> 3)",
1931                initial_size: 2,
1932                final_size: 3,
1933            },
1934            TestCase {
1935                name: "Scale In (5 -> 4)",
1936                initial_size: 5,
1937                final_size: 4,
1938            },
1939            TestCase {
1940                name: "Scale Out (4 -> 5)",
1941                initial_size: 4,
1942                final_size: 5,
1943            },
1944            TestCase {
1945                name: "No Change (3 -> 3)",
1946                initial_size: 3,
1947                final_size: 3,
1948            },
1949            TestCase {
1950                name: "Scale Double (4 -> 8)",
1951                initial_size: 4,
1952                final_size: 8,
1953            },
1954            TestCase {
1955                name: "Scale Half (8 -> 4)",
1956                initial_size: 8,
1957                final_size: 4,
1958            },
1959        ];
1960
1961        let modes = [
1962            (CapacityMode::Weighted, BalancedBy::RawWorkerWeights),
1963            (CapacityMode::Weighted, BalancedBy::ActorCounts),
1964            (CapacityMode::Unbounded, BalancedBy::RawWorkerWeights),
1965            (CapacityMode::Unbounded, BalancedBy::ActorCounts),
1966        ];
1967
1968        let actors: Vec<u16> = (0..100).collect();
1969        let vnodes: Vec<u32> = (0..1000).collect();
1970        let salt = 123u8;
1971
1972        for case in &test_cases {
1973            for (capacity_mode, balanced_by) in modes {
1974                println!(
1975                    "--- Running Test: {} with {:?}/{:?} ---",
1976                    case.name, capacity_mode, balanced_by
1977                );
1978
1979                let initial_workers: BTreeMap<_, _> = (1..=case.initial_size as u8)
1980                    .map(|i| (i, NonZeroUsize::new(5).unwrap()))
1981                    .collect();
1982                let final_workers: BTreeMap<_, _> = (1..=case.final_size as u8)
1983                    .map(|i| (i, NonZeroUsize::new(5).unwrap()))
1984                    .collect();
1985
1986                let initial_assignment = assign_hierarchical(
1987                    &initial_workers,
1988                    &actors,
1989                    &vnodes,
1990                    salt,
1991                    capacity_mode,
1992                    balanced_by,
1993                )
1994                .unwrap();
1995                let initial_map = get_vnode_to_worker_map(&initial_assignment);
1996
1997                let new_assignment = assign_hierarchical(
1998                    &final_workers,
1999                    &actors,
2000                    &vnodes,
2001                    salt,
2002                    capacity_mode,
2003                    balanced_by,
2004                )
2005                .unwrap();
2006                let new_map = get_vnode_to_worker_map(&new_assignment);
2007
2008                let analysis = analyze_cluster_change(
2009                    &initial_map,
2010                    &new_map,
2011                    &initial_workers,
2012                    &final_workers,
2013                );
2014
2015                match case.final_size.cmp(&case.initial_size) {
2016                    Ordering::Less => {
2017                        // Scale In
2018                        println!(
2019                            "  Result: Stability on survivors = {:.2}%",
2020                            analysis.stability_on_survivors_pct
2021                        );
2022                        assert!(
2023                            analysis.stability_on_survivors_pct > 90.0,
2024                            "Expected very high stability on surviving nodes during scale-in"
2025                        );
2026                    }
2027                    Ordering::Equal => {
2028                        // No Change
2029                        println!(
2030                            "  Result: Overall stability = {:.2}%",
2031                            analysis.overall_stability_pct
2032                        );
2033                        assert_eq!(
2034                            analysis.overall_stability_pct, 100.0,
2035                            "Expected 100% stability when cluster size does not change"
2036                        );
2037                    }
2038                    Ordering::Greater => {
2039                        // Scale Out
2040                        let expected_move_rate = (case.final_size - case.initial_size) as f64
2041                            / case.final_size as f64
2042                            * 100.0;
2043                        let expected_stability_rate =
2044                            case.initial_size as f64 / case.final_size as f64 * 100.0;
2045
2046                        println!(
2047                            "  Result: Overall stability = {:.2}% (Expected ~{:.2}%), Moved to new = {:.2}% (Expected ~{:.2}%)",
2048                            analysis.overall_stability_pct,
2049                            expected_stability_rate,
2050                            analysis.moved_to_new_workers_pct,
2051                            expected_move_rate
2052                        );
2053
2054                        // Assert that the actual values are within a reasonable tolerance of the expected values.
2055                        assert!(
2056                            (analysis.moved_to_new_workers_pct - expected_move_rate).abs() < 10.0,
2057                            "Move rate to new nodes is outside expected tolerance"
2058                        );
2059                        assert!(
2060                            (analysis.overall_stability_pct - expected_stability_rate).abs() < 10.0,
2061                            "Overall stability is outside expected tolerance"
2062                        );
2063                    }
2064                }
2065            }
2066        }
2067    }
2068}
2069
2070#[cfg(test)]
2071mod affinity_tests_vertical_scaling {
2072    use std::collections::BTreeMap;
2073    use std::num::NonZeroUsize;
2074
2075    use affinity_tests::get_vnode_to_worker_map;
2076    use risingwave_common::util::iter_util::ZipEqFast;
2077
2078    use super::*;
2079
2080    /// A struct to define a test case for worker weight changes.
2081    struct WeightChangeTestCase {
2082        name: &'static str,
2083        initial_weights: Vec<usize>,
2084        final_weights: Vec<usize>,
2085    }
2086
2087    /// A generic helper that runs a single weight change test case for a given mode.
2088    fn run_weight_change_test_case(
2089        case: &WeightChangeTestCase,
2090        capacity_mode: CapacityMode,
2091        balanced_by: BalancedBy,
2092    ) {
2093        let actors: Vec<u16> = (0..100).collect();
2094        let vnodes: Vec<u32> = (0..1000).collect();
2095        let salt = 123u8;
2096
2097        let initial_workers: BTreeMap<u8, _> = case
2098            .initial_weights
2099            .iter()
2100            .enumerate()
2101            .map(|(i, &w)| (i as u8 + 1, NonZeroUsize::new(w).unwrap()))
2102            .collect();
2103        let final_workers: BTreeMap<u8, _> = case
2104            .final_weights
2105            .iter()
2106            .enumerate()
2107            .map(|(i, &w)| (i as u8 + 1, NonZeroUsize::new(w).unwrap()))
2108            .collect();
2109
2110        let initial_assignment = assign_hierarchical(
2111            &initial_workers,
2112            &actors,
2113            &vnodes,
2114            salt,
2115            capacity_mode,
2116            balanced_by,
2117        )
2118        .unwrap();
2119        let initial_map = get_vnode_to_worker_map(&initial_assignment);
2120        let new_assignment = assign_hierarchical(
2121            &final_workers,
2122            &actors,
2123            &vnodes,
2124            salt,
2125            capacity_mode,
2126            balanced_by,
2127        )
2128        .unwrap();
2129        let new_map = get_vnode_to_worker_map(&new_assignment);
2130
2131        let stable_vnodes = initial_map
2132            .iter()
2133            .filter(|(v, w)| new_map.get(v) == Some(w))
2134            .count();
2135        let actual_stability_pct = (stable_vnodes as f64 / vnodes.len() as f64) * 100.0;
2136
2137        println!(
2138            "--- Running Test: '{}' with {:?}/{:?} ---",
2139            case.name, capacity_mode, balanced_by
2140        );
2141        println!(
2142            "  Result: {:.2}% of vnodes remained stable.",
2143            actual_stability_pct
2144        );
2145
2146        // --- Dynamic Assertions ---
2147        let expected_stability_pct = match balanced_by {
2148            BalancedBy::RawWorkerWeights => {
2149                // For RawWorkerWeights, churn is based on the change in worker weight ratios.
2150                let initial_total_weight: usize = case.initial_weights.iter().sum();
2151                let final_total_weight: usize = case.final_weights.iter().sum();
2152                let expected_moved: f64 = case
2153                    .initial_weights
2154                    .iter()
2155                    .zip_eq_fast(case.final_weights.iter())
2156                    .map(|(&iw, &fw)| {
2157                        let initial_share =
2158                            vnodes.len() as f64 * (iw as f64 / initial_total_weight as f64);
2159                        let final_share =
2160                            vnodes.len() as f64 * (fw as f64 / final_total_weight as f64);
2161                        (final_share - initial_share).max(0.0)
2162                    })
2163                    .sum();
2164                (vnodes.len() as f64 - expected_moved) / vnodes.len() as f64 * 100.0
2165            }
2166            BalancedBy::ActorCounts => {
2167                // For ActorCounts, churn is based on the change in actor distribution,
2168                // which itself is determined by worker weights (if CapacityMode is Weighted).
2169                let initial_actor_dist = AssignerBuilder::new(salt)
2170                    .build()
2171                    .count_actors_per_worker(&initial_workers, actors.len());
2172                let final_actor_dist = AssignerBuilder::new(salt)
2173                    .build()
2174                    .count_actors_per_worker(&final_workers, actors.len());
2175
2176                let initial_total_actors: usize = initial_actor_dist.values().sum();
2177                let final_total_actors: usize = final_actor_dist.values().sum();
2178
2179                let expected_moved: f64 = (1..=initial_workers.len() as u8)
2180                    .map(|worker_id| {
2181                        let initial_actors_on_worker =
2182                            *initial_actor_dist.get(&worker_id).unwrap_or(&0);
2183                        let final_actors_on_worker =
2184                            *final_actor_dist.get(&worker_id).unwrap_or(&0);
2185                        let initial_share = vnodes.len() as f64
2186                            * (initial_actors_on_worker as f64 / initial_total_actors as f64);
2187                        let final_share = vnodes.len() as f64
2188                            * (final_actors_on_worker as f64 / final_total_actors as f64);
2189                        (final_share - initial_share).max(0.0)
2190                    })
2191                    .sum();
2192                (vnodes.len() as f64 - expected_moved) / vnodes.len() as f64 * 100.0
2193            }
2194        };
2195
2196        println!(
2197            "  Expectation for this mode: ~{:.2}% stability.",
2198            expected_stability_pct
2199        );
2200
2201        assert!(
2202            (actual_stability_pct - expected_stability_pct).abs() < 10.0,
2203            "Stability is outside the expected tolerance for this mode."
2204        );
2205    }
2206
2207    #[test]
2208    fn test_generic_weight_change_affinity_all_modes() {
2209        let test_cases = [
2210            WeightChangeTestCase {
2211                name: "Uniform Scaling (No relative change) #1",
2212                initial_weights: vec![5, 5],
2213                final_weights: vec![10, 10],
2214            },
2215            WeightChangeTestCase {
2216                name: "Uniform Scaling (No relative change) #2",
2217                initial_weights: vec![8, 8],
2218                final_weights: vec![4, 4],
2219            },
2220            WeightChangeTestCase {
2221                name: "Single Worker Weight Decrease",
2222                initial_weights: vec![5, 5, 5],
2223                final_weights: vec![2, 5, 5],
2224            },
2225            WeightChangeTestCase {
2226                name: "Single Worker Weight Increase",
2227                initial_weights: vec![5, 5, 5],
2228                final_weights: vec![8, 5, 5],
2229            },
2230            WeightChangeTestCase {
2231                name: "Complex Rebalance",
2232                initial_weights: vec![5, 5],
2233                final_weights: vec![2, 8],
2234            },
2235        ];
2236
2237        let modes = [
2238            (CapacityMode::Weighted, BalancedBy::RawWorkerWeights),
2239            (CapacityMode::Weighted, BalancedBy::ActorCounts),
2240            (CapacityMode::Unbounded, BalancedBy::RawWorkerWeights),
2241            (CapacityMode::Unbounded, BalancedBy::ActorCounts),
2242        ];
2243
2244        for case in &test_cases {
2245            for (capacity_mode, balanced_by) in modes {
2246                run_weight_change_test_case(case, capacity_mode, balanced_by);
2247            }
2248        }
2249    }
2250}
2251
2252#[cfg(test)]
2253mod assigner_test {
2254    use std::collections::BTreeMap;
2255    use std::num::NonZeroUsize;
2256
2257    use super::*;
2258
2259    // Helper function to create a BTreeMap of workers
2260    fn create_workers(weights: &[(u8, usize)]) -> BTreeMap<u8, NonZeroUsize> {
2261        weights
2262            .iter()
2263            .map(|(id, w)| (*id, NonZeroUsize::new(*w).unwrap()))
2264            .collect()
2265    }
2266
2267    #[test]
2268    fn test_maximize_contiguity_basic_assignment() {
2269        // 2 workers, 4 actors, 100 vnodes.
2270        // Expected chunk_size = floor(100 / 4) = 25.
2271        // Expected num_chunks = ceil(100 / 25) = 4.
2272        let workers = create_workers(&[(1, 1), (2, 1)]);
2273        let actors: Vec<u16> = (0..4).collect();
2274        let vnodes: Vec<u32> = (0..100).collect();
2275
2276        let assigner = AssignerBuilder::new(0u8)
2277            .with_vnode_chunking_strategy(VnodeChunkingStrategy::MaximizeContiguity)
2278            .build();
2279
2280        let assignment = assigner
2281            .assign_hierarchical(&workers, &actors, &vnodes)
2282            .unwrap();
2283
2284        let mut total_vnodes_assigned = 0;
2285        let mut all_assigned_vnodes = BTreeMap::new();
2286
2287        for (_, actor_map) in assignment {
2288            for (actor_id, vnodes) in actor_map {
2289                total_vnodes_assigned += vnodes.len();
2290                // Each actor should receive exactly one chunk of 25 vnodes.
2291                assert_eq!(
2292                    vnodes.len(),
2293                    25,
2294                    "Actor {} should get a full chunk",
2295                    actor_id
2296                );
2297                // Verify that the assigned vnodes are contiguous.
2298                for i in 0..(vnodes.len() - 1) {
2299                    assert_eq!(vnodes[i] + 1, vnodes[i + 1], "VNodes should be contiguous");
2300                }
2301                all_assigned_vnodes.insert(vnodes[0], vnodes);
2302            }
2303        }
2304
2305        assert_eq!(
2306            total_vnodes_assigned,
2307            vnodes.len(),
2308            "All vnodes must be assigned"
2309        );
2310        // Check if the chunks are correct: 0-24, 25-49, 50-74, 75-99
2311        assert!(all_assigned_vnodes.contains_key(&0));
2312        assert!(all_assigned_vnodes.contains_key(&25));
2313        assert!(all_assigned_vnodes.contains_key(&50));
2314        assert!(all_assigned_vnodes.contains_key(&75));
2315    }
2316
2317    #[test]
2318    fn test_maximize_contiguity_non_divisible_vnodes() {
2319        // 4 actors, 103 vnodes.
2320        // Expected chunk_size = floor(103 / 4) = 25.
2321        // Expected num_chunks = ceil(103 / 25) = 5.
2322        let workers = create_workers(&[(1, 1)]);
2323        let actors: Vec<u16> = (0..4).collect();
2324        let vnodes: Vec<u32> = (0..103).collect();
2325
2326        let assigner = AssignerBuilder::new(0u8)
2327            .with_vnode_chunking_strategy(VnodeChunkingStrategy::MaximizeContiguity)
2328            .build();
2329
2330        let assignment = assigner
2331            .assign_hierarchical(&workers, &actors, &vnodes)
2332            .unwrap();
2333        let actor_map = assignment.get(&1).unwrap();
2334
2335        // Collect the number of vnodes assigned to each actor.
2336        let vnode_counts: Vec<usize> = actor_map.values().map(Vec::len).collect();
2337
2338        // 1. Verify the total number of assigned vnodes.
2339        let total_assigned: usize = vnode_counts.iter().sum();
2340        assert_eq!(total_assigned, vnodes.len(), "All vnodes must be assigned");
2341        assert_eq!(
2342            vnode_counts.len(),
2343            actors.len(),
2344            "Each actor must have an entry"
2345        );
2346
2347        // 2. Verify the distribution of vnode counts.
2348        // The final counts depend on how the chunks {25, 25, 25, 25, 3} are distributed.
2349        // The actor who gets two chunks could get:
2350        //   25 + 25 = 50
2351        //   25 + 3 = 28
2352        // The other actors get one chunk, so their counts will be 25 or 3.
2353
2354        // We can count how many actors got each possible number of vnodes.
2355        let mut counts_distribution = BTreeMap::new();
2356        for count in vnode_counts {
2357            *counts_distribution.entry(count).or_insert(0) += 1;
2358        }
2359
2360        assert_eq!(
2361            counts_distribution,
2362            BTreeMap::from([(25, 3), (28, 1)]),
2363            "The distribution of vnode counts is unexpected. Got: {:?}",
2364            counts_distribution
2365        );
2366    }
2367
2368    #[test]
2369    fn test_actors_gt_vnodes_fails() {
2370        // 10 actors, 5 vnodes. This should fail at the top level.
2371        let workers = create_workers(&[(1, 1)]);
2372        let actors: Vec<u16> = (0..10).collect();
2373        let vnodes: Vec<u32> = (0..5).collect();
2374
2375        let assigner = AssignerBuilder::new(0u8).build();
2376
2377        let result = assigner.assign_hierarchical(&workers, &actors, &vnodes);
2378        assert!(result.is_err());
2379        // The error comes from the inner `assign_hierarchical` call.
2380        assert!(
2381            result
2382                .unwrap_err()
2383                .to_string()
2384                .contains("not enough vnodes (5) for actors (10)")
2385        );
2386    }
2387
2388    #[test]
2389    fn test_maximize_contiguity_actors_eq_vnodes() {
2390        // 10 actors, 10 vnodes.
2391        // Expected chunk_size = floor(10 / 10) = 1.
2392        // This should behave identically to NoChunking.
2393        let workers = create_workers(&[(1, 1), (2, 1)]);
2394        let actors: Vec<u16> = (0..10).collect();
2395        let vnodes: Vec<u32> = (0..10).collect();
2396
2397        let assigner_contiguity = AssignerBuilder::new(42u8)
2398            .with_vnode_chunking_strategy(VnodeChunkingStrategy::MaximizeContiguity)
2399            .build();
2400        let assignment_contiguity = assigner_contiguity
2401            .assign_hierarchical(&workers, &actors, &vnodes)
2402            .unwrap();
2403
2404        let assigner_no_chunking = AssignerBuilder::new(42u8)
2405            .with_vnode_chunking_strategy(VnodeChunkingStrategy::NoChunking)
2406            .build();
2407        let assignment_no_chunking = assigner_no_chunking
2408            .assign_hierarchical(&workers, &actors, &vnodes)
2409            .unwrap();
2410
2411        // With the same salt, the results should be identical.
2412        assert_eq!(assignment_contiguity, assignment_no_chunking);
2413
2414        // Also verify each actor gets exactly one vnode.
2415        let total_vnodes: usize = assignment_contiguity
2416            .values()
2417            .flat_map(|amap| amap.values().map(Vec::len))
2418            .sum();
2419        assert_eq!(total_vnodes, vnodes.len());
2420        assert!(
2421            assignment_contiguity
2422                .values()
2423                .all(|amap| amap.values().all(|v| v.len() == 1))
2424        );
2425    }
2426
2427    #[test]
2428    fn test_maximize_contiguity_single_actor() {
2429        // 1 actor, 1000 vnodes.
2430        // Expected chunk_size = floor(1000 / 1) = 1000.
2431        // All vnodes should go to the single actor.
2432        let workers = create_workers(&[(1, 1)]);
2433        let actors: Vec<u16> = vec![100];
2434        let vnodes: Vec<u32> = (0..1000).collect();
2435
2436        let assigner = AssignerBuilder::new(0u8)
2437            .with_vnode_chunking_strategy(VnodeChunkingStrategy::MaximizeContiguity)
2438            .build();
2439
2440        let assignment = assigner
2441            .assign_hierarchical(&workers, &actors, &vnodes)
2442            .unwrap();
2443
2444        // Check that only one worker has assignments.
2445        assert_eq!(assignment.len(), 1);
2446        let actor_map = assignment.get(&1).unwrap();
2447
2448        // Check that only the single actor has assignments.
2449        assert_eq!(actor_map.len(), 1);
2450        let assigned_vnodes = actor_map.get(&100).unwrap();
2451
2452        // Check that the actor received all vnodes.
2453        assert_eq!(assigned_vnodes.len(), vnodes.len());
2454        // Check that they are the correct vnodes.
2455        assert_eq!(*assigned_vnodes, vnodes);
2456    }
2457
2458    #[test]
2459    fn test_maximize_contiguity_differs_from_no_chunking() {
2460        // Use a setup where the difference will be clear.
2461        // 2 workers, 2 actors, 4 vnodes.
2462        // MaximizeContiguity: chunk_size = floor(4/2) = 2. Two chunks: [0,1], [2,3].
2463        // Each actor gets one chunk. Actor 0 might get [0,1] and Actor 1 [2,3].
2464        // NoChunking: vnodes 0,1,2,3 are assigned independently. It's highly likely
2465        // they will be distributed between the two actors, not in contiguous blocks.
2466        let workers = create_workers(&[(1, 1)]);
2467        let actors: Vec<u16> = vec![0, 1];
2468        let vnodes: Vec<u32> = vec![0, 1, 2, 3];
2469        let salt = 123u8; // A fixed salt to make it deterministic.
2470
2471        let assigner_contiguity = AssignerBuilder::new(salt)
2472            .with_vnode_chunking_strategy(VnodeChunkingStrategy::MaximizeContiguity)
2473            .build();
2474        let assignment_contiguity = assigner_contiguity
2475            .assign_hierarchical(&workers, &actors, &vnodes)
2476            .unwrap();
2477
2478        let assigner_no_chunking = AssignerBuilder::new(salt)
2479            .with_vnode_chunking_strategy(VnodeChunkingStrategy::NoChunking)
2480            .build();
2481        let assignment_no_chunking = assigner_no_chunking
2482            .assign_hierarchical(&workers, &actors, &vnodes)
2483            .unwrap();
2484
2485        // The assignments should be different.
2486        assert_ne!(
2487            assignment_contiguity, assignment_no_chunking,
2488            "Expected different assignments for the two strategies"
2489        );
2490
2491        // Verify contiguity for the MaximizeContiguity result.
2492        let actor_map_contiguity = assignment_contiguity.get(&1).unwrap();
2493        for vnodes in actor_map_contiguity.values() {
2494            assert_eq!(vnodes.len(), 2, "Each actor should get a chunk of size 2");
2495            assert_eq!(vnodes[0] + 1, vnodes[1], "VNodes must be contiguous");
2496        }
2497    }
2498}
2499
2500#[cfg(test)]
2501mod multi_group_cluster_simulation_tests {
2502    use std::collections::BTreeMap;
2503    use std::num::NonZeroUsize;
2504
2505    use super::affinity_tests::get_vnode_to_worker_map;
2506    use super::*;
2507
2508    const VNODE_COUNT: usize = 256;
2509    const ACTOR_COUNT: usize = 32;
2510
2511    #[test]
2512    fn test_multi_group_balance() {
2513        let actors: Vec<u16> = (0..ACTOR_COUNT as u16).collect();
2514        let vnodes: Vec<u32> = (0..VNODE_COUNT as u32).collect();
2515
2516        let modes = [
2517            (CapacityMode::Weighted, BalancedBy::RawWorkerWeights),
2518            (CapacityMode::Weighted, BalancedBy::ActorCounts),
2519            (CapacityMode::Unbounded, BalancedBy::RawWorkerWeights),
2520            (CapacityMode::Unbounded, BalancedBy::ActorCounts),
2521        ];
2522
2523        for num_groups in [1, 5, 10, 20, 50, 100, 500] {
2524            for worker_count in [3, 4, 5] {
2525                println!(
2526                    "\n--- Testing with {} Groups {} Workers ---\n",
2527                    num_groups, worker_count
2528                );
2529                for &(capacity_mode, balanced_by) in &modes {
2530                    println!(
2531                        "\n--- Testing Mode: Capacity={:?}, BalancedBy={:?} ---\n",
2532                        capacity_mode, balanced_by
2533                    );
2534
2535                    run_uniformity_check(
2536                        capacity_mode,
2537                        balanced_by,
2538                        &actors,
2539                        &vnodes,
2540                        worker_count,
2541                        num_groups,
2542                    );
2543                }
2544            }
2545        }
2546    }
2547
2548    #[test]
2549    fn test_multi_group_affinity() {
2550        let actors: Vec<u16> = (0..ACTOR_COUNT as u16).collect();
2551        let vnodes: Vec<u32> = (0..VNODE_COUNT as u32).collect();
2552
2553        let modes = [
2554            (CapacityMode::Weighted, BalancedBy::RawWorkerWeights),
2555            (CapacityMode::Weighted, BalancedBy::ActorCounts),
2556            (CapacityMode::Unbounded, BalancedBy::RawWorkerWeights),
2557            (CapacityMode::Unbounded, BalancedBy::ActorCounts),
2558        ];
2559
2560        for worker_count in [1, 2, 3, 4] {
2561            for &(capacity_mode, balanced_by) in &modes {
2562                println!(
2563                    "\n--- Testing Mode: Capacity={:?}, BalancedBy={:?} ---\n",
2564                    capacity_mode, balanced_by
2565                );
2566                run_affinity_check(
2567                    capacity_mode,
2568                    balanced_by,
2569                    &actors,
2570                    &vnodes,
2571                    (worker_count, worker_count + 1),
2572                );
2573                run_affinity_check(
2574                    capacity_mode,
2575                    balanced_by,
2576                    &actors,
2577                    &vnodes,
2578                    (worker_count, worker_count * 2),
2579                );
2580            }
2581        }
2582    }
2583
2584    /// Helper function to check the overall VNode distribution uniformity over multiple groups.
2585    fn run_uniformity_check(
2586        capacity_mode: CapacityMode,
2587        balanced_by: BalancedBy,
2588        actors: &[u16],
2589        vnodes: &[u32],
2590        worker_count: u8,
2591        num_groups: usize,
2592    ) {
2593        let initial_workers = create_workers(worker_count);
2594        let mut total_vnode_distribution: BTreeMap<u8, usize> = BTreeMap::new();
2595
2596        for group_id in 0..num_groups {
2597            let salt = group_id as u64;
2598            let assigner =
2599                AssignerBuilder::new(salt).build_with_strategies(capacity_mode, balanced_by);
2600
2601            let assignment = assigner
2602                .assign_hierarchical(&initial_workers, actors, vnodes)
2603                .unwrap();
2604
2605            for (worker_id, actor_map) in assignment {
2606                let vnodes_on_worker: usize = actor_map.values().map(Vec::len).sum();
2607                *total_vnode_distribution.entry(worker_id).or_insert(0) += vnodes_on_worker;
2608            }
2609        }
2610
2611        let total_vnodes_assigned: usize = total_vnode_distribution.values().sum();
2612        let expected_total_vnodes = vnodes.len() * num_groups;
2613        assert_eq!(
2614            total_vnodes_assigned, expected_total_vnodes,
2615            "All vnodes from all groups must be assigned"
2616        );
2617
2618        let worker_count = initial_workers.len();
2619        let expected_vnodes_per_worker = (expected_total_vnodes as f64) / worker_count as f64;
2620        println!(
2621            "[Uniformity] Overall VNode Distribution: {:?}",
2622            total_vnode_distribution
2623        );
2624        println!(
2625            "[Uniformity] Expected per worker: {:.2}",
2626            expected_vnodes_per_worker
2627        );
2628
2629        match (capacity_mode, balanced_by) {
2630            (CapacityMode::Unbounded, BalancedBy::ActorCounts) => {
2631                println!("[Uniformity] Skipping uniformity check for Unbounded/ActorCounts.");
2632            }
2633            _ => {
2634                println!("[Uniformity] Applying STATISTICAL uniformity check.");
2635                let n = expected_total_vnodes as f64;
2636                let p = 1.0 / worker_count as f64;
2637                let mu = n * p;
2638                let sigma = (n * p * (1.0 - p)).sqrt();
2639                const K_SIGMA_TOLERANCE: f64 = 2.0;
2640                let max_allowed_deviation_abs = K_SIGMA_TOLERANCE * sigma;
2641                let dynamic_deviation_threshold = max_allowed_deviation_abs / mu;
2642
2643                println!(
2644                    "[Uniformity] Dynamic Threshold: {:.2}% (based on {} sigma)",
2645                    dynamic_deviation_threshold * 100.0,
2646                    K_SIGMA_TOLERANCE
2647                );
2648
2649                for (worker_id, v_count) in &total_vnode_distribution {
2650                    let deviation = (*v_count as f64 - mu).abs() / mu;
2651                    println!(
2652                        " Worker #{} VNode Count: {}, Deviation: {:.2}%",
2653                        worker_id, v_count, deviation
2654                    );
2655                    assert!(
2656                        deviation < dynamic_deviation_threshold,
2657                        "[Uniformity] VNode distribution on worker {} is not uniform for {:?}/{:?}. Deviation: {:.2}%, Allowed: {:.2}%",
2658                        worker_id,
2659                        capacity_mode,
2660                        balanced_by,
2661                        deviation * 100.0,
2662                        dynamic_deviation_threshold * 100.0
2663                    );
2664                }
2665            }
2666        }
2667
2668        println!("[Uniformity] Test passed.");
2669    }
2670
2671    /// Helper function to check VNode affinity during cluster scaling.
2672    fn run_affinity_check(
2673        capacity_mode: CapacityMode,
2674        balanced_by: BalancedBy,
2675        actors: &[u16],
2676        vnodes: &[u32],
2677        (initial_worker_count, scaled_worker_count): (u8, u8),
2678    ) {
2679        let salt_for_affinity_test = 0u64;
2680
2681        let (expected_ceiling_out, expected_ceiling_in) = calculate_ideal_affinity_thresholds(
2682            salt_for_affinity_test,
2683            initial_worker_count,
2684            scaled_worker_count,
2685            vnodes,
2686        );
2687
2688        println!(
2689            "[Affinity] Calculated ideal benchmarks -> Scale-Out: {:.2}%, Scale-In: {:.2}%",
2690            expected_ceiling_out, expected_ceiling_in
2691        );
2692
2693        let assigner = AssignerBuilder::new(salt_for_affinity_test)
2694            .build_with_strategies(capacity_mode, balanced_by);
2695
2696        let initial_workers = create_workers(initial_worker_count);
2697        let scaled_workers = create_workers(scaled_worker_count);
2698
2699        let initial_map = get_vnode_to_worker_map(
2700            &assigner
2701                .assign_hierarchical(&initial_workers, actors, vnodes)
2702                .unwrap(),
2703        );
2704        let scaled_map = get_vnode_to_worker_map(
2705            &assigner
2706                .assign_hierarchical(&scaled_workers, actors, vnodes)
2707                .unwrap(),
2708        );
2709
2710        let stable_on_scale_out = initial_map
2711            .iter()
2712            .filter(|(v, w)| scaled_map.get(v) == Some(w))
2713            .count();
2714        let stability_pct_out = (stable_on_scale_out as f64 / vnodes.len() as f64) * 100.0;
2715
2716        let vnodes_on_survivors_before = scaled_map
2717            .values()
2718            .filter(|w| initial_workers.contains_key(w))
2719            .count();
2720        let stable_on_scale_in = scaled_map
2721            .iter()
2722            .filter(|(v, w)| initial_workers.contains_key(w) && initial_map.get(v) == Some(w))
2723            .count();
2724        let stability_pct_in = if vnodes_on_survivors_before > 0 {
2725            (stable_on_scale_in as f64 / vnodes_on_survivors_before as f64) * 100.0
2726        } else {
2727            100.0 // If no survivors, vacuously stable.
2728        };
2729
2730        println!(
2731            "[Affinity] Actual Measured Result ({}->{}) -> Scale-Out: {:.2}%, Scale-In: {:.2}%",
2732            initial_worker_count, scaled_worker_count, stability_pct_out, stability_pct_in
2733        );
2734
2735        if capacity_mode == CapacityMode::Unbounded && balanced_by == BalancedBy::ActorCounts {
2736            // skip this mode, as it has no affinity guarantees.
2737        } else {
2738            println!("[Affinity] Expecting HIGH VNode affinity for this mode.");
2739            // Assert that the stability is very close to the theoretical ceiling.
2740            assert!(
2741                (stability_pct_out - expected_ceiling_out).abs() < 20f64,
2742                "Expected stability on scale-out, close to {:.2}%, but got {:.2}%",
2743                expected_ceiling_out,
2744                stability_pct_out
2745            );
2746            assert!(
2747                (stability_pct_in - expected_ceiling_in).abs() < 20f64,
2748                "Expected stability on scale-in, close to {:.2}%, but got {:.2}%",
2749                expected_ceiling_in,
2750                stability_pct_in
2751            );
2752        }
2753        println!("[Affinity] Test passed.");
2754    }
2755
2756    /// A helper function to calculate the precise, ideal affinity thresholds by
2757    /// simulating a direct, single-layer assignment.
2758    fn calculate_ideal_affinity_thresholds(
2759        salt: u64,
2760        initial_worker_count: u8,
2761        scaled_worker_count: u8,
2762        vnodes: &[u32],
2763    ) -> (f64, f64) {
2764        let initial_workers = create_workers(initial_worker_count);
2765        let scaled_workers = create_workers(scaled_worker_count);
2766
2767        // Simulate direct VNode->Worker assignment to get the "ideal" maps.
2768        // This uses the core assignment function, which correctly handles remainders.
2769        let ideal_map_before =
2770            assign_items_weighted_with_scale_fn(&initial_workers, vnodes, salt, unbounded_scale);
2771        let ideal_map_after =
2772            assign_items_weighted_with_scale_fn(&scaled_workers, vnodes, salt, unbounded_scale);
2773
2774        // Flatten the maps for easy comparison.
2775        let initial_flat_map: HashMap<u32, u8> = ideal_map_before
2776            .into_iter()
2777            .flat_map(|(k, vs)| vs.into_iter().map(move |v| (v, k)))
2778            .collect();
2779        let scaled_flat_map: HashMap<u32, u8> = ideal_map_after
2780            .into_iter()
2781            .flat_map(|(k, vs)| vs.into_iter().map(move |v| (v, k)))
2782            .collect();
2783
2784        // Calculate ideal scale-out stability
2785        let stable_on_scale_out = initial_flat_map
2786            .iter()
2787            .filter(|(v, w)| scaled_flat_map.get(v) == Some(w))
2788            .count();
2789        let ideal_stability_out = (stable_on_scale_out as f64 / vnodes.len() as f64) * 100.0;
2790
2791        // Calculate ideal scale-in stability
2792        let vnodes_on_survivors_before = scaled_flat_map
2793            .values()
2794            .filter(|w| initial_workers.contains_key(w))
2795            .count();
2796        let stable_on_scale_in = scaled_flat_map
2797            .iter()
2798            .filter(|(v, w)| initial_workers.contains_key(w) && initial_flat_map.get(v) == Some(w))
2799            .count();
2800        let ideal_stability_in = if vnodes_on_survivors_before > 0 {
2801            (stable_on_scale_in as f64 / vnodes_on_survivors_before as f64) * 100.0
2802        } else {
2803            100.0
2804        };
2805
2806        (ideal_stability_out, ideal_stability_in)
2807    }
2808
2809    /// Helper function to create a `BTreeMap` of workers with equal weights.
2810    fn create_workers(count: u8) -> BTreeMap<u8, NonZeroUsize> {
2811        (1..=count)
2812            .map(|i| (i, NonZeroUsize::new(1).unwrap()))
2813            .collect()
2814    }
2815
2816    /// Helper extension trait to make builder setup cleaner in the test
2817    trait AssignerBuilderExt<S: Hash + Copy> {
2818        fn build_with_strategies(
2819            &mut self,
2820            capacity: CapacityMode,
2821            balance: BalancedBy,
2822        ) -> Assigner<S>;
2823    }
2824
2825    impl<S: Hash + Copy> AssignerBuilderExt<S> for AssignerBuilder<S> {
2826        fn build_with_strategies(
2827            &mut self,
2828            capacity: CapacityMode,
2829            balance: BalancedBy,
2830        ) -> Assigner<S> {
2831            match capacity {
2832                CapacityMode::Weighted => self.with_capacity_weighted(),
2833                CapacityMode::Unbounded => self.with_capacity_unbounded(),
2834            };
2835            match balance {
2836                BalancedBy::RawWorkerWeights => self.with_worker_oriented_balancing(),
2837                BalancedBy::ActorCounts => self.with_actor_oriented_balancing(),
2838            };
2839            self.build()
2840        }
2841    }
2842}