risingwave_meta/stream/source_manager/
split_assignment.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use anyhow::anyhow;
16use itertools::Itertools;
17
18use super::*;
19use crate::model::{FragmentNewNoShuffle, FragmentReplaceUpstream, StreamJobFragments};
20
21impl SourceManager {
22    /// Migrates splits from previous actors to the new actors for a rescheduled fragment.
23    ///
24    /// Very occasionally split removal may happen during scaling, in which case we need to
25    /// use the old splits for reallocation instead of the latest splits (which may be missing),
26    /// so that we can resolve the split removal in the next command.
27    pub async fn migrate_splits_for_source_actors(
28        &self,
29        fragment_id: FragmentId,
30        prev_actor_ids: &[ActorId],
31        curr_actor_ids: &[ActorId],
32    ) -> MetaResult<HashMap<ActorId, Vec<SplitImpl>>> {
33        let core = self.core.lock().await;
34
35        let prev_splits = prev_actor_ids
36            .iter()
37            .flat_map(|actor_id| {
38                // Note: File Source / Iceberg Source doesn't have splits assigned by meta.
39                core.actor_splits.get(actor_id).cloned().unwrap_or_default()
40            })
41            .map(|split| (split.id(), split))
42            .collect();
43
44        let empty_actor_splits = curr_actor_ids
45            .iter()
46            .map(|actor_id| (*actor_id, vec![]))
47            .collect();
48
49        let diff = reassign_splits(
50            fragment_id,
51            empty_actor_splits,
52            &prev_splits,
53            // pre-allocate splits is the first time getting splits and it does not have scale-in scene
54            SplitDiffOptions::default(),
55        )
56        .unwrap_or_default();
57
58        Ok(diff)
59    }
60
61    /// Migrates splits from previous actors to the new actors for a rescheduled fragment.
62    pub fn migrate_splits_for_backfill_actors(
63        &self,
64        fragment_id: FragmentId,
65        upstream_source_fragment_id: FragmentId,
66        curr_actor_ids: &[ActorId],
67        fragment_actor_splits: &HashMap<FragmentId, HashMap<ActorId, Vec<SplitImpl>>>,
68        no_shuffle_upstream_actor_map: &HashMap<ActorId, HashMap<FragmentId, ActorId>>,
69    ) -> MetaResult<HashMap<ActorId, Vec<SplitImpl>>> {
70        // align splits for backfill fragments with its upstream source fragment
71        let actors = no_shuffle_upstream_actor_map
72            .iter()
73            .filter(|(id, _)| curr_actor_ids.contains(id))
74            .map(|(id, upstream_fragment_actors)| {
75                (
76                    *id,
77                    *upstream_fragment_actors
78                        .get(&upstream_source_fragment_id)
79                        .unwrap(),
80                )
81            });
82        let upstream_assignment = fragment_actor_splits
83            .get(&upstream_source_fragment_id)
84            .unwrap();
85        tracing::info!(
86            fragment_id,
87            upstream_source_fragment_id,
88            ?upstream_assignment,
89            "migrate_splits_for_backfill_actors"
90        );
91        Ok(align_splits(
92            actors,
93            upstream_assignment,
94            fragment_id,
95            upstream_source_fragment_id,
96        )?)
97    }
98
99    /// Allocates splits to actors for a newly created source executor.
100    #[await_tree::instrument]
101    pub async fn allocate_splits(
102        &self,
103        table_fragments: &StreamJobFragments,
104    ) -> MetaResult<SplitAssignment> {
105        let core = self.core.lock().await;
106
107        // stream_job_id is used to fetch pre-calculated splits for CDC table.
108        let stream_job_id = table_fragments.stream_job_id;
109        let source_fragments = table_fragments.stream_source_fragments();
110
111        let mut assigned = HashMap::new();
112
113        'loop_source: for (source_id, fragments) in source_fragments {
114            let handle = core
115                .managed_sources
116                .get(&source_id)
117                .with_context(|| format!("could not find source {}", source_id))?;
118
119            if handle.splits.lock().await.splits.is_none() {
120                handle.force_tick().await?;
121            }
122
123            for fragment_id in fragments {
124                let empty_actor_splits: HashMap<u32, Vec<SplitImpl>> = table_fragments
125                    .fragments
126                    .get(&fragment_id)
127                    .unwrap()
128                    .actors
129                    .iter()
130                    .map(|actor| (actor.actor_id, vec![]))
131                    .collect();
132                let actor_hashset: HashSet<u32> = empty_actor_splits.keys().cloned().collect();
133                let splits = handle.discovered_splits(source_id, &actor_hashset).await?;
134                if splits.is_empty() {
135                    tracing::warn!(?stream_job_id, source_id, "no splits detected");
136                    continue 'loop_source;
137                }
138                if let Some(diff) = reassign_splits(
139                    fragment_id,
140                    empty_actor_splits,
141                    &splits,
142                    SplitDiffOptions::default(),
143                ) {
144                    assigned.insert(fragment_id, diff);
145                }
146            }
147        }
148
149        Ok(assigned)
150    }
151
152    /// Allocates splits to actors for replace source job.
153    pub async fn allocate_splits_for_replace_source(
154        &self,
155        table_fragments: &StreamJobFragments,
156        upstream_updates: &FragmentReplaceUpstream,
157        // new_no_shuffle:
158        //     upstream fragment_id ->
159        //     downstream fragment_id ->
160        //     upstream actor_id ->
161        //     downstream actor_id
162        new_no_shuffle: &FragmentNewNoShuffle,
163    ) -> MetaResult<SplitAssignment> {
164        tracing::debug!(?upstream_updates, "allocate_splits_for_replace_source");
165        if upstream_updates.is_empty() {
166            // no existing downstream. We can just re-allocate splits arbitrarily.
167            return self.allocate_splits(table_fragments).await;
168        }
169
170        let core = self.core.lock().await;
171
172        let source_fragments = table_fragments.stream_source_fragments();
173        assert_eq!(
174            source_fragments.len(),
175            1,
176            "replace source job should only have one source"
177        );
178        let (_source_id, fragments) = source_fragments.into_iter().next().unwrap();
179        assert_eq!(
180            fragments.len(),
181            1,
182            "replace source job should only have one fragment"
183        );
184        let fragment_id = fragments.into_iter().next().unwrap();
185
186        debug_assert!(
187            upstream_updates.values().flatten().next().is_some()
188                && upstream_updates
189                    .values()
190                    .flatten()
191                    .all(|(_, new_upstream_fragment_id)| {
192                        *new_upstream_fragment_id == fragment_id
193                    })
194                && upstream_updates
195                    .values()
196                    .flatten()
197                    .map(|(upstream_fragment_id, _)| upstream_fragment_id)
198                    .all_equal(),
199            "upstream update should only replace one fragment: {:?}",
200            upstream_updates
201        );
202        let prev_fragment_id = upstream_updates
203            .values()
204            .flatten()
205            .next()
206            .map(|(upstream_fragment_id, _)| *upstream_fragment_id)
207            .expect("non-empty");
208        // Here we align the new source executor to backfill executors
209        //
210        // old_source => new_source            backfill_1
211        // actor_x1   => actor_y1 -----┬------>actor_a1
212        // actor_x2   => actor_y2 -----┼-┬---->actor_a2
213        //                             │ │
214        //                             │ │     backfill_2
215        //                             └─┼---->actor_b1
216        //                               └---->actor_b2
217        //
218        // Note: we can choose any backfill actor to align here.
219        // We use `HashMap` to dedup.
220        let aligned_actors: HashMap<ActorId, ActorId> = new_no_shuffle
221            .get(&fragment_id)
222            .map(HashMap::values)
223            .into_iter()
224            .flatten()
225            .flatten()
226            .map(|(upstream_actor_id, actor_id)| (*upstream_actor_id, *actor_id))
227            .collect();
228        let assignment = align_splits(
229            aligned_actors.into_iter(),
230            &core.actor_splits,
231            fragment_id,
232            prev_fragment_id,
233        )?;
234        Ok(HashMap::from([(fragment_id, assignment)]))
235    }
236
237    /// Allocates splits to actors for a newly created `SourceBackfill` executor.
238    ///
239    /// Unlike [`Self::allocate_splits`], which creates a new assignment,
240    /// this method aligns the splits for backfill fragments with its upstream source fragment ([`align_splits`]).
241    #[await_tree::instrument]
242    pub async fn allocate_splits_for_backfill(
243        &self,
244        table_fragments: &StreamJobFragments,
245        upstream_new_no_shuffle: &FragmentNewNoShuffle,
246        upstream_actors: &HashMap<FragmentId, HashSet<ActorId>>,
247    ) -> MetaResult<SplitAssignment> {
248        let core = self.core.lock().await;
249
250        let source_backfill_fragments = table_fragments.source_backfill_fragments();
251
252        let mut assigned = HashMap::new();
253
254        for (_source_id, fragments) in source_backfill_fragments {
255            for (fragment_id, upstream_source_fragment_id) in fragments {
256                let upstream_actors = upstream_actors
257                    .get(&upstream_source_fragment_id)
258                    .ok_or_else(|| {
259                        anyhow!(
260                            "no upstream actors found from fragment {} to upstream source fragment {}",
261                            fragment_id,
262                            upstream_source_fragment_id
263                        )
264                    })?;
265                let mut backfill_actors = vec![];
266                let Some(source_new_no_shuffle) = upstream_new_no_shuffle
267                    .get(&upstream_source_fragment_id)
268                    .and_then(|source_upstream_new_no_shuffle| {
269                        source_upstream_new_no_shuffle.get(&fragment_id)
270                    })
271                else {
272                    return Err(anyhow::anyhow!(
273                            "source backfill fragment's upstream fragment should have one-on-one no_shuffle mapping, fragment_id: {fragment_id}, upstream_fragment_id: {upstream_source_fragment_id}, upstream_new_no_shuffle: {upstream_new_no_shuffle:?}",
274                            fragment_id = fragment_id,
275                            upstream_source_fragment_id = upstream_source_fragment_id,
276                            upstream_new_no_shuffle = upstream_new_no_shuffle,
277                        ).into());
278                };
279                for upstream_actor in upstream_actors {
280                    let Some(no_shuffle_backfill_actor) = source_new_no_shuffle.get(upstream_actor)
281                    else {
282                        return Err(anyhow::anyhow!(
283                            "source backfill fragment's upstream fragment should have one-on-one no_shuffle mapping, fragment_id: {fragment_id}, upstream_fragment_id: {upstream_source_fragment_id}, upstream_actor: {upstream_actor}, source_new_no_shuffle: {source_new_no_shuffle:?}",
284                            fragment_id = fragment_id,
285                            upstream_source_fragment_id = upstream_source_fragment_id,
286                            upstream_actor = upstream_actor,
287                            source_new_no_shuffle = source_new_no_shuffle
288                        ).into());
289                    };
290                    backfill_actors.push((*no_shuffle_backfill_actor, *upstream_actor));
291                }
292                assigned.insert(
293                    fragment_id,
294                    align_splits(
295                        backfill_actors,
296                        &core.actor_splits,
297                        fragment_id,
298                        upstream_source_fragment_id,
299                    )?,
300                );
301            }
302        }
303
304        Ok(assigned)
305    }
306}
307
308impl SourceManagerCore {
309    /// Checks whether the external source metadata has changed,
310    /// and re-assigns splits if there's a diff.
311    ///
312    /// `self.actor_splits` will not be updated. It will be updated by `Self::apply_source_change`,
313    /// after the mutation barrier has been collected.
314    pub async fn reassign_splits(&self) -> MetaResult<HashMap<DatabaseId, SplitAssignment>> {
315        let mut split_assignment: SplitAssignment = HashMap::new();
316
317        'loop_source: for (source_id, handle) in &self.managed_sources {
318            let source_fragment_ids = match self.source_fragments.get(source_id) {
319                Some(fragment_ids) if !fragment_ids.is_empty() => fragment_ids,
320                _ => {
321                    continue;
322                }
323            };
324            let backfill_fragment_ids = self.backfill_fragments.get(source_id);
325
326            'loop_fragment: for &fragment_id in source_fragment_ids {
327                let actors = match self
328                    .metadata_manager
329                    .get_running_actors_of_fragment(fragment_id)
330                    .await
331                {
332                    Ok(actors) => {
333                        if actors.is_empty() {
334                            tracing::warn!("No actors found for fragment {}", fragment_id);
335                            continue 'loop_fragment;
336                        }
337                        actors
338                    }
339                    Err(err) => {
340                        tracing::warn!(error = %err.as_report(), "Failed to get the actor of the fragment, maybe the fragment doesn't exist anymore");
341                        continue 'loop_fragment;
342                    }
343                };
344
345                let discovered_splits = handle.discovered_splits(*source_id, &actors).await?;
346                if discovered_splits.is_empty() {
347                    // The discover loop for this source is not ready yet; we'll wait for the next run
348                    continue 'loop_source;
349                }
350
351                let prev_actor_splits: HashMap<_, _> = actors
352                    .into_iter()
353                    .map(|actor_id| {
354                        (
355                            actor_id,
356                            self.actor_splits
357                                .get(&actor_id)
358                                .cloned()
359                                .unwrap_or_default(),
360                        )
361                    })
362                    .collect();
363
364                if let Some(new_assignment) = reassign_splits(
365                    fragment_id,
366                    prev_actor_splits,
367                    &discovered_splits,
368                    SplitDiffOptions {
369                        enable_scale_in: handle.enable_drop_split,
370                        enable_adaptive: handle.enable_adaptive_splits,
371                    },
372                ) {
373                    split_assignment.insert(fragment_id, new_assignment);
374                }
375            }
376
377            if let Some(backfill_fragment_ids) = backfill_fragment_ids {
378                // align splits for backfill fragments with its upstream source fragment
379                for (fragment_id, upstream_fragment_id) in backfill_fragment_ids {
380                    let Some(upstream_assignment) = split_assignment.get(upstream_fragment_id)
381                    else {
382                        // upstream fragment unchanged, do not update backfill fragment too
383                        continue;
384                    };
385                    let actors = match self
386                        .metadata_manager
387                        .get_running_actors_for_source_backfill(*fragment_id, *upstream_fragment_id)
388                        .await
389                    {
390                        Ok(actors) => {
391                            if actors.is_empty() {
392                                tracing::warn!("No actors found for fragment {}", fragment_id);
393                                continue;
394                            }
395                            actors
396                        }
397                        Err(err) => {
398                            tracing::warn!(error = %err.as_report(),"Failed to get the actor of the fragment, maybe the fragment doesn't exist anymore");
399                            continue;
400                        }
401                    };
402                    split_assignment.insert(
403                        *fragment_id,
404                        align_splits(
405                            actors,
406                            upstream_assignment,
407                            *fragment_id,
408                            *upstream_fragment_id,
409                        )?,
410                    );
411                }
412            }
413        }
414
415        self.metadata_manager
416            .split_fragment_map_by_database(split_assignment)
417            .await
418    }
419}
420
421/// Reassigns splits if there are new splits or dropped splits,
422/// i.e., `actor_splits` and `discovered_splits` differ, or actors are rescheduled.
423///
424/// The existing splits will remain unmoved in their currently assigned actor.
425///
426/// If an actor has an upstream actor, it should be a backfill executor,
427/// and its splits should be aligned with the upstream actor. **`reassign_splits` should not be used in this case.
428/// Use [`align_splits`] instead.**
429///
430/// - `fragment_id`: just for logging
431///
432/// ## Different connectors' behavior of split change
433///
434/// ### Kafka and Pulsar
435/// They only support increasing the number of splits via adding new empty splits.
436/// Old data is not moved.
437///
438/// ### Kinesis
439/// It supports *pairwise* shard split and merge.
440///
441/// In both cases, old data remain in the old shard(s) and the old shard is still available.
442/// New data are routed to the new shard(s).
443/// After the retention period has expired, the old shard will become `EXPIRED` and isn't
444/// listed any more. In other words, the total number of shards will first increase and then decrease.
445///
446/// See also:
447/// - [Kinesis resharding doc](https://docs.aws.amazon.com/streams/latest/dev/kinesis-using-sdk-java-after-resharding.html#kinesis-using-sdk-java-resharding-data-routing)
448/// - An example of how the shards can be like: <https://stackoverflow.com/questions/72272034/list-shard-show-more-shards-than-provisioned>
449fn reassign_splits<T>(
450    fragment_id: FragmentId,
451    actor_splits: HashMap<ActorId, Vec<T>>,
452    discovered_splits: &BTreeMap<SplitId, T>,
453    opts: SplitDiffOptions,
454) -> Option<HashMap<ActorId, Vec<T>>>
455where
456    T: SplitMetaData + Clone,
457{
458    // if no actors, return
459    if actor_splits.is_empty() {
460        return None;
461    }
462
463    let prev_split_ids: HashSet<_> = actor_splits
464        .values()
465        .flat_map(|splits| splits.iter().map(SplitMetaData::id))
466        .collect();
467
468    tracing::trace!(fragment_id, prev_split_ids = ?prev_split_ids, "previous splits");
469    tracing::trace!(fragment_id, prev_split_ids = ?discovered_splits.keys(), "discovered splits");
470
471    let discovered_split_ids: HashSet<_> = discovered_splits.keys().cloned().collect();
472
473    let dropped_splits: HashSet<_> = prev_split_ids
474        .difference(&discovered_split_ids)
475        .cloned()
476        .collect();
477
478    if !dropped_splits.is_empty() {
479        if opts.enable_scale_in {
480            tracing::info!(fragment_id, dropped_spltis = ?dropped_splits, "new dropped splits");
481        } else {
482            tracing::warn!(fragment_id, dropped_spltis = ?dropped_splits, "split dropping happened, but it is not allowed");
483        }
484    }
485
486    let new_discovered_splits: BTreeSet<_> = discovered_split_ids
487        .into_iter()
488        .filter(|split_id| !prev_split_ids.contains(split_id))
489        .collect();
490
491    if opts.enable_scale_in || opts.enable_adaptive {
492        // if we support scale in, no more splits are discovered, and no splits are dropped, return
493        // we need to check if discovered_split_ids is empty, because if it is empty, we need to
494        // handle the case of scale in to zero (like deleting all objects from s3)
495        if dropped_splits.is_empty()
496            && new_discovered_splits.is_empty()
497            && !discovered_splits.is_empty()
498        {
499            return None;
500        }
501    } else {
502        // if we do not support scale in, and no more splits are discovered, return
503        if new_discovered_splits.is_empty() && !discovered_splits.is_empty() {
504            return None;
505        }
506    }
507
508    tracing::info!(fragment_id, new_discovered_splits = ?new_discovered_splits, "new discovered splits");
509
510    let mut heap = BinaryHeap::with_capacity(actor_splits.len());
511
512    for (actor_id, mut splits) in actor_splits {
513        if opts.enable_scale_in || opts.enable_adaptive {
514            splits.retain(|split| !dropped_splits.contains(&split.id()));
515        }
516
517        heap.push(ActorSplitsAssignment { actor_id, splits })
518    }
519
520    for split_id in new_discovered_splits {
521        // ActorSplitsAssignment's Ord is reversed, so this is min heap, i.e.,
522        // we get the assignment with the least splits here.
523
524        // Note: If multiple actors have the same number of splits, it will be randomly picked.
525        // When the number of source actors is larger than the number of splits,
526        // It's possible that the assignment is uneven.
527        // e.g., https://github.com/risingwavelabs/risingwave/issues/14324#issuecomment-1875033158
528        // TODO: We should make the assignment rack-aware to make sure it's even.
529        let mut peek_ref = heap.peek_mut().unwrap();
530        peek_ref
531            .splits
532            .push(discovered_splits.get(&split_id).cloned().unwrap());
533    }
534
535    Some(
536        heap.into_iter()
537            .map(|ActorSplitsAssignment { actor_id, splits }| (actor_id, splits))
538            .collect(),
539    )
540}
541
542/// Assign splits to a new set of actors, according to existing assignment.
543///
544/// illustration:
545/// ```text
546/// upstream                               new
547/// actor x1 [split 1, split2]      ->     actor y1 [split 1, split2]
548/// actor x2 [split 3]              ->     actor y2 [split 3]
549/// ...
550/// ```
551fn align_splits(
552    // (actor_id, upstream_actor_id)
553    aligned_actors: impl IntoIterator<Item = (ActorId, ActorId)>,
554    existing_assignment: &HashMap<ActorId, Vec<SplitImpl>>,
555    fragment_id: FragmentId,
556    upstream_source_fragment_id: FragmentId,
557) -> anyhow::Result<HashMap<ActorId, Vec<SplitImpl>>> {
558    aligned_actors
559        .into_iter()
560        .map(|(actor_id, upstream_actor_id)| {
561            let Some(splits) = existing_assignment.get(&upstream_actor_id) else {
562                return Err(anyhow::anyhow!("upstream assignment not found, fragment_id: {fragment_id}, upstream_fragment_id: {upstream_source_fragment_id}, actor_id: {actor_id}, upstream_assignment: {existing_assignment:?}, upstream_actor_id: {upstream_actor_id:?}"));
563            };
564            Ok((
565                actor_id,
566                splits.clone(),
567            ))
568        })
569        .collect()
570}
571
572/// Note: the `PartialEq` and `Ord` impl just compares the number of splits.
573#[derive(Debug)]
574struct ActorSplitsAssignment<T: SplitMetaData> {
575    actor_id: ActorId,
576    splits: Vec<T>,
577}
578
579impl<T: SplitMetaData + Clone> Eq for ActorSplitsAssignment<T> {}
580
581impl<T: SplitMetaData + Clone> PartialEq<Self> for ActorSplitsAssignment<T> {
582    fn eq(&self, other: &Self) -> bool {
583        self.splits.len() == other.splits.len()
584    }
585}
586
587impl<T: SplitMetaData + Clone> PartialOrd<Self> for ActorSplitsAssignment<T> {
588    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
589        Some(self.cmp(other))
590    }
591}
592
593impl<T: SplitMetaData + Clone> Ord for ActorSplitsAssignment<T> {
594    fn cmp(&self, other: &Self) -> Ordering {
595        // Note: this is reversed order, to make BinaryHeap a min heap.
596        other.splits.len().cmp(&self.splits.len())
597    }
598}
599
600#[derive(Debug)]
601pub struct SplitDiffOptions {
602    pub enable_scale_in: bool,
603
604    /// For most connectors, this should be false. When enabled, RisingWave will not track any progress.
605    pub enable_adaptive: bool,
606}
607
608#[allow(clippy::derivable_impls)]
609impl Default for SplitDiffOptions {
610    fn default() -> Self {
611        SplitDiffOptions {
612            enable_scale_in: false,
613            enable_adaptive: false,
614        }
615    }
616}
617
618#[cfg(test)]
619mod tests {
620    use std::collections::{BTreeMap, HashMap, HashSet};
621
622    use risingwave_common::types::JsonbVal;
623    use risingwave_connector::error::ConnectorResult;
624    use risingwave_connector::source::{SplitId, SplitMetaData};
625    use serde::{Deserialize, Serialize};
626
627    use super::*;
628    use crate::model::{ActorId, FragmentId};
629
630    #[derive(Debug, Copy, Clone, Serialize, Deserialize)]
631    struct TestSplit {
632        id: u32,
633    }
634
635    impl SplitMetaData for TestSplit {
636        fn id(&self) -> SplitId {
637            format!("{}", self.id).into()
638        }
639
640        fn encode_to_json(&self) -> JsonbVal {
641            serde_json::to_value(*self).unwrap().into()
642        }
643
644        fn restore_from_json(value: JsonbVal) -> ConnectorResult<Self> {
645            serde_json::from_value(value.take()).map_err(Into::into)
646        }
647
648        fn update_offset(&mut self, _last_read_offset: String) -> ConnectorResult<()> {
649            Ok(())
650        }
651    }
652
653    fn check_all_splits(
654        discovered_splits: &BTreeMap<SplitId, TestSplit>,
655        diff: &HashMap<ActorId, Vec<TestSplit>>,
656    ) {
657        let mut split_ids: HashSet<_> = discovered_splits.keys().cloned().collect();
658
659        for splits in diff.values() {
660            for split in splits {
661                assert!(split_ids.remove(&split.id()))
662            }
663        }
664
665        assert!(split_ids.is_empty());
666    }
667
668    #[test]
669    fn test_drop_splits() {
670        let mut actor_splits: HashMap<ActorId, _> = HashMap::new();
671        actor_splits.insert(0, vec![TestSplit { id: 0 }, TestSplit { id: 1 }]);
672        actor_splits.insert(1, vec![TestSplit { id: 2 }, TestSplit { id: 3 }]);
673        actor_splits.insert(2, vec![TestSplit { id: 4 }, TestSplit { id: 5 }]);
674
675        let mut prev_split_to_actor = HashMap::new();
676        for (actor_id, splits) in &actor_splits {
677            for split in splits {
678                prev_split_to_actor.insert(split.id(), *actor_id);
679            }
680        }
681
682        let discovered_splits: BTreeMap<SplitId, TestSplit> = (1..5)
683            .map(|i| {
684                let split = TestSplit { id: i };
685                (split.id(), split)
686            })
687            .collect();
688
689        let opts = SplitDiffOptions {
690            enable_scale_in: true,
691            enable_adaptive: false,
692        };
693
694        let prev_split_ids: HashSet<_> = actor_splits
695            .values()
696            .flat_map(|splits| splits.iter().map(|split| split.id()))
697            .collect();
698
699        let diff = reassign_splits(
700            FragmentId::default(),
701            actor_splits,
702            &discovered_splits,
703            opts,
704        )
705        .unwrap();
706        check_all_splits(&discovered_splits, &diff);
707
708        let mut after_split_to_actor = HashMap::new();
709        for (actor_id, splits) in &diff {
710            for split in splits {
711                after_split_to_actor.insert(split.id(), *actor_id);
712            }
713        }
714
715        let discovered_split_ids: HashSet<_> = discovered_splits.keys().cloned().collect();
716
717        let retained_split_ids: HashSet<_> =
718            prev_split_ids.intersection(&discovered_split_ids).collect();
719
720        for retained_split_id in retained_split_ids {
721            assert_eq!(
722                prev_split_to_actor.get(retained_split_id),
723                after_split_to_actor.get(retained_split_id)
724            )
725        }
726    }
727
728    #[test]
729    fn test_drop_splits_to_empty() {
730        let mut actor_splits: HashMap<ActorId, _> = HashMap::new();
731        actor_splits.insert(0, vec![TestSplit { id: 0 }]);
732
733        let discovered_splits: BTreeMap<SplitId, TestSplit> = BTreeMap::new();
734
735        let opts = SplitDiffOptions {
736            enable_scale_in: true,
737            enable_adaptive: false,
738        };
739
740        let diff = reassign_splits(
741            FragmentId::default(),
742            actor_splits,
743            &discovered_splits,
744            opts,
745        )
746        .unwrap();
747
748        assert!(!diff.is_empty())
749    }
750
751    #[test]
752    fn test_reassign_splits() {
753        let actor_splits = HashMap::new();
754        let discovered_splits: BTreeMap<SplitId, TestSplit> = BTreeMap::new();
755        assert!(
756            reassign_splits(
757                FragmentId::default(),
758                actor_splits,
759                &discovered_splits,
760                Default::default()
761            )
762            .is_none()
763        );
764
765        let actor_splits = (0..3).map(|i| (i, vec![])).collect();
766        let discovered_splits: BTreeMap<SplitId, TestSplit> = BTreeMap::new();
767        let diff = reassign_splits(
768            FragmentId::default(),
769            actor_splits,
770            &discovered_splits,
771            Default::default(),
772        )
773        .unwrap();
774        assert_eq!(diff.len(), 3);
775        for splits in diff.values() {
776            assert!(splits.is_empty())
777        }
778
779        let actor_splits = (0..3).map(|i| (i, vec![])).collect();
780        let discovered_splits: BTreeMap<SplitId, TestSplit> = (0..3)
781            .map(|i| {
782                let split = TestSplit { id: i };
783                (split.id(), split)
784            })
785            .collect();
786
787        let diff = reassign_splits(
788            FragmentId::default(),
789            actor_splits,
790            &discovered_splits,
791            Default::default(),
792        )
793        .unwrap();
794        assert_eq!(diff.len(), 3);
795        for splits in diff.values() {
796            assert_eq!(splits.len(), 1);
797        }
798
799        check_all_splits(&discovered_splits, &diff);
800
801        let actor_splits = (0..3).map(|i| (i, vec![TestSplit { id: i }])).collect();
802        let discovered_splits: BTreeMap<SplitId, TestSplit> = (0..5)
803            .map(|i| {
804                let split = TestSplit { id: i };
805                (split.id(), split)
806            })
807            .collect();
808
809        let diff = reassign_splits(
810            FragmentId::default(),
811            actor_splits,
812            &discovered_splits,
813            Default::default(),
814        )
815        .unwrap();
816        assert_eq!(diff.len(), 3);
817        for splits in diff.values() {
818            let len = splits.len();
819            assert!(len == 1 || len == 2);
820        }
821
822        check_all_splits(&discovered_splits, &diff);
823
824        let mut actor_splits: HashMap<ActorId, Vec<TestSplit>> =
825            (0..3).map(|i| (i, vec![TestSplit { id: i }])).collect();
826        actor_splits.insert(3, vec![]);
827        actor_splits.insert(4, vec![]);
828
829        let discovered_splits: BTreeMap<SplitId, TestSplit> = (0..5)
830            .map(|i| {
831                let split = TestSplit { id: i };
832                (split.id(), split)
833            })
834            .collect();
835
836        let diff = reassign_splits(
837            FragmentId::default(),
838            actor_splits,
839            &discovered_splits,
840            Default::default(),
841        )
842        .unwrap();
843        assert_eq!(diff.len(), 5);
844        for splits in diff.values() {
845            assert_eq!(splits.len(), 1);
846        }
847
848        check_all_splits(&discovered_splits, &diff);
849    }
850}