risingwave_meta/stream/
scale.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap, HashSet};
16use std::fmt::Debug;
17use std::num::NonZeroUsize;
18use std::sync::Arc;
19use std::time::Duration;
20
21use anyhow::anyhow;
22use futures::future;
23use itertools::Itertools;
24use risingwave_common::bail;
25use risingwave_common::catalog::DatabaseId;
26use risingwave_common::hash::ActorMapping;
27use risingwave_meta_model::{StreamingParallelism, WorkerId, fragment, fragment_relation};
28use risingwave_pb::common::{PbWorkerNode, WorkerNode, WorkerType};
29use risingwave_pb::stream_plan::{PbDispatchOutputMapping, PbDispatcher};
30use sea_orm::{ActiveModelTrait, ConnectionTrait, QuerySelect};
31use thiserror_ext::AsReport;
32use tokio::sync::oneshot::Receiver;
33use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard, oneshot};
34use tokio::task::JoinHandle;
35use tokio::time::{Instant, MissedTickBehavior};
36
37use crate::barrier::{Command, Reschedule, SharedFragmentInfo};
38use crate::controller::scale::{
39    FragmentRenderMap, NoShuffleEnsemble, RenderedGraph, WorkerInfo,
40    find_fragment_no_shuffle_dags_detailed, render_fragments, render_jobs,
41};
42use crate::error::bail_invalid_parameter;
43use crate::manager::{LocalNotification, MetaSrvEnv, MetadataManager};
44use crate::model::{ActorId, FragmentId, StreamActor, StreamActorWithDispatchers};
45use crate::stream::{GlobalStreamManager, SourceManagerRef};
46use crate::{MetaError, MetaResult};
47
48#[derive(Debug, Clone, Eq, PartialEq)]
49pub struct WorkerReschedule {
50    pub worker_actor_diff: BTreeMap<WorkerId, isize>,
51}
52
53use risingwave_common::id::JobId;
54use risingwave_common::system_param::AdaptiveParallelismStrategy;
55use risingwave_common::system_param::reader::SystemParamsRead;
56use risingwave_meta_model::DispatcherType;
57use risingwave_meta_model::fragment::DistributionType;
58use risingwave_meta_model::prelude::{Fragment, FragmentRelation, StreamingJob};
59use sea_orm::ActiveValue::Set;
60use sea_orm::{ColumnTrait, EntityTrait, IntoActiveModel, QueryFilter, TransactionTrait};
61
62use crate::controller::fragment::{InflightActorInfo, InflightFragmentInfo};
63use crate::controller::utils::{
64    StreamingJobExtraInfo, compose_dispatchers, get_streaming_job_extra_info,
65};
66
67pub type ScaleControllerRef = Arc<ScaleController>;
68
69pub struct ScaleController {
70    pub metadata_manager: MetadataManager,
71
72    pub source_manager: SourceManagerRef,
73
74    pub env: MetaSrvEnv,
75
76    /// We will acquire lock during DDL to prevent scaling operations on jobs that are in the creating state.
77    /// e.g., a MV cannot be rescheduled during foreground backfill.
78    pub reschedule_lock: RwLock<()>,
79}
80
81impl ScaleController {
82    pub fn new(
83        metadata_manager: &MetadataManager,
84        source_manager: SourceManagerRef,
85        env: MetaSrvEnv,
86    ) -> Self {
87        Self {
88            metadata_manager: metadata_manager.clone(),
89            source_manager,
90            env,
91            reschedule_lock: RwLock::new(()),
92        }
93    }
94
95    pub async fn resolve_related_no_shuffle_jobs(
96        &self,
97        jobs: &[JobId],
98    ) -> MetaResult<HashSet<JobId>> {
99        let inner = self.metadata_manager.catalog_controller.inner.read().await;
100        let txn = inner.db.begin().await?;
101
102        let fragment_ids: Vec<_> = Fragment::find()
103            .select_only()
104            .column(fragment::Column::FragmentId)
105            .filter(fragment::Column::JobId.is_in(jobs.to_vec()))
106            .into_tuple()
107            .all(&txn)
108            .await?;
109        let ensembles = find_fragment_no_shuffle_dags_detailed(&txn, &fragment_ids).await?;
110        let related_fragments = ensembles
111            .iter()
112            .flat_map(|ensemble| ensemble.fragments())
113            .collect_vec();
114
115        let job_ids: Vec<_> = Fragment::find()
116            .select_only()
117            .column(fragment::Column::JobId)
118            .filter(fragment::Column::FragmentId.is_in(related_fragments))
119            .into_tuple()
120            .all(&txn)
121            .await?;
122
123        let job_ids = job_ids.into_iter().collect();
124
125        Ok(job_ids)
126    }
127
128    pub fn diff_fragment(
129        &self,
130        prev_fragment_info: &SharedFragmentInfo,
131        curr_actors: &HashMap<ActorId, InflightActorInfo>,
132        upstream_fragments: HashMap<FragmentId, DispatcherType>,
133        downstream_fragments: HashMap<FragmentId, DispatcherType>,
134        all_actor_dispatchers: HashMap<ActorId, Vec<PbDispatcher>>,
135        job_extra_info: Option<&StreamingJobExtraInfo>,
136    ) -> MetaResult<Reschedule> {
137        let prev_actors: HashMap<_, _> = prev_fragment_info
138            .actors
139            .iter()
140            .map(|(actor_id, actor)| (*actor_id, actor))
141            .collect();
142
143        let prev_ids: HashSet<_> = prev_actors.keys().cloned().collect();
144        let curr_ids: HashSet<_> = curr_actors.keys().cloned().collect();
145
146        let removed_actors: HashSet<_> = &prev_ids - &curr_ids;
147        let added_actor_ids: HashSet<_> = &curr_ids - &prev_ids;
148        let kept_ids: HashSet<_> = prev_ids.intersection(&curr_ids).cloned().collect();
149        debug_assert!(
150            kept_ids.is_empty(),
151            "kept actors found in scale; expected full rebuild, prev={prev_ids:?}, curr={curr_ids:?}, kept={kept_ids:?}"
152        );
153
154        let mut added_actors = HashMap::new();
155        for &actor_id in &added_actor_ids {
156            let InflightActorInfo { worker_id, .. } = curr_actors
157                .get(&actor_id)
158                .ok_or_else(|| anyhow!("BUG: Worker not found for new actor {}", actor_id))?;
159
160            added_actors
161                .entry(*worker_id)
162                .or_insert_with(Vec::new)
163                .push(actor_id);
164        }
165
166        let mut vnode_bitmap_updates = HashMap::new();
167        for actor_id in kept_ids {
168            let prev_actor = prev_actors[&actor_id];
169            let curr_actor = &curr_actors[&actor_id];
170
171            // Check if the vnode distribution has changed.
172            if prev_actor.vnode_bitmap != curr_actor.vnode_bitmap
173                && let Some(bitmap) = curr_actor.vnode_bitmap.clone()
174            {
175                vnode_bitmap_updates.insert(actor_id, bitmap);
176            }
177        }
178
179        let upstream_dispatcher_mapping =
180            if let DistributionType::Hash = prev_fragment_info.distribution_type {
181                let actor_mapping = curr_actors
182                    .iter()
183                    .map(
184                        |(
185                            actor_id,
186                            InflightActorInfo {
187                                worker_id: _,
188                                vnode_bitmap,
189                                ..
190                            },
191                        )| { (*actor_id, vnode_bitmap.clone().unwrap()) },
192                    )
193                    .collect();
194                Some(ActorMapping::from_bitmaps(&actor_mapping))
195            } else {
196                None
197            };
198
199        let upstream_fragment_dispatcher_ids = upstream_fragments
200            .iter()
201            .filter(|&(_, dispatcher_type)| *dispatcher_type != DispatcherType::NoShuffle)
202            .map(|(upstream_fragment, _)| (*upstream_fragment, prev_fragment_info.fragment_id))
203            .collect();
204
205        let downstream_fragment_ids = downstream_fragments
206            .iter()
207            .filter(|&(_, dispatcher_type)| *dispatcher_type != DispatcherType::NoShuffle)
208            .map(|(fragment_id, _)| *fragment_id)
209            .collect();
210
211        let extra_info = job_extra_info.cloned().unwrap_or_default();
212        let expr_context = extra_info.stream_context().to_expr_context();
213        let job_definition = extra_info.job_definition;
214        let config_override = extra_info.config_override;
215
216        let newly_created_actors: HashMap<ActorId, (StreamActorWithDispatchers, WorkerId)> =
217            added_actor_ids
218                .iter()
219                .map(|actor_id| {
220                    let actor = StreamActor {
221                        actor_id: *actor_id,
222                        fragment_id: prev_fragment_info.fragment_id,
223                        vnode_bitmap: curr_actors[actor_id].vnode_bitmap.clone(),
224                        mview_definition: job_definition.clone(),
225                        expr_context: Some(expr_context.clone()),
226                        config_override: config_override.clone(),
227                    };
228                    (
229                        *actor_id,
230                        (
231                            (
232                                actor,
233                                all_actor_dispatchers
234                                    .get(actor_id)
235                                    .cloned()
236                                    .unwrap_or_default(),
237                            ),
238                            curr_actors[actor_id].worker_id,
239                        ),
240                    )
241                })
242                .collect();
243
244        let actor_splits = curr_actors
245            .iter()
246            .map(|(&actor_id, info)| (actor_id, info.splits.clone()))
247            .collect();
248
249        let reschedule = Reschedule {
250            added_actors,
251            removed_actors,
252            vnode_bitmap_updates,
253            upstream_fragment_dispatcher_ids,
254            upstream_dispatcher_mapping,
255            downstream_fragment_ids,
256            actor_splits,
257            newly_created_actors,
258        };
259
260        Ok(reschedule)
261    }
262
263    pub async fn reschedule_inplace(
264        &self,
265        policy: HashMap<JobId, ReschedulePolicy>,
266        workers: HashMap<WorkerId, PbWorkerNode>,
267    ) -> MetaResult<HashMap<DatabaseId, Command>> {
268        let inner = self.metadata_manager.catalog_controller.inner.read().await;
269        let txn = inner.db.begin().await?;
270
271        for (table_id, target) in &policy {
272            let streaming_job = StreamingJob::find_by_id(*table_id)
273                .one(&txn)
274                .await?
275                .ok_or_else(|| MetaError::catalog_id_not_found("table", table_id))?;
276
277            let max_parallelism = streaming_job.max_parallelism;
278
279            let mut streaming_job = streaming_job.into_active_model();
280
281            match &target {
282                ReschedulePolicy::Parallelism(p) | ReschedulePolicy::Both(p, _) => {
283                    if let StreamingParallelism::Fixed(n) = p.parallelism
284                        && n > max_parallelism as usize
285                    {
286                        bail!(format!(
287                            "specified parallelism {n} should not exceed max parallelism {max_parallelism}"
288                        ));
289                    }
290
291                    streaming_job.parallelism = Set(p.parallelism.clone());
292                }
293                _ => {}
294            }
295
296            match &target {
297                ReschedulePolicy::ResourceGroup(r) | ReschedulePolicy::Both(_, r) => {
298                    streaming_job.specific_resource_group = Set(r.resource_group.clone());
299                }
300                _ => {}
301            }
302
303            streaming_job.update(&txn).await?;
304        }
305
306        let jobs = policy.keys().copied().collect();
307
308        let workers = workers
309            .into_iter()
310            .map(|(id, worker)| {
311                (
312                    id,
313                    WorkerInfo {
314                        parallelism: NonZeroUsize::new(worker.compute_node_parallelism()).unwrap(),
315                        resource_group: worker.resource_group(),
316                    },
317                )
318            })
319            .collect();
320
321        let command = self.rerender_inner(&txn, jobs, workers).await?;
322
323        txn.commit().await?;
324
325        Ok(command)
326    }
327
328    pub async fn reschedule_fragment_inplace(
329        &self,
330        policy: HashMap<risingwave_meta_model::FragmentId, Option<StreamingParallelism>>,
331        workers: HashMap<WorkerId, PbWorkerNode>,
332    ) -> MetaResult<HashMap<DatabaseId, Command>> {
333        if policy.is_empty() {
334            return Ok(HashMap::new());
335        }
336
337        let inner = self.metadata_manager.catalog_controller.inner.read().await;
338        let txn = inner.db.begin().await?;
339
340        let fragment_id_list = policy.keys().copied().collect_vec();
341
342        let existing_fragment_ids: HashSet<_> = Fragment::find()
343            .select_only()
344            .column(fragment::Column::FragmentId)
345            .filter(fragment::Column::FragmentId.is_in(fragment_id_list.clone()))
346            .into_tuple::<risingwave_meta_model::FragmentId>()
347            .all(&txn)
348            .await?
349            .into_iter()
350            .collect();
351
352        if let Some(missing_fragment_id) = fragment_id_list
353            .iter()
354            .find(|fragment_id| !existing_fragment_ids.contains(fragment_id))
355        {
356            return Err(MetaError::catalog_id_not_found(
357                "fragment",
358                *missing_fragment_id,
359            ));
360        }
361
362        let mut target_ensembles = vec![];
363
364        for ensemble in find_fragment_no_shuffle_dags_detailed(&txn, &fragment_id_list).await? {
365            let entry_fragment_ids = ensemble.entry_fragments().collect_vec();
366
367            let desired_parallelism = match entry_fragment_ids
368                .iter()
369                .filter_map(|fragment_id| policy.get(fragment_id).cloned())
370                .dedup()
371                .collect_vec()
372                .as_slice()
373            {
374                [] => {
375                    bail_invalid_parameter!(
376                        "none of the entry fragments {:?} were included in the reschedule request; \
377                         provide at least one entry fragment id",
378                        entry_fragment_ids
379                    );
380                }
381                [parallelism] => parallelism.clone(),
382                parallelisms => {
383                    bail!(
384                        "conflicting reschedule policies for fragments in the same no-shuffle ensemble: {:?}",
385                        parallelisms
386                    );
387                }
388            };
389
390            let fragments = Fragment::find()
391                .filter(fragment::Column::FragmentId.is_in(entry_fragment_ids))
392                .all(&txn)
393                .await?;
394
395            debug_assert!(
396                fragments
397                    .iter()
398                    .map(|fragment| fragment.parallelism.as_ref())
399                    .all_equal(),
400                "entry fragments in the same ensemble should share the same parallelism"
401            );
402
403            let current_parallelism = fragments
404                .first()
405                .and_then(|fragment| fragment.parallelism.clone());
406
407            if current_parallelism == desired_parallelism {
408                continue;
409            }
410
411            for fragment in fragments {
412                let mut fragment = fragment.into_active_model();
413                fragment.parallelism = Set(desired_parallelism.clone());
414                fragment.update(&txn).await?;
415            }
416
417            target_ensembles.push(ensemble);
418        }
419
420        let workers = workers
421            .into_iter()
422            .map(|(id, worker)| {
423                (
424                    id,
425                    WorkerInfo {
426                        parallelism: NonZeroUsize::new(worker.compute_node_parallelism()).unwrap(),
427                        resource_group: worker.resource_group(),
428                    },
429                )
430            })
431            .collect();
432
433        let command = self
434            .rerender_fragment_inner(&txn, target_ensembles, workers)
435            .await?;
436
437        txn.commit().await?;
438
439        Ok(command)
440    }
441
442    async fn rerender(
443        &self,
444        jobs: HashSet<JobId>,
445        workers: BTreeMap<WorkerId, WorkerInfo>,
446    ) -> MetaResult<HashMap<DatabaseId, Command>> {
447        let inner = self.metadata_manager.catalog_controller.inner.read().await;
448        self.rerender_inner(&inner.db, jobs, workers).await
449    }
450
451    async fn rerender_fragment_inner(
452        &self,
453        txn: &impl ConnectionTrait,
454        ensembles: Vec<NoShuffleEnsemble>,
455        workers: BTreeMap<WorkerId, WorkerInfo>,
456    ) -> MetaResult<HashMap<DatabaseId, Command>> {
457        if ensembles.is_empty() {
458            return Ok(HashMap::new());
459        }
460
461        let adaptive_parallelism_strategy = {
462            let system_params_reader = self.env.system_params_reader().await;
463            system_params_reader.adaptive_parallelism_strategy()
464        };
465
466        let RenderedGraph { fragments, .. } = render_fragments(
467            txn,
468            self.env.actor_id_generator(),
469            ensembles,
470            workers,
471            adaptive_parallelism_strategy,
472        )
473        .await?;
474
475        self.build_reschedule_commands(txn, fragments).await
476    }
477
478    async fn rerender_inner(
479        &self,
480        txn: &impl ConnectionTrait,
481        jobs: HashSet<JobId>,
482        workers: BTreeMap<WorkerId, WorkerInfo>,
483    ) -> MetaResult<HashMap<DatabaseId, Command>> {
484        let adaptive_parallelism_strategy = {
485            let system_params_reader = self.env.system_params_reader().await;
486            system_params_reader.adaptive_parallelism_strategy()
487        };
488
489        let RenderedGraph { fragments, .. } = render_jobs(
490            txn,
491            self.env.actor_id_generator(),
492            jobs,
493            workers,
494            adaptive_parallelism_strategy,
495        )
496        .await?;
497
498        self.build_reschedule_commands(txn, fragments).await
499    }
500
501    async fn build_reschedule_commands(
502        &self,
503        txn: &impl ConnectionTrait,
504        render_result: FragmentRenderMap,
505    ) -> MetaResult<HashMap<DatabaseId, Command>> {
506        if render_result.is_empty() {
507            return Ok(HashMap::new());
508        }
509
510        let job_ids = render_result
511            .values()
512            .flat_map(|jobs| jobs.keys().copied())
513            .collect_vec();
514
515        let job_extra_info = get_streaming_job_extra_info(txn, job_ids).await?;
516
517        let fragment_ids = render_result
518            .values()
519            .flat_map(|jobs| jobs.values())
520            .flatten()
521            .map(|(fragment_id, _)| *fragment_id)
522            .collect_vec();
523
524        let upstreams: Vec<(
525            risingwave_meta_model::FragmentId,
526            risingwave_meta_model::FragmentId,
527            DispatcherType,
528        )> = FragmentRelation::find()
529            .select_only()
530            .columns([
531                fragment_relation::Column::TargetFragmentId,
532                fragment_relation::Column::SourceFragmentId,
533                fragment_relation::Column::DispatcherType,
534            ])
535            .filter(fragment_relation::Column::TargetFragmentId.is_in(fragment_ids.clone()))
536            .into_tuple()
537            .all(txn)
538            .await?;
539
540        let downstreams = FragmentRelation::find()
541            .filter(fragment_relation::Column::SourceFragmentId.is_in(fragment_ids.clone()))
542            .all(txn)
543            .await?;
544
545        let mut all_upstream_fragments = HashMap::new();
546
547        for (fragment, upstream, dispatcher) in upstreams {
548            let fragment_id = fragment as FragmentId;
549            let upstream_id = upstream as FragmentId;
550            all_upstream_fragments
551                .entry(fragment_id)
552                .or_insert(HashMap::new())
553                .insert(upstream_id, dispatcher);
554        }
555
556        let mut all_downstream_fragments = HashMap::new();
557
558        let mut downstream_relations = HashMap::new();
559        for relation in downstreams {
560            let source_fragment_id = relation.source_fragment_id as FragmentId;
561            let target_fragment_id = relation.target_fragment_id as FragmentId;
562            all_downstream_fragments
563                .entry(source_fragment_id)
564                .or_insert(HashMap::new())
565                .insert(target_fragment_id, relation.dispatcher_type);
566
567            downstream_relations.insert((source_fragment_id, target_fragment_id), relation);
568        }
569
570        let all_related_fragment_ids: HashSet<_> = fragment_ids
571            .iter()
572            .copied()
573            .chain(all_upstream_fragments.values().flatten().map(|(id, _)| *id))
574            .chain(
575                all_downstream_fragments
576                    .values()
577                    .flatten()
578                    .map(|(id, _)| *id),
579            )
580            .collect();
581
582        let all_related_fragment_ids = all_related_fragment_ids.into_iter().collect_vec();
583
584        let all_prev_fragments: HashMap<_, _> = {
585            let read_guard = self.env.shared_actor_infos().read_guard();
586            all_related_fragment_ids
587                .iter()
588                .map(|&fragment_id| {
589                    read_guard
590                        .get_fragment(fragment_id as FragmentId)
591                        .cloned()
592                        .map(|fragment| (fragment_id, fragment))
593                        .ok_or_else(|| {
594                            MetaError::from(anyhow!(
595                                "previous fragment info for {fragment_id} not found"
596                            ))
597                        })
598                })
599                .collect::<MetaResult<_>>()?
600        };
601
602        let all_rendered_fragments: HashMap<_, _> = render_result
603            .values()
604            .flat_map(|jobs| jobs.values())
605            .flatten()
606            .map(|(fragment_id, info)| (*fragment_id, info))
607            .collect();
608
609        let mut commands = HashMap::new();
610
611        for (database_id, jobs) in &render_result {
612            let mut all_fragment_actors = HashMap::new();
613            let mut reschedules = HashMap::new();
614
615            for (job_id, fragment_id, fragment_info) in
616                jobs.iter().flat_map(|(job_id, fragments)| {
617                    fragments
618                        .iter()
619                        .map(move |(fragment_id, info)| (job_id, fragment_id, info))
620                })
621            {
622                let InflightFragmentInfo {
623                    distribution_type,
624                    actors,
625                    ..
626                } = fragment_info;
627
628                let upstream_fragments = all_upstream_fragments
629                    .remove(&(*fragment_id as FragmentId))
630                    .unwrap_or_default();
631                let downstream_fragments = all_downstream_fragments
632                    .remove(&(*fragment_id as FragmentId))
633                    .unwrap_or_default();
634
635                let fragment_actors: HashMap<_, _> = upstream_fragments
636                    .keys()
637                    .copied()
638                    .chain(downstream_fragments.keys().copied())
639                    .map(|fragment_id| {
640                        all_prev_fragments
641                            .get(&fragment_id)
642                            .map(|fragment| {
643                                (
644                                    fragment_id,
645                                    fragment.actors.keys().copied().collect::<HashSet<_>>(),
646                                )
647                            })
648                            .ok_or_else(|| {
649                                MetaError::from(anyhow!(
650                                    "fragment {} not found in previous state",
651                                    fragment_id
652                                ))
653                            })
654                    })
655                    .collect::<MetaResult<_>>()?;
656
657                all_fragment_actors.extend(fragment_actors);
658
659                let source_fragment_actors = actors
660                    .iter()
661                    .map(|(actor_id, info)| (*actor_id, info.vnode_bitmap.clone()))
662                    .collect();
663
664                let mut all_actor_dispatchers: HashMap<_, Vec<_>> = HashMap::new();
665
666                for downstream_fragment_id in downstream_fragments.keys() {
667                    let target_fragment_actors =
668                        match all_rendered_fragments.get(downstream_fragment_id) {
669                            None => {
670                                let external_fragment = all_prev_fragments
671                                    .get(downstream_fragment_id)
672                                    .ok_or_else(|| {
673                                        MetaError::from(anyhow!(
674                                            "fragment {} not found in previous state",
675                                            downstream_fragment_id
676                                        ))
677                                    })?;
678
679                                external_fragment
680                                    .actors
681                                    .iter()
682                                    .map(|(actor_id, info)| (*actor_id, info.vnode_bitmap.clone()))
683                                    .collect()
684                            }
685                            Some(downstream_rendered) => downstream_rendered
686                                .actors
687                                .iter()
688                                .map(|(actor_id, info)| (*actor_id, info.vnode_bitmap.clone()))
689                                .collect(),
690                        };
691
692                    let target_fragment_distribution = *distribution_type;
693
694                    let fragment_relation::Model {
695                        source_fragment_id: _,
696                        target_fragment_id: _,
697                        dispatcher_type,
698                        dist_key_indices,
699                        output_indices,
700                        output_type_mapping,
701                    } = downstream_relations
702                        .remove(&(
703                            *fragment_id as FragmentId,
704                            *downstream_fragment_id as FragmentId,
705                        ))
706                        .ok_or_else(|| {
707                            MetaError::from(anyhow!(
708                                "downstream relation missing for {} -> {}",
709                                fragment_id,
710                                downstream_fragment_id
711                            ))
712                        })?;
713
714                    let pb_mapping = PbDispatchOutputMapping {
715                        indices: output_indices.into_u32_array(),
716                        types: output_type_mapping.unwrap_or_default().to_protobuf(),
717                    };
718
719                    let dispatchers = compose_dispatchers(
720                        *distribution_type,
721                        &source_fragment_actors,
722                        *downstream_fragment_id,
723                        target_fragment_distribution,
724                        &target_fragment_actors,
725                        dispatcher_type,
726                        dist_key_indices.into_u32_array(),
727                        pb_mapping,
728                    );
729
730                    for (actor_id, dispatcher) in dispatchers {
731                        all_actor_dispatchers
732                            .entry(actor_id)
733                            .or_default()
734                            .push(dispatcher);
735                    }
736                }
737
738                let prev_fragment = all_prev_fragments.get(&{ *fragment_id }).ok_or_else(|| {
739                    MetaError::from(anyhow!(
740                        "fragment {} not found in previous state",
741                        fragment_id
742                    ))
743                })?;
744
745                let reschedule = self.diff_fragment(
746                    prev_fragment,
747                    actors,
748                    upstream_fragments,
749                    downstream_fragments,
750                    all_actor_dispatchers,
751                    job_extra_info.get(job_id),
752                )?;
753
754                reschedules.insert(*fragment_id as FragmentId, reschedule);
755            }
756
757            let command = Command::RescheduleFragment {
758                reschedules,
759                fragment_actors: all_fragment_actors,
760            };
761
762            if let Command::RescheduleFragment { reschedules, .. } = &command {
763                debug_assert!(
764                    reschedules
765                        .values()
766                        .all(|reschedule| reschedule.vnode_bitmap_updates.is_empty()),
767                    "RescheduleFragment command carries vnode_bitmap_updates, expected full rebuild"
768                );
769            }
770
771            commands.insert(*database_id, command);
772        }
773
774        Ok(commands)
775    }
776}
777
778#[derive(Clone, Debug, Eq, PartialEq)]
779pub struct ParallelismPolicy {
780    pub parallelism: StreamingParallelism,
781}
782
783#[derive(Clone, Debug)]
784pub struct ResourceGroupPolicy {
785    pub resource_group: Option<String>,
786}
787
788#[derive(Clone, Debug)]
789pub enum ReschedulePolicy {
790    Parallelism(ParallelismPolicy),
791    ResourceGroup(ResourceGroupPolicy),
792    Both(ParallelismPolicy, ResourceGroupPolicy),
793}
794
795impl GlobalStreamManager {
796    #[await_tree::instrument("acquire_reschedule_read_guard")]
797    pub async fn reschedule_lock_read_guard(&self) -> RwLockReadGuard<'_, ()> {
798        self.scale_controller.reschedule_lock.read().await
799    }
800
801    #[await_tree::instrument("acquire_reschedule_write_guard")]
802    pub async fn reschedule_lock_write_guard(&self) -> RwLockWriteGuard<'_, ()> {
803        self.scale_controller.reschedule_lock.write().await
804    }
805
806    /// When new worker nodes joined, or the parallelism of existing worker nodes changed,
807    /// examines if there are any jobs can be scaled, and scales them if found.
808    ///
809    /// This method will iterate over all `CREATED` jobs, and can be repeatedly called.
810    ///
811    /// Returns
812    /// - `Ok(false)` if no jobs can be scaled;
813    /// - `Ok(true)` if some jobs are scaled, and it is possible that there are more jobs can be scaled.
814    async fn trigger_parallelism_control(&self) -> MetaResult<bool> {
815        tracing::info!("trigger parallelism control");
816
817        let _reschedule_job_lock = self.reschedule_lock_write_guard().await;
818
819        let background_streaming_jobs = self
820            .metadata_manager
821            .list_background_creating_jobs()
822            .await?;
823
824        let unreschedulable_jobs = self
825            .metadata_manager
826            .collect_unreschedulable_backfill_jobs(&background_streaming_jobs)
827            .await?;
828
829        let database_objects: HashMap<risingwave_meta_model::DatabaseId, Vec<JobId>> = self
830            .metadata_manager
831            .catalog_controller
832            .list_streaming_job_with_database()
833            .await?;
834
835        let job_ids = database_objects
836            .iter()
837            .flat_map(|(database_id, job_ids)| {
838                job_ids
839                    .iter()
840                    .enumerate()
841                    .map(move |(idx, job_id)| (idx, database_id, job_id))
842            })
843            .sorted_by(|(idx_a, database_a, _), (idx_b, database_b, _)| {
844                idx_a.cmp(idx_b).then(database_a.cmp(database_b))
845            })
846            .map(|(_, database_id, job_id)| (*database_id, *job_id))
847            .filter(|(_, job_id)| !unreschedulable_jobs.contains(job_id))
848            .collect_vec();
849
850        if job_ids.is_empty() {
851            tracing::info!("no streaming jobs for scaling, maybe an empty cluster");
852            return Ok(false);
853        }
854
855        let workers = self
856            .metadata_manager
857            .cluster_controller
858            .list_active_streaming_workers()
859            .await?;
860
861        let schedulable_workers: BTreeMap<_, _> = workers
862            .iter()
863            .filter(|worker| {
864                !worker
865                    .property
866                    .as_ref()
867                    .map(|p| p.is_unschedulable)
868                    .unwrap_or(false)
869            })
870            .map(|worker| {
871                (
872                    worker.id,
873                    WorkerInfo {
874                        parallelism: NonZeroUsize::new(worker.compute_node_parallelism()).unwrap(),
875                        resource_group: worker.resource_group(),
876                    },
877                )
878            })
879            .collect();
880
881        if job_ids.is_empty() {
882            tracing::info!("no streaming jobs for scaling, maybe an empty cluster");
883            return Ok(false);
884        }
885
886        tracing::info!(
887            "trigger parallelism control for jobs: {:#?}, workers {:#?}",
888            job_ids,
889            schedulable_workers
890        );
891
892        let batch_size = match self.env.opts.parallelism_control_batch_size {
893            0 => job_ids.len(),
894            n => n,
895        };
896
897        tracing::info!(
898            "total {} streaming jobs, batch size {}, schedulable worker ids: {:?}",
899            job_ids.len(),
900            batch_size,
901            schedulable_workers
902        );
903
904        let batches: Vec<_> = job_ids
905            .into_iter()
906            .chunks(batch_size)
907            .into_iter()
908            .map(|chunk| chunk.collect_vec())
909            .collect();
910
911        for batch in batches {
912            let jobs = batch.iter().map(|(_, job_id)| *job_id).collect();
913
914            let commands = self
915                .scale_controller
916                .rerender(jobs, schedulable_workers.clone())
917                .await?;
918
919            let futures = commands.into_iter().map(|(database_id, command)| {
920                let barrier_scheduler = self.barrier_scheduler.clone();
921                async move { barrier_scheduler.run_command(database_id, command).await }
922            });
923
924            let _results = future::try_join_all(futures).await?;
925        }
926
927        Ok(false)
928    }
929
930    /// Handles notification of worker node activation and deletion, and triggers parallelism control.
931    async fn run(&self, mut shutdown_rx: Receiver<()>) {
932        tracing::info!("starting automatic parallelism control monitor");
933
934        let check_period =
935            Duration::from_secs(self.env.opts.parallelism_control_trigger_period_sec);
936
937        let mut ticker = tokio::time::interval_at(
938            Instant::now()
939                + Duration::from_secs(self.env.opts.parallelism_control_trigger_first_delay_sec),
940            check_period,
941        );
942        ticker.set_missed_tick_behavior(MissedTickBehavior::Skip);
943
944        let (local_notification_tx, mut local_notification_rx) =
945            tokio::sync::mpsc::unbounded_channel();
946
947        self.env
948            .notification_manager()
949            .insert_local_sender(local_notification_tx);
950
951        // waiting for the first tick
952        ticker.tick().await;
953
954        let worker_nodes = self
955            .metadata_manager
956            .list_active_streaming_compute_nodes()
957            .await
958            .expect("list active streaming compute nodes");
959
960        let mut worker_cache: BTreeMap<_, _> = worker_nodes
961            .into_iter()
962            .map(|worker| (worker.id, worker))
963            .collect();
964
965        let mut previous_adaptive_parallelism_strategy = AdaptiveParallelismStrategy::default();
966
967        let mut should_trigger = false;
968
969        loop {
970            tokio::select! {
971                biased;
972
973                _ = &mut shutdown_rx => {
974                    tracing::info!("Stream manager is stopped");
975                    break;
976                }
977
978                _ = ticker.tick(), if should_trigger => {
979                    let include_workers = worker_cache.keys().copied().collect_vec();
980
981                    if include_workers.is_empty() {
982                        tracing::debug!("no available worker nodes");
983                        should_trigger = false;
984                        continue;
985                    }
986
987                    match self.trigger_parallelism_control().await {
988                        Ok(cont) => {
989                            should_trigger = cont;
990                        }
991                        Err(e) => {
992                            tracing::warn!(error = %e.as_report(), "Failed to trigger scale out, waiting for next tick to retry after {}s", ticker.period().as_secs());
993                            ticker.reset();
994                        }
995                    }
996                }
997
998                notification = local_notification_rx.recv() => {
999                    let notification = notification.expect("local notification channel closed in loop of stream manager");
1000
1001                    // Only maintain the cache for streaming compute nodes.
1002                    let worker_is_streaming_compute = |worker: &WorkerNode| {
1003                        worker.get_type() == Ok(WorkerType::ComputeNode)
1004                            && worker.property.as_ref().unwrap().is_streaming
1005                    };
1006
1007                    match notification {
1008                        LocalNotification::SystemParamsChange(reader) => {
1009                            let new_strategy = reader.adaptive_parallelism_strategy();
1010                            if new_strategy != previous_adaptive_parallelism_strategy {
1011                                tracing::info!("adaptive parallelism strategy changed from {:?} to {:?}", previous_adaptive_parallelism_strategy, new_strategy);
1012                                should_trigger = true;
1013                                previous_adaptive_parallelism_strategy = new_strategy;
1014                            }
1015                        }
1016                        LocalNotification::WorkerNodeActivated(worker) => {
1017                            if !worker_is_streaming_compute(&worker) {
1018                                continue;
1019                            }
1020
1021                            tracing::info!(worker = %worker.id, "worker activated notification received");
1022
1023                            let prev_worker = worker_cache.insert(worker.id, worker.clone());
1024
1025                            match prev_worker {
1026                                Some(prev_worker) if prev_worker.compute_node_parallelism() != worker.compute_node_parallelism()  => {
1027                                    tracing::info!(worker = %worker.id, "worker parallelism changed");
1028                                    should_trigger = true;
1029                                }
1030                                Some(prev_worker) if prev_worker.resource_group() != worker.resource_group()  => {
1031                                    tracing::info!(worker = %worker.id, "worker label changed");
1032                                    should_trigger = true;
1033                                }
1034                                None => {
1035                                    tracing::info!(worker = %worker.id, "new worker joined");
1036                                    should_trigger = true;
1037                                }
1038                                _ => {}
1039                            }
1040                        }
1041
1042                        // Since our logic for handling passive scale-in is within the barrier manager,
1043                        // there’s not much we can do here. All we can do is proactively remove the entries from our cache.
1044                        LocalNotification::WorkerNodeDeleted(worker) => {
1045                            if !worker_is_streaming_compute(&worker) {
1046                                continue;
1047                            }
1048
1049                            match worker_cache.remove(&worker.id) {
1050                                Some(prev_worker) => {
1051                                    tracing::info!(worker = %prev_worker.id, "worker removed from stream manager cache");
1052                                }
1053                                None => {
1054                                    tracing::warn!(worker = %worker.id, "worker not found in stream manager cache, but it was removed");
1055                                }
1056                            }
1057                        }
1058
1059                        LocalNotification::StreamingJobBackfillFinished(job_id) => {
1060                            tracing::debug!(job_id = %job_id, "received backfill finished notification");
1061                            if let Err(e) = self.apply_post_backfill_parallelism(job_id).await {
1062                                tracing::warn!(job_id = %job_id, error = %e.as_report(), "failed to restore parallelism after backfill");
1063                            }
1064                        }
1065
1066                        _ => {}
1067                    }
1068                }
1069            }
1070        }
1071    }
1072
1073    /// Restores a streaming job's parallelism to its target value after backfill completes.
1074    async fn apply_post_backfill_parallelism(&self, job_id: JobId) -> MetaResult<()> {
1075        // Fetch both the target parallelism (final desired state) and the backfill parallelism
1076        // (temporary parallelism used during backfill phase) from the catalog.
1077        let (target, backfill_parallelism) = self
1078            .metadata_manager
1079            .catalog_controller
1080            .get_job_parallelisms(job_id)
1081            .await?;
1082
1083        // Determine if we need to reschedule based on the backfill configuration.
1084        match backfill_parallelism {
1085            Some(backfill_parallelism) if backfill_parallelism == target => {
1086                // Backfill parallelism matches target - no reschedule needed since the job
1087                // is already running at the desired parallelism.
1088                tracing::debug!(
1089                    job_id = %job_id,
1090                    ?backfill_parallelism,
1091                    ?target,
1092                    "backfill parallelism equals job parallelism, skip reschedule"
1093                );
1094                return Ok(());
1095            }
1096            Some(_) => {
1097                // Backfill parallelism differs from target - proceed to restore target parallelism.
1098            }
1099            None => {
1100                // No backfill parallelism was configured, meaning the job was created without
1101                // a special backfill override. No reschedule is necessary.
1102                tracing::debug!(
1103                    job_id = %job_id,
1104                    ?target,
1105                    "no backfill parallelism configured, skip post-backfill reschedule"
1106                );
1107                return Ok(());
1108            }
1109        }
1110
1111        // Reschedule the job to restore its target parallelism.
1112        tracing::info!(
1113            job_id = %job_id,
1114            ?target,
1115            ?backfill_parallelism,
1116            "restoring parallelism after backfill via reschedule"
1117        );
1118        let policy = ReschedulePolicy::Parallelism(ParallelismPolicy {
1119            parallelism: target,
1120        });
1121        if let Err(e) = self.reschedule_streaming_job(job_id, policy, false).await {
1122            tracing::warn!(job_id = %job_id, error = %e.as_report(), "reschedule after backfill failed");
1123            return Err(e);
1124        }
1125
1126        tracing::info!(job_id = %job_id, "parallelism reschedule after backfill submitted");
1127        Ok(())
1128    }
1129
1130    pub fn start_auto_parallelism_monitor(
1131        self: Arc<Self>,
1132    ) -> (JoinHandle<()>, oneshot::Sender<()>) {
1133        tracing::info!("Automatic parallelism scale-out is enabled for streaming jobs");
1134        let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
1135        let join_handle = tokio::spawn(async move {
1136            self.run(shutdown_rx).await;
1137        });
1138
1139        (join_handle, shutdown_tx)
1140    }
1141}