risingwave_meta/controller/
scale.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque};
16use std::num::NonZeroUsize;
17use std::sync::atomic::{AtomicU32, Ordering};
18
19use anyhow::anyhow;
20use itertools::Itertools;
21use risingwave_common::bitmap::Bitmap;
22use risingwave_common::catalog::{FragmentTypeFlag, FragmentTypeMask, TableId};
23use risingwave_common::id::JobId;
24use risingwave_common::system_param::AdaptiveParallelismStrategy;
25use risingwave_common::util::worker_util::DEFAULT_RESOURCE_GROUP;
26use risingwave_connector::source::{SplitImpl, SplitMetaData};
27use risingwave_meta_model::fragment::DistributionType;
28use risingwave_meta_model::prelude::{
29    Database, Fragment, FragmentRelation, FragmentSplits, Sink, Source, StreamingJob, Table,
30};
31use risingwave_meta_model::{
32    CreateType, DatabaseId, DispatcherType, FragmentId, JobStatus, SourceId, StreamingParallelism,
33    WorkerId, database, fragment, fragment_relation, fragment_splits, object, sink, source,
34    streaming_job, table,
35};
36use risingwave_meta_model_migration::Condition;
37use risingwave_pb::common::WorkerNode;
38use risingwave_pb::stream_plan::PbStreamNode;
39use sea_orm::{
40    ColumnTrait, ConnectionTrait, EntityTrait, JoinType, QueryFilter, QuerySelect, QueryTrait,
41    RelationTrait,
42};
43
44use crate::MetaResult;
45use crate::controller::fragment::{InflightActorInfo, InflightFragmentInfo};
46use crate::manager::ActiveStreamingWorkerNodes;
47use crate::model::{ActorId, StreamActor, StreamingJobModelContextExt};
48use crate::stream::{AssignerBuilder, SplitDiffOptions};
49
50pub(crate) async fn resolve_streaming_job_definition<C>(
51    txn: &C,
52    job_ids: &HashSet<JobId>,
53) -> MetaResult<HashMap<JobId, String>>
54where
55    C: ConnectionTrait,
56{
57    let job_ids = job_ids.iter().cloned().collect_vec();
58
59    // including table, materialized view, index
60    let common_job_definitions: Vec<(JobId, String)> = Table::find()
61        .select_only()
62        .columns([
63            table::Column::TableId,
64            #[cfg(not(debug_assertions))]
65            table::Column::Name,
66            #[cfg(debug_assertions)]
67            table::Column::Definition,
68        ])
69        .filter(table::Column::TableId.is_in(job_ids.clone()))
70        .into_tuple()
71        .all(txn)
72        .await?;
73
74    let sink_definitions: Vec<(JobId, String)> = Sink::find()
75        .select_only()
76        .columns([
77            sink::Column::SinkId,
78            #[cfg(not(debug_assertions))]
79            sink::Column::Name,
80            #[cfg(debug_assertions)]
81            sink::Column::Definition,
82        ])
83        .filter(sink::Column::SinkId.is_in(job_ids.clone()))
84        .into_tuple()
85        .all(txn)
86        .await?;
87
88    let source_definitions: Vec<(JobId, String)> = Source::find()
89        .select_only()
90        .columns([
91            source::Column::SourceId,
92            #[cfg(not(debug_assertions))]
93            source::Column::Name,
94            #[cfg(debug_assertions)]
95            source::Column::Definition,
96        ])
97        .filter(source::Column::SourceId.is_in(job_ids.clone()))
98        .into_tuple()
99        .all(txn)
100        .await?;
101
102    let definitions: HashMap<JobId, String> = common_job_definitions
103        .into_iter()
104        .chain(sink_definitions.into_iter())
105        .chain(source_definitions.into_iter())
106        .collect();
107
108    Ok(definitions)
109}
110
111pub async fn load_fragment_info<C>(
112    txn: &C,
113    actor_id_counter: &AtomicU32,
114    database_id: Option<DatabaseId>,
115    worker_nodes: &ActiveStreamingWorkerNodes,
116    adaptive_parallelism_strategy: AdaptiveParallelismStrategy,
117) -> MetaResult<FragmentRenderMap>
118where
119    C: ConnectionTrait,
120{
121    let mut query = StreamingJob::find()
122        .select_only()
123        .column(streaming_job::Column::JobId);
124
125    if let Some(database_id) = database_id {
126        query = query
127            .join(JoinType::InnerJoin, streaming_job::Relation::Object.def())
128            .filter(object::Column::DatabaseId.eq(database_id));
129    }
130
131    let jobs: Vec<JobId> = query.into_tuple().all(txn).await?;
132
133    if jobs.is_empty() {
134        return Ok(HashMap::new());
135    }
136
137    let jobs: HashSet<JobId> = jobs.into_iter().collect();
138
139    let loaded = load_fragment_context_for_jobs(txn, jobs).await?;
140
141    if loaded.is_empty() {
142        return Ok(HashMap::new());
143    }
144
145    let RenderedGraph { fragments, .. } = render_actor_assignments(
146        actor_id_counter,
147        worker_nodes.current(),
148        adaptive_parallelism_strategy,
149        &loaded,
150    )?;
151
152    Ok(fragments)
153}
154
155#[derive(Debug)]
156pub struct TargetResourcePolicy {
157    pub resource_group: Option<String>,
158    pub parallelism: StreamingParallelism,
159}
160
161#[derive(Debug, Clone)]
162pub struct WorkerInfo {
163    pub parallelism: NonZeroUsize,
164    pub resource_group: Option<String>,
165}
166
167pub type FragmentRenderMap =
168    HashMap<DatabaseId, HashMap<JobId, HashMap<FragmentId, InflightFragmentInfo>>>;
169
170#[derive(Default)]
171pub struct RenderedGraph {
172    pub fragments: FragmentRenderMap,
173    pub ensembles: Vec<NoShuffleEnsemble>,
174}
175
176impl RenderedGraph {
177    pub fn empty() -> Self {
178        Self::default()
179    }
180}
181
182/// Context loaded asynchronously from database, containing all metadata
183/// required to render actor assignments. This separates async I/O from
184/// sync rendering logic.
185#[derive(Clone, Debug)]
186pub struct LoadedFragment {
187    pub fragment_id: FragmentId,
188    pub job_id: JobId,
189    pub fragment_type_mask: FragmentTypeMask,
190    pub distribution_type: DistributionType,
191    pub vnode_count: usize,
192    pub nodes: PbStreamNode,
193    pub state_table_ids: HashSet<TableId>,
194    pub parallelism: Option<StreamingParallelism>,
195}
196
197impl From<fragment::Model> for LoadedFragment {
198    fn from(model: fragment::Model) -> Self {
199        Self {
200            fragment_id: model.fragment_id,
201            job_id: model.job_id,
202            fragment_type_mask: FragmentTypeMask::from(model.fragment_type_mask),
203            distribution_type: model.distribution_type,
204            vnode_count: model.vnode_count as usize,
205            nodes: model.stream_node.to_protobuf(),
206            state_table_ids: model.state_table_ids.into_inner().into_iter().collect(),
207            parallelism: model.parallelism,
208        }
209    }
210}
211
212#[derive(Default, Debug, Clone)]
213pub struct LoadedFragmentContext {
214    pub ensembles: Vec<NoShuffleEnsemble>,
215    pub job_fragments: HashMap<JobId, HashMap<FragmentId, LoadedFragment>>,
216    pub job_map: HashMap<JobId, streaming_job::Model>,
217    pub streaming_job_databases: HashMap<JobId, DatabaseId>,
218    pub database_map: HashMap<DatabaseId, database::Model>,
219    pub fragment_source_ids: HashMap<FragmentId, SourceId>,
220    pub fragment_splits: HashMap<FragmentId, Vec<SplitImpl>>,
221}
222
223impl LoadedFragmentContext {
224    pub fn is_empty(&self) -> bool {
225        if self.ensembles.is_empty() {
226            assert!(
227                self.job_fragments.is_empty(),
228                "non-empty job fragments for empty ensembles: {:?}",
229                self.job_fragments
230            );
231            true
232        } else {
233            false
234        }
235    }
236
237    pub fn for_database(&self, database_id: DatabaseId) -> Option<Self> {
238        let job_ids: HashSet<JobId> = self
239            .streaming_job_databases
240            .iter()
241            .filter_map(|(job_id, db_id)| (*db_id == database_id).then_some(*job_id))
242            .collect();
243
244        if job_ids.is_empty() {
245            return None;
246        }
247
248        let job_fragments: HashMap<_, _> = job_ids
249            .iter()
250            .map(|job_id| (*job_id, self.job_fragments[job_id].clone()))
251            .collect();
252
253        let fragment_ids: HashSet<_> = job_fragments
254            .values()
255            .flat_map(|fragments| fragments.keys().copied())
256            .collect();
257
258        assert!(
259            !fragment_ids.is_empty(),
260            "empty fragments for non-empty database {database_id} with jobs {job_ids:?}"
261        );
262
263        let ensembles: Vec<NoShuffleEnsemble> = self
264            .ensembles
265            .iter()
266            .filter(|ensemble| {
267                if ensemble
268                    .components
269                    .iter()
270                    .any(|fragment_id| fragment_ids.contains(fragment_id))
271                {
272                    assert!(
273                        ensemble
274                            .components
275                            .iter()
276                            .all(|fragment_id| fragment_ids.contains(fragment_id)),
277                        "ensemble {ensemble:?} partially exists in database {database_id} with fragments {job_fragments:?}"
278                    );
279                    true
280                } else {
281                    false
282                }
283            })
284            .cloned()
285            .collect();
286
287        assert!(
288            !ensembles.is_empty(),
289            "empty ensembles for non-empty database {database_id} with jobs {job_fragments:?}"
290        );
291
292        let job_map = job_ids
293            .iter()
294            .filter_map(|job_id| self.job_map.get(job_id).map(|job| (*job_id, job.clone())))
295            .collect();
296
297        let streaming_job_databases = job_ids
298            .iter()
299            .filter_map(|job_id| {
300                self.streaming_job_databases
301                    .get(job_id)
302                    .map(|db_id| (*job_id, *db_id))
303            })
304            .collect();
305
306        let database_model = self.database_map[&database_id].clone();
307        let database_map = HashMap::from([(database_id, database_model)]);
308
309        let fragment_source_ids = self
310            .fragment_source_ids
311            .iter()
312            .filter(|(fragment_id, _)| fragment_ids.contains(*fragment_id))
313            .map(|(fragment_id, source_id)| (*fragment_id, *source_id))
314            .collect();
315
316        let fragment_splits = self
317            .fragment_splits
318            .iter()
319            .filter(|(fragment_id, _)| fragment_ids.contains(*fragment_id))
320            .map(|(fragment_id, splits)| (*fragment_id, splits.clone()))
321            .collect();
322
323        Some(Self {
324            ensembles,
325            job_fragments,
326            job_map,
327            streaming_job_databases,
328            database_map,
329            fragment_source_ids,
330            fragment_splits,
331        })
332    }
333
334    /// Split this loaded context by database in a single ownership pass to avoid cloning large
335    /// fragment payloads (for example `stream_node` in `LoadedFragment`).
336    pub fn into_database_contexts(self) -> HashMap<DatabaseId, Self> {
337        let Self {
338            ensembles,
339            mut job_fragments,
340            mut job_map,
341            streaming_job_databases,
342            mut database_map,
343            mut fragment_source_ids,
344            mut fragment_splits,
345        } = self;
346
347        let mut contexts = HashMap::<DatabaseId, Self>::new();
348        let mut fragment_databases = HashMap::<FragmentId, DatabaseId>::new();
349        let mut unresolved_ensembles = 0usize;
350        let mut unresolved_ensemble_sample: Option<Vec<FragmentId>> = None;
351
352        for (job_id, database_id) in streaming_job_databases {
353            let context = contexts.entry(database_id).or_insert_with(|| {
354                let database_model = database_map
355                    .remove(&database_id)
356                    .expect("database should exist for streaming job");
357                Self {
358                    ensembles: Vec::new(),
359                    job_fragments: HashMap::new(),
360                    job_map: HashMap::new(),
361                    streaming_job_databases: HashMap::new(),
362                    database_map: HashMap::from([(database_id, database_model)]),
363                    fragment_source_ids: HashMap::new(),
364                    fragment_splits: HashMap::new(),
365                }
366            });
367
368            let fragments = job_fragments
369                .remove(&job_id)
370                .expect("job fragments should exist for streaming job");
371            for fragment_id in fragments.keys().copied() {
372                fragment_databases.insert(fragment_id, database_id);
373                if let Some(source_id) = fragment_source_ids.remove(&fragment_id) {
374                    context.fragment_source_ids.insert(fragment_id, source_id);
375                }
376                if let Some(splits) = fragment_splits.remove(&fragment_id) {
377                    context.fragment_splits.insert(fragment_id, splits);
378                }
379            }
380
381            assert!(
382                context
383                    .job_map
384                    .insert(
385                        job_id,
386                        job_map
387                            .remove(&job_id)
388                            .expect("streaming job should exist for loaded context"),
389                    )
390                    .is_none(),
391                "duplicated streaming job"
392            );
393            assert!(
394                context.job_fragments.insert(job_id, fragments).is_none(),
395                "duplicated job fragments"
396            );
397            assert!(
398                context
399                    .streaming_job_databases
400                    .insert(job_id, database_id)
401                    .is_none(),
402                "duplicated job database mapping"
403            );
404        }
405
406        for ensemble in ensembles {
407            let Some(database_id) = ensemble
408                .components
409                .iter()
410                .find_map(|fragment_id| fragment_databases.get(fragment_id).copied())
411            else {
412                unresolved_ensembles += 1;
413                if unresolved_ensemble_sample.is_none() {
414                    unresolved_ensemble_sample =
415                        Some(ensemble.components.iter().copied().collect());
416                }
417                continue;
418            };
419
420            debug_assert!(
421                ensemble
422                    .components
423                    .iter()
424                    .all(|fragment_id| fragment_databases.get(fragment_id) == Some(&database_id)),
425                "ensemble {ensemble:?} should belong to a single database"
426            );
427
428            contexts
429                .get_mut(&database_id)
430                .expect("database context should exist for ensemble")
431                .ensembles
432                .push(ensemble);
433        }
434
435        if unresolved_ensembles > 0 {
436            tracing::warn!(
437                unresolved_ensembles,
438                ?unresolved_ensemble_sample,
439                known_fragments = fragment_databases.len(),
440                "skip ensembles without resolved database while splitting loaded context"
441            );
442        }
443        debug_assert_eq!(
444            unresolved_ensembles, 0,
445            "all ensembles should be mappable to a database"
446        );
447
448        contexts
449    }
450}
451
452/// Fragment-scoped rendering entry point used by operational tooling.
453/// It validates that the requested fragments are roots of their no-shuffle ensembles,
454/// resolves only the metadata required for those components, and then reuses the shared
455/// rendering pipeline to materialize actor assignments.
456pub async fn render_fragments<C>(
457    txn: &C,
458    actor_id_counter: &AtomicU32,
459    ensembles: Vec<NoShuffleEnsemble>,
460    workers: &HashMap<WorkerId, WorkerNode>,
461    adaptive_parallelism_strategy: AdaptiveParallelismStrategy,
462) -> MetaResult<RenderedGraph>
463where
464    C: ConnectionTrait,
465{
466    let loaded = load_fragment_context(txn, ensembles).await?;
467
468    if loaded.is_empty() {
469        return Ok(RenderedGraph::empty());
470    }
471
472    render_actor_assignments(
473        actor_id_counter,
474        workers,
475        adaptive_parallelism_strategy,
476        &loaded,
477    )
478}
479
480/// Async load stage for fragment-scoped rendering. It resolves all metadata required to later
481/// render actor assignments with arbitrary worker sets.
482pub async fn load_fragment_context<C>(
483    txn: &C,
484    ensembles: Vec<NoShuffleEnsemble>,
485) -> MetaResult<LoadedFragmentContext>
486where
487    C: ConnectionTrait,
488{
489    if ensembles.is_empty() {
490        return Ok(LoadedFragmentContext::default());
491    }
492
493    let required_fragment_ids: HashSet<_> = ensembles
494        .iter()
495        .flat_map(|ensemble| ensemble.components.iter().copied())
496        .collect();
497
498    let fragment_models = Fragment::find()
499        .filter(fragment::Column::FragmentId.is_in(required_fragment_ids.iter().copied()))
500        .all(txn)
501        .await?;
502
503    let found_fragment_ids: HashSet<_> = fragment_models
504        .iter()
505        .map(|fragment| fragment.fragment_id)
506        .collect();
507
508    if found_fragment_ids.len() != required_fragment_ids.len() {
509        let missing = required_fragment_ids
510            .difference(&found_fragment_ids)
511            .copied()
512            .collect_vec();
513        return Err(anyhow!("fragments {:?} not found", missing).into());
514    }
515
516    let fragment_models: HashMap<_, _> = fragment_models
517        .into_iter()
518        .map(|fragment| (fragment.fragment_id, fragment))
519        .collect();
520
521    let job_ids: HashSet<_> = fragment_models
522        .values()
523        .map(|fragment| fragment.job_id)
524        .collect();
525
526    if job_ids.is_empty() {
527        return Ok(LoadedFragmentContext::default());
528    }
529
530    let jobs: HashMap<_, _> = StreamingJob::find()
531        .filter(streaming_job::Column::JobId.is_in(job_ids.iter().copied().collect_vec()))
532        .all(txn)
533        .await?
534        .into_iter()
535        .map(|job| (job.job_id, job))
536        .collect();
537
538    let found_job_ids: HashSet<_> = jobs.keys().copied().collect();
539    if found_job_ids.len() != job_ids.len() {
540        let missing = job_ids.difference(&found_job_ids).copied().collect_vec();
541        return Err(anyhow!("streaming jobs {:?} not found", missing).into());
542    }
543
544    build_loaded_context(txn, ensembles, fragment_models, jobs).await
545}
546
547/// Job-scoped rendering entry point that walks every no-shuffle root belonging to the
548/// provided streaming jobs before delegating to the shared rendering backend.
549pub async fn render_jobs<C>(
550    txn: &C,
551    actor_id_counter: &AtomicU32,
552    job_ids: HashSet<JobId>,
553    workers: &HashMap<WorkerId, WorkerNode>,
554    adaptive_parallelism_strategy: AdaptiveParallelismStrategy,
555) -> MetaResult<RenderedGraph>
556where
557    C: ConnectionTrait,
558{
559    let loaded = load_fragment_context_for_jobs(txn, job_ids).await?;
560
561    if loaded.is_empty() {
562        return Ok(RenderedGraph::empty());
563    }
564
565    render_actor_assignments(
566        actor_id_counter,
567        workers,
568        adaptive_parallelism_strategy,
569        &loaded,
570    )
571}
572
573/// Async load stage for job-scoped rendering. It collects all no-shuffle ensembles and the
574/// metadata required to render actor assignments later with a provided worker set.
575pub async fn load_fragment_context_for_jobs<C>(
576    txn: &C,
577    job_ids: HashSet<JobId>,
578) -> MetaResult<LoadedFragmentContext>
579where
580    C: ConnectionTrait,
581{
582    if job_ids.is_empty() {
583        return Ok(LoadedFragmentContext::default());
584    }
585
586    let excluded_fragments_query = FragmentRelation::find()
587        .select_only()
588        .column(fragment_relation::Column::TargetFragmentId)
589        .filter(fragment_relation::Column::DispatcherType.eq(DispatcherType::NoShuffle))
590        .into_query();
591
592    let condition = Condition::all()
593        .add(fragment::Column::JobId.is_in(job_ids.clone()))
594        .add(fragment::Column::FragmentId.not_in_subquery(excluded_fragments_query));
595
596    let fragments: Vec<FragmentId> = Fragment::find()
597        .select_only()
598        .column(fragment::Column::FragmentId)
599        .filter(condition)
600        .into_tuple()
601        .all(txn)
602        .await?;
603
604    let ensembles = find_fragment_no_shuffle_dags_detailed(txn, &fragments).await?;
605
606    let fragments = Fragment::find()
607        .filter(
608            fragment::Column::FragmentId.is_in(
609                ensembles
610                    .iter()
611                    .flat_map(|graph| graph.components.iter())
612                    .cloned()
613                    .collect_vec(),
614            ),
615        )
616        .all(txn)
617        .await?;
618
619    let fragment_map: HashMap<_, _> = fragments
620        .into_iter()
621        .map(|fragment| (fragment.fragment_id, fragment))
622        .collect();
623
624    let job_ids = fragment_map
625        .values()
626        .map(|fragment| fragment.job_id)
627        .collect::<BTreeSet<_>>()
628        .into_iter()
629        .collect_vec();
630
631    let jobs: HashMap<_, _> = StreamingJob::find()
632        .filter(streaming_job::Column::JobId.is_in(job_ids))
633        .all(txn)
634        .await?
635        .into_iter()
636        .map(|job| (job.job_id, job))
637        .collect();
638
639    build_loaded_context(txn, ensembles, fragment_map, jobs).await
640}
641
642/// Sync render stage: uses loaded fragment context and current worker info
643/// to produce actor-to-worker assignments and vnode bitmaps.
644pub(crate) fn render_actor_assignments(
645    actor_id_counter: &AtomicU32,
646    worker_map: &HashMap<WorkerId, WorkerNode>,
647    adaptive_parallelism_strategy: AdaptiveParallelismStrategy,
648    loaded: &LoadedFragmentContext,
649) -> MetaResult<RenderedGraph> {
650    if loaded.is_empty() {
651        return Ok(RenderedGraph::empty());
652    }
653
654    let backfill_jobs: HashSet<JobId> = loaded
655        .job_map
656        .iter()
657        .filter(|(_, job)| {
658            job.create_type == CreateType::Background && job.job_status == JobStatus::Creating
659        })
660        .map(|(id, _)| *id)
661        .collect();
662
663    let render_context = RenderActorsContext {
664        fragment_source_ids: &loaded.fragment_source_ids,
665        fragment_splits: &loaded.fragment_splits,
666        streaming_job_databases: &loaded.streaming_job_databases,
667        database_map: &loaded.database_map,
668        backfill_jobs: &backfill_jobs,
669    };
670
671    let fragments = render_actors(
672        actor_id_counter,
673        &loaded.ensembles,
674        &loaded.job_fragments,
675        &loaded.job_map,
676        worker_map,
677        adaptive_parallelism_strategy,
678        render_context,
679    )?;
680
681    Ok(RenderedGraph {
682        fragments,
683        ensembles: loaded.ensembles.clone(),
684    })
685}
686
687async fn build_loaded_context<C>(
688    txn: &C,
689    ensembles: Vec<NoShuffleEnsemble>,
690    fragment_models: HashMap<FragmentId, fragment::Model>,
691    job_map: HashMap<JobId, streaming_job::Model>,
692) -> MetaResult<LoadedFragmentContext>
693where
694    C: ConnectionTrait,
695{
696    if ensembles.is_empty() {
697        return Ok(LoadedFragmentContext::default());
698    }
699
700    let mut job_fragments: HashMap<JobId, HashMap<FragmentId, LoadedFragment>> = HashMap::new();
701    for (fragment_id, model) in fragment_models {
702        job_fragments
703            .entry(model.job_id)
704            .or_default()
705            .try_insert(fragment_id, LoadedFragment::from(model))
706            .expect("duplicate fragment id for job");
707    }
708
709    #[cfg(debug_assertions)]
710    {
711        debug_sanity_check(&ensembles, &job_fragments, &job_map);
712    }
713
714    let (fragment_source_ids, fragment_splits) =
715        resolve_source_fragments(txn, &job_fragments).await?;
716
717    let job_ids = job_map.keys().copied().collect_vec();
718
719    let streaming_job_databases: HashMap<JobId, _> = StreamingJob::find()
720        .select_only()
721        .column(streaming_job::Column::JobId)
722        .column(object::Column::DatabaseId)
723        .join(JoinType::LeftJoin, streaming_job::Relation::Object.def())
724        .filter(streaming_job::Column::JobId.is_in(job_ids))
725        .into_tuple()
726        .all(txn)
727        .await?
728        .into_iter()
729        .collect();
730
731    let database_map: HashMap<_, _> = Database::find()
732        .filter(
733            database::Column::DatabaseId
734                .is_in(streaming_job_databases.values().copied().collect_vec()),
735        )
736        .all(txn)
737        .await?
738        .into_iter()
739        .map(|db| (db.database_id, db))
740        .collect();
741
742    Ok(LoadedFragmentContext {
743        ensembles,
744        job_fragments,
745        job_map,
746        streaming_job_databases,
747        database_map,
748        fragment_source_ids,
749        fragment_splits,
750    })
751}
752
753// Only metadata resolved asynchronously lives here so the renderer stays synchronous
754// and the call site keeps the runtime dependencies (maps, strategy, actor counter, etc.) explicit.
755struct RenderActorsContext<'a> {
756    fragment_source_ids: &'a HashMap<FragmentId, SourceId>,
757    fragment_splits: &'a HashMap<FragmentId, Vec<SplitImpl>>,
758    streaming_job_databases: &'a HashMap<JobId, DatabaseId>,
759    database_map: &'a HashMap<DatabaseId, database::Model>,
760    backfill_jobs: &'a HashSet<JobId>,
761}
762
763fn render_actors(
764    actor_id_counter: &AtomicU32,
765    ensembles: &[NoShuffleEnsemble],
766    job_fragments: &HashMap<JobId, HashMap<FragmentId, LoadedFragment>>,
767    job_map: &HashMap<JobId, streaming_job::Model>,
768    worker_map: &HashMap<WorkerId, WorkerNode>,
769    adaptive_parallelism_strategy: AdaptiveParallelismStrategy,
770    context: RenderActorsContext<'_>,
771) -> MetaResult<FragmentRenderMap> {
772    let RenderActorsContext {
773        fragment_source_ids,
774        fragment_splits: fragment_splits_map,
775        streaming_job_databases,
776        database_map,
777        backfill_jobs,
778    } = context;
779
780    let mut all_fragments: FragmentRenderMap = HashMap::new();
781    let fragment_lookup: HashMap<FragmentId, &LoadedFragment> = job_fragments
782        .values()
783        .flat_map(|fragments| fragments.iter())
784        .map(|(fragment_id, fragment)| (*fragment_id, fragment))
785        .collect();
786
787    for NoShuffleEnsemble {
788        entries,
789        components,
790    } in ensembles
791    {
792        tracing::debug!("rendering ensemble entries {:?}", entries);
793
794        let entry_fragments = entries
795            .iter()
796            .map(|fragment_id| fragment_lookup.get(fragment_id).unwrap())
797            .collect_vec();
798
799        let entry_fragment_parallelism = entry_fragments
800            .iter()
801            .map(|fragment| fragment.parallelism.clone())
802            .dedup()
803            .exactly_one()
804            .map_err(|_| {
805                anyhow!(
806                    "entry fragments {:?} have inconsistent parallelism settings",
807                    entries.iter().copied().collect_vec()
808                )
809            })?;
810
811        let (job_id, vnode_count) = entry_fragments
812            .iter()
813            .map(|f| (f.job_id, f.vnode_count))
814            .dedup()
815            .exactly_one()
816            .map_err(|_| anyhow!("Multiple jobs found in no-shuffle ensemble"))?;
817
818        let job = job_map
819            .get(&job_id)
820            .ok_or_else(|| anyhow!("streaming job {job_id} not found"))?;
821
822        let job_strategy = job
823            .stream_context()
824            .adaptive_parallelism_strategy
825            .unwrap_or(adaptive_parallelism_strategy);
826
827        let resource_group = match &job.specific_resource_group {
828            None => {
829                let database = streaming_job_databases
830                    .get(&job_id)
831                    .and_then(|database_id| database_map.get(database_id))
832                    .unwrap();
833                database.resource_group.clone()
834            }
835            Some(resource_group) => resource_group.clone(),
836        };
837
838        let available_workers: BTreeMap<WorkerId, NonZeroUsize> = worker_map
839            .iter()
840            .filter_map(|(worker_id, worker)| {
841                if worker
842                    .resource_group()
843                    .as_deref()
844                    .unwrap_or(DEFAULT_RESOURCE_GROUP)
845                    == resource_group.as_str()
846                {
847                    Some((
848                        *worker_id,
849                        worker
850                            .parallelism()
851                            .expect("should have parallelism for compute node")
852                            .try_into()
853                            .expect("parallelism for compute node"),
854                    ))
855                } else {
856                    None
857                }
858            })
859            .collect();
860
861        let total_parallelism = available_workers.values().map(|w| w.get()).sum::<usize>();
862
863        let effective_job_parallelism = if backfill_jobs.contains(&job_id) {
864            job.backfill_parallelism
865                .as_ref()
866                .unwrap_or(&job.parallelism)
867        } else {
868            &job.parallelism
869        };
870
871        let actual_parallelism = match entry_fragment_parallelism
872            .as_ref()
873            .unwrap_or(effective_job_parallelism)
874        {
875            StreamingParallelism::Adaptive | StreamingParallelism::Custom => {
876                job_strategy.compute_target_parallelism(total_parallelism)
877            }
878            StreamingParallelism::Fixed(n) => *n,
879        }
880        .min(vnode_count)
881        .min(job.max_parallelism as usize);
882
883        tracing::debug!(
884            "job {}, final {} parallelism {:?} total_parallelism {} job_max {} vnode count {} fragment_override {:?}",
885            job_id,
886            actual_parallelism,
887            job.parallelism,
888            total_parallelism,
889            job.max_parallelism,
890            vnode_count,
891            entry_fragment_parallelism
892        );
893
894        let assigner = AssignerBuilder::new(job_id).build();
895
896        let actors = (0..(actual_parallelism as u32))
897            .map_into::<ActorId>()
898            .collect_vec();
899        let vnodes = (0..vnode_count).collect_vec();
900
901        let assignment = assigner.assign_hierarchical(&available_workers, &actors, &vnodes)?;
902
903        let source_entry_fragment = entry_fragments.iter().find(|f| {
904            let mask = f.fragment_type_mask;
905            if mask.contains(FragmentTypeFlag::Source) {
906                assert!(!mask.contains(FragmentTypeFlag::SourceScan))
907            }
908            mask.contains(FragmentTypeFlag::Source) && !mask.contains(FragmentTypeFlag::Dml)
909        });
910
911        let (fragment_splits, shared_source_id) = match source_entry_fragment {
912            Some(entry_fragment) => {
913                let source_id = fragment_source_ids
914                    .get(&entry_fragment.fragment_id)
915                    .ok_or_else(|| {
916                        anyhow!(
917                            "missing source id in source fragment {}",
918                            entry_fragment.fragment_id
919                        )
920                    })?;
921
922                let entry_fragment_id = entry_fragment.fragment_id;
923
924                let empty_actor_splits: HashMap<_, _> =
925                    actors.iter().map(|actor_id| (*actor_id, vec![])).collect();
926
927                let splits = fragment_splits_map
928                    .get(&entry_fragment_id)
929                    .cloned()
930                    .unwrap_or_default();
931
932                let splits: BTreeMap<_, _> = splits.into_iter().map(|s| (s.id(), s)).collect();
933
934                let fragment_splits = crate::stream::source_manager::reassign_splits(
935                    entry_fragment_id,
936                    empty_actor_splits,
937                    &splits,
938                    SplitDiffOptions::default(),
939                )
940                .unwrap_or_default();
941                (fragment_splits, Some(*source_id))
942            }
943            None => (HashMap::new(), None),
944        };
945
946        for component_fragment_id in components {
947            let fragment = fragment_lookup.get(component_fragment_id).unwrap();
948            let fragment_id = fragment.fragment_id;
949            let job_id = fragment.job_id;
950            let fragment_type_mask = fragment.fragment_type_mask;
951            let distribution_type = fragment.distribution_type;
952            let stream_node = &fragment.nodes;
953            let state_table_ids = &fragment.state_table_ids;
954            let vnode_count = fragment.vnode_count;
955
956            let actor_count =
957                u32::try_from(actors.len()).expect("actor parallelism exceeds u32::MAX");
958            let actor_id_base = actor_id_counter.fetch_add(actor_count, Ordering::Relaxed);
959
960            let actors: HashMap<ActorId, InflightActorInfo> = assignment
961                .iter()
962                .flat_map(|(worker_id, actors)| {
963                    actors
964                        .iter()
965                        .map(move |(actor_id, vnodes)| (worker_id, actor_id, vnodes))
966                })
967                .map(|(&worker_id, &actor_idx, vnodes)| {
968                    let vnode_bitmap = match distribution_type {
969                        DistributionType::Single => None,
970                        DistributionType::Hash => Some(Bitmap::from_indices(vnode_count, vnodes)),
971                    };
972
973                    let actor_id = actor_idx + actor_id_base;
974
975                    let splits = if let Some(source_id) = fragment_source_ids.get(&fragment_id) {
976                        assert_eq!(shared_source_id, Some(*source_id));
977
978                        fragment_splits
979                            .get(&(actor_idx))
980                            .cloned()
981                            .unwrap_or_default()
982                    } else {
983                        vec![]
984                    };
985
986                    (
987                        actor_id,
988                        InflightActorInfo {
989                            worker_id,
990                            vnode_bitmap,
991                            splits,
992                        },
993                    )
994                })
995                .collect();
996
997            let fragment = InflightFragmentInfo {
998                fragment_id,
999                distribution_type,
1000                fragment_type_mask,
1001                vnode_count,
1002                nodes: stream_node.clone(),
1003                actors,
1004                state_table_ids: state_table_ids.clone(),
1005            };
1006
1007            let &database_id = streaming_job_databases.get(&job_id).ok_or_else(|| {
1008                anyhow!("streaming job {job_id} not found in streaming_job_databases")
1009            })?;
1010
1011            all_fragments
1012                .entry(database_id)
1013                .or_default()
1014                .entry(job_id)
1015                .or_default()
1016                .insert(fragment_id, fragment);
1017        }
1018    }
1019
1020    Ok(all_fragments)
1021}
1022
1023#[cfg(debug_assertions)]
1024fn debug_sanity_check(
1025    ensembles: &[NoShuffleEnsemble],
1026    job_fragments: &HashMap<JobId, HashMap<FragmentId, LoadedFragment>>,
1027    jobs: &HashMap<JobId, streaming_job::Model>,
1028) {
1029    let fragment_lookup: HashMap<FragmentId, (&LoadedFragment, JobId)> = job_fragments
1030        .iter()
1031        .flat_map(|(job_id, fragments)| {
1032            fragments
1033                .iter()
1034                .map(move |(fragment_id, fragment)| (*fragment_id, (fragment, *job_id)))
1035        })
1036        .collect();
1037
1038    // Debug-only assertions to catch inconsistent ensemble metadata early.
1039    debug_assert!(
1040        ensembles
1041            .iter()
1042            .all(|ensemble| ensemble.entries.is_subset(&ensemble.components)),
1043        "entries must be subset of components"
1044    );
1045
1046    let mut missing_fragments = BTreeSet::new();
1047    let mut missing_jobs = BTreeSet::new();
1048
1049    for fragment_id in ensembles
1050        .iter()
1051        .flat_map(|ensemble| ensemble.components.iter())
1052    {
1053        match fragment_lookup.get(fragment_id) {
1054            Some((fragment, job_id)) => {
1055                if !jobs.contains_key(&fragment.job_id) {
1056                    missing_jobs.insert(*job_id);
1057                }
1058            }
1059            None => {
1060                missing_fragments.insert(*fragment_id);
1061            }
1062        }
1063    }
1064
1065    debug_assert!(
1066        missing_fragments.is_empty(),
1067        "missing fragments in fragment_map: {:?}",
1068        missing_fragments
1069    );
1070
1071    debug_assert!(
1072        missing_jobs.is_empty(),
1073        "missing jobs for fragments' job_id: {:?}",
1074        missing_jobs
1075    );
1076
1077    for ensemble in ensembles {
1078        let unique_vnode_counts: Vec<_> = ensemble
1079            .components
1080            .iter()
1081            .flat_map(|fragment_id| {
1082                fragment_lookup
1083                    .get(fragment_id)
1084                    .map(|(fragment, _)| fragment.vnode_count)
1085            })
1086            .unique()
1087            .collect();
1088
1089        debug_assert!(
1090            unique_vnode_counts.len() <= 1,
1091            "components in ensemble must share same vnode_count: ensemble={:?}, vnode_counts={:?}",
1092            ensemble.components,
1093            unique_vnode_counts
1094        );
1095    }
1096}
1097
1098async fn resolve_source_fragments<C>(
1099    txn: &C,
1100    job_fragments: &HashMap<JobId, HashMap<FragmentId, LoadedFragment>>,
1101) -> MetaResult<(
1102    HashMap<FragmentId, SourceId>,
1103    HashMap<FragmentId, Vec<SplitImpl>>,
1104)>
1105where
1106    C: ConnectionTrait,
1107{
1108    let mut source_fragment_ids: HashMap<SourceId, BTreeSet<FragmentId>> = HashMap::new();
1109    for (fragment_id, fragment) in job_fragments.values().flatten() {
1110        let mask = fragment.fragment_type_mask;
1111        if mask.contains(FragmentTypeFlag::Source)
1112            && let Some(source_id) = fragment.nodes.find_stream_source()
1113        {
1114            source_fragment_ids
1115                .entry(source_id)
1116                .or_default()
1117                .insert(*fragment_id);
1118        }
1119
1120        if mask.contains(FragmentTypeFlag::SourceScan)
1121            && let Some((source_id, _)) = fragment.nodes.find_source_backfill()
1122        {
1123            source_fragment_ids
1124                .entry(source_id)
1125                .or_default()
1126                .insert(*fragment_id);
1127        }
1128    }
1129
1130    let fragment_source_ids: HashMap<_, _> = source_fragment_ids
1131        .iter()
1132        .flat_map(|(source_id, fragment_ids)| {
1133            fragment_ids
1134                .iter()
1135                .map(|fragment_id| (*fragment_id, *source_id))
1136        })
1137        .collect();
1138
1139    let fragment_ids = fragment_source_ids.keys().copied().collect_vec();
1140
1141    let fragment_splits: Vec<_> = FragmentSplits::find()
1142        .filter(fragment_splits::Column::FragmentId.is_in(fragment_ids))
1143        .all(txn)
1144        .await?;
1145
1146    let fragment_splits: HashMap<_, _> = fragment_splits
1147        .into_iter()
1148        .flat_map(|model| {
1149            model.splits.map(|splits| {
1150                (
1151                    model.fragment_id,
1152                    splits
1153                        .to_protobuf()
1154                        .splits
1155                        .iter()
1156                        .flat_map(SplitImpl::try_from)
1157                        .collect_vec(),
1158                )
1159            })
1160        })
1161        .collect();
1162
1163    Ok((fragment_source_ids, fragment_splits))
1164}
1165
1166// Helper struct to make the function signature cleaner and to properly bundle the required data.
1167#[derive(Debug)]
1168pub struct ActorGraph<'a> {
1169    pub fragments: &'a HashMap<FragmentId, (Fragment, Vec<StreamActor>)>,
1170    pub locations: &'a HashMap<ActorId, WorkerId>,
1171}
1172
1173#[derive(Debug, Clone)]
1174pub struct NoShuffleEnsemble {
1175    entries: HashSet<FragmentId>,
1176    components: HashSet<FragmentId>,
1177}
1178
1179impl NoShuffleEnsemble {
1180    #[cfg(test)]
1181    pub fn for_test(
1182        entries: impl IntoIterator<Item = FragmentId>,
1183        components: impl IntoIterator<Item = FragmentId>,
1184    ) -> Self {
1185        let entries = entries.into_iter().collect();
1186        let components = components.into_iter().collect();
1187        Self {
1188            entries,
1189            components,
1190        }
1191    }
1192
1193    pub fn fragments(&self) -> impl Iterator<Item = FragmentId> + '_ {
1194        self.components.iter().cloned()
1195    }
1196
1197    pub fn entry_fragments(&self) -> impl Iterator<Item = FragmentId> + '_ {
1198        self.entries.iter().copied()
1199    }
1200
1201    pub fn component_fragments(&self) -> impl Iterator<Item = FragmentId> + '_ {
1202        self.components.iter().copied()
1203    }
1204
1205    pub fn contains_entry(&self, fragment_id: &FragmentId) -> bool {
1206        self.entries.contains(fragment_id)
1207    }
1208}
1209
1210pub async fn find_fragment_no_shuffle_dags_detailed(
1211    db: &impl ConnectionTrait,
1212    initial_fragment_ids: &[FragmentId],
1213) -> MetaResult<Vec<NoShuffleEnsemble>> {
1214    let all_no_shuffle_relations: Vec<(_, _)> = FragmentRelation::find()
1215        .columns([
1216            fragment_relation::Column::SourceFragmentId,
1217            fragment_relation::Column::TargetFragmentId,
1218        ])
1219        .filter(fragment_relation::Column::DispatcherType.eq(DispatcherType::NoShuffle))
1220        .into_tuple()
1221        .all(db)
1222        .await?;
1223
1224    let (forward_edges, backward_edges) =
1225        build_no_shuffle_fragment_graph_edges(all_no_shuffle_relations);
1226
1227    find_no_shuffle_graphs(initial_fragment_ids, &forward_edges, &backward_edges)
1228}
1229
1230pub(crate) fn build_no_shuffle_fragment_graph_edges(
1231    relations: impl IntoIterator<Item = (FragmentId, FragmentId)>,
1232) -> (
1233    HashMap<FragmentId, Vec<FragmentId>>,
1234    HashMap<FragmentId, Vec<FragmentId>>,
1235) {
1236    let mut forward_edges: HashMap<FragmentId, HashSet<FragmentId>> = HashMap::new();
1237    let mut backward_edges: HashMap<FragmentId, HashSet<FragmentId>> = HashMap::new();
1238
1239    for (src, dst) in relations {
1240        forward_edges.entry(src).or_default().insert(dst);
1241        backward_edges.entry(dst).or_default().insert(src);
1242    }
1243
1244    let forward_edges = forward_edges
1245        .into_iter()
1246        .map(|(src, dst_set)| (src, dst_set.into_iter().collect()))
1247        .collect();
1248    let backward_edges = backward_edges
1249        .into_iter()
1250        .map(|(dst, src_set)| (dst, src_set.into_iter().collect()))
1251        .collect();
1252
1253    (forward_edges, backward_edges)
1254}
1255
1256pub(crate) fn find_no_shuffle_graphs(
1257    initial_fragment_ids: &[impl Into<FragmentId> + Copy],
1258    forward_edges: &HashMap<FragmentId, Vec<FragmentId>>,
1259    backward_edges: &HashMap<FragmentId, Vec<FragmentId>>,
1260) -> MetaResult<Vec<NoShuffleEnsemble>> {
1261    let mut graphs: Vec<NoShuffleEnsemble> = Vec::new();
1262    let mut globally_visited: HashSet<FragmentId> = HashSet::new();
1263
1264    for &init_id in initial_fragment_ids {
1265        let init_id = init_id.into();
1266        if globally_visited.contains(&init_id) {
1267            continue;
1268        }
1269
1270        // Found a new component. Traverse it to find all its nodes.
1271        let mut components = HashSet::new();
1272        let mut queue: VecDeque<FragmentId> = VecDeque::new();
1273
1274        queue.push_back(init_id);
1275        globally_visited.insert(init_id);
1276
1277        while let Some(current_id) = queue.pop_front() {
1278            components.insert(current_id);
1279            let neighbors = forward_edges
1280                .get(&current_id)
1281                .into_iter()
1282                .flatten()
1283                .chain(backward_edges.get(&current_id).into_iter().flatten());
1284
1285            for &neighbor_id in neighbors {
1286                if globally_visited.insert(neighbor_id) {
1287                    queue.push_back(neighbor_id);
1288                }
1289            }
1290        }
1291
1292        // For the newly found component, identify its roots.
1293        let mut entries = HashSet::new();
1294        for &node_id in &components {
1295            let is_root = match backward_edges.get(&node_id) {
1296                Some(parents) => parents.iter().all(|p| !components.contains(p)),
1297                None => true,
1298            };
1299            if is_root {
1300                entries.insert(node_id);
1301            }
1302        }
1303
1304        // Store the detailed DAG structure (roots, all nodes in this DAG).
1305        if !entries.is_empty() {
1306            graphs.push(NoShuffleEnsemble {
1307                entries,
1308                components,
1309            });
1310        }
1311    }
1312
1313    Ok(graphs)
1314}
1315
1316#[cfg(test)]
1317mod tests {
1318    use std::collections::{BTreeSet, HashMap, HashSet};
1319    use std::sync::Arc;
1320
1321    use risingwave_connector::source::SplitImpl;
1322    use risingwave_connector::source::test_source::TestSourceSplit;
1323    use risingwave_meta_model::{CreateType, JobStatus};
1324    use risingwave_pb::common::WorkerType;
1325    use risingwave_pb::common::worker_node::Property as WorkerProperty;
1326    use risingwave_pb::stream_plan::StreamNode as PbStreamNode;
1327
1328    use super::*;
1329
1330    // Helper type aliases for cleaner test code
1331    // Using the actual FragmentId type from the module
1332    type Edges = (
1333        HashMap<FragmentId, Vec<FragmentId>>,
1334        HashMap<FragmentId, Vec<FragmentId>>,
1335    );
1336
1337    /// A helper function to build forward and backward edge maps from a simple list of tuples.
1338    /// This reduces boilerplate in each test.
1339    fn build_edges(relations: &[(u32, u32)]) -> Edges {
1340        let mut forward_edges: HashMap<FragmentId, Vec<FragmentId>> = HashMap::new();
1341        let mut backward_edges: HashMap<FragmentId, Vec<FragmentId>> = HashMap::new();
1342        for &(src, dst) in relations {
1343            forward_edges
1344                .entry(src.into())
1345                .or_default()
1346                .push(dst.into());
1347            backward_edges
1348                .entry(dst.into())
1349                .or_default()
1350                .push(src.into());
1351        }
1352        (forward_edges, backward_edges)
1353    }
1354
1355    /// Helper function to create a `HashSet` from a slice easily.
1356    fn to_hashset(ids: &[u32]) -> HashSet<FragmentId> {
1357        ids.iter().map(|id| (*id).into()).collect()
1358    }
1359
1360    fn build_fragment(
1361        fragment_id: FragmentId,
1362        job_id: JobId,
1363        fragment_type_mask: i32,
1364        distribution_type: DistributionType,
1365        vnode_count: i32,
1366        parallelism: StreamingParallelism,
1367    ) -> LoadedFragment {
1368        LoadedFragment {
1369            fragment_id,
1370            job_id,
1371            fragment_type_mask: FragmentTypeMask::from(fragment_type_mask),
1372            distribution_type,
1373            vnode_count: vnode_count as usize,
1374            nodes: PbStreamNode::default(),
1375            state_table_ids: HashSet::new(),
1376            parallelism: Some(parallelism),
1377        }
1378    }
1379
1380    type ActorState = (ActorId, WorkerId, Option<Vec<usize>>, Vec<String>);
1381
1382    fn collect_actor_state(fragment: &InflightFragmentInfo) -> Vec<ActorState> {
1383        let base = fragment.actors.keys().copied().min().unwrap_or_default();
1384
1385        let mut entries: Vec<_> = fragment
1386            .actors
1387            .iter()
1388            .map(|(&actor_id, info)| {
1389                let idx = actor_id.as_raw_id() - base.as_raw_id();
1390                let vnode_indices = info.vnode_bitmap.as_ref().map(|bitmap| {
1391                    bitmap
1392                        .iter()
1393                        .enumerate()
1394                        .filter_map(|(pos, is_set)| is_set.then_some(pos))
1395                        .collect::<Vec<_>>()
1396                });
1397                let splits = info
1398                    .splits
1399                    .iter()
1400                    .map(|split| split.id().to_string())
1401                    .collect::<Vec<_>>();
1402                (idx.into(), info.worker_id, vnode_indices, splits)
1403            })
1404            .collect();
1405
1406        entries.sort_by_key(|(idx, _, _, _)| *idx);
1407        entries
1408    }
1409
1410    fn build_worker_node(
1411        id: impl Into<WorkerId>,
1412        parallelism: usize,
1413        resource_group: &str,
1414    ) -> WorkerNode {
1415        WorkerNode {
1416            id: id.into(),
1417            r#type: WorkerType::ComputeNode as i32,
1418            property: Some(WorkerProperty {
1419                is_streaming: true,
1420                parallelism: u32::try_from(parallelism).expect("parallelism fits into u32"),
1421                resource_group: Some(resource_group.to_owned()),
1422                ..Default::default()
1423            }),
1424            ..Default::default()
1425        }
1426    }
1427
1428    #[test]
1429    fn test_single_linear_chain() {
1430        // Scenario: A simple linear graph 1 -> 2 -> 3.
1431        // We start from the middle node (2).
1432        let (forward, backward) = build_edges(&[(1, 2), (2, 3)]);
1433        let initial_ids = &[2];
1434
1435        // Act
1436        let result = find_no_shuffle_graphs(initial_ids, &forward, &backward);
1437
1438        // Assert
1439        assert!(result.is_ok());
1440        let graphs = result.unwrap();
1441
1442        assert_eq!(graphs.len(), 1);
1443        let graph = &graphs[0];
1444        assert_eq!(graph.entries, to_hashset(&[1]));
1445        assert_eq!(graph.components, to_hashset(&[1, 2, 3]));
1446    }
1447
1448    #[test]
1449    fn test_two_disconnected_graphs() {
1450        // Scenario: Two separate graphs: 1->2 and 10->11.
1451        // We start with one node from each graph.
1452        let (forward, backward) = build_edges(&[(1, 2), (10, 11)]);
1453        let initial_ids = &[2, 10];
1454
1455        // Act
1456        let mut graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1457
1458        // Assert
1459        assert_eq!(graphs.len(), 2);
1460
1461        // Sort results to make the test deterministic, as HashMap iteration order is not guaranteed.
1462        graphs.sort_by_key(|g| *g.components.iter().min().unwrap_or(&0.into()));
1463
1464        // Graph 1
1465        assert_eq!(graphs[0].entries, to_hashset(&[1]));
1466        assert_eq!(graphs[0].components, to_hashset(&[1, 2]));
1467
1468        // Graph 2
1469        assert_eq!(graphs[1].entries, to_hashset(&[10]));
1470        assert_eq!(graphs[1].components, to_hashset(&[10, 11]));
1471    }
1472
1473    #[test]
1474    fn test_multiple_entries_in_one_graph() {
1475        // Scenario: A graph with two roots feeding into one node: 1->3, 2->3.
1476        let (forward, backward) = build_edges(&[(1, 3), (2, 3)]);
1477        let initial_ids = &[3];
1478
1479        // Act
1480        let graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1481
1482        // Assert
1483        assert_eq!(graphs.len(), 1);
1484        let graph = &graphs[0];
1485        assert_eq!(graph.entries, to_hashset(&[1, 2]));
1486        assert_eq!(graph.components, to_hashset(&[1, 2, 3]));
1487    }
1488
1489    #[test]
1490    fn test_diamond_shape_graph() {
1491        // Scenario: A diamond shape: 1->2, 1->3, 2->4, 3->4
1492        let (forward, backward) = build_edges(&[(1, 2), (1, 3), (2, 4), (3, 4)]);
1493        let initial_ids = &[4];
1494
1495        // Act
1496        let graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1497
1498        // Assert
1499        assert_eq!(graphs.len(), 1);
1500        let graph = &graphs[0];
1501        assert_eq!(graph.entries, to_hashset(&[1]));
1502        assert_eq!(graph.components, to_hashset(&[1, 2, 3, 4]));
1503    }
1504
1505    #[test]
1506    fn test_starting_with_multiple_nodes_in_same_graph() {
1507        // Scenario: Start with two different nodes (2 and 4) from the same component.
1508        // Should only identify one graph, not two.
1509        let (forward, backward) = build_edges(&[(1, 2), (2, 3), (3, 4)]);
1510        let initial_ids = &[2, 4];
1511
1512        // Act
1513        let graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1514
1515        // Assert
1516        assert_eq!(graphs.len(), 1);
1517        let graph = &graphs[0];
1518        assert_eq!(graph.entries, to_hashset(&[1]));
1519        assert_eq!(graph.components, to_hashset(&[1, 2, 3, 4]));
1520    }
1521
1522    #[test]
1523    fn test_empty_initial_ids() {
1524        // Scenario: The initial ID list is empty.
1525        let (forward, backward) = build_edges(&[(1, 2)]);
1526        let initial_ids: &[u32] = &[];
1527
1528        // Act
1529        let graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1530
1531        // Assert
1532        assert!(graphs.is_empty());
1533    }
1534
1535    #[test]
1536    fn test_isolated_node_as_input() {
1537        // Scenario: Start with an ID that has no relations.
1538        let (forward, backward) = build_edges(&[(1, 2)]);
1539        let initial_ids = &[100];
1540
1541        // Act
1542        let graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1543
1544        // Assert
1545        assert_eq!(graphs.len(), 1);
1546        let graph = &graphs[0];
1547        assert_eq!(graph.entries, to_hashset(&[100]));
1548        assert_eq!(graph.components, to_hashset(&[100]));
1549    }
1550
1551    #[test]
1552    fn test_graph_with_a_cycle() {
1553        // Scenario: A graph with a cycle: 1 -> 2 -> 3 -> 1.
1554        // The algorithm should correctly identify all nodes in the component.
1555        // Crucially, NO node is a root because every node has a parent *within the component*.
1556        // Therefore, the `entries` set should be empty, and the graph should not be included in the results.
1557        let (forward, backward) = build_edges(&[(1, 2), (2, 3), (3, 1)]);
1558        let initial_ids = &[2];
1559
1560        // Act
1561        let graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1562
1563        // Assert
1564        assert!(
1565            graphs.is_empty(),
1566            "A graph with no entries should not be returned"
1567        );
1568    }
1569    #[test]
1570    fn test_custom_complex() {
1571        let (forward, backward) = build_edges(&[(1, 3), (1, 8), (2, 3), (4, 3), (3, 5), (6, 7)]);
1572        let initial_ids = &[1, 2, 4, 6];
1573
1574        // Act
1575        let mut graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1576
1577        // Assert
1578        assert_eq!(graphs.len(), 2);
1579        // Sort results to make the test deterministic, as HashMap iteration order is not guaranteed.
1580        graphs.sort_by_key(|g| *g.components.iter().min().unwrap_or(&0.into()));
1581
1582        // Graph 1
1583        assert_eq!(graphs[0].entries, to_hashset(&[1, 2, 4]));
1584        assert_eq!(graphs[0].components, to_hashset(&[1, 2, 3, 4, 5, 8]));
1585
1586        // Graph 2
1587        assert_eq!(graphs[1].entries, to_hashset(&[6]));
1588        assert_eq!(graphs[1].components, to_hashset(&[6, 7]));
1589    }
1590
1591    #[test]
1592    fn render_actors_increments_actor_counter() {
1593        let actor_id_counter = AtomicU32::new(100);
1594        let fragment_id: FragmentId = 1.into();
1595        let job_id: JobId = 10.into();
1596        let database_id: DatabaseId = DatabaseId::new(3);
1597
1598        let fragment_model = build_fragment(
1599            fragment_id,
1600            job_id,
1601            0,
1602            DistributionType::Single,
1603            1,
1604            StreamingParallelism::Fixed(1),
1605        );
1606
1607        let job_model = streaming_job::Model {
1608            job_id,
1609            job_status: JobStatus::Created,
1610            create_type: CreateType::Foreground,
1611            timezone: None,
1612            config_override: None,
1613            adaptive_parallelism_strategy: None,
1614            parallelism: StreamingParallelism::Fixed(1),
1615            backfill_parallelism: None,
1616            backfill_orders: None,
1617            max_parallelism: 1,
1618            specific_resource_group: None,
1619            is_serverless_backfill: false,
1620        };
1621
1622        let database_model = database::Model {
1623            database_id,
1624            name: "test_db".into(),
1625            resource_group: "rg-a".into(),
1626            barrier_interval_ms: None,
1627            checkpoint_frequency: None,
1628        };
1629
1630        let ensembles = vec![NoShuffleEnsemble {
1631            entries: HashSet::from([fragment_id]),
1632            components: HashSet::from([fragment_id]),
1633        }];
1634
1635        let fragment_map = HashMap::from([(fragment_id, fragment_model)]);
1636        let job_fragments = HashMap::from([(job_id, fragment_map)]);
1637        let job_map = HashMap::from([(job_id, job_model)]);
1638
1639        let worker_map: HashMap<WorkerId, WorkerNode> =
1640            HashMap::from([(1.into(), build_worker_node(1, 1, "rg-a"))]);
1641
1642        let fragment_source_ids: HashMap<FragmentId, SourceId> = HashMap::new();
1643        let fragment_splits: HashMap<FragmentId, Vec<SplitImpl>> = HashMap::new();
1644        let streaming_job_databases = HashMap::from([(job_id, database_id)]);
1645        let database_map = HashMap::from([(database_id, database_model)]);
1646        let backfill_jobs = HashSet::new();
1647
1648        let context = RenderActorsContext {
1649            fragment_source_ids: &fragment_source_ids,
1650            fragment_splits: &fragment_splits,
1651            streaming_job_databases: &streaming_job_databases,
1652            database_map: &database_map,
1653            backfill_jobs: &backfill_jobs,
1654        };
1655
1656        let result = render_actors(
1657            &actor_id_counter,
1658            &ensembles,
1659            &job_fragments,
1660            &job_map,
1661            &worker_map,
1662            AdaptiveParallelismStrategy::Auto,
1663            context,
1664        )
1665        .expect("actor rendering succeeds");
1666
1667        let state = collect_actor_state(&result[&database_id][&job_id][&fragment_id]);
1668        assert_eq!(state.len(), 1);
1669        assert!(
1670            state[0].2.is_none(),
1671            "single distribution should not assign vnode bitmaps"
1672        );
1673        assert_eq!(actor_id_counter.load(Ordering::Relaxed), 101);
1674    }
1675
1676    #[test]
1677    fn render_actors_aligns_hash_vnode_bitmaps() {
1678        let actor_id_counter = AtomicU32::new(0);
1679        let entry_fragment_id: FragmentId = 1.into();
1680        let downstream_fragment_id: FragmentId = 2.into();
1681        let job_id: JobId = 20.into();
1682        let database_id: DatabaseId = DatabaseId::new(5);
1683
1684        let entry_fragment = build_fragment(
1685            entry_fragment_id,
1686            job_id,
1687            0,
1688            DistributionType::Hash,
1689            4,
1690            StreamingParallelism::Fixed(2),
1691        );
1692
1693        let downstream_fragment = build_fragment(
1694            downstream_fragment_id,
1695            job_id,
1696            0,
1697            DistributionType::Hash,
1698            4,
1699            StreamingParallelism::Fixed(2),
1700        );
1701
1702        let job_model = streaming_job::Model {
1703            job_id,
1704            job_status: JobStatus::Created,
1705            create_type: CreateType::Background,
1706            timezone: None,
1707            config_override: None,
1708            adaptive_parallelism_strategy: None,
1709            parallelism: StreamingParallelism::Fixed(2),
1710            backfill_parallelism: None,
1711            backfill_orders: None,
1712            max_parallelism: 2,
1713            specific_resource_group: None,
1714            is_serverless_backfill: false,
1715        };
1716
1717        let database_model = database::Model {
1718            database_id,
1719            name: "test_db_hash".into(),
1720            resource_group: "rg-hash".into(),
1721            barrier_interval_ms: None,
1722            checkpoint_frequency: None,
1723        };
1724
1725        let ensembles = vec![NoShuffleEnsemble {
1726            entries: HashSet::from([entry_fragment_id]),
1727            components: HashSet::from([entry_fragment_id, downstream_fragment_id]),
1728        }];
1729
1730        let fragment_map = HashMap::from([
1731            (entry_fragment_id, entry_fragment),
1732            (downstream_fragment_id, downstream_fragment),
1733        ]);
1734        let job_fragments = HashMap::from([(job_id, fragment_map)]);
1735        let job_map = HashMap::from([(job_id, job_model)]);
1736
1737        let worker_map: HashMap<WorkerId, WorkerNode> = HashMap::from([
1738            (1.into(), build_worker_node(1, 1, "rg-hash")),
1739            (2.into(), build_worker_node(2, 1, "rg-hash")),
1740        ]);
1741
1742        let fragment_source_ids: HashMap<FragmentId, SourceId> = HashMap::new();
1743        let fragment_splits: HashMap<FragmentId, Vec<SplitImpl>> = HashMap::new();
1744        let streaming_job_databases = HashMap::from([(job_id, database_id)]);
1745        let database_map = HashMap::from([(database_id, database_model)]);
1746        let backfill_jobs = HashSet::new();
1747
1748        let context = RenderActorsContext {
1749            fragment_source_ids: &fragment_source_ids,
1750            fragment_splits: &fragment_splits,
1751            streaming_job_databases: &streaming_job_databases,
1752            database_map: &database_map,
1753            backfill_jobs: &backfill_jobs,
1754        };
1755
1756        let result = render_actors(
1757            &actor_id_counter,
1758            &ensembles,
1759            &job_fragments,
1760            &job_map,
1761            &worker_map,
1762            AdaptiveParallelismStrategy::Auto,
1763            context,
1764        )
1765        .expect("actor rendering succeeds");
1766
1767        let entry_state = collect_actor_state(&result[&database_id][&job_id][&entry_fragment_id]);
1768        let downstream_state =
1769            collect_actor_state(&result[&database_id][&job_id][&downstream_fragment_id]);
1770
1771        assert_eq!(entry_state.len(), 2);
1772        assert_eq!(entry_state, downstream_state);
1773
1774        let assigned_vnodes: BTreeSet<_> = entry_state
1775            .iter()
1776            .flat_map(|(_, _, vnodes, _)| {
1777                vnodes
1778                    .as_ref()
1779                    .expect("hash distribution should populate vnode bitmap")
1780                    .iter()
1781                    .copied()
1782            })
1783            .collect();
1784        assert_eq!(assigned_vnodes, BTreeSet::from([0, 1, 2, 3]));
1785        assert_eq!(actor_id_counter.load(Ordering::Relaxed), 4);
1786    }
1787
1788    #[test]
1789    fn render_actors_propagates_source_splits() {
1790        let actor_id_counter = AtomicU32::new(0);
1791        let entry_fragment_id: FragmentId = 11.into();
1792        let downstream_fragment_id: FragmentId = 12.into();
1793        let job_id: JobId = 30.into();
1794        let database_id: DatabaseId = DatabaseId::new(7);
1795        let source_id: SourceId = 99.into();
1796
1797        let source_mask = FragmentTypeFlag::raw_flag([FragmentTypeFlag::Source]) as i32;
1798        let source_scan_mask = FragmentTypeFlag::raw_flag([FragmentTypeFlag::SourceScan]) as i32;
1799
1800        let entry_fragment = build_fragment(
1801            entry_fragment_id,
1802            job_id,
1803            source_mask,
1804            DistributionType::Hash,
1805            4,
1806            StreamingParallelism::Fixed(2),
1807        );
1808
1809        let downstream_fragment = build_fragment(
1810            downstream_fragment_id,
1811            job_id,
1812            source_scan_mask,
1813            DistributionType::Hash,
1814            4,
1815            StreamingParallelism::Fixed(2),
1816        );
1817
1818        let job_model = streaming_job::Model {
1819            job_id,
1820            job_status: JobStatus::Created,
1821            create_type: CreateType::Background,
1822            timezone: None,
1823            config_override: None,
1824            adaptive_parallelism_strategy: None,
1825            parallelism: StreamingParallelism::Fixed(2),
1826            backfill_parallelism: None,
1827            backfill_orders: None,
1828            max_parallelism: 2,
1829            specific_resource_group: None,
1830            is_serverless_backfill: false,
1831        };
1832
1833        let database_model = database::Model {
1834            database_id,
1835            name: "split_db".into(),
1836            resource_group: "rg-source".into(),
1837            barrier_interval_ms: None,
1838            checkpoint_frequency: None,
1839        };
1840
1841        let ensembles = vec![NoShuffleEnsemble {
1842            entries: HashSet::from([entry_fragment_id]),
1843            components: HashSet::from([entry_fragment_id, downstream_fragment_id]),
1844        }];
1845
1846        let fragment_map = HashMap::from([
1847            (entry_fragment_id, entry_fragment),
1848            (downstream_fragment_id, downstream_fragment),
1849        ]);
1850        let job_fragments = HashMap::from([(job_id, fragment_map)]);
1851        let job_map = HashMap::from([(job_id, job_model)]);
1852
1853        let worker_map: HashMap<WorkerId, WorkerNode> = HashMap::from([
1854            (1.into(), build_worker_node(1, 1, "rg-source")),
1855            (2.into(), build_worker_node(2, 1, "rg-source")),
1856        ]);
1857
1858        let split_a = SplitImpl::Test(TestSourceSplit {
1859            id: Arc::<str>::from("split-a"),
1860            properties: HashMap::new(),
1861            offset: "0".into(),
1862        });
1863        let split_b = SplitImpl::Test(TestSourceSplit {
1864            id: Arc::<str>::from("split-b"),
1865            properties: HashMap::new(),
1866            offset: "0".into(),
1867        });
1868
1869        let fragment_source_ids = HashMap::from([
1870            (entry_fragment_id, source_id),
1871            (downstream_fragment_id, source_id),
1872        ]);
1873        let fragment_splits =
1874            HashMap::from([(entry_fragment_id, vec![split_a.clone(), split_b.clone()])]);
1875        let streaming_job_databases = HashMap::from([(job_id, database_id)]);
1876        let database_map = HashMap::from([(database_id, database_model)]);
1877        let backfill_jobs = HashSet::new();
1878
1879        let context = RenderActorsContext {
1880            fragment_source_ids: &fragment_source_ids,
1881            fragment_splits: &fragment_splits,
1882            streaming_job_databases: &streaming_job_databases,
1883            database_map: &database_map,
1884            backfill_jobs: &backfill_jobs,
1885        };
1886
1887        let result = render_actors(
1888            &actor_id_counter,
1889            &ensembles,
1890            &job_fragments,
1891            &job_map,
1892            &worker_map,
1893            AdaptiveParallelismStrategy::Auto,
1894            context,
1895        )
1896        .expect("actor rendering succeeds");
1897
1898        let entry_state = collect_actor_state(&result[&database_id][&job_id][&entry_fragment_id]);
1899        let downstream_state =
1900            collect_actor_state(&result[&database_id][&job_id][&downstream_fragment_id]);
1901
1902        assert_eq!(entry_state, downstream_state);
1903
1904        let split_ids: BTreeSet<_> = entry_state
1905            .iter()
1906            .flat_map(|(_, _, _, splits)| splits.iter().cloned())
1907            .collect();
1908        assert_eq!(
1909            split_ids,
1910            BTreeSet::from([split_a.id().to_string(), split_b.id().to_string()])
1911        );
1912        assert_eq!(actor_id_counter.load(Ordering::Relaxed), 4);
1913    }
1914
1915    /// Test that job-level strategy overrides global strategy for Adaptive parallelism.
1916    #[test]
1917    fn render_actors_job_strategy_overrides_global() {
1918        let actor_id_counter = AtomicU32::new(0);
1919        let fragment_id: FragmentId = 1.into();
1920        let job_id: JobId = 100.into();
1921        let database_id: DatabaseId = DatabaseId::new(10);
1922
1923        // Fragment with Adaptive parallelism, vnode_count = 8
1924        let fragment_model = build_fragment(
1925            fragment_id,
1926            job_id,
1927            0,
1928            DistributionType::Hash,
1929            8,
1930            StreamingParallelism::Adaptive,
1931        );
1932
1933        // Job has custom strategy: BOUNDED(2)
1934        let job_model = streaming_job::Model {
1935            job_id,
1936            job_status: JobStatus::Created,
1937            create_type: CreateType::Foreground,
1938            timezone: None,
1939            config_override: None,
1940            adaptive_parallelism_strategy: Some("BOUNDED(2)".to_owned()),
1941            parallelism: StreamingParallelism::Adaptive,
1942            backfill_parallelism: None,
1943            backfill_orders: None,
1944            max_parallelism: 8,
1945            specific_resource_group: None,
1946            is_serverless_backfill: false,
1947        };
1948
1949        let database_model = database::Model {
1950            database_id,
1951            name: "test_db".into(),
1952            resource_group: "default".into(),
1953            barrier_interval_ms: None,
1954            checkpoint_frequency: None,
1955        };
1956
1957        let ensembles = vec![NoShuffleEnsemble {
1958            entries: HashSet::from([fragment_id]),
1959            components: HashSet::from([fragment_id]),
1960        }];
1961
1962        let fragment_map =
1963            HashMap::from([(job_id, HashMap::from([(fragment_id, fragment_model)]))]);
1964        let job_map = HashMap::from([(job_id, job_model)]);
1965
1966        // 4 workers with 1 parallelism each = total 4 parallelism
1967        let worker_map = HashMap::from([
1968            (1.into(), build_worker_node(1, 1, "default")),
1969            (2.into(), build_worker_node(2, 1, "default")),
1970            (3.into(), build_worker_node(3, 1, "default")),
1971            (4.into(), build_worker_node(4, 1, "default")),
1972        ]);
1973
1974        let fragment_source_ids: HashMap<FragmentId, SourceId> = HashMap::new();
1975        let fragment_splits: HashMap<FragmentId, Vec<SplitImpl>> = HashMap::new();
1976        let streaming_job_databases = HashMap::from([(job_id, database_id)]);
1977        let database_map = HashMap::from([(database_id, database_model)]);
1978        let backfill_jobs = HashSet::new();
1979
1980        let context = RenderActorsContext {
1981            fragment_source_ids: &fragment_source_ids,
1982            fragment_splits: &fragment_splits,
1983            streaming_job_databases: &streaming_job_databases,
1984            database_map: &database_map,
1985            backfill_jobs: &backfill_jobs,
1986        };
1987
1988        // Global strategy is FULL (would give 4 actors), but job strategy is BOUNDED(2)
1989        let result = render_actors(
1990            &actor_id_counter,
1991            &ensembles,
1992            &fragment_map,
1993            &job_map,
1994            &worker_map,
1995            AdaptiveParallelismStrategy::Full,
1996            context,
1997        )
1998        .expect("actor rendering succeeds");
1999
2000        let state = collect_actor_state(&result[&database_id][&job_id][&fragment_id]);
2001        // Job strategy BOUNDED(2) should limit to 2 actors, not 4 (global FULL)
2002        assert_eq!(
2003            state.len(),
2004            2,
2005            "Job strategy BOUNDED(2) should override global FULL"
2006        );
2007    }
2008
2009    /// Test that global strategy is used when job has no custom strategy.
2010    #[test]
2011    fn render_actors_uses_global_strategy_when_job_has_none() {
2012        let actor_id_counter = AtomicU32::new(0);
2013        let fragment_id: FragmentId = 1.into();
2014        let job_id: JobId = 101.into();
2015        let database_id: DatabaseId = DatabaseId::new(11);
2016
2017        let fragment_model = build_fragment(
2018            fragment_id,
2019            job_id,
2020            0,
2021            DistributionType::Hash,
2022            8,
2023            StreamingParallelism::Adaptive,
2024        );
2025
2026        // Job has NO custom strategy (None)
2027        let job_model = streaming_job::Model {
2028            job_id,
2029            job_status: JobStatus::Created,
2030            create_type: CreateType::Foreground,
2031            timezone: None,
2032            config_override: None,
2033            adaptive_parallelism_strategy: None, // No custom strategy
2034            parallelism: StreamingParallelism::Adaptive,
2035            backfill_parallelism: None,
2036            backfill_orders: None,
2037            max_parallelism: 8,
2038            specific_resource_group: None,
2039            is_serverless_backfill: false,
2040        };
2041
2042        let database_model = database::Model {
2043            database_id,
2044            name: "test_db".into(),
2045            resource_group: "default".into(),
2046            barrier_interval_ms: None,
2047            checkpoint_frequency: None,
2048        };
2049
2050        let ensembles = vec![NoShuffleEnsemble {
2051            entries: HashSet::from([fragment_id]),
2052            components: HashSet::from([fragment_id]),
2053        }];
2054
2055        let fragment_map =
2056            HashMap::from([(job_id, HashMap::from([(fragment_id, fragment_model)]))]);
2057        let job_map = HashMap::from([(job_id, job_model)]);
2058
2059        // 4 workers = total 4 parallelism
2060        let worker_map = HashMap::from([
2061            (1.into(), build_worker_node(1, 1, "default")),
2062            (2.into(), build_worker_node(2, 1, "default")),
2063            (3.into(), build_worker_node(3, 1, "default")),
2064            (4.into(), build_worker_node(4, 1, "default")),
2065        ]);
2066
2067        let fragment_source_ids: HashMap<FragmentId, SourceId> = HashMap::new();
2068        let fragment_splits: HashMap<FragmentId, Vec<SplitImpl>> = HashMap::new();
2069        let streaming_job_databases = HashMap::from([(job_id, database_id)]);
2070        let database_map = HashMap::from([(database_id, database_model)]);
2071        let backfill_jobs = HashSet::new();
2072
2073        let context = RenderActorsContext {
2074            fragment_source_ids: &fragment_source_ids,
2075            fragment_splits: &fragment_splits,
2076            streaming_job_databases: &streaming_job_databases,
2077            database_map: &database_map,
2078            backfill_jobs: &backfill_jobs,
2079        };
2080
2081        // Global strategy is BOUNDED(3)
2082        let result = render_actors(
2083            &actor_id_counter,
2084            &ensembles,
2085            &fragment_map,
2086            &job_map,
2087            &worker_map,
2088            AdaptiveParallelismStrategy::Bounded(NonZeroUsize::new(3).unwrap()),
2089            context,
2090        )
2091        .expect("actor rendering succeeds");
2092
2093        let state = collect_actor_state(&result[&database_id][&job_id][&fragment_id]);
2094        // Should use global strategy BOUNDED(3)
2095        assert_eq!(
2096            state.len(),
2097            3,
2098            "Should use global strategy BOUNDED(3) when job has no custom strategy"
2099        );
2100    }
2101
2102    /// Test that Fixed parallelism ignores strategy entirely.
2103    #[test]
2104    fn render_actors_fixed_parallelism_ignores_strategy() {
2105        let actor_id_counter = AtomicU32::new(0);
2106        let fragment_id: FragmentId = 1.into();
2107        let job_id: JobId = 102.into();
2108        let database_id: DatabaseId = DatabaseId::new(12);
2109
2110        // Fragment with FIXED parallelism
2111        let fragment_model = build_fragment(
2112            fragment_id,
2113            job_id,
2114            0,
2115            DistributionType::Hash,
2116            8,
2117            StreamingParallelism::Fixed(5),
2118        );
2119
2120        // Job has custom strategy, but it should be ignored for Fixed parallelism
2121        let job_model = streaming_job::Model {
2122            job_id,
2123            job_status: JobStatus::Created,
2124            create_type: CreateType::Foreground,
2125            timezone: None,
2126            config_override: None,
2127            adaptive_parallelism_strategy: Some("BOUNDED(2)".to_owned()),
2128            parallelism: StreamingParallelism::Fixed(5),
2129            backfill_parallelism: None,
2130            backfill_orders: None,
2131            max_parallelism: 8,
2132            specific_resource_group: None,
2133            is_serverless_backfill: false,
2134        };
2135
2136        let database_model = database::Model {
2137            database_id,
2138            name: "test_db".into(),
2139            resource_group: "default".into(),
2140            barrier_interval_ms: None,
2141            checkpoint_frequency: None,
2142        };
2143
2144        let ensembles = vec![NoShuffleEnsemble {
2145            entries: HashSet::from([fragment_id]),
2146            components: HashSet::from([fragment_id]),
2147        }];
2148
2149        let fragment_map =
2150            HashMap::from([(job_id, HashMap::from([(fragment_id, fragment_model)]))]);
2151        let job_map = HashMap::from([(job_id, job_model)]);
2152
2153        // 6 workers = total 6 parallelism
2154        let worker_map = HashMap::from([
2155            (1.into(), build_worker_node(1, 1, "default")),
2156            (2.into(), build_worker_node(2, 1, "default")),
2157            (3.into(), build_worker_node(3, 1, "default")),
2158            (4.into(), build_worker_node(4, 1, "default")),
2159            (5.into(), build_worker_node(5, 1, "default")),
2160            (6.into(), build_worker_node(6, 1, "default")),
2161        ]);
2162
2163        let fragment_source_ids: HashMap<FragmentId, SourceId> = HashMap::new();
2164        let fragment_splits: HashMap<FragmentId, Vec<SplitImpl>> = HashMap::new();
2165        let streaming_job_databases = HashMap::from([(job_id, database_id)]);
2166        let database_map = HashMap::from([(database_id, database_model)]);
2167        let backfill_jobs = HashSet::new();
2168
2169        let context = RenderActorsContext {
2170            fragment_source_ids: &fragment_source_ids,
2171            fragment_splits: &fragment_splits,
2172            streaming_job_databases: &streaming_job_databases,
2173            database_map: &database_map,
2174            backfill_jobs: &backfill_jobs,
2175        };
2176
2177        let result = render_actors(
2178            &actor_id_counter,
2179            &ensembles,
2180            &fragment_map,
2181            &job_map,
2182            &worker_map,
2183            AdaptiveParallelismStrategy::Full,
2184            context,
2185        )
2186        .expect("actor rendering succeeds");
2187
2188        let state = collect_actor_state(&result[&database_id][&job_id][&fragment_id]);
2189        // Fixed(5) should be used, ignoring both job strategy BOUNDED(2) and global FULL
2190        assert_eq!(
2191            state.len(),
2192            5,
2193            "Fixed parallelism should ignore all strategies"
2194        );
2195    }
2196
2197    /// Test RATIO strategy calculation.
2198    #[test]
2199    fn render_actors_ratio_strategy() {
2200        let actor_id_counter = AtomicU32::new(0);
2201        let fragment_id: FragmentId = 1.into();
2202        let job_id: JobId = 103.into();
2203        let database_id: DatabaseId = DatabaseId::new(13);
2204
2205        let fragment_model = build_fragment(
2206            fragment_id,
2207            job_id,
2208            0,
2209            DistributionType::Hash,
2210            16,
2211            StreamingParallelism::Adaptive,
2212        );
2213
2214        // Job has RATIO(0.5) strategy
2215        let job_model = streaming_job::Model {
2216            job_id,
2217            job_status: JobStatus::Created,
2218            create_type: CreateType::Foreground,
2219            timezone: None,
2220            config_override: None,
2221            adaptive_parallelism_strategy: Some("RATIO(0.5)".to_owned()),
2222            parallelism: StreamingParallelism::Adaptive,
2223            backfill_parallelism: None,
2224            backfill_orders: None,
2225            max_parallelism: 16,
2226            specific_resource_group: None,
2227            is_serverless_backfill: false,
2228        };
2229
2230        let database_model = database::Model {
2231            database_id,
2232            name: "test_db".into(),
2233            resource_group: "default".into(),
2234            barrier_interval_ms: None,
2235            checkpoint_frequency: None,
2236        };
2237
2238        let ensembles = vec![NoShuffleEnsemble {
2239            entries: HashSet::from([fragment_id]),
2240            components: HashSet::from([fragment_id]),
2241        }];
2242
2243        let fragment_map =
2244            HashMap::from([(job_id, HashMap::from([(fragment_id, fragment_model)]))]);
2245        let job_map = HashMap::from([(job_id, job_model)]);
2246
2247        // 8 workers = total 8 parallelism
2248        let worker_map = HashMap::from([
2249            (1.into(), build_worker_node(1, 1, "default")),
2250            (2.into(), build_worker_node(2, 1, "default")),
2251            (3.into(), build_worker_node(3, 1, "default")),
2252            (4.into(), build_worker_node(4, 1, "default")),
2253            (5.into(), build_worker_node(5, 1, "default")),
2254            (6.into(), build_worker_node(6, 1, "default")),
2255            (7.into(), build_worker_node(7, 1, "default")),
2256            (8.into(), build_worker_node(8, 1, "default")),
2257        ]);
2258
2259        let fragment_source_ids: HashMap<FragmentId, SourceId> = HashMap::new();
2260        let fragment_splits: HashMap<FragmentId, Vec<SplitImpl>> = HashMap::new();
2261        let streaming_job_databases = HashMap::from([(job_id, database_id)]);
2262        let database_map = HashMap::from([(database_id, database_model)]);
2263        let backfill_jobs = HashSet::new();
2264
2265        let context = RenderActorsContext {
2266            fragment_source_ids: &fragment_source_ids,
2267            fragment_splits: &fragment_splits,
2268            streaming_job_databases: &streaming_job_databases,
2269            database_map: &database_map,
2270            backfill_jobs: &backfill_jobs,
2271        };
2272
2273        let result = render_actors(
2274            &actor_id_counter,
2275            &ensembles,
2276            &fragment_map,
2277            &job_map,
2278            &worker_map,
2279            AdaptiveParallelismStrategy::Full,
2280            context,
2281        )
2282        .expect("actor rendering succeeds");
2283
2284        let state = collect_actor_state(&result[&database_id][&job_id][&fragment_id]);
2285        // RATIO(0.5) of 8 = 4
2286        assert_eq!(
2287            state.len(),
2288            4,
2289            "RATIO(0.5) of 8 workers should give 4 actors"
2290        );
2291    }
2292}