risingwave_meta/stream/
scale.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap, HashSet};
16use std::fmt::Debug;
17use std::num::NonZeroUsize;
18use std::sync::Arc;
19use std::time::Duration;
20
21use anyhow::anyhow;
22use futures::future;
23use itertools::Itertools;
24use risingwave_common::bail;
25use risingwave_common::catalog::{DatabaseId, FragmentTypeFlag};
26use risingwave_common::hash::ActorMapping;
27use risingwave_meta_model::{StreamingParallelism, WorkerId, fragment, fragment_relation};
28use risingwave_pb::common::{PbWorkerNode, WorkerNode, WorkerType};
29use risingwave_pb::stream_plan::{PbDispatchOutputMapping, PbDispatcher};
30use sea_orm::{ActiveModelTrait, ConnectionTrait, QuerySelect};
31use thiserror_ext::AsReport;
32use tokio::sync::oneshot::Receiver;
33use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard, oneshot};
34use tokio::task::JoinHandle;
35use tokio::time::{Instant, MissedTickBehavior};
36
37use crate::barrier::{Command, Reschedule, SharedFragmentInfo};
38use crate::controller::scale::{
39    FragmentRenderMap, NoShuffleEnsemble, RenderedGraph, WorkerInfo,
40    find_fragment_no_shuffle_dags_detailed, render_fragments, render_jobs,
41};
42use crate::error::bail_invalid_parameter;
43use crate::manager::{LocalNotification, MetaSrvEnv, MetadataManager};
44use crate::model::{
45    ActorId, DispatcherId, FragmentId, StreamActor, StreamActorWithDispatchers, StreamContext,
46};
47use crate::stream::{GlobalStreamManager, SourceManagerRef};
48use crate::{MetaError, MetaResult};
49
50#[derive(Debug, Clone, Eq, PartialEq)]
51pub struct WorkerReschedule {
52    pub worker_actor_diff: BTreeMap<WorkerId, isize>,
53}
54
55use risingwave_common::id::JobId;
56use risingwave_common::system_param::AdaptiveParallelismStrategy;
57use risingwave_common::system_param::reader::SystemParamsRead;
58use risingwave_meta_model::DispatcherType;
59use risingwave_meta_model::fragment::DistributionType;
60use risingwave_meta_model::prelude::{Fragment, FragmentRelation, StreamingJob};
61use sea_orm::ActiveValue::Set;
62use sea_orm::{ColumnTrait, EntityTrait, IntoActiveModel, QueryFilter, TransactionTrait};
63
64use crate::controller::fragment::{InflightActorInfo, InflightFragmentInfo};
65use crate::controller::utils::{
66    StreamingJobExtraInfo, compose_dispatchers, get_streaming_job_extra_info,
67};
68use crate::stream::cdc::assign_cdc_table_snapshot_splits_impl;
69
70pub type ScaleControllerRef = Arc<ScaleController>;
71
72pub struct ScaleController {
73    pub metadata_manager: MetadataManager,
74
75    pub source_manager: SourceManagerRef,
76
77    pub env: MetaSrvEnv,
78
79    /// We will acquire lock during DDL to prevent scaling operations on jobs that are in the creating state.
80    /// e.g., a MV cannot be rescheduled during foreground backfill.
81    pub reschedule_lock: RwLock<()>,
82}
83
84impl ScaleController {
85    pub fn new(
86        metadata_manager: &MetadataManager,
87        source_manager: SourceManagerRef,
88        env: MetaSrvEnv,
89    ) -> Self {
90        Self {
91            metadata_manager: metadata_manager.clone(),
92            source_manager,
93            env,
94            reschedule_lock: RwLock::new(()),
95        }
96    }
97
98    pub async fn resolve_related_no_shuffle_jobs(
99        &self,
100        jobs: &[JobId],
101    ) -> MetaResult<HashSet<JobId>> {
102        let inner = self.metadata_manager.catalog_controller.inner.read().await;
103        let txn = inner.db.begin().await?;
104
105        let fragment_ids: Vec<_> = Fragment::find()
106            .select_only()
107            .column(fragment::Column::FragmentId)
108            .filter(fragment::Column::JobId.is_in(jobs.to_vec()))
109            .into_tuple()
110            .all(&txn)
111            .await?;
112        let ensembles = find_fragment_no_shuffle_dags_detailed(&txn, &fragment_ids).await?;
113        let related_fragments = ensembles
114            .iter()
115            .flat_map(|ensemble| ensemble.fragments())
116            .collect_vec();
117
118        let job_ids: Vec<_> = Fragment::find()
119            .select_only()
120            .column(fragment::Column::JobId)
121            .filter(fragment::Column::FragmentId.is_in(related_fragments))
122            .into_tuple()
123            .all(&txn)
124            .await?;
125
126        let job_ids = job_ids.into_iter().collect();
127
128        Ok(job_ids)
129    }
130
131    pub fn diff_fragment(
132        &self,
133        prev_fragment_info: &SharedFragmentInfo,
134        curr_actors: &HashMap<ActorId, InflightActorInfo>,
135        upstream_fragments: HashMap<FragmentId, DispatcherType>,
136        downstream_fragments: HashMap<FragmentId, DispatcherType>,
137        all_actor_dispatchers: HashMap<ActorId, Vec<PbDispatcher>>,
138        job_extra_info: Option<&StreamingJobExtraInfo>,
139    ) -> MetaResult<Reschedule> {
140        let prev_actors: HashMap<_, _> = prev_fragment_info
141            .actors
142            .iter()
143            .map(|(actor_id, actor)| (*actor_id, actor))
144            .collect();
145
146        let prev_ids: HashSet<_> = prev_actors.keys().cloned().collect();
147        let curr_ids: HashSet<_> = curr_actors.keys().cloned().collect();
148
149        let removed_actors: HashSet<_> = &prev_ids - &curr_ids;
150        let added_actor_ids: HashSet<_> = &curr_ids - &prev_ids;
151        let kept_ids: HashSet<_> = prev_ids.intersection(&curr_ids).cloned().collect();
152
153        let mut added_actors = HashMap::new();
154        for &actor_id in &added_actor_ids {
155            let InflightActorInfo { worker_id, .. } = curr_actors
156                .get(&actor_id)
157                .ok_or_else(|| anyhow!("BUG: Worker not found for new actor {}", actor_id))?;
158
159            added_actors
160                .entry(*worker_id)
161                .or_insert_with(Vec::new)
162                .push(actor_id);
163        }
164
165        let mut vnode_bitmap_updates = HashMap::new();
166        for actor_id in kept_ids {
167            let prev_actor = prev_actors[&actor_id];
168            let curr_actor = &curr_actors[&actor_id];
169
170            // Check if the vnode distribution has changed.
171            if prev_actor.vnode_bitmap != curr_actor.vnode_bitmap
172                && let Some(bitmap) = curr_actor.vnode_bitmap.clone()
173            {
174                vnode_bitmap_updates.insert(actor_id, bitmap);
175            }
176        }
177
178        let upstream_dispatcher_mapping =
179            if let DistributionType::Hash = prev_fragment_info.distribution_type {
180                let actor_mapping = curr_actors
181                    .iter()
182                    .map(
183                        |(
184                            actor_id,
185                            InflightActorInfo {
186                                worker_id: _,
187                                vnode_bitmap,
188                                ..
189                            },
190                        )| { (*actor_id, vnode_bitmap.clone().unwrap()) },
191                    )
192                    .collect();
193                Some(ActorMapping::from_bitmaps(&actor_mapping))
194            } else {
195                None
196            };
197
198        let upstream_fragment_dispatcher_ids = upstream_fragments
199            .iter()
200            .filter(|&(_, dispatcher_type)| *dispatcher_type != DispatcherType::NoShuffle)
201            .map(|(upstream_fragment, _)| {
202                (
203                    *upstream_fragment,
204                    prev_fragment_info.fragment_id.as_raw_id() as DispatcherId,
205                )
206            })
207            .collect();
208
209        let downstream_fragment_ids = downstream_fragments
210            .iter()
211            .filter(|&(_, dispatcher_type)| *dispatcher_type != DispatcherType::NoShuffle)
212            .map(|(fragment_id, _)| *fragment_id)
213            .collect();
214
215        let extra_info = job_extra_info.cloned().unwrap_or_default();
216        let timezone = extra_info.timezone.clone();
217        let job_definition = extra_info.job_definition;
218
219        let newly_created_actors: HashMap<ActorId, (StreamActorWithDispatchers, WorkerId)> =
220            added_actor_ids
221                .iter()
222                .map(|actor_id| {
223                    let actor = StreamActor {
224                        actor_id: *actor_id,
225                        fragment_id: prev_fragment_info.fragment_id,
226                        vnode_bitmap: curr_actors[actor_id].vnode_bitmap.clone(),
227                        mview_definition: job_definition.clone(),
228                        expr_context: Some(
229                            StreamContext {
230                                timezone: timezone.clone(),
231                            }
232                            .to_expr_context(),
233                        ),
234                    };
235                    (
236                        *actor_id,
237                        (
238                            (
239                                actor,
240                                all_actor_dispatchers
241                                    .get(actor_id)
242                                    .cloned()
243                                    .unwrap_or_default(),
244                            ),
245                            curr_actors[actor_id].worker_id,
246                        ),
247                    )
248                })
249                .collect();
250
251        let actor_splits = curr_actors
252            .iter()
253            .map(|(&actor_id, info)| (actor_id, info.splits.clone()))
254            .collect();
255
256        let reschedule = Reschedule {
257            added_actors,
258            removed_actors,
259            vnode_bitmap_updates,
260            upstream_fragment_dispatcher_ids,
261            upstream_dispatcher_mapping,
262            downstream_fragment_ids,
263            newly_created_actors,
264            actor_splits,
265            cdc_table_snapshot_split_assignment: Default::default(),
266            cdc_table_job_id: None,
267        };
268
269        Ok(reschedule)
270    }
271
272    pub async fn reschedule_inplace(
273        &self,
274        policy: HashMap<JobId, ReschedulePolicy>,
275        workers: HashMap<WorkerId, PbWorkerNode>,
276    ) -> MetaResult<HashMap<DatabaseId, Command>> {
277        let inner = self.metadata_manager.catalog_controller.inner.read().await;
278        let txn = inner.db.begin().await?;
279
280        for (table_id, target) in &policy {
281            let streaming_job = StreamingJob::find_by_id(*table_id)
282                .one(&txn)
283                .await?
284                .ok_or_else(|| MetaError::catalog_id_not_found("table", table_id))?;
285
286            let max_parallelism = streaming_job.max_parallelism;
287
288            let mut streaming_job = streaming_job.into_active_model();
289
290            match &target {
291                ReschedulePolicy::Parallelism(p) | ReschedulePolicy::Both(p, _) => {
292                    if let StreamingParallelism::Fixed(n) = p.parallelism
293                        && n > max_parallelism as usize
294                    {
295                        bail!(format!(
296                            "specified parallelism {n} should not exceed max parallelism {max_parallelism}"
297                        ));
298                    }
299
300                    streaming_job.parallelism = Set(p.parallelism.clone());
301                }
302                _ => {}
303            }
304
305            match &target {
306                ReschedulePolicy::ResourceGroup(r) | ReschedulePolicy::Both(_, r) => {
307                    streaming_job.specific_resource_group = Set(r.resource_group.clone());
308                }
309                _ => {}
310            }
311
312            streaming_job.update(&txn).await?;
313        }
314
315        let jobs = policy.keys().copied().collect();
316
317        let workers = workers
318            .into_iter()
319            .map(|(id, worker)| {
320                (
321                    id,
322                    WorkerInfo {
323                        parallelism: NonZeroUsize::new(worker.compute_node_parallelism()).unwrap(),
324                        resource_group: worker.resource_group(),
325                    },
326                )
327            })
328            .collect();
329
330        let command = self.rerender_inner(&txn, jobs, workers).await?;
331
332        txn.commit().await?;
333
334        Ok(command)
335    }
336
337    pub async fn reschedule_fragment_inplace(
338        &self,
339        policy: HashMap<risingwave_meta_model::FragmentId, Option<StreamingParallelism>>,
340        workers: HashMap<WorkerId, PbWorkerNode>,
341    ) -> MetaResult<HashMap<DatabaseId, Command>> {
342        if policy.is_empty() {
343            return Ok(HashMap::new());
344        }
345
346        let inner = self.metadata_manager.catalog_controller.inner.read().await;
347        let txn = inner.db.begin().await?;
348
349        let fragment_id_list = policy.keys().copied().collect_vec();
350
351        let existing_fragment_ids: HashSet<_> = Fragment::find()
352            .select_only()
353            .column(fragment::Column::FragmentId)
354            .filter(fragment::Column::FragmentId.is_in(fragment_id_list.clone()))
355            .into_tuple::<risingwave_meta_model::FragmentId>()
356            .all(&txn)
357            .await?
358            .into_iter()
359            .collect();
360
361        if let Some(missing_fragment_id) = fragment_id_list
362            .iter()
363            .find(|fragment_id| !existing_fragment_ids.contains(fragment_id))
364        {
365            return Err(MetaError::catalog_id_not_found(
366                "fragment",
367                *missing_fragment_id,
368            ));
369        }
370
371        let mut target_ensembles = vec![];
372
373        for ensemble in find_fragment_no_shuffle_dags_detailed(&txn, &fragment_id_list).await? {
374            let entry_fragment_ids = ensemble.entry_fragments().collect_vec();
375
376            let desired_parallelism = match entry_fragment_ids
377                .iter()
378                .filter_map(|fragment_id| policy.get(fragment_id).cloned())
379                .dedup()
380                .collect_vec()
381                .as_slice()
382            {
383                [] => {
384                    bail_invalid_parameter!(
385                        "none of the entry fragments {:?} were included in the reschedule request; \
386                         provide at least one entry fragment id",
387                        entry_fragment_ids
388                    );
389                }
390                [parallelism] => parallelism.clone(),
391                parallelisms => {
392                    bail!(
393                        "conflicting reschedule policies for fragments in the same no-shuffle ensemble: {:?}",
394                        parallelisms
395                    );
396                }
397            };
398
399            let fragments = Fragment::find()
400                .filter(fragment::Column::FragmentId.is_in(entry_fragment_ids))
401                .all(&txn)
402                .await?;
403
404            debug_assert!(
405                fragments
406                    .iter()
407                    .map(|fragment| fragment.parallelism.as_ref())
408                    .all_equal(),
409                "entry fragments in the same ensemble should share the same parallelism"
410            );
411
412            let current_parallelism = fragments
413                .first()
414                .and_then(|fragment| fragment.parallelism.clone());
415
416            if current_parallelism == desired_parallelism {
417                continue;
418            }
419
420            for fragment in fragments {
421                let mut fragment = fragment.into_active_model();
422                fragment.parallelism = Set(desired_parallelism.clone());
423                fragment.update(&txn).await?;
424            }
425
426            target_ensembles.push(ensemble);
427        }
428
429        let workers = workers
430            .into_iter()
431            .map(|(id, worker)| {
432                (
433                    id,
434                    WorkerInfo {
435                        parallelism: NonZeroUsize::new(worker.compute_node_parallelism()).unwrap(),
436                        resource_group: worker.resource_group(),
437                    },
438                )
439            })
440            .collect();
441
442        let command = self
443            .rerender_fragment_inner(&txn, target_ensembles, workers)
444            .await?;
445
446        txn.commit().await?;
447
448        Ok(command)
449    }
450
451    async fn rerender(
452        &self,
453        jobs: HashSet<JobId>,
454        workers: BTreeMap<WorkerId, WorkerInfo>,
455    ) -> MetaResult<HashMap<DatabaseId, Command>> {
456        let inner = self.metadata_manager.catalog_controller.inner.read().await;
457        self.rerender_inner(&inner.db, jobs, workers).await
458    }
459
460    async fn rerender_fragment_inner(
461        &self,
462        txn: &impl ConnectionTrait,
463        ensembles: Vec<NoShuffleEnsemble>,
464        workers: BTreeMap<WorkerId, WorkerInfo>,
465    ) -> MetaResult<HashMap<DatabaseId, Command>> {
466        if ensembles.is_empty() {
467            return Ok(HashMap::new());
468        }
469
470        let adaptive_parallelism_strategy = {
471            let system_params_reader = self.env.system_params_reader().await;
472            system_params_reader.adaptive_parallelism_strategy()
473        };
474
475        let RenderedGraph { fragments, .. } = render_fragments(
476            txn,
477            self.env.actor_id_generator(),
478            ensembles,
479            workers,
480            adaptive_parallelism_strategy,
481        )
482        .await?;
483
484        self.build_reschedule_commands(txn, fragments).await
485    }
486
487    async fn rerender_inner(
488        &self,
489        txn: &impl ConnectionTrait,
490        jobs: HashSet<JobId>,
491        workers: BTreeMap<WorkerId, WorkerInfo>,
492    ) -> MetaResult<HashMap<DatabaseId, Command>> {
493        let adaptive_parallelism_strategy = {
494            let system_params_reader = self.env.system_params_reader().await;
495            system_params_reader.adaptive_parallelism_strategy()
496        };
497
498        let RenderedGraph { fragments, .. } = render_jobs(
499            txn,
500            self.env.actor_id_generator(),
501            jobs,
502            workers,
503            adaptive_parallelism_strategy,
504        )
505        .await?;
506
507        self.build_reschedule_commands(txn, fragments).await
508    }
509
510    async fn build_reschedule_commands(
511        &self,
512        txn: &impl ConnectionTrait,
513        render_result: FragmentRenderMap,
514    ) -> MetaResult<HashMap<DatabaseId, Command>> {
515        if render_result.is_empty() {
516            return Ok(HashMap::new());
517        }
518
519        let job_ids = render_result
520            .values()
521            .flat_map(|jobs| jobs.keys().copied())
522            .collect_vec();
523
524        let job_extra_info = get_streaming_job_extra_info(txn, job_ids).await?;
525
526        let fragment_ids = render_result
527            .values()
528            .flat_map(|jobs| jobs.values())
529            .flatten()
530            .map(|(fragment_id, _)| *fragment_id)
531            .collect_vec();
532
533        let upstreams: Vec<(
534            risingwave_meta_model::FragmentId,
535            risingwave_meta_model::FragmentId,
536            DispatcherType,
537        )> = FragmentRelation::find()
538            .select_only()
539            .columns([
540                fragment_relation::Column::TargetFragmentId,
541                fragment_relation::Column::SourceFragmentId,
542                fragment_relation::Column::DispatcherType,
543            ])
544            .filter(fragment_relation::Column::TargetFragmentId.is_in(fragment_ids.clone()))
545            .into_tuple()
546            .all(txn)
547            .await?;
548
549        let downstreams = FragmentRelation::find()
550            .filter(fragment_relation::Column::SourceFragmentId.is_in(fragment_ids.clone()))
551            .all(txn)
552            .await?;
553
554        let mut all_upstream_fragments = HashMap::new();
555
556        for (fragment, upstream, dispatcher) in upstreams {
557            let fragment_id = fragment as FragmentId;
558            let upstream_id = upstream as FragmentId;
559            all_upstream_fragments
560                .entry(fragment_id)
561                .or_insert(HashMap::new())
562                .insert(upstream_id, dispatcher);
563        }
564
565        let mut all_downstream_fragments = HashMap::new();
566
567        let mut downstream_relations = HashMap::new();
568        for relation in downstreams {
569            let source_fragment_id = relation.source_fragment_id as FragmentId;
570            let target_fragment_id = relation.target_fragment_id as FragmentId;
571            all_downstream_fragments
572                .entry(source_fragment_id)
573                .or_insert(HashMap::new())
574                .insert(target_fragment_id, relation.dispatcher_type);
575
576            downstream_relations.insert((source_fragment_id, target_fragment_id), relation);
577        }
578
579        let all_related_fragment_ids: HashSet<_> = fragment_ids
580            .iter()
581            .copied()
582            .chain(all_upstream_fragments.values().flatten().map(|(id, _)| *id))
583            .chain(
584                all_downstream_fragments
585                    .values()
586                    .flatten()
587                    .map(|(id, _)| *id),
588            )
589            .collect();
590
591        let all_related_fragment_ids = all_related_fragment_ids.into_iter().collect_vec();
592
593        // let all_fragments_from_db: HashMap<_, _> = Fragment::find()
594        //     .filter(fragment::Column::FragmentId.is_in(all_related_fragment_ids.clone()))
595        //     .all(&txn)
596        //     .await?
597        //     .into_iter()
598        //     .map(|f| (f.fragment_id, f))
599        //     .collect();
600
601        let all_prev_fragments: HashMap<_, _> = {
602            let read_guard = self.env.shared_actor_infos().read_guard();
603            all_related_fragment_ids
604                .iter()
605                .map(|&fragment_id| {
606                    read_guard
607                        .get_fragment(fragment_id as FragmentId)
608                        .cloned()
609                        .map(|fragment| (fragment_id, fragment))
610                        .ok_or_else(|| {
611                            MetaError::from(anyhow!(
612                                "previous fragment info for {fragment_id} not found"
613                            ))
614                        })
615                })
616                .collect::<MetaResult<_>>()?
617        };
618
619        let all_rendered_fragments: HashMap<_, _> = render_result
620            .values()
621            .flat_map(|jobs| jobs.values())
622            .flatten()
623            .map(|(fragment_id, info)| (*fragment_id, info))
624            .collect();
625
626        let mut commands = HashMap::new();
627
628        for (database_id, jobs) in &render_result {
629            let mut all_fragment_actors = HashMap::new();
630            let mut reschedules = HashMap::new();
631
632            for (job_id, fragment_id, fragment_info) in
633                jobs.iter().flat_map(|(job_id, fragments)| {
634                    fragments
635                        .iter()
636                        .map(move |(fragment_id, info)| (job_id, fragment_id, info))
637                })
638            {
639                let InflightFragmentInfo {
640                    distribution_type,
641                    actors,
642                    ..
643                } = fragment_info;
644
645                let upstream_fragments = all_upstream_fragments
646                    .remove(&(*fragment_id as FragmentId))
647                    .unwrap_or_default();
648                let downstream_fragments = all_downstream_fragments
649                    .remove(&(*fragment_id as FragmentId))
650                    .unwrap_or_default();
651
652                let fragment_actors: HashMap<_, _> = upstream_fragments
653                    .keys()
654                    .copied()
655                    .chain(downstream_fragments.keys().copied())
656                    .map(|fragment_id| {
657                        all_prev_fragments
658                            .get(&fragment_id)
659                            .map(|fragment| {
660                                (
661                                    fragment_id,
662                                    fragment.actors.keys().copied().collect::<HashSet<_>>(),
663                                )
664                            })
665                            .ok_or_else(|| {
666                                MetaError::from(anyhow!(
667                                    "fragment {} not found in previous state",
668                                    fragment_id
669                                ))
670                            })
671                    })
672                    .collect::<MetaResult<_>>()?;
673
674                all_fragment_actors.extend(fragment_actors);
675
676                let source_fragment_actors = actors
677                    .iter()
678                    .map(|(actor_id, info)| (*actor_id, info.vnode_bitmap.clone()))
679                    .collect();
680
681                let mut all_actor_dispatchers: HashMap<_, Vec<_>> = HashMap::new();
682
683                for downstream_fragment_id in downstream_fragments.keys() {
684                    let target_fragment_actors =
685                        match all_rendered_fragments.get(downstream_fragment_id) {
686                            None => {
687                                let external_fragment = all_prev_fragments
688                                    .get(downstream_fragment_id)
689                                    .ok_or_else(|| {
690                                        MetaError::from(anyhow!(
691                                            "fragment {} not found in previous state",
692                                            downstream_fragment_id
693                                        ))
694                                    })?;
695
696                                external_fragment
697                                    .actors
698                                    .iter()
699                                    .map(|(actor_id, info)| (*actor_id, info.vnode_bitmap.clone()))
700                                    .collect()
701                            }
702                            Some(downstream_rendered) => downstream_rendered
703                                .actors
704                                .iter()
705                                .map(|(actor_id, info)| (*actor_id, info.vnode_bitmap.clone()))
706                                .collect(),
707                        };
708
709                    let target_fragment_distribution = *distribution_type;
710
711                    let fragment_relation::Model {
712                        source_fragment_id: _,
713                        target_fragment_id: _,
714                        dispatcher_type,
715                        dist_key_indices,
716                        output_indices,
717                        output_type_mapping,
718                    } = downstream_relations
719                        .remove(&(
720                            *fragment_id as FragmentId,
721                            *downstream_fragment_id as FragmentId,
722                        ))
723                        .ok_or_else(|| {
724                            MetaError::from(anyhow!(
725                                "downstream relation missing for {} -> {}",
726                                fragment_id,
727                                downstream_fragment_id
728                            ))
729                        })?;
730
731                    let pb_mapping = PbDispatchOutputMapping {
732                        indices: output_indices.into_u32_array(),
733                        types: output_type_mapping.unwrap_or_default().to_protobuf(),
734                    };
735
736                    let dispatchers = compose_dispatchers(
737                        *distribution_type,
738                        &source_fragment_actors,
739                        *downstream_fragment_id,
740                        target_fragment_distribution,
741                        &target_fragment_actors,
742                        dispatcher_type,
743                        dist_key_indices.into_u32_array(),
744                        pb_mapping,
745                    );
746
747                    for (actor_id, dispatcher) in dispatchers {
748                        all_actor_dispatchers
749                            .entry(actor_id)
750                            .or_default()
751                            .push(dispatcher);
752                    }
753                }
754
755                let prev_fragment = all_prev_fragments.get(&{ *fragment_id }).ok_or_else(|| {
756                    MetaError::from(anyhow!(
757                        "fragment {} not found in previous state",
758                        fragment_id
759                    ))
760                })?;
761
762                let mut reschedule = self.diff_fragment(
763                    prev_fragment,
764                    actors,
765                    upstream_fragments,
766                    downstream_fragments,
767                    all_actor_dispatchers,
768                    job_extra_info.get(job_id),
769                )?;
770
771                // We only handle CDC splits at this stage, so it should have been empty before.
772                debug_assert!(reschedule.cdc_table_job_id.is_none());
773                debug_assert!(reschedule.cdc_table_snapshot_split_assignment.is_empty());
774                let cdc_info = if fragment_info
775                    .fragment_type_mask
776                    .contains(FragmentTypeFlag::StreamCdcScan)
777                {
778                    let assignment = assign_cdc_table_snapshot_splits_impl(
779                        *job_id,
780                        actors.keys().copied().collect(),
781                        self.env.meta_store_ref(),
782                        None,
783                    )
784                    .await?;
785                    Some((job_id, assignment))
786                } else {
787                    None
788                };
789
790                if let Some((cdc_table_id, cdc_table_snapshot_split_assignment)) = cdc_info {
791                    reschedule.cdc_table_job_id = Some(*cdc_table_id);
792                    reschedule.cdc_table_snapshot_split_assignment =
793                        cdc_table_snapshot_split_assignment;
794                }
795
796                reschedules.insert(*fragment_id as FragmentId, reschedule);
797            }
798
799            let command = Command::RescheduleFragment {
800                reschedules,
801                fragment_actors: all_fragment_actors,
802            };
803
804            commands.insert(*database_id, command);
805        }
806
807        Ok(commands)
808    }
809}
810
811#[derive(Clone, Debug, Eq, PartialEq)]
812pub struct ParallelismPolicy {
813    pub parallelism: StreamingParallelism,
814}
815
816#[derive(Clone, Debug)]
817pub struct ResourceGroupPolicy {
818    pub resource_group: Option<String>,
819}
820
821#[derive(Clone, Debug)]
822pub enum ReschedulePolicy {
823    Parallelism(ParallelismPolicy),
824    ResourceGroup(ResourceGroupPolicy),
825    Both(ParallelismPolicy, ResourceGroupPolicy),
826}
827
828impl GlobalStreamManager {
829    #[await_tree::instrument("acquire_reschedule_read_guard")]
830    pub async fn reschedule_lock_read_guard(&self) -> RwLockReadGuard<'_, ()> {
831        self.scale_controller.reschedule_lock.read().await
832    }
833
834    #[await_tree::instrument("acquire_reschedule_write_guard")]
835    pub async fn reschedule_lock_write_guard(&self) -> RwLockWriteGuard<'_, ()> {
836        self.scale_controller.reschedule_lock.write().await
837    }
838
839    /// When new worker nodes joined, or the parallelism of existing worker nodes changed,
840    /// examines if there are any jobs can be scaled, and scales them if found.
841    ///
842    /// This method will iterate over all `CREATED` jobs, and can be repeatedly called.
843    ///
844    /// Returns
845    /// - `Ok(false)` if no jobs can be scaled;
846    /// - `Ok(true)` if some jobs are scaled, and it is possible that there are more jobs can be scaled.
847    async fn trigger_parallelism_control(&self) -> MetaResult<bool> {
848        tracing::info!("trigger parallelism control");
849
850        let _reschedule_job_lock = self.reschedule_lock_write_guard().await;
851
852        let background_streaming_jobs = self
853            .metadata_manager
854            .list_background_creating_jobs()
855            .await?;
856
857        let skipped_jobs = if !background_streaming_jobs.is_empty() {
858            let jobs = self
859                .scale_controller
860                .resolve_related_no_shuffle_jobs(&background_streaming_jobs)
861                .await?;
862
863            tracing::info!(
864                "skipping parallelism control of background jobs {:?} and associated jobs {:?}",
865                background_streaming_jobs,
866                jobs
867            );
868
869            jobs
870        } else {
871            HashSet::new()
872        };
873
874        let database_objects: HashMap<risingwave_meta_model::DatabaseId, Vec<JobId>> = self
875            .metadata_manager
876            .catalog_controller
877            .list_streaming_job_with_database()
878            .await?;
879
880        let job_ids = database_objects
881            .iter()
882            .flat_map(|(database_id, job_ids)| {
883                job_ids
884                    .iter()
885                    .enumerate()
886                    .map(move |(idx, job_id)| (idx, database_id, job_id))
887            })
888            .sorted_by(|(idx_a, database_a, _), (idx_b, database_b, _)| {
889                idx_a.cmp(idx_b).then(database_a.cmp(database_b))
890            })
891            .map(|(_, database_id, job_id)| (*database_id, *job_id))
892            .filter(|(_, job_id)| !skipped_jobs.contains(job_id))
893            .collect_vec();
894
895        if job_ids.is_empty() {
896            tracing::info!("no streaming jobs for scaling, maybe an empty cluster");
897            return Ok(false);
898        }
899
900        let workers = self
901            .metadata_manager
902            .cluster_controller
903            .list_active_streaming_workers()
904            .await?;
905
906        let schedulable_workers: BTreeMap<_, _> = workers
907            .iter()
908            .filter(|worker| {
909                !worker
910                    .property
911                    .as_ref()
912                    .map(|p| p.is_unschedulable)
913                    .unwrap_or(false)
914            })
915            .map(|worker| {
916                (
917                    worker.id,
918                    WorkerInfo {
919                        parallelism: NonZeroUsize::new(worker.compute_node_parallelism()).unwrap(),
920                        resource_group: worker.resource_group(),
921                    },
922                )
923            })
924            .collect();
925
926        if job_ids.is_empty() {
927            tracing::info!("no streaming jobs for scaling, maybe an empty cluster");
928            return Ok(false);
929        }
930
931        tracing::info!(
932            "trigger parallelism control for jobs: {:#?}, workers {:#?}",
933            job_ids,
934            schedulable_workers
935        );
936
937        let batch_size = match self.env.opts.parallelism_control_batch_size {
938            0 => job_ids.len(),
939            n => n,
940        };
941
942        tracing::info!(
943            "total {} streaming jobs, batch size {}, schedulable worker ids: {:?}",
944            job_ids.len(),
945            batch_size,
946            schedulable_workers
947        );
948
949        let batches: Vec<_> = job_ids
950            .into_iter()
951            .chunks(batch_size)
952            .into_iter()
953            .map(|chunk| chunk.collect_vec())
954            .collect();
955
956        for batch in batches {
957            let jobs = batch.iter().map(|(_, job_id)| *job_id).collect();
958
959            let commands = self
960                .scale_controller
961                .rerender(jobs, schedulable_workers.clone())
962                .await?;
963
964            let futures = commands.into_iter().map(|(database_id, command)| {
965                let barrier_scheduler = self.barrier_scheduler.clone();
966                async move { barrier_scheduler.run_command(database_id, command).await }
967            });
968
969            let _results = future::try_join_all(futures).await?;
970        }
971
972        Ok(false)
973    }
974
975    /// Handles notification of worker node activation and deletion, and triggers parallelism control.
976    async fn run(&self, mut shutdown_rx: Receiver<()>) {
977        tracing::info!("starting automatic parallelism control monitor");
978
979        let check_period =
980            Duration::from_secs(self.env.opts.parallelism_control_trigger_period_sec);
981
982        let mut ticker = tokio::time::interval_at(
983            Instant::now()
984                + Duration::from_secs(self.env.opts.parallelism_control_trigger_first_delay_sec),
985            check_period,
986        );
987        ticker.set_missed_tick_behavior(MissedTickBehavior::Skip);
988
989        // waiting for the first tick
990        ticker.tick().await;
991
992        let (local_notification_tx, mut local_notification_rx) =
993            tokio::sync::mpsc::unbounded_channel();
994
995        self.env
996            .notification_manager()
997            .insert_local_sender(local_notification_tx);
998
999        let worker_nodes = self
1000            .metadata_manager
1001            .list_active_streaming_compute_nodes()
1002            .await
1003            .expect("list active streaming compute nodes");
1004
1005        let mut worker_cache: BTreeMap<_, _> = worker_nodes
1006            .into_iter()
1007            .map(|worker| (worker.id, worker))
1008            .collect();
1009
1010        let mut previous_adaptive_parallelism_strategy = AdaptiveParallelismStrategy::default();
1011
1012        let mut should_trigger = false;
1013
1014        loop {
1015            tokio::select! {
1016                biased;
1017
1018                _ = &mut shutdown_rx => {
1019                    tracing::info!("Stream manager is stopped");
1020                    break;
1021                }
1022
1023                _ = ticker.tick(), if should_trigger => {
1024                    let include_workers = worker_cache.keys().copied().collect_vec();
1025
1026                    if include_workers.is_empty() {
1027                        tracing::debug!("no available worker nodes");
1028                        should_trigger = false;
1029                        continue;
1030                    }
1031
1032                    match self.trigger_parallelism_control().await {
1033                        Ok(cont) => {
1034                            should_trigger = cont;
1035                        }
1036                        Err(e) => {
1037                            tracing::warn!(error = %e.as_report(), "Failed to trigger scale out, waiting for next tick to retry after {}s", ticker.period().as_secs());
1038                            ticker.reset();
1039                        }
1040                    }
1041                }
1042
1043                notification = local_notification_rx.recv() => {
1044                    let notification = notification.expect("local notification channel closed in loop of stream manager");
1045
1046                    // Only maintain the cache for streaming compute nodes.
1047                    let worker_is_streaming_compute = |worker: &WorkerNode| {
1048                        worker.get_type() == Ok(WorkerType::ComputeNode)
1049                            && worker.property.as_ref().unwrap().is_streaming
1050                    };
1051
1052                    match notification {
1053                        LocalNotification::SystemParamsChange(reader) => {
1054                            let new_strategy = reader.adaptive_parallelism_strategy();
1055                            if new_strategy != previous_adaptive_parallelism_strategy {
1056                                tracing::info!("adaptive parallelism strategy changed from {:?} to {:?}", previous_adaptive_parallelism_strategy, new_strategy);
1057                                should_trigger = true;
1058                                previous_adaptive_parallelism_strategy = new_strategy;
1059                            }
1060                        }
1061                        LocalNotification::WorkerNodeActivated(worker) => {
1062                            if !worker_is_streaming_compute(&worker) {
1063                                continue;
1064                            }
1065
1066                            tracing::info!(worker = %worker.id, "worker activated notification received");
1067
1068                            let prev_worker = worker_cache.insert(worker.id, worker.clone());
1069
1070                            match prev_worker {
1071                                Some(prev_worker) if prev_worker.compute_node_parallelism() != worker.compute_node_parallelism()  => {
1072                                    tracing::info!(worker = %worker.id, "worker parallelism changed");
1073                                    should_trigger = true;
1074                                }
1075                                Some(prev_worker) if  prev_worker.resource_group() != worker.resource_group()  => {
1076                                    tracing::info!(worker = %worker.id, "worker label changed");
1077                                    should_trigger = true;
1078                                }
1079                                None => {
1080                                    tracing::info!(worker = %worker.id, "new worker joined");
1081                                    should_trigger = true;
1082                                }
1083                                _ => {}
1084                            }
1085                        }
1086
1087                        // Since our logic for handling passive scale-in is within the barrier manager,
1088                        // there’s not much we can do here. All we can do is proactively remove the entries from our cache.
1089                        LocalNotification::WorkerNodeDeleted(worker) => {
1090                            if !worker_is_streaming_compute(&worker) {
1091                                continue;
1092                            }
1093
1094                            match worker_cache.remove(&worker.id) {
1095                                Some(prev_worker) => {
1096                                    tracing::info!(worker = %prev_worker.id, "worker removed from stream manager cache");
1097                                }
1098                                None => {
1099                                    tracing::warn!(worker = %worker.id, "worker not found in stream manager cache, but it was removed");
1100                                }
1101                            }
1102                        }
1103
1104                        _ => {}
1105                    }
1106                }
1107            }
1108        }
1109    }
1110
1111    pub fn start_auto_parallelism_monitor(
1112        self: Arc<Self>,
1113    ) -> (JoinHandle<()>, oneshot::Sender<()>) {
1114        tracing::info!("Automatic parallelism scale-out is enabled for streaming jobs");
1115        let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
1116        let join_handle = tokio::spawn(async move {
1117            self.run(shutdown_rx).await;
1118        });
1119
1120        (join_handle, shutdown_tx)
1121    }
1122}