risingwave_meta/stream/
scale.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::{Ordering, min};
16use std::collections::hash_map::DefaultHasher;
17use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque};
18use std::fmt::Debug;
19use std::hash::{Hash, Hasher};
20use std::sync::Arc;
21use std::time::Duration;
22
23use anyhow::{Context, anyhow};
24use itertools::Itertools;
25use num_integer::Integer;
26use num_traits::abs;
27use risingwave_common::bail;
28use risingwave_common::bitmap::{Bitmap, BitmapBuilder};
29use risingwave_common::catalog::{DatabaseId, TableId};
30use risingwave_common::hash::ActorMapping;
31use risingwave_common::util::iter_util::ZipEqDebug;
32use risingwave_meta_model::{ObjectId, WorkerId, actor, fragment, streaming_job};
33use risingwave_pb::common::{WorkerNode, WorkerType};
34use risingwave_pb::meta::FragmentWorkerSlotMappings;
35use risingwave_pb::meta::subscribe_response::{Info, Operation};
36use risingwave_pb::meta::table_fragments::fragment::{
37    FragmentDistributionType, PbFragmentDistributionType,
38};
39use risingwave_pb::meta::table_fragments::{self, State};
40use risingwave_pb::stream_plan::{
41    Dispatcher, FragmentTypeFlag, PbDispatcher, PbDispatcherType, StreamNode,
42};
43use thiserror_ext::AsReport;
44use tokio::sync::oneshot::Receiver;
45use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard, oneshot};
46use tokio::task::JoinHandle;
47use tokio::time::{Instant, MissedTickBehavior};
48
49use crate::barrier::{Command, Reschedule};
50use crate::controller::scale::RescheduleWorkingSet;
51use crate::manager::{LocalNotification, MetaSrvEnv, MetadataManager};
52use crate::model::{
53    ActorId, DispatcherId, FragmentId, StreamActor, StreamActorWithDispatchers, TableParallelism,
54};
55use crate::serving::{
56    ServingVnodeMapping, to_deleted_fragment_worker_slot_mapping, to_fragment_worker_slot_mapping,
57};
58use crate::stream::{GlobalStreamManager, SourceManagerRef};
59use crate::{MetaError, MetaResult};
60
61#[derive(Debug, Clone, Eq, PartialEq)]
62pub struct WorkerReschedule {
63    pub worker_actor_diff: BTreeMap<WorkerId, isize>,
64}
65
66pub struct CustomFragmentInfo {
67    pub fragment_id: u32,
68    pub fragment_type_mask: u32,
69    pub distribution_type: PbFragmentDistributionType,
70    pub state_table_ids: Vec<u32>,
71    pub node: StreamNode,
72    pub actor_template: StreamActorWithDispatchers,
73    pub actors: Vec<CustomActorInfo>,
74}
75
76#[derive(Default, Clone)]
77pub struct CustomActorInfo {
78    pub actor_id: u32,
79    pub fragment_id: u32,
80    pub dispatcher: Vec<Dispatcher>,
81    /// `None` if singleton.
82    pub vnode_bitmap: Option<Bitmap>,
83}
84
85impl CustomFragmentInfo {
86    pub fn get_fragment_type_mask(&self) -> u32 {
87        self.fragment_type_mask
88    }
89
90    pub fn distribution_type(&self) -> FragmentDistributionType {
91        self.distribution_type
92    }
93}
94
95use educe::Educe;
96use futures::future::try_join_all;
97use risingwave_common::system_param::AdaptiveParallelismStrategy;
98use risingwave_common::system_param::reader::SystemParamsRead;
99use risingwave_common::util::stream_graph_visitor::visit_stream_node_cont;
100use risingwave_meta_model::DispatcherType;
101use risingwave_pb::stream_plan::stream_node::NodeBody;
102
103use super::SourceChange;
104use crate::controller::id::IdCategory;
105use crate::controller::utils::filter_workers_by_resource_group;
106
107// The debug implementation is arbitrary. Just used in debug logs.
108#[derive(Educe)]
109#[educe(Debug)]
110pub struct RescheduleContext {
111    /// Meta information for all Actors
112    #[educe(Debug(ignore))]
113    actor_map: HashMap<ActorId, CustomActorInfo>,
114    /// Status of all Actors, used to find the location of the `Actor`
115    actor_status: BTreeMap<ActorId, WorkerId>,
116    /// Meta information of all `Fragment`, used to find the `Fragment`'s `Actor`
117    #[educe(Debug(ignore))]
118    fragment_map: HashMap<FragmentId, CustomFragmentInfo>,
119    /// Fragments with `StreamSource`
120    stream_source_fragment_ids: HashSet<FragmentId>,
121    /// Fragments with `StreamSourceBackfill` and the corresponding upstream source fragment
122    stream_source_backfill_fragment_ids: HashMap<FragmentId, FragmentId>,
123    /// Target fragments in `NoShuffle` relation
124    no_shuffle_target_fragment_ids: HashSet<FragmentId>,
125    /// Source fragments in `NoShuffle` relation
126    no_shuffle_source_fragment_ids: HashSet<FragmentId>,
127    // index for dispatcher type from upstream fragment to downstream fragment
128    fragment_dispatcher_map: HashMap<FragmentId, HashMap<FragmentId, DispatcherType>>,
129    fragment_upstreams: HashMap<
130        risingwave_meta_model::FragmentId,
131        HashMap<risingwave_meta_model::FragmentId, DispatcherType>,
132    >,
133}
134
135impl RescheduleContext {
136    fn actor_id_to_worker_id(&self, actor_id: &ActorId) -> MetaResult<WorkerId> {
137        self.actor_status
138            .get(actor_id)
139            .cloned()
140            .ok_or_else(|| anyhow!("could not find worker for actor {}", actor_id).into())
141    }
142}
143
144/// This function provides an simple balancing method
145/// The specific process is as follows
146///
147/// 1. Calculate the number of target actors, and calculate the average value and the remainder, and
148///    use the average value as expected.
149///
150/// 2. Filter out the actor to be removed and the actor to be retained, and sort them from largest
151///    to smallest (according to the number of virtual nodes held).
152///
153/// 3. Calculate their balance, 1) For the actors to be removed, the number of virtual nodes per
154///    actor is the balance. 2) For retained actors, the number of virtual nodes - expected is the
155///    balance. 3) For newly created actors, -expected is the balance (always negative).
156///
157/// 4. Allocate the remainder, high priority to newly created nodes.
158///
159/// 5. After that, merge removed, retained and created into a queue, with the head of the queue
160///    being the source, and move the virtual nodes to the destination at the end of the queue.
161///
162/// This can handle scale in, scale out, migration, and simultaneous scaling with as much affinity
163/// as possible.
164///
165/// Note that this function can only rebalance actors whose `vnode_bitmap` is not `None`, in other
166/// words, for `Fragment` of `FragmentDistributionType::Single`, using this function will cause
167/// assert to fail and should be skipped from the upper level.
168///
169/// The return value is the bitmap distribution after scaling, which covers all virtual node indexes
170pub fn rebalance_actor_vnode(
171    actors: &[CustomActorInfo],
172    actors_to_remove: &BTreeSet<ActorId>,
173    actors_to_create: &BTreeSet<ActorId>,
174) -> HashMap<ActorId, Bitmap> {
175    let actor_ids: BTreeSet<_> = actors.iter().map(|actor| actor.actor_id).collect();
176
177    assert_eq!(actors_to_remove.difference(&actor_ids).count(), 0);
178    assert_eq!(actors_to_create.intersection(&actor_ids).count(), 0);
179
180    assert!(actors.len() >= actors_to_remove.len());
181
182    let target_actor_count = actors.len() - actors_to_remove.len() + actors_to_create.len();
183    assert!(target_actor_count > 0);
184
185    // `vnode_bitmap` must be set on distributed fragments.
186    let vnode_count = actors[0]
187        .vnode_bitmap
188        .as_ref()
189        .expect("vnode bitmap unset")
190        .len();
191
192    // represents the balance of each actor, used to sort later
193    #[derive(Debug)]
194    struct Balance {
195        actor_id: ActorId,
196        balance: i32,
197        builder: BitmapBuilder,
198    }
199    let (expected, mut remain) = vnode_count.div_rem(&target_actor_count);
200
201    tracing::debug!(
202        "expected {}, remain {}, prev actors {}, target actors {}",
203        expected,
204        remain,
205        actors.len(),
206        target_actor_count,
207    );
208
209    let (mut removed, mut rest): (Vec<_>, Vec<_>) = actors
210        .iter()
211        .map(|actor| {
212            (
213                actor.actor_id as ActorId,
214                actor.vnode_bitmap.clone().expect("vnode bitmap unset"),
215            )
216        })
217        .partition(|(actor_id, _)| actors_to_remove.contains(actor_id));
218
219    let order_by_bitmap_desc =
220        |(id_a, bitmap_a): &(ActorId, Bitmap), (id_b, bitmap_b): &(ActorId, Bitmap)| -> Ordering {
221            bitmap_a
222                .count_ones()
223                .cmp(&bitmap_b.count_ones())
224                .reverse()
225                .then(id_a.cmp(id_b))
226        };
227
228    let builder_from_bitmap = |bitmap: &Bitmap| -> BitmapBuilder {
229        let mut builder = BitmapBuilder::default();
230        builder.append_bitmap(bitmap);
231        builder
232    };
233
234    let (prev_expected, _) = vnode_count.div_rem(&actors.len());
235
236    let prev_remain = removed
237        .iter()
238        .map(|(_, bitmap)| {
239            assert!(bitmap.count_ones() >= prev_expected);
240            bitmap.count_ones() - prev_expected
241        })
242        .sum::<usize>();
243
244    removed.sort_by(order_by_bitmap_desc);
245    rest.sort_by(order_by_bitmap_desc);
246
247    let removed_balances = removed.into_iter().map(|(actor_id, bitmap)| Balance {
248        actor_id,
249        balance: bitmap.count_ones() as i32,
250        builder: builder_from_bitmap(&bitmap),
251    });
252
253    let mut rest_balances = rest
254        .into_iter()
255        .map(|(actor_id, bitmap)| Balance {
256            actor_id,
257            balance: bitmap.count_ones() as i32 - expected as i32,
258            builder: builder_from_bitmap(&bitmap),
259        })
260        .collect_vec();
261
262    let mut created_balances = actors_to_create
263        .iter()
264        .map(|actor_id| Balance {
265            actor_id: *actor_id,
266            balance: -(expected as i32),
267            builder: BitmapBuilder::zeroed(vnode_count),
268        })
269        .collect_vec();
270
271    for balance in created_balances
272        .iter_mut()
273        .rev()
274        .take(prev_remain)
275        .chain(rest_balances.iter_mut())
276    {
277        if remain > 0 {
278            balance.balance -= 1;
279            remain -= 1;
280        }
281    }
282
283    // consume the rest `remain`
284    for balance in &mut created_balances {
285        if remain > 0 {
286            balance.balance -= 1;
287            remain -= 1;
288        }
289    }
290
291    assert_eq!(remain, 0);
292
293    let mut v: VecDeque<_> = removed_balances
294        .chain(rest_balances)
295        .chain(created_balances)
296        .collect();
297
298    // We will return the full bitmap here after rebalancing,
299    // if we want to return only the changed actors, filter balance = 0 here
300    let mut result = HashMap::with_capacity(target_actor_count);
301
302    for balance in &v {
303        tracing::debug!(
304            "actor {:5}\tbalance {:5}\tR[{:5}]\tC[{:5}]",
305            balance.actor_id,
306            balance.balance,
307            actors_to_remove.contains(&balance.actor_id),
308            actors_to_create.contains(&balance.actor_id)
309        );
310    }
311
312    while !v.is_empty() {
313        if v.len() == 1 {
314            let single = v.pop_front().unwrap();
315            assert_eq!(single.balance, 0);
316            if !actors_to_remove.contains(&single.actor_id) {
317                result.insert(single.actor_id, single.builder.finish());
318            }
319
320            continue;
321        }
322
323        let mut src = v.pop_front().unwrap();
324        let mut dst = v.pop_back().unwrap();
325
326        let n = min(abs(src.balance), abs(dst.balance));
327
328        let mut moved = 0;
329        for idx in (0..vnode_count).rev() {
330            if moved >= n {
331                break;
332            }
333
334            if src.builder.is_set(idx) {
335                src.builder.set(idx, false);
336                assert!(!dst.builder.is_set(idx));
337                dst.builder.set(idx, true);
338                moved += 1;
339            }
340        }
341
342        src.balance -= n;
343        dst.balance += n;
344
345        if src.balance != 0 {
346            v.push_front(src);
347        } else if !actors_to_remove.contains(&src.actor_id) {
348            result.insert(src.actor_id, src.builder.finish());
349        }
350
351        if dst.balance != 0 {
352            v.push_back(dst);
353        } else {
354            result.insert(dst.actor_id, dst.builder.finish());
355        }
356    }
357
358    result
359}
360
361#[derive(Debug, Clone, Copy)]
362pub struct RescheduleOptions {
363    /// Whether to resolve the upstream of `NoShuffle` when scaling. It will check whether all the reschedules in the no shuffle dependency tree are corresponding, and rewrite them to the root of the no shuffle dependency tree.
364    pub resolve_no_shuffle_upstream: bool,
365
366    /// Whether to skip creating new actors. If it is true, the scaling-out actors will not be created.
367    pub skip_create_new_actors: bool,
368}
369
370pub type ScaleControllerRef = Arc<ScaleController>;
371
372pub struct ScaleController {
373    pub metadata_manager: MetadataManager,
374
375    pub source_manager: SourceManagerRef,
376
377    pub env: MetaSrvEnv,
378
379    /// We will acquire lock during DDL to prevent scaling operations on jobs that are in the creating state.
380    /// e.g., a MV cannot be rescheduled during foreground backfill.
381    pub reschedule_lock: RwLock<()>,
382}
383
384impl ScaleController {
385    pub fn new(
386        metadata_manager: &MetadataManager,
387        source_manager: SourceManagerRef,
388        env: MetaSrvEnv,
389    ) -> Self {
390        Self {
391            metadata_manager: metadata_manager.clone(),
392            source_manager,
393            env,
394            reschedule_lock: RwLock::new(()),
395        }
396    }
397
398    pub async fn integrity_check(&self) -> MetaResult<()> {
399        self.metadata_manager
400            .catalog_controller
401            .integrity_check()
402            .await
403    }
404
405    /// Build the context for rescheduling and do some validation for the request.
406    async fn build_reschedule_context(
407        &self,
408        reschedule: &mut HashMap<FragmentId, WorkerReschedule>,
409        options: RescheduleOptions,
410        table_parallelisms: &mut HashMap<TableId, TableParallelism>,
411    ) -> MetaResult<RescheduleContext> {
412        let worker_nodes: HashMap<WorkerId, WorkerNode> = self
413            .metadata_manager
414            .list_active_streaming_compute_nodes()
415            .await?
416            .into_iter()
417            .map(|worker_node| (worker_node.id as _, worker_node))
418            .collect();
419
420        if worker_nodes.is_empty() {
421            bail!("no available compute node in the cluster");
422        }
423
424        // Check if we are trying to move a fragment to a node marked as unschedulable
425        let unschedulable_worker_ids: HashSet<_> = worker_nodes
426            .values()
427            .filter(|w| {
428                w.property
429                    .as_ref()
430                    .map(|property| property.is_unschedulable)
431                    .unwrap_or(false)
432            })
433            .map(|worker| worker.id as WorkerId)
434            .collect();
435
436        for (fragment_id, reschedule) in &*reschedule {
437            for (worker_id, change) in &reschedule.worker_actor_diff {
438                if unschedulable_worker_ids.contains(worker_id) && change.is_positive() {
439                    bail!(
440                        "unable to move fragment {} to unschedulable worker {}",
441                        fragment_id,
442                        worker_id
443                    );
444                }
445            }
446        }
447
448        // FIXME: the same as anther place calling `list_table_fragments` in scaling.
449        // Index for StreamActor
450        let mut actor_map = HashMap::new();
451        // Index for Fragment
452        let mut fragment_map = HashMap::new();
453        // Index for actor status, including actor's worker id
454        let mut actor_status = BTreeMap::new();
455        let mut fragment_state = HashMap::new();
456        let mut fragment_to_table = HashMap::new();
457
458        fn fulfill_index_by_fragment_ids(
459            actor_map: &mut HashMap<u32, CustomActorInfo>,
460            fragment_map: &mut HashMap<FragmentId, CustomFragmentInfo>,
461            actor_status: &mut BTreeMap<ActorId, WorkerId>,
462            fragment_state: &mut HashMap<FragmentId, State>,
463            fragment_to_table: &mut HashMap<FragmentId, TableId>,
464            fragments: HashMap<risingwave_meta_model::FragmentId, fragment::Model>,
465            actors: HashMap<ActorId, actor::Model>,
466            mut actor_dispatchers: HashMap<ActorId, Vec<PbDispatcher>>,
467            related_jobs: HashMap<ObjectId, (streaming_job::Model, String)>,
468        ) {
469            let mut fragment_actors: HashMap<
470                risingwave_meta_model::FragmentId,
471                Vec<CustomActorInfo>,
472            > = HashMap::new();
473
474            let mut expr_contexts = HashMap::new();
475            for (
476                _,
477                actor::Model {
478                    actor_id,
479                    fragment_id,
480                    status: _,
481                    splits: _,
482                    worker_id,
483                    vnode_bitmap,
484                    expr_context,
485                    ..
486                },
487            ) in actors
488            {
489                let dispatchers = actor_dispatchers
490                    .remove(&(actor_id as _))
491                    .unwrap_or_default();
492
493                let actor_info = CustomActorInfo {
494                    actor_id: actor_id as _,
495                    fragment_id: fragment_id as _,
496                    dispatcher: dispatchers,
497                    vnode_bitmap: vnode_bitmap.map(|b| Bitmap::from(&b.to_protobuf())),
498                };
499
500                actor_map.insert(actor_id as _, actor_info.clone());
501
502                fragment_actors
503                    .entry(fragment_id as _)
504                    .or_default()
505                    .push(actor_info);
506
507                actor_status.insert(actor_id as _, worker_id as WorkerId);
508
509                expr_contexts.insert(actor_id as u32, expr_context);
510            }
511
512            for (
513                _,
514                fragment::Model {
515                    fragment_id,
516                    job_id,
517                    fragment_type_mask,
518                    distribution_type,
519                    stream_node,
520                    state_table_ids,
521                    ..
522                },
523            ) in fragments
524            {
525                let actors = fragment_actors
526                    .remove(&(fragment_id as _))
527                    .unwrap_or_default();
528
529                let CustomActorInfo {
530                    actor_id,
531                    fragment_id,
532                    dispatcher,
533                    vnode_bitmap,
534                } = actors.first().unwrap().clone();
535
536                let (related_job, job_definition) =
537                    related_jobs.get(&job_id).expect("job not found");
538
539                let fragment = CustomFragmentInfo {
540                    fragment_id: fragment_id as _,
541                    fragment_type_mask: fragment_type_mask as _,
542                    distribution_type: distribution_type.into(),
543                    state_table_ids: state_table_ids.into_u32_array(),
544                    node: stream_node.to_protobuf(),
545                    actor_template: (
546                        StreamActor {
547                            actor_id,
548                            fragment_id: fragment_id as _,
549                            vnode_bitmap,
550                            mview_definition: job_definition.to_owned(),
551                            expr_context: expr_contexts
552                                .get(&actor_id)
553                                .cloned()
554                                .map(|expr_context| expr_context.to_protobuf()),
555                        },
556                        dispatcher,
557                    ),
558                    actors,
559                };
560
561                fragment_map.insert(fragment_id as _, fragment);
562
563                fragment_to_table.insert(fragment_id as _, TableId::from(job_id as u32));
564
565                fragment_state.insert(
566                    fragment_id,
567                    table_fragments::PbState::from(related_job.job_status),
568                );
569            }
570        }
571        let fragment_ids = reschedule.keys().map(|id| *id as _).collect();
572        let working_set = self
573            .metadata_manager
574            .catalog_controller
575            .resolve_working_set_for_reschedule_fragments(fragment_ids)
576            .await?;
577
578        fulfill_index_by_fragment_ids(
579            &mut actor_map,
580            &mut fragment_map,
581            &mut actor_status,
582            &mut fragment_state,
583            &mut fragment_to_table,
584            working_set.fragments,
585            working_set.actors,
586            working_set.actor_dispatchers,
587            working_set.related_jobs,
588        );
589
590        // NoShuffle relation index
591        let mut no_shuffle_source_fragment_ids = HashSet::new();
592        let mut no_shuffle_target_fragment_ids = HashSet::new();
593
594        Self::build_no_shuffle_relation_index(
595            &actor_map,
596            &mut no_shuffle_source_fragment_ids,
597            &mut no_shuffle_target_fragment_ids,
598        );
599
600        if options.resolve_no_shuffle_upstream {
601            let original_reschedule_keys = reschedule.keys().cloned().collect();
602
603            Self::resolve_no_shuffle_upstream_fragments(
604                reschedule,
605                &no_shuffle_source_fragment_ids,
606                &no_shuffle_target_fragment_ids,
607                &working_set.fragment_upstreams,
608            )?;
609
610            if !table_parallelisms.is_empty() {
611                // We need to reiterate through the NO_SHUFFLE dependencies in order to ascertain which downstream table the custom modifications of the table have been propagated from.
612                Self::resolve_no_shuffle_upstream_tables(
613                    original_reschedule_keys,
614                    &no_shuffle_source_fragment_ids,
615                    &no_shuffle_target_fragment_ids,
616                    &fragment_to_table,
617                    &working_set.fragment_upstreams,
618                    table_parallelisms,
619                )?;
620            }
621        }
622
623        let mut fragment_dispatcher_map = HashMap::new();
624        Self::build_fragment_dispatcher_index(&actor_map, &mut fragment_dispatcher_map);
625
626        let mut stream_source_fragment_ids = HashSet::new();
627        let mut stream_source_backfill_fragment_ids = HashMap::new();
628        let mut no_shuffle_reschedule = HashMap::new();
629        for (fragment_id, WorkerReschedule { worker_actor_diff }) in &*reschedule {
630            let fragment = fragment_map
631                .get(fragment_id)
632                .ok_or_else(|| anyhow!("fragment {fragment_id} does not exist"))?;
633
634            // Check if the rescheduling is supported.
635            match fragment_state[fragment_id] {
636                table_fragments::State::Unspecified => unreachable!(),
637                state @ table_fragments::State::Initial => {
638                    bail!(
639                        "the materialized view of fragment {fragment_id} is in state {}",
640                        state.as_str_name()
641                    )
642                }
643                state @ table_fragments::State::Creating => {
644                    let stream_node = &fragment.node;
645
646                    let mut is_reschedulable = true;
647                    visit_stream_node_cont(stream_node, |body| {
648                        if let Some(NodeBody::StreamScan(node)) = &body.node_body {
649                            if !node.stream_scan_type().is_reschedulable() {
650                                is_reschedulable = false;
651
652                                // fail fast
653                                return false;
654                            }
655
656                            // continue visiting
657                            return true;
658                        }
659
660                        // continue visiting
661                        true
662                    });
663
664                    if !is_reschedulable {
665                        bail!(
666                            "the materialized view of fragment {fragment_id} is in state {}",
667                            state.as_str_name()
668                        )
669                    }
670                }
671                table_fragments::State::Created => {}
672            }
673
674            if no_shuffle_target_fragment_ids.contains(fragment_id) {
675                bail!(
676                    "rescheduling NoShuffle downstream fragment (maybe Chain fragment) is forbidden, please use NoShuffle upstream fragment (like Materialized fragment) to scale"
677                );
678            }
679
680            // For the relation of NoShuffle (e.g. Materialize and Chain), we need a special
681            // treatment because the upstream and downstream of NoShuffle are always 1-1
682            // correspondence, so we need to clone the reschedule plan to the downstream of all
683            // cascading relations.
684            if no_shuffle_source_fragment_ids.contains(fragment_id) {
685                // This fragment is a NoShuffle's upstream.
686                let mut queue: VecDeque<_> = fragment_dispatcher_map
687                    .get(fragment_id)
688                    .unwrap()
689                    .keys()
690                    .cloned()
691                    .collect();
692
693                while let Some(downstream_id) = queue.pop_front() {
694                    if !no_shuffle_target_fragment_ids.contains(&downstream_id) {
695                        continue;
696                    }
697
698                    if let Some(downstream_fragments) = fragment_dispatcher_map.get(&downstream_id)
699                    {
700                        let no_shuffle_downstreams = downstream_fragments
701                            .iter()
702                            .filter(|(_, ty)| **ty == DispatcherType::NoShuffle)
703                            .map(|(fragment_id, _)| fragment_id);
704
705                        queue.extend(no_shuffle_downstreams.copied());
706                    }
707
708                    no_shuffle_reschedule.insert(
709                        downstream_id,
710                        WorkerReschedule {
711                            worker_actor_diff: worker_actor_diff.clone(),
712                        },
713                    );
714                }
715            }
716
717            if (fragment.fragment_type_mask & FragmentTypeFlag::Source as u32) != 0
718                && fragment.node.find_stream_source().is_some()
719            {
720                stream_source_fragment_ids.insert(*fragment_id);
721            }
722
723            // Check if the reschedule plan is valid.
724            let current_worker_ids = fragment
725                .actors
726                .iter()
727                .map(|a| actor_status.get(&a.actor_id).cloned().unwrap())
728                .collect::<HashSet<_>>();
729
730            for (removed, change) in worker_actor_diff {
731                if !current_worker_ids.contains(removed) && change.is_negative() {
732                    bail!(
733                        "no actor on the worker {} of fragment {}",
734                        removed,
735                        fragment_id
736                    );
737                }
738            }
739
740            let added_actor_count: usize = worker_actor_diff
741                .values()
742                .filter(|change| change.is_positive())
743                .cloned()
744                .map(|change| change as usize)
745                .sum();
746
747            let removed_actor_count: usize = worker_actor_diff
748                .values()
749                .filter(|change| change.is_positive())
750                .cloned()
751                .map(|v| v.unsigned_abs())
752                .sum();
753
754            match fragment.distribution_type() {
755                FragmentDistributionType::Hash => {
756                    if fragment.actors.len() + added_actor_count <= removed_actor_count {
757                        bail!("can't remove all actors from fragment {}", fragment_id);
758                    }
759                }
760                FragmentDistributionType::Single => {
761                    if added_actor_count != removed_actor_count {
762                        bail!("single distribution fragment only support migration");
763                    }
764                }
765                FragmentDistributionType::Unspecified => unreachable!(),
766            }
767        }
768
769        if !no_shuffle_reschedule.is_empty() {
770            tracing::info!(
771                "reschedule plan rewritten with NoShuffle reschedule {:?}",
772                no_shuffle_reschedule
773            );
774
775            for noshuffle_downstream in no_shuffle_reschedule.keys() {
776                let fragment = fragment_map.get(noshuffle_downstream).unwrap();
777                // SourceScan is always a NoShuffle downstream, rescheduled together with the upstream Source.
778                if (fragment.fragment_type_mask & FragmentTypeFlag::SourceScan as u32) != 0 {
779                    let stream_node = &fragment.node;
780                    if let Some((_source_id, upstream_source_fragment_id)) =
781                        stream_node.find_source_backfill()
782                    {
783                        stream_source_backfill_fragment_ids
784                            .insert(fragment.fragment_id, upstream_source_fragment_id);
785                    }
786                }
787            }
788        }
789
790        // Modifications for NoShuffle downstream.
791        reschedule.extend(no_shuffle_reschedule.into_iter());
792
793        Ok(RescheduleContext {
794            actor_map,
795            actor_status,
796            fragment_map,
797            stream_source_fragment_ids,
798            stream_source_backfill_fragment_ids,
799            no_shuffle_target_fragment_ids,
800            no_shuffle_source_fragment_ids,
801            fragment_dispatcher_map,
802            fragment_upstreams: working_set.fragment_upstreams,
803        })
804    }
805
806    /// From the high-level [`WorkerReschedule`] to the low-level reschedule plan [`Reschedule`].
807    ///
808    /// Returns `(reschedule_fragment, applied_reschedules)`
809    /// - `reschedule_fragment`: the generated reschedule plan
810    /// - `applied_reschedules`: the changes that need to be updated to the meta store (`pre_apply_reschedules`, only for V1).
811    ///
812    /// In [normal process of scaling](`GlobalStreamManager::reschedule_actors`), we use the returned values to
813    /// build a [`Command::RescheduleFragment`], which will then flows through the barrier mechanism to perform scaling.
814    /// Meta store is updated after the barrier is collected.
815    ///
816    /// During recovery, we don't need the barrier mechanism, and can directly use the returned values to update meta.
817    pub(crate) async fn analyze_reschedule_plan(
818        &self,
819        mut reschedules: HashMap<FragmentId, WorkerReschedule>,
820        options: RescheduleOptions,
821        table_parallelisms: &mut HashMap<TableId, TableParallelism>,
822    ) -> MetaResult<HashMap<FragmentId, Reschedule>> {
823        tracing::debug!("build_reschedule_context, reschedules: {:#?}", reschedules);
824        let ctx = self
825            .build_reschedule_context(&mut reschedules, options, table_parallelisms)
826            .await?;
827        tracing::debug!("reschedule context: {:#?}", ctx);
828        let reschedules = reschedules;
829
830        // Here, the plan for both upstream and downstream of the NO_SHUFFLE Fragment should already have been populated.
831
832        // Index of actors to create/remove
833        // Fragment Id => ( Actor Id => Worker Id )
834        let (fragment_actors_to_remove, fragment_actors_to_create) =
835            self.arrange_reschedules(&reschedules, &ctx)?;
836
837        let mut fragment_actor_bitmap = HashMap::new();
838        for fragment_id in reschedules.keys() {
839            if ctx.no_shuffle_target_fragment_ids.contains(fragment_id) {
840                // skipping chain fragment, we need to clone the upstream materialize fragment's
841                // mapping later
842                continue;
843            }
844
845            let actors_to_create = fragment_actors_to_create
846                .get(fragment_id)
847                .map(|map| map.iter().map(|(actor_id, _)| *actor_id).collect())
848                .unwrap_or_default();
849
850            let actors_to_remove = fragment_actors_to_remove
851                .get(fragment_id)
852                .map(|map| map.iter().map(|(actor_id, _)| *actor_id).collect())
853                .unwrap_or_default();
854
855            let fragment = ctx.fragment_map.get(fragment_id).unwrap();
856
857            match fragment.distribution_type() {
858                FragmentDistributionType::Single => {
859                    // Skip re-balancing action for single distribution (always None)
860                    fragment_actor_bitmap
861                        .insert(fragment.fragment_id as FragmentId, Default::default());
862                }
863                FragmentDistributionType::Hash => {
864                    let actor_vnode = rebalance_actor_vnode(
865                        &fragment.actors,
866                        &actors_to_remove,
867                        &actors_to_create,
868                    );
869
870                    fragment_actor_bitmap.insert(fragment.fragment_id as FragmentId, actor_vnode);
871                }
872
873                FragmentDistributionType::Unspecified => unreachable!(),
874            }
875        }
876
877        // Index for fragment -> { actor -> worker_id } after reschedule.
878        // Since we need to organize the upstream and downstream relationships of NoShuffle,
879        // we need to organize the actor distribution after a scaling.
880        let mut fragment_actors_after_reschedule = HashMap::with_capacity(reschedules.len());
881        for fragment_id in reschedules.keys() {
882            let fragment = ctx.fragment_map.get(fragment_id).unwrap();
883            let mut new_actor_ids = BTreeMap::new();
884            for actor in &fragment.actors {
885                if let Some(actors_to_remove) = fragment_actors_to_remove.get(fragment_id) {
886                    if actors_to_remove.contains_key(&actor.actor_id) {
887                        continue;
888                    }
889                }
890                let worker_id = ctx.actor_id_to_worker_id(&actor.actor_id)?;
891                new_actor_ids.insert(actor.actor_id as ActorId, worker_id);
892            }
893
894            if let Some(actors_to_create) = fragment_actors_to_create.get(fragment_id) {
895                for (actor_id, worker_id) in actors_to_create {
896                    new_actor_ids.insert(*actor_id, *worker_id);
897                }
898            }
899
900            assert!(
901                !new_actor_ids.is_empty(),
902                "should be at least one actor in fragment {} after rescheduling",
903                fragment_id
904            );
905
906            fragment_actors_after_reschedule.insert(*fragment_id, new_actor_ids);
907        }
908
909        let fragment_actors_after_reschedule = fragment_actors_after_reschedule;
910
911        // In order to maintain consistency with the original structure, the upstream and downstream
912        // actors of NoShuffle need to be in the same worker slot and hold the same virtual nodes,
913        // so for the actors after the upstream re-balancing, since we have sorted the actors of the same fragment by id on all workers,
914        // we can identify the corresponding upstream actor with NO_SHUFFLE.
915        // NOTE: There should be more asserts here to ensure correctness.
916        fn arrange_no_shuffle_relation(
917            ctx: &RescheduleContext,
918            fragment_id: &FragmentId,
919            upstream_fragment_id: &FragmentId,
920            fragment_actors_after_reschedule: &HashMap<FragmentId, BTreeMap<ActorId, WorkerId>>,
921            actor_group_map: &mut HashMap<ActorId, (FragmentId, ActorId)>,
922            fragment_updated_bitmap: &mut HashMap<FragmentId, HashMap<ActorId, Bitmap>>,
923            no_shuffle_upstream_actor_map: &mut HashMap<ActorId, HashMap<FragmentId, ActorId>>,
924            no_shuffle_downstream_actors_map: &mut HashMap<ActorId, HashMap<FragmentId, ActorId>>,
925        ) {
926            if !ctx.no_shuffle_target_fragment_ids.contains(fragment_id) {
927                return;
928            }
929
930            let fragment = &ctx.fragment_map[fragment_id];
931
932            let upstream_fragment = &ctx.fragment_map[upstream_fragment_id];
933
934            // build actor group map
935            for upstream_actor in &upstream_fragment.actors {
936                for dispatcher in &upstream_actor.dispatcher {
937                    if let PbDispatcherType::NoShuffle = dispatcher.get_type().unwrap() {
938                        let downstream_actor_id =
939                            *dispatcher.downstream_actor_id.iter().exactly_one().unwrap();
940
941                        // upstream is root
942                        if !ctx
943                            .no_shuffle_target_fragment_ids
944                            .contains(upstream_fragment_id)
945                        {
946                            actor_group_map.insert(
947                                upstream_actor.actor_id,
948                                (upstream_fragment.fragment_id, upstream_actor.actor_id),
949                            );
950                            actor_group_map.insert(
951                                downstream_actor_id,
952                                (upstream_fragment.fragment_id, upstream_actor.actor_id),
953                            );
954                        } else {
955                            let root_actor_id = actor_group_map[&upstream_actor.actor_id];
956
957                            actor_group_map.insert(downstream_actor_id, root_actor_id);
958                        }
959                    }
960                }
961            }
962
963            // If the upstream is a Singleton Fragment, there will be no Bitmap changes
964            let upstream_fragment_bitmap = fragment_updated_bitmap
965                .get(upstream_fragment_id)
966                .cloned()
967                .unwrap_or_default();
968
969            // Question: Is it possible to have Hash Distribution Fragment but the Actor's bitmap remains unchanged?
970            if upstream_fragment.distribution_type() == FragmentDistributionType::Single {
971                assert!(
972                    upstream_fragment_bitmap.is_empty(),
973                    "single fragment should have no bitmap updates"
974                );
975            }
976
977            let upstream_fragment_actor_map = fragment_actors_after_reschedule
978                .get(upstream_fragment_id)
979                .cloned()
980                .unwrap();
981
982            let fragment_actor_map = fragment_actors_after_reschedule
983                .get(fragment_id)
984                .cloned()
985                .unwrap();
986
987            let mut worker_reverse_index: HashMap<WorkerId, BTreeSet<_>> = HashMap::new();
988
989            // first, find existing actor bitmap, copy them
990            let mut fragment_bitmap = HashMap::new();
991
992            for (actor_id, worker_id) in &fragment_actor_map {
993                if let Some((root_fragment, root_actor_id)) = actor_group_map.get(actor_id) {
994                    let root_bitmap = fragment_updated_bitmap
995                        .get(root_fragment)
996                        .expect("root fragment bitmap not found")
997                        .get(root_actor_id)
998                        .cloned()
999                        .expect("root actor bitmap not found");
1000
1001                    // Copy the bitmap
1002                    fragment_bitmap.insert(*actor_id, root_bitmap);
1003
1004                    no_shuffle_upstream_actor_map
1005                        .entry(*actor_id as ActorId)
1006                        .or_default()
1007                        .insert(*upstream_fragment_id, *root_actor_id);
1008                    no_shuffle_downstream_actors_map
1009                        .entry(*root_actor_id)
1010                        .or_default()
1011                        .insert(*fragment_id, *actor_id);
1012                } else {
1013                    worker_reverse_index
1014                        .entry(*worker_id)
1015                        .or_default()
1016                        .insert(*actor_id);
1017                }
1018            }
1019
1020            let mut upstream_worker_reverse_index: HashMap<WorkerId, BTreeSet<_>> = HashMap::new();
1021
1022            for (actor_id, worker_id) in &upstream_fragment_actor_map {
1023                if !actor_group_map.contains_key(actor_id) {
1024                    upstream_worker_reverse_index
1025                        .entry(*worker_id)
1026                        .or_default()
1027                        .insert(*actor_id);
1028                }
1029            }
1030
1031            // then, find the rest of the actors and copy the bitmap
1032            for (worker_id, actor_ids) in worker_reverse_index {
1033                let upstream_actor_ids = upstream_worker_reverse_index
1034                    .get(&worker_id)
1035                    .unwrap()
1036                    .clone();
1037
1038                assert_eq!(actor_ids.len(), upstream_actor_ids.len());
1039
1040                for (actor_id, upstream_actor_id) in actor_ids
1041                    .into_iter()
1042                    .zip_eq_debug(upstream_actor_ids.into_iter())
1043                {
1044                    match upstream_fragment_bitmap.get(&upstream_actor_id).cloned() {
1045                        None => {
1046                            // single fragment should have no bitmap updates (same as upstream)
1047                            assert_eq!(
1048                                upstream_fragment.distribution_type(),
1049                                FragmentDistributionType::Single
1050                            );
1051                        }
1052                        Some(bitmap) => {
1053                            // Copy the bitmap
1054                            fragment_bitmap.insert(actor_id, bitmap);
1055                        }
1056                    }
1057
1058                    no_shuffle_upstream_actor_map
1059                        .entry(actor_id as ActorId)
1060                        .or_default()
1061                        .insert(*upstream_fragment_id, upstream_actor_id);
1062                    no_shuffle_downstream_actors_map
1063                        .entry(upstream_actor_id)
1064                        .or_default()
1065                        .insert(*fragment_id, actor_id);
1066                }
1067            }
1068
1069            match fragment.distribution_type() {
1070                FragmentDistributionType::Hash => {}
1071                FragmentDistributionType::Single => {
1072                    // single distribution should update nothing
1073                    assert!(fragment_bitmap.is_empty());
1074                }
1075                FragmentDistributionType::Unspecified => unreachable!(),
1076            }
1077
1078            if let Err(e) = fragment_updated_bitmap.try_insert(*fragment_id, fragment_bitmap) {
1079                assert_eq!(
1080                    e.entry.get(),
1081                    &e.value,
1082                    "bitmaps derived from different no-shuffle upstreams mismatch"
1083                );
1084            }
1085
1086            // Visit downstream fragments recursively.
1087            if let Some(downstream_fragments) = ctx.fragment_dispatcher_map.get(fragment_id) {
1088                let no_shuffle_downstreams = downstream_fragments
1089                    .iter()
1090                    .filter(|(_, ty)| **ty == DispatcherType::NoShuffle)
1091                    .map(|(fragment_id, _)| fragment_id);
1092
1093                for downstream_fragment_id in no_shuffle_downstreams {
1094                    arrange_no_shuffle_relation(
1095                        ctx,
1096                        downstream_fragment_id,
1097                        fragment_id,
1098                        fragment_actors_after_reschedule,
1099                        actor_group_map,
1100                        fragment_updated_bitmap,
1101                        no_shuffle_upstream_actor_map,
1102                        no_shuffle_downstream_actors_map,
1103                    );
1104                }
1105            }
1106        }
1107
1108        let mut no_shuffle_upstream_actor_map = HashMap::new();
1109        let mut no_shuffle_downstream_actors_map = HashMap::new();
1110        let mut actor_group_map = HashMap::new();
1111        // For all roots in the upstream and downstream dependency trees of NoShuffle, recursively
1112        // find all correspondences
1113        for fragment_id in reschedules.keys() {
1114            if ctx.no_shuffle_source_fragment_ids.contains(fragment_id)
1115                && !ctx.no_shuffle_target_fragment_ids.contains(fragment_id)
1116            {
1117                if let Some(downstream_fragments) = ctx.fragment_dispatcher_map.get(fragment_id) {
1118                    for downstream_fragment_id in downstream_fragments.keys() {
1119                        arrange_no_shuffle_relation(
1120                            &ctx,
1121                            downstream_fragment_id,
1122                            fragment_id,
1123                            &fragment_actors_after_reschedule,
1124                            &mut actor_group_map,
1125                            &mut fragment_actor_bitmap,
1126                            &mut no_shuffle_upstream_actor_map,
1127                            &mut no_shuffle_downstream_actors_map,
1128                        );
1129                    }
1130                }
1131            }
1132        }
1133
1134        tracing::debug!("actor group map {:?}", actor_group_map);
1135
1136        let mut new_created_actors = HashMap::new();
1137        for fragment_id in reschedules.keys() {
1138            let actors_to_create = fragment_actors_to_create
1139                .get(fragment_id)
1140                .cloned()
1141                .unwrap_or_default();
1142
1143            let fragment = &ctx.fragment_map[fragment_id];
1144
1145            assert!(!fragment.actors.is_empty());
1146
1147            for actor_to_create in &actors_to_create {
1148                let new_actor_id = actor_to_create.0;
1149                let (mut new_actor, mut dispatchers) = fragment.actor_template.clone();
1150
1151                // This should be assigned before the `modify_actor_upstream_and_downstream` call,
1152                // because we need to use the new actor id to find the upstream and
1153                // downstream in the NoShuffle relationship
1154                new_actor.actor_id = *new_actor_id;
1155
1156                Self::modify_actor_upstream_and_downstream(
1157                    &ctx,
1158                    &fragment_actors_to_remove,
1159                    &fragment_actors_to_create,
1160                    &fragment_actor_bitmap,
1161                    &no_shuffle_downstream_actors_map,
1162                    &mut new_actor,
1163                    &mut dispatchers,
1164                )?;
1165
1166                if let Some(bitmap) = fragment_actor_bitmap
1167                    .get(fragment_id)
1168                    .and_then(|actor_bitmaps| actor_bitmaps.get(new_actor_id))
1169                {
1170                    new_actor.vnode_bitmap = Some(bitmap.to_protobuf().into());
1171                }
1172
1173                new_created_actors.insert(*new_actor_id, (new_actor, dispatchers));
1174            }
1175        }
1176
1177        // For stream source & source backfill fragments, we need to reallocate the splits.
1178        // Because we are in the Pause state, so it's no problem to reallocate
1179        let mut fragment_actor_splits = HashMap::new();
1180        for fragment_id in reschedules.keys() {
1181            let actors_after_reschedule = &fragment_actors_after_reschedule[fragment_id];
1182
1183            if ctx.stream_source_fragment_ids.contains(fragment_id) {
1184                let fragment = &ctx.fragment_map[fragment_id];
1185
1186                let prev_actor_ids = fragment
1187                    .actors
1188                    .iter()
1189                    .map(|actor| actor.actor_id)
1190                    .collect_vec();
1191
1192                let curr_actor_ids = actors_after_reschedule.keys().cloned().collect_vec();
1193
1194                let actor_splits = self
1195                    .source_manager
1196                    .migrate_splits_for_source_actors(
1197                        *fragment_id,
1198                        &prev_actor_ids,
1199                        &curr_actor_ids,
1200                    )
1201                    .await?;
1202
1203                tracing::debug!(
1204                    "source actor splits: {:?}, fragment_id: {}",
1205                    actor_splits,
1206                    fragment_id
1207                );
1208                fragment_actor_splits.insert(*fragment_id, actor_splits);
1209            }
1210        }
1211        // We use 2 iterations to make sure source actors are migrated first, and then align backfill actors
1212        if !ctx.stream_source_backfill_fragment_ids.is_empty() {
1213            for fragment_id in reschedules.keys() {
1214                let actors_after_reschedule = &fragment_actors_after_reschedule[fragment_id];
1215
1216                if let Some(upstream_source_fragment_id) =
1217                    ctx.stream_source_backfill_fragment_ids.get(fragment_id)
1218                {
1219                    let curr_actor_ids = actors_after_reschedule.keys().cloned().collect_vec();
1220
1221                    let actor_splits = self.source_manager.migrate_splits_for_backfill_actors(
1222                        *fragment_id,
1223                        *upstream_source_fragment_id,
1224                        &curr_actor_ids,
1225                        &fragment_actor_splits,
1226                        &no_shuffle_upstream_actor_map,
1227                    )?;
1228                    tracing::debug!(
1229                        "source backfill actor splits: {:?}, fragment_id: {}",
1230                        actor_splits,
1231                        fragment_id
1232                    );
1233                    fragment_actor_splits.insert(*fragment_id, actor_splits);
1234                }
1235            }
1236        }
1237
1238        // Generate fragment reschedule plan
1239        let mut reschedule_fragment: HashMap<FragmentId, Reschedule> =
1240            HashMap::with_capacity(reschedules.len());
1241
1242        for (fragment_id, _) in reschedules {
1243            let mut actors_to_create: HashMap<_, Vec<_>> = HashMap::new();
1244
1245            if let Some(actor_worker_maps) = fragment_actors_to_create.get(&fragment_id).cloned() {
1246                for (actor_id, worker_id) in actor_worker_maps {
1247                    actors_to_create
1248                        .entry(worker_id)
1249                        .or_default()
1250                        .push(actor_id);
1251                }
1252            }
1253
1254            let actors_to_remove = fragment_actors_to_remove
1255                .get(&fragment_id)
1256                .cloned()
1257                .unwrap_or_default()
1258                .into_keys()
1259                .collect();
1260
1261            let actors_after_reschedule = &fragment_actors_after_reschedule[&fragment_id];
1262
1263            assert!(!actors_after_reschedule.is_empty());
1264
1265            let fragment = &ctx.fragment_map[&fragment_id];
1266
1267            let in_degree_types: HashSet<_> = ctx
1268                .fragment_upstreams
1269                .get(&(fragment_id as _))
1270                .map(|upstreams| upstreams.values())
1271                .into_iter()
1272                .flatten()
1273                .cloned()
1274                .collect();
1275
1276            let upstream_dispatcher_mapping = match fragment.distribution_type() {
1277                FragmentDistributionType::Hash => {
1278                    if !in_degree_types.contains(&DispatcherType::Hash) {
1279                        None
1280                    } else {
1281                        // Changes of the bitmap must occur in the case of HashDistribution
1282                        Some(ActorMapping::from_bitmaps(
1283                            &fragment_actor_bitmap[&fragment_id],
1284                        ))
1285                    }
1286                }
1287
1288                FragmentDistributionType::Single => {
1289                    assert!(fragment_actor_bitmap.get(&fragment_id).unwrap().is_empty());
1290                    None
1291                }
1292                FragmentDistributionType::Unspecified => unreachable!(),
1293            };
1294
1295            let mut upstream_fragment_dispatcher_set = BTreeSet::new();
1296
1297            {
1298                if let Some(upstreams) = ctx.fragment_upstreams.get(&(fragment.fragment_id as _)) {
1299                    for (upstream_fragment_id, upstream_dispatcher_type) in upstreams {
1300                        match upstream_dispatcher_type {
1301                            DispatcherType::NoShuffle => {}
1302                            _ => {
1303                                upstream_fragment_dispatcher_set.insert((
1304                                    *upstream_fragment_id as FragmentId,
1305                                    fragment.fragment_id as DispatcherId,
1306                                ));
1307                            }
1308                        }
1309                    }
1310                }
1311            }
1312
1313            let downstream_fragment_ids = if let Some(downstream_fragments) =
1314                ctx.fragment_dispatcher_map.get(&fragment_id)
1315            {
1316                // Skip fragments' no-shuffle downstream, as there's no need to update the merger
1317                // (receiver) of a no-shuffle downstream
1318                downstream_fragments
1319                    .iter()
1320                    .filter(|(_, dispatcher_type)| *dispatcher_type != &DispatcherType::NoShuffle)
1321                    .map(|(fragment_id, _)| *fragment_id)
1322                    .collect_vec()
1323            } else {
1324                vec![]
1325            };
1326
1327            let vnode_bitmap_updates = match fragment.distribution_type() {
1328                FragmentDistributionType::Hash => {
1329                    let mut vnode_bitmap_updates =
1330                        fragment_actor_bitmap.remove(&fragment_id).unwrap();
1331
1332                    // We need to keep the bitmaps from changed actors only,
1333                    // otherwise the barrier will become very large with many actors
1334                    for actor_id in actors_after_reschedule.keys() {
1335                        assert!(vnode_bitmap_updates.contains_key(actor_id));
1336
1337                        // retain actor
1338                        if let Some(actor) = ctx.actor_map.get(actor_id) {
1339                            let bitmap = vnode_bitmap_updates.get(actor_id).unwrap();
1340
1341                            if let Some(prev_bitmap) = actor.vnode_bitmap.as_ref() {
1342                                if prev_bitmap.eq(bitmap) {
1343                                    vnode_bitmap_updates.remove(actor_id);
1344                                }
1345                            }
1346                        }
1347                    }
1348
1349                    vnode_bitmap_updates
1350                }
1351                FragmentDistributionType::Single => HashMap::new(),
1352                FragmentDistributionType::Unspecified => unreachable!(),
1353            };
1354
1355            let upstream_fragment_dispatcher_ids =
1356                upstream_fragment_dispatcher_set.into_iter().collect_vec();
1357
1358            let actor_splits = fragment_actor_splits
1359                .get(&fragment_id)
1360                .cloned()
1361                .unwrap_or_default();
1362
1363            reschedule_fragment.insert(
1364                fragment_id,
1365                Reschedule {
1366                    added_actors: actors_to_create,
1367                    removed_actors: actors_to_remove,
1368                    vnode_bitmap_updates,
1369                    upstream_fragment_dispatcher_ids,
1370                    upstream_dispatcher_mapping,
1371                    downstream_fragment_ids,
1372                    actor_splits,
1373                    newly_created_actors: Default::default(),
1374                },
1375            );
1376        }
1377
1378        let mut fragment_created_actors = HashMap::new();
1379        for (fragment_id, actors_to_create) in &fragment_actors_to_create {
1380            let mut created_actors = HashMap::new();
1381            for (actor_id, worker_id) in actors_to_create {
1382                let actor = new_created_actors.get(actor_id).cloned().unwrap();
1383                created_actors.insert(*actor_id, (actor, *worker_id));
1384            }
1385
1386            fragment_created_actors.insert(*fragment_id, created_actors);
1387        }
1388
1389        for (fragment_id, to_create) in fragment_created_actors {
1390            let reschedule = reschedule_fragment.get_mut(&fragment_id).unwrap();
1391            reschedule.newly_created_actors = to_create;
1392        }
1393        tracing::debug!("analyze_reschedule_plan result: {:#?}", reschedule_fragment);
1394
1395        Ok(reschedule_fragment)
1396    }
1397
1398    #[expect(clippy::type_complexity)]
1399    fn arrange_reschedules(
1400        &self,
1401        reschedule: &HashMap<FragmentId, WorkerReschedule>,
1402        ctx: &RescheduleContext,
1403    ) -> MetaResult<(
1404        HashMap<FragmentId, BTreeMap<ActorId, WorkerId>>,
1405        HashMap<FragmentId, BTreeMap<ActorId, WorkerId>>,
1406    )> {
1407        let mut fragment_actors_to_remove = HashMap::with_capacity(reschedule.len());
1408        let mut fragment_actors_to_create = HashMap::with_capacity(reschedule.len());
1409
1410        for (fragment_id, WorkerReschedule { worker_actor_diff }) in reschedule {
1411            let fragment = ctx.fragment_map.get(fragment_id).unwrap();
1412
1413            // Actor Id => Worker Id
1414            let mut actors_to_remove = BTreeMap::new();
1415            let mut actors_to_create = BTreeMap::new();
1416
1417            // NOTE(important): The value needs to be a BTreeSet to ensure that the actors on the worker are sorted in ascending order.
1418            let mut worker_to_actors = HashMap::new();
1419
1420            for actor in &fragment.actors {
1421                let worker_id = ctx.actor_id_to_worker_id(&actor.actor_id).unwrap();
1422                worker_to_actors
1423                    .entry(worker_id)
1424                    .or_insert(BTreeSet::new())
1425                    .insert(actor.actor_id as ActorId);
1426            }
1427
1428            let decreased_actor_count = worker_actor_diff
1429                .iter()
1430                .filter(|(_, change)| change.is_negative())
1431                .map(|(worker_id, change)| (worker_id, change.unsigned_abs()));
1432
1433            for (worker_id, n) in decreased_actor_count {
1434                if let Some(actor_ids) = worker_to_actors.get(worker_id) {
1435                    if actor_ids.len() < n {
1436                        bail!(
1437                            "plan illegal, for fragment {}, worker {} only has {} actors, but needs to reduce {}",
1438                            fragment_id,
1439                            worker_id,
1440                            actor_ids.len(),
1441                            n
1442                        );
1443                    }
1444
1445                    let removed_actors: Vec<_> = actor_ids
1446                        .iter()
1447                        .skip(actor_ids.len().saturating_sub(n))
1448                        .cloned()
1449                        .collect();
1450
1451                    for actor in removed_actors {
1452                        actors_to_remove.insert(actor, *worker_id);
1453                    }
1454                }
1455            }
1456
1457            let increased_actor_count = worker_actor_diff
1458                .iter()
1459                .filter(|(_, change)| change.is_positive());
1460
1461            for (worker, n) in increased_actor_count {
1462                for _ in 0..*n {
1463                    let id = self
1464                        .env
1465                        .id_gen_manager()
1466                        .generate_interval::<{ IdCategory::Actor }>(1)
1467                        as ActorId;
1468                    actors_to_create.insert(id, *worker);
1469                }
1470            }
1471
1472            if !actors_to_remove.is_empty() {
1473                fragment_actors_to_remove.insert(*fragment_id as FragmentId, actors_to_remove);
1474            }
1475
1476            if !actors_to_create.is_empty() {
1477                fragment_actors_to_create.insert(*fragment_id as FragmentId, actors_to_create);
1478            }
1479        }
1480
1481        // sanity checking
1482        for actors_to_remove in fragment_actors_to_remove.values() {
1483            for actor_id in actors_to_remove.keys() {
1484                let actor = ctx.actor_map.get(actor_id).unwrap();
1485                for dispatcher in &actor.dispatcher {
1486                    if PbDispatcherType::NoShuffle == dispatcher.get_type().unwrap() {
1487                        let downstream_actor_id = dispatcher.downstream_actor_id.iter().exactly_one().expect("there should be only one downstream actor id in NO_SHUFFLE dispatcher");
1488
1489                        let _should_exists = fragment_actors_to_remove
1490                            .get(&(dispatcher.dispatcher_id as FragmentId))
1491                            .expect("downstream fragment of NO_SHUFFLE relation should be in the removing map")
1492                            .get(downstream_actor_id)
1493                            .expect("downstream actor of NO_SHUFFLE relation should be in the removing map");
1494                    }
1495                }
1496            }
1497        }
1498
1499        Ok((fragment_actors_to_remove, fragment_actors_to_create))
1500    }
1501
1502    /// Modifies the upstream and downstream actors of the new created actor according to the
1503    /// overall changes, and is used to handle cascading updates
1504    fn modify_actor_upstream_and_downstream(
1505        ctx: &RescheduleContext,
1506        fragment_actors_to_remove: &HashMap<FragmentId, BTreeMap<ActorId, WorkerId>>,
1507        fragment_actors_to_create: &HashMap<FragmentId, BTreeMap<ActorId, WorkerId>>,
1508        fragment_actor_bitmap: &HashMap<FragmentId, HashMap<ActorId, Bitmap>>,
1509        no_shuffle_downstream_actors_map: &HashMap<ActorId, HashMap<FragmentId, ActorId>>,
1510        new_actor: &mut StreamActor,
1511        dispatchers: &mut Vec<PbDispatcher>,
1512    ) -> MetaResult<()> {
1513        // Update downstream actor ids
1514        for dispatcher in dispatchers {
1515            let downstream_fragment_id = dispatcher
1516                .downstream_actor_id
1517                .iter()
1518                .filter_map(|actor_id| ctx.actor_map.get(actor_id).map(|actor| actor.fragment_id))
1519                .dedup()
1520                .exactly_one()
1521                .unwrap() as FragmentId;
1522
1523            let downstream_fragment_actors_to_remove =
1524                fragment_actors_to_remove.get(&downstream_fragment_id);
1525            let downstream_fragment_actors_to_create =
1526                fragment_actors_to_create.get(&downstream_fragment_id);
1527
1528            match dispatcher.r#type() {
1529                d @ (PbDispatcherType::Hash
1530                | PbDispatcherType::Simple
1531                | PbDispatcherType::Broadcast) => {
1532                    if let Some(downstream_actors_to_remove) = downstream_fragment_actors_to_remove
1533                    {
1534                        dispatcher
1535                            .downstream_actor_id
1536                            .retain(|id| !downstream_actors_to_remove.contains_key(id));
1537                    }
1538
1539                    if let Some(downstream_actors_to_create) = downstream_fragment_actors_to_create
1540                    {
1541                        dispatcher
1542                            .downstream_actor_id
1543                            .extend(downstream_actors_to_create.keys().cloned())
1544                    }
1545
1546                    // There should be still exactly one downstream actor
1547                    if d == PbDispatcherType::Simple {
1548                        assert_eq!(dispatcher.downstream_actor_id.len(), 1);
1549                    }
1550                }
1551                PbDispatcherType::NoShuffle => {
1552                    assert_eq!(dispatcher.downstream_actor_id.len(), 1);
1553                    let downstream_actor_id = no_shuffle_downstream_actors_map
1554                        .get(&new_actor.actor_id)
1555                        .and_then(|map| map.get(&downstream_fragment_id))
1556                        .unwrap();
1557                    dispatcher.downstream_actor_id = vec![*downstream_actor_id as ActorId];
1558                }
1559                PbDispatcherType::Unspecified => unreachable!(),
1560            }
1561
1562            if let Some(mapping) = dispatcher.hash_mapping.as_mut() {
1563                if let Some(downstream_updated_bitmap) =
1564                    fragment_actor_bitmap.get(&downstream_fragment_id)
1565                {
1566                    // If downstream scale in/out
1567                    *mapping = ActorMapping::from_bitmaps(downstream_updated_bitmap).to_protobuf();
1568                }
1569            }
1570        }
1571
1572        Ok(())
1573    }
1574
1575    #[await_tree::instrument]
1576    pub async fn post_apply_reschedule(
1577        &self,
1578        reschedules: &HashMap<FragmentId, Reschedule>,
1579        post_updates: &JobReschedulePostUpdates,
1580    ) -> MetaResult<()> {
1581        // Update fragment info after rescheduling in meta store.
1582        self.metadata_manager
1583            .post_apply_reschedules(reschedules.clone(), post_updates)
1584            .await?;
1585
1586        // Update serving fragment info after rescheduling in meta store.
1587        if !reschedules.is_empty() {
1588            let workers = self
1589                .metadata_manager
1590                .list_active_serving_compute_nodes()
1591                .await?;
1592            let streaming_parallelisms = self
1593                .metadata_manager
1594                .running_fragment_parallelisms(Some(reschedules.keys().cloned().collect()))
1595                .await?;
1596            let serving_worker_slot_mapping = Arc::new(ServingVnodeMapping::default());
1597            let max_serving_parallelism = self
1598                .env
1599                .session_params_manager_impl_ref()
1600                .get_params()
1601                .await
1602                .batch_parallelism()
1603                .map(|p| p.get());
1604            let (upserted, failed) = serving_worker_slot_mapping.upsert(
1605                streaming_parallelisms,
1606                &workers,
1607                max_serving_parallelism,
1608            );
1609            if !upserted.is_empty() {
1610                tracing::debug!(
1611                    "Update serving vnode mapping for fragments {:?}.",
1612                    upserted.keys()
1613                );
1614                self.env
1615                    .notification_manager()
1616                    .notify_frontend_without_version(
1617                        Operation::Update,
1618                        Info::ServingWorkerSlotMappings(FragmentWorkerSlotMappings {
1619                            mappings: to_fragment_worker_slot_mapping(&upserted),
1620                        }),
1621                    );
1622            }
1623            if !failed.is_empty() {
1624                tracing::debug!(
1625                    "Fail to update serving vnode mapping for fragments {:?}.",
1626                    failed
1627                );
1628                self.env
1629                    .notification_manager()
1630                    .notify_frontend_without_version(
1631                        Operation::Delete,
1632                        Info::ServingWorkerSlotMappings(FragmentWorkerSlotMappings {
1633                            mappings: to_deleted_fragment_worker_slot_mapping(&failed),
1634                        }),
1635                    );
1636            }
1637        }
1638
1639        let mut stream_source_actor_splits = HashMap::new();
1640        let mut stream_source_dropped_actors = HashSet::new();
1641
1642        // todo: handle adaptive splits
1643        for (fragment_id, reschedule) in reschedules {
1644            if !reschedule.actor_splits.is_empty() {
1645                stream_source_actor_splits
1646                    .insert(*fragment_id as FragmentId, reschedule.actor_splits.clone());
1647                stream_source_dropped_actors.extend(reschedule.removed_actors.clone());
1648            }
1649        }
1650
1651        if !stream_source_actor_splits.is_empty() {
1652            self.source_manager
1653                .apply_source_change(SourceChange::Reschedule {
1654                    split_assignment: stream_source_actor_splits,
1655                    dropped_actors: stream_source_dropped_actors,
1656                })
1657                .await;
1658        }
1659
1660        Ok(())
1661    }
1662
1663    pub async fn generate_job_reschedule_plan(
1664        &self,
1665        policy: JobReschedulePolicy,
1666    ) -> MetaResult<JobReschedulePlan> {
1667        type VnodeCount = usize;
1668
1669        let JobReschedulePolicy { targets } = policy;
1670
1671        let workers = self
1672            .metadata_manager
1673            .list_active_streaming_compute_nodes()
1674            .await?;
1675
1676        // The `schedulable` field should eventually be replaced by resource groups like `unschedulable`
1677        let workers: HashMap<_, _> = workers
1678            .into_iter()
1679            .filter(|worker| worker.is_streaming_schedulable())
1680            .map(|worker| (worker.id, worker))
1681            .collect();
1682
1683        #[derive(Debug)]
1684        struct JobUpdate {
1685            filtered_worker_ids: BTreeSet<WorkerId>,
1686            parallelism: TableParallelism,
1687        }
1688
1689        let mut job_parallelism_updates = HashMap::new();
1690
1691        let mut job_reschedule_post_updates = JobReschedulePostUpdates {
1692            parallelism_updates: Default::default(),
1693            resource_group_updates: Default::default(),
1694        };
1695
1696        for (
1697            job_id,
1698            JobRescheduleTarget {
1699                parallelism: parallelism_update,
1700                resource_group: resource_group_update,
1701            },
1702        ) in &targets
1703        {
1704            let parallelism = match parallelism_update {
1705                JobParallelismTarget::Update(parallelism) => *parallelism,
1706                JobParallelismTarget::Refresh => {
1707                    let parallelism = self
1708                        .metadata_manager
1709                        .catalog_controller
1710                        .get_job_streaming_parallelisms(*job_id as _)
1711                        .await?;
1712
1713                    parallelism.into()
1714                }
1715            };
1716
1717            job_reschedule_post_updates
1718                .parallelism_updates
1719                .insert(TableId::from(*job_id), parallelism);
1720
1721            let current_resource_group = match resource_group_update {
1722                JobResourceGroupTarget::Update(Some(specific_resource_group)) => {
1723                    job_reschedule_post_updates.resource_group_updates.insert(
1724                        *job_id as ObjectId,
1725                        Some(specific_resource_group.to_owned()),
1726                    );
1727
1728                    specific_resource_group.to_owned()
1729                }
1730                JobResourceGroupTarget::Update(None) => {
1731                    let database_resource_group = self
1732                        .metadata_manager
1733                        .catalog_controller
1734                        .get_existing_job_database_resource_group(*job_id as _)
1735                        .await?;
1736
1737                    job_reschedule_post_updates
1738                        .resource_group_updates
1739                        .insert(*job_id as ObjectId, None);
1740                    database_resource_group
1741                }
1742                JobResourceGroupTarget::Keep => {
1743                    self.metadata_manager
1744                        .catalog_controller
1745                        .get_existing_job_resource_group(*job_id as _)
1746                        .await?
1747                }
1748            };
1749
1750            let filtered_worker_ids =
1751                filter_workers_by_resource_group(&workers, current_resource_group.as_str());
1752
1753            if filtered_worker_ids.is_empty() {
1754                bail!("Cannot resize streaming_job {job_id} to empty worker set")
1755            }
1756
1757            job_parallelism_updates.insert(
1758                *job_id,
1759                JobUpdate {
1760                    filtered_worker_ids,
1761                    parallelism,
1762                },
1763            );
1764        }
1765
1766        // index for no shuffle relation
1767        let mut no_shuffle_source_fragment_ids = HashSet::new();
1768        let mut no_shuffle_target_fragment_ids = HashSet::new();
1769
1770        // index for fragment_id -> (distribution_type, vnode_count)
1771        let mut fragment_distribution_map = HashMap::new();
1772        // index for actor -> worker id
1773        let mut actor_location = HashMap::new();
1774        // index for table_id -> [fragment_id]
1775        let mut table_fragment_id_map = HashMap::new();
1776        // index for fragment_id -> [actor_id]
1777        let mut fragment_actor_id_map = HashMap::new();
1778
1779        async fn build_index(
1780            no_shuffle_source_fragment_ids: &mut HashSet<FragmentId>,
1781            no_shuffle_target_fragment_ids: &mut HashSet<FragmentId>,
1782            fragment_distribution_map: &mut HashMap<
1783                FragmentId,
1784                (FragmentDistributionType, VnodeCount),
1785            >,
1786            actor_location: &mut HashMap<ActorId, WorkerId>,
1787            table_fragment_id_map: &mut HashMap<u32, HashSet<FragmentId>>,
1788            fragment_actor_id_map: &mut HashMap<FragmentId, HashSet<u32>>,
1789            mgr: &MetadataManager,
1790            table_ids: Vec<ObjectId>,
1791        ) -> Result<(), MetaError> {
1792            let RescheduleWorkingSet {
1793                fragments,
1794                actors,
1795                actor_dispatchers: _actor_dispatchers,
1796                fragment_downstreams,
1797                fragment_upstreams: _fragment_upstreams,
1798                related_jobs: _related_jobs,
1799                job_resource_groups: _job_resource_groups,
1800            } = mgr
1801                .catalog_controller
1802                .resolve_working_set_for_reschedule_tables(table_ids)
1803                .await?;
1804
1805            for (fragment_id, downstreams) in fragment_downstreams {
1806                for (downstream_fragment_id, dispatcher_type) in downstreams {
1807                    if let risingwave_meta_model::DispatcherType::NoShuffle = dispatcher_type {
1808                        no_shuffle_source_fragment_ids.insert(fragment_id as FragmentId);
1809                        no_shuffle_target_fragment_ids.insert(downstream_fragment_id as FragmentId);
1810                    }
1811                }
1812            }
1813
1814            for (fragment_id, fragment) in fragments {
1815                fragment_distribution_map.insert(
1816                    fragment_id as FragmentId,
1817                    (
1818                        FragmentDistributionType::from(fragment.distribution_type),
1819                        fragment.vnode_count as _,
1820                    ),
1821                );
1822
1823                table_fragment_id_map
1824                    .entry(fragment.job_id as u32)
1825                    .or_default()
1826                    .insert(fragment_id as FragmentId);
1827            }
1828
1829            for (actor_id, actor) in actors {
1830                actor_location.insert(actor_id as ActorId, actor.worker_id as WorkerId);
1831                fragment_actor_id_map
1832                    .entry(actor.fragment_id as FragmentId)
1833                    .or_default()
1834                    .insert(actor_id as ActorId);
1835            }
1836
1837            Ok(())
1838        }
1839
1840        let table_ids = targets.keys().map(|id| *id as ObjectId).collect();
1841
1842        build_index(
1843            &mut no_shuffle_source_fragment_ids,
1844            &mut no_shuffle_target_fragment_ids,
1845            &mut fragment_distribution_map,
1846            &mut actor_location,
1847            &mut table_fragment_id_map,
1848            &mut fragment_actor_id_map,
1849            &self.metadata_manager,
1850            table_ids,
1851        )
1852        .await?;
1853        tracing::debug!(
1854            ?job_reschedule_post_updates,
1855            ?job_parallelism_updates,
1856            ?no_shuffle_source_fragment_ids,
1857            ?no_shuffle_target_fragment_ids,
1858            ?fragment_distribution_map,
1859            ?actor_location,
1860            ?table_fragment_id_map,
1861            ?fragment_actor_id_map,
1862            "generate_table_resize_plan, after build_index"
1863        );
1864
1865        let adaptive_parallelism_strategy = self
1866            .env
1867            .system_params_reader()
1868            .await
1869            .adaptive_parallelism_strategy();
1870
1871        let mut target_plan = HashMap::new();
1872
1873        for (
1874            table_id,
1875            JobUpdate {
1876                filtered_worker_ids,
1877                parallelism,
1878            },
1879        ) in job_parallelism_updates
1880        {
1881            let fragment_map = table_fragment_id_map.remove(&table_id).unwrap();
1882
1883            let available_worker_slots = workers
1884                .iter()
1885                .filter(|(id, _)| filtered_worker_ids.contains(&(**id as WorkerId)))
1886                .map(|(_, worker)| (worker.id as WorkerId, worker.compute_node_parallelism()))
1887                .collect::<BTreeMap<_, _>>();
1888
1889            for fragment_id in fragment_map {
1890                // Currently, all of our NO_SHUFFLE relation propagations are only transmitted from upstream to downstream.
1891                if no_shuffle_target_fragment_ids.contains(&fragment_id) {
1892                    continue;
1893                }
1894
1895                let mut fragment_slots: BTreeMap<WorkerId, usize> = BTreeMap::new();
1896
1897                for actor_id in &fragment_actor_id_map[&fragment_id] {
1898                    let worker_id = actor_location[actor_id];
1899                    *fragment_slots.entry(worker_id).or_default() += 1;
1900                }
1901
1902                let available_slot_count: usize = available_worker_slots.values().cloned().sum();
1903
1904                if available_slot_count == 0 {
1905                    bail!(
1906                        "No schedulable slots available for fragment {}",
1907                        fragment_id
1908                    );
1909                }
1910
1911                let (dist, vnode_count) = fragment_distribution_map[&fragment_id];
1912                let max_parallelism = vnode_count;
1913
1914                match dist {
1915                    FragmentDistributionType::Unspecified => unreachable!(),
1916                    FragmentDistributionType::Single => {
1917                        let (single_worker_id, should_be_one) = fragment_slots
1918                            .iter()
1919                            .exactly_one()
1920                            .expect("single fragment should have only one worker slot");
1921
1922                        assert_eq!(*should_be_one, 1);
1923
1924                        let units = schedule_units_for_slots(&available_worker_slots, 1, table_id)?;
1925
1926                        let (chosen_target_worker_id, should_be_one) =
1927                            units.iter().exactly_one().ok().with_context(|| {
1928                                format!(
1929                                    "Cannot find a single target worker for fragment {fragment_id}"
1930                                )
1931                            })?;
1932
1933                        assert_eq!(*should_be_one, 1);
1934
1935                        if *chosen_target_worker_id == *single_worker_id {
1936                            tracing::debug!(
1937                                "single fragment {fragment_id} already on target worker {chosen_target_worker_id}"
1938                            );
1939                            continue;
1940                        }
1941
1942                        target_plan.insert(
1943                            fragment_id,
1944                            WorkerReschedule {
1945                                worker_actor_diff: BTreeMap::from_iter(vec![
1946                                    (*chosen_target_worker_id, 1),
1947                                    (*single_worker_id, -1),
1948                                ]),
1949                            },
1950                        );
1951                    }
1952                    FragmentDistributionType::Hash => match parallelism {
1953                        TableParallelism::Adaptive => {
1954                            let target_slot_count = adaptive_parallelism_strategy
1955                                .compute_target_parallelism(available_slot_count);
1956
1957                            if target_slot_count > max_parallelism {
1958                                tracing::warn!(
1959                                    "available parallelism for table {table_id} is larger than max parallelism, force limit to {max_parallelism}"
1960                                );
1961                                // force limit to `max_parallelism`
1962                                let target_worker_slots = schedule_units_for_slots(
1963                                    &available_worker_slots,
1964                                    max_parallelism,
1965                                    table_id,
1966                                )?;
1967
1968                                target_plan.insert(
1969                                    fragment_id,
1970                                    Self::diff_worker_slot_changes(
1971                                        &fragment_slots,
1972                                        &target_worker_slots,
1973                                    ),
1974                                );
1975                            } else if available_slot_count != target_slot_count {
1976                                tracing::info!(
1977                                    "available parallelism for table {table_id} is limit by adaptive strategy {adaptive_parallelism_strategy}, resetting to {target_slot_count}"
1978                                );
1979                                let target_worker_slots = schedule_units_for_slots(
1980                                    &available_worker_slots,
1981                                    target_slot_count,
1982                                    table_id,
1983                                )?;
1984
1985                                target_plan.insert(
1986                                    fragment_id,
1987                                    Self::diff_worker_slot_changes(
1988                                        &fragment_slots,
1989                                        &target_worker_slots,
1990                                    ),
1991                                );
1992                            } else {
1993                                target_plan.insert(
1994                                    fragment_id,
1995                                    Self::diff_worker_slot_changes(
1996                                        &fragment_slots,
1997                                        &available_worker_slots,
1998                                    ),
1999                                );
2000                            }
2001                        }
2002                        TableParallelism::Fixed(mut n) => {
2003                            if n > max_parallelism {
2004                                tracing::warn!(
2005                                    "specified parallelism {n} for table {table_id} is larger than max parallelism, force limit to {max_parallelism}"
2006                                );
2007                                n = max_parallelism
2008                            }
2009
2010                            let target_worker_slots =
2011                                schedule_units_for_slots(&available_worker_slots, n, table_id)?;
2012
2013                            target_plan.insert(
2014                                fragment_id,
2015                                Self::diff_worker_slot_changes(
2016                                    &fragment_slots,
2017                                    &target_worker_slots,
2018                                ),
2019                            );
2020                        }
2021                        TableParallelism::Custom => {
2022                            // skipping for custom
2023                        }
2024                    },
2025                }
2026            }
2027        }
2028
2029        target_plan.retain(|_, plan| !plan.worker_actor_diff.is_empty());
2030        tracing::debug!(
2031            ?target_plan,
2032            "generate_table_resize_plan finished target_plan"
2033        );
2034
2035        Ok(JobReschedulePlan {
2036            reschedules: target_plan,
2037            post_updates: job_reschedule_post_updates,
2038        })
2039    }
2040
2041    fn diff_worker_slot_changes(
2042        fragment_worker_slots: &BTreeMap<WorkerId, usize>,
2043        target_worker_slots: &BTreeMap<WorkerId, usize>,
2044    ) -> WorkerReschedule {
2045        let mut increased_actor_count: BTreeMap<WorkerId, usize> = BTreeMap::new();
2046        let mut decreased_actor_count: BTreeMap<WorkerId, usize> = BTreeMap::new();
2047
2048        for (&worker_id, &target_slots) in target_worker_slots {
2049            let &current_slots = fragment_worker_slots.get(&worker_id).unwrap_or(&0);
2050
2051            if target_slots > current_slots {
2052                increased_actor_count.insert(worker_id, target_slots - current_slots);
2053            }
2054        }
2055
2056        for (&worker_id, &current_slots) in fragment_worker_slots {
2057            let &target_slots = target_worker_slots.get(&worker_id).unwrap_or(&0);
2058
2059            if current_slots > target_slots {
2060                decreased_actor_count.insert(worker_id, current_slots - target_slots);
2061            }
2062        }
2063
2064        let worker_ids: HashSet<_> = increased_actor_count
2065            .keys()
2066            .chain(decreased_actor_count.keys())
2067            .cloned()
2068            .collect();
2069
2070        let mut worker_actor_diff = BTreeMap::new();
2071
2072        for worker_id in worker_ids {
2073            let increased = increased_actor_count.remove(&worker_id).unwrap_or(0) as isize;
2074            let decreased = decreased_actor_count.remove(&worker_id).unwrap_or(0) as isize;
2075            let change = increased - decreased;
2076
2077            assert_ne!(change, 0);
2078
2079            worker_actor_diff.insert(worker_id, change);
2080        }
2081
2082        WorkerReschedule { worker_actor_diff }
2083    }
2084
2085    fn build_no_shuffle_relation_index(
2086        actor_map: &HashMap<ActorId, CustomActorInfo>,
2087        no_shuffle_source_fragment_ids: &mut HashSet<FragmentId>,
2088        no_shuffle_target_fragment_ids: &mut HashSet<FragmentId>,
2089    ) {
2090        let mut fragment_cache = HashSet::new();
2091        for actor in actor_map.values() {
2092            if fragment_cache.contains(&actor.fragment_id) {
2093                continue;
2094            }
2095
2096            for dispatcher in &actor.dispatcher {
2097                for downstream_actor_id in &dispatcher.downstream_actor_id {
2098                    if let Some(downstream_actor) = actor_map.get(downstream_actor_id) {
2099                        // Checking for no shuffle dispatchers
2100                        if dispatcher.r#type() == PbDispatcherType::NoShuffle {
2101                            no_shuffle_source_fragment_ids.insert(actor.fragment_id as FragmentId);
2102                            no_shuffle_target_fragment_ids
2103                                .insert(downstream_actor.fragment_id as FragmentId);
2104                        }
2105                    }
2106                }
2107            }
2108
2109            fragment_cache.insert(actor.fragment_id);
2110        }
2111    }
2112
2113    fn build_fragment_dispatcher_index(
2114        actor_map: &HashMap<ActorId, CustomActorInfo>,
2115        fragment_dispatcher_map: &mut HashMap<FragmentId, HashMap<FragmentId, DispatcherType>>,
2116    ) {
2117        for actor in actor_map.values() {
2118            for dispatcher in &actor.dispatcher {
2119                for downstream_actor_id in &dispatcher.downstream_actor_id {
2120                    if let Some(downstream_actor) = actor_map.get(downstream_actor_id) {
2121                        fragment_dispatcher_map
2122                            .entry(actor.fragment_id as FragmentId)
2123                            .or_default()
2124                            .insert(
2125                                downstream_actor.fragment_id as FragmentId,
2126                                dispatcher.r#type().into(),
2127                            );
2128                    }
2129                }
2130            }
2131        }
2132    }
2133
2134    pub fn resolve_no_shuffle_upstream_tables(
2135        fragment_ids: HashSet<FragmentId>,
2136        no_shuffle_source_fragment_ids: &HashSet<FragmentId>,
2137        no_shuffle_target_fragment_ids: &HashSet<FragmentId>,
2138        fragment_to_table: &HashMap<FragmentId, TableId>,
2139        fragment_upstreams: &HashMap<
2140            risingwave_meta_model::FragmentId,
2141            HashMap<risingwave_meta_model::FragmentId, DispatcherType>,
2142        >,
2143        table_parallelisms: &mut HashMap<TableId, TableParallelism>,
2144    ) -> MetaResult<()> {
2145        let mut queue: VecDeque<FragmentId> = fragment_ids.iter().cloned().collect();
2146
2147        let mut fragment_ids = fragment_ids;
2148
2149        // We trace the upstreams of each downstream under the hierarchy until we reach the top
2150        // for every no_shuffle relation.
2151        while let Some(fragment_id) = queue.pop_front() {
2152            if !no_shuffle_target_fragment_ids.contains(&fragment_id) {
2153                continue;
2154            }
2155
2156            // for upstream
2157            for upstream_fragment_id in fragment_upstreams
2158                .get(&(fragment_id as _))
2159                .map(|upstreams| upstreams.keys())
2160                .into_iter()
2161                .flatten()
2162            {
2163                let upstream_fragment_id = *upstream_fragment_id as FragmentId;
2164                let upstream_fragment_id = &upstream_fragment_id;
2165                if !no_shuffle_source_fragment_ids.contains(upstream_fragment_id) {
2166                    continue;
2167                }
2168
2169                let table_id = &fragment_to_table[&fragment_id];
2170                let upstream_table_id = &fragment_to_table[upstream_fragment_id];
2171
2172                // Only custom parallelism will be propagated to the no shuffle upstream.
2173                if let Some(TableParallelism::Custom) = table_parallelisms.get(table_id) {
2174                    if let Some(upstream_table_parallelism) =
2175                        table_parallelisms.get(upstream_table_id)
2176                    {
2177                        if upstream_table_parallelism != &TableParallelism::Custom {
2178                            bail!(
2179                                "Cannot change upstream table {} from {:?} to {:?}",
2180                                upstream_table_id,
2181                                upstream_table_parallelism,
2182                                TableParallelism::Custom
2183                            )
2184                        }
2185                    } else {
2186                        table_parallelisms.insert(*upstream_table_id, TableParallelism::Custom);
2187                    }
2188                }
2189
2190                fragment_ids.insert(*upstream_fragment_id);
2191                queue.push_back(*upstream_fragment_id);
2192            }
2193        }
2194
2195        let downstream_fragment_ids = fragment_ids
2196            .iter()
2197            .filter(|fragment_id| no_shuffle_target_fragment_ids.contains(fragment_id));
2198
2199        let downstream_table_ids = downstream_fragment_ids
2200            .map(|fragment_id| fragment_to_table.get(fragment_id).unwrap())
2201            .collect::<HashSet<_>>();
2202
2203        table_parallelisms.retain(|table_id, _| !downstream_table_ids.contains(table_id));
2204
2205        Ok(())
2206    }
2207
2208    pub fn resolve_no_shuffle_upstream_fragments<T>(
2209        reschedule: &mut HashMap<FragmentId, T>,
2210        no_shuffle_source_fragment_ids: &HashSet<FragmentId>,
2211        no_shuffle_target_fragment_ids: &HashSet<FragmentId>,
2212        fragment_upstreams: &HashMap<
2213            risingwave_meta_model::FragmentId,
2214            HashMap<risingwave_meta_model::FragmentId, DispatcherType>,
2215        >,
2216    ) -> MetaResult<()>
2217    where
2218        T: Clone + Eq,
2219    {
2220        let mut queue: VecDeque<FragmentId> = reschedule.keys().cloned().collect();
2221
2222        // We trace the upstreams of each downstream under the hierarchy until we reach the top
2223        // for every no_shuffle relation.
2224        while let Some(fragment_id) = queue.pop_front() {
2225            if !no_shuffle_target_fragment_ids.contains(&fragment_id) {
2226                continue;
2227            }
2228
2229            // for upstream
2230            for upstream_fragment_id in fragment_upstreams
2231                .get(&(fragment_id as _))
2232                .map(|upstreams| upstreams.keys())
2233                .into_iter()
2234                .flatten()
2235            {
2236                let upstream_fragment_id = *upstream_fragment_id as FragmentId;
2237                let upstream_fragment_id = &upstream_fragment_id;
2238                if !no_shuffle_source_fragment_ids.contains(upstream_fragment_id) {
2239                    continue;
2240                }
2241
2242                let reschedule_plan = &reschedule[&fragment_id];
2243
2244                if let Some(upstream_reschedule_plan) = reschedule.get(upstream_fragment_id) {
2245                    if upstream_reschedule_plan != reschedule_plan {
2246                        bail!(
2247                            "Inconsistent NO_SHUFFLE plan, check target worker ids of fragment {} and {}",
2248                            fragment_id,
2249                            upstream_fragment_id
2250                        );
2251                    }
2252
2253                    continue;
2254                }
2255
2256                reschedule.insert(*upstream_fragment_id, reschedule_plan.clone());
2257
2258                queue.push_back(*upstream_fragment_id);
2259            }
2260        }
2261
2262        reschedule.retain(|fragment_id, _| !no_shuffle_target_fragment_ids.contains(fragment_id));
2263
2264        Ok(())
2265    }
2266
2267    pub async fn resolve_related_no_shuffle_jobs(
2268        &self,
2269        jobs: &[TableId],
2270    ) -> MetaResult<HashSet<TableId>> {
2271        let RescheduleWorkingSet { related_jobs, .. } = self
2272            .metadata_manager
2273            .catalog_controller
2274            .resolve_working_set_for_reschedule_tables(
2275                jobs.iter().map(|id| id.table_id as _).collect(),
2276            )
2277            .await?;
2278
2279        Ok(related_jobs
2280            .keys()
2281            .map(|id| TableId::new(*id as _))
2282            .collect())
2283    }
2284}
2285
2286#[derive(Debug, Clone)]
2287pub enum JobParallelismTarget {
2288    Update(TableParallelism),
2289    Refresh,
2290}
2291
2292#[derive(Debug, Clone)]
2293pub enum JobResourceGroupTarget {
2294    Update(Option<String>),
2295    Keep,
2296}
2297
2298#[derive(Debug, Clone)]
2299pub struct JobRescheduleTarget {
2300    pub parallelism: JobParallelismTarget,
2301    pub resource_group: JobResourceGroupTarget,
2302}
2303
2304#[derive(Debug)]
2305pub struct JobReschedulePolicy {
2306    pub(crate) targets: HashMap<u32, JobRescheduleTarget>,
2307}
2308
2309// final updates for `post_collect`
2310#[derive(Debug, Clone)]
2311pub struct JobReschedulePostUpdates {
2312    pub parallelism_updates: HashMap<TableId, TableParallelism>,
2313    pub resource_group_updates: HashMap<ObjectId, Option<String>>,
2314}
2315
2316#[derive(Debug)]
2317pub struct JobReschedulePlan {
2318    pub reschedules: HashMap<FragmentId, WorkerReschedule>,
2319    pub post_updates: JobReschedulePostUpdates,
2320}
2321
2322impl GlobalStreamManager {
2323    #[await_tree::instrument("acquire_reschedule_read_guard")]
2324    pub async fn reschedule_lock_read_guard(&self) -> RwLockReadGuard<'_, ()> {
2325        self.scale_controller.reschedule_lock.read().await
2326    }
2327
2328    #[await_tree::instrument("acquire_reschedule_write_guard")]
2329    pub async fn reschedule_lock_write_guard(&self) -> RwLockWriteGuard<'_, ()> {
2330        self.scale_controller.reschedule_lock.write().await
2331    }
2332
2333    /// The entrypoint of rescheduling actors.
2334    ///
2335    /// Used by:
2336    /// - The directly exposed low-level API `risingwave_meta_service::scale_service::ScaleService` (`risectl meta reschedule`)
2337    /// - High-level parallelism control API
2338    ///     * manual `ALTER [TABLE | INDEX | MATERIALIZED VIEW | SINK] SET PARALLELISM`
2339    ///     * automatic parallelism control for [`TableParallelism::Adaptive`] when worker nodes changed
2340    pub async fn reschedule_actors(
2341        &self,
2342        database_id: DatabaseId,
2343        plan: JobReschedulePlan,
2344        options: RescheduleOptions,
2345    ) -> MetaResult<()> {
2346        let JobReschedulePlan {
2347            reschedules,
2348            mut post_updates,
2349        } = plan;
2350
2351        let reschedule_fragment = self
2352            .scale_controller
2353            .analyze_reschedule_plan(reschedules, options, &mut post_updates.parallelism_updates)
2354            .await?;
2355
2356        tracing::debug!("reschedule plan: {:?}", reschedule_fragment);
2357
2358        let up_down_stream_fragment: HashSet<_> = reschedule_fragment
2359            .iter()
2360            .flat_map(|(_, reschedule)| {
2361                reschedule
2362                    .upstream_fragment_dispatcher_ids
2363                    .iter()
2364                    .map(|(fragment_id, _)| *fragment_id)
2365                    .chain(reschedule.downstream_fragment_ids.iter().cloned())
2366            })
2367            .collect();
2368
2369        let fragment_actors =
2370            try_join_all(up_down_stream_fragment.iter().map(|fragment_id| async {
2371                let actor_ids = self
2372                    .metadata_manager
2373                    .get_running_actors_of_fragment(*fragment_id)
2374                    .await?;
2375                Result::<_, MetaError>::Ok((*fragment_id, actor_ids))
2376            }))
2377            .await?
2378            .into_iter()
2379            .collect();
2380
2381        let command = Command::RescheduleFragment {
2382            reschedules: reschedule_fragment,
2383            fragment_actors,
2384            post_updates,
2385        };
2386
2387        let _guard = self.source_manager.pause_tick().await;
2388
2389        self.barrier_scheduler
2390            .run_command(database_id, command)
2391            .await?;
2392
2393        tracing::info!("reschedule done");
2394
2395        Ok(())
2396    }
2397
2398    /// When new worker nodes joined, or the parallelism of existing worker nodes changed,
2399    /// examines if there are any jobs can be scaled, and scales them if found.
2400    ///
2401    /// This method will iterate over all `CREATED` jobs, and can be repeatedly called.
2402    ///
2403    /// Returns
2404    /// - `Ok(false)` if no jobs can be scaled;
2405    /// - `Ok(true)` if some jobs are scaled, and it is possible that there are more jobs can be scaled.
2406    async fn trigger_parallelism_control(&self) -> MetaResult<bool> {
2407        tracing::info!("trigger parallelism control");
2408
2409        let _reschedule_job_lock = self.reschedule_lock_write_guard().await;
2410
2411        let background_streaming_jobs = self
2412            .metadata_manager
2413            .list_background_creating_jobs()
2414            .await?;
2415
2416        let skipped_jobs = if !background_streaming_jobs.is_empty() {
2417            let jobs = self
2418                .scale_controller
2419                .resolve_related_no_shuffle_jobs(&background_streaming_jobs)
2420                .await?;
2421
2422            tracing::info!(
2423                "skipping parallelism control of background jobs {:?} and associated jobs {:?}",
2424                background_streaming_jobs,
2425                jobs
2426            );
2427
2428            jobs
2429        } else {
2430            HashSet::new()
2431        };
2432
2433        let job_ids: HashSet<_> = {
2434            let streaming_parallelisms = self
2435                .metadata_manager
2436                .catalog_controller
2437                .get_all_streaming_parallelisms()
2438                .await?;
2439
2440            streaming_parallelisms
2441                .into_iter()
2442                .filter(|(table_id, _)| !skipped_jobs.contains(&TableId::new(*table_id as _)))
2443                .map(|(table_id, _)| table_id)
2444                .collect()
2445        };
2446
2447        let workers = self
2448            .metadata_manager
2449            .cluster_controller
2450            .list_active_streaming_workers()
2451            .await?;
2452
2453        let schedulable_worker_ids: BTreeSet<_> = workers
2454            .iter()
2455            .filter(|worker| {
2456                !worker
2457                    .property
2458                    .as_ref()
2459                    .map(|p| p.is_unschedulable)
2460                    .unwrap_or(false)
2461            })
2462            .map(|worker| worker.id as WorkerId)
2463            .collect();
2464
2465        if job_ids.is_empty() {
2466            tracing::info!("no streaming jobs for scaling, maybe an empty cluster");
2467            return Ok(false);
2468        }
2469
2470        let batch_size = match self.env.opts.parallelism_control_batch_size {
2471            0 => job_ids.len(),
2472            n => n,
2473        };
2474
2475        tracing::info!(
2476            "total {} streaming jobs, batch size {}, schedulable worker ids: {:?}",
2477            job_ids.len(),
2478            batch_size,
2479            schedulable_worker_ids
2480        );
2481
2482        let batches: Vec<_> = job_ids
2483            .into_iter()
2484            .chunks(batch_size)
2485            .into_iter()
2486            .map(|chunk| chunk.collect_vec())
2487            .collect();
2488
2489        let mut reschedules = None;
2490
2491        for batch in batches {
2492            let targets: HashMap<_, _> = batch
2493                .into_iter()
2494                .map(|job_id| {
2495                    (
2496                        job_id as u32,
2497                        JobRescheduleTarget {
2498                            parallelism: JobParallelismTarget::Refresh,
2499                            resource_group: JobResourceGroupTarget::Keep,
2500                        },
2501                    )
2502                })
2503                .collect();
2504
2505            let plan = self
2506                .scale_controller
2507                .generate_job_reschedule_plan(JobReschedulePolicy { targets })
2508                .await?;
2509
2510            if !plan.reschedules.is_empty() {
2511                tracing::info!("reschedule plan generated for streaming jobs {:?}", plan);
2512                reschedules = Some(plan);
2513                break;
2514            }
2515        }
2516
2517        let Some(plan) = reschedules else {
2518            tracing::info!("no reschedule plan generated");
2519            return Ok(false);
2520        };
2521
2522        // todo
2523        for (database_id, reschedules) in self
2524            .metadata_manager
2525            .split_fragment_map_by_database(plan.reschedules)
2526            .await?
2527        {
2528            self.reschedule_actors(
2529                database_id,
2530                JobReschedulePlan {
2531                    reschedules,
2532                    post_updates: plan.post_updates.clone(),
2533                },
2534                RescheduleOptions {
2535                    resolve_no_shuffle_upstream: false,
2536                    skip_create_new_actors: false,
2537                },
2538            )
2539            .await?;
2540        }
2541
2542        Ok(true)
2543    }
2544
2545    /// Handles notification of worker node activation and deletion, and triggers parallelism control.
2546    async fn run(&self, mut shutdown_rx: Receiver<()>) {
2547        tracing::info!("starting automatic parallelism control monitor");
2548
2549        let check_period =
2550            Duration::from_secs(self.env.opts.parallelism_control_trigger_period_sec);
2551
2552        let mut ticker = tokio::time::interval_at(
2553            Instant::now()
2554                + Duration::from_secs(self.env.opts.parallelism_control_trigger_first_delay_sec),
2555            check_period,
2556        );
2557        ticker.set_missed_tick_behavior(MissedTickBehavior::Skip);
2558
2559        // waiting for the first tick
2560        ticker.tick().await;
2561
2562        let (local_notification_tx, mut local_notification_rx) =
2563            tokio::sync::mpsc::unbounded_channel();
2564
2565        self.env
2566            .notification_manager()
2567            .insert_local_sender(local_notification_tx)
2568            .await;
2569
2570        let worker_nodes = self
2571            .metadata_manager
2572            .list_active_streaming_compute_nodes()
2573            .await
2574            .expect("list active streaming compute nodes");
2575
2576        let mut worker_cache: BTreeMap<_, _> = worker_nodes
2577            .into_iter()
2578            .map(|worker| (worker.id, worker))
2579            .collect();
2580
2581        let mut previous_adaptive_parallelism_strategy = AdaptiveParallelismStrategy::default();
2582
2583        let mut should_trigger = false;
2584
2585        loop {
2586            tokio::select! {
2587                biased;
2588
2589                _ = &mut shutdown_rx => {
2590                    tracing::info!("Stream manager is stopped");
2591                    break;
2592                }
2593
2594                _ = ticker.tick(), if should_trigger => {
2595                    let include_workers = worker_cache.keys().copied().collect_vec();
2596
2597                    if include_workers.is_empty() {
2598                        tracing::debug!("no available worker nodes");
2599                        should_trigger = false;
2600                        continue;
2601                    }
2602
2603                    match self.trigger_parallelism_control().await {
2604                        Ok(cont) => {
2605                            should_trigger = cont;
2606                        }
2607                        Err(e) => {
2608                            tracing::warn!(error = %e.as_report(), "Failed to trigger scale out, waiting for next tick to retry after {}s", ticker.period().as_secs());
2609                            ticker.reset();
2610                        }
2611                    }
2612                }
2613
2614                notification = local_notification_rx.recv() => {
2615                    let notification = notification.expect("local notification channel closed in loop of stream manager");
2616
2617                    // Only maintain the cache for streaming compute nodes.
2618                    let worker_is_streaming_compute = |worker: &WorkerNode| {
2619                        worker.get_type() == Ok(WorkerType::ComputeNode)
2620                            && worker.property.as_ref().unwrap().is_streaming
2621                    };
2622
2623                    match notification {
2624                        LocalNotification::SystemParamsChange(reader) => {
2625                            let new_strategy = reader.adaptive_parallelism_strategy();
2626                            if new_strategy != previous_adaptive_parallelism_strategy {
2627                                tracing::info!("adaptive parallelism strategy changed from {:?} to {:?}", previous_adaptive_parallelism_strategy, new_strategy);
2628                                should_trigger = true;
2629                                previous_adaptive_parallelism_strategy = new_strategy;
2630                            }
2631                        }
2632                        LocalNotification::WorkerNodeActivated(worker) => {
2633                            if !worker_is_streaming_compute(&worker) {
2634                                continue;
2635                            }
2636
2637                            tracing::info!(worker = worker.id, "worker activated notification received");
2638
2639                            let prev_worker = worker_cache.insert(worker.id, worker.clone());
2640
2641                            match prev_worker {
2642                                Some(prev_worker) if prev_worker.compute_node_parallelism() != worker.compute_node_parallelism()  => {
2643                                    tracing::info!(worker = worker.id, "worker parallelism changed");
2644                                    should_trigger = true;
2645                                }
2646                                Some(prev_worker) if  prev_worker.resource_group() != worker.resource_group()  => {
2647                                    tracing::info!(worker = worker.id, "worker label changed");
2648                                    should_trigger = true;
2649                                }
2650                                None => {
2651                                    tracing::info!(worker = worker.id, "new worker joined");
2652                                    should_trigger = true;
2653                                }
2654                                _ => {}
2655                            }
2656                        }
2657
2658                        // Since our logic for handling passive scale-in is within the barrier manager,
2659                        // there’s not much we can do here. All we can do is proactively remove the entries from our cache.
2660                        LocalNotification::WorkerNodeDeleted(worker) => {
2661                            if !worker_is_streaming_compute(&worker) {
2662                                continue;
2663                            }
2664
2665                            match worker_cache.remove(&worker.id) {
2666                                Some(prev_worker) => {
2667                                    tracing::info!(worker = prev_worker.id, "worker removed from stream manager cache");
2668                                }
2669                                None => {
2670                                    tracing::warn!(worker = worker.id, "worker not found in stream manager cache, but it was removed");
2671                                }
2672                            }
2673                        }
2674
2675                        _ => {}
2676                    }
2677                }
2678            }
2679        }
2680    }
2681
2682    pub fn start_auto_parallelism_monitor(
2683        self: Arc<Self>,
2684    ) -> (JoinHandle<()>, oneshot::Sender<()>) {
2685        tracing::info!("Automatic parallelism scale-out is enabled for streaming jobs");
2686        let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
2687        let join_handle = tokio::spawn(async move {
2688            self.run(shutdown_rx).await;
2689        });
2690
2691        (join_handle, shutdown_tx)
2692    }
2693}
2694
2695pub fn schedule_units_for_slots(
2696    slots: &BTreeMap<WorkerId, usize>,
2697    total_unit_size: usize,
2698    salt: u32,
2699) -> MetaResult<BTreeMap<WorkerId, usize>> {
2700    let mut ch = ConsistentHashRing::new(salt);
2701
2702    for (worker_id, parallelism) in slots {
2703        ch.add_worker(*worker_id as _, *parallelism as u32);
2704    }
2705
2706    let target_distribution = ch.distribute_tasks(total_unit_size as u32)?;
2707
2708    Ok(target_distribution
2709        .into_iter()
2710        .map(|(worker_id, task_count)| (worker_id as WorkerId, task_count as usize))
2711        .collect())
2712}
2713
2714pub struct ConsistentHashRing {
2715    ring: BTreeMap<u64, u32>,
2716    weights: BTreeMap<u32, u32>,
2717    virtual_nodes: u32,
2718    salt: u32,
2719}
2720
2721impl ConsistentHashRing {
2722    fn new(salt: u32) -> Self {
2723        ConsistentHashRing {
2724            ring: BTreeMap::new(),
2725            weights: BTreeMap::new(),
2726            virtual_nodes: 1024,
2727            salt,
2728        }
2729    }
2730
2731    fn hash<T: Hash, S: Hash>(key: T, salt: S) -> u64 {
2732        let mut hasher = DefaultHasher::new();
2733        salt.hash(&mut hasher);
2734        key.hash(&mut hasher);
2735        hasher.finish()
2736    }
2737
2738    fn add_worker(&mut self, id: u32, weight: u32) {
2739        let virtual_nodes_count = self.virtual_nodes;
2740
2741        for i in 0..virtual_nodes_count {
2742            let virtual_node_key = (id, i);
2743            let hash = Self::hash(virtual_node_key, self.salt);
2744            self.ring.insert(hash, id);
2745        }
2746
2747        self.weights.insert(id, weight);
2748    }
2749
2750    fn distribute_tasks(&self, total_tasks: u32) -> MetaResult<BTreeMap<u32, u32>> {
2751        let total_weight = self.weights.values().sum::<u32>();
2752
2753        let mut soft_limits = HashMap::new();
2754        for (worker_id, worker_capacity) in &self.weights {
2755            soft_limits.insert(
2756                *worker_id,
2757                (total_tasks as f64 * (*worker_capacity as f64 / total_weight as f64)).ceil()
2758                    as u32,
2759            );
2760        }
2761
2762        let mut task_distribution: BTreeMap<u32, u32> = BTreeMap::new();
2763        let mut task_hashes = (0..total_tasks)
2764            .map(|task_idx| Self::hash(task_idx, self.salt))
2765            .collect_vec();
2766
2767        // Sort task hashes to disperse them around the hash ring
2768        task_hashes.sort();
2769
2770        for task_hash in task_hashes {
2771            let mut assigned = false;
2772
2773            // Iterator that starts from the current task_hash or the next node in the ring
2774            let ring_range = self.ring.range(task_hash..).chain(self.ring.iter());
2775
2776            for (_, &worker_id) in ring_range {
2777                let task_limit = soft_limits[&worker_id];
2778
2779                let worker_task_count = task_distribution.entry(worker_id).or_insert(0);
2780
2781                if *worker_task_count < task_limit {
2782                    *worker_task_count += 1;
2783                    assigned = true;
2784                    break;
2785                }
2786            }
2787
2788            if !assigned {
2789                bail!("Could not distribute tasks due to capacity constraints.");
2790            }
2791        }
2792
2793        Ok(task_distribution)
2794    }
2795}
2796
2797#[cfg(test)]
2798mod tests {
2799    use super::*;
2800
2801    const DEFAULT_SALT: u32 = 42;
2802
2803    #[test]
2804    fn test_single_worker_capacity() {
2805        let mut ch = ConsistentHashRing::new(DEFAULT_SALT);
2806        ch.add_worker(1, 10);
2807
2808        let total_tasks = 5;
2809        let task_distribution = ch.distribute_tasks(total_tasks).unwrap();
2810
2811        assert_eq!(task_distribution.get(&1).cloned().unwrap_or(0), 5);
2812    }
2813
2814    #[test]
2815    fn test_multiple_workers_even_distribution() {
2816        let mut ch = ConsistentHashRing::new(DEFAULT_SALT);
2817
2818        ch.add_worker(1, 1);
2819        ch.add_worker(2, 1);
2820        ch.add_worker(3, 1);
2821
2822        let total_tasks = 3;
2823        let task_distribution = ch.distribute_tasks(total_tasks).unwrap();
2824
2825        for id in 1..=3 {
2826            assert_eq!(task_distribution.get(&id).cloned().unwrap_or(0), 1);
2827        }
2828    }
2829
2830    #[test]
2831    fn test_weighted_distribution() {
2832        let mut ch = ConsistentHashRing::new(DEFAULT_SALT);
2833
2834        ch.add_worker(1, 2);
2835        ch.add_worker(2, 3);
2836        ch.add_worker(3, 5);
2837
2838        let total_tasks = 10;
2839        let task_distribution = ch.distribute_tasks(total_tasks).unwrap();
2840
2841        assert_eq!(task_distribution.get(&1).cloned().unwrap_or(0), 2);
2842        assert_eq!(task_distribution.get(&2).cloned().unwrap_or(0), 3);
2843        assert_eq!(task_distribution.get(&3).cloned().unwrap_or(0), 5);
2844    }
2845
2846    #[test]
2847    fn test_over_capacity() {
2848        let mut ch = ConsistentHashRing::new(DEFAULT_SALT);
2849
2850        ch.add_worker(1, 1);
2851        ch.add_worker(2, 2);
2852        ch.add_worker(3, 3);
2853
2854        let total_tasks = 10; // More tasks than the total weight
2855        let task_distribution = ch.distribute_tasks(total_tasks);
2856
2857        assert!(task_distribution.is_ok());
2858    }
2859
2860    #[test]
2861    fn test_balance_distribution() {
2862        for mut worker_capacity in 1..10 {
2863            for workers in 3..10 {
2864                let mut ring = ConsistentHashRing::new(DEFAULT_SALT);
2865
2866                for worker_id in 0..workers {
2867                    ring.add_worker(worker_id, worker_capacity);
2868                }
2869
2870                // Here we simulate a real situation where the actual parallelism cannot fill all the capacity.
2871                // This is to ensure an average distribution, for example, when three workers with 6 parallelism are assigned 9 tasks,
2872                // they should ideally get an exact distribution of 3, 3, 3 respectively.
2873                if worker_capacity % 2 == 0 {
2874                    worker_capacity /= 2;
2875                }
2876
2877                let total_tasks = worker_capacity * workers;
2878
2879                let task_distribution = ring.distribute_tasks(total_tasks).unwrap();
2880
2881                for (_, v) in task_distribution {
2882                    assert_eq!(v, worker_capacity);
2883                }
2884            }
2885        }
2886    }
2887}