risingwave_meta/controller/
scale.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque};
16use std::num::NonZeroUsize;
17use std::sync::atomic::{AtomicU32, Ordering};
18
19use anyhow::anyhow;
20use itertools::Itertools;
21use risingwave_common::bitmap::Bitmap;
22use risingwave_common::catalog::{FragmentTypeFlag, FragmentTypeMask, TableId};
23use risingwave_common::id::JobId;
24use risingwave_common::system_param::AdaptiveParallelismStrategy;
25use risingwave_common::util::worker_util::DEFAULT_RESOURCE_GROUP;
26use risingwave_connector::source::{SplitId, SplitImpl, SplitMetaData};
27use risingwave_meta_model::fragment::DistributionType;
28use risingwave_meta_model::prelude::{
29    Database, Fragment, FragmentRelation, FragmentSplits, Sink, Source, StreamingJob, Table,
30};
31use risingwave_meta_model::{
32    DatabaseId, DispatcherType, FragmentId, JobStatus, SourceId, StreamingParallelism, WorkerId,
33    database, fragment, fragment_relation, fragment_splits, object, sink, source, streaming_job,
34    table,
35};
36use risingwave_meta_model_migration::Condition;
37use risingwave_pb::common::WorkerNode;
38use risingwave_pb::stream_plan::PbStreamNode;
39use sea_orm::{
40    ColumnTrait, ConnectionTrait, EntityTrait, JoinType, QueryFilter, QuerySelect, QueryTrait,
41    RelationTrait,
42};
43
44use crate::MetaResult;
45use crate::controller::fragment::{InflightActorInfo, InflightFragmentInfo};
46use crate::manager::ActiveStreamingWorkerNodes;
47use crate::model::{ActorId, StreamActor, StreamingJobModelContextExt};
48use crate::stream::{AssignerBuilder, SplitDiffOptions};
49
50pub(crate) async fn resolve_streaming_job_definition<C>(
51    txn: &C,
52    job_ids: &HashSet<JobId>,
53) -> MetaResult<HashMap<JobId, String>>
54where
55    C: ConnectionTrait,
56{
57    let job_ids = job_ids.iter().cloned().collect_vec();
58
59    // including table, materialized view, index
60    let common_job_definitions: Vec<(JobId, String)> = Table::find()
61        .select_only()
62        .columns([
63            table::Column::TableId,
64            #[cfg(not(debug_assertions))]
65            table::Column::Name,
66            #[cfg(debug_assertions)]
67            table::Column::Definition,
68        ])
69        .filter(table::Column::TableId.is_in(job_ids.clone()))
70        .into_tuple()
71        .all(txn)
72        .await?;
73
74    let sink_definitions: Vec<(JobId, String)> = Sink::find()
75        .select_only()
76        .columns([
77            sink::Column::SinkId,
78            #[cfg(not(debug_assertions))]
79            sink::Column::Name,
80            #[cfg(debug_assertions)]
81            sink::Column::Definition,
82        ])
83        .filter(sink::Column::SinkId.is_in(job_ids.clone()))
84        .into_tuple()
85        .all(txn)
86        .await?;
87
88    let source_definitions: Vec<(JobId, String)> = Source::find()
89        .select_only()
90        .columns([
91            source::Column::SourceId,
92            #[cfg(not(debug_assertions))]
93            source::Column::Name,
94            #[cfg(debug_assertions)]
95            source::Column::Definition,
96        ])
97        .filter(source::Column::SourceId.is_in(job_ids.clone()))
98        .into_tuple()
99        .all(txn)
100        .await?;
101
102    let definitions: HashMap<JobId, String> = common_job_definitions
103        .into_iter()
104        .chain(sink_definitions.into_iter())
105        .chain(source_definitions.into_iter())
106        .collect();
107
108    Ok(definitions)
109}
110
111pub async fn load_fragment_info<C>(
112    txn: &C,
113    actor_id_counter: &AtomicU32,
114    database_id: Option<DatabaseId>,
115    worker_nodes: &ActiveStreamingWorkerNodes,
116    adaptive_parallelism_strategy: AdaptiveParallelismStrategy,
117) -> MetaResult<FragmentRenderMap>
118where
119    C: ConnectionTrait,
120{
121    let mut query = StreamingJob::find()
122        .select_only()
123        .column(streaming_job::Column::JobId);
124
125    if let Some(database_id) = database_id {
126        query = query
127            .join(JoinType::InnerJoin, streaming_job::Relation::Object.def())
128            .filter(object::Column::DatabaseId.eq(database_id));
129    }
130
131    let jobs: Vec<JobId> = query.into_tuple().all(txn).await?;
132
133    if jobs.is_empty() {
134        return Ok(HashMap::new());
135    }
136
137    let jobs: HashSet<JobId> = jobs.into_iter().collect();
138
139    let loaded = load_fragment_context_for_jobs(txn, jobs).await?;
140
141    if loaded.is_empty() {
142        return Ok(HashMap::new());
143    }
144
145    let RenderedGraph { fragments, .. } = render_actor_assignments(
146        actor_id_counter,
147        worker_nodes.current(),
148        adaptive_parallelism_strategy,
149        &loaded,
150    )?;
151
152    Ok(fragments)
153}
154
155#[derive(Debug)]
156pub struct TargetResourcePolicy {
157    pub resource_group: Option<String>,
158    pub parallelism: StreamingParallelism,
159}
160
161#[derive(Debug, Clone)]
162pub struct WorkerInfo {
163    pub parallelism: NonZeroUsize,
164    pub resource_group: Option<String>,
165}
166
167pub type FragmentRenderMap =
168    HashMap<DatabaseId, HashMap<JobId, HashMap<FragmentId, InflightFragmentInfo>>>;
169
170#[derive(Default)]
171pub struct RenderedGraph {
172    pub fragments: FragmentRenderMap,
173    pub ensembles: Vec<NoShuffleEnsemble>,
174}
175
176impl RenderedGraph {
177    pub fn empty() -> Self {
178        Self::default()
179    }
180}
181
182/// Context loaded asynchronously from database, containing all metadata
183/// required to render actor assignments. This separates async I/O from
184/// sync rendering logic.
185#[derive(Clone, Debug)]
186pub struct LoadedFragment {
187    pub fragment_id: FragmentId,
188    pub job_id: JobId,
189    pub fragment_type_mask: FragmentTypeMask,
190    pub distribution_type: DistributionType,
191    pub vnode_count: usize,
192    pub nodes: PbStreamNode,
193    pub state_table_ids: HashSet<TableId>,
194    pub parallelism: Option<StreamingParallelism>,
195}
196
197impl From<fragment::Model> for LoadedFragment {
198    fn from(model: fragment::Model) -> Self {
199        Self {
200            fragment_id: model.fragment_id,
201            job_id: model.job_id,
202            fragment_type_mask: FragmentTypeMask::from(model.fragment_type_mask),
203            distribution_type: model.distribution_type,
204            vnode_count: model.vnode_count as usize,
205            nodes: model.stream_node.to_protobuf(),
206            state_table_ids: model.state_table_ids.into_inner().into_iter().collect(),
207            parallelism: model.parallelism,
208        }
209    }
210}
211
212#[derive(Default, Debug, Clone)]
213pub struct LoadedFragmentContext {
214    pub ensembles: Vec<NoShuffleEnsemble>,
215    pub job_fragments: HashMap<JobId, HashMap<FragmentId, LoadedFragment>>,
216    pub job_map: HashMap<JobId, streaming_job::Model>,
217    pub streaming_job_databases: HashMap<JobId, DatabaseId>,
218    pub database_map: HashMap<DatabaseId, database::Model>,
219    pub fragment_source_ids: HashMap<FragmentId, SourceId>,
220    pub fragment_splits: HashMap<FragmentId, Vec<SplitImpl>>,
221}
222
223impl LoadedFragmentContext {
224    pub fn is_empty(&self) -> bool {
225        if self.ensembles.is_empty() {
226            assert!(
227                self.job_fragments.is_empty(),
228                "non-empty job fragments for empty ensembles: {:?}",
229                self.job_fragments
230            );
231            true
232        } else {
233            false
234        }
235    }
236
237    pub fn for_database(&self, database_id: DatabaseId) -> Option<Self> {
238        let job_ids: HashSet<JobId> = self
239            .streaming_job_databases
240            .iter()
241            .filter_map(|(job_id, db_id)| (*db_id == database_id).then_some(*job_id))
242            .collect();
243
244        if job_ids.is_empty() {
245            return None;
246        }
247
248        let job_fragments: HashMap<_, _> = job_ids
249            .iter()
250            .map(|job_id| (*job_id, self.job_fragments[job_id].clone()))
251            .collect();
252
253        let fragment_ids: HashSet<_> = job_fragments
254            .values()
255            .flat_map(|fragments| fragments.keys().copied())
256            .collect();
257
258        assert!(
259            !fragment_ids.is_empty(),
260            "empty fragments for non-empty database {database_id} with jobs {job_ids:?}"
261        );
262
263        let ensembles: Vec<NoShuffleEnsemble> = self
264            .ensembles
265            .iter()
266            .filter(|ensemble| {
267                if ensemble
268                    .components
269                    .iter()
270                    .any(|fragment_id| fragment_ids.contains(fragment_id))
271                {
272                    assert!(
273                        ensemble
274                            .components
275                            .iter()
276                            .all(|fragment_id| fragment_ids.contains(fragment_id)),
277                        "ensemble {ensemble:?} partially exists in database {database_id} with fragments {job_fragments:?}"
278                    );
279                    true
280                } else {
281                    false
282                }
283            })
284            .cloned()
285            .collect();
286
287        assert!(
288            !ensembles.is_empty(),
289            "empty ensembles for non-empty database {database_id} with jobs {job_fragments:?}"
290        );
291
292        let job_map = job_ids
293            .iter()
294            .filter_map(|job_id| self.job_map.get(job_id).map(|job| (*job_id, job.clone())))
295            .collect();
296
297        let streaming_job_databases = job_ids
298            .iter()
299            .filter_map(|job_id| {
300                self.streaming_job_databases
301                    .get(job_id)
302                    .map(|db_id| (*job_id, *db_id))
303            })
304            .collect();
305
306        let database_model = self.database_map[&database_id].clone();
307        let database_map = HashMap::from([(database_id, database_model)]);
308
309        let fragment_source_ids = self
310            .fragment_source_ids
311            .iter()
312            .filter(|(fragment_id, _)| fragment_ids.contains(*fragment_id))
313            .map(|(fragment_id, source_id)| (*fragment_id, *source_id))
314            .collect();
315
316        let fragment_splits = self
317            .fragment_splits
318            .iter()
319            .filter(|(fragment_id, _)| fragment_ids.contains(*fragment_id))
320            .map(|(fragment_id, splits)| (*fragment_id, splits.clone()))
321            .collect();
322
323        Some(Self {
324            ensembles,
325            job_fragments,
326            job_map,
327            streaming_job_databases,
328            database_map,
329            fragment_source_ids,
330            fragment_splits,
331        })
332    }
333
334    /// Split this loaded context by database in a single ownership pass to avoid cloning large
335    /// fragment payloads (for example `stream_node` in `LoadedFragment`).
336    pub fn into_database_contexts(self) -> HashMap<DatabaseId, Self> {
337        let Self {
338            ensembles,
339            mut job_fragments,
340            mut job_map,
341            streaming_job_databases,
342            mut database_map,
343            mut fragment_source_ids,
344            mut fragment_splits,
345        } = self;
346
347        let mut contexts = HashMap::<DatabaseId, Self>::new();
348        let mut fragment_databases = HashMap::<FragmentId, DatabaseId>::new();
349        let mut unresolved_ensembles = 0usize;
350        let mut unresolved_ensemble_sample: Option<Vec<FragmentId>> = None;
351
352        for (job_id, database_id) in streaming_job_databases {
353            let context = contexts.entry(database_id).or_insert_with(|| {
354                let database_model = database_map
355                    .remove(&database_id)
356                    .expect("database should exist for streaming job");
357                Self {
358                    ensembles: Vec::new(),
359                    job_fragments: HashMap::new(),
360                    job_map: HashMap::new(),
361                    streaming_job_databases: HashMap::new(),
362                    database_map: HashMap::from([(database_id, database_model)]),
363                    fragment_source_ids: HashMap::new(),
364                    fragment_splits: HashMap::new(),
365                }
366            });
367
368            let fragments = job_fragments
369                .remove(&job_id)
370                .expect("job fragments should exist for streaming job");
371            for fragment_id in fragments.keys().copied() {
372                fragment_databases.insert(fragment_id, database_id);
373                if let Some(source_id) = fragment_source_ids.remove(&fragment_id) {
374                    context.fragment_source_ids.insert(fragment_id, source_id);
375                }
376                if let Some(splits) = fragment_splits.remove(&fragment_id) {
377                    context.fragment_splits.insert(fragment_id, splits);
378                }
379            }
380
381            assert!(
382                context
383                    .job_map
384                    .insert(
385                        job_id,
386                        job_map
387                            .remove(&job_id)
388                            .expect("streaming job should exist for loaded context"),
389                    )
390                    .is_none(),
391                "duplicated streaming job"
392            );
393            assert!(
394                context.job_fragments.insert(job_id, fragments).is_none(),
395                "duplicated job fragments"
396            );
397            assert!(
398                context
399                    .streaming_job_databases
400                    .insert(job_id, database_id)
401                    .is_none(),
402                "duplicated job database mapping"
403            );
404        }
405
406        for ensemble in ensembles {
407            let Some(database_id) = ensemble
408                .components
409                .iter()
410                .find_map(|fragment_id| fragment_databases.get(fragment_id).copied())
411            else {
412                unresolved_ensembles += 1;
413                if unresolved_ensemble_sample.is_none() {
414                    unresolved_ensemble_sample =
415                        Some(ensemble.components.iter().copied().collect());
416                }
417                continue;
418            };
419
420            debug_assert!(
421                ensemble
422                    .components
423                    .iter()
424                    .all(|fragment_id| fragment_databases.get(fragment_id) == Some(&database_id)),
425                "ensemble {ensemble:?} should belong to a single database"
426            );
427
428            contexts
429                .get_mut(&database_id)
430                .expect("database context should exist for ensemble")
431                .ensembles
432                .push(ensemble);
433        }
434
435        if unresolved_ensembles > 0 {
436            tracing::warn!(
437                unresolved_ensembles,
438                ?unresolved_ensemble_sample,
439                known_fragments = fragment_databases.len(),
440                "skip ensembles without resolved database while splitting loaded context"
441            );
442        }
443        debug_assert_eq!(
444            unresolved_ensembles, 0,
445            "all ensembles should be mappable to a database"
446        );
447
448        contexts
449    }
450}
451
452/// Async load stage for fragment-scoped rendering. It resolves all metadata required to later
453/// render actor assignments with arbitrary worker sets.
454pub async fn load_fragment_context<C>(
455    txn: &C,
456    ensembles: Vec<NoShuffleEnsemble>,
457) -> MetaResult<LoadedFragmentContext>
458where
459    C: ConnectionTrait,
460{
461    if ensembles.is_empty() {
462        return Ok(LoadedFragmentContext::default());
463    }
464
465    let required_fragment_ids: HashSet<_> = ensembles
466        .iter()
467        .flat_map(|ensemble| ensemble.components.iter().copied())
468        .collect();
469
470    let fragment_models = Fragment::find()
471        .filter(fragment::Column::FragmentId.is_in(required_fragment_ids.iter().copied()))
472        .all(txn)
473        .await?;
474
475    let found_fragment_ids: HashSet<_> = fragment_models
476        .iter()
477        .map(|fragment| fragment.fragment_id)
478        .collect();
479
480    if found_fragment_ids.len() != required_fragment_ids.len() {
481        let missing = required_fragment_ids
482            .difference(&found_fragment_ids)
483            .copied()
484            .collect_vec();
485        return Err(anyhow!("fragments {:?} not found", missing).into());
486    }
487
488    let fragment_models: HashMap<_, _> = fragment_models
489        .into_iter()
490        .map(|fragment| (fragment.fragment_id, fragment))
491        .collect();
492
493    let job_ids: HashSet<_> = fragment_models
494        .values()
495        .map(|fragment| fragment.job_id)
496        .collect();
497
498    if job_ids.is_empty() {
499        return Ok(LoadedFragmentContext::default());
500    }
501
502    let jobs: HashMap<_, _> = StreamingJob::find()
503        .filter(streaming_job::Column::JobId.is_in(job_ids.iter().copied().collect_vec()))
504        .all(txn)
505        .await?
506        .into_iter()
507        .map(|job| (job.job_id, job))
508        .collect();
509
510    let found_job_ids: HashSet<_> = jobs.keys().copied().collect();
511    if found_job_ids.len() != job_ids.len() {
512        let missing = job_ids.difference(&found_job_ids).copied().collect_vec();
513        return Err(anyhow!("streaming jobs {:?} not found", missing).into());
514    }
515
516    build_loaded_context(txn, ensembles, fragment_models, jobs).await
517}
518
519/// Async load stage for job-scoped rendering. It collects all no-shuffle ensembles and the
520/// metadata required to render actor assignments later with a provided worker set.
521pub async fn load_fragment_context_for_jobs<C>(
522    txn: &C,
523    job_ids: HashSet<JobId>,
524) -> MetaResult<LoadedFragmentContext>
525where
526    C: ConnectionTrait,
527{
528    if job_ids.is_empty() {
529        return Ok(LoadedFragmentContext::default());
530    }
531
532    let excluded_fragments_query = FragmentRelation::find()
533        .select_only()
534        .column(fragment_relation::Column::TargetFragmentId)
535        .filter(fragment_relation::Column::DispatcherType.eq(DispatcherType::NoShuffle))
536        .into_query();
537
538    let condition = Condition::all()
539        .add(fragment::Column::JobId.is_in(job_ids.clone()))
540        .add(fragment::Column::FragmentId.not_in_subquery(excluded_fragments_query));
541
542    let fragments: Vec<FragmentId> = Fragment::find()
543        .select_only()
544        .column(fragment::Column::FragmentId)
545        .filter(condition)
546        .into_tuple()
547        .all(txn)
548        .await?;
549
550    let ensembles = find_fragment_no_shuffle_dags_detailed(txn, &fragments).await?;
551
552    let fragments = Fragment::find()
553        .filter(
554            fragment::Column::FragmentId.is_in(
555                ensembles
556                    .iter()
557                    .flat_map(|graph| graph.components.iter())
558                    .cloned()
559                    .collect_vec(),
560            ),
561        )
562        .all(txn)
563        .await?;
564
565    let fragment_map: HashMap<_, _> = fragments
566        .into_iter()
567        .map(|fragment| (fragment.fragment_id, fragment))
568        .collect();
569
570    let job_ids = fragment_map
571        .values()
572        .map(|fragment| fragment.job_id)
573        .collect::<BTreeSet<_>>()
574        .into_iter()
575        .collect_vec();
576
577    let jobs: HashMap<_, _> = StreamingJob::find()
578        .filter(streaming_job::Column::JobId.is_in(job_ids))
579        .all(txn)
580        .await?
581        .into_iter()
582        .map(|job| (job.job_id, job))
583        .collect();
584
585    build_loaded_context(txn, ensembles, fragment_map, jobs).await
586}
587
588/// Sync render stage: uses loaded fragment context and current worker info
589/// to produce actor-to-worker assignments and vnode bitmaps.
590pub(crate) fn render_actor_assignments(
591    actor_id_counter: &AtomicU32,
592    worker_map: &HashMap<WorkerId, WorkerNode>,
593    adaptive_parallelism_strategy: AdaptiveParallelismStrategy,
594    loaded: &LoadedFragmentContext,
595) -> MetaResult<RenderedGraph> {
596    if loaded.is_empty() {
597        return Ok(RenderedGraph::empty());
598    }
599
600    let render_context = RenderActorsContext {
601        fragment_source_ids: &loaded.fragment_source_ids,
602        fragment_splits: &loaded.fragment_splits,
603        streaming_job_databases: &loaded.streaming_job_databases,
604        database_map: &loaded.database_map,
605    };
606
607    let fragments = render_actors(
608        actor_id_counter,
609        &loaded.ensembles,
610        &loaded.job_fragments,
611        &loaded.job_map,
612        worker_map,
613        adaptive_parallelism_strategy,
614        render_context,
615    )?;
616
617    Ok(RenderedGraph {
618        fragments,
619        ensembles: loaded.ensembles.clone(),
620    })
621}
622
623async fn build_loaded_context<C>(
624    txn: &C,
625    ensembles: Vec<NoShuffleEnsemble>,
626    fragment_models: HashMap<FragmentId, fragment::Model>,
627    job_map: HashMap<JobId, streaming_job::Model>,
628) -> MetaResult<LoadedFragmentContext>
629where
630    C: ConnectionTrait,
631{
632    if ensembles.is_empty() {
633        return Ok(LoadedFragmentContext::default());
634    }
635
636    let mut job_fragments: HashMap<JobId, HashMap<FragmentId, LoadedFragment>> = HashMap::new();
637    for (fragment_id, model) in fragment_models {
638        job_fragments
639            .entry(model.job_id)
640            .or_default()
641            .try_insert(fragment_id, LoadedFragment::from(model))
642            .expect("duplicate fragment id for job");
643    }
644
645    #[cfg(debug_assertions)]
646    {
647        debug_sanity_check(&ensembles, &job_fragments, &job_map);
648    }
649
650    let (fragment_source_ids, fragment_splits) =
651        resolve_source_fragments(txn, &job_fragments).await?;
652
653    let job_ids = job_map.keys().copied().collect_vec();
654
655    let streaming_job_databases: HashMap<JobId, _> = StreamingJob::find()
656        .select_only()
657        .column(streaming_job::Column::JobId)
658        .column(object::Column::DatabaseId)
659        .join(JoinType::LeftJoin, streaming_job::Relation::Object.def())
660        .filter(streaming_job::Column::JobId.is_in(job_ids))
661        .into_tuple()
662        .all(txn)
663        .await?
664        .into_iter()
665        .collect();
666
667    let database_map: HashMap<_, _> = Database::find()
668        .filter(
669            database::Column::DatabaseId
670                .is_in(streaming_job_databases.values().copied().collect_vec()),
671        )
672        .all(txn)
673        .await?
674        .into_iter()
675        .map(|db| (db.database_id, db))
676        .collect();
677
678    Ok(LoadedFragmentContext {
679        ensembles,
680        job_fragments,
681        job_map,
682        streaming_job_databases,
683        database_map,
684        fragment_source_ids,
685        fragment_splits,
686    })
687}
688
689// Only metadata resolved asynchronously lives here so the renderer stays synchronous
690// and the call site keeps the runtime dependencies (maps, strategy, actor counter, etc.) explicit.
691struct RenderActorsContext<'a> {
692    fragment_source_ids: &'a HashMap<FragmentId, SourceId>,
693    fragment_splits: &'a HashMap<FragmentId, Vec<SplitImpl>>,
694    streaming_job_databases: &'a HashMap<JobId, DatabaseId>,
695    database_map: &'a HashMap<DatabaseId, database::Model>,
696}
697
698fn render_actors(
699    actor_id_counter: &AtomicU32,
700    ensembles: &[NoShuffleEnsemble],
701    job_fragments: &HashMap<JobId, HashMap<FragmentId, LoadedFragment>>,
702    job_map: &HashMap<JobId, streaming_job::Model>,
703    worker_map: &HashMap<WorkerId, WorkerNode>,
704    adaptive_parallelism_strategy: AdaptiveParallelismStrategy,
705    context: RenderActorsContext<'_>,
706) -> MetaResult<FragmentRenderMap> {
707    let RenderActorsContext {
708        fragment_source_ids,
709        fragment_splits: fragment_splits_map,
710        streaming_job_databases,
711        database_map,
712    } = context;
713
714    let mut all_fragments: FragmentRenderMap = HashMap::new();
715    let fragment_lookup: HashMap<FragmentId, &LoadedFragment> = job_fragments
716        .values()
717        .flat_map(|fragments| fragments.iter())
718        .map(|(fragment_id, fragment)| (*fragment_id, fragment))
719        .collect();
720
721    for NoShuffleEnsemble {
722        entries,
723        components,
724    } in ensembles
725    {
726        tracing::debug!("rendering ensemble entries {:?}", entries);
727
728        let entry_fragments = entries
729            .iter()
730            .map(|fragment_id| fragment_lookup.get(fragment_id).unwrap())
731            .collect_vec();
732
733        let entry_fragment_parallelism = entry_fragments
734            .iter()
735            .map(|fragment| fragment.parallelism.clone())
736            .dedup()
737            .exactly_one()
738            .map_err(|_| {
739                anyhow!(
740                    "entry fragments {:?} have inconsistent parallelism settings",
741                    entries.iter().copied().collect_vec()
742                )
743            })?;
744
745        let (job_id, vnode_count) = entry_fragments
746            .iter()
747            .map(|f| (f.job_id, f.vnode_count))
748            .dedup()
749            .exactly_one()
750            .map_err(|_| anyhow!("Multiple jobs found in no-shuffle ensemble"))?;
751
752        let job = job_map
753            .get(&job_id)
754            .ok_or_else(|| anyhow!("streaming job {job_id} not found"))?;
755
756        let database_resource_group = streaming_job_databases
757            .get(&job_id)
758            .and_then(|database_id| database_map.get(database_id))
759            .unwrap()
760            .resource_group
761            .clone();
762
763        let source_entry_fragment = entry_fragments.iter().find(|f| {
764            let mask = f.fragment_type_mask;
765            if mask.contains(FragmentTypeFlag::Source) {
766                assert!(!mask.contains(FragmentTypeFlag::SourceScan))
767            }
768            mask.contains(FragmentTypeFlag::Source) && !mask.contains(FragmentTypeFlag::Dml)
769        });
770
771        let actor_template = EnsembleActorTemplate::render_new(
772            job,
773            worker_map,
774            adaptive_parallelism_strategy,
775            entry_fragment_parallelism,
776            database_resource_group,
777            vnode_count,
778        )?;
779
780        let source_splits = match source_entry_fragment {
781            Some(entry_fragment) => {
782                let source_id = fragment_source_ids
783                    .get(&entry_fragment.fragment_id)
784                    .ok_or_else(|| {
785                        anyhow!(
786                            "missing source id in source fragment {}",
787                            entry_fragment.fragment_id
788                        )
789                    })?;
790
791                let entry_fragment_id = entry_fragment.fragment_id;
792
793                let splits = fragment_splits_map
794                    .get(&entry_fragment_id)
795                    .cloned()
796                    .unwrap_or_default();
797
798                let splits: std::collections::BTreeMap<_, _> =
799                    splits.into_iter().map(|s| (s.id(), s)).collect();
800                let splits = actor_template.assign_splits(entry_fragment_id, splits);
801                Some((splits, *source_id))
802            }
803            None => None,
804        };
805
806        for component_fragment_id in components {
807            let fragment = fragment_lookup.get(component_fragment_id).unwrap();
808            let fragment_id = fragment.fragment_id;
809            let job_id = fragment.job_id;
810            let fragment_type_mask = fragment.fragment_type_mask;
811            let distribution_type = fragment.distribution_type;
812            let stream_node = &fragment.nodes;
813            let state_table_ids = &fragment.state_table_ids;
814            let vnode_count = fragment.vnode_count;
815            let source_id = fragment_source_ids.get(&fragment_id).cloned();
816
817            let aligner = ComponentFragmentAligner::new(&actor_template, actor_id_counter);
818            let actors = aligner.align_component_actor(distribution_type);
819            let mut splits = source_id
820                .map(|source_id| {
821                    let (fragment_splits, shared_source_id) = source_splits.as_ref().unwrap();
822                    assert_eq!(*shared_source_id, source_id);
823                    aligner.align_component_splits(fragment_splits)
824                })
825                .unwrap_or_default();
826
827            let actors: HashMap<ActorId, InflightActorInfo> = actors
828                .into_iter()
829                .map(|(actor_id, (worker_id, vnode_bitmap))| {
830                    (
831                        actor_id,
832                        InflightActorInfo {
833                            worker_id,
834                            vnode_bitmap,
835                            splits: splits.remove(&actor_id).unwrap_or_default(),
836                        },
837                    )
838                })
839                .collect();
840
841            let fragment = InflightFragmentInfo {
842                fragment_id,
843                distribution_type,
844                fragment_type_mask,
845                vnode_count,
846                nodes: stream_node.clone(),
847                actors,
848                state_table_ids: state_table_ids.clone(),
849            };
850
851            let &database_id = streaming_job_databases.get(&job_id).ok_or_else(|| {
852                anyhow!("streaming job {job_id} not found in streaming_job_databases")
853            })?;
854
855            all_fragments
856                .entry(database_id)
857                .or_default()
858                .entry(job_id)
859                .or_default()
860                .insert(fragment_id, fragment);
861        }
862    }
863
864    Ok(all_fragments)
865}
866
867pub(crate) struct EnsembleActorTemplate {
868    assignment: BTreeMap<WorkerId, BTreeMap<u32, Vec<usize>>>,
869    actor_count: u32,
870    vnode_count: usize,
871}
872
873impl EnsembleActorTemplate {
874    pub(crate) fn render_new(
875        job: &streaming_job::Model,
876        worker_map: &HashMap<WorkerId, WorkerNode>,
877        adaptive_parallelism_strategy: AdaptiveParallelismStrategy,
878        entry_fragment_parallelism: Option<StreamingParallelism>,
879        database_resource_group: String,
880        vnode_count: usize,
881    ) -> MetaResult<Self> {
882        let job_id = job.job_id;
883        let job_strategy = job
884            .stream_context()
885            .adaptive_parallelism_strategy
886            .unwrap_or(adaptive_parallelism_strategy);
887
888        let resource_group = match &job.specific_resource_group {
889            None => database_resource_group,
890            Some(resource_group) => resource_group.clone(),
891        };
892
893        let available_workers: BTreeMap<WorkerId, NonZeroUsize> = worker_map
894            .iter()
895            .filter_map(|(worker_id, worker)| {
896                if worker
897                    .resource_group()
898                    .as_deref()
899                    .unwrap_or(DEFAULT_RESOURCE_GROUP)
900                    == resource_group.as_str()
901                {
902                    Some((
903                        *worker_id,
904                        worker
905                            .parallelism()
906                            .expect("should have parallelism for compute node")
907                            .try_into()
908                            .expect("parallelism for compute node"),
909                    ))
910                } else {
911                    None
912                }
913            })
914            .collect();
915
916        let total_parallelism = available_workers.values().map(|w| w.get()).sum::<usize>();
917
918        let effective_job_parallelism = if job.job_status != JobStatus::Created {
919            job.backfill_parallelism
920                .as_ref()
921                .unwrap_or(&job.parallelism)
922        } else {
923            &job.parallelism
924        };
925
926        let actual_parallelism = match entry_fragment_parallelism
927            .as_ref()
928            .unwrap_or(effective_job_parallelism)
929        {
930            StreamingParallelism::Adaptive | StreamingParallelism::Custom => {
931                job_strategy.compute_target_parallelism(total_parallelism)
932            }
933            StreamingParallelism::Fixed(n) => *n,
934        }
935        .min(vnode_count)
936        .min(job.max_parallelism as usize);
937
938        tracing::debug!(
939            "job {}, final {} parallelism {:?} total_parallelism {} job_max {} vnode count {} fragment_override {:?}",
940            job_id,
941            actual_parallelism,
942            job.parallelism,
943            total_parallelism,
944            job.max_parallelism,
945            vnode_count,
946            entry_fragment_parallelism
947        );
948
949        let assigner = AssignerBuilder::new(job_id).build();
950
951        let actors = (0..(actual_parallelism as u32)).collect_vec();
952        let vnodes = (0..vnode_count).collect_vec();
953
954        let assignment = assigner.assign_hierarchical(&available_workers, &actors, &vnodes)?;
955
956        let actor_count = u32::try_from(actors.len()).expect("actor parallelism exceeds u32::MAX");
957
958        Ok(Self {
959            assignment,
960            actor_count,
961            vnode_count,
962        })
963    }
964
965    fn assign_splits(
966        &self,
967        entry_fragment_id: FragmentId,
968        splits: BTreeMap<SplitId, SplitImpl>,
969    ) -> HashMap<u32, Vec<SplitImpl>> {
970        {
971            {
972                let empty_actor_splits: HashMap<_, _> = self
973                    .assignment
974                    .values()
975                    .flat_map(|actors| actors.keys())
976                    .map(|actor_id| (*actor_id, vec![]))
977                    .collect();
978
979                crate::stream::source_manager::reassign_splits(
980                    entry_fragment_id,
981                    empty_actor_splits,
982                    &splits,
983                    SplitDiffOptions::default(),
984                )
985                .unwrap_or_default()
986            }
987        }
988    }
989}
990
991pub(crate) struct ComponentFragmentAligner<'a> {
992    actor_template: &'a EnsembleActorTemplate,
993    actor_id_base: ActorId,
994}
995
996impl<'a> ComponentFragmentAligner<'a> {
997    pub(crate) fn new(
998        actor_template: &'a EnsembleActorTemplate,
999        actor_id_counter: &AtomicU32,
1000    ) -> Self {
1001        let actor_id_base = actor_id_counter
1002            .fetch_add(actor_template.actor_count, Ordering::Relaxed)
1003            .into();
1004        Self {
1005            actor_template,
1006            actor_id_base,
1007        }
1008    }
1009
1010    pub(crate) fn align_component_actor(
1011        &self,
1012        distribution_type: DistributionType,
1013    ) -> HashMap<ActorId, (WorkerId, Option<Bitmap>)> {
1014        let EnsembleActorTemplate {
1015            assignment,
1016            vnode_count,
1017            actor_count,
1018        } = &self.actor_template;
1019        let actor_id_base = self.actor_id_base;
1020        let vnode_count = *vnode_count;
1021        {
1022            assignment
1023                .iter()
1024                .flat_map(|(worker_id, actors)| {
1025                    actors
1026                        .iter()
1027                        .map(move |(actor_id, vnodes)| (worker_id, actor_id, vnodes))
1028                })
1029                .map(|(&worker_id, &actor_idx, vnodes)| {
1030                    let vnode_bitmap = match distribution_type {
1031                        DistributionType::Single => {
1032                            assert_eq!(*actor_count, 1);
1033                            None
1034                        }
1035                        DistributionType::Hash => Some(Bitmap::from_indices(vnode_count, vnodes)),
1036                    };
1037
1038                    let actor_id = actor_id_base + actor_idx;
1039
1040                    (actor_id, (worker_id, vnode_bitmap))
1041                })
1042                .collect()
1043        }
1044    }
1045
1046    pub(crate) fn align_component_splits(
1047        &self,
1048        split_assignment: &HashMap<u32, Vec<SplitImpl>>,
1049    ) -> HashMap<ActorId, Vec<SplitImpl>> {
1050        (0..self.actor_template.actor_count)
1051            .filter_map(|actor_idx| {
1052                split_assignment
1053                    .get(&actor_idx)
1054                    .map(|splits| ((self.actor_id_base + actor_idx), splits.clone()))
1055            })
1056            .collect()
1057    }
1058}
1059
1060#[cfg(debug_assertions)]
1061fn debug_sanity_check(
1062    ensembles: &[NoShuffleEnsemble],
1063    job_fragments: &HashMap<JobId, HashMap<FragmentId, LoadedFragment>>,
1064    jobs: &HashMap<JobId, streaming_job::Model>,
1065) {
1066    let fragment_lookup: HashMap<FragmentId, (&LoadedFragment, JobId)> = job_fragments
1067        .iter()
1068        .flat_map(|(job_id, fragments)| {
1069            fragments
1070                .iter()
1071                .map(move |(fragment_id, fragment)| (*fragment_id, (fragment, *job_id)))
1072        })
1073        .collect();
1074
1075    // Debug-only assertions to catch inconsistent ensemble metadata early.
1076    debug_assert!(
1077        ensembles
1078            .iter()
1079            .all(|ensemble| ensemble.entries.is_subset(&ensemble.components)),
1080        "entries must be subset of components"
1081    );
1082
1083    let mut missing_fragments = BTreeSet::new();
1084    let mut missing_jobs = BTreeSet::new();
1085
1086    for fragment_id in ensembles
1087        .iter()
1088        .flat_map(|ensemble| ensemble.components.iter())
1089    {
1090        match fragment_lookup.get(fragment_id) {
1091            Some((fragment, job_id)) => {
1092                if !jobs.contains_key(&fragment.job_id) {
1093                    missing_jobs.insert(*job_id);
1094                }
1095            }
1096            None => {
1097                missing_fragments.insert(*fragment_id);
1098            }
1099        }
1100    }
1101
1102    debug_assert!(
1103        missing_fragments.is_empty(),
1104        "missing fragments in fragment_map: {:?}",
1105        missing_fragments
1106    );
1107
1108    debug_assert!(
1109        missing_jobs.is_empty(),
1110        "missing jobs for fragments' job_id: {:?}",
1111        missing_jobs
1112    );
1113
1114    for ensemble in ensembles {
1115        let unique_vnode_counts: Vec<_> = ensemble
1116            .components
1117            .iter()
1118            .flat_map(|fragment_id| {
1119                fragment_lookup
1120                    .get(fragment_id)
1121                    .map(|(fragment, _)| fragment.vnode_count)
1122            })
1123            .unique()
1124            .collect();
1125
1126        debug_assert!(
1127            unique_vnode_counts.len() <= 1,
1128            "components in ensemble must share same vnode_count: ensemble={:?}, vnode_counts={:?}",
1129            ensemble.components,
1130            unique_vnode_counts
1131        );
1132    }
1133}
1134
1135async fn resolve_source_fragments<C>(
1136    txn: &C,
1137    job_fragments: &HashMap<JobId, HashMap<FragmentId, LoadedFragment>>,
1138) -> MetaResult<(
1139    HashMap<FragmentId, SourceId>,
1140    HashMap<FragmentId, Vec<SplitImpl>>,
1141)>
1142where
1143    C: ConnectionTrait,
1144{
1145    let mut source_fragment_ids: HashMap<SourceId, BTreeSet<FragmentId>> = HashMap::new();
1146    for (fragment_id, fragment) in job_fragments.values().flatten() {
1147        let mask = fragment.fragment_type_mask;
1148        if mask.contains(FragmentTypeFlag::Source)
1149            && let Some(source_id) = fragment.nodes.find_stream_source()
1150        {
1151            source_fragment_ids
1152                .entry(source_id)
1153                .or_default()
1154                .insert(*fragment_id);
1155        }
1156
1157        if mask.contains(FragmentTypeFlag::SourceScan)
1158            && let Some((source_id, _)) = fragment.nodes.find_source_backfill()
1159        {
1160            source_fragment_ids
1161                .entry(source_id)
1162                .or_default()
1163                .insert(*fragment_id);
1164        }
1165    }
1166
1167    let fragment_source_ids: HashMap<_, _> = source_fragment_ids
1168        .iter()
1169        .flat_map(|(source_id, fragment_ids)| {
1170            fragment_ids
1171                .iter()
1172                .map(|fragment_id| (*fragment_id, *source_id))
1173        })
1174        .collect();
1175
1176    let fragment_ids = fragment_source_ids.keys().copied().collect_vec();
1177
1178    let fragment_splits: Vec<_> = FragmentSplits::find()
1179        .filter(fragment_splits::Column::FragmentId.is_in(fragment_ids))
1180        .all(txn)
1181        .await?;
1182
1183    let fragment_splits: HashMap<_, _> = fragment_splits
1184        .into_iter()
1185        .flat_map(|model| {
1186            model.splits.map(|splits| {
1187                (
1188                    model.fragment_id,
1189                    splits
1190                        .to_protobuf()
1191                        .splits
1192                        .iter()
1193                        .flat_map(SplitImpl::try_from)
1194                        .collect_vec(),
1195                )
1196            })
1197        })
1198        .collect();
1199
1200    Ok((fragment_source_ids, fragment_splits))
1201}
1202
1203// Helper struct to make the function signature cleaner and to properly bundle the required data.
1204#[derive(Debug)]
1205pub struct ActorGraph<'a> {
1206    pub fragments: &'a HashMap<FragmentId, (Fragment, Vec<StreamActor>)>,
1207    pub locations: &'a HashMap<ActorId, WorkerId>,
1208}
1209
1210#[derive(Debug, Clone)]
1211pub struct NoShuffleEnsemble {
1212    entries: HashSet<FragmentId>,
1213    components: HashSet<FragmentId>,
1214}
1215
1216impl NoShuffleEnsemble {
1217    #[cfg(test)]
1218    pub fn for_test(
1219        entries: impl IntoIterator<Item = FragmentId>,
1220        components: impl IntoIterator<Item = FragmentId>,
1221    ) -> Self {
1222        let entries = entries.into_iter().collect();
1223        let components = components.into_iter().collect();
1224        Self {
1225            entries,
1226            components,
1227        }
1228    }
1229
1230    pub fn fragments(&self) -> impl Iterator<Item = FragmentId> + '_ {
1231        self.components.iter().cloned()
1232    }
1233
1234    pub fn entry_fragments(&self) -> impl Iterator<Item = FragmentId> + '_ {
1235        self.entries.iter().copied()
1236    }
1237
1238    pub fn component_fragments(&self) -> impl Iterator<Item = FragmentId> + '_ {
1239        self.components.iter().copied()
1240    }
1241
1242    pub fn contains_entry(&self, fragment_id: &FragmentId) -> bool {
1243        self.entries.contains(fragment_id)
1244    }
1245}
1246
1247pub async fn find_fragment_no_shuffle_dags_detailed(
1248    db: &impl ConnectionTrait,
1249    initial_fragment_ids: &[FragmentId],
1250) -> MetaResult<Vec<NoShuffleEnsemble>> {
1251    let all_no_shuffle_relations: Vec<(_, _)> = FragmentRelation::find()
1252        .columns([
1253            fragment_relation::Column::SourceFragmentId,
1254            fragment_relation::Column::TargetFragmentId,
1255        ])
1256        .filter(fragment_relation::Column::DispatcherType.eq(DispatcherType::NoShuffle))
1257        .into_tuple()
1258        .all(db)
1259        .await?;
1260
1261    let (forward_edges, backward_edges) =
1262        build_no_shuffle_fragment_graph_edges(all_no_shuffle_relations);
1263
1264    find_no_shuffle_graphs(initial_fragment_ids, &forward_edges, &backward_edges)
1265}
1266
1267pub(crate) fn build_no_shuffle_fragment_graph_edges(
1268    relations: impl IntoIterator<Item = (FragmentId, FragmentId)>,
1269) -> (
1270    HashMap<FragmentId, Vec<FragmentId>>,
1271    HashMap<FragmentId, Vec<FragmentId>>,
1272) {
1273    let mut forward_edges: HashMap<FragmentId, HashSet<FragmentId>> = HashMap::new();
1274    let mut backward_edges: HashMap<FragmentId, HashSet<FragmentId>> = HashMap::new();
1275
1276    for (src, dst) in relations {
1277        forward_edges.entry(src).or_default().insert(dst);
1278        backward_edges.entry(dst).or_default().insert(src);
1279    }
1280
1281    let forward_edges = forward_edges
1282        .into_iter()
1283        .map(|(src, dst_set)| (src, dst_set.into_iter().collect()))
1284        .collect();
1285    let backward_edges = backward_edges
1286        .into_iter()
1287        .map(|(dst, src_set)| (dst, src_set.into_iter().collect()))
1288        .collect();
1289
1290    (forward_edges, backward_edges)
1291}
1292
1293pub(crate) fn find_no_shuffle_graphs(
1294    initial_fragment_ids: &[impl Into<FragmentId> + Copy],
1295    forward_edges: &HashMap<FragmentId, Vec<FragmentId>>,
1296    backward_edges: &HashMap<FragmentId, Vec<FragmentId>>,
1297) -> MetaResult<Vec<NoShuffleEnsemble>> {
1298    let mut graphs: Vec<NoShuffleEnsemble> = Vec::new();
1299    let mut globally_visited: HashSet<FragmentId> = HashSet::new();
1300
1301    for &init_id in initial_fragment_ids {
1302        let init_id = init_id.into();
1303        if globally_visited.contains(&init_id) {
1304            continue;
1305        }
1306
1307        // Found a new component. Traverse it to find all its nodes.
1308        let mut components = HashSet::new();
1309        let mut queue: VecDeque<FragmentId> = VecDeque::new();
1310
1311        queue.push_back(init_id);
1312        globally_visited.insert(init_id);
1313
1314        while let Some(current_id) = queue.pop_front() {
1315            components.insert(current_id);
1316            let neighbors = forward_edges
1317                .get(&current_id)
1318                .into_iter()
1319                .flatten()
1320                .chain(backward_edges.get(&current_id).into_iter().flatten());
1321
1322            for &neighbor_id in neighbors {
1323                if globally_visited.insert(neighbor_id) {
1324                    queue.push_back(neighbor_id);
1325                }
1326            }
1327        }
1328
1329        // For the newly found component, identify its roots.
1330        let mut entries = HashSet::new();
1331        for &node_id in &components {
1332            let is_root = match backward_edges.get(&node_id) {
1333                Some(parents) => parents.iter().all(|p| !components.contains(p)),
1334                None => true,
1335            };
1336            if is_root {
1337                entries.insert(node_id);
1338            }
1339        }
1340
1341        // Store the detailed DAG structure (roots, all nodes in this DAG).
1342        if !entries.is_empty() {
1343            graphs.push(NoShuffleEnsemble {
1344                entries,
1345                components,
1346            });
1347        }
1348    }
1349
1350    Ok(graphs)
1351}
1352
1353#[cfg(test)]
1354mod tests {
1355    use std::collections::{BTreeSet, HashMap, HashSet};
1356    use std::sync::Arc;
1357
1358    use risingwave_connector::source::SplitImpl;
1359    use risingwave_connector::source::test_source::TestSourceSplit;
1360    use risingwave_meta_model::{CreateType, JobStatus};
1361    use risingwave_pb::common::WorkerType;
1362    use risingwave_pb::common::worker_node::Property as WorkerProperty;
1363    use risingwave_pb::stream_plan::StreamNode as PbStreamNode;
1364
1365    use super::*;
1366
1367    // Helper type aliases for cleaner test code
1368    // Using the actual FragmentId type from the module
1369    type Edges = (
1370        HashMap<FragmentId, Vec<FragmentId>>,
1371        HashMap<FragmentId, Vec<FragmentId>>,
1372    );
1373
1374    /// A helper function to build forward and backward edge maps from a simple list of tuples.
1375    /// This reduces boilerplate in each test.
1376    fn build_edges(relations: &[(u32, u32)]) -> Edges {
1377        let mut forward_edges: HashMap<FragmentId, Vec<FragmentId>> = HashMap::new();
1378        let mut backward_edges: HashMap<FragmentId, Vec<FragmentId>> = HashMap::new();
1379        for &(src, dst) in relations {
1380            forward_edges
1381                .entry(src.into())
1382                .or_default()
1383                .push(dst.into());
1384            backward_edges
1385                .entry(dst.into())
1386                .or_default()
1387                .push(src.into());
1388        }
1389        (forward_edges, backward_edges)
1390    }
1391
1392    /// Helper function to create a `HashSet` from a slice easily.
1393    fn to_hashset(ids: &[u32]) -> HashSet<FragmentId> {
1394        ids.iter().map(|id| (*id).into()).collect()
1395    }
1396
1397    fn build_fragment(
1398        fragment_id: FragmentId,
1399        job_id: JobId,
1400        fragment_type_mask: i32,
1401        distribution_type: DistributionType,
1402        vnode_count: i32,
1403        parallelism: StreamingParallelism,
1404    ) -> LoadedFragment {
1405        LoadedFragment {
1406            fragment_id,
1407            job_id,
1408            fragment_type_mask: FragmentTypeMask::from(fragment_type_mask),
1409            distribution_type,
1410            vnode_count: vnode_count as usize,
1411            nodes: PbStreamNode::default(),
1412            state_table_ids: HashSet::new(),
1413            parallelism: Some(parallelism),
1414        }
1415    }
1416
1417    type ActorState = (ActorId, WorkerId, Option<Vec<usize>>, Vec<String>);
1418
1419    fn collect_actor_state(fragment: &InflightFragmentInfo) -> Vec<ActorState> {
1420        let base = fragment.actors.keys().copied().min().unwrap_or_default();
1421
1422        let mut entries: Vec<_> = fragment
1423            .actors
1424            .iter()
1425            .map(|(&actor_id, info)| {
1426                let idx = actor_id.as_raw_id() - base.as_raw_id();
1427                let vnode_indices = info.vnode_bitmap.as_ref().map(|bitmap| {
1428                    bitmap
1429                        .iter()
1430                        .enumerate()
1431                        .filter_map(|(pos, is_set)| is_set.then_some(pos))
1432                        .collect::<Vec<_>>()
1433                });
1434                let splits = info
1435                    .splits
1436                    .iter()
1437                    .map(|split| split.id().to_string())
1438                    .collect::<Vec<_>>();
1439                (idx.into(), info.worker_id, vnode_indices, splits)
1440            })
1441            .collect();
1442
1443        entries.sort_by_key(|(idx, _, _, _)| *idx);
1444        entries
1445    }
1446
1447    fn build_worker_node(
1448        id: impl Into<WorkerId>,
1449        parallelism: usize,
1450        resource_group: &str,
1451    ) -> WorkerNode {
1452        WorkerNode {
1453            id: id.into(),
1454            r#type: WorkerType::ComputeNode as i32,
1455            property: Some(WorkerProperty {
1456                is_streaming: true,
1457                parallelism: u32::try_from(parallelism).expect("parallelism fits into u32"),
1458                resource_group: Some(resource_group.to_owned()),
1459                ..Default::default()
1460            }),
1461            ..Default::default()
1462        }
1463    }
1464
1465    #[test]
1466    fn test_single_linear_chain() {
1467        // Scenario: A simple linear graph 1 -> 2 -> 3.
1468        // We start from the middle node (2).
1469        let (forward, backward) = build_edges(&[(1, 2), (2, 3)]);
1470        let initial_ids = &[2];
1471
1472        // Act
1473        let result = find_no_shuffle_graphs(initial_ids, &forward, &backward);
1474
1475        // Assert
1476        assert!(result.is_ok());
1477        let graphs = result.unwrap();
1478
1479        assert_eq!(graphs.len(), 1);
1480        let graph = &graphs[0];
1481        assert_eq!(graph.entries, to_hashset(&[1]));
1482        assert_eq!(graph.components, to_hashset(&[1, 2, 3]));
1483    }
1484
1485    #[test]
1486    fn test_two_disconnected_graphs() {
1487        // Scenario: Two separate graphs: 1->2 and 10->11.
1488        // We start with one node from each graph.
1489        let (forward, backward) = build_edges(&[(1, 2), (10, 11)]);
1490        let initial_ids = &[2, 10];
1491
1492        // Act
1493        let mut graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1494
1495        // Assert
1496        assert_eq!(graphs.len(), 2);
1497
1498        // Sort results to make the test deterministic, as HashMap iteration order is not guaranteed.
1499        graphs.sort_by_key(|g| *g.components.iter().min().unwrap_or(&0.into()));
1500
1501        // Graph 1
1502        assert_eq!(graphs[0].entries, to_hashset(&[1]));
1503        assert_eq!(graphs[0].components, to_hashset(&[1, 2]));
1504
1505        // Graph 2
1506        assert_eq!(graphs[1].entries, to_hashset(&[10]));
1507        assert_eq!(graphs[1].components, to_hashset(&[10, 11]));
1508    }
1509
1510    #[test]
1511    fn test_multiple_entries_in_one_graph() {
1512        // Scenario: A graph with two roots feeding into one node: 1->3, 2->3.
1513        let (forward, backward) = build_edges(&[(1, 3), (2, 3)]);
1514        let initial_ids = &[3];
1515
1516        // Act
1517        let graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1518
1519        // Assert
1520        assert_eq!(graphs.len(), 1);
1521        let graph = &graphs[0];
1522        assert_eq!(graph.entries, to_hashset(&[1, 2]));
1523        assert_eq!(graph.components, to_hashset(&[1, 2, 3]));
1524    }
1525
1526    #[test]
1527    fn test_diamond_shape_graph() {
1528        // Scenario: A diamond shape: 1->2, 1->3, 2->4, 3->4
1529        let (forward, backward) = build_edges(&[(1, 2), (1, 3), (2, 4), (3, 4)]);
1530        let initial_ids = &[4];
1531
1532        // Act
1533        let graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1534
1535        // Assert
1536        assert_eq!(graphs.len(), 1);
1537        let graph = &graphs[0];
1538        assert_eq!(graph.entries, to_hashset(&[1]));
1539        assert_eq!(graph.components, to_hashset(&[1, 2, 3, 4]));
1540    }
1541
1542    #[test]
1543    fn test_starting_with_multiple_nodes_in_same_graph() {
1544        // Scenario: Start with two different nodes (2 and 4) from the same component.
1545        // Should only identify one graph, not two.
1546        let (forward, backward) = build_edges(&[(1, 2), (2, 3), (3, 4)]);
1547        let initial_ids = &[2, 4];
1548
1549        // Act
1550        let graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1551
1552        // Assert
1553        assert_eq!(graphs.len(), 1);
1554        let graph = &graphs[0];
1555        assert_eq!(graph.entries, to_hashset(&[1]));
1556        assert_eq!(graph.components, to_hashset(&[1, 2, 3, 4]));
1557    }
1558
1559    #[test]
1560    fn test_empty_initial_ids() {
1561        // Scenario: The initial ID list is empty.
1562        let (forward, backward) = build_edges(&[(1, 2)]);
1563        let initial_ids: &[u32] = &[];
1564
1565        // Act
1566        let graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1567
1568        // Assert
1569        assert!(graphs.is_empty());
1570    }
1571
1572    #[test]
1573    fn test_isolated_node_as_input() {
1574        // Scenario: Start with an ID that has no relations.
1575        let (forward, backward) = build_edges(&[(1, 2)]);
1576        let initial_ids = &[100];
1577
1578        // Act
1579        let graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1580
1581        // Assert
1582        assert_eq!(graphs.len(), 1);
1583        let graph = &graphs[0];
1584        assert_eq!(graph.entries, to_hashset(&[100]));
1585        assert_eq!(graph.components, to_hashset(&[100]));
1586    }
1587
1588    #[test]
1589    fn test_graph_with_a_cycle() {
1590        // Scenario: A graph with a cycle: 1 -> 2 -> 3 -> 1.
1591        // The algorithm should correctly identify all nodes in the component.
1592        // Crucially, NO node is a root because every node has a parent *within the component*.
1593        // Therefore, the `entries` set should be empty, and the graph should not be included in the results.
1594        let (forward, backward) = build_edges(&[(1, 2), (2, 3), (3, 1)]);
1595        let initial_ids = &[2];
1596
1597        // Act
1598        let graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1599
1600        // Assert
1601        assert!(
1602            graphs.is_empty(),
1603            "A graph with no entries should not be returned"
1604        );
1605    }
1606    #[test]
1607    fn test_custom_complex() {
1608        let (forward, backward) = build_edges(&[(1, 3), (1, 8), (2, 3), (4, 3), (3, 5), (6, 7)]);
1609        let initial_ids = &[1, 2, 4, 6];
1610
1611        // Act
1612        let mut graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1613
1614        // Assert
1615        assert_eq!(graphs.len(), 2);
1616        // Sort results to make the test deterministic, as HashMap iteration order is not guaranteed.
1617        graphs.sort_by_key(|g| *g.components.iter().min().unwrap_or(&0.into()));
1618
1619        // Graph 1
1620        assert_eq!(graphs[0].entries, to_hashset(&[1, 2, 4]));
1621        assert_eq!(graphs[0].components, to_hashset(&[1, 2, 3, 4, 5, 8]));
1622
1623        // Graph 2
1624        assert_eq!(graphs[1].entries, to_hashset(&[6]));
1625        assert_eq!(graphs[1].components, to_hashset(&[6, 7]));
1626    }
1627
1628    #[test]
1629    fn render_actors_increments_actor_counter() {
1630        let actor_id_counter = AtomicU32::new(100);
1631        let fragment_id: FragmentId = 1.into();
1632        let job_id: JobId = 10.into();
1633        let database_id: DatabaseId = DatabaseId::new(3);
1634
1635        let fragment_model = build_fragment(
1636            fragment_id,
1637            job_id,
1638            0,
1639            DistributionType::Single,
1640            1,
1641            StreamingParallelism::Fixed(1),
1642        );
1643
1644        let job_model = streaming_job::Model {
1645            job_id,
1646            job_status: JobStatus::Created,
1647            create_type: CreateType::Foreground,
1648            timezone: None,
1649            config_override: None,
1650            adaptive_parallelism_strategy: None,
1651            parallelism: StreamingParallelism::Fixed(1),
1652            backfill_parallelism: None,
1653            backfill_orders: None,
1654            max_parallelism: 1,
1655            specific_resource_group: None,
1656            is_serverless_backfill: false,
1657        };
1658
1659        let database_model = database::Model {
1660            database_id,
1661            name: "test_db".into(),
1662            resource_group: "rg-a".into(),
1663            barrier_interval_ms: None,
1664            checkpoint_frequency: None,
1665        };
1666
1667        let ensembles = vec![NoShuffleEnsemble {
1668            entries: HashSet::from([fragment_id]),
1669            components: HashSet::from([fragment_id]),
1670        }];
1671
1672        let fragment_map = HashMap::from([(fragment_id, fragment_model)]);
1673        let job_fragments = HashMap::from([(job_id, fragment_map)]);
1674        let job_map = HashMap::from([(job_id, job_model)]);
1675
1676        let worker_map: HashMap<WorkerId, WorkerNode> =
1677            HashMap::from([(1.into(), build_worker_node(1, 1, "rg-a"))]);
1678
1679        let fragment_source_ids: HashMap<FragmentId, SourceId> = HashMap::new();
1680        let fragment_splits: HashMap<FragmentId, Vec<SplitImpl>> = HashMap::new();
1681        let streaming_job_databases = HashMap::from([(job_id, database_id)]);
1682        let database_map = HashMap::from([(database_id, database_model)]);
1683
1684        let context = RenderActorsContext {
1685            fragment_source_ids: &fragment_source_ids,
1686            fragment_splits: &fragment_splits,
1687            streaming_job_databases: &streaming_job_databases,
1688            database_map: &database_map,
1689        };
1690
1691        let result = render_actors(
1692            &actor_id_counter,
1693            &ensembles,
1694            &job_fragments,
1695            &job_map,
1696            &worker_map,
1697            AdaptiveParallelismStrategy::Auto,
1698            context,
1699        )
1700        .expect("actor rendering succeeds");
1701
1702        let state = collect_actor_state(&result[&database_id][&job_id][&fragment_id]);
1703        assert_eq!(state.len(), 1);
1704        assert!(
1705            state[0].2.is_none(),
1706            "single distribution should not assign vnode bitmaps"
1707        );
1708        assert_eq!(actor_id_counter.load(Ordering::Relaxed), 101);
1709    }
1710
1711    #[test]
1712    fn render_actors_aligns_hash_vnode_bitmaps() {
1713        let actor_id_counter = AtomicU32::new(0);
1714        let entry_fragment_id: FragmentId = 1.into();
1715        let downstream_fragment_id: FragmentId = 2.into();
1716        let job_id: JobId = 20.into();
1717        let database_id: DatabaseId = DatabaseId::new(5);
1718
1719        let entry_fragment = build_fragment(
1720            entry_fragment_id,
1721            job_id,
1722            0,
1723            DistributionType::Hash,
1724            4,
1725            StreamingParallelism::Fixed(2),
1726        );
1727
1728        let downstream_fragment = build_fragment(
1729            downstream_fragment_id,
1730            job_id,
1731            0,
1732            DistributionType::Hash,
1733            4,
1734            StreamingParallelism::Fixed(2),
1735        );
1736
1737        let job_model = streaming_job::Model {
1738            job_id,
1739            job_status: JobStatus::Created,
1740            create_type: CreateType::Background,
1741            timezone: None,
1742            config_override: None,
1743            adaptive_parallelism_strategy: None,
1744            parallelism: StreamingParallelism::Fixed(2),
1745            backfill_parallelism: None,
1746            backfill_orders: None,
1747            max_parallelism: 2,
1748            specific_resource_group: None,
1749            is_serverless_backfill: false,
1750        };
1751
1752        let database_model = database::Model {
1753            database_id,
1754            name: "test_db_hash".into(),
1755            resource_group: "rg-hash".into(),
1756            barrier_interval_ms: None,
1757            checkpoint_frequency: None,
1758        };
1759
1760        let ensembles = vec![NoShuffleEnsemble {
1761            entries: HashSet::from([entry_fragment_id]),
1762            components: HashSet::from([entry_fragment_id, downstream_fragment_id]),
1763        }];
1764
1765        let fragment_map = HashMap::from([
1766            (entry_fragment_id, entry_fragment),
1767            (downstream_fragment_id, downstream_fragment),
1768        ]);
1769        let job_fragments = HashMap::from([(job_id, fragment_map)]);
1770        let job_map = HashMap::from([(job_id, job_model)]);
1771
1772        let worker_map: HashMap<WorkerId, WorkerNode> = HashMap::from([
1773            (1.into(), build_worker_node(1, 1, "rg-hash")),
1774            (2.into(), build_worker_node(2, 1, "rg-hash")),
1775        ]);
1776
1777        let fragment_source_ids: HashMap<FragmentId, SourceId> = HashMap::new();
1778        let fragment_splits: HashMap<FragmentId, Vec<SplitImpl>> = HashMap::new();
1779        let streaming_job_databases = HashMap::from([(job_id, database_id)]);
1780        let database_map = HashMap::from([(database_id, database_model)]);
1781
1782        let context = RenderActorsContext {
1783            fragment_source_ids: &fragment_source_ids,
1784            fragment_splits: &fragment_splits,
1785            streaming_job_databases: &streaming_job_databases,
1786            database_map: &database_map,
1787        };
1788
1789        let result = render_actors(
1790            &actor_id_counter,
1791            &ensembles,
1792            &job_fragments,
1793            &job_map,
1794            &worker_map,
1795            AdaptiveParallelismStrategy::Auto,
1796            context,
1797        )
1798        .expect("actor rendering succeeds");
1799
1800        let entry_state = collect_actor_state(&result[&database_id][&job_id][&entry_fragment_id]);
1801        let downstream_state =
1802            collect_actor_state(&result[&database_id][&job_id][&downstream_fragment_id]);
1803
1804        assert_eq!(entry_state.len(), 2);
1805        assert_eq!(entry_state, downstream_state);
1806
1807        let assigned_vnodes: BTreeSet<_> = entry_state
1808            .iter()
1809            .flat_map(|(_, _, vnodes, _)| {
1810                vnodes
1811                    .as_ref()
1812                    .expect("hash distribution should populate vnode bitmap")
1813                    .iter()
1814                    .copied()
1815            })
1816            .collect();
1817        assert_eq!(assigned_vnodes, BTreeSet::from([0, 1, 2, 3]));
1818        assert_eq!(actor_id_counter.load(Ordering::Relaxed), 4);
1819    }
1820
1821    #[test]
1822    fn render_actors_propagates_source_splits() {
1823        let actor_id_counter = AtomicU32::new(0);
1824        let entry_fragment_id: FragmentId = 11.into();
1825        let downstream_fragment_id: FragmentId = 12.into();
1826        let job_id: JobId = 30.into();
1827        let database_id: DatabaseId = DatabaseId::new(7);
1828        let source_id: SourceId = 99.into();
1829
1830        let source_mask = FragmentTypeFlag::raw_flag([FragmentTypeFlag::Source]) as i32;
1831        let source_scan_mask = FragmentTypeFlag::raw_flag([FragmentTypeFlag::SourceScan]) as i32;
1832
1833        let entry_fragment = build_fragment(
1834            entry_fragment_id,
1835            job_id,
1836            source_mask,
1837            DistributionType::Hash,
1838            4,
1839            StreamingParallelism::Fixed(2),
1840        );
1841
1842        let downstream_fragment = build_fragment(
1843            downstream_fragment_id,
1844            job_id,
1845            source_scan_mask,
1846            DistributionType::Hash,
1847            4,
1848            StreamingParallelism::Fixed(2),
1849        );
1850
1851        let job_model = streaming_job::Model {
1852            job_id,
1853            job_status: JobStatus::Created,
1854            create_type: CreateType::Background,
1855            timezone: None,
1856            config_override: None,
1857            adaptive_parallelism_strategy: None,
1858            parallelism: StreamingParallelism::Fixed(2),
1859            backfill_parallelism: None,
1860            backfill_orders: None,
1861            max_parallelism: 2,
1862            specific_resource_group: None,
1863            is_serverless_backfill: false,
1864        };
1865
1866        let database_model = database::Model {
1867            database_id,
1868            name: "split_db".into(),
1869            resource_group: "rg-source".into(),
1870            barrier_interval_ms: None,
1871            checkpoint_frequency: None,
1872        };
1873
1874        let ensembles = vec![NoShuffleEnsemble {
1875            entries: HashSet::from([entry_fragment_id]),
1876            components: HashSet::from([entry_fragment_id, downstream_fragment_id]),
1877        }];
1878
1879        let fragment_map = HashMap::from([
1880            (entry_fragment_id, entry_fragment),
1881            (downstream_fragment_id, downstream_fragment),
1882        ]);
1883        let job_fragments = HashMap::from([(job_id, fragment_map)]);
1884        let job_map = HashMap::from([(job_id, job_model)]);
1885
1886        let worker_map: HashMap<WorkerId, WorkerNode> = HashMap::from([
1887            (1.into(), build_worker_node(1, 1, "rg-source")),
1888            (2.into(), build_worker_node(2, 1, "rg-source")),
1889        ]);
1890
1891        let split_a = SplitImpl::Test(TestSourceSplit {
1892            id: Arc::<str>::from("split-a"),
1893            properties: HashMap::new(),
1894            offset: "0".into(),
1895        });
1896        let split_b = SplitImpl::Test(TestSourceSplit {
1897            id: Arc::<str>::from("split-b"),
1898            properties: HashMap::new(),
1899            offset: "0".into(),
1900        });
1901
1902        let fragment_source_ids = HashMap::from([
1903            (entry_fragment_id, source_id),
1904            (downstream_fragment_id, source_id),
1905        ]);
1906        let fragment_splits =
1907            HashMap::from([(entry_fragment_id, vec![split_a.clone(), split_b.clone()])]);
1908        let streaming_job_databases = HashMap::from([(job_id, database_id)]);
1909        let database_map = HashMap::from([(database_id, database_model)]);
1910
1911        let context = RenderActorsContext {
1912            fragment_source_ids: &fragment_source_ids,
1913            fragment_splits: &fragment_splits,
1914            streaming_job_databases: &streaming_job_databases,
1915            database_map: &database_map,
1916        };
1917
1918        let result = render_actors(
1919            &actor_id_counter,
1920            &ensembles,
1921            &job_fragments,
1922            &job_map,
1923            &worker_map,
1924            AdaptiveParallelismStrategy::Auto,
1925            context,
1926        )
1927        .expect("actor rendering succeeds");
1928
1929        let entry_state = collect_actor_state(&result[&database_id][&job_id][&entry_fragment_id]);
1930        let downstream_state =
1931            collect_actor_state(&result[&database_id][&job_id][&downstream_fragment_id]);
1932
1933        assert_eq!(entry_state, downstream_state);
1934
1935        let split_ids: BTreeSet<_> = entry_state
1936            .iter()
1937            .flat_map(|(_, _, _, splits)| splits.iter().cloned())
1938            .collect();
1939        assert_eq!(
1940            split_ids,
1941            BTreeSet::from([split_a.id().to_string(), split_b.id().to_string()])
1942        );
1943        assert_eq!(actor_id_counter.load(Ordering::Relaxed), 4);
1944    }
1945
1946    /// Test that job-level strategy overrides global strategy for Adaptive parallelism.
1947    #[test]
1948    fn render_actors_job_strategy_overrides_global() {
1949        let actor_id_counter = AtomicU32::new(0);
1950        let fragment_id: FragmentId = 1.into();
1951        let job_id: JobId = 100.into();
1952        let database_id: DatabaseId = DatabaseId::new(10);
1953
1954        // Fragment with Adaptive parallelism, vnode_count = 8
1955        let fragment_model = build_fragment(
1956            fragment_id,
1957            job_id,
1958            0,
1959            DistributionType::Hash,
1960            8,
1961            StreamingParallelism::Adaptive,
1962        );
1963
1964        // Job has custom strategy: BOUNDED(2)
1965        let job_model = streaming_job::Model {
1966            job_id,
1967            job_status: JobStatus::Created,
1968            create_type: CreateType::Foreground,
1969            timezone: None,
1970            config_override: None,
1971            adaptive_parallelism_strategy: Some("BOUNDED(2)".to_owned()),
1972            parallelism: StreamingParallelism::Adaptive,
1973            backfill_parallelism: None,
1974            backfill_orders: None,
1975            max_parallelism: 8,
1976            specific_resource_group: None,
1977            is_serverless_backfill: false,
1978        };
1979
1980        let database_model = database::Model {
1981            database_id,
1982            name: "test_db".into(),
1983            resource_group: "default".into(),
1984            barrier_interval_ms: None,
1985            checkpoint_frequency: None,
1986        };
1987
1988        let ensembles = vec![NoShuffleEnsemble {
1989            entries: HashSet::from([fragment_id]),
1990            components: HashSet::from([fragment_id]),
1991        }];
1992
1993        let fragment_map =
1994            HashMap::from([(job_id, HashMap::from([(fragment_id, fragment_model)]))]);
1995        let job_map = HashMap::from([(job_id, job_model)]);
1996
1997        // 4 workers with 1 parallelism each = total 4 parallelism
1998        let worker_map = HashMap::from([
1999            (1.into(), build_worker_node(1, 1, "default")),
2000            (2.into(), build_worker_node(2, 1, "default")),
2001            (3.into(), build_worker_node(3, 1, "default")),
2002            (4.into(), build_worker_node(4, 1, "default")),
2003        ]);
2004
2005        let fragment_source_ids: HashMap<FragmentId, SourceId> = HashMap::new();
2006        let fragment_splits: HashMap<FragmentId, Vec<SplitImpl>> = HashMap::new();
2007        let streaming_job_databases = HashMap::from([(job_id, database_id)]);
2008        let database_map = HashMap::from([(database_id, database_model)]);
2009
2010        let context = RenderActorsContext {
2011            fragment_source_ids: &fragment_source_ids,
2012            fragment_splits: &fragment_splits,
2013            streaming_job_databases: &streaming_job_databases,
2014            database_map: &database_map,
2015        };
2016
2017        // Global strategy is FULL (would give 4 actors), but job strategy is BOUNDED(2)
2018        let result = render_actors(
2019            &actor_id_counter,
2020            &ensembles,
2021            &fragment_map,
2022            &job_map,
2023            &worker_map,
2024            AdaptiveParallelismStrategy::Full,
2025            context,
2026        )
2027        .expect("actor rendering succeeds");
2028
2029        let state = collect_actor_state(&result[&database_id][&job_id][&fragment_id]);
2030        // Job strategy BOUNDED(2) should limit to 2 actors, not 4 (global FULL)
2031        assert_eq!(
2032            state.len(),
2033            2,
2034            "Job strategy BOUNDED(2) should override global FULL"
2035        );
2036    }
2037
2038    /// Test that global strategy is used when job has no custom strategy.
2039    #[test]
2040    fn render_actors_uses_global_strategy_when_job_has_none() {
2041        let actor_id_counter = AtomicU32::new(0);
2042        let fragment_id: FragmentId = 1.into();
2043        let job_id: JobId = 101.into();
2044        let database_id: DatabaseId = DatabaseId::new(11);
2045
2046        let fragment_model = build_fragment(
2047            fragment_id,
2048            job_id,
2049            0,
2050            DistributionType::Hash,
2051            8,
2052            StreamingParallelism::Adaptive,
2053        );
2054
2055        // Job has NO custom strategy (None)
2056        let job_model = streaming_job::Model {
2057            job_id,
2058            job_status: JobStatus::Created,
2059            create_type: CreateType::Foreground,
2060            timezone: None,
2061            config_override: None,
2062            adaptive_parallelism_strategy: None, // No custom strategy
2063            parallelism: StreamingParallelism::Adaptive,
2064            backfill_parallelism: None,
2065            backfill_orders: None,
2066            max_parallelism: 8,
2067            specific_resource_group: None,
2068            is_serverless_backfill: false,
2069        };
2070
2071        let database_model = database::Model {
2072            database_id,
2073            name: "test_db".into(),
2074            resource_group: "default".into(),
2075            barrier_interval_ms: None,
2076            checkpoint_frequency: None,
2077        };
2078
2079        let ensembles = vec![NoShuffleEnsemble {
2080            entries: HashSet::from([fragment_id]),
2081            components: HashSet::from([fragment_id]),
2082        }];
2083
2084        let fragment_map =
2085            HashMap::from([(job_id, HashMap::from([(fragment_id, fragment_model)]))]);
2086        let job_map = HashMap::from([(job_id, job_model)]);
2087
2088        // 4 workers = total 4 parallelism
2089        let worker_map = HashMap::from([
2090            (1.into(), build_worker_node(1, 1, "default")),
2091            (2.into(), build_worker_node(2, 1, "default")),
2092            (3.into(), build_worker_node(3, 1, "default")),
2093            (4.into(), build_worker_node(4, 1, "default")),
2094        ]);
2095
2096        let fragment_source_ids: HashMap<FragmentId, SourceId> = HashMap::new();
2097        let fragment_splits: HashMap<FragmentId, Vec<SplitImpl>> = HashMap::new();
2098        let streaming_job_databases = HashMap::from([(job_id, database_id)]);
2099        let database_map = HashMap::from([(database_id, database_model)]);
2100
2101        let context = RenderActorsContext {
2102            fragment_source_ids: &fragment_source_ids,
2103            fragment_splits: &fragment_splits,
2104            streaming_job_databases: &streaming_job_databases,
2105            database_map: &database_map,
2106        };
2107
2108        // Global strategy is BOUNDED(3)
2109        let result = render_actors(
2110            &actor_id_counter,
2111            &ensembles,
2112            &fragment_map,
2113            &job_map,
2114            &worker_map,
2115            AdaptiveParallelismStrategy::Bounded(NonZeroUsize::new(3).unwrap()),
2116            context,
2117        )
2118        .expect("actor rendering succeeds");
2119
2120        let state = collect_actor_state(&result[&database_id][&job_id][&fragment_id]);
2121        // Should use global strategy BOUNDED(3)
2122        assert_eq!(
2123            state.len(),
2124            3,
2125            "Should use global strategy BOUNDED(3) when job has no custom strategy"
2126        );
2127    }
2128
2129    /// Test that Fixed parallelism ignores strategy entirely.
2130    #[test]
2131    fn render_actors_fixed_parallelism_ignores_strategy() {
2132        let actor_id_counter = AtomicU32::new(0);
2133        let fragment_id: FragmentId = 1.into();
2134        let job_id: JobId = 102.into();
2135        let database_id: DatabaseId = DatabaseId::new(12);
2136
2137        // Fragment with FIXED parallelism
2138        let fragment_model = build_fragment(
2139            fragment_id,
2140            job_id,
2141            0,
2142            DistributionType::Hash,
2143            8,
2144            StreamingParallelism::Fixed(5),
2145        );
2146
2147        // Job has custom strategy, but it should be ignored for Fixed parallelism
2148        let job_model = streaming_job::Model {
2149            job_id,
2150            job_status: JobStatus::Created,
2151            create_type: CreateType::Foreground,
2152            timezone: None,
2153            config_override: None,
2154            adaptive_parallelism_strategy: Some("BOUNDED(2)".to_owned()),
2155            parallelism: StreamingParallelism::Fixed(5),
2156            backfill_parallelism: None,
2157            backfill_orders: None,
2158            max_parallelism: 8,
2159            specific_resource_group: None,
2160            is_serverless_backfill: false,
2161        };
2162
2163        let database_model = database::Model {
2164            database_id,
2165            name: "test_db".into(),
2166            resource_group: "default".into(),
2167            barrier_interval_ms: None,
2168            checkpoint_frequency: None,
2169        };
2170
2171        let ensembles = vec![NoShuffleEnsemble {
2172            entries: HashSet::from([fragment_id]),
2173            components: HashSet::from([fragment_id]),
2174        }];
2175
2176        let fragment_map =
2177            HashMap::from([(job_id, HashMap::from([(fragment_id, fragment_model)]))]);
2178        let job_map = HashMap::from([(job_id, job_model)]);
2179
2180        // 6 workers = total 6 parallelism
2181        let worker_map = HashMap::from([
2182            (1.into(), build_worker_node(1, 1, "default")),
2183            (2.into(), build_worker_node(2, 1, "default")),
2184            (3.into(), build_worker_node(3, 1, "default")),
2185            (4.into(), build_worker_node(4, 1, "default")),
2186            (5.into(), build_worker_node(5, 1, "default")),
2187            (6.into(), build_worker_node(6, 1, "default")),
2188        ]);
2189
2190        let fragment_source_ids: HashMap<FragmentId, SourceId> = HashMap::new();
2191        let fragment_splits: HashMap<FragmentId, Vec<SplitImpl>> = HashMap::new();
2192        let streaming_job_databases = HashMap::from([(job_id, database_id)]);
2193        let database_map = HashMap::from([(database_id, database_model)]);
2194
2195        let context = RenderActorsContext {
2196            fragment_source_ids: &fragment_source_ids,
2197            fragment_splits: &fragment_splits,
2198            streaming_job_databases: &streaming_job_databases,
2199            database_map: &database_map,
2200        };
2201
2202        let result = render_actors(
2203            &actor_id_counter,
2204            &ensembles,
2205            &fragment_map,
2206            &job_map,
2207            &worker_map,
2208            AdaptiveParallelismStrategy::Full,
2209            context,
2210        )
2211        .expect("actor rendering succeeds");
2212
2213        let state = collect_actor_state(&result[&database_id][&job_id][&fragment_id]);
2214        // Fixed(5) should be used, ignoring both job strategy BOUNDED(2) and global FULL
2215        assert_eq!(
2216            state.len(),
2217            5,
2218            "Fixed parallelism should ignore all strategies"
2219        );
2220    }
2221
2222    /// Test RATIO strategy calculation.
2223    #[test]
2224    fn render_actors_ratio_strategy() {
2225        let actor_id_counter = AtomicU32::new(0);
2226        let fragment_id: FragmentId = 1.into();
2227        let job_id: JobId = 103.into();
2228        let database_id: DatabaseId = DatabaseId::new(13);
2229
2230        let fragment_model = build_fragment(
2231            fragment_id,
2232            job_id,
2233            0,
2234            DistributionType::Hash,
2235            16,
2236            StreamingParallelism::Adaptive,
2237        );
2238
2239        // Job has RATIO(0.5) strategy
2240        let job_model = streaming_job::Model {
2241            job_id,
2242            job_status: JobStatus::Created,
2243            create_type: CreateType::Foreground,
2244            timezone: None,
2245            config_override: None,
2246            adaptive_parallelism_strategy: Some("RATIO(0.5)".to_owned()),
2247            parallelism: StreamingParallelism::Adaptive,
2248            backfill_parallelism: None,
2249            backfill_orders: None,
2250            max_parallelism: 16,
2251            specific_resource_group: None,
2252            is_serverless_backfill: false,
2253        };
2254
2255        let database_model = database::Model {
2256            database_id,
2257            name: "test_db".into(),
2258            resource_group: "default".into(),
2259            barrier_interval_ms: None,
2260            checkpoint_frequency: None,
2261        };
2262
2263        let ensembles = vec![NoShuffleEnsemble {
2264            entries: HashSet::from([fragment_id]),
2265            components: HashSet::from([fragment_id]),
2266        }];
2267
2268        let fragment_map =
2269            HashMap::from([(job_id, HashMap::from([(fragment_id, fragment_model)]))]);
2270        let job_map = HashMap::from([(job_id, job_model)]);
2271
2272        // 8 workers = total 8 parallelism
2273        let worker_map = HashMap::from([
2274            (1.into(), build_worker_node(1, 1, "default")),
2275            (2.into(), build_worker_node(2, 1, "default")),
2276            (3.into(), build_worker_node(3, 1, "default")),
2277            (4.into(), build_worker_node(4, 1, "default")),
2278            (5.into(), build_worker_node(5, 1, "default")),
2279            (6.into(), build_worker_node(6, 1, "default")),
2280            (7.into(), build_worker_node(7, 1, "default")),
2281            (8.into(), build_worker_node(8, 1, "default")),
2282        ]);
2283
2284        let fragment_source_ids: HashMap<FragmentId, SourceId> = HashMap::new();
2285        let fragment_splits: HashMap<FragmentId, Vec<SplitImpl>> = HashMap::new();
2286        let streaming_job_databases = HashMap::from([(job_id, database_id)]);
2287        let database_map = HashMap::from([(database_id, database_model)]);
2288
2289        let context = RenderActorsContext {
2290            fragment_source_ids: &fragment_source_ids,
2291            fragment_splits: &fragment_splits,
2292            streaming_job_databases: &streaming_job_databases,
2293            database_map: &database_map,
2294        };
2295
2296        let result = render_actors(
2297            &actor_id_counter,
2298            &ensembles,
2299            &fragment_map,
2300            &job_map,
2301            &worker_map,
2302            AdaptiveParallelismStrategy::Full,
2303            context,
2304        )
2305        .expect("actor rendering succeeds");
2306
2307        let state = collect_actor_state(&result[&database_id][&job_id][&fragment_id]);
2308        // RATIO(0.5) of 8 = 4
2309        assert_eq!(
2310            state.len(),
2311            4,
2312            "RATIO(0.5) of 8 workers should give 4 actors"
2313        );
2314    }
2315}