risingwave_meta/stream/
scale.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap, HashSet};
16use std::sync::Arc;
17use std::time::Duration;
18
19use anyhow::anyhow;
20use futures::future;
21use itertools::Itertools;
22use risingwave_common::bail;
23use risingwave_common::catalog::DatabaseId;
24use risingwave_common::hash::ActorMapping;
25use risingwave_common::system_param::reader::SystemParamsRead;
26use risingwave_meta_model::{
27    StreamingParallelism, WorkerId, fragment, fragment_relation, object, streaming_job,
28};
29use risingwave_pb::common::{WorkerNode, WorkerType};
30use risingwave_pb::stream_plan::{PbDispatchOutputMapping, PbDispatcher};
31use sea_orm::{ConnectionTrait, JoinType, QuerySelect, RelationTrait};
32use thiserror_ext::AsReport;
33use tokio::sync::oneshot::Receiver;
34use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard, oneshot};
35use tokio::task::JoinHandle;
36use tokio::time::{Instant, MissedTickBehavior};
37
38use crate::barrier::{Command, Reschedule, RescheduleContext, ReschedulePlan};
39use crate::controller::scale::{
40    FragmentRenderMap, LoadedFragmentContext, NoShuffleEnsemble,
41    find_fragment_no_shuffle_dags_detailed, load_fragment_context, load_fragment_context_for_jobs,
42};
43use crate::error::bail_invalid_parameter;
44use crate::manager::{ActiveStreamingWorkerNodes, LocalNotification, MetaSrvEnv, MetadataManager};
45use crate::model::{ActorId, FragmentId, StreamActor, StreamActorWithDispatchers};
46use crate::stream::{GlobalStreamManager, SourceManagerRef};
47use crate::{MetaError, MetaResult};
48
49#[derive(Debug, Clone, Eq, PartialEq)]
50pub struct WorkerReschedule {
51    pub worker_actor_diff: BTreeMap<WorkerId, isize>,
52}
53
54use risingwave_common::id::JobId;
55use risingwave_common::system_param::AdaptiveParallelismStrategy;
56use risingwave_meta_model::DispatcherType;
57use risingwave_meta_model::fragment::DistributionType;
58use risingwave_meta_model::prelude::{Fragment, FragmentRelation, StreamingJob};
59use sea_orm::ActiveValue::Set;
60use sea_orm::{
61    ActiveModelTrait, ColumnTrait, EntityTrait, IntoActiveModel, QueryFilter, TransactionTrait,
62};
63
64use crate::controller::fragment::{InflightActorInfo, InflightFragmentInfo};
65use crate::controller::utils::{
66    StreamingJobExtraInfo, compose_dispatchers, get_streaming_job_extra_info,
67};
68
69pub type ScaleControllerRef = Arc<ScaleController>;
70
71pub struct ScaleController {
72    pub metadata_manager: MetadataManager,
73
74    pub source_manager: SourceManagerRef,
75
76    pub env: MetaSrvEnv,
77
78    /// We will acquire lock during DDL to prevent scaling operations on jobs that are in the creating state.
79    /// e.g., a MV cannot be rescheduled during foreground backfill.
80    pub reschedule_lock: RwLock<()>,
81}
82
83impl ScaleController {
84    pub fn new(
85        metadata_manager: &MetadataManager,
86        source_manager: SourceManagerRef,
87        env: MetaSrvEnv,
88    ) -> Self {
89        Self {
90            metadata_manager: metadata_manager.clone(),
91            source_manager,
92            env,
93            reschedule_lock: RwLock::new(()),
94        }
95    }
96
97    pub async fn resolve_related_no_shuffle_jobs(
98        &self,
99        jobs: &[JobId],
100    ) -> MetaResult<HashSet<JobId>> {
101        let inner = self.metadata_manager.catalog_controller.inner.read().await;
102        let txn = inner.db.begin().await?;
103
104        let fragment_ids: Vec<_> = Fragment::find()
105            .select_only()
106            .column(fragment::Column::FragmentId)
107            .filter(fragment::Column::JobId.is_in(jobs.to_vec()))
108            .into_tuple()
109            .all(&txn)
110            .await?;
111        let ensembles = find_fragment_no_shuffle_dags_detailed(&txn, &fragment_ids).await?;
112        let related_fragments = ensembles
113            .iter()
114            .flat_map(|ensemble| ensemble.fragments())
115            .collect_vec();
116
117        let job_ids: Vec<_> = Fragment::find()
118            .select_only()
119            .column(fragment::Column::JobId)
120            .filter(fragment::Column::FragmentId.is_in(related_fragments))
121            .into_tuple()
122            .all(&txn)
123            .await?;
124
125        let job_ids = job_ids.into_iter().collect();
126
127        Ok(job_ids)
128    }
129
130    pub async fn reschedule_inplace(
131        &self,
132        policy: HashMap<JobId, ReschedulePolicy>,
133    ) -> MetaResult<HashMap<DatabaseId, Command>> {
134        let inner = self.metadata_manager.catalog_controller.inner.write().await;
135        let txn = inner.db.begin().await?;
136
137        for (table_id, target) in &policy {
138            let streaming_job = StreamingJob::find_by_id(*table_id)
139                .one(&txn)
140                .await?
141                .ok_or_else(|| MetaError::catalog_id_not_found("table", table_id))?;
142
143            let max_parallelism = streaming_job.max_parallelism;
144
145            let mut streaming_job = streaming_job.into_active_model();
146
147            match &target {
148                ReschedulePolicy::Parallelism(p) | ReschedulePolicy::Both(p, _) => {
149                    if let StreamingParallelism::Fixed(n) = p.parallelism
150                        && n > max_parallelism as usize
151                    {
152                        bail!(format!(
153                            "specified parallelism {n} should not exceed max parallelism {max_parallelism}"
154                        ));
155                    }
156
157                    streaming_job.parallelism = Set(p.parallelism.clone());
158                }
159                _ => {}
160            }
161
162            match &target {
163                ReschedulePolicy::ResourceGroup(r) | ReschedulePolicy::Both(_, r) => {
164                    streaming_job.specific_resource_group = Set(r.resource_group.clone());
165                }
166                _ => {}
167            }
168
169            StreamingJob::update(streaming_job).exec(&txn).await?;
170        }
171
172        let job_ids: HashSet<JobId> = policy.keys().copied().collect();
173        let commands = build_reschedule_intent_for_jobs(&txn, job_ids).await?;
174
175        txn.commit().await?;
176
177        Ok(commands)
178    }
179
180    pub async fn reschedule_backfill_parallelism_inplace(
181        &self,
182        policy: HashMap<JobId, Option<StreamingParallelism>>,
183    ) -> MetaResult<HashMap<DatabaseId, Command>> {
184        if policy.is_empty() {
185            return Ok(HashMap::new());
186        }
187
188        let inner = self.metadata_manager.catalog_controller.inner.write().await;
189        let txn = inner.db.begin().await?;
190
191        for (table_id, parallelism) in &policy {
192            let streaming_job = StreamingJob::find_by_id(*table_id)
193                .one(&txn)
194                .await?
195                .ok_or_else(|| MetaError::catalog_id_not_found("table", table_id))?;
196
197            let max_parallelism = streaming_job.max_parallelism;
198
199            let mut streaming_job = streaming_job.into_active_model();
200
201            if let Some(StreamingParallelism::Fixed(n)) = parallelism
202                && *n > max_parallelism as usize
203            {
204                bail!(format!(
205                    "specified backfill parallelism {n} should not exceed max parallelism {max_parallelism}"
206                ));
207            }
208
209            streaming_job.backfill_parallelism = Set(parallelism.clone());
210            streaming_job.update(&txn).await?;
211        }
212
213        let jobs = policy.keys().copied().collect();
214
215        let command = build_reschedule_intent_for_jobs(&txn, jobs).await?;
216
217        txn.commit().await?;
218
219        Ok(command)
220    }
221
222    pub async fn reschedule_fragment_inplace(
223        &self,
224        policy: HashMap<risingwave_meta_model::FragmentId, Option<StreamingParallelism>>,
225    ) -> MetaResult<HashMap<DatabaseId, Command>> {
226        if policy.is_empty() {
227            return Ok(HashMap::new());
228        }
229
230        let inner = self.metadata_manager.catalog_controller.inner.write().await;
231        let txn = inner.db.begin().await?;
232
233        let fragment_id_list = policy.keys().copied().collect_vec();
234
235        let existing_fragment_ids: HashSet<_> = Fragment::find()
236            .select_only()
237            .column(fragment::Column::FragmentId)
238            .filter(fragment::Column::FragmentId.is_in(fragment_id_list.clone()))
239            .into_tuple::<risingwave_meta_model::FragmentId>()
240            .all(&txn)
241            .await?
242            .into_iter()
243            .collect();
244
245        if let Some(missing_fragment_id) = fragment_id_list
246            .iter()
247            .find(|fragment_id| !existing_fragment_ids.contains(*fragment_id))
248        {
249            return Err(MetaError::catalog_id_not_found(
250                "fragment",
251                *missing_fragment_id,
252            ));
253        }
254
255        let mut target_ensembles = vec![];
256
257        for ensemble in find_fragment_no_shuffle_dags_detailed(&txn, &fragment_id_list).await? {
258            let entry_fragment_ids = ensemble.entry_fragments().collect_vec();
259
260            let desired_parallelism = match entry_fragment_ids
261                .iter()
262                .filter_map(|fragment_id| policy.get(fragment_id).cloned())
263                .dedup()
264                .collect_vec()
265                .as_slice()
266            {
267                [] => {
268                    bail_invalid_parameter!(
269                        "none of the entry fragments {:?} were included in the reschedule request; \
270                         provide at least one entry fragment id",
271                        entry_fragment_ids
272                    );
273                }
274                [parallelism] => parallelism.clone(),
275                parallelisms => {
276                    bail!(
277                        "conflicting reschedule policies for fragments in the same no-shuffle ensemble: {:?}",
278                        parallelisms
279                    );
280                }
281            };
282
283            let fragments = Fragment::find()
284                .filter(fragment::Column::FragmentId.is_in(entry_fragment_ids))
285                .all(&txn)
286                .await?;
287
288            debug_assert!(
289                fragments
290                    .iter()
291                    .map(|fragment| fragment.parallelism.as_ref())
292                    .all_equal(),
293                "entry fragments in the same ensemble should share the same parallelism"
294            );
295
296            let current_parallelism = fragments
297                .first()
298                .and_then(|fragment| fragment.parallelism.clone());
299
300            if current_parallelism == desired_parallelism {
301                continue;
302            }
303
304            for fragment in fragments {
305                let mut fragment = fragment.into_active_model();
306                fragment.parallelism = Set(desired_parallelism.clone());
307                Fragment::update(fragment).exec(&txn).await?;
308            }
309
310            target_ensembles.push(ensemble);
311        }
312
313        if target_ensembles.is_empty() {
314            txn.commit().await?;
315            return Ok(HashMap::new());
316        }
317
318        let target_fragment_ids: HashSet<FragmentId> = target_ensembles
319            .iter()
320            .flat_map(|ensemble| ensemble.component_fragments())
321            .collect();
322        let commands = build_reschedule_intent_for_fragments(&txn, target_fragment_ids).await?;
323
324        txn.commit().await?;
325
326        Ok(commands)
327    }
328
329    async fn rerender(&self, jobs: HashSet<JobId>) -> MetaResult<HashMap<DatabaseId, Command>> {
330        if jobs.is_empty() {
331            return Ok(HashMap::new());
332        }
333
334        let inner = self.metadata_manager.catalog_controller.inner.read().await;
335        let txn = inner.db.begin().await?;
336        let commands = build_reschedule_intent_for_jobs(&txn, jobs).await?;
337        txn.commit().await?;
338        Ok(commands)
339    }
340}
341
342async fn build_reschedule_intent_for_jobs(
343    txn: &impl ConnectionTrait,
344    job_ids: HashSet<JobId>,
345) -> MetaResult<HashMap<DatabaseId, Command>> {
346    if job_ids.is_empty() {
347        return Ok(HashMap::new());
348    }
349
350    let job_id_list = job_ids.iter().copied().collect_vec();
351    let database_jobs: Vec<(DatabaseId, JobId)> = StreamingJob::find()
352        .select_only()
353        .column(object::Column::DatabaseId)
354        .column(streaming_job::Column::JobId)
355        .join(JoinType::LeftJoin, streaming_job::Relation::Object.def())
356        .filter(streaming_job::Column::JobId.is_in(job_id_list.clone()))
357        .into_tuple()
358        .all(txn)
359        .await?;
360
361    if database_jobs.len() != job_ids.len() {
362        let returned_jobs: HashSet<JobId> =
363            database_jobs.iter().map(|(_, job_id)| *job_id).collect();
364        let missing = job_ids.difference(&returned_jobs).copied().collect_vec();
365        return Err(MetaError::catalog_id_not_found(
366            "streaming job",
367            format!("{missing:?}"),
368        ));
369    }
370
371    let reschedule_context = load_reschedule_context_for_jobs(txn, job_ids).await?;
372    if reschedule_context.is_empty() {
373        return Ok(HashMap::new());
374    }
375
376    let commands = reschedule_context
377        .into_database_contexts()
378        .into_iter()
379        .map(|(database_id, context)| {
380            (
381                database_id,
382                Command::RescheduleIntent {
383                    context,
384                    reschedule_plan: None,
385                },
386            )
387        })
388        .collect();
389
390    Ok(commands)
391}
392
393async fn build_reschedule_intent_for_fragments(
394    txn: &impl ConnectionTrait,
395    fragment_ids: HashSet<FragmentId>,
396) -> MetaResult<HashMap<DatabaseId, Command>> {
397    if fragment_ids.is_empty() {
398        return Ok(HashMap::new());
399    }
400
401    let fragment_id_list = fragment_ids.iter().copied().collect_vec();
402    let fragment_databases: Vec<(FragmentId, DatabaseId)> = Fragment::find()
403        .select_only()
404        .column(fragment::Column::FragmentId)
405        .column(object::Column::DatabaseId)
406        .join(JoinType::LeftJoin, fragment::Relation::Object.def())
407        .filter(fragment::Column::FragmentId.is_in(fragment_id_list.clone()))
408        .into_tuple()
409        .all(txn)
410        .await?;
411
412    if fragment_databases.len() != fragment_ids.len() {
413        let returned: HashSet<FragmentId> = fragment_databases
414            .iter()
415            .map(|(fragment_id, _)| *fragment_id)
416            .collect();
417        let missing = fragment_ids.difference(&returned).copied().collect_vec();
418        return Err(MetaError::catalog_id_not_found(
419            "fragment",
420            format!("{missing:?}"),
421        ));
422    }
423
424    let ensembles = find_fragment_no_shuffle_dags_detailed(txn, &fragment_id_list).await?;
425    let reschedule_context = load_reschedule_context_for_ensembles(txn, ensembles).await?;
426    if reschedule_context.is_empty() {
427        return Ok(HashMap::new());
428    }
429
430    let commands = reschedule_context
431        .into_database_contexts()
432        .into_iter()
433        .map(|(database_id, context)| {
434            (
435                database_id,
436                Command::RescheduleIntent {
437                    context,
438                    reschedule_plan: None,
439                },
440            )
441        })
442        .collect();
443
444    Ok(commands)
445}
446
447async fn load_reschedule_context_for_jobs(
448    txn: &impl ConnectionTrait,
449    job_ids: HashSet<JobId>,
450) -> MetaResult<RescheduleContext> {
451    let loaded = load_fragment_context_for_jobs(txn, job_ids).await?;
452    build_reschedule_context_from_loaded(txn, loaded).await
453}
454
455async fn load_reschedule_context_for_ensembles(
456    txn: &impl ConnectionTrait,
457    ensembles: Vec<NoShuffleEnsemble>,
458) -> MetaResult<RescheduleContext> {
459    let loaded = load_fragment_context(txn, ensembles).await?;
460    build_reschedule_context_from_loaded(txn, loaded).await
461}
462
463async fn build_reschedule_context_from_loaded(
464    txn: &impl ConnectionTrait,
465    loaded: LoadedFragmentContext,
466) -> MetaResult<RescheduleContext> {
467    if loaded.is_empty() {
468        return Ok(RescheduleContext::empty());
469    }
470
471    let job_ids = loaded.job_map.keys().copied().collect_vec();
472    let job_extra_info = get_streaming_job_extra_info(txn, job_ids).await?;
473
474    let fragment_ids = loaded
475        .job_fragments
476        .values()
477        .flat_map(|fragments| fragments.keys().copied())
478        .collect_vec();
479
480    let upstreams: Vec<(FragmentId, FragmentId, DispatcherType)> = FragmentRelation::find()
481        .select_only()
482        .columns([
483            fragment_relation::Column::TargetFragmentId,
484            fragment_relation::Column::SourceFragmentId,
485            fragment_relation::Column::DispatcherType,
486        ])
487        .filter(fragment_relation::Column::TargetFragmentId.is_in(fragment_ids.clone()))
488        .into_tuple()
489        .all(txn)
490        .await?;
491
492    let mut upstream_fragments = HashMap::new();
493    for (fragment, upstream, dispatcher) in upstreams {
494        upstream_fragments
495            .entry(fragment as FragmentId)
496            .or_insert(HashMap::new())
497            .insert(upstream as FragmentId, dispatcher);
498    }
499
500    let downstreams = FragmentRelation::find()
501        .filter(fragment_relation::Column::SourceFragmentId.is_in(fragment_ids.clone()))
502        .all(txn)
503        .await?;
504
505    let mut downstream_fragments = HashMap::new();
506    let mut downstream_relations = HashMap::new();
507    for relation in downstreams {
508        let source_fragment_id = relation.source_fragment_id as FragmentId;
509        let target_fragment_id = relation.target_fragment_id as FragmentId;
510        downstream_fragments
511            .entry(source_fragment_id)
512            .or_insert(HashMap::new())
513            .insert(target_fragment_id, relation.dispatcher_type);
514        downstream_relations.insert((source_fragment_id, target_fragment_id), relation);
515    }
516
517    Ok(RescheduleContext {
518        loaded,
519        job_extra_info,
520        upstream_fragments,
521        downstream_fragments,
522        downstream_relations,
523    })
524}
525
526/// Build a `Reschedule` by diffing the previously materialized fragment state against
527/// the newly rendered actor layout.
528///
529/// This function assumes a full rebuild (no kept actors) and produces:
530/// - actor additions/removals and vnode bitmap updates
531/// - dispatcher updates for upstream/downstream fragments
532/// - updated split assignments for source actors
533///
534/// `upstream_fragments`/`downstream_fragments` describe neighbor fragments and dispatcher types,
535/// while `all_actor_dispatchers` contains the new dispatcher list for each actor. `job_extra_info`
536/// supplies job-level context for building new actors.
537fn diff_fragment(
538    prev_fragment_info: &InflightFragmentInfo,
539    curr_actors: &HashMap<ActorId, InflightActorInfo>,
540    upstream_fragments: HashMap<FragmentId, DispatcherType>,
541    downstream_fragments: HashMap<FragmentId, DispatcherType>,
542    all_actor_dispatchers: HashMap<ActorId, Vec<PbDispatcher>>,
543    job_extra_info: Option<&StreamingJobExtraInfo>,
544) -> MetaResult<Reschedule> {
545    let prev_ids: HashSet<_> = prev_fragment_info.actors.keys().cloned().collect();
546    let curr_ids: HashSet<_> = curr_actors.keys().cloned().collect();
547
548    let removed_actors: HashSet<_> = &prev_ids - &curr_ids;
549    let added_actor_ids: HashSet<_> = &curr_ids - &prev_ids;
550    let kept_ids: HashSet<_> = prev_ids.intersection(&curr_ids).cloned().collect();
551    debug_assert!(
552        kept_ids.is_empty(),
553        "kept actors found in scale; expected full rebuild, prev={prev_ids:?}, curr={curr_ids:?}, kept={kept_ids:?}"
554    );
555
556    let mut added_actors = HashMap::new();
557    for &actor_id in &added_actor_ids {
558        let InflightActorInfo { worker_id, .. } = curr_actors
559            .get(&actor_id)
560            .ok_or_else(|| anyhow!("BUG: Worker not found for new actor {}", actor_id))?;
561
562        added_actors
563            .entry(*worker_id)
564            .or_insert_with(Vec::new)
565            .push(actor_id);
566    }
567
568    let mut vnode_bitmap_updates = HashMap::new();
569    for actor_id in kept_ids {
570        let prev_actor = &prev_fragment_info.actors[&actor_id];
571        let curr_actor = &curr_actors[&actor_id];
572
573        // Check if the vnode distribution has changed.
574        if prev_actor.vnode_bitmap != curr_actor.vnode_bitmap
575            && let Some(bitmap) = curr_actor.vnode_bitmap.clone()
576        {
577            vnode_bitmap_updates.insert(actor_id, bitmap);
578        }
579    }
580
581    let upstream_dispatcher_mapping =
582        if let DistributionType::Hash = prev_fragment_info.distribution_type {
583            let actor_mapping = curr_actors
584                .iter()
585                .map(
586                    |(
587                        actor_id,
588                        InflightActorInfo {
589                            worker_id: _,
590                            vnode_bitmap,
591                            ..
592                        },
593                    )| { (*actor_id, vnode_bitmap.clone().unwrap()) },
594                )
595                .collect();
596            Some(ActorMapping::from_bitmaps(&actor_mapping))
597        } else {
598            None
599        };
600
601    let upstream_fragment_dispatcher_ids = upstream_fragments
602        .iter()
603        .filter(|&(_, dispatcher_type)| *dispatcher_type != DispatcherType::NoShuffle)
604        .map(|(upstream_fragment, _)| (*upstream_fragment, prev_fragment_info.fragment_id))
605        .collect();
606
607    let downstream_fragment_ids = downstream_fragments
608        .iter()
609        .filter(|&(_, dispatcher_type)| *dispatcher_type != DispatcherType::NoShuffle)
610        .map(|(fragment_id, _)| *fragment_id)
611        .collect();
612
613    let extra_info = job_extra_info.cloned().unwrap_or_default();
614    let expr_context = extra_info.stream_context().to_expr_context();
615    let job_definition = extra_info.job_definition;
616    let config_override = extra_info.config_override;
617
618    let newly_created_actors: HashMap<ActorId, (StreamActorWithDispatchers, WorkerId)> =
619        added_actor_ids
620            .iter()
621            .map(|actor_id| {
622                let actor = StreamActor {
623                    actor_id: *actor_id,
624                    fragment_id: prev_fragment_info.fragment_id,
625                    vnode_bitmap: curr_actors[actor_id].vnode_bitmap.clone(),
626                    mview_definition: job_definition.clone(),
627                    expr_context: Some(expr_context.clone()),
628                    config_override: config_override.clone(),
629                };
630                (
631                    *actor_id,
632                    (
633                        (
634                            actor,
635                            all_actor_dispatchers
636                                .get(actor_id)
637                                .cloned()
638                                .unwrap_or_default(),
639                        ),
640                        curr_actors[actor_id].worker_id,
641                    ),
642                )
643            })
644            .collect();
645
646    let actor_splits = curr_actors
647        .iter()
648        .map(|(&actor_id, info)| (actor_id, info.splits.clone()))
649        .collect();
650
651    let reschedule = Reschedule {
652        added_actors,
653        removed_actors,
654        vnode_bitmap_updates,
655        upstream_fragment_dispatcher_ids,
656        upstream_dispatcher_mapping,
657        downstream_fragment_ids,
658        actor_splits,
659        newly_created_actors,
660    };
661
662    Ok(reschedule)
663}
664
665pub(crate) fn build_reschedule_commands(
666    render_result: FragmentRenderMap,
667    context: RescheduleContext,
668    all_prev_fragments: HashMap<FragmentId, &InflightFragmentInfo>,
669) -> MetaResult<HashMap<DatabaseId, ReschedulePlan>> {
670    if render_result.is_empty() {
671        return Ok(HashMap::new());
672    }
673
674    let RescheduleContext {
675        job_extra_info,
676        upstream_fragments: mut all_upstream_fragments,
677        downstream_fragments: mut all_downstream_fragments,
678        mut downstream_relations,
679        ..
680    } = context;
681
682    let fragment_ids = render_result
683        .values()
684        .flat_map(|jobs| jobs.values())
685        .flatten()
686        .map(|(fragment_id, _)| *fragment_id)
687        .collect_vec();
688
689    let all_related_fragment_ids: HashSet<_> = fragment_ids
690        .iter()
691        .copied()
692        .chain(all_upstream_fragments.values().flatten().map(|(id, _)| *id))
693        .chain(
694            all_downstream_fragments
695                .values()
696                .flatten()
697                .map(|(id, _)| *id),
698        )
699        .collect();
700
701    for fragment_id in all_related_fragment_ids {
702        if !all_prev_fragments.contains_key(&fragment_id) {
703            return Err(MetaError::from(anyhow!(
704                "previous fragment info for {fragment_id} not found"
705            )));
706        }
707    }
708
709    let all_rendered_fragments: HashMap<_, _> = render_result
710        .values()
711        .flat_map(|jobs| jobs.values())
712        .flatten()
713        .map(|(fragment_id, info)| (*fragment_id, info))
714        .collect();
715
716    let mut commands = HashMap::new();
717
718    for (database_id, jobs) in &render_result {
719        let mut all_fragment_actors = HashMap::new();
720        let mut reschedules = HashMap::new();
721
722        for (job_id, fragment_id, fragment_info) in jobs.iter().flat_map(|(job_id, fragments)| {
723            fragments
724                .iter()
725                .map(move |(fragment_id, info)| (job_id, fragment_id, info))
726        }) {
727            let InflightFragmentInfo {
728                distribution_type,
729                actors,
730                ..
731            } = fragment_info;
732
733            let upstream_fragments = all_upstream_fragments
734                .remove(&(*fragment_id as FragmentId))
735                .unwrap_or_default();
736            let downstream_fragments = all_downstream_fragments
737                .remove(&(*fragment_id as FragmentId))
738                .unwrap_or_default();
739
740            let fragment_actors: HashMap<_, _> = upstream_fragments
741                .keys()
742                .copied()
743                .chain(downstream_fragments.keys().copied())
744                .map(|fragment_id| {
745                    all_prev_fragments
746                        .get(&fragment_id)
747                        .map(|fragment| {
748                            (
749                                fragment_id,
750                                fragment.actors.keys().copied().collect::<HashSet<_>>(),
751                            )
752                        })
753                        .ok_or_else(|| {
754                            MetaError::from(anyhow!(
755                                "fragment {} not found in previous state",
756                                fragment_id
757                            ))
758                        })
759                })
760                .collect::<MetaResult<_>>()?;
761
762            all_fragment_actors.extend(fragment_actors);
763
764            let source_fragment_actors = actors
765                .iter()
766                .map(|(actor_id, info)| (*actor_id, info.vnode_bitmap.clone()))
767                .collect();
768
769            let mut all_actor_dispatchers: HashMap<_, Vec<_>> = HashMap::new();
770
771            for downstream_fragment_id in downstream_fragments.keys() {
772                let target_fragment_actors =
773                    match all_rendered_fragments.get(downstream_fragment_id) {
774                        None => {
775                            let external_fragment = all_prev_fragments
776                                .get(downstream_fragment_id)
777                                .ok_or_else(|| {
778                                    MetaError::from(anyhow!(
779                                        "fragment {} not found in previous state",
780                                        downstream_fragment_id
781                                    ))
782                                })?;
783
784                            external_fragment
785                                .actors
786                                .iter()
787                                .map(|(actor_id, info)| (*actor_id, info.vnode_bitmap.clone()))
788                                .collect()
789                        }
790                        Some(downstream_rendered) => downstream_rendered
791                            .actors
792                            .iter()
793                            .map(|(actor_id, info)| (*actor_id, info.vnode_bitmap.clone()))
794                            .collect(),
795                    };
796
797                let target_fragment_distribution = *distribution_type;
798
799                let fragment_relation::Model {
800                    source_fragment_id: _,
801                    target_fragment_id: _,
802                    dispatcher_type,
803                    dist_key_indices,
804                    output_indices,
805                    output_type_mapping,
806                } = downstream_relations
807                    .remove(&(
808                        *fragment_id as FragmentId,
809                        *downstream_fragment_id as FragmentId,
810                    ))
811                    .ok_or_else(|| {
812                        MetaError::from(anyhow!(
813                            "downstream relation missing for {} -> {}",
814                            fragment_id,
815                            downstream_fragment_id
816                        ))
817                    })?;
818
819                let pb_mapping = PbDispatchOutputMapping {
820                    indices: output_indices.into_u32_array(),
821                    types: output_type_mapping.unwrap_or_default().to_protobuf(),
822                };
823
824                let dispatchers = compose_dispatchers(
825                    *distribution_type,
826                    &source_fragment_actors,
827                    *downstream_fragment_id,
828                    target_fragment_distribution,
829                    &target_fragment_actors,
830                    dispatcher_type,
831                    dist_key_indices.into_u32_array(),
832                    pb_mapping,
833                );
834
835                for (actor_id, dispatcher) in dispatchers {
836                    all_actor_dispatchers
837                        .entry(actor_id)
838                        .or_default()
839                        .push(dispatcher);
840                }
841            }
842
843            let prev_fragment = all_prev_fragments.get(&{ *fragment_id }).ok_or_else(|| {
844                MetaError::from(anyhow!(
845                    "fragment {} not found in previous state",
846                    fragment_id
847                ))
848            })?;
849
850            let reschedule = diff_fragment(
851                prev_fragment,
852                actors,
853                upstream_fragments,
854                downstream_fragments,
855                all_actor_dispatchers,
856                job_extra_info.get(job_id),
857            )?;
858
859            reschedules.insert(*fragment_id as FragmentId, reschedule);
860        }
861
862        let command = ReschedulePlan {
863            reschedules,
864            fragment_actors: all_fragment_actors,
865        };
866
867        debug_assert!(
868            command
869                .reschedules
870                .values()
871                .all(|reschedule| reschedule.vnode_bitmap_updates.is_empty()),
872            "reschedule plan carries vnode_bitmap_updates, expected full rebuild"
873        );
874
875        commands.insert(*database_id, command);
876    }
877
878    Ok(commands)
879}
880
881#[derive(Clone, Debug, Eq, PartialEq)]
882pub struct ParallelismPolicy {
883    pub parallelism: StreamingParallelism,
884}
885
886#[derive(Clone, Debug)]
887pub struct ResourceGroupPolicy {
888    pub resource_group: Option<String>,
889}
890
891#[derive(Clone, Debug)]
892pub enum ReschedulePolicy {
893    Parallelism(ParallelismPolicy),
894    ResourceGroup(ResourceGroupPolicy),
895    Both(ParallelismPolicy, ResourceGroupPolicy),
896}
897
898impl GlobalStreamManager {
899    #[await_tree::instrument("acquire_reschedule_read_guard")]
900    pub async fn reschedule_lock_read_guard(&self) -> RwLockReadGuard<'_, ()> {
901        self.scale_controller.reschedule_lock.read().await
902    }
903
904    #[await_tree::instrument("acquire_reschedule_write_guard")]
905    pub async fn reschedule_lock_write_guard(&self) -> RwLockWriteGuard<'_, ()> {
906        self.scale_controller.reschedule_lock.write().await
907    }
908
909    /// When new worker nodes joined, or the parallelism of existing worker nodes changed,
910    /// examines if there are any jobs can be scaled, and scales them if found.
911    ///
912    /// This method will iterate over all `CREATED` jobs, and can be repeatedly called.
913    ///
914    /// Returns
915    /// - `Ok(false)` if no jobs can be scaled;
916    /// - `Ok(true)` if some jobs are scaled, and it is possible that there are more jobs can be scaled.
917    async fn trigger_parallelism_control(&self) -> MetaResult<bool> {
918        tracing::info!("trigger parallelism control");
919
920        let _reschedule_job_lock = self.reschedule_lock_write_guard().await;
921
922        let background_streaming_jobs = self
923            .metadata_manager
924            .list_background_creating_jobs()
925            .await?;
926
927        let blocked_jobs = self
928            .metadata_manager
929            .collect_reschedule_blocked_jobs_for_creating_jobs(&background_streaming_jobs, true)
930            .await?;
931        let has_blocked_jobs = !blocked_jobs.is_empty();
932
933        let database_objects: HashMap<risingwave_meta_model::DatabaseId, Vec<JobId>> = self
934            .metadata_manager
935            .catalog_controller
936            .list_streaming_job_with_database()
937            .await?;
938
939        let job_ids = database_objects
940            .iter()
941            .flat_map(|(database_id, job_ids)| {
942                job_ids
943                    .iter()
944                    .enumerate()
945                    .map(move |(idx, job_id)| (idx, database_id, job_id))
946            })
947            .sorted_by(|(idx_a, database_a, _), (idx_b, database_b, _)| {
948                idx_a.cmp(idx_b).then(database_a.cmp(database_b))
949            })
950            .map(|(_, database_id, job_id)| (*database_id, *job_id))
951            .filter(|(_, job_id)| !blocked_jobs.contains(job_id))
952            .collect_vec();
953
954        if job_ids.is_empty() {
955            tracing::info!("no streaming jobs for scaling, maybe an empty cluster");
956            // Retry periodically while some jobs are temporarily blocked by creating
957            // unreschedulable backfill jobs. This allows us to scale them
958            // automatically once the creating jobs finish.
959            return Ok(has_blocked_jobs);
960        }
961
962        let active_workers =
963            ActiveStreamingWorkerNodes::new_snapshot(self.metadata_manager.clone()).await?;
964
965        tracing::info!(
966            "trigger parallelism control for jobs: {:#?}, workers {:#?}",
967            job_ids,
968            active_workers.current()
969        );
970
971        let batch_size = match self.env.opts.parallelism_control_batch_size {
972            0 => job_ids.len(),
973            n => n,
974        };
975
976        tracing::info!(
977            "total {} streaming jobs, batch size {}, schedulable worker ids: {:?}",
978            job_ids.len(),
979            batch_size,
980            active_workers.current()
981        );
982
983        let batches: Vec<_> = job_ids
984            .into_iter()
985            .chunks(batch_size)
986            .into_iter()
987            .map(|chunk| chunk.collect_vec())
988            .collect();
989
990        for batch in batches {
991            let jobs = batch.iter().map(|(_, job_id)| *job_id).collect();
992
993            let commands = self.scale_controller.rerender(jobs).await?;
994
995            let futures = commands.into_iter().map(|(database_id, command)| {
996                let barrier_scheduler = self.barrier_scheduler.clone();
997                async move { barrier_scheduler.run_command(database_id, command).await }
998            });
999
1000            let _results = future::try_join_all(futures).await?;
1001        }
1002
1003        Ok(has_blocked_jobs)
1004    }
1005
1006    /// Handles notification of worker node activation and deletion, and triggers parallelism control.
1007    async fn run(&self, mut shutdown_rx: Receiver<()>) {
1008        tracing::info!("starting automatic parallelism control monitor");
1009
1010        let check_period =
1011            Duration::from_secs(self.env.opts.parallelism_control_trigger_period_sec);
1012
1013        let mut ticker = tokio::time::interval_at(
1014            Instant::now()
1015                + Duration::from_secs(self.env.opts.parallelism_control_trigger_first_delay_sec),
1016            check_period,
1017        );
1018        ticker.set_missed_tick_behavior(MissedTickBehavior::Skip);
1019
1020        let (local_notification_tx, mut local_notification_rx) =
1021            tokio::sync::mpsc::unbounded_channel();
1022
1023        self.env
1024            .notification_manager()
1025            .insert_local_sender(local_notification_tx);
1026
1027        // waiting for the first tick
1028        ticker.tick().await;
1029
1030        let worker_nodes = self
1031            .metadata_manager
1032            .list_active_streaming_compute_nodes()
1033            .await
1034            .expect("list active streaming compute nodes");
1035
1036        let mut worker_cache: BTreeMap<_, _> = worker_nodes
1037            .into_iter()
1038            .map(|worker| (worker.id, worker))
1039            .collect();
1040
1041        let mut previous_adaptive_parallelism_strategy = AdaptiveParallelismStrategy::default();
1042
1043        let mut should_trigger = false;
1044
1045        loop {
1046            tokio::select! {
1047                biased;
1048
1049                _ = &mut shutdown_rx => {
1050                    tracing::info!("Stream manager is stopped");
1051                    break;
1052                }
1053
1054                _ = ticker.tick(), if should_trigger => {
1055                    let include_workers = worker_cache.keys().copied().collect_vec();
1056
1057                    if include_workers.is_empty() {
1058                        tracing::debug!("no available worker nodes");
1059                        should_trigger = false;
1060                        continue;
1061                    }
1062
1063                    match self.trigger_parallelism_control().await {
1064                        Ok(cont) => {
1065                            should_trigger = cont;
1066                        }
1067                        Err(e) => {
1068                            tracing::warn!(error = %e.as_report(), "Failed to trigger scale out, waiting for next tick to retry after {}s", ticker.period().as_secs());
1069                            ticker.reset();
1070                        }
1071                    }
1072                }
1073
1074                notification = local_notification_rx.recv() => {
1075                    let notification = notification.expect("local notification channel closed in loop of stream manager");
1076
1077                    // Only maintain the cache for streaming compute nodes.
1078                    let worker_is_streaming_compute = |worker: &WorkerNode| {
1079                        worker.get_type() == Ok(WorkerType::ComputeNode)
1080                            && worker.property.as_ref().unwrap().is_streaming
1081                    };
1082
1083                    match notification {
1084                        LocalNotification::SystemParamsChange(reader) => {
1085                            let new_strategy = reader.adaptive_parallelism_strategy();
1086                            if new_strategy != previous_adaptive_parallelism_strategy {
1087                                tracing::info!("adaptive parallelism strategy changed from {:?} to {:?}", previous_adaptive_parallelism_strategy, new_strategy);
1088                                should_trigger = true;
1089                                previous_adaptive_parallelism_strategy = new_strategy;
1090                            }
1091                        }
1092                        LocalNotification::WorkerNodeActivated(worker) => {
1093                            if !worker_is_streaming_compute(&worker) {
1094                                continue;
1095                            }
1096
1097                            tracing::info!(worker = %worker.id, "worker activated notification received");
1098
1099                            let prev_worker = worker_cache.insert(worker.id, worker.clone());
1100
1101                            match prev_worker {
1102                                Some(prev_worker) if prev_worker.compute_node_parallelism() != worker.compute_node_parallelism()  => {
1103                                    tracing::info!(worker = %worker.id, "worker parallelism changed");
1104                                    should_trigger = true;
1105                                }
1106                                Some(prev_worker) if prev_worker.resource_group() != worker.resource_group()  => {
1107                                    tracing::info!(worker = %worker.id, "worker label changed");
1108                                    should_trigger = true;
1109                                }
1110                                None => {
1111                                    tracing::info!(worker = %worker.id, "new worker joined");
1112                                    should_trigger = true;
1113                                }
1114                                _ => {}
1115                            }
1116                        }
1117
1118                        // Since our logic for handling passive scale-in is within the barrier manager,
1119                        // there’s not much we can do here. All we can do is proactively remove the entries from our cache.
1120                        LocalNotification::WorkerNodeDeleted(worker) => {
1121                            if !worker_is_streaming_compute(&worker) {
1122                                continue;
1123                            }
1124
1125                            match worker_cache.remove(&worker.id) {
1126                                Some(prev_worker) => {
1127                                    tracing::info!(worker = %prev_worker.id, "worker removed from stream manager cache");
1128                                }
1129                                None => {
1130                                    tracing::warn!(worker = %worker.id, "worker not found in stream manager cache, but it was removed");
1131                                }
1132                            }
1133                        }
1134
1135                        LocalNotification::StreamingJobBackfillFinished(job_id) => {
1136                            tracing::debug!(job_id = %job_id, "received backfill finished notification");
1137                            if let Err(e) = self.apply_post_backfill_parallelism(job_id).await {
1138                                tracing::warn!(job_id = %job_id, error = %e.as_report(), "failed to restore parallelism after backfill");
1139                                // Retry in the next periodic tick. This avoids triggering
1140                                // unnecessary full control passes when restore succeeds.
1141                                should_trigger = true;
1142                            }
1143                        }
1144
1145                        _ => {}
1146                    }
1147                }
1148            }
1149        }
1150    }
1151
1152    /// Restores a streaming job's parallelism to its target value after backfill completes.
1153    async fn apply_post_backfill_parallelism(&self, job_id: JobId) -> MetaResult<()> {
1154        // Fetch both the target parallelism (final desired state) and the backfill parallelism
1155        // (temporary parallelism used during backfill phase) from the catalog.
1156        let Some((target, backfill_parallelism)) = self
1157            .metadata_manager
1158            .catalog_controller
1159            .get_job_parallelisms(job_id)
1160            .await?
1161        else {
1162            // The job may have been dropped before this notification is processed.
1163            // Treat it as a benign no-op so we don't trigger unnecessary retries.
1164            tracing::debug!(
1165                job_id = %job_id,
1166                "streaming job not found when applying post-backfill parallelism, skip"
1167            );
1168            return Ok(());
1169        };
1170
1171        // Determine if we need to reschedule based on the backfill configuration.
1172        match backfill_parallelism {
1173            Some(backfill_parallelism) if backfill_parallelism == target => {
1174                // Backfill parallelism matches target - no reschedule needed since the job
1175                // is already running at the desired parallelism.
1176                tracing::debug!(
1177                    job_id = %job_id,
1178                    ?backfill_parallelism,
1179                    ?target,
1180                    "backfill parallelism equals job parallelism, skip reschedule"
1181                );
1182                return Ok(());
1183            }
1184            Some(_) => {
1185                // Backfill parallelism differs from target - proceed to restore target parallelism.
1186            }
1187            None => {
1188                // No backfill parallelism was configured, meaning the job was created without
1189                // a special backfill override. No reschedule is necessary.
1190                tracing::debug!(
1191                    job_id = %job_id,
1192                    ?target,
1193                    "no backfill parallelism configured, skip post-backfill reschedule"
1194                );
1195                return Ok(());
1196            }
1197        }
1198
1199        // Reschedule the job to restore its target parallelism.
1200        tracing::info!(
1201            job_id = %job_id,
1202            ?target,
1203            ?backfill_parallelism,
1204            "restoring parallelism after backfill via reschedule"
1205        );
1206        let policy = ReschedulePolicy::Parallelism(ParallelismPolicy {
1207            parallelism: target,
1208        });
1209        if let Err(e) = self.reschedule_streaming_job(job_id, policy, false).await {
1210            tracing::warn!(job_id = %job_id, error = %e.as_report(), "reschedule after backfill failed");
1211            return Err(e);
1212        }
1213
1214        tracing::info!(job_id = %job_id, "parallelism reschedule after backfill submitted");
1215        Ok(())
1216    }
1217
1218    pub fn start_auto_parallelism_monitor(
1219        self: Arc<Self>,
1220    ) -> (JoinHandle<()>, oneshot::Sender<()>) {
1221        tracing::info!("Automatic parallelism scale-out is enabled for streaming jobs");
1222        let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
1223        let join_handle = tokio::spawn(async move {
1224            self.run(shutdown_rx).await;
1225        });
1226
1227        (join_handle, shutdown_tx)
1228    }
1229}