risingwave_meta/stream/source_manager/
split_assignment.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use anyhow::anyhow;
16use itertools::Itertools;
17
18use super::*;
19use crate::model::{FragmentNewNoShuffle, FragmentReplaceUpstream, StreamJobFragments};
20
21impl SourceManager {
22    /// Migrates splits from previous actors to the new actors for a rescheduled fragment.
23    ///
24    /// Very occasionally split removal may happen during scaling, in which case we need to
25    /// use the old splits for reallocation instead of the latest splits (which may be missing),
26    /// so that we can resolve the split removal in the next command.
27    pub async fn migrate_splits_for_source_actors(
28        &self,
29        fragment_id: FragmentId,
30        prev_actor_ids: &[ActorId],
31        curr_actor_ids: &[ActorId],
32    ) -> MetaResult<HashMap<ActorId, Vec<SplitImpl>>> {
33        let core = self.core.lock().await;
34
35        let prev_splits = prev_actor_ids
36            .iter()
37            .flat_map(|actor_id| {
38                // Note: File Source / Iceberg Source doesn't have splits assigned by meta.
39                core.actor_splits.get(actor_id).cloned().unwrap_or_default()
40            })
41            .map(|split| (split.id(), split))
42            .collect();
43
44        let empty_actor_splits = curr_actor_ids
45            .iter()
46            .map(|actor_id| (*actor_id, vec![]))
47            .collect();
48
49        let diff = reassign_splits(
50            fragment_id,
51            empty_actor_splits,
52            &prev_splits,
53            // pre-allocate splits is the first time getting splits and it does not have scale-in scene
54            SplitDiffOptions::default(),
55        )
56        .unwrap_or_default();
57
58        Ok(diff)
59    }
60
61    /// Migrates splits from previous actors to the new actors for a rescheduled fragment.
62    pub fn migrate_splits_for_backfill_actors(
63        &self,
64        fragment_id: FragmentId,
65        upstream_source_fragment_id: FragmentId,
66        curr_actor_ids: &[ActorId],
67        fragment_actor_splits: &HashMap<FragmentId, HashMap<ActorId, Vec<SplitImpl>>>,
68        no_shuffle_upstream_actor_map: &HashMap<ActorId, HashMap<FragmentId, ActorId>>,
69    ) -> MetaResult<HashMap<ActorId, Vec<SplitImpl>>> {
70        // align splits for backfill fragments with its upstream source fragment
71        let actors = no_shuffle_upstream_actor_map
72            .iter()
73            .filter(|(id, _)| curr_actor_ids.contains(id))
74            .map(|(id, upstream_fragment_actors)| {
75                (
76                    *id,
77                    *upstream_fragment_actors
78                        .get(&upstream_source_fragment_id)
79                        .unwrap(),
80                )
81            });
82        let upstream_assignment = fragment_actor_splits
83            .get(&upstream_source_fragment_id)
84            .unwrap();
85        tracing::info!(
86            fragment_id,
87            upstream_source_fragment_id,
88            ?upstream_assignment,
89            "migrate_splits_for_backfill_actors"
90        );
91        Ok(align_splits(
92            actors,
93            upstream_assignment,
94            fragment_id,
95            upstream_source_fragment_id,
96        )?)
97    }
98
99    /// Allocates splits to actors for a newly created source executor.
100    #[await_tree::instrument]
101    pub async fn allocate_splits(
102        &self,
103        table_fragments: &StreamJobFragments,
104    ) -> MetaResult<SplitAssignment> {
105        let core = self.core.lock().await;
106
107        let source_fragments = table_fragments.stream_source_fragments();
108
109        let mut assigned = HashMap::new();
110
111        'loop_source: for (source_id, fragments) in source_fragments {
112            let handle = core
113                .managed_sources
114                .get(&source_id)
115                .with_context(|| format!("could not find source {}", source_id))?;
116
117            if handle.splits.lock().await.splits.is_none() {
118                handle.force_tick().await?;
119            }
120
121            for fragment_id in fragments {
122                let empty_actor_splits: HashMap<u32, Vec<SplitImpl>> = table_fragments
123                    .fragments
124                    .get(&fragment_id)
125                    .unwrap()
126                    .actors
127                    .iter()
128                    .map(|actor| (actor.actor_id, vec![]))
129                    .collect();
130                let actor_hashset: HashSet<u32> = empty_actor_splits.keys().cloned().collect();
131                let splits = handle.discovered_splits(source_id, &actor_hashset).await?;
132                if splits.is_empty() {
133                    tracing::warn!("no splits detected for source {}", source_id);
134                    continue 'loop_source;
135                }
136
137                if let Some(diff) = reassign_splits(
138                    fragment_id,
139                    empty_actor_splits,
140                    &splits,
141                    SplitDiffOptions::default(),
142                ) {
143                    assigned.insert(fragment_id, diff);
144                }
145            }
146        }
147
148        Ok(assigned)
149    }
150
151    /// Allocates splits to actors for replace source job.
152    pub async fn allocate_splits_for_replace_source(
153        &self,
154        table_fragments: &StreamJobFragments,
155        upstream_updates: &FragmentReplaceUpstream,
156        // new_no_shuffle:
157        //     upstream fragment_id ->
158        //     downstream fragment_id ->
159        //     upstream actor_id ->
160        //     downstream actor_id
161        new_no_shuffle: &FragmentNewNoShuffle,
162    ) -> MetaResult<SplitAssignment> {
163        tracing::debug!(?upstream_updates, "allocate_splits_for_replace_source");
164        if upstream_updates.is_empty() {
165            // no existing downstream. We can just re-allocate splits arbitrarily.
166            return self.allocate_splits(table_fragments).await;
167        }
168
169        let core = self.core.lock().await;
170
171        let source_fragments = table_fragments.stream_source_fragments();
172        assert_eq!(
173            source_fragments.len(),
174            1,
175            "replace source job should only have one source"
176        );
177        let (_source_id, fragments) = source_fragments.into_iter().next().unwrap();
178        assert_eq!(
179            fragments.len(),
180            1,
181            "replace source job should only have one fragment"
182        );
183        let fragment_id = fragments.into_iter().next().unwrap();
184
185        debug_assert!(
186            upstream_updates.values().flatten().next().is_some()
187                && upstream_updates
188                    .values()
189                    .flatten()
190                    .all(|(_, new_upstream_fragment_id)| {
191                        *new_upstream_fragment_id == fragment_id
192                    })
193                && upstream_updates
194                    .values()
195                    .flatten()
196                    .map(|(upstream_fragment_id, _)| upstream_fragment_id)
197                    .all_equal(),
198            "upstream update should only replace one fragment: {:?}",
199            upstream_updates
200        );
201        let prev_fragment_id = upstream_updates
202            .values()
203            .flatten()
204            .next()
205            .map(|(upstream_fragment_id, _)| *upstream_fragment_id)
206            .expect("non-empty");
207        // Here we align the new source executor to backfill executors
208        //
209        // old_source => new_source            backfill_1
210        // actor_x1   => actor_y1 -----┬------>actor_a1
211        // actor_x2   => actor_y2 -----┼-┬---->actor_a2
212        //                             │ │
213        //                             │ │     backfill_2
214        //                             └─┼---->actor_b1
215        //                               └---->actor_b2
216        //
217        // Note: we can choose any backfill actor to align here.
218        // We use `HashMap` to dedup.
219        let aligned_actors: HashMap<ActorId, ActorId> = new_no_shuffle
220            .get(&fragment_id)
221            .map(HashMap::values)
222            .into_iter()
223            .flatten()
224            .flatten()
225            .map(|(upstream_actor_id, actor_id)| (*upstream_actor_id, *actor_id))
226            .collect();
227        let assignment = align_splits(
228            aligned_actors.into_iter(),
229            &core.actor_splits,
230            fragment_id,
231            prev_fragment_id,
232        )?;
233        Ok(HashMap::from([(fragment_id, assignment)]))
234    }
235
236    /// Allocates splits to actors for a newly created `SourceBackfill` executor.
237    ///
238    /// Unlike [`Self::allocate_splits`], which creates a new assignment,
239    /// this method aligns the splits for backfill fragments with its upstream source fragment ([`align_splits`]).
240    #[await_tree::instrument]
241    pub async fn allocate_splits_for_backfill(
242        &self,
243        table_fragments: &StreamJobFragments,
244        upstream_new_no_shuffle: &FragmentNewNoShuffle,
245        upstream_actors: &HashMap<FragmentId, HashSet<ActorId>>,
246    ) -> MetaResult<SplitAssignment> {
247        let core = self.core.lock().await;
248
249        let source_backfill_fragments = table_fragments.source_backfill_fragments()?;
250
251        let mut assigned = HashMap::new();
252
253        for (_source_id, fragments) in source_backfill_fragments {
254            for (fragment_id, upstream_source_fragment_id) in fragments {
255                let upstream_actors = upstream_actors
256                    .get(&upstream_source_fragment_id)
257                    .ok_or_else(|| {
258                        anyhow!(
259                            "no upstream actors found from fragment {} to upstream source fragment {}",
260                            fragment_id,
261                            upstream_source_fragment_id
262                        )
263                    })?;
264                let mut backfill_actors = vec![];
265                let Some(source_new_no_shuffle) = upstream_new_no_shuffle
266                    .get(&upstream_source_fragment_id)
267                    .and_then(|source_upstream_new_no_shuffle| {
268                        source_upstream_new_no_shuffle.get(&fragment_id)
269                    })
270                else {
271                    return Err(anyhow::anyhow!(
272                            "source backfill fragment's upstream fragment should have one-on-one no_shuffle mapping, fragment_id: {fragment_id}, upstream_fragment_id: {upstream_source_fragment_id}, upstream_new_no_shuffle: {upstream_new_no_shuffle:?}",
273                            fragment_id = fragment_id,
274                            upstream_source_fragment_id = upstream_source_fragment_id,
275                            upstream_new_no_shuffle = upstream_new_no_shuffle,
276                        ).into());
277                };
278                for upstream_actor in upstream_actors {
279                    let Some(no_shuffle_backfill_actor) = source_new_no_shuffle.get(upstream_actor)
280                    else {
281                        return Err(anyhow::anyhow!(
282                            "source backfill fragment's upstream fragment should have one-on-one no_shuffle mapping, fragment_id: {fragment_id}, upstream_fragment_id: {upstream_source_fragment_id}, upstream_actor: {upstream_actor}, source_new_no_shuffle: {source_new_no_shuffle:?}",
283                            fragment_id = fragment_id,
284                            upstream_source_fragment_id = upstream_source_fragment_id,
285                            upstream_actor = upstream_actor,
286                            source_new_no_shuffle = source_new_no_shuffle
287                        ).into());
288                    };
289                    backfill_actors.push((*no_shuffle_backfill_actor, *upstream_actor));
290                }
291                assigned.insert(
292                    fragment_id,
293                    align_splits(
294                        backfill_actors,
295                        &core.actor_splits,
296                        fragment_id,
297                        upstream_source_fragment_id,
298                    )?,
299                );
300            }
301        }
302
303        Ok(assigned)
304    }
305}
306
307impl SourceManagerCore {
308    /// Checks whether the external source metadata has changed,
309    /// and re-assigns splits if there's a diff.
310    ///
311    /// `self.actor_splits` will not be updated. It will be updated by `Self::apply_source_change`,
312    /// after the mutation barrier has been collected.
313    pub async fn reassign_splits(&self) -> MetaResult<HashMap<DatabaseId, SplitAssignment>> {
314        let mut split_assignment: SplitAssignment = HashMap::new();
315
316        'loop_source: for (source_id, handle) in &self.managed_sources {
317            let source_fragment_ids = match self.source_fragments.get(source_id) {
318                Some(fragment_ids) if !fragment_ids.is_empty() => fragment_ids,
319                _ => {
320                    continue;
321                }
322            };
323            let backfill_fragment_ids = self.backfill_fragments.get(source_id);
324
325            'loop_fragment: for &fragment_id in source_fragment_ids {
326                let actors = match self
327                    .metadata_manager
328                    .get_running_actors_of_fragment(fragment_id)
329                    .await
330                {
331                    Ok(actors) => {
332                        if actors.is_empty() {
333                            tracing::warn!("No actors found for fragment {}", fragment_id);
334                            continue 'loop_fragment;
335                        }
336                        actors
337                    }
338                    Err(err) => {
339                        tracing::warn!(error = %err.as_report(), "Failed to get the actor of the fragment, maybe the fragment doesn't exist anymore");
340                        continue 'loop_fragment;
341                    }
342                };
343
344                let discovered_splits = handle.discovered_splits(*source_id, &actors).await?;
345                if discovered_splits.is_empty() {
346                    // The discover loop for this source is not ready yet; we'll wait for the next run
347                    continue 'loop_source;
348                }
349
350                let prev_actor_splits: HashMap<_, _> = actors
351                    .into_iter()
352                    .map(|actor_id| {
353                        (
354                            actor_id,
355                            self.actor_splits
356                                .get(&actor_id)
357                                .cloned()
358                                .unwrap_or_default(),
359                        )
360                    })
361                    .collect();
362
363                if let Some(new_assignment) = reassign_splits(
364                    fragment_id,
365                    prev_actor_splits,
366                    &discovered_splits,
367                    SplitDiffOptions {
368                        enable_scale_in: handle.enable_drop_split,
369                        enable_adaptive: handle.enable_adaptive_splits,
370                    },
371                ) {
372                    split_assignment.insert(fragment_id, new_assignment);
373                }
374            }
375
376            if let Some(backfill_fragment_ids) = backfill_fragment_ids {
377                // align splits for backfill fragments with its upstream source fragment
378                for (fragment_id, upstream_fragment_id) in backfill_fragment_ids {
379                    let Some(upstream_assignment) = split_assignment.get(upstream_fragment_id)
380                    else {
381                        // upstream fragment unchanged, do not update backfill fragment too
382                        continue;
383                    };
384                    let actors = match self
385                        .metadata_manager
386                        .get_running_actors_for_source_backfill(*fragment_id, *upstream_fragment_id)
387                        .await
388                    {
389                        Ok(actors) => {
390                            if actors.is_empty() {
391                                tracing::warn!("No actors found for fragment {}", fragment_id);
392                                continue;
393                            }
394                            actors
395                        }
396                        Err(err) => {
397                            tracing::warn!(error = %err.as_report(),"Failed to get the actor of the fragment, maybe the fragment doesn't exist anymore");
398                            continue;
399                        }
400                    };
401                    split_assignment.insert(
402                        *fragment_id,
403                        align_splits(
404                            actors,
405                            upstream_assignment,
406                            *fragment_id,
407                            *upstream_fragment_id,
408                        )?,
409                    );
410                }
411            }
412        }
413
414        self.metadata_manager
415            .split_fragment_map_by_database(split_assignment)
416            .await
417    }
418}
419
420/// Reassigns splits if there are new splits or dropped splits,
421/// i.e., `actor_splits` and `discovered_splits` differ, or actors are rescheduled.
422///
423/// The existing splits will remain unmoved in their currently assigned actor.
424///
425/// If an actor has an upstream actor, it should be a backfill executor,
426/// and its splits should be aligned with the upstream actor. **`reassign_splits` should not be used in this case.
427/// Use [`align_splits`] instead.**
428///
429/// - `fragment_id`: just for logging
430///
431/// ## Different connectors' behavior of split change
432///
433/// ### Kafka and Pulsar
434/// They only support increasing the number of splits via adding new empty splits.
435/// Old data is not moved.
436///
437/// ### Kinesis
438/// It supports *pairwise* shard split and merge.
439///
440/// In both cases, old data remain in the old shard(s) and the old shard is still available.
441/// New data are routed to the new shard(s).
442/// After the retention period has expired, the old shard will become `EXPIRED` and isn't
443/// listed any more. In other words, the total number of shards will first increase and then decrease.
444///
445/// See also:
446/// - [Kinesis resharding doc](https://docs.aws.amazon.com/streams/latest/dev/kinesis-using-sdk-java-after-resharding.html#kinesis-using-sdk-java-resharding-data-routing)
447/// - An example of how the shards can be like: <https://stackoverflow.com/questions/72272034/list-shard-show-more-shards-than-provisioned>
448fn reassign_splits<T>(
449    fragment_id: FragmentId,
450    actor_splits: HashMap<ActorId, Vec<T>>,
451    discovered_splits: &BTreeMap<SplitId, T>,
452    opts: SplitDiffOptions,
453) -> Option<HashMap<ActorId, Vec<T>>>
454where
455    T: SplitMetaData + Clone,
456{
457    // if no actors, return
458    if actor_splits.is_empty() {
459        return None;
460    }
461
462    let prev_split_ids: HashSet<_> = actor_splits
463        .values()
464        .flat_map(|splits| splits.iter().map(SplitMetaData::id))
465        .collect();
466
467    tracing::trace!(fragment_id, prev_split_ids = ?prev_split_ids, "previous splits");
468    tracing::trace!(fragment_id, prev_split_ids = ?discovered_splits.keys(), "discovered splits");
469
470    let discovered_split_ids: HashSet<_> = discovered_splits.keys().cloned().collect();
471
472    let dropped_splits: HashSet<_> = prev_split_ids
473        .difference(&discovered_split_ids)
474        .cloned()
475        .collect();
476
477    if !dropped_splits.is_empty() {
478        if opts.enable_scale_in {
479            tracing::info!(fragment_id, dropped_spltis = ?dropped_splits, "new dropped splits");
480        } else {
481            tracing::warn!(fragment_id, dropped_spltis = ?dropped_splits, "split dropping happened, but it is not allowed");
482        }
483    }
484
485    let new_discovered_splits: BTreeSet<_> = discovered_split_ids
486        .into_iter()
487        .filter(|split_id| !prev_split_ids.contains(split_id))
488        .collect();
489
490    if opts.enable_scale_in || opts.enable_adaptive {
491        // if we support scale in, no more splits are discovered, and no splits are dropped, return
492        // we need to check if discovered_split_ids is empty, because if it is empty, we need to
493        // handle the case of scale in to zero (like deleting all objects from s3)
494        if dropped_splits.is_empty()
495            && new_discovered_splits.is_empty()
496            && !discovered_splits.is_empty()
497        {
498            return None;
499        }
500    } else {
501        // if we do not support scale in, and no more splits are discovered, return
502        if new_discovered_splits.is_empty() && !discovered_splits.is_empty() {
503            return None;
504        }
505    }
506
507    tracing::info!(fragment_id, new_discovered_splits = ?new_discovered_splits, "new discovered splits");
508
509    let mut heap = BinaryHeap::with_capacity(actor_splits.len());
510
511    for (actor_id, mut splits) in actor_splits {
512        if opts.enable_scale_in || opts.enable_adaptive {
513            splits.retain(|split| !dropped_splits.contains(&split.id()));
514        }
515
516        heap.push(ActorSplitsAssignment { actor_id, splits })
517    }
518
519    for split_id in new_discovered_splits {
520        // ActorSplitsAssignment's Ord is reversed, so this is min heap, i.e.,
521        // we get the assignment with the least splits here.
522
523        // Note: If multiple actors have the same number of splits, it will be randomly picked.
524        // When the number of source actors is larger than the number of splits,
525        // It's possible that the assignment is uneven.
526        // e.g., https://github.com/risingwavelabs/risingwave/issues/14324#issuecomment-1875033158
527        // TODO: We should make the assignment rack-aware to make sure it's even.
528        let mut peek_ref = heap.peek_mut().unwrap();
529        peek_ref
530            .splits
531            .push(discovered_splits.get(&split_id).cloned().unwrap());
532    }
533
534    Some(
535        heap.into_iter()
536            .map(|ActorSplitsAssignment { actor_id, splits }| (actor_id, splits))
537            .collect(),
538    )
539}
540
541/// Assign splits to a new set of actors, according to existing assignment.
542///
543/// illustration:
544/// ```text
545/// upstream                               new
546/// actor x1 [split 1, split2]      ->     actor y1 [split 1, split2]
547/// actor x2 [split 3]              ->     actor y2 [split 3]
548/// ...
549/// ```
550fn align_splits(
551    // (actor_id, upstream_actor_id)
552    aligned_actors: impl IntoIterator<Item = (ActorId, ActorId)>,
553    existing_assignment: &HashMap<ActorId, Vec<SplitImpl>>,
554    fragment_id: FragmentId,
555    upstream_source_fragment_id: FragmentId,
556) -> anyhow::Result<HashMap<ActorId, Vec<SplitImpl>>> {
557    aligned_actors
558        .into_iter()
559        .map(|(actor_id, upstream_actor_id)| {
560            let Some(splits) = existing_assignment.get(&upstream_actor_id) else {
561                return Err(anyhow::anyhow!("upstream assignment not found, fragment_id: {fragment_id}, upstream_fragment_id: {upstream_source_fragment_id}, actor_id: {actor_id}, upstream_assignment: {existing_assignment:?}, upstream_actor_id: {upstream_actor_id:?}"));
562            };
563            Ok((
564                actor_id,
565                splits.clone(),
566            ))
567        })
568        .collect()
569}
570
571/// Note: the `PartialEq` and `Ord` impl just compares the number of splits.
572#[derive(Debug)]
573struct ActorSplitsAssignment<T: SplitMetaData> {
574    actor_id: ActorId,
575    splits: Vec<T>,
576}
577
578impl<T: SplitMetaData + Clone> Eq for ActorSplitsAssignment<T> {}
579
580impl<T: SplitMetaData + Clone> PartialEq<Self> for ActorSplitsAssignment<T> {
581    fn eq(&self, other: &Self) -> bool {
582        self.splits.len() == other.splits.len()
583    }
584}
585
586impl<T: SplitMetaData + Clone> PartialOrd<Self> for ActorSplitsAssignment<T> {
587    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
588        Some(self.cmp(other))
589    }
590}
591
592impl<T: SplitMetaData + Clone> Ord for ActorSplitsAssignment<T> {
593    fn cmp(&self, other: &Self) -> Ordering {
594        // Note: this is reversed order, to make BinaryHeap a min heap.
595        other.splits.len().cmp(&self.splits.len())
596    }
597}
598
599#[derive(Debug)]
600pub struct SplitDiffOptions {
601    pub enable_scale_in: bool,
602
603    /// For most connectors, this should be false. When enabled, RisingWave will not track any progress.
604    pub enable_adaptive: bool,
605}
606
607#[allow(clippy::derivable_impls)]
608impl Default for SplitDiffOptions {
609    fn default() -> Self {
610        SplitDiffOptions {
611            enable_scale_in: false,
612            enable_adaptive: false,
613        }
614    }
615}
616
617#[cfg(test)]
618mod tests {
619    use std::collections::{BTreeMap, HashMap, HashSet};
620
621    use risingwave_common::types::JsonbVal;
622    use risingwave_connector::error::ConnectorResult;
623    use risingwave_connector::source::{SplitId, SplitMetaData};
624    use serde::{Deserialize, Serialize};
625
626    use super::*;
627    use crate::model::{ActorId, FragmentId};
628
629    #[derive(Debug, Copy, Clone, Serialize, Deserialize)]
630    struct TestSplit {
631        id: u32,
632    }
633
634    impl SplitMetaData for TestSplit {
635        fn id(&self) -> SplitId {
636            format!("{}", self.id).into()
637        }
638
639        fn encode_to_json(&self) -> JsonbVal {
640            serde_json::to_value(*self).unwrap().into()
641        }
642
643        fn restore_from_json(value: JsonbVal) -> ConnectorResult<Self> {
644            serde_json::from_value(value.take()).map_err(Into::into)
645        }
646
647        fn update_offset(&mut self, _last_read_offset: String) -> ConnectorResult<()> {
648            Ok(())
649        }
650    }
651
652    fn check_all_splits(
653        discovered_splits: &BTreeMap<SplitId, TestSplit>,
654        diff: &HashMap<ActorId, Vec<TestSplit>>,
655    ) {
656        let mut split_ids: HashSet<_> = discovered_splits.keys().cloned().collect();
657
658        for splits in diff.values() {
659            for split in splits {
660                assert!(split_ids.remove(&split.id()))
661            }
662        }
663
664        assert!(split_ids.is_empty());
665    }
666
667    #[test]
668    fn test_drop_splits() {
669        let mut actor_splits: HashMap<ActorId, _> = HashMap::new();
670        actor_splits.insert(0, vec![TestSplit { id: 0 }, TestSplit { id: 1 }]);
671        actor_splits.insert(1, vec![TestSplit { id: 2 }, TestSplit { id: 3 }]);
672        actor_splits.insert(2, vec![TestSplit { id: 4 }, TestSplit { id: 5 }]);
673
674        let mut prev_split_to_actor = HashMap::new();
675        for (actor_id, splits) in &actor_splits {
676            for split in splits {
677                prev_split_to_actor.insert(split.id(), *actor_id);
678            }
679        }
680
681        let discovered_splits: BTreeMap<SplitId, TestSplit> = (1..5)
682            .map(|i| {
683                let split = TestSplit { id: i };
684                (split.id(), split)
685            })
686            .collect();
687
688        let opts = SplitDiffOptions {
689            enable_scale_in: true,
690            enable_adaptive: false,
691        };
692
693        let prev_split_ids: HashSet<_> = actor_splits
694            .values()
695            .flat_map(|splits| splits.iter().map(|split| split.id()))
696            .collect();
697
698        let diff = reassign_splits(
699            FragmentId::default(),
700            actor_splits,
701            &discovered_splits,
702            opts,
703        )
704        .unwrap();
705        check_all_splits(&discovered_splits, &diff);
706
707        let mut after_split_to_actor = HashMap::new();
708        for (actor_id, splits) in &diff {
709            for split in splits {
710                after_split_to_actor.insert(split.id(), *actor_id);
711            }
712        }
713
714        let discovered_split_ids: HashSet<_> = discovered_splits.keys().cloned().collect();
715
716        let retained_split_ids: HashSet<_> =
717            prev_split_ids.intersection(&discovered_split_ids).collect();
718
719        for retained_split_id in retained_split_ids {
720            assert_eq!(
721                prev_split_to_actor.get(retained_split_id),
722                after_split_to_actor.get(retained_split_id)
723            )
724        }
725    }
726
727    #[test]
728    fn test_drop_splits_to_empty() {
729        let mut actor_splits: HashMap<ActorId, _> = HashMap::new();
730        actor_splits.insert(0, vec![TestSplit { id: 0 }]);
731
732        let discovered_splits: BTreeMap<SplitId, TestSplit> = BTreeMap::new();
733
734        let opts = SplitDiffOptions {
735            enable_scale_in: true,
736            enable_adaptive: false,
737        };
738
739        let diff = reassign_splits(
740            FragmentId::default(),
741            actor_splits,
742            &discovered_splits,
743            opts,
744        )
745        .unwrap();
746
747        assert!(!diff.is_empty())
748    }
749
750    #[test]
751    fn test_reassign_splits() {
752        let actor_splits = HashMap::new();
753        let discovered_splits: BTreeMap<SplitId, TestSplit> = BTreeMap::new();
754        assert!(
755            reassign_splits(
756                FragmentId::default(),
757                actor_splits,
758                &discovered_splits,
759                Default::default()
760            )
761            .is_none()
762        );
763
764        let actor_splits = (0..3).map(|i| (i, vec![])).collect();
765        let discovered_splits: BTreeMap<SplitId, TestSplit> = BTreeMap::new();
766        let diff = reassign_splits(
767            FragmentId::default(),
768            actor_splits,
769            &discovered_splits,
770            Default::default(),
771        )
772        .unwrap();
773        assert_eq!(diff.len(), 3);
774        for splits in diff.values() {
775            assert!(splits.is_empty())
776        }
777
778        let actor_splits = (0..3).map(|i| (i, vec![])).collect();
779        let discovered_splits: BTreeMap<SplitId, TestSplit> = (0..3)
780            .map(|i| {
781                let split = TestSplit { id: i };
782                (split.id(), split)
783            })
784            .collect();
785
786        let diff = reassign_splits(
787            FragmentId::default(),
788            actor_splits,
789            &discovered_splits,
790            Default::default(),
791        )
792        .unwrap();
793        assert_eq!(diff.len(), 3);
794        for splits in diff.values() {
795            assert_eq!(splits.len(), 1);
796        }
797
798        check_all_splits(&discovered_splits, &diff);
799
800        let actor_splits = (0..3).map(|i| (i, vec![TestSplit { id: i }])).collect();
801        let discovered_splits: BTreeMap<SplitId, TestSplit> = (0..5)
802            .map(|i| {
803                let split = TestSplit { id: i };
804                (split.id(), split)
805            })
806            .collect();
807
808        let diff = reassign_splits(
809            FragmentId::default(),
810            actor_splits,
811            &discovered_splits,
812            Default::default(),
813        )
814        .unwrap();
815        assert_eq!(diff.len(), 3);
816        for splits in diff.values() {
817            let len = splits.len();
818            assert!(len == 1 || len == 2);
819        }
820
821        check_all_splits(&discovered_splits, &diff);
822
823        let mut actor_splits: HashMap<ActorId, Vec<TestSplit>> =
824            (0..3).map(|i| (i, vec![TestSplit { id: i }])).collect();
825        actor_splits.insert(3, vec![]);
826        actor_splits.insert(4, vec![]);
827
828        let discovered_splits: BTreeMap<SplitId, TestSplit> = (0..5)
829            .map(|i| {
830                let split = TestSplit { id: i };
831                (split.id(), split)
832            })
833            .collect();
834
835        let diff = reassign_splits(
836            FragmentId::default(),
837            actor_splits,
838            &discovered_splits,
839            Default::default(),
840        )
841        .unwrap();
842        assert_eq!(diff.len(), 5);
843        for splits in diff.values() {
844            assert_eq!(splits.len(), 1);
845        }
846
847        check_all_splits(&discovered_splits, &diff);
848    }
849}