risingwave_meta/controller/
scale.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque};
16use std::num::NonZeroUsize;
17use std::sync::atomic::{AtomicU32, Ordering};
18
19use anyhow::anyhow;
20use itertools::Itertools;
21use risingwave_common::bitmap::Bitmap;
22use risingwave_common::catalog::{FragmentTypeFlag, FragmentTypeMask};
23use risingwave_common::id::JobId;
24use risingwave_common::system_param::AdaptiveParallelismStrategy;
25use risingwave_common::util::worker_util::DEFAULT_RESOURCE_GROUP;
26use risingwave_connector::source::{SplitImpl, SplitMetaData};
27use risingwave_meta_model::fragment::DistributionType;
28use risingwave_meta_model::prelude::{
29    Database, Fragment, FragmentRelation, FragmentSplits, Sink, Source, StreamingJob, Table,
30};
31use risingwave_meta_model::{
32    CreateType, DatabaseId, DispatcherType, FragmentId, JobStatus, SourceId, StreamingParallelism,
33    WorkerId, database, fragment, fragment_relation, fragment_splits, object, sink, source,
34    streaming_job, table,
35};
36use risingwave_meta_model_migration::Condition;
37use sea_orm::{
38    ColumnTrait, ConnectionTrait, EntityTrait, JoinType, QueryFilter, QuerySelect, QueryTrait,
39    RelationTrait,
40};
41
42use crate::MetaResult;
43use crate::controller::fragment::{InflightActorInfo, InflightFragmentInfo};
44use crate::manager::ActiveStreamingWorkerNodes;
45use crate::model::{ActorId, StreamActor};
46use crate::stream::{AssignerBuilder, SplitDiffOptions};
47
48pub(crate) async fn resolve_streaming_job_definition<C>(
49    txn: &C,
50    job_ids: &HashSet<JobId>,
51) -> MetaResult<HashMap<JobId, String>>
52where
53    C: ConnectionTrait,
54{
55    let job_ids = job_ids.iter().cloned().collect_vec();
56
57    // including table, materialized view, index
58    let common_job_definitions: Vec<(JobId, String)> = Table::find()
59        .select_only()
60        .columns([
61            table::Column::TableId,
62            #[cfg(not(debug_assertions))]
63            table::Column::Name,
64            #[cfg(debug_assertions)]
65            table::Column::Definition,
66        ])
67        .filter(table::Column::TableId.is_in(job_ids.clone()))
68        .into_tuple()
69        .all(txn)
70        .await?;
71
72    let sink_definitions: Vec<(JobId, String)> = Sink::find()
73        .select_only()
74        .columns([
75            sink::Column::SinkId,
76            #[cfg(not(debug_assertions))]
77            sink::Column::Name,
78            #[cfg(debug_assertions)]
79            sink::Column::Definition,
80        ])
81        .filter(sink::Column::SinkId.is_in(job_ids.clone()))
82        .into_tuple()
83        .all(txn)
84        .await?;
85
86    let source_definitions: Vec<(JobId, String)> = Source::find()
87        .select_only()
88        .columns([
89            source::Column::SourceId,
90            #[cfg(not(debug_assertions))]
91            source::Column::Name,
92            #[cfg(debug_assertions)]
93            source::Column::Definition,
94        ])
95        .filter(source::Column::SourceId.is_in(job_ids.clone()))
96        .into_tuple()
97        .all(txn)
98        .await?;
99
100    let definitions: HashMap<JobId, String> = common_job_definitions
101        .into_iter()
102        .chain(sink_definitions.into_iter())
103        .chain(source_definitions.into_iter())
104        .collect();
105
106    Ok(definitions)
107}
108
109pub async fn load_fragment_info<C>(
110    txn: &C,
111    actor_id_counter: &AtomicU32,
112    database_id: Option<DatabaseId>,
113    worker_nodes: &ActiveStreamingWorkerNodes,
114    adaptive_parallelism_strategy: AdaptiveParallelismStrategy,
115) -> MetaResult<FragmentRenderMap>
116where
117    C: ConnectionTrait,
118{
119    let mut query = StreamingJob::find()
120        .select_only()
121        .column(streaming_job::Column::JobId);
122
123    if let Some(database_id) = database_id {
124        query = query
125            .join(JoinType::InnerJoin, streaming_job::Relation::Object.def())
126            .filter(object::Column::DatabaseId.eq(database_id));
127    }
128
129    let jobs: Vec<JobId> = query.into_tuple().all(txn).await?;
130
131    if jobs.is_empty() {
132        return Ok(HashMap::new());
133    }
134
135    let jobs: HashSet<JobId> = jobs.into_iter().collect();
136
137    let loaded = load_fragment_context_for_jobs(txn, jobs).await?;
138
139    if loaded.is_empty() {
140        return Ok(HashMap::new());
141    }
142
143    let available_workers: BTreeMap<_, _> = worker_nodes
144        .current()
145        .values()
146        .filter(|worker| worker.is_streaming_schedulable())
147        .map(|worker| {
148            (
149                worker.id,
150                WorkerInfo {
151                    parallelism: NonZeroUsize::new(worker.compute_node_parallelism()).unwrap(),
152                    resource_group: worker.resource_group(),
153                },
154            )
155        })
156        .collect();
157
158    let RenderedGraph { fragments, .. } = render_actor_assignments(
159        actor_id_counter,
160        &available_workers,
161        adaptive_parallelism_strategy,
162        &loaded,
163    )?;
164
165    Ok(fragments)
166}
167
168#[derive(Debug)]
169pub struct TargetResourcePolicy {
170    pub resource_group: Option<String>,
171    pub parallelism: StreamingParallelism,
172}
173
174#[derive(Debug, Clone)]
175pub struct WorkerInfo {
176    pub parallelism: NonZeroUsize,
177    pub resource_group: Option<String>,
178}
179
180pub type FragmentRenderMap =
181    HashMap<DatabaseId, HashMap<JobId, HashMap<FragmentId, InflightFragmentInfo>>>;
182
183#[derive(Default)]
184pub struct RenderedGraph {
185    pub fragments: FragmentRenderMap,
186    pub ensembles: Vec<NoShuffleEnsemble>,
187}
188
189impl RenderedGraph {
190    pub fn empty() -> Self {
191        Self::default()
192    }
193}
194
195/// Context loaded asynchronously from database, containing all metadata
196/// required to render actor assignments. This separates async I/O from
197/// sync rendering logic.
198#[derive(Default)]
199pub struct LoadedFragmentContext {
200    pub ensembles: Vec<NoShuffleEnsemble>,
201    pub fragment_map: HashMap<FragmentId, fragment::Model>,
202    pub job_map: HashMap<JobId, streaming_job::Model>,
203    pub streaming_job_databases: HashMap<JobId, DatabaseId>,
204    pub database_map: HashMap<DatabaseId, database::Model>,
205    pub fragment_source_ids: HashMap<FragmentId, SourceId>,
206    pub fragment_splits: HashMap<FragmentId, Vec<SplitImpl>>,
207}
208
209impl LoadedFragmentContext {
210    pub fn is_empty(&self) -> bool {
211        self.ensembles.is_empty()
212    }
213}
214
215/// Fragment-scoped rendering entry point used by operational tooling.
216/// It validates that the requested fragments are roots of their no-shuffle ensembles,
217/// resolves only the metadata required for those components, and then reuses the shared
218/// rendering pipeline to materialize actor assignments.
219pub async fn render_fragments<C>(
220    txn: &C,
221    actor_id_counter: &AtomicU32,
222    ensembles: Vec<NoShuffleEnsemble>,
223    workers: BTreeMap<WorkerId, WorkerInfo>,
224    adaptive_parallelism_strategy: AdaptiveParallelismStrategy,
225) -> MetaResult<RenderedGraph>
226where
227    C: ConnectionTrait,
228{
229    let loaded = load_fragment_context(txn, ensembles).await?;
230
231    if loaded.is_empty() {
232        return Ok(RenderedGraph::empty());
233    }
234
235    render_actor_assignments(
236        actor_id_counter,
237        &workers,
238        adaptive_parallelism_strategy,
239        &loaded,
240    )
241}
242
243/// Async load stage for fragment-scoped rendering. It resolves all metadata required to later
244/// render actor assignments with arbitrary worker sets.
245pub async fn load_fragment_context<C>(
246    txn: &C,
247    ensembles: Vec<NoShuffleEnsemble>,
248) -> MetaResult<LoadedFragmentContext>
249where
250    C: ConnectionTrait,
251{
252    if ensembles.is_empty() {
253        return Ok(LoadedFragmentContext::default());
254    }
255
256    let required_fragment_ids: HashSet<_> = ensembles
257        .iter()
258        .flat_map(|ensemble| ensemble.components.iter().copied())
259        .collect();
260
261    let fragment_models = Fragment::find()
262        .filter(fragment::Column::FragmentId.is_in(required_fragment_ids.iter().copied()))
263        .all(txn)
264        .await?;
265
266    let found_fragment_ids: HashSet<_> = fragment_models
267        .iter()
268        .map(|fragment| fragment.fragment_id)
269        .collect();
270
271    if found_fragment_ids.len() != required_fragment_ids.len() {
272        let missing = required_fragment_ids
273            .difference(&found_fragment_ids)
274            .copied()
275            .collect_vec();
276        return Err(anyhow!("fragments {:?} not found", missing).into());
277    }
278
279    let fragment_map: HashMap<_, _> = fragment_models
280        .into_iter()
281        .map(|fragment| (fragment.fragment_id, fragment))
282        .collect();
283
284    let job_ids: HashSet<_> = fragment_map
285        .values()
286        .map(|fragment| fragment.job_id)
287        .collect();
288
289    if job_ids.is_empty() {
290        return Ok(LoadedFragmentContext::default());
291    }
292
293    let jobs: HashMap<_, _> = StreamingJob::find()
294        .filter(streaming_job::Column::JobId.is_in(job_ids.iter().copied().collect_vec()))
295        .all(txn)
296        .await?
297        .into_iter()
298        .map(|job| (job.job_id, job))
299        .collect();
300
301    let found_job_ids: HashSet<_> = jobs.keys().copied().collect();
302    if found_job_ids.len() != job_ids.len() {
303        let missing = job_ids.difference(&found_job_ids).copied().collect_vec();
304        return Err(anyhow!("streaming jobs {:?} not found", missing).into());
305    }
306
307    build_loaded_context(txn, ensembles, fragment_map, jobs).await
308}
309
310/// Job-scoped rendering entry point that walks every no-shuffle root belonging to the
311/// provided streaming jobs before delegating to the shared rendering backend.
312pub async fn render_jobs<C>(
313    txn: &C,
314    actor_id_counter: &AtomicU32,
315    job_ids: HashSet<JobId>,
316    workers: BTreeMap<WorkerId, WorkerInfo>,
317    adaptive_parallelism_strategy: AdaptiveParallelismStrategy,
318) -> MetaResult<RenderedGraph>
319where
320    C: ConnectionTrait,
321{
322    let loaded = load_fragment_context_for_jobs(txn, job_ids).await?;
323
324    if loaded.is_empty() {
325        return Ok(RenderedGraph::empty());
326    }
327
328    render_actor_assignments(
329        actor_id_counter,
330        &workers,
331        adaptive_parallelism_strategy,
332        &loaded,
333    )
334}
335
336/// Async load stage for job-scoped rendering. It collects all no-shuffle ensembles and the
337/// metadata required to render actor assignments later with a provided worker set.
338pub async fn load_fragment_context_for_jobs<C>(
339    txn: &C,
340    job_ids: HashSet<JobId>,
341) -> MetaResult<LoadedFragmentContext>
342where
343    C: ConnectionTrait,
344{
345    if job_ids.is_empty() {
346        return Ok(LoadedFragmentContext::default());
347    }
348
349    let excluded_fragments_query = FragmentRelation::find()
350        .select_only()
351        .column(fragment_relation::Column::TargetFragmentId)
352        .filter(fragment_relation::Column::DispatcherType.eq(DispatcherType::NoShuffle))
353        .into_query();
354
355    let condition = Condition::all()
356        .add(fragment::Column::JobId.is_in(job_ids.clone()))
357        .add(fragment::Column::FragmentId.not_in_subquery(excluded_fragments_query));
358
359    let fragments: Vec<FragmentId> = Fragment::find()
360        .select_only()
361        .column(fragment::Column::FragmentId)
362        .filter(condition)
363        .into_tuple()
364        .all(txn)
365        .await?;
366
367    let ensembles = find_fragment_no_shuffle_dags_detailed(txn, &fragments).await?;
368
369    let fragments = Fragment::find()
370        .filter(
371            fragment::Column::FragmentId.is_in(
372                ensembles
373                    .iter()
374                    .flat_map(|graph| graph.components.iter())
375                    .cloned()
376                    .collect_vec(),
377            ),
378        )
379        .all(txn)
380        .await?;
381
382    let fragment_map: HashMap<_, _> = fragments
383        .into_iter()
384        .map(|fragment| (fragment.fragment_id, fragment))
385        .collect();
386
387    let job_ids = fragment_map
388        .values()
389        .map(|fragment| fragment.job_id)
390        .collect::<BTreeSet<_>>()
391        .into_iter()
392        .collect_vec();
393
394    let jobs: HashMap<_, _> = StreamingJob::find()
395        .filter(streaming_job::Column::JobId.is_in(job_ids))
396        .all(txn)
397        .await?
398        .into_iter()
399        .map(|job| (job.job_id, job))
400        .collect();
401
402    build_loaded_context(txn, ensembles, fragment_map, jobs).await
403}
404
405/// Sync render stage: uses loaded fragment context and current worker info
406/// to produce actor-to-worker assignments and vnode bitmaps.
407pub(crate) fn render_actor_assignments(
408    actor_id_counter: &AtomicU32,
409    worker_map: &BTreeMap<WorkerId, WorkerInfo>,
410    adaptive_parallelism_strategy: AdaptiveParallelismStrategy,
411    loaded: &LoadedFragmentContext,
412) -> MetaResult<RenderedGraph> {
413    if loaded.is_empty() {
414        return Ok(RenderedGraph::empty());
415    }
416
417    let backfill_jobs: HashSet<JobId> = loaded
418        .job_map
419        .iter()
420        .filter(|(_, job)| {
421            job.create_type == CreateType::Background && job.job_status == JobStatus::Creating
422        })
423        .map(|(id, _)| *id)
424        .collect();
425
426    let render_context = RenderActorsContext {
427        fragment_source_ids: &loaded.fragment_source_ids,
428        fragment_splits: &loaded.fragment_splits,
429        streaming_job_databases: &loaded.streaming_job_databases,
430        database_map: &loaded.database_map,
431        backfill_jobs: &backfill_jobs,
432    };
433
434    let fragments = render_actors(
435        actor_id_counter,
436        &loaded.ensembles,
437        &loaded.fragment_map,
438        &loaded.job_map,
439        worker_map,
440        adaptive_parallelism_strategy,
441        render_context,
442    )?;
443
444    Ok(RenderedGraph {
445        fragments,
446        ensembles: loaded.ensembles.clone(),
447    })
448}
449
450async fn build_loaded_context<C>(
451    txn: &C,
452    ensembles: Vec<NoShuffleEnsemble>,
453    fragment_map: HashMap<FragmentId, fragment::Model>,
454    job_map: HashMap<JobId, streaming_job::Model>,
455) -> MetaResult<LoadedFragmentContext>
456where
457    C: ConnectionTrait,
458{
459    if ensembles.is_empty() {
460        return Ok(LoadedFragmentContext::default());
461    }
462
463    #[cfg(debug_assertions)]
464    {
465        debug_sanity_check(&ensembles, &fragment_map, &job_map);
466    }
467
468    let (fragment_source_ids, fragment_splits) =
469        resolve_source_fragments(txn, &fragment_map).await?;
470
471    let job_ids = job_map.keys().copied().collect_vec();
472
473    let streaming_job_databases: HashMap<JobId, _> = StreamingJob::find()
474        .select_only()
475        .column(streaming_job::Column::JobId)
476        .column(object::Column::DatabaseId)
477        .join(JoinType::LeftJoin, streaming_job::Relation::Object.def())
478        .filter(streaming_job::Column::JobId.is_in(job_ids))
479        .into_tuple()
480        .all(txn)
481        .await?
482        .into_iter()
483        .collect();
484
485    let database_map: HashMap<_, _> = Database::find()
486        .filter(
487            database::Column::DatabaseId
488                .is_in(streaming_job_databases.values().copied().collect_vec()),
489        )
490        .all(txn)
491        .await?
492        .into_iter()
493        .map(|db| (db.database_id, db))
494        .collect();
495
496    Ok(LoadedFragmentContext {
497        ensembles,
498        fragment_map,
499        job_map,
500        streaming_job_databases,
501        database_map,
502        fragment_source_ids,
503        fragment_splits,
504    })
505}
506
507// Only metadata resolved asynchronously lives here so the renderer stays synchronous
508// and the call site keeps the runtime dependencies (maps, strategy, actor counter, etc.) explicit.
509struct RenderActorsContext<'a> {
510    fragment_source_ids: &'a HashMap<FragmentId, SourceId>,
511    fragment_splits: &'a HashMap<FragmentId, Vec<SplitImpl>>,
512    streaming_job_databases: &'a HashMap<JobId, DatabaseId>,
513    database_map: &'a HashMap<DatabaseId, database::Model>,
514    backfill_jobs: &'a HashSet<JobId>,
515}
516
517fn render_actors(
518    actor_id_counter: &AtomicU32,
519    ensembles: &[NoShuffleEnsemble],
520    fragment_map: &HashMap<FragmentId, fragment::Model>,
521    job_map: &HashMap<JobId, streaming_job::Model>,
522    worker_map: &BTreeMap<WorkerId, WorkerInfo>,
523    adaptive_parallelism_strategy: AdaptiveParallelismStrategy,
524    context: RenderActorsContext<'_>,
525) -> MetaResult<FragmentRenderMap> {
526    let RenderActorsContext {
527        fragment_source_ids,
528        fragment_splits: fragment_splits_map,
529        streaming_job_databases,
530        database_map,
531        backfill_jobs,
532    } = context;
533
534    let mut all_fragments: FragmentRenderMap = HashMap::new();
535
536    for NoShuffleEnsemble {
537        entries,
538        components,
539    } in ensembles
540    {
541        tracing::debug!("rendering ensemble entries {:?}", entries);
542
543        let entry_fragments = entries
544            .iter()
545            .map(|fragment_id| fragment_map.get(fragment_id).unwrap())
546            .collect_vec();
547
548        let entry_fragment_parallelism = entry_fragments
549            .iter()
550            .map(|fragment| fragment.parallelism.clone())
551            .dedup()
552            .exactly_one()
553            .map_err(|_| {
554                anyhow!(
555                    "entry fragments {:?} have inconsistent parallelism settings",
556                    entries.iter().copied().collect_vec()
557                )
558            })?;
559
560        let (job_id, vnode_count) = entry_fragments
561            .iter()
562            .map(|f| (f.job_id, f.vnode_count as usize))
563            .dedup()
564            .exactly_one()
565            .map_err(|_| anyhow!("Multiple jobs found in no-shuffle ensemble"))?;
566
567        let job = job_map
568            .get(&job_id)
569            .ok_or_else(|| anyhow!("streaming job {job_id} not found"))?;
570
571        let resource_group = match &job.specific_resource_group {
572            None => {
573                let database = streaming_job_databases
574                    .get(&job_id)
575                    .and_then(|database_id| database_map.get(database_id))
576                    .unwrap();
577                database.resource_group.clone()
578            }
579            Some(resource_group) => resource_group.clone(),
580        };
581
582        let available_workers: BTreeMap<WorkerId, NonZeroUsize> = worker_map
583            .iter()
584            .filter_map(|(worker_id, worker)| {
585                if worker
586                    .resource_group
587                    .as_deref()
588                    .unwrap_or(DEFAULT_RESOURCE_GROUP)
589                    == resource_group.as_str()
590                {
591                    Some((*worker_id, worker.parallelism))
592                } else {
593                    None
594                }
595            })
596            .collect();
597
598        let total_parallelism = available_workers.values().map(|w| w.get()).sum::<usize>();
599
600        let effective_job_parallelism = if backfill_jobs.contains(&job_id) {
601            job.backfill_parallelism
602                .as_ref()
603                .unwrap_or(&job.parallelism)
604        } else {
605            &job.parallelism
606        };
607
608        let actual_parallelism = match entry_fragment_parallelism
609            .as_ref()
610            .unwrap_or(effective_job_parallelism)
611        {
612            StreamingParallelism::Adaptive | StreamingParallelism::Custom => {
613                adaptive_parallelism_strategy.compute_target_parallelism(total_parallelism)
614            }
615            StreamingParallelism::Fixed(n) => *n,
616        }
617        .min(vnode_count)
618        .min(job.max_parallelism as usize);
619
620        tracing::debug!(
621            "job {}, final {} parallelism {:?} total_parallelism {} job_max {} vnode count {} fragment_override {:?}",
622            job_id,
623            actual_parallelism,
624            job.parallelism,
625            total_parallelism,
626            job.max_parallelism,
627            vnode_count,
628            entry_fragment_parallelism
629        );
630
631        let assigner = AssignerBuilder::new(job_id).build();
632
633        let actors = (0..(actual_parallelism as u32))
634            .map_into::<ActorId>()
635            .collect_vec();
636        let vnodes = (0..vnode_count).collect_vec();
637
638        let assignment = assigner.assign_hierarchical(&available_workers, &actors, &vnodes)?;
639
640        let source_entry_fragment = entry_fragments.iter().find(|f| {
641            let mask = FragmentTypeMask::from(f.fragment_type_mask);
642            if mask.contains(FragmentTypeFlag::Source) {
643                assert!(!mask.contains(FragmentTypeFlag::SourceScan))
644            }
645            mask.contains(FragmentTypeFlag::Source) && !mask.contains(FragmentTypeFlag::Dml)
646        });
647
648        let (fragment_splits, shared_source_id) = match source_entry_fragment {
649            Some(entry_fragment) => {
650                let source_id = fragment_source_ids
651                    .get(&entry_fragment.fragment_id)
652                    .ok_or_else(|| {
653                        anyhow!(
654                            "missing source id in source fragment {}",
655                            entry_fragment.fragment_id
656                        )
657                    })?;
658
659                let entry_fragment_id = entry_fragment.fragment_id;
660
661                let empty_actor_splits: HashMap<_, _> =
662                    actors.iter().map(|actor_id| (*actor_id, vec![])).collect();
663
664                let splits = fragment_splits_map
665                    .get(&entry_fragment_id)
666                    .cloned()
667                    .unwrap_or_default();
668
669                let splits: BTreeMap<_, _> = splits.into_iter().map(|s| (s.id(), s)).collect();
670
671                let fragment_splits = crate::stream::source_manager::reassign_splits(
672                    entry_fragment_id,
673                    empty_actor_splits,
674                    &splits,
675                    SplitDiffOptions::default(),
676                )
677                .unwrap_or_default();
678                (fragment_splits, Some(*source_id))
679            }
680            None => (HashMap::new(), None),
681        };
682
683        for component_fragment_id in components {
684            let &fragment::Model {
685                fragment_id,
686                job_id,
687                fragment_type_mask,
688                distribution_type,
689                ref stream_node,
690                ref state_table_ids,
691                ..
692            } = fragment_map.get(component_fragment_id).unwrap();
693
694            let actor_count =
695                u32::try_from(actors.len()).expect("actor parallelism exceeds u32::MAX");
696            let actor_id_base = actor_id_counter.fetch_add(actor_count, Ordering::Relaxed);
697
698            let actors: HashMap<ActorId, InflightActorInfo> = assignment
699                .iter()
700                .flat_map(|(worker_id, actors)| {
701                    actors
702                        .iter()
703                        .map(move |(actor_id, vnodes)| (worker_id, actor_id, vnodes))
704                })
705                .map(|(&worker_id, &actor_idx, vnodes)| {
706                    let vnode_bitmap = match distribution_type {
707                        DistributionType::Single => None,
708                        DistributionType::Hash => Some(Bitmap::from_indices(vnode_count, vnodes)),
709                    };
710
711                    let actor_id = actor_idx + actor_id_base;
712
713                    let splits = if let Some(source_id) = fragment_source_ids.get(&fragment_id) {
714                        assert_eq!(shared_source_id, Some(*source_id));
715
716                        fragment_splits
717                            .get(&(actor_idx))
718                            .cloned()
719                            .unwrap_or_default()
720                    } else {
721                        vec![]
722                    };
723
724                    (
725                        actor_id,
726                        InflightActorInfo {
727                            worker_id,
728                            vnode_bitmap,
729                            splits,
730                        },
731                    )
732                })
733                .collect();
734
735            let fragment = InflightFragmentInfo {
736                fragment_id,
737                distribution_type,
738                fragment_type_mask: fragment_type_mask.into(),
739                vnode_count,
740                nodes: stream_node.to_protobuf(),
741                actors,
742                state_table_ids: state_table_ids.inner_ref().iter().copied().collect(),
743            };
744
745            let &database_id = streaming_job_databases.get(&job_id).ok_or_else(|| {
746                anyhow!("streaming job {job_id} not found in streaming_job_databases")
747            })?;
748
749            all_fragments
750                .entry(database_id)
751                .or_default()
752                .entry(job_id)
753                .or_default()
754                .insert(fragment_id, fragment);
755        }
756    }
757
758    Ok(all_fragments)
759}
760
761#[cfg(debug_assertions)]
762fn debug_sanity_check(
763    ensembles: &[NoShuffleEnsemble],
764    fragment_map: &HashMap<FragmentId, fragment::Model>,
765    jobs: &HashMap<JobId, streaming_job::Model>,
766) {
767    // Debug-only assertions to catch inconsistent ensemble metadata early.
768    debug_assert!(
769        ensembles
770            .iter()
771            .all(|ensemble| ensemble.entries.is_subset(&ensemble.components)),
772        "entries must be subset of components"
773    );
774
775    let mut missing_fragments = BTreeSet::new();
776    let mut missing_jobs = BTreeSet::new();
777
778    for fragment_id in ensembles
779        .iter()
780        .flat_map(|ensemble| ensemble.components.iter())
781    {
782        match fragment_map.get(fragment_id) {
783            Some(fragment) => {
784                if !jobs.contains_key(&fragment.job_id) {
785                    missing_jobs.insert(fragment.job_id);
786                }
787            }
788            None => {
789                missing_fragments.insert(*fragment_id);
790            }
791        }
792    }
793
794    debug_assert!(
795        missing_fragments.is_empty(),
796        "missing fragments in fragment_map: {:?}",
797        missing_fragments
798    );
799
800    debug_assert!(
801        missing_jobs.is_empty(),
802        "missing jobs for fragments' job_id: {:?}",
803        missing_jobs
804    );
805
806    for ensemble in ensembles {
807        let unique_vnode_counts: Vec<_> = ensemble
808            .components
809            .iter()
810            .flat_map(|fragment_id| {
811                fragment_map
812                    .get(fragment_id)
813                    .map(|fragment| fragment.vnode_count)
814            })
815            .unique()
816            .collect();
817
818        debug_assert!(
819            unique_vnode_counts.len() <= 1,
820            "components in ensemble must share same vnode_count: ensemble={:?}, vnode_counts={:?}",
821            ensemble.components,
822            unique_vnode_counts
823        );
824    }
825}
826
827async fn resolve_source_fragments<C>(
828    txn: &C,
829    fragment_map: &HashMap<FragmentId, fragment::Model>,
830) -> MetaResult<(
831    HashMap<FragmentId, SourceId>,
832    HashMap<FragmentId, Vec<SplitImpl>>,
833)>
834where
835    C: ConnectionTrait,
836{
837    let mut source_fragment_ids: HashMap<SourceId, _> = HashMap::new();
838    for (fragment_id, fragment) in fragment_map {
839        let mask = FragmentTypeMask::from(fragment.fragment_type_mask);
840        if mask.contains(FragmentTypeFlag::Source)
841            && let Some(source_id) = fragment.stream_node.to_protobuf().find_stream_source()
842        {
843            source_fragment_ids
844                .entry(source_id)
845                .or_insert_with(BTreeSet::new)
846                .insert(fragment_id);
847        }
848
849        if mask.contains(FragmentTypeFlag::SourceScan)
850            && let Some((source_id, _)) = fragment.stream_node.to_protobuf().find_source_backfill()
851        {
852            source_fragment_ids
853                .entry(source_id)
854                .or_insert_with(BTreeSet::new)
855                .insert(fragment_id);
856        }
857    }
858
859    let fragment_source_ids: HashMap<_, _> = source_fragment_ids
860        .iter()
861        .flat_map(|(source_id, fragment_ids)| {
862            fragment_ids
863                .iter()
864                .map(|fragment_id| (**fragment_id, *source_id as SourceId))
865        })
866        .collect();
867
868    let fragment_ids = fragment_source_ids.keys().copied().collect_vec();
869
870    let fragment_splits: Vec<_> = FragmentSplits::find()
871        .filter(fragment_splits::Column::FragmentId.is_in(fragment_ids))
872        .all(txn)
873        .await?;
874
875    let fragment_splits: HashMap<_, _> = fragment_splits
876        .into_iter()
877        .flat_map(|model| {
878            model.splits.map(|splits| {
879                (
880                    model.fragment_id,
881                    splits
882                        .to_protobuf()
883                        .splits
884                        .iter()
885                        .flat_map(SplitImpl::try_from)
886                        .collect_vec(),
887                )
888            })
889        })
890        .collect();
891
892    Ok((fragment_source_ids, fragment_splits))
893}
894
895// Helper struct to make the function signature cleaner and to properly bundle the required data.
896#[derive(Debug)]
897pub struct ActorGraph<'a> {
898    pub fragments: &'a HashMap<FragmentId, (Fragment, Vec<StreamActor>)>,
899    pub locations: &'a HashMap<ActorId, WorkerId>,
900}
901
902#[derive(Debug, Clone)]
903pub struct NoShuffleEnsemble {
904    entries: HashSet<FragmentId>,
905    components: HashSet<FragmentId>,
906}
907
908impl NoShuffleEnsemble {
909    pub fn fragments(&self) -> impl Iterator<Item = FragmentId> + '_ {
910        self.components.iter().cloned()
911    }
912
913    pub fn entry_fragments(&self) -> impl Iterator<Item = FragmentId> + '_ {
914        self.entries.iter().copied()
915    }
916
917    pub fn component_fragments(&self) -> impl Iterator<Item = FragmentId> + '_ {
918        self.components.iter().copied()
919    }
920
921    pub fn contains_entry(&self, fragment_id: &FragmentId) -> bool {
922        self.entries.contains(fragment_id)
923    }
924}
925
926pub async fn find_fragment_no_shuffle_dags_detailed(
927    db: &impl ConnectionTrait,
928    initial_fragment_ids: &[FragmentId],
929) -> MetaResult<Vec<NoShuffleEnsemble>> {
930    let all_no_shuffle_relations: Vec<(_, _)> = FragmentRelation::find()
931        .columns([
932            fragment_relation::Column::SourceFragmentId,
933            fragment_relation::Column::TargetFragmentId,
934        ])
935        .filter(fragment_relation::Column::DispatcherType.eq(DispatcherType::NoShuffle))
936        .into_tuple()
937        .all(db)
938        .await?;
939
940    let mut forward_edges: HashMap<FragmentId, Vec<FragmentId>> = HashMap::new();
941    let mut backward_edges: HashMap<FragmentId, Vec<FragmentId>> = HashMap::new();
942
943    for (src, dst) in all_no_shuffle_relations {
944        forward_edges.entry(src).or_default().push(dst);
945        backward_edges.entry(dst).or_default().push(src);
946    }
947
948    find_no_shuffle_graphs(initial_fragment_ids, &forward_edges, &backward_edges)
949}
950
951fn find_no_shuffle_graphs(
952    initial_fragment_ids: &[impl Into<FragmentId> + Copy],
953    forward_edges: &HashMap<FragmentId, Vec<FragmentId>>,
954    backward_edges: &HashMap<FragmentId, Vec<FragmentId>>,
955) -> MetaResult<Vec<NoShuffleEnsemble>> {
956    let mut graphs: Vec<NoShuffleEnsemble> = Vec::new();
957    let mut globally_visited: HashSet<FragmentId> = HashSet::new();
958
959    for &init_id in initial_fragment_ids {
960        let init_id = init_id.into();
961        if globally_visited.contains(&init_id) {
962            continue;
963        }
964
965        // Found a new component. Traverse it to find all its nodes.
966        let mut components = HashSet::new();
967        let mut queue: VecDeque<FragmentId> = VecDeque::new();
968
969        queue.push_back(init_id);
970        globally_visited.insert(init_id);
971
972        while let Some(current_id) = queue.pop_front() {
973            components.insert(current_id);
974            let neighbors = forward_edges
975                .get(&current_id)
976                .into_iter()
977                .flatten()
978                .chain(backward_edges.get(&current_id).into_iter().flatten());
979
980            for &neighbor_id in neighbors {
981                if globally_visited.insert(neighbor_id) {
982                    queue.push_back(neighbor_id);
983                }
984            }
985        }
986
987        // For the newly found component, identify its roots.
988        let mut entries = HashSet::new();
989        for &node_id in &components {
990            let is_root = match backward_edges.get(&node_id) {
991                Some(parents) => parents.iter().all(|p| !components.contains(p)),
992                None => true,
993            };
994            if is_root {
995                entries.insert(node_id);
996            }
997        }
998
999        // Store the detailed DAG structure (roots, all nodes in this DAG).
1000        if !entries.is_empty() {
1001            graphs.push(NoShuffleEnsemble {
1002                entries,
1003                components,
1004            });
1005        }
1006    }
1007
1008    Ok(graphs)
1009}
1010
1011#[cfg(test)]
1012mod tests {
1013    use std::collections::{BTreeSet, HashMap, HashSet};
1014    use std::sync::Arc;
1015
1016    use risingwave_connector::source::SplitImpl;
1017    use risingwave_connector::source::test_source::TestSourceSplit;
1018    use risingwave_meta_model::{CreateType, I32Array, JobStatus, StreamNode, TableIdArray};
1019    use risingwave_pb::stream_plan::StreamNode as PbStreamNode;
1020
1021    use super::*;
1022
1023    // Helper type aliases for cleaner test code
1024    // Using the actual FragmentId type from the module
1025    type Edges = (
1026        HashMap<FragmentId, Vec<FragmentId>>,
1027        HashMap<FragmentId, Vec<FragmentId>>,
1028    );
1029
1030    /// A helper function to build forward and backward edge maps from a simple list of tuples.
1031    /// This reduces boilerplate in each test.
1032    fn build_edges(relations: &[(u32, u32)]) -> Edges {
1033        let mut forward_edges: HashMap<FragmentId, Vec<FragmentId>> = HashMap::new();
1034        let mut backward_edges: HashMap<FragmentId, Vec<FragmentId>> = HashMap::new();
1035        for &(src, dst) in relations {
1036            forward_edges
1037                .entry(src.into())
1038                .or_default()
1039                .push(dst.into());
1040            backward_edges
1041                .entry(dst.into())
1042                .or_default()
1043                .push(src.into());
1044        }
1045        (forward_edges, backward_edges)
1046    }
1047
1048    /// Helper function to create a `HashSet` from a slice easily.
1049    fn to_hashset(ids: &[u32]) -> HashSet<FragmentId> {
1050        ids.iter().map(|id| (*id).into()).collect()
1051    }
1052
1053    #[allow(deprecated)]
1054    fn build_fragment(
1055        fragment_id: FragmentId,
1056        job_id: JobId,
1057        fragment_type_mask: i32,
1058        distribution_type: DistributionType,
1059        vnode_count: i32,
1060        parallelism: StreamingParallelism,
1061    ) -> fragment::Model {
1062        fragment::Model {
1063            fragment_id,
1064            job_id,
1065            fragment_type_mask,
1066            distribution_type,
1067            stream_node: StreamNode::from(&PbStreamNode::default()),
1068            state_table_ids: TableIdArray::default(),
1069            upstream_fragment_id: I32Array::default(),
1070            vnode_count,
1071            parallelism: Some(parallelism),
1072        }
1073    }
1074
1075    type ActorState = (ActorId, WorkerId, Option<Vec<usize>>, Vec<String>);
1076
1077    fn collect_actor_state(fragment: &InflightFragmentInfo) -> Vec<ActorState> {
1078        let base = fragment.actors.keys().copied().min().unwrap_or_default();
1079
1080        let mut entries: Vec<_> = fragment
1081            .actors
1082            .iter()
1083            .map(|(&actor_id, info)| {
1084                let idx = actor_id.as_raw_id() - base.as_raw_id();
1085                let vnode_indices = info.vnode_bitmap.as_ref().map(|bitmap| {
1086                    bitmap
1087                        .iter()
1088                        .enumerate()
1089                        .filter_map(|(pos, is_set)| is_set.then_some(pos))
1090                        .collect::<Vec<_>>()
1091                });
1092                let splits = info
1093                    .splits
1094                    .iter()
1095                    .map(|split| split.id().to_string())
1096                    .collect::<Vec<_>>();
1097                (idx.into(), info.worker_id, vnode_indices, splits)
1098            })
1099            .collect();
1100
1101        entries.sort_by_key(|(idx, _, _, _)| *idx);
1102        entries
1103    }
1104
1105    #[test]
1106    fn test_single_linear_chain() {
1107        // Scenario: A simple linear graph 1 -> 2 -> 3.
1108        // We start from the middle node (2).
1109        let (forward, backward) = build_edges(&[(1, 2), (2, 3)]);
1110        let initial_ids = &[2];
1111
1112        // Act
1113        let result = find_no_shuffle_graphs(initial_ids, &forward, &backward);
1114
1115        // Assert
1116        assert!(result.is_ok());
1117        let graphs = result.unwrap();
1118
1119        assert_eq!(graphs.len(), 1);
1120        let graph = &graphs[0];
1121        assert_eq!(graph.entries, to_hashset(&[1]));
1122        assert_eq!(graph.components, to_hashset(&[1, 2, 3]));
1123    }
1124
1125    #[test]
1126    fn test_two_disconnected_graphs() {
1127        // Scenario: Two separate graphs: 1->2 and 10->11.
1128        // We start with one node from each graph.
1129        let (forward, backward) = build_edges(&[(1, 2), (10, 11)]);
1130        let initial_ids = &[2, 10];
1131
1132        // Act
1133        let mut graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1134
1135        // Assert
1136        assert_eq!(graphs.len(), 2);
1137
1138        // Sort results to make the test deterministic, as HashMap iteration order is not guaranteed.
1139        graphs.sort_by_key(|g| *g.components.iter().min().unwrap_or(&0.into()));
1140
1141        // Graph 1
1142        assert_eq!(graphs[0].entries, to_hashset(&[1]));
1143        assert_eq!(graphs[0].components, to_hashset(&[1, 2]));
1144
1145        // Graph 2
1146        assert_eq!(graphs[1].entries, to_hashset(&[10]));
1147        assert_eq!(graphs[1].components, to_hashset(&[10, 11]));
1148    }
1149
1150    #[test]
1151    fn test_multiple_entries_in_one_graph() {
1152        // Scenario: A graph with two roots feeding into one node: 1->3, 2->3.
1153        let (forward, backward) = build_edges(&[(1, 3), (2, 3)]);
1154        let initial_ids = &[3];
1155
1156        // Act
1157        let graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1158
1159        // Assert
1160        assert_eq!(graphs.len(), 1);
1161        let graph = &graphs[0];
1162        assert_eq!(graph.entries, to_hashset(&[1, 2]));
1163        assert_eq!(graph.components, to_hashset(&[1, 2, 3]));
1164    }
1165
1166    #[test]
1167    fn test_diamond_shape_graph() {
1168        // Scenario: A diamond shape: 1->2, 1->3, 2->4, 3->4
1169        let (forward, backward) = build_edges(&[(1, 2), (1, 3), (2, 4), (3, 4)]);
1170        let initial_ids = &[4];
1171
1172        // Act
1173        let graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1174
1175        // Assert
1176        assert_eq!(graphs.len(), 1);
1177        let graph = &graphs[0];
1178        assert_eq!(graph.entries, to_hashset(&[1]));
1179        assert_eq!(graph.components, to_hashset(&[1, 2, 3, 4]));
1180    }
1181
1182    #[test]
1183    fn test_starting_with_multiple_nodes_in_same_graph() {
1184        // Scenario: Start with two different nodes (2 and 4) from the same component.
1185        // Should only identify one graph, not two.
1186        let (forward, backward) = build_edges(&[(1, 2), (2, 3), (3, 4)]);
1187        let initial_ids = &[2, 4];
1188
1189        // Act
1190        let graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1191
1192        // Assert
1193        assert_eq!(graphs.len(), 1);
1194        let graph = &graphs[0];
1195        assert_eq!(graph.entries, to_hashset(&[1]));
1196        assert_eq!(graph.components, to_hashset(&[1, 2, 3, 4]));
1197    }
1198
1199    #[test]
1200    fn test_empty_initial_ids() {
1201        // Scenario: The initial ID list is empty.
1202        let (forward, backward) = build_edges(&[(1, 2)]);
1203        let initial_ids: &[u32] = &[];
1204
1205        // Act
1206        let graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1207
1208        // Assert
1209        assert!(graphs.is_empty());
1210    }
1211
1212    #[test]
1213    fn test_isolated_node_as_input() {
1214        // Scenario: Start with an ID that has no relations.
1215        let (forward, backward) = build_edges(&[(1, 2)]);
1216        let initial_ids = &[100];
1217
1218        // Act
1219        let graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1220
1221        // Assert
1222        assert_eq!(graphs.len(), 1);
1223        let graph = &graphs[0];
1224        assert_eq!(graph.entries, to_hashset(&[100]));
1225        assert_eq!(graph.components, to_hashset(&[100]));
1226    }
1227
1228    #[test]
1229    fn test_graph_with_a_cycle() {
1230        // Scenario: A graph with a cycle: 1 -> 2 -> 3 -> 1.
1231        // The algorithm should correctly identify all nodes in the component.
1232        // Crucially, NO node is a root because every node has a parent *within the component*.
1233        // Therefore, the `entries` set should be empty, and the graph should not be included in the results.
1234        let (forward, backward) = build_edges(&[(1, 2), (2, 3), (3, 1)]);
1235        let initial_ids = &[2];
1236
1237        // Act
1238        let graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1239
1240        // Assert
1241        assert!(
1242            graphs.is_empty(),
1243            "A graph with no entries should not be returned"
1244        );
1245    }
1246    #[test]
1247    fn test_custom_complex() {
1248        let (forward, backward) = build_edges(&[(1, 3), (1, 8), (2, 3), (4, 3), (3, 5), (6, 7)]);
1249        let initial_ids = &[1, 2, 4, 6];
1250
1251        // Act
1252        let mut graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1253
1254        // Assert
1255        assert_eq!(graphs.len(), 2);
1256        // Sort results to make the test deterministic, as HashMap iteration order is not guaranteed.
1257        graphs.sort_by_key(|g| *g.components.iter().min().unwrap_or(&0.into()));
1258
1259        // Graph 1
1260        assert_eq!(graphs[0].entries, to_hashset(&[1, 2, 4]));
1261        assert_eq!(graphs[0].components, to_hashset(&[1, 2, 3, 4, 5, 8]));
1262
1263        // Graph 2
1264        assert_eq!(graphs[1].entries, to_hashset(&[6]));
1265        assert_eq!(graphs[1].components, to_hashset(&[6, 7]));
1266    }
1267
1268    #[test]
1269    fn render_actors_increments_actor_counter() {
1270        let actor_id_counter = AtomicU32::new(100);
1271        let fragment_id: FragmentId = 1.into();
1272        let job_id: JobId = 10.into();
1273        let database_id: DatabaseId = DatabaseId::new(3);
1274
1275        let fragment_model = build_fragment(
1276            fragment_id,
1277            job_id,
1278            0,
1279            DistributionType::Single,
1280            1,
1281            StreamingParallelism::Fixed(1),
1282        );
1283
1284        let job_model = streaming_job::Model {
1285            job_id,
1286            job_status: JobStatus::Created,
1287            create_type: CreateType::Foreground,
1288            timezone: None,
1289            config_override: None,
1290            parallelism: StreamingParallelism::Fixed(1),
1291            backfill_parallelism: None,
1292            max_parallelism: 1,
1293            specific_resource_group: None,
1294        };
1295
1296        let database_model = database::Model {
1297            database_id,
1298            name: "test_db".into(),
1299            resource_group: "rg-a".into(),
1300            barrier_interval_ms: None,
1301            checkpoint_frequency: None,
1302        };
1303
1304        let ensembles = vec![NoShuffleEnsemble {
1305            entries: HashSet::from([fragment_id]),
1306            components: HashSet::from([fragment_id]),
1307        }];
1308
1309        let fragment_map = HashMap::from([(fragment_id, fragment_model)]);
1310        let job_map = HashMap::from([(job_id, job_model)]);
1311
1312        let worker_map = BTreeMap::from([(
1313            1.into(),
1314            WorkerInfo {
1315                parallelism: NonZeroUsize::new(1).unwrap(),
1316                resource_group: Some("rg-a".into()),
1317            },
1318        )]);
1319
1320        let fragment_source_ids: HashMap<FragmentId, SourceId> = HashMap::new();
1321        let fragment_splits: HashMap<FragmentId, Vec<SplitImpl>> = HashMap::new();
1322        let streaming_job_databases = HashMap::from([(job_id, database_id)]);
1323        let database_map = HashMap::from([(database_id, database_model)]);
1324        let backfill_jobs = HashSet::new();
1325
1326        let context = RenderActorsContext {
1327            fragment_source_ids: &fragment_source_ids,
1328            fragment_splits: &fragment_splits,
1329            streaming_job_databases: &streaming_job_databases,
1330            database_map: &database_map,
1331            backfill_jobs: &backfill_jobs,
1332        };
1333
1334        let result = render_actors(
1335            &actor_id_counter,
1336            &ensembles,
1337            &fragment_map,
1338            &job_map,
1339            &worker_map,
1340            AdaptiveParallelismStrategy::Auto,
1341            context,
1342        )
1343        .expect("actor rendering succeeds");
1344
1345        let state = collect_actor_state(&result[&database_id][&job_id][&fragment_id]);
1346        assert_eq!(state.len(), 1);
1347        assert!(
1348            state[0].2.is_none(),
1349            "single distribution should not assign vnode bitmaps"
1350        );
1351        assert_eq!(actor_id_counter.load(Ordering::Relaxed), 101);
1352    }
1353
1354    #[test]
1355    fn render_actors_aligns_hash_vnode_bitmaps() {
1356        let actor_id_counter = AtomicU32::new(0);
1357        let entry_fragment_id: FragmentId = 1.into();
1358        let downstream_fragment_id: FragmentId = 2.into();
1359        let job_id: JobId = 20.into();
1360        let database_id: DatabaseId = DatabaseId::new(5);
1361
1362        let entry_fragment = build_fragment(
1363            entry_fragment_id,
1364            job_id,
1365            0,
1366            DistributionType::Hash,
1367            4,
1368            StreamingParallelism::Fixed(2),
1369        );
1370
1371        let downstream_fragment = build_fragment(
1372            downstream_fragment_id,
1373            job_id,
1374            0,
1375            DistributionType::Hash,
1376            4,
1377            StreamingParallelism::Fixed(2),
1378        );
1379
1380        let job_model = streaming_job::Model {
1381            job_id,
1382            job_status: JobStatus::Created,
1383            create_type: CreateType::Background,
1384            timezone: None,
1385            config_override: None,
1386            parallelism: StreamingParallelism::Fixed(2),
1387            backfill_parallelism: None,
1388            max_parallelism: 2,
1389            specific_resource_group: None,
1390        };
1391
1392        let database_model = database::Model {
1393            database_id,
1394            name: "test_db_hash".into(),
1395            resource_group: "rg-hash".into(),
1396            barrier_interval_ms: None,
1397            checkpoint_frequency: None,
1398        };
1399
1400        let ensembles = vec![NoShuffleEnsemble {
1401            entries: HashSet::from([entry_fragment_id]),
1402            components: HashSet::from([entry_fragment_id, downstream_fragment_id]),
1403        }];
1404
1405        let fragment_map = HashMap::from([
1406            (entry_fragment_id, entry_fragment),
1407            (downstream_fragment_id, downstream_fragment),
1408        ]);
1409        let job_map = HashMap::from([(job_id, job_model)]);
1410
1411        let worker_map = BTreeMap::from([
1412            (
1413                1.into(),
1414                WorkerInfo {
1415                    parallelism: NonZeroUsize::new(1).unwrap(),
1416                    resource_group: Some("rg-hash".into()),
1417                },
1418            ),
1419            (
1420                2.into(),
1421                WorkerInfo {
1422                    parallelism: NonZeroUsize::new(1).unwrap(),
1423                    resource_group: Some("rg-hash".into()),
1424                },
1425            ),
1426        ]);
1427
1428        let fragment_source_ids: HashMap<FragmentId, SourceId> = HashMap::new();
1429        let fragment_splits: HashMap<FragmentId, Vec<SplitImpl>> = HashMap::new();
1430        let streaming_job_databases = HashMap::from([(job_id, database_id)]);
1431        let database_map = HashMap::from([(database_id, database_model)]);
1432        let backfill_jobs = HashSet::new();
1433
1434        let context = RenderActorsContext {
1435            fragment_source_ids: &fragment_source_ids,
1436            fragment_splits: &fragment_splits,
1437            streaming_job_databases: &streaming_job_databases,
1438            database_map: &database_map,
1439            backfill_jobs: &backfill_jobs,
1440        };
1441
1442        let result = render_actors(
1443            &actor_id_counter,
1444            &ensembles,
1445            &fragment_map,
1446            &job_map,
1447            &worker_map,
1448            AdaptiveParallelismStrategy::Auto,
1449            context,
1450        )
1451        .expect("actor rendering succeeds");
1452
1453        let entry_state = collect_actor_state(&result[&database_id][&job_id][&entry_fragment_id]);
1454        let downstream_state =
1455            collect_actor_state(&result[&database_id][&job_id][&downstream_fragment_id]);
1456
1457        assert_eq!(entry_state.len(), 2);
1458        assert_eq!(entry_state, downstream_state);
1459
1460        let assigned_vnodes: BTreeSet<_> = entry_state
1461            .iter()
1462            .flat_map(|(_, _, vnodes, _)| {
1463                vnodes
1464                    .as_ref()
1465                    .expect("hash distribution should populate vnode bitmap")
1466                    .iter()
1467                    .copied()
1468            })
1469            .collect();
1470        assert_eq!(assigned_vnodes, BTreeSet::from([0, 1, 2, 3]));
1471        assert_eq!(actor_id_counter.load(Ordering::Relaxed), 4);
1472    }
1473
1474    #[test]
1475    fn render_actors_propagates_source_splits() {
1476        let actor_id_counter = AtomicU32::new(0);
1477        let entry_fragment_id: FragmentId = 11.into();
1478        let downstream_fragment_id: FragmentId = 12.into();
1479        let job_id: JobId = 30.into();
1480        let database_id: DatabaseId = DatabaseId::new(7);
1481        let source_id: SourceId = 99.into();
1482
1483        let source_mask = FragmentTypeFlag::raw_flag([FragmentTypeFlag::Source]) as i32;
1484        let source_scan_mask = FragmentTypeFlag::raw_flag([FragmentTypeFlag::SourceScan]) as i32;
1485
1486        let entry_fragment = build_fragment(
1487            entry_fragment_id,
1488            job_id,
1489            source_mask,
1490            DistributionType::Hash,
1491            4,
1492            StreamingParallelism::Fixed(2),
1493        );
1494
1495        let downstream_fragment = build_fragment(
1496            downstream_fragment_id,
1497            job_id,
1498            source_scan_mask,
1499            DistributionType::Hash,
1500            4,
1501            StreamingParallelism::Fixed(2),
1502        );
1503
1504        let job_model = streaming_job::Model {
1505            job_id,
1506            job_status: JobStatus::Created,
1507            create_type: CreateType::Background,
1508            timezone: None,
1509            config_override: None,
1510            parallelism: StreamingParallelism::Fixed(2),
1511            backfill_parallelism: None,
1512            max_parallelism: 2,
1513            specific_resource_group: None,
1514        };
1515
1516        let database_model = database::Model {
1517            database_id,
1518            name: "split_db".into(),
1519            resource_group: "rg-source".into(),
1520            barrier_interval_ms: None,
1521            checkpoint_frequency: None,
1522        };
1523
1524        let ensembles = vec![NoShuffleEnsemble {
1525            entries: HashSet::from([entry_fragment_id]),
1526            components: HashSet::from([entry_fragment_id, downstream_fragment_id]),
1527        }];
1528
1529        let fragment_map = HashMap::from([
1530            (entry_fragment_id, entry_fragment),
1531            (downstream_fragment_id, downstream_fragment),
1532        ]);
1533        let job_map = HashMap::from([(job_id, job_model)]);
1534
1535        let worker_map = BTreeMap::from([
1536            (
1537                1.into(),
1538                WorkerInfo {
1539                    parallelism: NonZeroUsize::new(1).unwrap(),
1540                    resource_group: Some("rg-source".into()),
1541                },
1542            ),
1543            (
1544                2.into(),
1545                WorkerInfo {
1546                    parallelism: NonZeroUsize::new(1).unwrap(),
1547                    resource_group: Some("rg-source".into()),
1548                },
1549            ),
1550        ]);
1551
1552        let split_a = SplitImpl::Test(TestSourceSplit {
1553            id: Arc::<str>::from("split-a"),
1554            properties: HashMap::new(),
1555            offset: "0".into(),
1556        });
1557        let split_b = SplitImpl::Test(TestSourceSplit {
1558            id: Arc::<str>::from("split-b"),
1559            properties: HashMap::new(),
1560            offset: "0".into(),
1561        });
1562
1563        let fragment_source_ids = HashMap::from([
1564            (entry_fragment_id, source_id),
1565            (downstream_fragment_id, source_id),
1566        ]);
1567        let fragment_splits =
1568            HashMap::from([(entry_fragment_id, vec![split_a.clone(), split_b.clone()])]);
1569        let streaming_job_databases = HashMap::from([(job_id, database_id)]);
1570        let database_map = HashMap::from([(database_id, database_model)]);
1571        let backfill_jobs = HashSet::new();
1572
1573        let context = RenderActorsContext {
1574            fragment_source_ids: &fragment_source_ids,
1575            fragment_splits: &fragment_splits,
1576            streaming_job_databases: &streaming_job_databases,
1577            database_map: &database_map,
1578            backfill_jobs: &backfill_jobs,
1579        };
1580
1581        let result = render_actors(
1582            &actor_id_counter,
1583            &ensembles,
1584            &fragment_map,
1585            &job_map,
1586            &worker_map,
1587            AdaptiveParallelismStrategy::Auto,
1588            context,
1589        )
1590        .expect("actor rendering succeeds");
1591
1592        let entry_state = collect_actor_state(&result[&database_id][&job_id][&entry_fragment_id]);
1593        let downstream_state =
1594            collect_actor_state(&result[&database_id][&job_id][&downstream_fragment_id]);
1595
1596        assert_eq!(entry_state, downstream_state);
1597
1598        let split_ids: BTreeSet<_> = entry_state
1599            .iter()
1600            .flat_map(|(_, _, _, splits)| splits.iter().cloned())
1601            .collect();
1602        assert_eq!(
1603            split_ids,
1604            BTreeSet::from([split_a.id().to_string(), split_b.id().to_string()])
1605        );
1606        assert_eq!(actor_id_counter.load(Ordering::Relaxed), 4);
1607    }
1608}