risingwave_meta/stream/source_manager/
split_assignment.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use anyhow::anyhow;
16use itertools::Itertools;
17
18use super::*;
19use crate::model::{FragmentNewNoShuffle, FragmentReplaceUpstream, StreamJobFragments};
20
21impl SourceManager {
22    /// Migrates splits from previous actors to the new actors for a rescheduled fragment.
23    ///
24    /// Very occasionally split removal may happen during scaling, in which case we need to
25    /// use the old splits for reallocation instead of the latest splits (which may be missing),
26    /// so that we can resolve the split removal in the next command.
27    pub async fn migrate_splits_for_source_actors(
28        &self,
29        fragment_id: FragmentId,
30        prev_actor_ids: &[ActorId],
31        curr_actor_ids: &[ActorId],
32    ) -> MetaResult<HashMap<ActorId, Vec<SplitImpl>>> {
33        let core = self.core.lock().await;
34
35        let prev_splits = prev_actor_ids
36            .iter()
37            .flat_map(|actor_id| {
38                // Note: File Source / Iceberg Source doesn't have splits assigned by meta.
39                core.actor_splits.get(actor_id).cloned().unwrap_or_default()
40            })
41            .map(|split| (split.id(), split))
42            .collect();
43
44        let empty_actor_splits = curr_actor_ids
45            .iter()
46            .map(|actor_id| (*actor_id, vec![]))
47            .collect();
48
49        let diff = reassign_splits(
50            fragment_id,
51            empty_actor_splits,
52            &prev_splits,
53            // pre-allocate splits is the first time getting splits and it does not have scale-in scene
54            SplitDiffOptions::default(),
55        )
56        .unwrap_or_default();
57
58        Ok(diff)
59    }
60
61    /// Migrates splits from previous actors to the new actors for a rescheduled fragment.
62    pub fn migrate_splits_for_backfill_actors(
63        &self,
64        fragment_id: FragmentId,
65        upstream_source_fragment_id: FragmentId,
66        curr_actor_ids: &[ActorId],
67        fragment_actor_splits: &HashMap<FragmentId, HashMap<ActorId, Vec<SplitImpl>>>,
68        no_shuffle_upstream_actor_map: &HashMap<ActorId, HashMap<FragmentId, ActorId>>,
69    ) -> MetaResult<HashMap<ActorId, Vec<SplitImpl>>> {
70        // align splits for backfill fragments with its upstream source fragment
71        let actors = no_shuffle_upstream_actor_map
72            .iter()
73            .filter(|(id, _)| curr_actor_ids.contains(id))
74            .map(|(id, upstream_fragment_actors)| {
75                (
76                    *id,
77                    *upstream_fragment_actors
78                        .get(&upstream_source_fragment_id)
79                        .unwrap(),
80                )
81            });
82        let upstream_assignment = fragment_actor_splits
83            .get(&upstream_source_fragment_id)
84            .unwrap();
85        tracing::info!(
86            fragment_id,
87            upstream_source_fragment_id,
88            ?upstream_assignment,
89            "migrate_splits_for_backfill_actors"
90        );
91        Ok(align_splits(
92            actors,
93            upstream_assignment,
94            fragment_id,
95            upstream_source_fragment_id,
96        )?)
97    }
98
99    /// Allocates splits to actors for a newly created source executor.
100    pub async fn allocate_splits(
101        &self,
102        table_fragments: &StreamJobFragments,
103    ) -> MetaResult<SplitAssignment> {
104        let core = self.core.lock().await;
105
106        let source_fragments = table_fragments.stream_source_fragments();
107
108        let mut assigned = HashMap::new();
109
110        'loop_source: for (source_id, fragments) in source_fragments {
111            let handle = core
112                .managed_sources
113                .get(&source_id)
114                .with_context(|| format!("could not find source {}", source_id))?;
115
116            if handle.splits.lock().await.splits.is_none() {
117                handle.force_tick().await?;
118            }
119
120            for fragment_id in fragments {
121                let empty_actor_splits: HashMap<u32, Vec<SplitImpl>> = table_fragments
122                    .fragments
123                    .get(&fragment_id)
124                    .unwrap()
125                    .actors
126                    .iter()
127                    .map(|actor| (actor.actor_id, vec![]))
128                    .collect();
129                let actor_hashset: HashSet<u32> = empty_actor_splits.keys().cloned().collect();
130                let splits = handle.discovered_splits(source_id, &actor_hashset).await?;
131                if splits.is_empty() {
132                    tracing::warn!("no splits detected for source {}", source_id);
133                    continue 'loop_source;
134                }
135
136                if let Some(diff) = reassign_splits(
137                    fragment_id,
138                    empty_actor_splits,
139                    &splits,
140                    SplitDiffOptions::default(),
141                ) {
142                    assigned.insert(fragment_id, diff);
143                }
144            }
145        }
146
147        Ok(assigned)
148    }
149
150    /// Allocates splits to actors for replace source job.
151    pub async fn allocate_splits_for_replace_source(
152        &self,
153        table_fragments: &StreamJobFragments,
154        upstream_updates: &FragmentReplaceUpstream,
155        // new_no_shuffle:
156        //     upstream fragment_id ->
157        //     downstream fragment_id ->
158        //     upstream actor_id ->
159        //     downstream actor_id
160        new_no_shuffle: &FragmentNewNoShuffle,
161    ) -> MetaResult<SplitAssignment> {
162        tracing::debug!(?upstream_updates, "allocate_splits_for_replace_source");
163        if upstream_updates.is_empty() {
164            // no existing downstream. We can just re-allocate splits arbitrarily.
165            return self.allocate_splits(table_fragments).await;
166        }
167
168        let core = self.core.lock().await;
169
170        let source_fragments = table_fragments.stream_source_fragments();
171        assert_eq!(
172            source_fragments.len(),
173            1,
174            "replace source job should only have one source"
175        );
176        let (_source_id, fragments) = source_fragments.into_iter().next().unwrap();
177        assert_eq!(
178            fragments.len(),
179            1,
180            "replace source job should only have one fragment"
181        );
182        let fragment_id = fragments.into_iter().next().unwrap();
183
184        debug_assert!(
185            upstream_updates.values().flatten().next().is_some()
186                && upstream_updates
187                    .values()
188                    .flatten()
189                    .all(|(_, new_upstream_fragment_id)| {
190                        *new_upstream_fragment_id == fragment_id
191                    })
192                && upstream_updates
193                    .values()
194                    .flatten()
195                    .map(|(upstream_fragment_id, _)| upstream_fragment_id)
196                    .all_equal(),
197            "upstream update should only replace one fragment: {:?}",
198            upstream_updates
199        );
200        let prev_fragment_id = upstream_updates
201            .values()
202            .flatten()
203            .next()
204            .map(|(upstream_fragment_id, _)| *upstream_fragment_id)
205            .expect("non-empty");
206        // Here we align the new source executor to backfill executors
207        //
208        // old_source => new_source            backfill_1
209        // actor_x1   => actor_y1 -----┬------>actor_a1
210        // actor_x2   => actor_y2 -----┼-┬---->actor_a2
211        //                             │ │
212        //                             │ │     backfill_2
213        //                             └─┼---->actor_b1
214        //                               └---->actor_b2
215        //
216        // Note: we can choose any backfill actor to align here.
217        // We use `HashMap` to dedup.
218        let aligned_actors: HashMap<ActorId, ActorId> = new_no_shuffle
219            .get(&fragment_id)
220            .map(HashMap::values)
221            .into_iter()
222            .flatten()
223            .flatten()
224            .map(|(upstream_actor_id, actor_id)| (*upstream_actor_id, *actor_id))
225            .collect();
226        let assignment = align_splits(
227            aligned_actors.into_iter(),
228            &core.actor_splits,
229            fragment_id,
230            prev_fragment_id,
231        )?;
232        Ok(HashMap::from([(fragment_id, assignment)]))
233    }
234
235    /// Allocates splits to actors for a newly created `SourceBackfill` executor.
236    ///
237    /// Unlike [`Self::allocate_splits`], which creates a new assignment,
238    /// this method aligns the splits for backfill fragments with its upstream source fragment ([`align_splits`]).
239    pub async fn allocate_splits_for_backfill(
240        &self,
241        table_fragments: &StreamJobFragments,
242        upstream_new_no_shuffle: &FragmentNewNoShuffle,
243        upstream_actors: &HashMap<FragmentId, HashSet<ActorId>>,
244    ) -> MetaResult<SplitAssignment> {
245        let core = self.core.lock().await;
246
247        let source_backfill_fragments = table_fragments.source_backfill_fragments()?;
248
249        let mut assigned = HashMap::new();
250
251        for (_source_id, fragments) in source_backfill_fragments {
252            for (fragment_id, upstream_source_fragment_id) in fragments {
253                let upstream_actors = upstream_actors
254                    .get(&upstream_source_fragment_id)
255                    .ok_or_else(|| {
256                        anyhow!(
257                            "no upstream actors found from fragment {} to upstream source fragment {}",
258                            fragment_id,
259                            upstream_source_fragment_id
260                        )
261                    })?;
262                let mut backfill_actors = vec![];
263                let Some(source_new_no_shuffle) = upstream_new_no_shuffle
264                    .get(&upstream_source_fragment_id)
265                    .and_then(|source_upstream_new_no_shuffle| {
266                        source_upstream_new_no_shuffle.get(&fragment_id)
267                    })
268                else {
269                    return Err(anyhow::anyhow!(
270                            "source backfill fragment's upstream fragment should have one-on-one no_shuffle mapping, fragment_id: {fragment_id}, upstream_fragment_id: {upstream_source_fragment_id}, upstream_new_no_shuffle: {upstream_new_no_shuffle:?}",
271                            fragment_id = fragment_id,
272                            upstream_source_fragment_id = upstream_source_fragment_id,
273                            upstream_new_no_shuffle = upstream_new_no_shuffle,
274                        ).into());
275                };
276                for upstream_actor in upstream_actors {
277                    let Some(no_shuffle_backfill_actor) = source_new_no_shuffle.get(upstream_actor)
278                    else {
279                        return Err(anyhow::anyhow!(
280                            "source backfill fragment's upstream fragment should have one-on-one no_shuffle mapping, fragment_id: {fragment_id}, upstream_fragment_id: {upstream_source_fragment_id}, upstream_actor: {upstream_actor}, source_new_no_shuffle: {source_new_no_shuffle:?}",
281                            fragment_id = fragment_id,
282                            upstream_source_fragment_id = upstream_source_fragment_id,
283                            upstream_actor = upstream_actor,
284                            source_new_no_shuffle = source_new_no_shuffle
285                        ).into());
286                    };
287                    backfill_actors.push((*no_shuffle_backfill_actor, *upstream_actor));
288                }
289                assigned.insert(
290                    fragment_id,
291                    align_splits(
292                        backfill_actors,
293                        &core.actor_splits,
294                        fragment_id,
295                        upstream_source_fragment_id,
296                    )?,
297                );
298            }
299        }
300
301        Ok(assigned)
302    }
303}
304
305impl SourceManagerCore {
306    /// Checks whether the external source metadata has changed,
307    /// and re-assigns splits if there's a diff.
308    ///
309    /// `self.actor_splits` will not be updated. It will be updated by `Self::apply_source_change`,
310    /// after the mutation barrier has been collected.
311    pub async fn reassign_splits(&self) -> MetaResult<HashMap<DatabaseId, SplitAssignment>> {
312        let mut split_assignment: SplitAssignment = HashMap::new();
313
314        'loop_source: for (source_id, handle) in &self.managed_sources {
315            let source_fragment_ids = match self.source_fragments.get(source_id) {
316                Some(fragment_ids) if !fragment_ids.is_empty() => fragment_ids,
317                _ => {
318                    continue;
319                }
320            };
321            let backfill_fragment_ids = self.backfill_fragments.get(source_id);
322
323            'loop_fragment: for &fragment_id in source_fragment_ids {
324                let actors = match self
325                    .metadata_manager
326                    .get_running_actors_of_fragment(fragment_id)
327                    .await
328                {
329                    Ok(actors) => {
330                        if actors.is_empty() {
331                            tracing::warn!("No actors found for fragment {}", fragment_id);
332                            continue 'loop_fragment;
333                        }
334                        actors
335                    }
336                    Err(err) => {
337                        tracing::warn!(error = %err.as_report(), "Failed to get the actor of the fragment, maybe the fragment doesn't exist anymore");
338                        continue 'loop_fragment;
339                    }
340                };
341
342                let discovered_splits = handle.discovered_splits(*source_id, &actors).await?;
343                if discovered_splits.is_empty() {
344                    // The discover loop for this source is not ready yet; we'll wait for the next run
345                    continue 'loop_source;
346                }
347
348                let prev_actor_splits: HashMap<_, _> = actors
349                    .into_iter()
350                    .map(|actor_id| {
351                        (
352                            actor_id,
353                            self.actor_splits
354                                .get(&actor_id)
355                                .cloned()
356                                .unwrap_or_default(),
357                        )
358                    })
359                    .collect();
360
361                if let Some(new_assignment) = reassign_splits(
362                    fragment_id,
363                    prev_actor_splits,
364                    &discovered_splits,
365                    SplitDiffOptions {
366                        enable_scale_in: handle.enable_drop_split,
367                        enable_adaptive: handle.enable_adaptive_splits,
368                    },
369                ) {
370                    split_assignment.insert(fragment_id, new_assignment);
371                }
372            }
373
374            if let Some(backfill_fragment_ids) = backfill_fragment_ids {
375                // align splits for backfill fragments with its upstream source fragment
376                for (fragment_id, upstream_fragment_id) in backfill_fragment_ids {
377                    let Some(upstream_assignment) = split_assignment.get(upstream_fragment_id)
378                    else {
379                        // upstream fragment unchanged, do not update backfill fragment too
380                        continue;
381                    };
382                    let actors = match self
383                        .metadata_manager
384                        .get_running_actors_for_source_backfill(*fragment_id, *upstream_fragment_id)
385                        .await
386                    {
387                        Ok(actors) => {
388                            if actors.is_empty() {
389                                tracing::warn!("No actors found for fragment {}", fragment_id);
390                                continue;
391                            }
392                            actors
393                        }
394                        Err(err) => {
395                            tracing::warn!(error = %err.as_report(),"Failed to get the actor of the fragment, maybe the fragment doesn't exist anymore");
396                            continue;
397                        }
398                    };
399                    split_assignment.insert(
400                        *fragment_id,
401                        align_splits(
402                            actors,
403                            upstream_assignment,
404                            *fragment_id,
405                            *upstream_fragment_id,
406                        )?,
407                    );
408                }
409            }
410        }
411
412        self.metadata_manager
413            .split_fragment_map_by_database(split_assignment)
414            .await
415    }
416}
417
418/// Reassigns splits if there are new splits or dropped splits,
419/// i.e., `actor_splits` and `discovered_splits` differ, or actors are rescheduled.
420///
421/// The existing splits will remain unmoved in their currently assigned actor.
422///
423/// If an actor has an upstream actor, it should be a backfill executor,
424/// and its splits should be aligned with the upstream actor. **`reassign_splits` should not be used in this case.
425/// Use [`align_splits`] instead.**
426///
427/// - `fragment_id`: just for logging
428///
429/// ## Different connectors' behavior of split change
430///
431/// ### Kafka and Pulsar
432/// They only support increasing the number of splits via adding new empty splits.
433/// Old data is not moved.
434///
435/// ### Kinesis
436/// It supports *pairwise* shard split and merge.
437///
438/// In both cases, old data remain in the old shard(s) and the old shard is still available.
439/// New data are routed to the new shard(s).
440/// After the retention period has expired, the old shard will become `EXPIRED` and isn't
441/// listed any more. In other words, the total number of shards will first increase and then decrease.
442///
443/// See also:
444/// - [Kinesis resharding doc](https://docs.aws.amazon.com/streams/latest/dev/kinesis-using-sdk-java-after-resharding.html#kinesis-using-sdk-java-resharding-data-routing)
445/// - An example of how the shards can be like: <https://stackoverflow.com/questions/72272034/list-shard-show-more-shards-than-provisioned>
446fn reassign_splits<T>(
447    fragment_id: FragmentId,
448    actor_splits: HashMap<ActorId, Vec<T>>,
449    discovered_splits: &BTreeMap<SplitId, T>,
450    opts: SplitDiffOptions,
451) -> Option<HashMap<ActorId, Vec<T>>>
452where
453    T: SplitMetaData + Clone,
454{
455    // if no actors, return
456    if actor_splits.is_empty() {
457        return None;
458    }
459
460    let prev_split_ids: HashSet<_> = actor_splits
461        .values()
462        .flat_map(|splits| splits.iter().map(SplitMetaData::id))
463        .collect();
464
465    tracing::trace!(fragment_id, prev_split_ids = ?prev_split_ids, "previous splits");
466    tracing::trace!(fragment_id, prev_split_ids = ?discovered_splits.keys(), "discovered splits");
467
468    let discovered_split_ids: HashSet<_> = discovered_splits.keys().cloned().collect();
469
470    let dropped_splits: HashSet<_> = prev_split_ids
471        .difference(&discovered_split_ids)
472        .cloned()
473        .collect();
474
475    if !dropped_splits.is_empty() {
476        if opts.enable_scale_in {
477            tracing::info!(fragment_id, dropped_spltis = ?dropped_splits, "new dropped splits");
478        } else {
479            tracing::warn!(fragment_id, dropped_spltis = ?dropped_splits, "split dropping happened, but it is not allowed");
480        }
481    }
482
483    let new_discovered_splits: BTreeSet<_> = discovered_split_ids
484        .into_iter()
485        .filter(|split_id| !prev_split_ids.contains(split_id))
486        .collect();
487
488    if opts.enable_scale_in || opts.enable_adaptive {
489        // if we support scale in, no more splits are discovered, and no splits are dropped, return
490        // we need to check if discovered_split_ids is empty, because if it is empty, we need to
491        // handle the case of scale in to zero (like deleting all objects from s3)
492        if dropped_splits.is_empty()
493            && new_discovered_splits.is_empty()
494            && !discovered_splits.is_empty()
495        {
496            return None;
497        }
498    } else {
499        // if we do not support scale in, and no more splits are discovered, return
500        if new_discovered_splits.is_empty() && !discovered_splits.is_empty() {
501            return None;
502        }
503    }
504
505    tracing::info!(fragment_id, new_discovered_splits = ?new_discovered_splits, "new discovered splits");
506
507    let mut heap = BinaryHeap::with_capacity(actor_splits.len());
508
509    for (actor_id, mut splits) in actor_splits {
510        if opts.enable_scale_in || opts.enable_adaptive {
511            splits.retain(|split| !dropped_splits.contains(&split.id()));
512        }
513
514        heap.push(ActorSplitsAssignment { actor_id, splits })
515    }
516
517    for split_id in new_discovered_splits {
518        // ActorSplitsAssignment's Ord is reversed, so this is min heap, i.e.,
519        // we get the assignment with the least splits here.
520
521        // Note: If multiple actors have the same number of splits, it will be randomly picked.
522        // When the number of source actors is larger than the number of splits,
523        // It's possible that the assignment is uneven.
524        // e.g., https://github.com/risingwavelabs/risingwave/issues/14324#issuecomment-1875033158
525        // TODO: We should make the assignment rack-aware to make sure it's even.
526        let mut peek_ref = heap.peek_mut().unwrap();
527        peek_ref
528            .splits
529            .push(discovered_splits.get(&split_id).cloned().unwrap());
530    }
531
532    Some(
533        heap.into_iter()
534            .map(|ActorSplitsAssignment { actor_id, splits }| (actor_id, splits))
535            .collect(),
536    )
537}
538
539/// Assign splits to a new set of actors, according to existing assignment.
540///
541/// illustration:
542/// ```text
543/// upstream                               new
544/// actor x1 [split 1, split2]      ->     actor y1 [split 1, split2]
545/// actor x2 [split 3]              ->     actor y2 [split 3]
546/// ...
547/// ```
548fn align_splits(
549    // (actor_id, upstream_actor_id)
550    aligned_actors: impl IntoIterator<Item = (ActorId, ActorId)>,
551    existing_assignment: &HashMap<ActorId, Vec<SplitImpl>>,
552    fragment_id: FragmentId,
553    upstream_source_fragment_id: FragmentId,
554) -> anyhow::Result<HashMap<ActorId, Vec<SplitImpl>>> {
555    aligned_actors
556        .into_iter()
557        .map(|(actor_id, upstream_actor_id)| {
558            let Some(splits) = existing_assignment.get(&upstream_actor_id) else {
559                return Err(anyhow::anyhow!("upstream assignment not found, fragment_id: {fragment_id}, upstream_fragment_id: {upstream_source_fragment_id}, actor_id: {actor_id}, upstream_assignment: {existing_assignment:?}, upstream_actor_id: {upstream_actor_id:?}"));
560            };
561            Ok((
562                actor_id,
563                splits.clone(),
564            ))
565        })
566        .collect()
567}
568
569/// Note: the `PartialEq` and `Ord` impl just compares the number of splits.
570#[derive(Debug)]
571struct ActorSplitsAssignment<T: SplitMetaData> {
572    actor_id: ActorId,
573    splits: Vec<T>,
574}
575
576impl<T: SplitMetaData + Clone> Eq for ActorSplitsAssignment<T> {}
577
578impl<T: SplitMetaData + Clone> PartialEq<Self> for ActorSplitsAssignment<T> {
579    fn eq(&self, other: &Self) -> bool {
580        self.splits.len() == other.splits.len()
581    }
582}
583
584impl<T: SplitMetaData + Clone> PartialOrd<Self> for ActorSplitsAssignment<T> {
585    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
586        Some(self.cmp(other))
587    }
588}
589
590impl<T: SplitMetaData + Clone> Ord for ActorSplitsAssignment<T> {
591    fn cmp(&self, other: &Self) -> Ordering {
592        // Note: this is reversed order, to make BinaryHeap a min heap.
593        other.splits.len().cmp(&self.splits.len())
594    }
595}
596
597#[derive(Debug)]
598pub struct SplitDiffOptions {
599    pub enable_scale_in: bool,
600
601    /// For most connectors, this should be false. When enabled, RisingWave will not track any progress.
602    pub enable_adaptive: bool,
603}
604
605#[allow(clippy::derivable_impls)]
606impl Default for SplitDiffOptions {
607    fn default() -> Self {
608        SplitDiffOptions {
609            enable_scale_in: false,
610            enable_adaptive: false,
611        }
612    }
613}
614
615#[cfg(test)]
616mod tests {
617    use std::collections::{BTreeMap, HashMap, HashSet};
618
619    use risingwave_common::types::JsonbVal;
620    use risingwave_connector::error::ConnectorResult;
621    use risingwave_connector::source::{SplitId, SplitMetaData};
622    use serde::{Deserialize, Serialize};
623
624    use super::*;
625    use crate::model::{ActorId, FragmentId};
626
627    #[derive(Debug, Copy, Clone, Serialize, Deserialize)]
628    struct TestSplit {
629        id: u32,
630    }
631
632    impl SplitMetaData for TestSplit {
633        fn id(&self) -> SplitId {
634            format!("{}", self.id).into()
635        }
636
637        fn encode_to_json(&self) -> JsonbVal {
638            serde_json::to_value(*self).unwrap().into()
639        }
640
641        fn restore_from_json(value: JsonbVal) -> ConnectorResult<Self> {
642            serde_json::from_value(value.take()).map_err(Into::into)
643        }
644
645        fn update_offset(&mut self, _last_read_offset: String) -> ConnectorResult<()> {
646            Ok(())
647        }
648    }
649
650    fn check_all_splits(
651        discovered_splits: &BTreeMap<SplitId, TestSplit>,
652        diff: &HashMap<ActorId, Vec<TestSplit>>,
653    ) {
654        let mut split_ids: HashSet<_> = discovered_splits.keys().cloned().collect();
655
656        for splits in diff.values() {
657            for split in splits {
658                assert!(split_ids.remove(&split.id()))
659            }
660        }
661
662        assert!(split_ids.is_empty());
663    }
664
665    #[test]
666    fn test_drop_splits() {
667        let mut actor_splits: HashMap<ActorId, _> = HashMap::new();
668        actor_splits.insert(0, vec![TestSplit { id: 0 }, TestSplit { id: 1 }]);
669        actor_splits.insert(1, vec![TestSplit { id: 2 }, TestSplit { id: 3 }]);
670        actor_splits.insert(2, vec![TestSplit { id: 4 }, TestSplit { id: 5 }]);
671
672        let mut prev_split_to_actor = HashMap::new();
673        for (actor_id, splits) in &actor_splits {
674            for split in splits {
675                prev_split_to_actor.insert(split.id(), *actor_id);
676            }
677        }
678
679        let discovered_splits: BTreeMap<SplitId, TestSplit> = (1..5)
680            .map(|i| {
681                let split = TestSplit { id: i };
682                (split.id(), split)
683            })
684            .collect();
685
686        let opts = SplitDiffOptions {
687            enable_scale_in: true,
688            enable_adaptive: false,
689        };
690
691        let prev_split_ids: HashSet<_> = actor_splits
692            .values()
693            .flat_map(|splits| splits.iter().map(|split| split.id()))
694            .collect();
695
696        let diff = reassign_splits(
697            FragmentId::default(),
698            actor_splits,
699            &discovered_splits,
700            opts,
701        )
702        .unwrap();
703        check_all_splits(&discovered_splits, &diff);
704
705        let mut after_split_to_actor = HashMap::new();
706        for (actor_id, splits) in &diff {
707            for split in splits {
708                after_split_to_actor.insert(split.id(), *actor_id);
709            }
710        }
711
712        let discovered_split_ids: HashSet<_> = discovered_splits.keys().cloned().collect();
713
714        let retained_split_ids: HashSet<_> =
715            prev_split_ids.intersection(&discovered_split_ids).collect();
716
717        for retained_split_id in retained_split_ids {
718            assert_eq!(
719                prev_split_to_actor.get(retained_split_id),
720                after_split_to_actor.get(retained_split_id)
721            )
722        }
723    }
724
725    #[test]
726    fn test_drop_splits_to_empty() {
727        let mut actor_splits: HashMap<ActorId, _> = HashMap::new();
728        actor_splits.insert(0, vec![TestSplit { id: 0 }]);
729
730        let discovered_splits: BTreeMap<SplitId, TestSplit> = BTreeMap::new();
731
732        let opts = SplitDiffOptions {
733            enable_scale_in: true,
734            enable_adaptive: false,
735        };
736
737        let diff = reassign_splits(
738            FragmentId::default(),
739            actor_splits,
740            &discovered_splits,
741            opts,
742        )
743        .unwrap();
744
745        assert!(!diff.is_empty())
746    }
747
748    #[test]
749    fn test_reassign_splits() {
750        let actor_splits = HashMap::new();
751        let discovered_splits: BTreeMap<SplitId, TestSplit> = BTreeMap::new();
752        assert!(
753            reassign_splits(
754                FragmentId::default(),
755                actor_splits,
756                &discovered_splits,
757                Default::default()
758            )
759            .is_none()
760        );
761
762        let actor_splits = (0..3).map(|i| (i, vec![])).collect();
763        let discovered_splits: BTreeMap<SplitId, TestSplit> = BTreeMap::new();
764        let diff = reassign_splits(
765            FragmentId::default(),
766            actor_splits,
767            &discovered_splits,
768            Default::default(),
769        )
770        .unwrap();
771        assert_eq!(diff.len(), 3);
772        for splits in diff.values() {
773            assert!(splits.is_empty())
774        }
775
776        let actor_splits = (0..3).map(|i| (i, vec![])).collect();
777        let discovered_splits: BTreeMap<SplitId, TestSplit> = (0..3)
778            .map(|i| {
779                let split = TestSplit { id: i };
780                (split.id(), split)
781            })
782            .collect();
783
784        let diff = reassign_splits(
785            FragmentId::default(),
786            actor_splits,
787            &discovered_splits,
788            Default::default(),
789        )
790        .unwrap();
791        assert_eq!(diff.len(), 3);
792        for splits in diff.values() {
793            assert_eq!(splits.len(), 1);
794        }
795
796        check_all_splits(&discovered_splits, &diff);
797
798        let actor_splits = (0..3).map(|i| (i, vec![TestSplit { id: i }])).collect();
799        let discovered_splits: BTreeMap<SplitId, TestSplit> = (0..5)
800            .map(|i| {
801                let split = TestSplit { id: i };
802                (split.id(), split)
803            })
804            .collect();
805
806        let diff = reassign_splits(
807            FragmentId::default(),
808            actor_splits,
809            &discovered_splits,
810            Default::default(),
811        )
812        .unwrap();
813        assert_eq!(diff.len(), 3);
814        for splits in diff.values() {
815            let len = splits.len();
816            assert!(len == 1 || len == 2);
817        }
818
819        check_all_splits(&discovered_splits, &diff);
820
821        let mut actor_splits: HashMap<ActorId, Vec<TestSplit>> =
822            (0..3).map(|i| (i, vec![TestSplit { id: i }])).collect();
823        actor_splits.insert(3, vec![]);
824        actor_splits.insert(4, vec![]);
825
826        let discovered_splits: BTreeMap<SplitId, TestSplit> = (0..5)
827            .map(|i| {
828                let split = TestSplit { id: i };
829                (split.id(), split)
830            })
831            .collect();
832
833        let diff = reassign_splits(
834            FragmentId::default(),
835            actor_splits,
836            &discovered_splits,
837            Default::default(),
838        )
839        .unwrap();
840        assert_eq!(diff.len(), 5);
841        for splits in diff.values() {
842            assert_eq!(splits.len(), 1);
843        }
844
845        check_all_splits(&discovered_splits, &diff);
846    }
847}