risingwave_meta/stream/
scale.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap, HashSet};
16use std::fmt::Debug;
17use std::num::NonZeroUsize;
18use std::sync::Arc;
19use std::time::Duration;
20
21use anyhow::anyhow;
22use futures::future;
23use itertools::Itertools;
24use risingwave_common::bail;
25use risingwave_common::catalog::{DatabaseId, FragmentTypeFlag};
26use risingwave_common::hash::ActorMapping;
27use risingwave_meta_model::{StreamingParallelism, WorkerId, fragment, fragment_relation};
28use risingwave_pb::common::{PbWorkerNode, WorkerNode, WorkerType};
29use risingwave_pb::stream_plan::{PbDispatchOutputMapping, PbDispatcher};
30use sea_orm::{ActiveModelTrait, ConnectionTrait, QuerySelect};
31use thiserror_ext::AsReport;
32use tokio::sync::oneshot::Receiver;
33use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard, oneshot};
34use tokio::task::JoinHandle;
35use tokio::time::{Instant, MissedTickBehavior};
36
37use crate::barrier::{Command, Reschedule, SharedFragmentInfo};
38use crate::controller::scale::{
39    FragmentRenderMap, NoShuffleEnsemble, RenderedGraph, WorkerInfo,
40    find_fragment_no_shuffle_dags_detailed, render_fragments, render_jobs,
41};
42use crate::error::bail_invalid_parameter;
43use crate::manager::{LocalNotification, MetaSrvEnv, MetadataManager};
44use crate::model::{ActorId, DispatcherId, FragmentId, StreamActor, StreamActorWithDispatchers};
45use crate::stream::{GlobalStreamManager, SourceManagerRef};
46use crate::{MetaError, MetaResult};
47
48#[derive(Debug, Clone, Eq, PartialEq)]
49pub struct WorkerReschedule {
50    pub worker_actor_diff: BTreeMap<WorkerId, isize>,
51}
52
53use risingwave_common::id::JobId;
54use risingwave_common::system_param::AdaptiveParallelismStrategy;
55use risingwave_common::system_param::reader::SystemParamsRead;
56use risingwave_meta_model::DispatcherType;
57use risingwave_meta_model::fragment::DistributionType;
58use risingwave_meta_model::prelude::{Fragment, FragmentRelation, StreamingJob};
59use sea_orm::ActiveValue::Set;
60use sea_orm::{ColumnTrait, EntityTrait, IntoActiveModel, QueryFilter, TransactionTrait};
61
62use crate::controller::fragment::{InflightActorInfo, InflightFragmentInfo};
63use crate::controller::utils::{
64    StreamingJobExtraInfo, compose_dispatchers, get_streaming_job_extra_info,
65};
66use crate::stream::cdc::assign_cdc_table_snapshot_splits_impl;
67
68pub type ScaleControllerRef = Arc<ScaleController>;
69
70pub struct ScaleController {
71    pub metadata_manager: MetadataManager,
72
73    pub source_manager: SourceManagerRef,
74
75    pub env: MetaSrvEnv,
76
77    /// We will acquire lock during DDL to prevent scaling operations on jobs that are in the creating state.
78    /// e.g., a MV cannot be rescheduled during foreground backfill.
79    pub reschedule_lock: RwLock<()>,
80}
81
82impl ScaleController {
83    pub fn new(
84        metadata_manager: &MetadataManager,
85        source_manager: SourceManagerRef,
86        env: MetaSrvEnv,
87    ) -> Self {
88        Self {
89            metadata_manager: metadata_manager.clone(),
90            source_manager,
91            env,
92            reschedule_lock: RwLock::new(()),
93        }
94    }
95
96    pub async fn resolve_related_no_shuffle_jobs(
97        &self,
98        jobs: &[JobId],
99    ) -> MetaResult<HashSet<JobId>> {
100        let inner = self.metadata_manager.catalog_controller.inner.read().await;
101        let txn = inner.db.begin().await?;
102
103        let fragment_ids: Vec<_> = Fragment::find()
104            .select_only()
105            .column(fragment::Column::FragmentId)
106            .filter(fragment::Column::JobId.is_in(jobs.to_vec()))
107            .into_tuple()
108            .all(&txn)
109            .await?;
110        let ensembles = find_fragment_no_shuffle_dags_detailed(&txn, &fragment_ids).await?;
111        let related_fragments = ensembles
112            .iter()
113            .flat_map(|ensemble| ensemble.fragments())
114            .collect_vec();
115
116        let job_ids: Vec<_> = Fragment::find()
117            .select_only()
118            .column(fragment::Column::JobId)
119            .filter(fragment::Column::FragmentId.is_in(related_fragments))
120            .into_tuple()
121            .all(&txn)
122            .await?;
123
124        let job_ids = job_ids.into_iter().collect();
125
126        Ok(job_ids)
127    }
128
129    pub fn diff_fragment(
130        &self,
131        prev_fragment_info: &SharedFragmentInfo,
132        curr_actors: &HashMap<ActorId, InflightActorInfo>,
133        upstream_fragments: HashMap<FragmentId, DispatcherType>,
134        downstream_fragments: HashMap<FragmentId, DispatcherType>,
135        all_actor_dispatchers: HashMap<ActorId, Vec<PbDispatcher>>,
136        job_extra_info: Option<&StreamingJobExtraInfo>,
137    ) -> MetaResult<Reschedule> {
138        let prev_actors: HashMap<_, _> = prev_fragment_info
139            .actors
140            .iter()
141            .map(|(actor_id, actor)| (*actor_id, actor))
142            .collect();
143
144        let prev_ids: HashSet<_> = prev_actors.keys().cloned().collect();
145        let curr_ids: HashSet<_> = curr_actors.keys().cloned().collect();
146
147        let removed_actors: HashSet<_> = &prev_ids - &curr_ids;
148        let added_actor_ids: HashSet<_> = &curr_ids - &prev_ids;
149        let kept_ids: HashSet<_> = prev_ids.intersection(&curr_ids).cloned().collect();
150
151        let mut added_actors = HashMap::new();
152        for &actor_id in &added_actor_ids {
153            let InflightActorInfo { worker_id, .. } = curr_actors
154                .get(&actor_id)
155                .ok_or_else(|| anyhow!("BUG: Worker not found for new actor {}", actor_id))?;
156
157            added_actors
158                .entry(*worker_id)
159                .or_insert_with(Vec::new)
160                .push(actor_id);
161        }
162
163        let mut vnode_bitmap_updates = HashMap::new();
164        for actor_id in kept_ids {
165            let prev_actor = prev_actors[&actor_id];
166            let curr_actor = &curr_actors[&actor_id];
167
168            // Check if the vnode distribution has changed.
169            if prev_actor.vnode_bitmap != curr_actor.vnode_bitmap
170                && let Some(bitmap) = curr_actor.vnode_bitmap.clone()
171            {
172                vnode_bitmap_updates.insert(actor_id, bitmap);
173            }
174        }
175
176        let upstream_dispatcher_mapping =
177            if let DistributionType::Hash = prev_fragment_info.distribution_type {
178                let actor_mapping = curr_actors
179                    .iter()
180                    .map(
181                        |(
182                            actor_id,
183                            InflightActorInfo {
184                                worker_id: _,
185                                vnode_bitmap,
186                                ..
187                            },
188                        )| { (*actor_id, vnode_bitmap.clone().unwrap()) },
189                    )
190                    .collect();
191                Some(ActorMapping::from_bitmaps(&actor_mapping))
192            } else {
193                None
194            };
195
196        let upstream_fragment_dispatcher_ids = upstream_fragments
197            .iter()
198            .filter(|&(_, dispatcher_type)| *dispatcher_type != DispatcherType::NoShuffle)
199            .map(|(upstream_fragment, _)| {
200                (
201                    *upstream_fragment,
202                    prev_fragment_info.fragment_id.as_raw_id() as DispatcherId,
203                )
204            })
205            .collect();
206
207        let downstream_fragment_ids = downstream_fragments
208            .iter()
209            .filter(|&(_, dispatcher_type)| *dispatcher_type != DispatcherType::NoShuffle)
210            .map(|(fragment_id, _)| *fragment_id)
211            .collect();
212
213        let extra_info = job_extra_info.cloned().unwrap_or_default();
214        let expr_context = extra_info.stream_context().to_expr_context();
215        let job_definition = extra_info.job_definition;
216        let config_override = extra_info.config_override;
217
218        let newly_created_actors: HashMap<ActorId, (StreamActorWithDispatchers, WorkerId)> =
219            added_actor_ids
220                .iter()
221                .map(|actor_id| {
222                    let actor = StreamActor {
223                        actor_id: *actor_id,
224                        fragment_id: prev_fragment_info.fragment_id,
225                        vnode_bitmap: curr_actors[actor_id].vnode_bitmap.clone(),
226                        mview_definition: job_definition.clone(),
227                        expr_context: Some(expr_context.clone()),
228                        config_override: config_override.clone(),
229                    };
230                    (
231                        *actor_id,
232                        (
233                            (
234                                actor,
235                                all_actor_dispatchers
236                                    .get(actor_id)
237                                    .cloned()
238                                    .unwrap_or_default(),
239                            ),
240                            curr_actors[actor_id].worker_id,
241                        ),
242                    )
243                })
244                .collect();
245
246        let actor_splits = curr_actors
247            .iter()
248            .map(|(&actor_id, info)| (actor_id, info.splits.clone()))
249            .collect();
250
251        let reschedule = Reschedule {
252            added_actors,
253            removed_actors,
254            vnode_bitmap_updates,
255            upstream_fragment_dispatcher_ids,
256            upstream_dispatcher_mapping,
257            downstream_fragment_ids,
258            newly_created_actors,
259            actor_splits,
260            cdc_table_snapshot_split_assignment: Default::default(),
261            cdc_table_job_id: None,
262        };
263
264        Ok(reschedule)
265    }
266
267    pub async fn reschedule_inplace(
268        &self,
269        policy: HashMap<JobId, ReschedulePolicy>,
270        workers: HashMap<WorkerId, PbWorkerNode>,
271    ) -> MetaResult<HashMap<DatabaseId, Command>> {
272        let inner = self.metadata_manager.catalog_controller.inner.read().await;
273        let txn = inner.db.begin().await?;
274
275        for (table_id, target) in &policy {
276            let streaming_job = StreamingJob::find_by_id(*table_id)
277                .one(&txn)
278                .await?
279                .ok_or_else(|| MetaError::catalog_id_not_found("table", table_id))?;
280
281            let max_parallelism = streaming_job.max_parallelism;
282
283            let mut streaming_job = streaming_job.into_active_model();
284
285            match &target {
286                ReschedulePolicy::Parallelism(p) | ReschedulePolicy::Both(p, _) => {
287                    if let StreamingParallelism::Fixed(n) = p.parallelism
288                        && n > max_parallelism as usize
289                    {
290                        bail!(format!(
291                            "specified parallelism {n} should not exceed max parallelism {max_parallelism}"
292                        ));
293                    }
294
295                    streaming_job.parallelism = Set(p.parallelism.clone());
296                }
297                _ => {}
298            }
299
300            match &target {
301                ReschedulePolicy::ResourceGroup(r) | ReschedulePolicy::Both(_, r) => {
302                    streaming_job.specific_resource_group = Set(r.resource_group.clone());
303                }
304                _ => {}
305            }
306
307            streaming_job.update(&txn).await?;
308        }
309
310        let jobs = policy.keys().copied().collect();
311
312        let workers = workers
313            .into_iter()
314            .map(|(id, worker)| {
315                (
316                    id,
317                    WorkerInfo {
318                        parallelism: NonZeroUsize::new(worker.compute_node_parallelism()).unwrap(),
319                        resource_group: worker.resource_group(),
320                    },
321                )
322            })
323            .collect();
324
325        let command = self.rerender_inner(&txn, jobs, workers).await?;
326
327        txn.commit().await?;
328
329        Ok(command)
330    }
331
332    pub async fn reschedule_fragment_inplace(
333        &self,
334        policy: HashMap<risingwave_meta_model::FragmentId, Option<StreamingParallelism>>,
335        workers: HashMap<WorkerId, PbWorkerNode>,
336    ) -> MetaResult<HashMap<DatabaseId, Command>> {
337        if policy.is_empty() {
338            return Ok(HashMap::new());
339        }
340
341        let inner = self.metadata_manager.catalog_controller.inner.read().await;
342        let txn = inner.db.begin().await?;
343
344        let fragment_id_list = policy.keys().copied().collect_vec();
345
346        let existing_fragment_ids: HashSet<_> = Fragment::find()
347            .select_only()
348            .column(fragment::Column::FragmentId)
349            .filter(fragment::Column::FragmentId.is_in(fragment_id_list.clone()))
350            .into_tuple::<risingwave_meta_model::FragmentId>()
351            .all(&txn)
352            .await?
353            .into_iter()
354            .collect();
355
356        if let Some(missing_fragment_id) = fragment_id_list
357            .iter()
358            .find(|fragment_id| !existing_fragment_ids.contains(fragment_id))
359        {
360            return Err(MetaError::catalog_id_not_found(
361                "fragment",
362                *missing_fragment_id,
363            ));
364        }
365
366        let mut target_ensembles = vec![];
367
368        for ensemble in find_fragment_no_shuffle_dags_detailed(&txn, &fragment_id_list).await? {
369            let entry_fragment_ids = ensemble.entry_fragments().collect_vec();
370
371            let desired_parallelism = match entry_fragment_ids
372                .iter()
373                .filter_map(|fragment_id| policy.get(fragment_id).cloned())
374                .dedup()
375                .collect_vec()
376                .as_slice()
377            {
378                [] => {
379                    bail_invalid_parameter!(
380                        "none of the entry fragments {:?} were included in the reschedule request; \
381                         provide at least one entry fragment id",
382                        entry_fragment_ids
383                    );
384                }
385                [parallelism] => parallelism.clone(),
386                parallelisms => {
387                    bail!(
388                        "conflicting reschedule policies for fragments in the same no-shuffle ensemble: {:?}",
389                        parallelisms
390                    );
391                }
392            };
393
394            let fragments = Fragment::find()
395                .filter(fragment::Column::FragmentId.is_in(entry_fragment_ids))
396                .all(&txn)
397                .await?;
398
399            debug_assert!(
400                fragments
401                    .iter()
402                    .map(|fragment| fragment.parallelism.as_ref())
403                    .all_equal(),
404                "entry fragments in the same ensemble should share the same parallelism"
405            );
406
407            let current_parallelism = fragments
408                .first()
409                .and_then(|fragment| fragment.parallelism.clone());
410
411            if current_parallelism == desired_parallelism {
412                continue;
413            }
414
415            for fragment in fragments {
416                let mut fragment = fragment.into_active_model();
417                fragment.parallelism = Set(desired_parallelism.clone());
418                fragment.update(&txn).await?;
419            }
420
421            target_ensembles.push(ensemble);
422        }
423
424        let workers = workers
425            .into_iter()
426            .map(|(id, worker)| {
427                (
428                    id,
429                    WorkerInfo {
430                        parallelism: NonZeroUsize::new(worker.compute_node_parallelism()).unwrap(),
431                        resource_group: worker.resource_group(),
432                    },
433                )
434            })
435            .collect();
436
437        let command = self
438            .rerender_fragment_inner(&txn, target_ensembles, workers)
439            .await?;
440
441        txn.commit().await?;
442
443        Ok(command)
444    }
445
446    async fn rerender(
447        &self,
448        jobs: HashSet<JobId>,
449        workers: BTreeMap<WorkerId, WorkerInfo>,
450    ) -> MetaResult<HashMap<DatabaseId, Command>> {
451        let inner = self.metadata_manager.catalog_controller.inner.read().await;
452        self.rerender_inner(&inner.db, jobs, workers).await
453    }
454
455    async fn rerender_fragment_inner(
456        &self,
457        txn: &impl ConnectionTrait,
458        ensembles: Vec<NoShuffleEnsemble>,
459        workers: BTreeMap<WorkerId, WorkerInfo>,
460    ) -> MetaResult<HashMap<DatabaseId, Command>> {
461        if ensembles.is_empty() {
462            return Ok(HashMap::new());
463        }
464
465        let adaptive_parallelism_strategy = {
466            let system_params_reader = self.env.system_params_reader().await;
467            system_params_reader.adaptive_parallelism_strategy()
468        };
469
470        let RenderedGraph { fragments, .. } = render_fragments(
471            txn,
472            self.env.actor_id_generator(),
473            ensembles,
474            workers,
475            adaptive_parallelism_strategy,
476        )
477        .await?;
478
479        self.build_reschedule_commands(txn, fragments).await
480    }
481
482    async fn rerender_inner(
483        &self,
484        txn: &impl ConnectionTrait,
485        jobs: HashSet<JobId>,
486        workers: BTreeMap<WorkerId, WorkerInfo>,
487    ) -> MetaResult<HashMap<DatabaseId, Command>> {
488        let adaptive_parallelism_strategy = {
489            let system_params_reader = self.env.system_params_reader().await;
490            system_params_reader.adaptive_parallelism_strategy()
491        };
492
493        let RenderedGraph { fragments, .. } = render_jobs(
494            txn,
495            self.env.actor_id_generator(),
496            jobs,
497            workers,
498            adaptive_parallelism_strategy,
499        )
500        .await?;
501
502        self.build_reschedule_commands(txn, fragments).await
503    }
504
505    async fn build_reschedule_commands(
506        &self,
507        txn: &impl ConnectionTrait,
508        render_result: FragmentRenderMap,
509    ) -> MetaResult<HashMap<DatabaseId, Command>> {
510        if render_result.is_empty() {
511            return Ok(HashMap::new());
512        }
513
514        let job_ids = render_result
515            .values()
516            .flat_map(|jobs| jobs.keys().copied())
517            .collect_vec();
518
519        let job_extra_info = get_streaming_job_extra_info(txn, job_ids).await?;
520
521        let fragment_ids = render_result
522            .values()
523            .flat_map(|jobs| jobs.values())
524            .flatten()
525            .map(|(fragment_id, _)| *fragment_id)
526            .collect_vec();
527
528        let upstreams: Vec<(
529            risingwave_meta_model::FragmentId,
530            risingwave_meta_model::FragmentId,
531            DispatcherType,
532        )> = FragmentRelation::find()
533            .select_only()
534            .columns([
535                fragment_relation::Column::TargetFragmentId,
536                fragment_relation::Column::SourceFragmentId,
537                fragment_relation::Column::DispatcherType,
538            ])
539            .filter(fragment_relation::Column::TargetFragmentId.is_in(fragment_ids.clone()))
540            .into_tuple()
541            .all(txn)
542            .await?;
543
544        let downstreams = FragmentRelation::find()
545            .filter(fragment_relation::Column::SourceFragmentId.is_in(fragment_ids.clone()))
546            .all(txn)
547            .await?;
548
549        let mut all_upstream_fragments = HashMap::new();
550
551        for (fragment, upstream, dispatcher) in upstreams {
552            let fragment_id = fragment as FragmentId;
553            let upstream_id = upstream as FragmentId;
554            all_upstream_fragments
555                .entry(fragment_id)
556                .or_insert(HashMap::new())
557                .insert(upstream_id, dispatcher);
558        }
559
560        let mut all_downstream_fragments = HashMap::new();
561
562        let mut downstream_relations = HashMap::new();
563        for relation in downstreams {
564            let source_fragment_id = relation.source_fragment_id as FragmentId;
565            let target_fragment_id = relation.target_fragment_id as FragmentId;
566            all_downstream_fragments
567                .entry(source_fragment_id)
568                .or_insert(HashMap::new())
569                .insert(target_fragment_id, relation.dispatcher_type);
570
571            downstream_relations.insert((source_fragment_id, target_fragment_id), relation);
572        }
573
574        let all_related_fragment_ids: HashSet<_> = fragment_ids
575            .iter()
576            .copied()
577            .chain(all_upstream_fragments.values().flatten().map(|(id, _)| *id))
578            .chain(
579                all_downstream_fragments
580                    .values()
581                    .flatten()
582                    .map(|(id, _)| *id),
583            )
584            .collect();
585
586        let all_related_fragment_ids = all_related_fragment_ids.into_iter().collect_vec();
587
588        // let all_fragments_from_db: HashMap<_, _> = Fragment::find()
589        //     .filter(fragment::Column::FragmentId.is_in(all_related_fragment_ids.clone()))
590        //     .all(&txn)
591        //     .await?
592        //     .into_iter()
593        //     .map(|f| (f.fragment_id, f))
594        //     .collect();
595
596        let all_prev_fragments: HashMap<_, _> = {
597            let read_guard = self.env.shared_actor_infos().read_guard();
598            all_related_fragment_ids
599                .iter()
600                .map(|&fragment_id| {
601                    read_guard
602                        .get_fragment(fragment_id as FragmentId)
603                        .cloned()
604                        .map(|fragment| (fragment_id, fragment))
605                        .ok_or_else(|| {
606                            MetaError::from(anyhow!(
607                                "previous fragment info for {fragment_id} not found"
608                            ))
609                        })
610                })
611                .collect::<MetaResult<_>>()?
612        };
613
614        let all_rendered_fragments: HashMap<_, _> = render_result
615            .values()
616            .flat_map(|jobs| jobs.values())
617            .flatten()
618            .map(|(fragment_id, info)| (*fragment_id, info))
619            .collect();
620
621        let mut commands = HashMap::new();
622
623        for (database_id, jobs) in &render_result {
624            let mut all_fragment_actors = HashMap::new();
625            let mut reschedules = HashMap::new();
626
627            for (job_id, fragment_id, fragment_info) in
628                jobs.iter().flat_map(|(job_id, fragments)| {
629                    fragments
630                        .iter()
631                        .map(move |(fragment_id, info)| (job_id, fragment_id, info))
632                })
633            {
634                let InflightFragmentInfo {
635                    distribution_type,
636                    actors,
637                    ..
638                } = fragment_info;
639
640                let upstream_fragments = all_upstream_fragments
641                    .remove(&(*fragment_id as FragmentId))
642                    .unwrap_or_default();
643                let downstream_fragments = all_downstream_fragments
644                    .remove(&(*fragment_id as FragmentId))
645                    .unwrap_or_default();
646
647                let fragment_actors: HashMap<_, _> = upstream_fragments
648                    .keys()
649                    .copied()
650                    .chain(downstream_fragments.keys().copied())
651                    .map(|fragment_id| {
652                        all_prev_fragments
653                            .get(&fragment_id)
654                            .map(|fragment| {
655                                (
656                                    fragment_id,
657                                    fragment.actors.keys().copied().collect::<HashSet<_>>(),
658                                )
659                            })
660                            .ok_or_else(|| {
661                                MetaError::from(anyhow!(
662                                    "fragment {} not found in previous state",
663                                    fragment_id
664                                ))
665                            })
666                    })
667                    .collect::<MetaResult<_>>()?;
668
669                all_fragment_actors.extend(fragment_actors);
670
671                let source_fragment_actors = actors
672                    .iter()
673                    .map(|(actor_id, info)| (*actor_id, info.vnode_bitmap.clone()))
674                    .collect();
675
676                let mut all_actor_dispatchers: HashMap<_, Vec<_>> = HashMap::new();
677
678                for downstream_fragment_id in downstream_fragments.keys() {
679                    let target_fragment_actors =
680                        match all_rendered_fragments.get(downstream_fragment_id) {
681                            None => {
682                                let external_fragment = all_prev_fragments
683                                    .get(downstream_fragment_id)
684                                    .ok_or_else(|| {
685                                        MetaError::from(anyhow!(
686                                            "fragment {} not found in previous state",
687                                            downstream_fragment_id
688                                        ))
689                                    })?;
690
691                                external_fragment
692                                    .actors
693                                    .iter()
694                                    .map(|(actor_id, info)| (*actor_id, info.vnode_bitmap.clone()))
695                                    .collect()
696                            }
697                            Some(downstream_rendered) => downstream_rendered
698                                .actors
699                                .iter()
700                                .map(|(actor_id, info)| (*actor_id, info.vnode_bitmap.clone()))
701                                .collect(),
702                        };
703
704                    let target_fragment_distribution = *distribution_type;
705
706                    let fragment_relation::Model {
707                        source_fragment_id: _,
708                        target_fragment_id: _,
709                        dispatcher_type,
710                        dist_key_indices,
711                        output_indices,
712                        output_type_mapping,
713                    } = downstream_relations
714                        .remove(&(
715                            *fragment_id as FragmentId,
716                            *downstream_fragment_id as FragmentId,
717                        ))
718                        .ok_or_else(|| {
719                            MetaError::from(anyhow!(
720                                "downstream relation missing for {} -> {}",
721                                fragment_id,
722                                downstream_fragment_id
723                            ))
724                        })?;
725
726                    let pb_mapping = PbDispatchOutputMapping {
727                        indices: output_indices.into_u32_array(),
728                        types: output_type_mapping.unwrap_or_default().to_protobuf(),
729                    };
730
731                    let dispatchers = compose_dispatchers(
732                        *distribution_type,
733                        &source_fragment_actors,
734                        *downstream_fragment_id,
735                        target_fragment_distribution,
736                        &target_fragment_actors,
737                        dispatcher_type,
738                        dist_key_indices.into_u32_array(),
739                        pb_mapping,
740                    );
741
742                    for (actor_id, dispatcher) in dispatchers {
743                        all_actor_dispatchers
744                            .entry(actor_id)
745                            .or_default()
746                            .push(dispatcher);
747                    }
748                }
749
750                let prev_fragment = all_prev_fragments.get(&{ *fragment_id }).ok_or_else(|| {
751                    MetaError::from(anyhow!(
752                        "fragment {} not found in previous state",
753                        fragment_id
754                    ))
755                })?;
756
757                let mut reschedule = self.diff_fragment(
758                    prev_fragment,
759                    actors,
760                    upstream_fragments,
761                    downstream_fragments,
762                    all_actor_dispatchers,
763                    job_extra_info.get(job_id),
764                )?;
765
766                // We only handle CDC splits at this stage, so it should have been empty before.
767                debug_assert!(reschedule.cdc_table_job_id.is_none());
768                debug_assert!(reschedule.cdc_table_snapshot_split_assignment.is_empty());
769                let cdc_info = if fragment_info
770                    .fragment_type_mask
771                    .contains(FragmentTypeFlag::StreamCdcScan)
772                {
773                    let assignment = assign_cdc_table_snapshot_splits_impl(
774                        *job_id,
775                        actors.keys().copied().collect(),
776                        self.env.meta_store_ref(),
777                        None,
778                    )
779                    .await?;
780                    Some((job_id, assignment))
781                } else {
782                    None
783                };
784
785                if let Some((cdc_table_id, cdc_table_snapshot_split_assignment)) = cdc_info {
786                    reschedule.cdc_table_job_id = Some(*cdc_table_id);
787                    reschedule.cdc_table_snapshot_split_assignment =
788                        cdc_table_snapshot_split_assignment;
789                }
790
791                reschedules.insert(*fragment_id as FragmentId, reschedule);
792            }
793
794            let command = Command::RescheduleFragment {
795                reschedules,
796                fragment_actors: all_fragment_actors,
797            };
798
799            commands.insert(*database_id, command);
800        }
801
802        Ok(commands)
803    }
804}
805
806#[derive(Clone, Debug, Eq, PartialEq)]
807pub struct ParallelismPolicy {
808    pub parallelism: StreamingParallelism,
809}
810
811#[derive(Clone, Debug)]
812pub struct ResourceGroupPolicy {
813    pub resource_group: Option<String>,
814}
815
816#[derive(Clone, Debug)]
817pub enum ReschedulePolicy {
818    Parallelism(ParallelismPolicy),
819    ResourceGroup(ResourceGroupPolicy),
820    Both(ParallelismPolicy, ResourceGroupPolicy),
821}
822
823impl GlobalStreamManager {
824    #[await_tree::instrument("acquire_reschedule_read_guard")]
825    pub async fn reschedule_lock_read_guard(&self) -> RwLockReadGuard<'_, ()> {
826        self.scale_controller.reschedule_lock.read().await
827    }
828
829    #[await_tree::instrument("acquire_reschedule_write_guard")]
830    pub async fn reschedule_lock_write_guard(&self) -> RwLockWriteGuard<'_, ()> {
831        self.scale_controller.reschedule_lock.write().await
832    }
833
834    /// When new worker nodes joined, or the parallelism of existing worker nodes changed,
835    /// examines if there are any jobs can be scaled, and scales them if found.
836    ///
837    /// This method will iterate over all `CREATED` jobs, and can be repeatedly called.
838    ///
839    /// Returns
840    /// - `Ok(false)` if no jobs can be scaled;
841    /// - `Ok(true)` if some jobs are scaled, and it is possible that there are more jobs can be scaled.
842    async fn trigger_parallelism_control(&self) -> MetaResult<bool> {
843        tracing::info!("trigger parallelism control");
844
845        let _reschedule_job_lock = self.reschedule_lock_write_guard().await;
846
847        let background_streaming_jobs = self
848            .metadata_manager
849            .list_background_creating_jobs()
850            .await?;
851
852        let skipped_jobs = if !background_streaming_jobs.is_empty() {
853            let jobs = self
854                .scale_controller
855                .resolve_related_no_shuffle_jobs(&background_streaming_jobs)
856                .await?;
857
858            tracing::info!(
859                "skipping parallelism control of background jobs {:?} and associated jobs {:?}",
860                background_streaming_jobs,
861                jobs
862            );
863
864            jobs
865        } else {
866            HashSet::new()
867        };
868
869        let database_objects: HashMap<risingwave_meta_model::DatabaseId, Vec<JobId>> = self
870            .metadata_manager
871            .catalog_controller
872            .list_streaming_job_with_database()
873            .await?;
874
875        let job_ids = database_objects
876            .iter()
877            .flat_map(|(database_id, job_ids)| {
878                job_ids
879                    .iter()
880                    .enumerate()
881                    .map(move |(idx, job_id)| (idx, database_id, job_id))
882            })
883            .sorted_by(|(idx_a, database_a, _), (idx_b, database_b, _)| {
884                idx_a.cmp(idx_b).then(database_a.cmp(database_b))
885            })
886            .map(|(_, database_id, job_id)| (*database_id, *job_id))
887            .filter(|(_, job_id)| !skipped_jobs.contains(job_id))
888            .collect_vec();
889
890        if job_ids.is_empty() {
891            tracing::info!("no streaming jobs for scaling, maybe an empty cluster");
892            return Ok(false);
893        }
894
895        let workers = self
896            .metadata_manager
897            .cluster_controller
898            .list_active_streaming_workers()
899            .await?;
900
901        let schedulable_workers: BTreeMap<_, _> = workers
902            .iter()
903            .filter(|worker| {
904                !worker
905                    .property
906                    .as_ref()
907                    .map(|p| p.is_unschedulable)
908                    .unwrap_or(false)
909            })
910            .map(|worker| {
911                (
912                    worker.id,
913                    WorkerInfo {
914                        parallelism: NonZeroUsize::new(worker.compute_node_parallelism()).unwrap(),
915                        resource_group: worker.resource_group(),
916                    },
917                )
918            })
919            .collect();
920
921        if job_ids.is_empty() {
922            tracing::info!("no streaming jobs for scaling, maybe an empty cluster");
923            return Ok(false);
924        }
925
926        tracing::info!(
927            "trigger parallelism control for jobs: {:#?}, workers {:#?}",
928            job_ids,
929            schedulable_workers
930        );
931
932        let batch_size = match self.env.opts.parallelism_control_batch_size {
933            0 => job_ids.len(),
934            n => n,
935        };
936
937        tracing::info!(
938            "total {} streaming jobs, batch size {}, schedulable worker ids: {:?}",
939            job_ids.len(),
940            batch_size,
941            schedulable_workers
942        );
943
944        let batches: Vec<_> = job_ids
945            .into_iter()
946            .chunks(batch_size)
947            .into_iter()
948            .map(|chunk| chunk.collect_vec())
949            .collect();
950
951        for batch in batches {
952            let jobs = batch.iter().map(|(_, job_id)| *job_id).collect();
953
954            let commands = self
955                .scale_controller
956                .rerender(jobs, schedulable_workers.clone())
957                .await?;
958
959            let futures = commands.into_iter().map(|(database_id, command)| {
960                let barrier_scheduler = self.barrier_scheduler.clone();
961                async move { barrier_scheduler.run_command(database_id, command).await }
962            });
963
964            let _results = future::try_join_all(futures).await?;
965        }
966
967        Ok(false)
968    }
969
970    /// Handles notification of worker node activation and deletion, and triggers parallelism control.
971    async fn run(&self, mut shutdown_rx: Receiver<()>) {
972        tracing::info!("starting automatic parallelism control monitor");
973
974        let check_period =
975            Duration::from_secs(self.env.opts.parallelism_control_trigger_period_sec);
976
977        let mut ticker = tokio::time::interval_at(
978            Instant::now()
979                + Duration::from_secs(self.env.opts.parallelism_control_trigger_first_delay_sec),
980            check_period,
981        );
982        ticker.set_missed_tick_behavior(MissedTickBehavior::Skip);
983
984        // waiting for the first tick
985        ticker.tick().await;
986
987        let (local_notification_tx, mut local_notification_rx) =
988            tokio::sync::mpsc::unbounded_channel();
989
990        self.env
991            .notification_manager()
992            .insert_local_sender(local_notification_tx);
993
994        let worker_nodes = self
995            .metadata_manager
996            .list_active_streaming_compute_nodes()
997            .await
998            .expect("list active streaming compute nodes");
999
1000        let mut worker_cache: BTreeMap<_, _> = worker_nodes
1001            .into_iter()
1002            .map(|worker| (worker.id, worker))
1003            .collect();
1004
1005        let mut previous_adaptive_parallelism_strategy = AdaptiveParallelismStrategy::default();
1006
1007        let mut should_trigger = false;
1008
1009        loop {
1010            tokio::select! {
1011                biased;
1012
1013                _ = &mut shutdown_rx => {
1014                    tracing::info!("Stream manager is stopped");
1015                    break;
1016                }
1017
1018                _ = ticker.tick(), if should_trigger => {
1019                    let include_workers = worker_cache.keys().copied().collect_vec();
1020
1021                    if include_workers.is_empty() {
1022                        tracing::debug!("no available worker nodes");
1023                        should_trigger = false;
1024                        continue;
1025                    }
1026
1027                    match self.trigger_parallelism_control().await {
1028                        Ok(cont) => {
1029                            should_trigger = cont;
1030                        }
1031                        Err(e) => {
1032                            tracing::warn!(error = %e.as_report(), "Failed to trigger scale out, waiting for next tick to retry after {}s", ticker.period().as_secs());
1033                            ticker.reset();
1034                        }
1035                    }
1036                }
1037
1038                notification = local_notification_rx.recv() => {
1039                    let notification = notification.expect("local notification channel closed in loop of stream manager");
1040
1041                    // Only maintain the cache for streaming compute nodes.
1042                    let worker_is_streaming_compute = |worker: &WorkerNode| {
1043                        worker.get_type() == Ok(WorkerType::ComputeNode)
1044                            && worker.property.as_ref().unwrap().is_streaming
1045                    };
1046
1047                    match notification {
1048                        LocalNotification::SystemParamsChange(reader) => {
1049                            let new_strategy = reader.adaptive_parallelism_strategy();
1050                            if new_strategy != previous_adaptive_parallelism_strategy {
1051                                tracing::info!("adaptive parallelism strategy changed from {:?} to {:?}", previous_adaptive_parallelism_strategy, new_strategy);
1052                                should_trigger = true;
1053                                previous_adaptive_parallelism_strategy = new_strategy;
1054                            }
1055                        }
1056                        LocalNotification::WorkerNodeActivated(worker) => {
1057                            if !worker_is_streaming_compute(&worker) {
1058                                continue;
1059                            }
1060
1061                            tracing::info!(worker = %worker.id, "worker activated notification received");
1062
1063                            let prev_worker = worker_cache.insert(worker.id, worker.clone());
1064
1065                            match prev_worker {
1066                                Some(prev_worker) if prev_worker.compute_node_parallelism() != worker.compute_node_parallelism()  => {
1067                                    tracing::info!(worker = %worker.id, "worker parallelism changed");
1068                                    should_trigger = true;
1069                                }
1070                                Some(prev_worker) if  prev_worker.resource_group() != worker.resource_group()  => {
1071                                    tracing::info!(worker = %worker.id, "worker label changed");
1072                                    should_trigger = true;
1073                                }
1074                                None => {
1075                                    tracing::info!(worker = %worker.id, "new worker joined");
1076                                    should_trigger = true;
1077                                }
1078                                _ => {}
1079                            }
1080                        }
1081
1082                        // Since our logic for handling passive scale-in is within the barrier manager,
1083                        // there’s not much we can do here. All we can do is proactively remove the entries from our cache.
1084                        LocalNotification::WorkerNodeDeleted(worker) => {
1085                            if !worker_is_streaming_compute(&worker) {
1086                                continue;
1087                            }
1088
1089                            match worker_cache.remove(&worker.id) {
1090                                Some(prev_worker) => {
1091                                    tracing::info!(worker = %prev_worker.id, "worker removed from stream manager cache");
1092                                }
1093                                None => {
1094                                    tracing::warn!(worker = %worker.id, "worker not found in stream manager cache, but it was removed");
1095                                }
1096                            }
1097                        }
1098
1099                        _ => {}
1100                    }
1101                }
1102            }
1103        }
1104    }
1105
1106    pub fn start_auto_parallelism_monitor(
1107        self: Arc<Self>,
1108    ) -> (JoinHandle<()>, oneshot::Sender<()>) {
1109        tracing::info!("Automatic parallelism scale-out is enabled for streaming jobs");
1110        let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
1111        let join_handle = tokio::spawn(async move {
1112            self.run(shutdown_rx).await;
1113        });
1114
1115        (join_handle, shutdown_tx)
1116    }
1117}