risingwave_meta/controller/
scale.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque};
16use std::num::NonZeroUsize;
17use std::sync::atomic::{AtomicU32, Ordering};
18
19use anyhow::anyhow;
20use itertools::Itertools;
21use risingwave_common::bitmap::Bitmap;
22use risingwave_common::catalog;
23use risingwave_common::catalog::{FragmentTypeFlag, FragmentTypeMask};
24use risingwave_common::system_param::AdaptiveParallelismStrategy;
25use risingwave_common::util::worker_util::DEFAULT_RESOURCE_GROUP;
26use risingwave_connector::source::{SplitImpl, SplitMetaData};
27use risingwave_meta_model::fragment::DistributionType;
28use risingwave_meta_model::prelude::{
29    Database, Fragment, FragmentRelation, FragmentSplits, Sink, Source, StreamingJob, Table,
30};
31use risingwave_meta_model::{
32    DatabaseId, DispatcherType, FragmentId, ObjectId, SourceId, StreamingParallelism, TableId,
33    WorkerId, database, fragment, fragment_relation, fragment_splits, object, sink, source,
34    streaming_job, table,
35};
36use risingwave_meta_model_migration::Condition;
37use sea_orm::{
38    ColumnTrait, ConnectionTrait, EntityTrait, JoinType, QueryFilter, QuerySelect, QueryTrait,
39    RelationTrait,
40};
41
42use crate::MetaResult;
43use crate::controller::fragment::{InflightActorInfo, InflightFragmentInfo};
44use crate::manager::ActiveStreamingWorkerNodes;
45use crate::model::{ActorId, StreamActor};
46use crate::stream::{AssignerBuilder, SplitDiffOptions};
47
48pub(crate) async fn resolve_streaming_job_definition<C>(
49    txn: &C,
50    job_ids: &HashSet<ObjectId>,
51) -> MetaResult<HashMap<ObjectId, String>>
52where
53    C: ConnectionTrait,
54{
55    let job_ids = job_ids.iter().cloned().collect_vec();
56
57    // including table, materialized view, index
58    let common_job_definitions: Vec<(ObjectId, String)> = Table::find()
59        .select_only()
60        .columns([
61            table::Column::TableId,
62            #[cfg(not(debug_assertions))]
63            table::Column::Name,
64            #[cfg(debug_assertions)]
65            table::Column::Definition,
66        ])
67        .filter(table::Column::TableId.is_in(job_ids.clone()))
68        .into_tuple()
69        .all(txn)
70        .await?;
71
72    let sink_definitions: Vec<(ObjectId, String)> = Sink::find()
73        .select_only()
74        .columns([
75            sink::Column::SinkId,
76            #[cfg(not(debug_assertions))]
77            sink::Column::Name,
78            #[cfg(debug_assertions)]
79            sink::Column::Definition,
80        ])
81        .filter(sink::Column::SinkId.is_in(job_ids.clone()))
82        .into_tuple()
83        .all(txn)
84        .await?;
85
86    let source_definitions: Vec<(ObjectId, String)> = Source::find()
87        .select_only()
88        .columns([
89            source::Column::SourceId,
90            #[cfg(not(debug_assertions))]
91            source::Column::Name,
92            #[cfg(debug_assertions)]
93            source::Column::Definition,
94        ])
95        .filter(source::Column::SourceId.is_in(job_ids.clone()))
96        .into_tuple()
97        .all(txn)
98        .await?;
99
100    let definitions: HashMap<ObjectId, String> = common_job_definitions
101        .into_iter()
102        .chain(sink_definitions.into_iter())
103        .chain(source_definitions.into_iter())
104        .collect();
105
106    Ok(definitions)
107}
108
109pub async fn load_fragment_info<C>(
110    txn: &C,
111    actor_id_counter: &AtomicU32,
112    database_id: Option<DatabaseId>,
113    worker_nodes: &ActiveStreamingWorkerNodes,
114    adaptive_parallelism_strategy: AdaptiveParallelismStrategy,
115) -> MetaResult<FragmentRenderMap>
116where
117    C: ConnectionTrait,
118{
119    let mut query = StreamingJob::find()
120        .select_only()
121        .column(streaming_job::Column::JobId);
122
123    if let Some(database_id) = database_id {
124        query = query
125            .join(JoinType::InnerJoin, streaming_job::Relation::Object.def())
126            .filter(object::Column::DatabaseId.eq(database_id));
127    }
128
129    let jobs: Vec<ObjectId> = query.into_tuple().all(txn).await?;
130
131    if jobs.is_empty() {
132        return Ok(HashMap::new());
133    }
134
135    let jobs: HashSet<ObjectId> = jobs.into_iter().collect();
136
137    let available_workers: BTreeMap<_, _> = worker_nodes
138        .current()
139        .values()
140        .filter(|worker| worker.is_streaming_schedulable())
141        .map(|worker| {
142            (
143                worker.id as i32,
144                WorkerInfo {
145                    parallelism: NonZeroUsize::new(worker.compute_node_parallelism()).unwrap(),
146                    resource_group: worker.resource_group(),
147                },
148            )
149        })
150        .collect();
151
152    let RenderedGraph { fragments, .. } = render_jobs(
153        txn,
154        actor_id_counter,
155        jobs,
156        available_workers,
157        adaptive_parallelism_strategy,
158    )
159    .await?;
160
161    Ok(fragments)
162}
163
164#[derive(Debug)]
165pub struct TargetResourcePolicy {
166    pub resource_group: Option<String>,
167    pub parallelism: StreamingParallelism,
168}
169
170#[derive(Debug, Clone)]
171pub struct WorkerInfo {
172    pub parallelism: NonZeroUsize,
173    pub resource_group: Option<String>,
174}
175
176pub type FragmentRenderMap =
177    HashMap<DatabaseId, HashMap<TableId, HashMap<FragmentId, InflightFragmentInfo>>>;
178
179#[derive(Default)]
180pub struct RenderedGraph {
181    pub fragments: FragmentRenderMap,
182    pub ensembles: Vec<NoShuffleEnsemble>,
183}
184
185impl RenderedGraph {
186    pub fn empty() -> Self {
187        Self::default()
188    }
189}
190
191/// Fragment-scoped rendering entry point used by operational tooling.
192/// It validates that the requested fragments are roots of their no-shuffle ensembles,
193/// resolves only the metadata required for those components, and then reuses the shared
194/// rendering pipeline to materialize actor assignments.
195pub async fn render_fragments<C>(
196    txn: &C,
197    actor_id_counter: &AtomicU32,
198    ensembles: Vec<NoShuffleEnsemble>,
199    workers: BTreeMap<WorkerId, WorkerInfo>,
200    adaptive_parallelism_strategy: AdaptiveParallelismStrategy,
201) -> MetaResult<RenderedGraph>
202where
203    C: ConnectionTrait,
204{
205    if ensembles.is_empty() {
206        return Ok(RenderedGraph::empty());
207    }
208
209    let required_fragment_ids: HashSet<_> = ensembles
210        .iter()
211        .flat_map(|ensemble| ensemble.components.iter().copied())
212        .collect();
213
214    let fragment_models = Fragment::find()
215        .filter(fragment::Column::FragmentId.is_in(required_fragment_ids.iter().copied()))
216        .all(txn)
217        .await?;
218
219    let found_fragment_ids: HashSet<_> = fragment_models
220        .iter()
221        .map(|fragment| fragment.fragment_id)
222        .collect();
223
224    if found_fragment_ids.len() != required_fragment_ids.len() {
225        let missing = required_fragment_ids
226            .difference(&found_fragment_ids)
227            .copied()
228            .collect_vec();
229        return Err(anyhow!("fragments {:?} not found", missing).into());
230    }
231
232    let fragment_map: HashMap<_, _> = fragment_models
233        .into_iter()
234        .map(|fragment| (fragment.fragment_id, fragment))
235        .collect();
236
237    let job_ids: HashSet<_> = fragment_map
238        .values()
239        .map(|fragment| fragment.job_id)
240        .collect();
241
242    if job_ids.is_empty() {
243        return Ok(RenderedGraph::empty());
244    }
245
246    let jobs: HashMap<_, _> = StreamingJob::find()
247        .filter(streaming_job::Column::JobId.is_in(job_ids.iter().copied().collect_vec()))
248        .all(txn)
249        .await?
250        .into_iter()
251        .map(|job| (job.job_id, job))
252        .collect();
253
254    let found_job_ids: HashSet<_> = jobs.keys().copied().collect();
255    if found_job_ids.len() != job_ids.len() {
256        let missing = job_ids.difference(&found_job_ids).copied().collect_vec();
257        return Err(anyhow!("streaming jobs {:?} not found", missing).into());
258    }
259
260    let fragments = render_no_shuffle_ensembles(
261        txn,
262        actor_id_counter,
263        &ensembles,
264        &fragment_map,
265        &jobs,
266        &workers,
267        adaptive_parallelism_strategy,
268    )
269    .await?;
270
271    Ok(RenderedGraph {
272        fragments,
273        ensembles,
274    })
275}
276
277/// Job-scoped rendering entry point that walks every no-shuffle root belonging to the
278/// provided streaming jobs before delegating to the shared rendering backend.
279pub async fn render_jobs<C>(
280    txn: &C,
281    actor_id_counter: &AtomicU32,
282    job_ids: HashSet<ObjectId>,
283    workers: BTreeMap<WorkerId, WorkerInfo>,
284    adaptive_parallelism_strategy: AdaptiveParallelismStrategy,
285) -> MetaResult<RenderedGraph>
286where
287    C: ConnectionTrait,
288{
289    let excluded_fragments_query = FragmentRelation::find()
290        .select_only()
291        .column(fragment_relation::Column::TargetFragmentId)
292        .filter(fragment_relation::Column::DispatcherType.eq(DispatcherType::NoShuffle))
293        .into_query();
294
295    let condition = Condition::all()
296        .add(fragment::Column::JobId.is_in(job_ids.clone()))
297        .add(fragment::Column::FragmentId.not_in_subquery(excluded_fragments_query));
298
299    let fragments: Vec<FragmentId> = Fragment::find()
300        .select_only()
301        .column(fragment::Column::FragmentId)
302        .filter(condition)
303        .into_tuple()
304        .all(txn)
305        .await?;
306
307    let ensembles = find_fragment_no_shuffle_dags_detailed(txn, &fragments).await?;
308
309    let fragments = Fragment::find()
310        .filter(
311            fragment::Column::FragmentId.is_in(
312                ensembles
313                    .iter()
314                    .flat_map(|graph| graph.components.iter())
315                    .cloned()
316                    .collect_vec(),
317            ),
318        )
319        .all(txn)
320        .await?;
321
322    let fragment_map: HashMap<_, _> = fragments
323        .into_iter()
324        .map(|fragment| (fragment.fragment_id, fragment))
325        .collect();
326
327    let job_ids = fragment_map
328        .values()
329        .map(|fragment| fragment.job_id)
330        .collect::<BTreeSet<_>>()
331        .into_iter()
332        .collect_vec();
333
334    let jobs: HashMap<_, _> = StreamingJob::find()
335        .filter(streaming_job::Column::JobId.is_in(job_ids))
336        .all(txn)
337        .await?
338        .into_iter()
339        .map(|job| (job.job_id, job))
340        .collect();
341
342    let fragments = render_no_shuffle_ensembles(
343        txn,
344        actor_id_counter,
345        &ensembles,
346        &fragment_map,
347        &jobs,
348        &workers,
349        adaptive_parallelism_strategy,
350    )
351    .await?;
352
353    Ok(RenderedGraph {
354        fragments,
355        ensembles,
356    })
357}
358
359/// Core rendering routine that consumes no-shuffle ensembles and produces
360/// `InflightFragmentInfo`s by:
361///   * determining the eligible worker pools and effective parallelism,
362///   * generating actor→worker assignments plus vnode bitmaps and source splits,
363///   * grouping the rendered fragments by database and streaming job.
364async fn render_no_shuffle_ensembles<C>(
365    txn: &C,
366    actor_id_counter: &AtomicU32,
367    ensembles: &[NoShuffleEnsemble],
368    fragment_map: &HashMap<FragmentId, fragment::Model>,
369    job_map: &HashMap<ObjectId, streaming_job::Model>,
370    worker_map: &BTreeMap<WorkerId, WorkerInfo>,
371    adaptive_parallelism_strategy: AdaptiveParallelismStrategy,
372) -> MetaResult<FragmentRenderMap>
373where
374    C: ConnectionTrait,
375{
376    if ensembles.is_empty() {
377        return Ok(HashMap::new());
378    }
379
380    #[cfg(debug_assertions)]
381    {
382        debug_sanity_check(ensembles, fragment_map, job_map);
383    }
384
385    let (fragment_source_ids, fragment_splits) =
386        resolve_source_fragments(txn, fragment_map).await?;
387
388    let job_ids = job_map.keys().copied().collect_vec();
389
390    let streaming_job_databases: HashMap<ObjectId, _> = StreamingJob::find()
391        .select_only()
392        .column(streaming_job::Column::JobId)
393        .column(object::Column::DatabaseId)
394        .join(JoinType::LeftJoin, streaming_job::Relation::Object.def())
395        .filter(streaming_job::Column::JobId.is_in(job_ids))
396        .into_tuple()
397        .all(txn)
398        .await?
399        .into_iter()
400        .collect();
401
402    let database_map: HashMap<_, _> = Database::find()
403        .filter(
404            database::Column::DatabaseId
405                .is_in(streaming_job_databases.values().copied().collect_vec()),
406        )
407        .all(txn)
408        .await?
409        .into_iter()
410        .map(|db| (db.database_id, db))
411        .collect();
412
413    let context = RenderActorsContext {
414        fragment_source_ids: &fragment_source_ids,
415        fragment_splits: &fragment_splits,
416        streaming_job_databases: &streaming_job_databases,
417        database_map: &database_map,
418    };
419
420    render_actors(
421        actor_id_counter,
422        ensembles,
423        fragment_map,
424        job_map,
425        worker_map,
426        adaptive_parallelism_strategy,
427        context,
428    )
429}
430
431// Only metadata resolved asynchronously lives here so the renderer stays synchronous
432// and the call site keeps the runtime dependencies (maps, strategy, actor counter, etc.) explicit.
433struct RenderActorsContext<'a> {
434    fragment_source_ids: &'a HashMap<FragmentId, SourceId>,
435    fragment_splits: &'a HashMap<FragmentId, Vec<SplitImpl>>,
436    streaming_job_databases: &'a HashMap<ObjectId, DatabaseId>,
437    database_map: &'a HashMap<DatabaseId, database::Model>,
438}
439
440fn render_actors(
441    actor_id_counter: &AtomicU32,
442    ensembles: &[NoShuffleEnsemble],
443    fragment_map: &HashMap<FragmentId, fragment::Model>,
444    job_map: &HashMap<ObjectId, streaming_job::Model>,
445    worker_map: &BTreeMap<WorkerId, WorkerInfo>,
446    adaptive_parallelism_strategy: AdaptiveParallelismStrategy,
447    context: RenderActorsContext<'_>,
448) -> MetaResult<FragmentRenderMap> {
449    let RenderActorsContext {
450        fragment_source_ids,
451        fragment_splits: fragment_splits_map,
452        streaming_job_databases,
453        database_map,
454    } = context;
455
456    let mut all_fragments: FragmentRenderMap = HashMap::new();
457
458    for NoShuffleEnsemble {
459        entries,
460        components,
461    } in ensembles
462    {
463        tracing::debug!("rendering ensemble entries {:?}", entries);
464
465        let entry_fragments = entries
466            .iter()
467            .map(|fragment_id| fragment_map.get(fragment_id).unwrap())
468            .collect_vec();
469
470        let entry_fragment_parallelism = entry_fragments
471            .iter()
472            .map(|fragment| fragment.parallelism.clone())
473            .dedup()
474            .exactly_one()
475            .map_err(|_| {
476                anyhow!(
477                    "entry fragments {:?} have inconsistent parallelism settings",
478                    entries.iter().copied().collect_vec()
479                )
480            })?;
481
482        let (job_id, vnode_count) = entry_fragments
483            .iter()
484            .map(|f| (f.job_id, f.vnode_count as usize))
485            .dedup()
486            .exactly_one()
487            .map_err(|_| anyhow!("Multiple jobs found in no-shuffle ensemble"))?;
488
489        let job = job_map
490            .get(&job_id)
491            .ok_or_else(|| anyhow!("streaming job {job_id} not found"))?;
492
493        let resource_group = match &job.specific_resource_group {
494            None => {
495                let database = streaming_job_databases
496                    .get(&job_id)
497                    .and_then(|database_id| database_map.get(database_id))
498                    .unwrap();
499                database.resource_group.clone()
500            }
501            Some(resource_group) => resource_group.clone(),
502        };
503
504        let available_workers: BTreeMap<WorkerId, NonZeroUsize> = worker_map
505            .iter()
506            .filter_map(|(worker_id, worker)| {
507                if worker
508                    .resource_group
509                    .as_deref()
510                    .unwrap_or(DEFAULT_RESOURCE_GROUP)
511                    == resource_group.as_str()
512                {
513                    Some((*worker_id, worker.parallelism))
514                } else {
515                    None
516                }
517            })
518            .collect();
519
520        let total_parallelism = available_workers.values().map(|w| w.get()).sum::<usize>();
521
522        let actual_parallelism = match entry_fragment_parallelism
523            .as_ref()
524            .unwrap_or(&job.parallelism)
525        {
526            StreamingParallelism::Adaptive | StreamingParallelism::Custom => {
527                adaptive_parallelism_strategy.compute_target_parallelism(total_parallelism)
528            }
529            StreamingParallelism::Fixed(n) => *n,
530        }
531        .min(vnode_count)
532        .min(job.max_parallelism as usize);
533
534        tracing::debug!(
535            "job {}, final {} parallelism {:?} total_parallelism {} job_max {} vnode count {} fragment_override {:?}",
536            job_id,
537            actual_parallelism,
538            job.parallelism,
539            total_parallelism,
540            job.max_parallelism,
541            vnode_count,
542            entry_fragment_parallelism
543        );
544
545        let assigner = AssignerBuilder::new(job_id).build();
546
547        let actors = (0..actual_parallelism).collect_vec();
548        let vnodes = (0..vnode_count).collect_vec();
549
550        let assignment = assigner.assign_hierarchical(&available_workers, &actors, &vnodes)?;
551
552        let source_entry_fragment = entry_fragments.iter().find(|f| {
553            let mask = FragmentTypeMask::from(f.fragment_type_mask);
554            if mask.contains(FragmentTypeFlag::Source) {
555                assert!(!mask.contains(FragmentTypeFlag::SourceScan))
556            }
557            mask.contains(FragmentTypeFlag::Source) && !mask.contains(FragmentTypeFlag::Dml)
558        });
559
560        let (fragment_splits, shared_source_id) = match source_entry_fragment {
561            Some(entry_fragment) => {
562                let source_id = fragment_source_ids
563                    .get(&entry_fragment.fragment_id)
564                    .ok_or_else(|| {
565                        anyhow!(
566                            "missing source id in source fragment {}",
567                            entry_fragment.fragment_id
568                        )
569                    })?;
570
571                let entry_fragment_id = entry_fragment.fragment_id;
572
573                let empty_actor_splits: HashMap<_, _> = actors
574                    .iter()
575                    .map(|actor_id| (*actor_id as u32, vec![]))
576                    .collect();
577
578                let splits = fragment_splits_map
579                    .get(&entry_fragment_id)
580                    .cloned()
581                    .unwrap_or_default();
582
583                let splits: BTreeMap<_, _> = splits.into_iter().map(|s| (s.id(), s)).collect();
584
585                let fragment_splits = crate::stream::source_manager::reassign_splits(
586                    entry_fragment_id as u32,
587                    empty_actor_splits,
588                    &splits,
589                    SplitDiffOptions::default(),
590                )
591                .unwrap_or_default();
592                (fragment_splits, Some(*source_id))
593            }
594            None => (HashMap::new(), None),
595        };
596
597        for component_fragment_id in components {
598            let &fragment::Model {
599                fragment_id,
600                job_id,
601                fragment_type_mask,
602                distribution_type,
603                ref stream_node,
604                ref state_table_ids,
605                ..
606            } = fragment_map.get(component_fragment_id).unwrap();
607
608            let actor_count =
609                u32::try_from(actors.len()).expect("actor parallelism exceeds u32::MAX");
610            let actor_id_base = actor_id_counter.fetch_add(actor_count, Ordering::Relaxed);
611
612            let actors: HashMap<ActorId, InflightActorInfo> = assignment
613                .iter()
614                .flat_map(|(worker_id, actors)| {
615                    actors
616                        .iter()
617                        .map(move |(actor_id, vnodes)| (worker_id, actor_id, vnodes))
618                })
619                .map(|(&worker_id, &actor_idx, vnodes)| {
620                    let vnode_bitmap = match distribution_type {
621                        DistributionType::Single => None,
622                        DistributionType::Hash => Some(Bitmap::from_indices(vnode_count, vnodes)),
623                    };
624
625                    let actor_id = actor_id_base + actor_idx as u32;
626
627                    let splits = if let Some(source_id) = fragment_source_ids.get(&fragment_id) {
628                        assert_eq!(shared_source_id, Some(*source_id));
629
630                        fragment_splits
631                            .get(&(actor_idx as u32))
632                            .cloned()
633                            .unwrap_or_default()
634                    } else {
635                        vec![]
636                    };
637
638                    (
639                        actor_id,
640                        InflightActorInfo {
641                            worker_id,
642                            vnode_bitmap,
643                            splits,
644                        },
645                    )
646                })
647                .collect();
648
649            let fragment = InflightFragmentInfo {
650                fragment_id: fragment_id as u32,
651                distribution_type,
652                fragment_type_mask: fragment_type_mask.into(),
653                vnode_count,
654                nodes: stream_node.to_protobuf(),
655                actors,
656                state_table_ids: state_table_ids
657                    .inner_ref()
658                    .iter()
659                    .map(|id| catalog::TableId::new(*id as _))
660                    .collect(),
661            };
662
663            let &database_id = streaming_job_databases.get(&job_id).ok_or_else(|| {
664                anyhow!("streaming job {job_id} not found in streaming_job_databases")
665            })?;
666
667            all_fragments
668                .entry(database_id)
669                .or_default()
670                .entry(job_id)
671                .or_default()
672                .insert(fragment_id, fragment);
673        }
674    }
675
676    Ok(all_fragments)
677}
678
679#[cfg(debug_assertions)]
680fn debug_sanity_check(
681    ensembles: &[NoShuffleEnsemble],
682    fragment_map: &HashMap<FragmentId, fragment::Model>,
683    jobs: &HashMap<ObjectId, streaming_job::Model>,
684) {
685    // Debug-only assertions to catch inconsistent ensemble metadata early.
686    debug_assert!(
687        ensembles
688            .iter()
689            .all(|ensemble| ensemble.entries.is_subset(&ensemble.components)),
690        "entries must be subset of components"
691    );
692
693    let mut missing_fragments = BTreeSet::new();
694    let mut missing_jobs = BTreeSet::new();
695
696    for fragment_id in ensembles
697        .iter()
698        .flat_map(|ensemble| ensemble.components.iter())
699    {
700        match fragment_map.get(fragment_id) {
701            Some(fragment) => {
702                if !jobs.contains_key(&fragment.job_id) {
703                    missing_jobs.insert(fragment.job_id);
704                }
705            }
706            None => {
707                missing_fragments.insert(*fragment_id);
708            }
709        }
710    }
711
712    debug_assert!(
713        missing_fragments.is_empty(),
714        "missing fragments in fragment_map: {:?}",
715        missing_fragments
716    );
717
718    debug_assert!(
719        missing_jobs.is_empty(),
720        "missing jobs for fragments' job_id: {:?}",
721        missing_jobs
722    );
723
724    for ensemble in ensembles {
725        let unique_vnode_counts: Vec<_> = ensemble
726            .components
727            .iter()
728            .flat_map(|fragment_id| {
729                fragment_map
730                    .get(fragment_id)
731                    .map(|fragment| fragment.vnode_count)
732            })
733            .unique()
734            .collect();
735
736        debug_assert!(
737            unique_vnode_counts.len() <= 1,
738            "components in ensemble must share same vnode_count: ensemble={:?}, vnode_counts={:?}",
739            ensemble.components,
740            unique_vnode_counts
741        );
742    }
743}
744
745async fn resolve_source_fragments<C>(
746    txn: &C,
747    fragment_map: &HashMap<FragmentId, fragment::Model>,
748) -> MetaResult<(
749    HashMap<FragmentId, SourceId>,
750    HashMap<FragmentId, Vec<SplitImpl>>,
751)>
752where
753    C: ConnectionTrait,
754{
755    let mut source_fragment_ids = HashMap::new();
756    for (fragment_id, fragment) in fragment_map {
757        let mask = FragmentTypeMask::from(fragment.fragment_type_mask);
758        if mask.contains(FragmentTypeFlag::Source)
759            && let Some(source_id) = fragment.stream_node.to_protobuf().find_stream_source()
760        {
761            source_fragment_ids
762                .entry(source_id)
763                .or_insert_with(BTreeSet::new)
764                .insert(fragment_id);
765        }
766
767        if mask.contains(FragmentTypeFlag::SourceScan)
768            && let Some((source_id, _)) = fragment.stream_node.to_protobuf().find_source_backfill()
769        {
770            source_fragment_ids
771                .entry(source_id)
772                .or_insert_with(BTreeSet::new)
773                .insert(fragment_id);
774        }
775    }
776
777    let fragment_source_ids: HashMap<_, _> = source_fragment_ids
778        .iter()
779        .flat_map(|(source_id, fragment_ids)| {
780            fragment_ids
781                .iter()
782                .map(|fragment_id| (**fragment_id, *source_id as SourceId))
783        })
784        .collect();
785
786    let fragment_ids = fragment_source_ids.keys().copied().collect_vec();
787
788    let fragment_splits: Vec<_> = FragmentSplits::find()
789        .filter(fragment_splits::Column::FragmentId.is_in(fragment_ids))
790        .all(txn)
791        .await?;
792
793    let fragment_splits: HashMap<_, _> = fragment_splits
794        .into_iter()
795        .flat_map(|model| {
796            model.splits.map(|splits| {
797                (
798                    model.fragment_id,
799                    splits
800                        .to_protobuf()
801                        .splits
802                        .iter()
803                        .flat_map(SplitImpl::try_from)
804                        .collect_vec(),
805                )
806            })
807        })
808        .collect();
809
810    Ok((fragment_source_ids, fragment_splits))
811}
812
813// Helper struct to make the function signature cleaner and to properly bundle the required data.
814#[derive(Debug)]
815pub struct ActorGraph<'a> {
816    pub fragments: &'a HashMap<FragmentId, (Fragment, Vec<StreamActor>)>,
817    pub locations: &'a HashMap<ActorId, WorkerId>,
818}
819
820#[derive(Debug, Clone)]
821pub struct NoShuffleEnsemble {
822    entries: HashSet<FragmentId>,
823    components: HashSet<FragmentId>,
824}
825
826impl NoShuffleEnsemble {
827    pub fn fragments(&self) -> impl Iterator<Item = FragmentId> + '_ {
828        self.components.iter().cloned()
829    }
830
831    pub fn entry_fragments(&self) -> impl Iterator<Item = FragmentId> + '_ {
832        self.entries.iter().copied()
833    }
834
835    pub fn component_fragments(&self) -> impl Iterator<Item = FragmentId> + '_ {
836        self.components.iter().copied()
837    }
838
839    pub fn contains_entry(&self, fragment_id: &FragmentId) -> bool {
840        self.entries.contains(fragment_id)
841    }
842}
843
844pub async fn find_fragment_no_shuffle_dags_detailed(
845    db: &impl ConnectionTrait,
846    initial_fragment_ids: &[FragmentId],
847) -> MetaResult<Vec<NoShuffleEnsemble>> {
848    let all_no_shuffle_relations: Vec<(_, _)> = FragmentRelation::find()
849        .columns([
850            fragment_relation::Column::SourceFragmentId,
851            fragment_relation::Column::TargetFragmentId,
852        ])
853        .filter(fragment_relation::Column::DispatcherType.eq(DispatcherType::NoShuffle))
854        .into_tuple()
855        .all(db)
856        .await?;
857
858    let mut forward_edges: HashMap<FragmentId, Vec<FragmentId>> = HashMap::new();
859    let mut backward_edges: HashMap<FragmentId, Vec<FragmentId>> = HashMap::new();
860
861    for (src, dst) in all_no_shuffle_relations {
862        forward_edges.entry(src).or_default().push(dst);
863        backward_edges.entry(dst).or_default().push(src);
864    }
865
866    find_no_shuffle_graphs(initial_fragment_ids, &forward_edges, &backward_edges)
867}
868
869fn find_no_shuffle_graphs(
870    initial_fragment_ids: &[FragmentId],
871    forward_edges: &HashMap<FragmentId, Vec<FragmentId>>,
872    backward_edges: &HashMap<FragmentId, Vec<FragmentId>>,
873) -> MetaResult<Vec<NoShuffleEnsemble>> {
874    let mut graphs: Vec<NoShuffleEnsemble> = Vec::new();
875    let mut globally_visited: HashSet<FragmentId> = HashSet::new();
876
877    for &init_id in initial_fragment_ids {
878        if globally_visited.contains(&init_id) {
879            continue;
880        }
881
882        // Found a new component. Traverse it to find all its nodes.
883        let mut components = HashSet::new();
884        let mut queue: VecDeque<FragmentId> = VecDeque::new();
885
886        queue.push_back(init_id);
887        globally_visited.insert(init_id);
888
889        while let Some(current_id) = queue.pop_front() {
890            components.insert(current_id);
891            let neighbors = forward_edges
892                .get(&current_id)
893                .into_iter()
894                .flatten()
895                .chain(backward_edges.get(&current_id).into_iter().flatten());
896
897            for &neighbor_id in neighbors {
898                if globally_visited.insert(neighbor_id) {
899                    queue.push_back(neighbor_id);
900                }
901            }
902        }
903
904        // For the newly found component, identify its roots.
905        let mut entries = HashSet::new();
906        for &node_id in &components {
907            let is_root = match backward_edges.get(&node_id) {
908                Some(parents) => parents.iter().all(|p| !components.contains(p)),
909                None => true,
910            };
911            if is_root {
912                entries.insert(node_id);
913            }
914        }
915
916        // Store the detailed DAG structure (roots, all nodes in this DAG).
917        if !entries.is_empty() {
918            graphs.push(NoShuffleEnsemble {
919                entries,
920                components,
921            });
922        }
923    }
924
925    Ok(graphs)
926}
927
928#[cfg(test)]
929mod tests {
930    use std::collections::{BTreeSet, HashMap, HashSet};
931    use std::sync::Arc;
932
933    use risingwave_connector::source::SplitImpl;
934    use risingwave_connector::source::test_source::TestSourceSplit;
935    use risingwave_meta_model::{CreateType, I32Array, JobStatus, StreamNode};
936    use risingwave_pb::stream_plan::StreamNode as PbStreamNode;
937
938    use super::*;
939
940    // Helper type aliases for cleaner test code
941    // Using the actual FragmentId type from the module
942    type Edges = (
943        HashMap<FragmentId, Vec<FragmentId>>,
944        HashMap<FragmentId, Vec<FragmentId>>,
945    );
946
947    /// A helper function to build forward and backward edge maps from a simple list of tuples.
948    /// This reduces boilerplate in each test.
949    fn build_edges(relations: &[(FragmentId, FragmentId)]) -> Edges {
950        let mut forward_edges: HashMap<FragmentId, Vec<FragmentId>> = HashMap::new();
951        let mut backward_edges: HashMap<FragmentId, Vec<FragmentId>> = HashMap::new();
952        for &(src, dst) in relations {
953            forward_edges.entry(src).or_default().push(dst);
954            backward_edges.entry(dst).or_default().push(src);
955        }
956        (forward_edges, backward_edges)
957    }
958
959    /// Helper function to create a `HashSet` from a slice easily.
960    fn to_hashset(ids: &[FragmentId]) -> HashSet<FragmentId> {
961        ids.iter().cloned().collect()
962    }
963
964    #[allow(deprecated)]
965    fn build_fragment(
966        fragment_id: FragmentId,
967        job_id: ObjectId,
968        fragment_type_mask: i32,
969        distribution_type: DistributionType,
970        vnode_count: i32,
971        parallelism: StreamingParallelism,
972    ) -> fragment::Model {
973        fragment::Model {
974            fragment_id,
975            job_id,
976            fragment_type_mask,
977            distribution_type,
978            stream_node: StreamNode::from(&PbStreamNode::default()),
979            state_table_ids: I32Array::default(),
980            upstream_fragment_id: I32Array::default(),
981            vnode_count,
982            parallelism: Some(parallelism),
983        }
984    }
985
986    type ActorState = (u32, WorkerId, Option<Vec<usize>>, Vec<String>);
987
988    fn collect_actor_state(fragment: &InflightFragmentInfo) -> Vec<ActorState> {
989        let base = fragment.actors.keys().copied().min().unwrap_or_default();
990
991        let mut entries: Vec<_> = fragment
992            .actors
993            .iter()
994            .map(|(&actor_id, info)| {
995                let idx = actor_id - base;
996                let vnode_indices = info.vnode_bitmap.as_ref().map(|bitmap| {
997                    bitmap
998                        .iter()
999                        .enumerate()
1000                        .filter_map(|(pos, is_set)| is_set.then_some(pos))
1001                        .collect::<Vec<_>>()
1002                });
1003                let splits = info
1004                    .splits
1005                    .iter()
1006                    .map(|split| split.id().to_string())
1007                    .collect::<Vec<_>>();
1008                (idx, info.worker_id, vnode_indices, splits)
1009            })
1010            .collect();
1011
1012        entries.sort_by_key(|(idx, _, _, _)| *idx);
1013        entries
1014    }
1015
1016    #[test]
1017    fn test_single_linear_chain() {
1018        // Scenario: A simple linear graph 1 -> 2 -> 3.
1019        // We start from the middle node (2).
1020        let (forward, backward) = build_edges(&[(1, 2), (2, 3)]);
1021        let initial_ids = &[2];
1022
1023        // Act
1024        let result = find_no_shuffle_graphs(initial_ids, &forward, &backward);
1025
1026        // Assert
1027        assert!(result.is_ok());
1028        let graphs = result.unwrap();
1029
1030        assert_eq!(graphs.len(), 1);
1031        let graph = &graphs[0];
1032        assert_eq!(graph.entries, to_hashset(&[1]));
1033        assert_eq!(graph.components, to_hashset(&[1, 2, 3]));
1034    }
1035
1036    #[test]
1037    fn test_two_disconnected_graphs() {
1038        // Scenario: Two separate graphs: 1->2 and 10->11.
1039        // We start with one node from each graph.
1040        let (forward, backward) = build_edges(&[(1, 2), (10, 11)]);
1041        let initial_ids = &[2, 10];
1042
1043        // Act
1044        let mut graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1045
1046        // Assert
1047        assert_eq!(graphs.len(), 2);
1048
1049        // Sort results to make the test deterministic, as HashMap iteration order is not guaranteed.
1050        graphs.sort_by_key(|g| *g.components.iter().min().unwrap_or(&0));
1051
1052        // Graph 1
1053        assert_eq!(graphs[0].entries, to_hashset(&[1]));
1054        assert_eq!(graphs[0].components, to_hashset(&[1, 2]));
1055
1056        // Graph 2
1057        assert_eq!(graphs[1].entries, to_hashset(&[10]));
1058        assert_eq!(graphs[1].components, to_hashset(&[10, 11]));
1059    }
1060
1061    #[test]
1062    fn test_multiple_entries_in_one_graph() {
1063        // Scenario: A graph with two roots feeding into one node: 1->3, 2->3.
1064        let (forward, backward) = build_edges(&[(1, 3), (2, 3)]);
1065        let initial_ids = &[3];
1066
1067        // Act
1068        let graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1069
1070        // Assert
1071        assert_eq!(graphs.len(), 1);
1072        let graph = &graphs[0];
1073        assert_eq!(graph.entries, to_hashset(&[1, 2]));
1074        assert_eq!(graph.components, to_hashset(&[1, 2, 3]));
1075    }
1076
1077    #[test]
1078    fn test_diamond_shape_graph() {
1079        // Scenario: A diamond shape: 1->2, 1->3, 2->4, 3->4
1080        let (forward, backward) = build_edges(&[(1, 2), (1, 3), (2, 4), (3, 4)]);
1081        let initial_ids = &[4];
1082
1083        // Act
1084        let graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1085
1086        // Assert
1087        assert_eq!(graphs.len(), 1);
1088        let graph = &graphs[0];
1089        assert_eq!(graph.entries, to_hashset(&[1]));
1090        assert_eq!(graph.components, to_hashset(&[1, 2, 3, 4]));
1091    }
1092
1093    #[test]
1094    fn test_starting_with_multiple_nodes_in_same_graph() {
1095        // Scenario: Start with two different nodes (2 and 4) from the same component.
1096        // Should only identify one graph, not two.
1097        let (forward, backward) = build_edges(&[(1, 2), (2, 3), (3, 4)]);
1098        let initial_ids = &[2, 4];
1099
1100        // Act
1101        let graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1102
1103        // Assert
1104        assert_eq!(graphs.len(), 1);
1105        let graph = &graphs[0];
1106        assert_eq!(graph.entries, to_hashset(&[1]));
1107        assert_eq!(graph.components, to_hashset(&[1, 2, 3, 4]));
1108    }
1109
1110    #[test]
1111    fn test_empty_initial_ids() {
1112        // Scenario: The initial ID list is empty.
1113        let (forward, backward) = build_edges(&[(1, 2)]);
1114        let initial_ids = &[];
1115
1116        // Act
1117        let graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1118
1119        // Assert
1120        assert!(graphs.is_empty());
1121    }
1122
1123    #[test]
1124    fn test_isolated_node_as_input() {
1125        // Scenario: Start with an ID that has no relations.
1126        let (forward, backward) = build_edges(&[(1, 2)]);
1127        let initial_ids = &[100];
1128
1129        // Act
1130        let graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1131
1132        // Assert
1133        assert_eq!(graphs.len(), 1);
1134        let graph = &graphs[0];
1135        assert_eq!(graph.entries, to_hashset(&[100]));
1136        assert_eq!(graph.components, to_hashset(&[100]));
1137    }
1138
1139    #[test]
1140    fn test_graph_with_a_cycle() {
1141        // Scenario: A graph with a cycle: 1 -> 2 -> 3 -> 1.
1142        // The algorithm should correctly identify all nodes in the component.
1143        // Crucially, NO node is a root because every node has a parent *within the component*.
1144        // Therefore, the `entries` set should be empty, and the graph should not be included in the results.
1145        let (forward, backward) = build_edges(&[(1, 2), (2, 3), (3, 1)]);
1146        let initial_ids = &[2];
1147
1148        // Act
1149        let graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1150
1151        // Assert
1152        assert!(
1153            graphs.is_empty(),
1154            "A graph with no entries should not be returned"
1155        );
1156    }
1157    #[test]
1158    fn test_custom_complex() {
1159        let (forward, backward) = build_edges(&[(1, 3), (1, 8), (2, 3), (4, 3), (3, 5), (6, 7)]);
1160        let initial_ids = &[1, 2, 4, 6];
1161
1162        // Act
1163        let mut graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1164
1165        // Assert
1166        assert_eq!(graphs.len(), 2);
1167        // Sort results to make the test deterministic, as HashMap iteration order is not guaranteed.
1168        graphs.sort_by_key(|g| *g.components.iter().min().unwrap_or(&0));
1169
1170        // Graph 1
1171        assert_eq!(graphs[0].entries, to_hashset(&[1, 2, 4]));
1172        assert_eq!(graphs[0].components, to_hashset(&[1, 2, 3, 4, 5, 8]));
1173
1174        // Graph 2
1175        assert_eq!(graphs[1].entries, to_hashset(&[6]));
1176        assert_eq!(graphs[1].components, to_hashset(&[6, 7]));
1177    }
1178
1179    #[test]
1180    fn render_actors_increments_actor_counter() {
1181        let actor_id_counter = AtomicU32::new(100);
1182        let fragment_id: FragmentId = 1;
1183        let job_id: ObjectId = 10;
1184        let database_id: DatabaseId = 3;
1185
1186        let fragment_model = build_fragment(
1187            fragment_id,
1188            job_id,
1189            0,
1190            DistributionType::Single,
1191            1,
1192            StreamingParallelism::Fixed(1),
1193        );
1194
1195        let job_model = streaming_job::Model {
1196            job_id,
1197            job_status: JobStatus::Created,
1198            create_type: CreateType::Foreground,
1199            timezone: None,
1200            parallelism: StreamingParallelism::Fixed(1),
1201            max_parallelism: 1,
1202            specific_resource_group: None,
1203        };
1204
1205        let database_model = database::Model {
1206            database_id,
1207            name: "test_db".into(),
1208            resource_group: "rg-a".into(),
1209            barrier_interval_ms: None,
1210            checkpoint_frequency: None,
1211        };
1212
1213        let ensembles = vec![NoShuffleEnsemble {
1214            entries: HashSet::from([fragment_id]),
1215            components: HashSet::from([fragment_id]),
1216        }];
1217
1218        let fragment_map = HashMap::from([(fragment_id, fragment_model)]);
1219        let job_map = HashMap::from([(job_id, job_model)]);
1220
1221        let worker_map = BTreeMap::from([(
1222            1,
1223            WorkerInfo {
1224                parallelism: NonZeroUsize::new(1).unwrap(),
1225                resource_group: Some("rg-a".into()),
1226            },
1227        )]);
1228
1229        let fragment_source_ids: HashMap<FragmentId, SourceId> = HashMap::new();
1230        let fragment_splits: HashMap<FragmentId, Vec<SplitImpl>> = HashMap::new();
1231        let streaming_job_databases = HashMap::from([(job_id, database_id)]);
1232        let database_map = HashMap::from([(database_id, database_model)]);
1233
1234        let context = RenderActorsContext {
1235            fragment_source_ids: &fragment_source_ids,
1236            fragment_splits: &fragment_splits,
1237            streaming_job_databases: &streaming_job_databases,
1238            database_map: &database_map,
1239        };
1240
1241        let result = render_actors(
1242            &actor_id_counter,
1243            &ensembles,
1244            &fragment_map,
1245            &job_map,
1246            &worker_map,
1247            AdaptiveParallelismStrategy::Auto,
1248            context,
1249        )
1250        .expect("actor rendering succeeds");
1251
1252        let state = collect_actor_state(&result[&database_id][&job_id][&fragment_id]);
1253        assert_eq!(state.len(), 1);
1254        assert!(
1255            state[0].2.is_none(),
1256            "single distribution should not assign vnode bitmaps"
1257        );
1258        assert_eq!(actor_id_counter.load(Ordering::Relaxed), 101);
1259    }
1260
1261    #[test]
1262    fn render_actors_aligns_hash_vnode_bitmaps() {
1263        let actor_id_counter = AtomicU32::new(0);
1264        let entry_fragment_id: FragmentId = 1;
1265        let downstream_fragment_id: FragmentId = 2;
1266        let job_id: ObjectId = 20;
1267        let database_id: DatabaseId = 5;
1268
1269        let entry_fragment = build_fragment(
1270            entry_fragment_id,
1271            job_id,
1272            0,
1273            DistributionType::Hash,
1274            4,
1275            StreamingParallelism::Fixed(2),
1276        );
1277
1278        let downstream_fragment = build_fragment(
1279            downstream_fragment_id,
1280            job_id,
1281            0,
1282            DistributionType::Hash,
1283            4,
1284            StreamingParallelism::Fixed(2),
1285        );
1286
1287        let job_model = streaming_job::Model {
1288            job_id,
1289            job_status: JobStatus::Created,
1290            create_type: CreateType::Background,
1291            timezone: None,
1292            parallelism: StreamingParallelism::Fixed(2),
1293            max_parallelism: 2,
1294            specific_resource_group: None,
1295        };
1296
1297        let database_model = database::Model {
1298            database_id,
1299            name: "test_db_hash".into(),
1300            resource_group: "rg-hash".into(),
1301            barrier_interval_ms: None,
1302            checkpoint_frequency: None,
1303        };
1304
1305        let ensembles = vec![NoShuffleEnsemble {
1306            entries: HashSet::from([entry_fragment_id]),
1307            components: HashSet::from([entry_fragment_id, downstream_fragment_id]),
1308        }];
1309
1310        let fragment_map = HashMap::from([
1311            (entry_fragment_id, entry_fragment),
1312            (downstream_fragment_id, downstream_fragment),
1313        ]);
1314        let job_map = HashMap::from([(job_id, job_model)]);
1315
1316        let worker_map = BTreeMap::from([
1317            (
1318                1,
1319                WorkerInfo {
1320                    parallelism: NonZeroUsize::new(1).unwrap(),
1321                    resource_group: Some("rg-hash".into()),
1322                },
1323            ),
1324            (
1325                2,
1326                WorkerInfo {
1327                    parallelism: NonZeroUsize::new(1).unwrap(),
1328                    resource_group: Some("rg-hash".into()),
1329                },
1330            ),
1331        ]);
1332
1333        let fragment_source_ids: HashMap<FragmentId, SourceId> = HashMap::new();
1334        let fragment_splits: HashMap<FragmentId, Vec<SplitImpl>> = HashMap::new();
1335        let streaming_job_databases = HashMap::from([(job_id, database_id)]);
1336        let database_map = HashMap::from([(database_id, database_model)]);
1337
1338        let context = RenderActorsContext {
1339            fragment_source_ids: &fragment_source_ids,
1340            fragment_splits: &fragment_splits,
1341            streaming_job_databases: &streaming_job_databases,
1342            database_map: &database_map,
1343        };
1344
1345        let result = render_actors(
1346            &actor_id_counter,
1347            &ensembles,
1348            &fragment_map,
1349            &job_map,
1350            &worker_map,
1351            AdaptiveParallelismStrategy::Auto,
1352            context,
1353        )
1354        .expect("actor rendering succeeds");
1355
1356        let entry_state = collect_actor_state(&result[&database_id][&job_id][&entry_fragment_id]);
1357        let downstream_state =
1358            collect_actor_state(&result[&database_id][&job_id][&downstream_fragment_id]);
1359
1360        assert_eq!(entry_state.len(), 2);
1361        assert_eq!(entry_state, downstream_state);
1362
1363        let assigned_vnodes: BTreeSet<_> = entry_state
1364            .iter()
1365            .flat_map(|(_, _, vnodes, _)| {
1366                vnodes
1367                    .as_ref()
1368                    .expect("hash distribution should populate vnode bitmap")
1369                    .iter()
1370                    .copied()
1371            })
1372            .collect();
1373        assert_eq!(assigned_vnodes, BTreeSet::from([0, 1, 2, 3]));
1374        assert_eq!(actor_id_counter.load(Ordering::Relaxed), 4);
1375    }
1376
1377    #[test]
1378    fn render_actors_propagates_source_splits() {
1379        let actor_id_counter = AtomicU32::new(0);
1380        let entry_fragment_id: FragmentId = 11;
1381        let downstream_fragment_id: FragmentId = 12;
1382        let job_id: ObjectId = 30;
1383        let database_id: DatabaseId = 7;
1384        let source_id: SourceId = 99;
1385
1386        let source_mask = FragmentTypeFlag::raw_flag([FragmentTypeFlag::Source]) as i32;
1387        let source_scan_mask = FragmentTypeFlag::raw_flag([FragmentTypeFlag::SourceScan]) as i32;
1388
1389        let entry_fragment = build_fragment(
1390            entry_fragment_id,
1391            job_id,
1392            source_mask,
1393            DistributionType::Hash,
1394            4,
1395            StreamingParallelism::Fixed(2),
1396        );
1397
1398        let downstream_fragment = build_fragment(
1399            downstream_fragment_id,
1400            job_id,
1401            source_scan_mask,
1402            DistributionType::Hash,
1403            4,
1404            StreamingParallelism::Fixed(2),
1405        );
1406
1407        let job_model = streaming_job::Model {
1408            job_id,
1409            job_status: JobStatus::Created,
1410            create_type: CreateType::Background,
1411            timezone: None,
1412            parallelism: StreamingParallelism::Fixed(2),
1413            max_parallelism: 2,
1414            specific_resource_group: None,
1415        };
1416
1417        let database_model = database::Model {
1418            database_id,
1419            name: "split_db".into(),
1420            resource_group: "rg-source".into(),
1421            barrier_interval_ms: None,
1422            checkpoint_frequency: None,
1423        };
1424
1425        let ensembles = vec![NoShuffleEnsemble {
1426            entries: HashSet::from([entry_fragment_id]),
1427            components: HashSet::from([entry_fragment_id, downstream_fragment_id]),
1428        }];
1429
1430        let fragment_map = HashMap::from([
1431            (entry_fragment_id, entry_fragment),
1432            (downstream_fragment_id, downstream_fragment),
1433        ]);
1434        let job_map = HashMap::from([(job_id, job_model)]);
1435
1436        let worker_map = BTreeMap::from([
1437            (
1438                1,
1439                WorkerInfo {
1440                    parallelism: NonZeroUsize::new(1).unwrap(),
1441                    resource_group: Some("rg-source".into()),
1442                },
1443            ),
1444            (
1445                2,
1446                WorkerInfo {
1447                    parallelism: NonZeroUsize::new(1).unwrap(),
1448                    resource_group: Some("rg-source".into()),
1449                },
1450            ),
1451        ]);
1452
1453        let split_a = SplitImpl::Test(TestSourceSplit {
1454            id: Arc::<str>::from("split-a"),
1455            properties: HashMap::new(),
1456            offset: "0".into(),
1457        });
1458        let split_b = SplitImpl::Test(TestSourceSplit {
1459            id: Arc::<str>::from("split-b"),
1460            properties: HashMap::new(),
1461            offset: "0".into(),
1462        });
1463
1464        let fragment_source_ids = HashMap::from([
1465            (entry_fragment_id, source_id),
1466            (downstream_fragment_id, source_id),
1467        ]);
1468        let fragment_splits =
1469            HashMap::from([(entry_fragment_id, vec![split_a.clone(), split_b.clone()])]);
1470        let streaming_job_databases = HashMap::from([(job_id, database_id)]);
1471        let database_map = HashMap::from([(database_id, database_model)]);
1472
1473        let context = RenderActorsContext {
1474            fragment_source_ids: &fragment_source_ids,
1475            fragment_splits: &fragment_splits,
1476            streaming_job_databases: &streaming_job_databases,
1477            database_map: &database_map,
1478        };
1479
1480        let result = render_actors(
1481            &actor_id_counter,
1482            &ensembles,
1483            &fragment_map,
1484            &job_map,
1485            &worker_map,
1486            AdaptiveParallelismStrategy::Auto,
1487            context,
1488        )
1489        .expect("actor rendering succeeds");
1490
1491        let entry_state = collect_actor_state(&result[&database_id][&job_id][&entry_fragment_id]);
1492        let downstream_state =
1493            collect_actor_state(&result[&database_id][&job_id][&downstream_fragment_id]);
1494
1495        assert_eq!(entry_state, downstream_state);
1496
1497        let split_ids: BTreeSet<_> = entry_state
1498            .iter()
1499            .flat_map(|(_, _, _, splits)| splits.iter().cloned())
1500            .collect();
1501        assert_eq!(
1502            split_ids,
1503            BTreeSet::from([split_a.id().to_string(), split_b.id().to_string()])
1504        );
1505        assert_eq!(actor_id_counter.load(Ordering::Relaxed), 4);
1506    }
1507}