risingwave_meta/stream/
scale.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap, HashSet};
16use std::fmt::Debug;
17use std::num::NonZeroUsize;
18use std::sync::Arc;
19use std::time::Duration;
20
21use anyhow::anyhow;
22use futures::future;
23use itertools::Itertools;
24use risingwave_common::bail;
25use risingwave_common::bitmap::Bitmap;
26use risingwave_common::catalog::{DatabaseId, FragmentTypeFlag, FragmentTypeMask};
27use risingwave_common::hash::ActorMapping;
28use risingwave_meta_model::{StreamingParallelism, WorkerId, fragment, fragment_relation};
29use risingwave_pb::common::{PbWorkerNode, WorkerNode, WorkerType};
30use risingwave_pb::meta::table_fragments::fragment::PbFragmentDistributionType;
31use risingwave_pb::stream_plan::{Dispatcher, PbDispatchOutputMapping, PbDispatcher, StreamNode};
32use sea_orm::{ActiveModelTrait, ConnectionTrait, QuerySelect};
33use thiserror_ext::AsReport;
34use tokio::sync::oneshot::Receiver;
35use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard, oneshot};
36use tokio::task::JoinHandle;
37use tokio::time::{Instant, MissedTickBehavior};
38
39use crate::barrier::{Command, Reschedule, SharedFragmentInfo};
40use crate::controller::scale::{
41    FragmentRenderMap, NoShuffleEnsemble, RenderedGraph, WorkerInfo,
42    find_fragment_no_shuffle_dags_detailed, render_fragments, render_jobs,
43};
44use crate::error::bail_invalid_parameter;
45use crate::manager::{LocalNotification, MetaSrvEnv, MetadataManager};
46use crate::model::{
47    ActorId, DispatcherId, FragmentId, StreamActor, StreamActorWithDispatchers, StreamContext,
48};
49use crate::stream::{GlobalStreamManager, SourceManagerRef};
50use crate::{MetaError, MetaResult};
51
52#[derive(Debug, Clone, Eq, PartialEq)]
53pub struct WorkerReschedule {
54    pub worker_actor_diff: BTreeMap<WorkerId, isize>,
55}
56
57pub struct CustomFragmentInfo {
58    pub job_id: u32,
59    pub fragment_id: u32,
60    pub fragment_type_mask: FragmentTypeMask,
61    pub distribution_type: PbFragmentDistributionType,
62    pub state_table_ids: Vec<u32>,
63    pub node: StreamNode,
64    pub actor_template: StreamActorWithDispatchers,
65    pub actors: Vec<CustomActorInfo>,
66}
67
68#[derive(Default, Clone)]
69pub struct CustomActorInfo {
70    pub actor_id: u32,
71    pub fragment_id: u32,
72    pub dispatcher: Vec<Dispatcher>,
73    /// `None` if singleton.
74    pub vnode_bitmap: Option<Bitmap>,
75}
76
77use risingwave_common::id::JobId;
78use risingwave_common::system_param::AdaptiveParallelismStrategy;
79use risingwave_common::system_param::reader::SystemParamsRead;
80use risingwave_meta_model::DispatcherType;
81use risingwave_meta_model::fragment::DistributionType;
82use risingwave_meta_model::prelude::{Fragment, FragmentRelation, StreamingJob};
83use sea_orm::ActiveValue::Set;
84use sea_orm::{ColumnTrait, EntityTrait, IntoActiveModel, QueryFilter, TransactionTrait};
85
86use crate::controller::fragment::{InflightActorInfo, InflightFragmentInfo};
87use crate::controller::utils::{
88    StreamingJobExtraInfo, compose_dispatchers, get_streaming_job_extra_info,
89};
90use crate::stream::cdc::assign_cdc_table_snapshot_splits_impl;
91
92pub type ScaleControllerRef = Arc<ScaleController>;
93
94pub struct ScaleController {
95    pub metadata_manager: MetadataManager,
96
97    pub source_manager: SourceManagerRef,
98
99    pub env: MetaSrvEnv,
100
101    /// We will acquire lock during DDL to prevent scaling operations on jobs that are in the creating state.
102    /// e.g., a MV cannot be rescheduled during foreground backfill.
103    pub reschedule_lock: RwLock<()>,
104}
105
106impl ScaleController {
107    pub fn new(
108        metadata_manager: &MetadataManager,
109        source_manager: SourceManagerRef,
110        env: MetaSrvEnv,
111    ) -> Self {
112        Self {
113            metadata_manager: metadata_manager.clone(),
114            source_manager,
115            env,
116            reschedule_lock: RwLock::new(()),
117        }
118    }
119
120    pub async fn resolve_related_no_shuffle_jobs(
121        &self,
122        jobs: &[JobId],
123    ) -> MetaResult<HashSet<JobId>> {
124        let inner = self.metadata_manager.catalog_controller.inner.read().await;
125        let txn = inner.db.begin().await?;
126
127        let fragment_ids: Vec<_> = Fragment::find()
128            .select_only()
129            .column(fragment::Column::FragmentId)
130            .filter(fragment::Column::JobId.is_in(jobs.to_vec()))
131            .into_tuple()
132            .all(&txn)
133            .await?;
134        let ensembles = find_fragment_no_shuffle_dags_detailed(&txn, &fragment_ids).await?;
135        let related_fragments = ensembles
136            .iter()
137            .flat_map(|ensemble| ensemble.fragments())
138            .collect_vec();
139
140        let job_ids: Vec<_> = Fragment::find()
141            .select_only()
142            .column(fragment::Column::JobId)
143            .filter(fragment::Column::FragmentId.is_in(related_fragments))
144            .into_tuple()
145            .all(&txn)
146            .await?;
147
148        let job_ids = job_ids.into_iter().collect();
149
150        Ok(job_ids)
151    }
152
153    pub fn diff_fragment(
154        &self,
155        prev_fragment_info: &SharedFragmentInfo,
156        curr_actors: &HashMap<ActorId, InflightActorInfo>,
157        upstream_fragments: HashMap<FragmentId, DispatcherType>,
158        downstream_fragments: HashMap<FragmentId, DispatcherType>,
159        all_actor_dispatchers: HashMap<ActorId, Vec<PbDispatcher>>,
160        job_extra_info: Option<&StreamingJobExtraInfo>,
161    ) -> MetaResult<Reschedule> {
162        let prev_actors: HashMap<_, _> = prev_fragment_info
163            .actors
164            .iter()
165            .map(|(actor_id, actor)| (*actor_id, actor))
166            .collect();
167
168        let prev_ids: HashSet<_> = prev_actors.keys().cloned().collect();
169        let curr_ids: HashSet<_> = curr_actors.keys().cloned().collect();
170
171        let removed_actors: HashSet<_> = &prev_ids - &curr_ids;
172        let added_actor_ids: HashSet<_> = &curr_ids - &prev_ids;
173        let kept_ids: HashSet<_> = prev_ids.intersection(&curr_ids).cloned().collect();
174
175        let mut added_actors = HashMap::new();
176        for &actor_id in &added_actor_ids {
177            let InflightActorInfo { worker_id, .. } = curr_actors
178                .get(&actor_id)
179                .ok_or_else(|| anyhow!("BUG: Worker not found for new actor {}", actor_id))?;
180
181            added_actors
182                .entry(*worker_id)
183                .or_insert_with(Vec::new)
184                .push(actor_id);
185        }
186
187        let mut vnode_bitmap_updates = HashMap::new();
188        for actor_id in kept_ids {
189            let prev_actor = prev_actors[&actor_id];
190            let curr_actor = &curr_actors[&actor_id];
191
192            // Check if the vnode distribution has changed.
193            if prev_actor.vnode_bitmap != curr_actor.vnode_bitmap
194                && let Some(bitmap) = curr_actor.vnode_bitmap.clone()
195            {
196                vnode_bitmap_updates.insert(actor_id, bitmap);
197            }
198        }
199
200        let upstream_dispatcher_mapping =
201            if let DistributionType::Hash = prev_fragment_info.distribution_type {
202                let actor_mapping = curr_actors
203                    .iter()
204                    .map(
205                        |(
206                            actor_id,
207                            InflightActorInfo {
208                                worker_id: _,
209                                vnode_bitmap,
210                                ..
211                            },
212                        )| { (*actor_id, vnode_bitmap.clone().unwrap()) },
213                    )
214                    .collect();
215                Some(ActorMapping::from_bitmaps(&actor_mapping))
216            } else {
217                None
218            };
219
220        let upstream_fragment_dispatcher_ids = upstream_fragments
221            .iter()
222            .filter(|&(_, dispatcher_type)| *dispatcher_type != DispatcherType::NoShuffle)
223            .map(|(upstream_fragment, _)| {
224                (
225                    *upstream_fragment,
226                    prev_fragment_info.fragment_id as DispatcherId,
227                )
228            })
229            .collect();
230
231        let downstream_fragment_ids = downstream_fragments
232            .iter()
233            .filter(|&(_, dispatcher_type)| *dispatcher_type != DispatcherType::NoShuffle)
234            .map(|(fragment_id, _)| *fragment_id)
235            .collect();
236
237        let extra_info = job_extra_info.cloned().unwrap_or_default();
238        let timezone = extra_info.timezone.clone();
239        let job_definition = extra_info.job_definition;
240
241        let newly_created_actors: HashMap<ActorId, (StreamActorWithDispatchers, WorkerId)> =
242            added_actor_ids
243                .iter()
244                .map(|actor_id| {
245                    let actor = StreamActor {
246                        actor_id: *actor_id,
247                        fragment_id: prev_fragment_info.fragment_id,
248                        vnode_bitmap: curr_actors[actor_id].vnode_bitmap.clone(),
249                        mview_definition: job_definition.clone(),
250                        expr_context: Some(
251                            StreamContext {
252                                timezone: timezone.clone(),
253                            }
254                            .to_expr_context(),
255                        ),
256                    };
257                    (
258                        *actor_id,
259                        (
260                            (
261                                actor,
262                                all_actor_dispatchers
263                                    .get(actor_id)
264                                    .cloned()
265                                    .unwrap_or_default(),
266                            ),
267                            curr_actors[actor_id].worker_id,
268                        ),
269                    )
270                })
271                .collect();
272
273        let actor_splits = curr_actors
274            .iter()
275            .map(|(&actor_id, info)| (actor_id, info.splits.clone()))
276            .collect();
277
278        let reschedule = Reschedule {
279            added_actors,
280            removed_actors,
281            vnode_bitmap_updates,
282            upstream_fragment_dispatcher_ids,
283            upstream_dispatcher_mapping,
284            downstream_fragment_ids,
285            newly_created_actors,
286            actor_splits,
287            cdc_table_snapshot_split_assignment: Default::default(),
288            cdc_table_job_id: None,
289        };
290
291        Ok(reschedule)
292    }
293
294    pub async fn reschedule_inplace(
295        &self,
296        policy: HashMap<JobId, ReschedulePolicy>,
297        workers: HashMap<WorkerId, PbWorkerNode>,
298    ) -> MetaResult<HashMap<DatabaseId, Command>> {
299        let inner = self.metadata_manager.catalog_controller.inner.read().await;
300        let txn = inner.db.begin().await?;
301
302        for (table_id, target) in &policy {
303            let streaming_job = StreamingJob::find_by_id(*table_id)
304                .one(&txn)
305                .await?
306                .ok_or_else(|| MetaError::catalog_id_not_found("table", table_id))?;
307
308            let max_parallelism = streaming_job.max_parallelism;
309
310            let mut streaming_job = streaming_job.into_active_model();
311
312            match &target {
313                ReschedulePolicy::Parallelism(p) | ReschedulePolicy::Both(p, _) => {
314                    if let StreamingParallelism::Fixed(n) = p.parallelism
315                        && n > max_parallelism as usize
316                    {
317                        bail!(format!(
318                            "specified parallelism {n} should not exceed max parallelism {max_parallelism}"
319                        ));
320                    }
321
322                    streaming_job.parallelism = Set(p.parallelism.clone());
323                }
324                _ => {}
325            }
326
327            match &target {
328                ReschedulePolicy::ResourceGroup(r) | ReschedulePolicy::Both(_, r) => {
329                    streaming_job.specific_resource_group = Set(r.resource_group.clone());
330                }
331                _ => {}
332            }
333
334            streaming_job.update(&txn).await?;
335        }
336
337        let jobs = policy.keys().copied().collect();
338
339        let workers = workers
340            .into_iter()
341            .map(|(id, worker)| {
342                (
343                    id,
344                    WorkerInfo {
345                        parallelism: NonZeroUsize::new(worker.compute_node_parallelism()).unwrap(),
346                        resource_group: worker.resource_group(),
347                    },
348                )
349            })
350            .collect();
351
352        let command = self.rerender_inner(&txn, jobs, workers).await?;
353
354        txn.commit().await?;
355
356        Ok(command)
357    }
358
359    pub async fn reschedule_fragment_inplace(
360        &self,
361        policy: HashMap<risingwave_meta_model::FragmentId, Option<StreamingParallelism>>,
362        workers: HashMap<WorkerId, PbWorkerNode>,
363    ) -> MetaResult<HashMap<DatabaseId, Command>> {
364        if policy.is_empty() {
365            return Ok(HashMap::new());
366        }
367
368        let inner = self.metadata_manager.catalog_controller.inner.read().await;
369        let txn = inner.db.begin().await?;
370
371        let fragment_id_list = policy.keys().copied().collect_vec();
372
373        let existing_fragment_ids: HashSet<_> = Fragment::find()
374            .select_only()
375            .column(fragment::Column::FragmentId)
376            .filter(fragment::Column::FragmentId.is_in(fragment_id_list.clone()))
377            .into_tuple::<risingwave_meta_model::FragmentId>()
378            .all(&txn)
379            .await?
380            .into_iter()
381            .collect();
382
383        if let Some(missing_fragment_id) = fragment_id_list
384            .iter()
385            .find(|fragment_id| !existing_fragment_ids.contains(fragment_id))
386        {
387            return Err(MetaError::catalog_id_not_found(
388                "fragment",
389                *missing_fragment_id,
390            ));
391        }
392
393        let mut target_ensembles = vec![];
394
395        for ensemble in find_fragment_no_shuffle_dags_detailed(&txn, &fragment_id_list).await? {
396            let entry_fragment_ids = ensemble.entry_fragments().collect_vec();
397
398            let desired_parallelism = match entry_fragment_ids
399                .iter()
400                .filter_map(|fragment_id| policy.get(fragment_id).cloned())
401                .dedup()
402                .collect_vec()
403                .as_slice()
404            {
405                [] => {
406                    bail_invalid_parameter!(
407                        "none of the entry fragments {:?} were included in the reschedule request; \
408                         provide at least one entry fragment id",
409                        entry_fragment_ids
410                    );
411                }
412                [parallelism] => parallelism.clone(),
413                parallelisms => {
414                    bail!(
415                        "conflicting reschedule policies for fragments in the same no-shuffle ensemble: {:?}",
416                        parallelisms
417                    );
418                }
419            };
420
421            let fragments = Fragment::find()
422                .filter(fragment::Column::FragmentId.is_in(entry_fragment_ids))
423                .all(&txn)
424                .await?;
425
426            debug_assert!(
427                fragments
428                    .iter()
429                    .map(|fragment| fragment.parallelism.as_ref())
430                    .all_equal(),
431                "entry fragments in the same ensemble should share the same parallelism"
432            );
433
434            let current_parallelism = fragments
435                .first()
436                .and_then(|fragment| fragment.parallelism.clone());
437
438            if current_parallelism == desired_parallelism {
439                continue;
440            }
441
442            for fragment in fragments {
443                let mut fragment = fragment.into_active_model();
444                fragment.parallelism = Set(desired_parallelism.clone());
445                fragment.update(&txn).await?;
446            }
447
448            target_ensembles.push(ensemble);
449        }
450
451        let workers = workers
452            .into_iter()
453            .map(|(id, worker)| {
454                (
455                    id,
456                    WorkerInfo {
457                        parallelism: NonZeroUsize::new(worker.compute_node_parallelism()).unwrap(),
458                        resource_group: worker.resource_group(),
459                    },
460                )
461            })
462            .collect();
463
464        let command = self
465            .rerender_fragment_inner(&txn, target_ensembles, workers)
466            .await?;
467
468        txn.commit().await?;
469
470        Ok(command)
471    }
472
473    async fn rerender(
474        &self,
475        jobs: HashSet<JobId>,
476        workers: BTreeMap<WorkerId, WorkerInfo>,
477    ) -> MetaResult<HashMap<DatabaseId, Command>> {
478        let inner = self.metadata_manager.catalog_controller.inner.read().await;
479        self.rerender_inner(&inner.db, jobs, workers).await
480    }
481
482    async fn rerender_fragment_inner(
483        &self,
484        txn: &impl ConnectionTrait,
485        ensembles: Vec<NoShuffleEnsemble>,
486        workers: BTreeMap<WorkerId, WorkerInfo>,
487    ) -> MetaResult<HashMap<DatabaseId, Command>> {
488        if ensembles.is_empty() {
489            return Ok(HashMap::new());
490        }
491
492        let adaptive_parallelism_strategy = {
493            let system_params_reader = self.env.system_params_reader().await;
494            system_params_reader.adaptive_parallelism_strategy()
495        };
496
497        let RenderedGraph { fragments, .. } = render_fragments(
498            txn,
499            self.env.actor_id_generator(),
500            ensembles,
501            workers,
502            adaptive_parallelism_strategy,
503        )
504        .await?;
505
506        self.build_reschedule_commands(txn, fragments).await
507    }
508
509    async fn rerender_inner(
510        &self,
511        txn: &impl ConnectionTrait,
512        jobs: HashSet<JobId>,
513        workers: BTreeMap<WorkerId, WorkerInfo>,
514    ) -> MetaResult<HashMap<DatabaseId, Command>> {
515        let adaptive_parallelism_strategy = {
516            let system_params_reader = self.env.system_params_reader().await;
517            system_params_reader.adaptive_parallelism_strategy()
518        };
519
520        let RenderedGraph { fragments, .. } = render_jobs(
521            txn,
522            self.env.actor_id_generator(),
523            jobs,
524            workers,
525            adaptive_parallelism_strategy,
526        )
527        .await?;
528
529        self.build_reschedule_commands(txn, fragments).await
530    }
531
532    async fn build_reschedule_commands(
533        &self,
534        txn: &impl ConnectionTrait,
535        render_result: FragmentRenderMap,
536    ) -> MetaResult<HashMap<DatabaseId, Command>> {
537        if render_result.is_empty() {
538            return Ok(HashMap::new());
539        }
540
541        let job_ids = render_result
542            .values()
543            .flat_map(|jobs| jobs.keys().copied())
544            .collect_vec();
545
546        let job_extra_info = get_streaming_job_extra_info(txn, job_ids).await?;
547
548        let fragment_ids = render_result
549            .values()
550            .flat_map(|jobs| jobs.values())
551            .flatten()
552            .map(|(fragment_id, _)| *fragment_id)
553            .collect_vec();
554
555        let upstreams: Vec<(
556            risingwave_meta_model::FragmentId,
557            risingwave_meta_model::FragmentId,
558            DispatcherType,
559        )> = FragmentRelation::find()
560            .select_only()
561            .columns([
562                fragment_relation::Column::TargetFragmentId,
563                fragment_relation::Column::SourceFragmentId,
564                fragment_relation::Column::DispatcherType,
565            ])
566            .filter(fragment_relation::Column::TargetFragmentId.is_in(fragment_ids.clone()))
567            .into_tuple()
568            .all(txn)
569            .await?;
570
571        let downstreams = FragmentRelation::find()
572            .filter(fragment_relation::Column::SourceFragmentId.is_in(fragment_ids.clone()))
573            .all(txn)
574            .await?;
575
576        let mut all_upstream_fragments = HashMap::new();
577
578        for (fragment, upstream, dispatcher) in upstreams {
579            let fragment_id = fragment as FragmentId;
580            let upstream_id = upstream as FragmentId;
581            all_upstream_fragments
582                .entry(fragment_id)
583                .or_insert(HashMap::new())
584                .insert(upstream_id, dispatcher);
585        }
586
587        let mut all_downstream_fragments = HashMap::new();
588
589        let mut downstream_relations = HashMap::new();
590        for relation in downstreams {
591            let source_fragment_id = relation.source_fragment_id as FragmentId;
592            let target_fragment_id = relation.target_fragment_id as FragmentId;
593            all_downstream_fragments
594                .entry(source_fragment_id)
595                .or_insert(HashMap::new())
596                .insert(target_fragment_id, relation.dispatcher_type);
597
598            downstream_relations.insert((source_fragment_id, target_fragment_id), relation);
599        }
600
601        let all_related_fragment_ids: HashSet<_> = fragment_ids
602            .iter()
603            .copied()
604            .chain(
605                all_upstream_fragments
606                    .values()
607                    .flatten()
608                    .map(|(id, _)| *id as i32),
609            )
610            .chain(
611                all_downstream_fragments
612                    .values()
613                    .flatten()
614                    .map(|(id, _)| *id as i32),
615            )
616            .collect();
617
618        let all_related_fragment_ids = all_related_fragment_ids.into_iter().collect_vec();
619
620        // let all_fragments_from_db: HashMap<_, _> = Fragment::find()
621        //     .filter(fragment::Column::FragmentId.is_in(all_related_fragment_ids.clone()))
622        //     .all(&txn)
623        //     .await?
624        //     .into_iter()
625        //     .map(|f| (f.fragment_id, f))
626        //     .collect();
627
628        let all_prev_fragments: HashMap<_, _> = {
629            let read_guard = self.env.shared_actor_infos().read_guard();
630            all_related_fragment_ids
631                .iter()
632                .map(|&fragment_id| {
633                    read_guard
634                        .get_fragment(fragment_id as FragmentId)
635                        .cloned()
636                        .map(|fragment| (fragment_id, fragment))
637                        .ok_or_else(|| {
638                            MetaError::from(anyhow!(
639                                "previous fragment info for {fragment_id} not found"
640                            ))
641                        })
642                })
643                .collect::<MetaResult<_>>()?
644        };
645
646        let all_rendered_fragments: HashMap<_, _> = render_result
647            .values()
648            .flat_map(|jobs| jobs.values())
649            .flatten()
650            .map(|(fragment_id, info)| (*fragment_id, info))
651            .collect();
652
653        let mut commands = HashMap::new();
654
655        for (database_id, jobs) in &render_result {
656            let mut all_fragment_actors = HashMap::new();
657            let mut reschedules = HashMap::new();
658
659            for (job_id, fragment_id, fragment_info) in
660                jobs.iter().flat_map(|(job_id, fragments)| {
661                    fragments
662                        .iter()
663                        .map(move |(fragment_id, info)| (job_id, fragment_id, info))
664                })
665            {
666                let InflightFragmentInfo {
667                    distribution_type,
668                    actors,
669                    ..
670                } = fragment_info;
671
672                let upstream_fragments = all_upstream_fragments
673                    .remove(&(*fragment_id as FragmentId))
674                    .unwrap_or_default();
675                let downstream_fragments = all_downstream_fragments
676                    .remove(&(*fragment_id as FragmentId))
677                    .unwrap_or_default();
678
679                let fragment_actors: HashMap<_, _> = upstream_fragments
680                    .keys()
681                    .copied()
682                    .chain(downstream_fragments.keys().copied())
683                    .map(|fragment_id| {
684                        all_prev_fragments
685                            .get(&(fragment_id as i32))
686                            .map(|fragment| {
687                                (
688                                    fragment_id,
689                                    fragment.actors.keys().copied().collect::<HashSet<_>>(),
690                                )
691                            })
692                            .ok_or_else(|| {
693                                MetaError::from(anyhow!(
694                                    "fragment {} not found in previous state",
695                                    fragment_id
696                                ))
697                            })
698                    })
699                    .collect::<MetaResult<_>>()?;
700
701                all_fragment_actors.extend(fragment_actors);
702
703                let source_fragment_actors = actors
704                    .iter()
705                    .map(|(actor_id, info)| (*actor_id, info.vnode_bitmap.clone()))
706                    .collect();
707
708                let mut all_actor_dispatchers: HashMap<_, Vec<_>> = HashMap::new();
709
710                for downstream_fragment_id in downstream_fragments.keys() {
711                    let target_fragment_actors =
712                        match all_rendered_fragments.get(&(*downstream_fragment_id as i32)) {
713                            None => {
714                                let external_fragment = all_prev_fragments
715                                    .get(&(*downstream_fragment_id as i32))
716                                    .ok_or_else(|| {
717                                        MetaError::from(anyhow!(
718                                            "fragment {} not found in previous state",
719                                            downstream_fragment_id
720                                        ))
721                                    })?;
722
723                                external_fragment
724                                    .actors
725                                    .iter()
726                                    .map(|(actor_id, info)| (*actor_id, info.vnode_bitmap.clone()))
727                                    .collect()
728                            }
729                            Some(downstream_rendered) => downstream_rendered
730                                .actors
731                                .iter()
732                                .map(|(actor_id, info)| (*actor_id, info.vnode_bitmap.clone()))
733                                .collect(),
734                        };
735
736                    let target_fragment_distribution = *distribution_type;
737
738                    let fragment_relation::Model {
739                        source_fragment_id: _,
740                        target_fragment_id: _,
741                        dispatcher_type,
742                        dist_key_indices,
743                        output_indices,
744                        output_type_mapping,
745                    } = downstream_relations
746                        .remove(&(
747                            *fragment_id as FragmentId,
748                            *downstream_fragment_id as FragmentId,
749                        ))
750                        .ok_or_else(|| {
751                            MetaError::from(anyhow!(
752                                "downstream relation missing for {} -> {}",
753                                fragment_id,
754                                downstream_fragment_id
755                            ))
756                        })?;
757
758                    let pb_mapping = PbDispatchOutputMapping {
759                        indices: output_indices.into_u32_array(),
760                        types: output_type_mapping.unwrap_or_default().to_protobuf(),
761                    };
762
763                    let dispatchers = compose_dispatchers(
764                        *distribution_type,
765                        &source_fragment_actors,
766                        *downstream_fragment_id,
767                        target_fragment_distribution,
768                        &target_fragment_actors,
769                        dispatcher_type,
770                        dist_key_indices.into_u32_array(),
771                        pb_mapping,
772                    );
773
774                    for (actor_id, dispatcher) in dispatchers {
775                        all_actor_dispatchers
776                            .entry(actor_id)
777                            .or_default()
778                            .push(dispatcher);
779                    }
780                }
781
782                let prev_fragment = all_prev_fragments.get(&{ *fragment_id }).ok_or_else(|| {
783                    MetaError::from(anyhow!(
784                        "fragment {} not found in previous state",
785                        fragment_id
786                    ))
787                })?;
788
789                let mut reschedule = self.diff_fragment(
790                    prev_fragment,
791                    actors,
792                    upstream_fragments,
793                    downstream_fragments,
794                    all_actor_dispatchers,
795                    job_extra_info.get(job_id),
796                )?;
797
798                // We only handle CDC splits at this stage, so it should have been empty before.
799                debug_assert!(reschedule.cdc_table_job_id.is_none());
800                debug_assert!(reschedule.cdc_table_snapshot_split_assignment.is_empty());
801                let cdc_info = if fragment_info
802                    .fragment_type_mask
803                    .contains(FragmentTypeFlag::StreamCdcScan)
804                {
805                    let assignment = assign_cdc_table_snapshot_splits_impl(
806                        *job_id,
807                        actors.keys().copied().collect(),
808                        self.env.meta_store_ref(),
809                        None,
810                    )
811                    .await?;
812                    Some((job_id, assignment))
813                } else {
814                    None
815                };
816
817                if let Some((cdc_table_id, cdc_table_snapshot_split_assignment)) = cdc_info {
818                    reschedule.cdc_table_job_id = Some(*cdc_table_id);
819                    reschedule.cdc_table_snapshot_split_assignment =
820                        cdc_table_snapshot_split_assignment;
821                }
822
823                reschedules.insert(*fragment_id as FragmentId, reschedule);
824            }
825
826            let command = Command::RescheduleFragment {
827                reschedules,
828                fragment_actors: all_fragment_actors,
829            };
830
831            commands.insert(*database_id, command);
832        }
833
834        Ok(commands)
835    }
836}
837
838#[derive(Clone, Debug, Eq, PartialEq)]
839pub struct ParallelismPolicy {
840    pub parallelism: StreamingParallelism,
841}
842
843#[derive(Clone, Debug)]
844pub struct ResourceGroupPolicy {
845    pub resource_group: Option<String>,
846}
847
848#[derive(Clone, Debug)]
849pub enum ReschedulePolicy {
850    Parallelism(ParallelismPolicy),
851    ResourceGroup(ResourceGroupPolicy),
852    Both(ParallelismPolicy, ResourceGroupPolicy),
853}
854
855impl GlobalStreamManager {
856    #[await_tree::instrument("acquire_reschedule_read_guard")]
857    pub async fn reschedule_lock_read_guard(&self) -> RwLockReadGuard<'_, ()> {
858        self.scale_controller.reschedule_lock.read().await
859    }
860
861    #[await_tree::instrument("acquire_reschedule_write_guard")]
862    pub async fn reschedule_lock_write_guard(&self) -> RwLockWriteGuard<'_, ()> {
863        self.scale_controller.reschedule_lock.write().await
864    }
865
866    /// When new worker nodes joined, or the parallelism of existing worker nodes changed,
867    /// examines if there are any jobs can be scaled, and scales them if found.
868    ///
869    /// This method will iterate over all `CREATED` jobs, and can be repeatedly called.
870    ///
871    /// Returns
872    /// - `Ok(false)` if no jobs can be scaled;
873    /// - `Ok(true)` if some jobs are scaled, and it is possible that there are more jobs can be scaled.
874    async fn trigger_parallelism_control(&self) -> MetaResult<bool> {
875        tracing::info!("trigger parallelism control");
876
877        let _reschedule_job_lock = self.reschedule_lock_write_guard().await;
878
879        let background_streaming_jobs = self
880            .metadata_manager
881            .list_background_creating_jobs()
882            .await?;
883
884        let skipped_jobs = if !background_streaming_jobs.is_empty() {
885            let jobs = self
886                .scale_controller
887                .resolve_related_no_shuffle_jobs(&background_streaming_jobs)
888                .await?;
889
890            tracing::info!(
891                "skipping parallelism control of background jobs {:?} and associated jobs {:?}",
892                background_streaming_jobs,
893                jobs
894            );
895
896            jobs
897        } else {
898            HashSet::new()
899        };
900
901        let database_objects: HashMap<risingwave_meta_model::DatabaseId, Vec<JobId>> = self
902            .metadata_manager
903            .catalog_controller
904            .list_streaming_job_with_database()
905            .await?;
906
907        let job_ids = database_objects
908            .iter()
909            .flat_map(|(database_id, job_ids)| {
910                job_ids
911                    .iter()
912                    .enumerate()
913                    .map(move |(idx, job_id)| (idx, database_id, job_id))
914            })
915            .sorted_by(|(idx_a, database_a, _), (idx_b, database_b, _)| {
916                idx_a.cmp(idx_b).then(database_a.cmp(database_b))
917            })
918            .map(|(_, database_id, job_id)| (*database_id, *job_id))
919            .filter(|(_, job_id)| !skipped_jobs.contains(job_id))
920            .collect_vec();
921
922        if job_ids.is_empty() {
923            tracing::info!("no streaming jobs for scaling, maybe an empty cluster");
924            return Ok(false);
925        }
926
927        let workers = self
928            .metadata_manager
929            .cluster_controller
930            .list_active_streaming_workers()
931            .await?;
932
933        let schedulable_workers: BTreeMap<_, _> = workers
934            .iter()
935            .filter(|worker| {
936                !worker
937                    .property
938                    .as_ref()
939                    .map(|p| p.is_unschedulable)
940                    .unwrap_or(false)
941            })
942            .map(|worker| {
943                (
944                    worker.id as i32,
945                    WorkerInfo {
946                        parallelism: NonZeroUsize::new(worker.compute_node_parallelism()).unwrap(),
947                        resource_group: worker.resource_group(),
948                    },
949                )
950            })
951            .collect();
952
953        if job_ids.is_empty() {
954            tracing::info!("no streaming jobs for scaling, maybe an empty cluster");
955            return Ok(false);
956        }
957
958        tracing::info!(
959            "trigger parallelism control for jobs: {:#?}, workers {:#?}",
960            job_ids,
961            schedulable_workers
962        );
963
964        let batch_size = match self.env.opts.parallelism_control_batch_size {
965            0 => job_ids.len(),
966            n => n,
967        };
968
969        tracing::info!(
970            "total {} streaming jobs, batch size {}, schedulable worker ids: {:?}",
971            job_ids.len(),
972            batch_size,
973            schedulable_workers
974        );
975
976        let batches: Vec<_> = job_ids
977            .into_iter()
978            .chunks(batch_size)
979            .into_iter()
980            .map(|chunk| chunk.collect_vec())
981            .collect();
982
983        for batch in batches {
984            let jobs = batch.iter().map(|(_, job_id)| *job_id).collect();
985
986            let commands = self
987                .scale_controller
988                .rerender(jobs, schedulable_workers.clone())
989                .await?;
990
991            let futures = commands.into_iter().map(|(database_id, command)| {
992                let barrier_scheduler = self.barrier_scheduler.clone();
993                async move { barrier_scheduler.run_command(database_id, command).await }
994            });
995
996            let _results = future::try_join_all(futures).await?;
997        }
998
999        Ok(false)
1000    }
1001
1002    /// Handles notification of worker node activation and deletion, and triggers parallelism control.
1003    async fn run(&self, mut shutdown_rx: Receiver<()>) {
1004        tracing::info!("starting automatic parallelism control monitor");
1005
1006        let check_period =
1007            Duration::from_secs(self.env.opts.parallelism_control_trigger_period_sec);
1008
1009        let mut ticker = tokio::time::interval_at(
1010            Instant::now()
1011                + Duration::from_secs(self.env.opts.parallelism_control_trigger_first_delay_sec),
1012            check_period,
1013        );
1014        ticker.set_missed_tick_behavior(MissedTickBehavior::Skip);
1015
1016        // waiting for the first tick
1017        ticker.tick().await;
1018
1019        let (local_notification_tx, mut local_notification_rx) =
1020            tokio::sync::mpsc::unbounded_channel();
1021
1022        self.env
1023            .notification_manager()
1024            .insert_local_sender(local_notification_tx);
1025
1026        let worker_nodes = self
1027            .metadata_manager
1028            .list_active_streaming_compute_nodes()
1029            .await
1030            .expect("list active streaming compute nodes");
1031
1032        let mut worker_cache: BTreeMap<_, _> = worker_nodes
1033            .into_iter()
1034            .map(|worker| (worker.id, worker))
1035            .collect();
1036
1037        let mut previous_adaptive_parallelism_strategy = AdaptiveParallelismStrategy::default();
1038
1039        let mut should_trigger = false;
1040
1041        loop {
1042            tokio::select! {
1043                biased;
1044
1045                _ = &mut shutdown_rx => {
1046                    tracing::info!("Stream manager is stopped");
1047                    break;
1048                }
1049
1050                _ = ticker.tick(), if should_trigger => {
1051                    let include_workers = worker_cache.keys().copied().collect_vec();
1052
1053                    if include_workers.is_empty() {
1054                        tracing::debug!("no available worker nodes");
1055                        should_trigger = false;
1056                        continue;
1057                    }
1058
1059                    match self.trigger_parallelism_control().await {
1060                        Ok(cont) => {
1061                            should_trigger = cont;
1062                        }
1063                        Err(e) => {
1064                            tracing::warn!(error = %e.as_report(), "Failed to trigger scale out, waiting for next tick to retry after {}s", ticker.period().as_secs());
1065                            ticker.reset();
1066                        }
1067                    }
1068                }
1069
1070                notification = local_notification_rx.recv() => {
1071                    let notification = notification.expect("local notification channel closed in loop of stream manager");
1072
1073                    // Only maintain the cache for streaming compute nodes.
1074                    let worker_is_streaming_compute = |worker: &WorkerNode| {
1075                        worker.get_type() == Ok(WorkerType::ComputeNode)
1076                            && worker.property.as_ref().unwrap().is_streaming
1077                    };
1078
1079                    match notification {
1080                        LocalNotification::SystemParamsChange(reader) => {
1081                            let new_strategy = reader.adaptive_parallelism_strategy();
1082                            if new_strategy != previous_adaptive_parallelism_strategy {
1083                                tracing::info!("adaptive parallelism strategy changed from {:?} to {:?}", previous_adaptive_parallelism_strategy, new_strategy);
1084                                should_trigger = true;
1085                                previous_adaptive_parallelism_strategy = new_strategy;
1086                            }
1087                        }
1088                        LocalNotification::WorkerNodeActivated(worker) => {
1089                            if !worker_is_streaming_compute(&worker) {
1090                                continue;
1091                            }
1092
1093                            tracing::info!(worker = worker.id, "worker activated notification received");
1094
1095                            let prev_worker = worker_cache.insert(worker.id, worker.clone());
1096
1097                            match prev_worker {
1098                                Some(prev_worker) if prev_worker.compute_node_parallelism() != worker.compute_node_parallelism()  => {
1099                                    tracing::info!(worker = worker.id, "worker parallelism changed");
1100                                    should_trigger = true;
1101                                }
1102                                Some(prev_worker) if  prev_worker.resource_group() != worker.resource_group()  => {
1103                                    tracing::info!(worker = worker.id, "worker label changed");
1104                                    should_trigger = true;
1105                                }
1106                                None => {
1107                                    tracing::info!(worker = worker.id, "new worker joined");
1108                                    should_trigger = true;
1109                                }
1110                                _ => {}
1111                            }
1112                        }
1113
1114                        // Since our logic for handling passive scale-in is within the barrier manager,
1115                        // there’s not much we can do here. All we can do is proactively remove the entries from our cache.
1116                        LocalNotification::WorkerNodeDeleted(worker) => {
1117                            if !worker_is_streaming_compute(&worker) {
1118                                continue;
1119                            }
1120
1121                            match worker_cache.remove(&worker.id) {
1122                                Some(prev_worker) => {
1123                                    tracing::info!(worker = prev_worker.id, "worker removed from stream manager cache");
1124                                }
1125                                None => {
1126                                    tracing::warn!(worker = worker.id, "worker not found in stream manager cache, but it was removed");
1127                                }
1128                            }
1129                        }
1130
1131                        _ => {}
1132                    }
1133                }
1134            }
1135        }
1136    }
1137
1138    pub fn start_auto_parallelism_monitor(
1139        self: Arc<Self>,
1140    ) -> (JoinHandle<()>, oneshot::Sender<()>) {
1141        tracing::info!("Automatic parallelism scale-out is enabled for streaming jobs");
1142        let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
1143        let join_handle = tokio::spawn(async move {
1144            self.run(shutdown_rx).await;
1145        });
1146
1147        (join_handle, shutdown_tx)
1148    }
1149}