risingwave_meta/stream/source_manager/
split_assignment.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use anyhow::anyhow;
16use itertools::Itertools;
17
18use super::*;
19use crate::model::{FragmentNewNoShuffle, FragmentReplaceUpstream, StreamJobFragments};
20
21#[derive(Debug, Clone)]
22pub struct SplitState {
23    pub split_assignment: SplitAssignment,
24}
25
26impl SourceManager {
27    /// Migrates splits from previous actors to the new actors for a rescheduled fragment.
28    pub fn migrate_splits_for_backfill_actors(
29        &self,
30        fragment_id: FragmentId,
31        upstream_source_fragment_id: FragmentId,
32        curr_actor_ids: &[ActorId],
33        fragment_actor_splits: &HashMap<FragmentId, HashMap<ActorId, Vec<SplitImpl>>>,
34        no_shuffle_upstream_actor_map: &HashMap<ActorId, HashMap<FragmentId, ActorId>>,
35    ) -> MetaResult<HashMap<ActorId, Vec<SplitImpl>>> {
36        // align splits for backfill fragments with its upstream source fragment
37        let actors = no_shuffle_upstream_actor_map
38            .iter()
39            .filter(|(id, _)| curr_actor_ids.contains(id))
40            .map(|(id, upstream_fragment_actors)| {
41                (
42                    *id,
43                    *upstream_fragment_actors
44                        .get(&upstream_source_fragment_id)
45                        .unwrap(),
46                )
47            });
48        let upstream_assignment = fragment_actor_splits
49            .get(&upstream_source_fragment_id)
50            .unwrap();
51        tracing::info!(
52            fragment_id,
53            upstream_source_fragment_id,
54            ?upstream_assignment,
55            "migrate_splits_for_backfill_actors"
56        );
57        Ok(align_splits(
58            actors,
59            SplitAlignmentContext::Pending(upstream_assignment),
60            fragment_id,
61            upstream_source_fragment_id,
62        )?)
63    }
64
65    /// Allocates splits to actors for a newly created source executor.
66    #[await_tree::instrument]
67    pub async fn allocate_splits(
68        &self,
69        table_fragments: &StreamJobFragments,
70    ) -> MetaResult<SplitAssignment> {
71        let core = self.core.lock().await;
72
73        // stream_job_id is used to fetch pre-calculated splits for CDC table.
74        let stream_job_id = table_fragments.stream_job_id;
75        let source_fragments = table_fragments.stream_source_fragments();
76
77        let mut assigned = HashMap::new();
78
79        'loop_source: for (source_id, fragments) in source_fragments {
80            let handle = core
81                .managed_sources
82                .get(&source_id)
83                .with_context(|| format!("could not find source {}", source_id))?;
84
85            if handle.splits.lock().await.splits.is_none() {
86                handle.force_tick().await?;
87            }
88
89            for fragment_id in fragments {
90                let empty_actor_splits: HashMap<u32, Vec<SplitImpl>> = table_fragments
91                    .fragments
92                    .get(&fragment_id)
93                    .unwrap()
94                    .actors
95                    .iter()
96                    .map(|actor| (actor.actor_id, vec![]))
97                    .collect();
98                let actor_count = empty_actor_splits.keys().len();
99                let splits = handle.discovered_splits(source_id, actor_count).await?;
100                if splits.is_empty() {
101                    tracing::warn!(?stream_job_id, source_id, "no splits detected");
102                    continue 'loop_source;
103                }
104                if let Some(diff) = reassign_splits(
105                    fragment_id,
106                    empty_actor_splits,
107                    &splits,
108                    SplitDiffOptions::default(),
109                ) {
110                    assigned.insert(fragment_id, diff);
111                }
112            }
113        }
114
115        Ok(assigned)
116    }
117
118    /// Allocates splits to actors for replace source job.
119    pub async fn allocate_splits_for_replace_source(
120        &self,
121        table_fragments: &StreamJobFragments,
122        upstream_updates: &FragmentReplaceUpstream,
123        // new_no_shuffle:
124        //     upstream fragment_id ->
125        //     downstream fragment_id ->
126        //     upstream actor_id ->
127        //     downstream actor_id
128        new_no_shuffle: &FragmentNewNoShuffle,
129    ) -> MetaResult<SplitAssignment> {
130        tracing::debug!(?upstream_updates, "allocate_splits_for_replace_source");
131        if upstream_updates.is_empty() {
132            // no existing downstream. We can just re-allocate splits arbitrarily.
133            return self.allocate_splits(table_fragments).await;
134        }
135
136        let core = self.core.lock().await;
137
138        let source_fragments = table_fragments.stream_source_fragments();
139        assert_eq!(
140            source_fragments.len(),
141            1,
142            "replace source job should only have one source"
143        );
144        let (_source_id, fragments) = source_fragments.into_iter().next().unwrap();
145        assert_eq!(
146            fragments.len(),
147            1,
148            "replace source job should only have one fragment"
149        );
150        let fragment_id = fragments.into_iter().next().unwrap();
151
152        debug_assert!(
153            upstream_updates.values().flatten().next().is_some()
154                && upstream_updates
155                    .values()
156                    .flatten()
157                    .all(|(_, new_upstream_fragment_id)| {
158                        *new_upstream_fragment_id == fragment_id
159                    })
160                && upstream_updates
161                    .values()
162                    .flatten()
163                    .map(|(upstream_fragment_id, _)| upstream_fragment_id)
164                    .all_equal(),
165            "upstream update should only replace one fragment: {:?}",
166            upstream_updates
167        );
168        let prev_fragment_id = upstream_updates
169            .values()
170            .flatten()
171            .next()
172            .map(|(upstream_fragment_id, _)| *upstream_fragment_id)
173            .expect("non-empty");
174        // Here we align the new source executor to backfill executors
175        //
176        // old_source => new_source            backfill_1
177        // actor_x1   => actor_y1 -----┬------>actor_a1
178        // actor_x2   => actor_y2 -----┼-┬---->actor_a2
179        //                             │ │
180        //                             │ │     backfill_2
181        //                             └─┼---->actor_b1
182        //                               └---->actor_b2
183        //
184        // Note: we can choose any backfill actor to align here.
185        // We use `HashMap` to dedup.
186        let aligned_actors: HashMap<ActorId, ActorId> = new_no_shuffle
187            .get(&fragment_id)
188            .map(HashMap::values)
189            .into_iter()
190            .flatten()
191            .flatten()
192            .map(|(upstream_actor_id, actor_id)| (*upstream_actor_id, *actor_id))
193            .collect();
194        let assignment = align_splits(
195            aligned_actors.into_iter(),
196            SplitAlignmentContext::Stored(core.env.shared_actor_infos()),
197            fragment_id,
198            prev_fragment_id,
199        )?;
200        Ok(HashMap::from([(fragment_id, assignment)]))
201    }
202
203    /// Allocates splits to actors for a newly created `SourceBackfill` executor.
204    ///
205    /// Unlike [`Self::allocate_splits`], which creates a new assignment,
206    /// this method aligns the splits for backfill fragments with its upstream source fragment ([`align_splits`]).
207    #[await_tree::instrument]
208    pub async fn allocate_splits_for_backfill(
209        &self,
210        table_fragments: &StreamJobFragments,
211        upstream_new_no_shuffle: &FragmentNewNoShuffle,
212        upstream_actors: &HashMap<FragmentId, HashSet<ActorId>>,
213    ) -> MetaResult<SplitAssignment> {
214        let core = self.core.lock().await;
215
216        let source_backfill_fragments = table_fragments.source_backfill_fragments();
217
218        let mut assigned = HashMap::new();
219
220        for (_source_id, fragments) in source_backfill_fragments {
221            for (fragment_id, upstream_source_fragment_id) in fragments {
222                let upstream_actors = upstream_actors
223                    .get(&upstream_source_fragment_id)
224                    .ok_or_else(|| {
225                        anyhow!(
226                            "no upstream actors found from fragment {} to upstream source fragment {}",
227                            fragment_id,
228                            upstream_source_fragment_id
229                        )
230                    })?;
231                let mut backfill_actors = vec![];
232                let Some(source_new_no_shuffle) = upstream_new_no_shuffle
233                    .get(&upstream_source_fragment_id)
234                    .and_then(|source_upstream_new_no_shuffle| {
235                        source_upstream_new_no_shuffle.get(&fragment_id)
236                    })
237                else {
238                    return Err(anyhow::anyhow!(
239                            "source backfill fragment's upstream fragment should have one-on-one no_shuffle mapping, fragment_id: {fragment_id}, upstream_fragment_id: {upstream_source_fragment_id}, upstream_new_no_shuffle: {upstream_new_no_shuffle:?}",
240                            fragment_id = fragment_id,
241                            upstream_source_fragment_id = upstream_source_fragment_id,
242                            upstream_new_no_shuffle = upstream_new_no_shuffle,
243                        ).into());
244                };
245                for upstream_actor in upstream_actors {
246                    let Some(no_shuffle_backfill_actor) = source_new_no_shuffle.get(upstream_actor)
247                    else {
248                        return Err(anyhow::anyhow!(
249                            "source backfill fragment's upstream fragment should have one-on-one no_shuffle mapping, fragment_id: {fragment_id}, upstream_fragment_id: {upstream_source_fragment_id}, upstream_actor: {upstream_actor}, source_new_no_shuffle: {source_new_no_shuffle:?}",
250                            fragment_id = fragment_id,
251                            upstream_source_fragment_id = upstream_source_fragment_id,
252                            upstream_actor = upstream_actor,
253                            source_new_no_shuffle = source_new_no_shuffle
254                        ).into());
255                    };
256                    backfill_actors.push((*no_shuffle_backfill_actor, *upstream_actor));
257                }
258                assigned.insert(
259                    fragment_id,
260                    align_splits(
261                        backfill_actors,
262                        SplitAlignmentContext::Stored(core.env.shared_actor_infos()),
263                        fragment_id,
264                        upstream_source_fragment_id,
265                    )?,
266                );
267            }
268        }
269
270        Ok(assigned)
271    }
272}
273
274impl SourceManagerCore {
275    /// Checks whether the external source metadata has changed,
276    /// and re-assigns splits if there's a diff.
277    ///
278    /// `self.actor_splits` will not be updated. It will be updated by `Self::apply_source_change`,
279    /// after the mutation barrier has been collected.
280    pub async fn reassign_splits(&self) -> MetaResult<HashMap<DatabaseId, SplitState>> {
281        let mut split_assignment: SplitAssignment = HashMap::new();
282
283        'loop_source: for (source_id, handle) in &self.managed_sources {
284            let source_fragment_ids = match self.source_fragments.get(source_id) {
285                Some(fragment_ids) if !fragment_ids.is_empty() => fragment_ids,
286                _ => {
287                    continue;
288                }
289            };
290            let backfill_fragment_ids = self.backfill_fragments.get(source_id);
291
292            'loop_fragment: for &fragment_id in source_fragment_ids {
293                let actors = match self
294                    .metadata_manager
295                    .get_running_actors_of_fragment(fragment_id)
296                {
297                    Ok(actors) => {
298                        if actors.is_empty() {
299                            tracing::warn!("No actors found for fragment {}", fragment_id);
300                            continue 'loop_fragment;
301                        }
302                        actors
303                    }
304                    Err(err) => {
305                        tracing::warn!(error = %err.as_report(), "Failed to get the actor of the fragment, maybe the fragment doesn't exist anymore");
306                        continue 'loop_fragment;
307                    }
308                };
309
310                let discovered_splits = handle.discovered_splits(*source_id, actors.len()).await?;
311                if discovered_splits.is_empty() {
312                    // The discover loop for this source is not ready yet; we'll wait for the next run
313                    continue 'loop_source;
314                }
315
316                let prev_actor_splits = {
317                    let guard = self.env.shared_actor_infos().read_guard();
318
319                    guard
320                        .get_fragment(fragment_id)
321                        .and_then(|info| {
322                            info.actors
323                                .iter()
324                                .map(|(actor_id, actor_info)| {
325                                    (*actor_id, actor_info.splits.clone())
326                                })
327                                .collect::<HashMap<_, _>>()
328                                .into()
329                        })
330                        .unwrap_or_default()
331                };
332
333                if let Some(new_assignment) = reassign_splits(
334                    fragment_id,
335                    prev_actor_splits,
336                    &discovered_splits,
337                    SplitDiffOptions {
338                        enable_scale_in: handle.enable_drop_split,
339                        enable_adaptive: handle.enable_adaptive_splits,
340                    },
341                ) {
342                    split_assignment.insert(fragment_id, new_assignment);
343                }
344            }
345
346            if let Some(backfill_fragment_ids) = backfill_fragment_ids {
347                // align splits for backfill fragments with its upstream source fragment
348                for (fragment_id, upstream_fragment_id) in backfill_fragment_ids {
349                    let Some(upstream_assignment): Option<&HashMap<ActorId, Vec<SplitImpl>>> =
350                        split_assignment.get(upstream_fragment_id)
351                    else {
352                        // upstream fragment unchanged, do not update backfill fragment too
353                        continue;
354                    };
355                    let actors = match self
356                        .metadata_manager
357                        .get_running_actors_for_source_backfill(*fragment_id, *upstream_fragment_id)
358                        .await
359                    {
360                        Ok(actors) => {
361                            if actors.is_empty() {
362                                tracing::warn!("No actors found for fragment {}", fragment_id);
363                                continue;
364                            }
365                            actors
366                        }
367                        Err(err) => {
368                            tracing::warn!(error = %err.as_report(),"Failed to get the actor of the fragment, maybe the fragment doesn't exist anymore");
369                            continue;
370                        }
371                    };
372                    split_assignment.insert(
373                        *fragment_id,
374                        align_splits(
375                            actors,
376                            SplitAlignmentContext::Pending(upstream_assignment),
377                            *fragment_id,
378                            *upstream_fragment_id,
379                        )?,
380                    );
381                }
382            }
383        }
384
385        let assignments = self
386            .metadata_manager
387            .split_fragment_map_by_database(split_assignment)
388            .await?;
389
390        let mut result = HashMap::new();
391        for (database_id, assignment) in assignments {
392            result.insert(
393                database_id,
394                SplitState {
395                    split_assignment: assignment,
396                },
397            );
398        }
399
400        Ok(result)
401    }
402}
403
404/// Reassigns splits if there are new splits or dropped splits,
405/// i.e., `actor_splits` and `discovered_splits` differ, or actors are rescheduled.
406///
407/// The existing splits will remain unmoved in their currently assigned actor.
408///
409/// If an actor has an upstream actor, it should be a backfill executor,
410/// and its splits should be aligned with the upstream actor. **`reassign_splits` should not be used in this case.
411/// Use [`align_splits`] instead.**
412///
413/// - `fragment_id`: just for logging
414///
415/// ## Different connectors' behavior of split change
416///
417/// ### Kafka and Pulsar
418/// They only support increasing the number of splits via adding new empty splits.
419/// Old data is not moved.
420///
421/// ### Kinesis
422/// It supports *pairwise* shard split and merge.
423///
424/// In both cases, old data remain in the old shard(s) and the old shard is still available.
425/// New data are routed to the new shard(s).
426/// After the retention period has expired, the old shard will become `EXPIRED` and isn't
427/// listed any more. In other words, the total number of shards will first increase and then decrease.
428///
429/// See also:
430/// - [Kinesis resharding doc](https://docs.aws.amazon.com/streams/latest/dev/kinesis-using-sdk-java-after-resharding.html#kinesis-using-sdk-java-resharding-data-routing)
431/// - An example of how the shards can be like: <https://stackoverflow.com/questions/72272034/list-shard-show-more-shards-than-provisioned>
432pub fn reassign_splits<T>(
433    fragment_id: FragmentId,
434    actor_splits: HashMap<ActorId, Vec<T>>,
435    discovered_splits: &BTreeMap<SplitId, T>,
436    opts: SplitDiffOptions,
437) -> Option<HashMap<ActorId, Vec<T>>>
438where
439    T: SplitMetaData + Clone,
440{
441    // if no actors, return
442    if actor_splits.is_empty() {
443        return None;
444    }
445
446    let prev_split_ids: HashSet<_> = actor_splits
447        .values()
448        .flat_map(|splits| splits.iter().map(SplitMetaData::id))
449        .collect();
450
451    tracing::trace!(fragment_id, prev_split_ids = ?prev_split_ids, "previous splits");
452    tracing::trace!(fragment_id, prev_split_ids = ?discovered_splits.keys(), "discovered splits");
453
454    let discovered_split_ids: HashSet<_> = discovered_splits.keys().cloned().collect();
455
456    let dropped_splits: HashSet<_> = prev_split_ids
457        .difference(&discovered_split_ids)
458        .cloned()
459        .collect();
460
461    if !dropped_splits.is_empty() {
462        if opts.enable_scale_in {
463            tracing::info!(fragment_id, dropped_spltis = ?dropped_splits, "new dropped splits");
464        } else {
465            tracing::warn!(fragment_id, dropped_spltis = ?dropped_splits, "split dropping happened, but it is not allowed");
466        }
467    }
468
469    let new_discovered_splits: BTreeSet<_> = discovered_split_ids
470        .into_iter()
471        .filter(|split_id| !prev_split_ids.contains(split_id))
472        .collect();
473
474    if opts.enable_scale_in || opts.enable_adaptive {
475        // if we support scale in, no more splits are discovered, and no splits are dropped, return
476        // we need to check if discovered_split_ids is empty, because if it is empty, we need to
477        // handle the case of scale in to zero (like deleting all objects from s3)
478        if dropped_splits.is_empty()
479            && new_discovered_splits.is_empty()
480            && !discovered_splits.is_empty()
481        {
482            return None;
483        }
484    } else {
485        // if we do not support scale in, and no more splits are discovered, return
486        if new_discovered_splits.is_empty() && !discovered_splits.is_empty() {
487            return None;
488        }
489    }
490
491    tracing::info!(fragment_id, new_discovered_splits = ?new_discovered_splits, "new discovered splits");
492
493    let mut heap = BinaryHeap::with_capacity(actor_splits.len());
494
495    for (actor_id, mut splits) in actor_splits {
496        if opts.enable_scale_in || opts.enable_adaptive {
497            splits.retain(|split| !dropped_splits.contains(&split.id()));
498        }
499
500        heap.push(ActorSplitsAssignment { actor_id, splits })
501    }
502
503    for split_id in new_discovered_splits {
504        // ActorSplitsAssignment's Ord is reversed, so this is min heap, i.e.,
505        // we get the assignment with the least splits here.
506
507        // Note: If multiple actors have the same number of splits, it will be randomly picked.
508        // When the number of source actors is larger than the number of splits,
509        // It's possible that the assignment is uneven.
510        // e.g., https://github.com/risingwavelabs/risingwave/issues/14324#issuecomment-1875033158
511        // TODO: We should make the assignment rack-aware to make sure it's even.
512        let mut peek_ref = heap.peek_mut().unwrap();
513        peek_ref
514            .splits
515            .push(discovered_splits.get(&split_id).cloned().unwrap());
516    }
517
518    Some(
519        heap.into_iter()
520            .map(|ActorSplitsAssignment { actor_id, splits }| (actor_id, splits))
521            .collect(),
522    )
523}
524
525pub enum SplitAlignmentContext<'a> {
526    Stored(&'a SharedActorInfos),
527    Pending(&'a HashMap<ActorId, Vec<SplitImpl>>),
528}
529
530/// Assign splits to a new set of actors, according to existing assignment.
531///
532/// illustration:
533/// ```text
534/// upstream                               new
535/// actor x1 [split 1, split2]      ->     actor y1 [split 1, split2]
536/// actor x2 [split 3]              ->     actor y2 [split 3]
537/// ...
538/// ```
539fn align_splits(
540    // (actor_id, upstream_actor_id)
541    aligned_actors: impl IntoIterator<Item = (ActorId, ActorId)>,
542    alignment_context: SplitAlignmentContext<'_>,
543    fragment_id: FragmentId,
544    upstream_source_fragment_id: FragmentId,
545) -> anyhow::Result<HashMap<ActorId, Vec<SplitImpl>>> {
546    aligned_actors
547        .into_iter()
548        .map(|(actor_id, upstream_actor_id)| {
549            let Some(splits) = (match alignment_context {
550                SplitAlignmentContext::Stored(info) => {
551                    let guard = info.read_guard();
552                    let actor_info = guard.iter_over_fragments().find_map(|(_, fragment_info)| fragment_info.actors.get(&upstream_actor_id));
553                    actor_info.map(|info| info.splits.clone())
554                }
555                SplitAlignmentContext::Pending(existing_assignment) => {
556                    existing_assignment.get(&upstream_actor_id).cloned()
557                }
558            }) else {
559                return Err(anyhow::anyhow!("upstream assignment not found, fragment_id: {fragment_id}, upstream_fragment_id: {upstream_source_fragment_id}, actor_id: {actor_id}, upstream_actor_id: {upstream_actor_id:?}"));
560            };
561
562            Ok((
563                actor_id,
564                splits,
565            ))
566        })
567        .collect()
568}
569
570/// Note: the `PartialEq` and `Ord` impl just compares the number of splits.
571#[derive(Debug)]
572struct ActorSplitsAssignment<T: SplitMetaData> {
573    actor_id: ActorId,
574    splits: Vec<T>,
575}
576
577impl<T: SplitMetaData + Clone> Eq for ActorSplitsAssignment<T> {}
578
579impl<T: SplitMetaData + Clone> PartialEq<Self> for ActorSplitsAssignment<T> {
580    fn eq(&self, other: &Self) -> bool {
581        self.splits.len() == other.splits.len()
582    }
583}
584
585impl<T: SplitMetaData + Clone> PartialOrd<Self> for ActorSplitsAssignment<T> {
586    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
587        Some(self.cmp(other))
588    }
589}
590
591impl<T: SplitMetaData + Clone> Ord for ActorSplitsAssignment<T> {
592    fn cmp(&self, other: &Self) -> Ordering {
593        // Note: this is reversed order, to make BinaryHeap a min heap.
594        other
595            .splits
596            .len()
597            .cmp(&self.splits.len())
598            .then(self.actor_id.cmp(&other.actor_id))
599    }
600}
601
602#[derive(Debug)]
603pub struct SplitDiffOptions {
604    pub enable_scale_in: bool,
605
606    /// For most connectors, this should be false. When enabled, RisingWave will not track any progress.
607    pub enable_adaptive: bool,
608}
609
610#[allow(clippy::derivable_impls)]
611impl Default for SplitDiffOptions {
612    fn default() -> Self {
613        SplitDiffOptions {
614            enable_scale_in: false,
615            enable_adaptive: false,
616        }
617    }
618}
619
620#[cfg(test)]
621mod tests {
622    use std::collections::{BTreeMap, HashMap, HashSet};
623
624    use risingwave_common::types::JsonbVal;
625    use risingwave_connector::error::ConnectorResult;
626    use risingwave_connector::source::{SplitId, SplitMetaData};
627    use serde::{Deserialize, Serialize};
628
629    use super::*;
630    use crate::model::{ActorId, FragmentId};
631
632    #[derive(Debug, Copy, Clone, Serialize, Deserialize)]
633    struct TestSplit {
634        id: u32,
635    }
636
637    impl SplitMetaData for TestSplit {
638        fn id(&self) -> SplitId {
639            format!("{}", self.id).into()
640        }
641
642        fn encode_to_json(&self) -> JsonbVal {
643            serde_json::to_value(*self).unwrap().into()
644        }
645
646        fn restore_from_json(value: JsonbVal) -> ConnectorResult<Self> {
647            serde_json::from_value(value.take()).map_err(Into::into)
648        }
649
650        fn update_offset(&mut self, _last_read_offset: String) -> ConnectorResult<()> {
651            Ok(())
652        }
653    }
654
655    fn check_all_splits(
656        discovered_splits: &BTreeMap<SplitId, TestSplit>,
657        diff: &HashMap<ActorId, Vec<TestSplit>>,
658    ) {
659        let mut split_ids: HashSet<_> = discovered_splits.keys().cloned().collect();
660
661        for splits in diff.values() {
662            for split in splits {
663                assert!(split_ids.remove(&split.id()))
664            }
665        }
666
667        assert!(split_ids.is_empty());
668    }
669
670    #[test]
671    fn test_drop_splits() {
672        let mut actor_splits: HashMap<ActorId, _> = HashMap::new();
673        actor_splits.insert(0, vec![TestSplit { id: 0 }, TestSplit { id: 1 }]);
674        actor_splits.insert(1, vec![TestSplit { id: 2 }, TestSplit { id: 3 }]);
675        actor_splits.insert(2, vec![TestSplit { id: 4 }, TestSplit { id: 5 }]);
676
677        let mut prev_split_to_actor = HashMap::new();
678        for (actor_id, splits) in &actor_splits {
679            for split in splits {
680                prev_split_to_actor.insert(split.id(), *actor_id);
681            }
682        }
683
684        let discovered_splits: BTreeMap<SplitId, TestSplit> = (1..5)
685            .map(|i| {
686                let split = TestSplit { id: i };
687                (split.id(), split)
688            })
689            .collect();
690
691        let opts = SplitDiffOptions {
692            enable_scale_in: true,
693            enable_adaptive: false,
694        };
695
696        let prev_split_ids: HashSet<_> = actor_splits
697            .values()
698            .flat_map(|splits| splits.iter().map(|split| split.id()))
699            .collect();
700
701        let diff = reassign_splits(
702            FragmentId::default(),
703            actor_splits,
704            &discovered_splits,
705            opts,
706        )
707        .unwrap();
708        check_all_splits(&discovered_splits, &diff);
709
710        let mut after_split_to_actor = HashMap::new();
711        for (actor_id, splits) in &diff {
712            for split in splits {
713                after_split_to_actor.insert(split.id(), *actor_id);
714            }
715        }
716
717        let discovered_split_ids: HashSet<_> = discovered_splits.keys().cloned().collect();
718
719        let retained_split_ids: HashSet<_> =
720            prev_split_ids.intersection(&discovered_split_ids).collect();
721
722        for retained_split_id in retained_split_ids {
723            assert_eq!(
724                prev_split_to_actor.get(retained_split_id),
725                after_split_to_actor.get(retained_split_id)
726            )
727        }
728    }
729
730    #[test]
731    fn test_drop_splits_to_empty() {
732        let mut actor_splits: HashMap<ActorId, _> = HashMap::new();
733        actor_splits.insert(0, vec![TestSplit { id: 0 }]);
734
735        let discovered_splits: BTreeMap<SplitId, TestSplit> = BTreeMap::new();
736
737        let opts = SplitDiffOptions {
738            enable_scale_in: true,
739            enable_adaptive: false,
740        };
741
742        let diff = reassign_splits(
743            FragmentId::default(),
744            actor_splits,
745            &discovered_splits,
746            opts,
747        )
748        .unwrap();
749
750        assert!(!diff.is_empty())
751    }
752
753    #[test]
754    fn test_reassign_splits() {
755        let actor_splits = HashMap::new();
756        let discovered_splits: BTreeMap<SplitId, TestSplit> = BTreeMap::new();
757        assert!(
758            reassign_splits(
759                FragmentId::default(),
760                actor_splits,
761                &discovered_splits,
762                Default::default()
763            )
764            .is_none()
765        );
766
767        let actor_splits = (0..3).map(|i| (i, vec![])).collect();
768        let discovered_splits: BTreeMap<SplitId, TestSplit> = BTreeMap::new();
769        let diff = reassign_splits(
770            FragmentId::default(),
771            actor_splits,
772            &discovered_splits,
773            Default::default(),
774        )
775        .unwrap();
776        assert_eq!(diff.len(), 3);
777        for splits in diff.values() {
778            assert!(splits.is_empty())
779        }
780
781        let actor_splits = (0..3).map(|i| (i, vec![])).collect();
782        let discovered_splits: BTreeMap<SplitId, TestSplit> = (0..3)
783            .map(|i| {
784                let split = TestSplit { id: i };
785                (split.id(), split)
786            })
787            .collect();
788
789        let diff = reassign_splits(
790            FragmentId::default(),
791            actor_splits,
792            &discovered_splits,
793            Default::default(),
794        )
795        .unwrap();
796        assert_eq!(diff.len(), 3);
797        for splits in diff.values() {
798            assert_eq!(splits.len(), 1);
799        }
800
801        check_all_splits(&discovered_splits, &diff);
802
803        let actor_splits = (0..3).map(|i| (i, vec![TestSplit { id: i }])).collect();
804        let discovered_splits: BTreeMap<SplitId, TestSplit> = (0..5)
805            .map(|i| {
806                let split = TestSplit { id: i };
807                (split.id(), split)
808            })
809            .collect();
810
811        let diff = reassign_splits(
812            FragmentId::default(),
813            actor_splits,
814            &discovered_splits,
815            Default::default(),
816        )
817        .unwrap();
818        assert_eq!(diff.len(), 3);
819        for splits in diff.values() {
820            let len = splits.len();
821            assert!(len == 1 || len == 2);
822        }
823
824        check_all_splits(&discovered_splits, &diff);
825
826        let mut actor_splits: HashMap<ActorId, Vec<TestSplit>> =
827            (0..3).map(|i| (i, vec![TestSplit { id: i }])).collect();
828        actor_splits.insert(3, vec![]);
829        actor_splits.insert(4, vec![]);
830
831        let discovered_splits: BTreeMap<SplitId, TestSplit> = (0..5)
832            .map(|i| {
833                let split = TestSplit { id: i };
834                (split.id(), split)
835            })
836            .collect();
837
838        let diff = reassign_splits(
839            FragmentId::default(),
840            actor_splits,
841            &discovered_splits,
842            Default::default(),
843        )
844        .unwrap();
845        assert_eq!(diff.len(), 5);
846        for splits in diff.values() {
847            assert_eq!(splits.len(), 1);
848        }
849
850        check_all_splits(&discovered_splits, &diff);
851    }
852}