1use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque};
16use std::num::NonZeroUsize;
17use std::sync::atomic::{AtomicU32, Ordering};
18
19use anyhow::anyhow;
20use itertools::Itertools;
21use risingwave_common::bitmap::Bitmap;
22use risingwave_common::catalog::{FragmentTypeFlag, FragmentTypeMask};
23use risingwave_common::id::JobId;
24use risingwave_common::system_param::AdaptiveParallelismStrategy;
25use risingwave_common::util::worker_util::DEFAULT_RESOURCE_GROUP;
26use risingwave_connector::source::{SplitImpl, SplitMetaData};
27use risingwave_meta_model::fragment::DistributionType;
28use risingwave_meta_model::prelude::{
29 Database, Fragment, FragmentRelation, FragmentSplits, Sink, Source, StreamingJob, Table,
30};
31use risingwave_meta_model::{
32 DatabaseId, DispatcherType, FragmentId, SourceId, StreamingParallelism, WorkerId, database,
33 fragment, fragment_relation, fragment_splits, object, sink, source, streaming_job, table,
34};
35use risingwave_meta_model_migration::Condition;
36use sea_orm::{
37 ColumnTrait, ConnectionTrait, EntityTrait, JoinType, QueryFilter, QuerySelect, QueryTrait,
38 RelationTrait,
39};
40
41use crate::MetaResult;
42use crate::controller::fragment::{InflightActorInfo, InflightFragmentInfo};
43use crate::manager::ActiveStreamingWorkerNodes;
44use crate::model::{ActorId, StreamActor};
45use crate::stream::{AssignerBuilder, SplitDiffOptions};
46
47pub(crate) async fn resolve_streaming_job_definition<C>(
48 txn: &C,
49 job_ids: &HashSet<JobId>,
50) -> MetaResult<HashMap<JobId, String>>
51where
52 C: ConnectionTrait,
53{
54 let job_ids = job_ids.iter().cloned().collect_vec();
55
56 let common_job_definitions: Vec<(JobId, String)> = Table::find()
58 .select_only()
59 .columns([
60 table::Column::TableId,
61 #[cfg(not(debug_assertions))]
62 table::Column::Name,
63 #[cfg(debug_assertions)]
64 table::Column::Definition,
65 ])
66 .filter(table::Column::TableId.is_in(job_ids.clone()))
67 .into_tuple()
68 .all(txn)
69 .await?;
70
71 let sink_definitions: Vec<(JobId, String)> = Sink::find()
72 .select_only()
73 .columns([
74 sink::Column::SinkId,
75 #[cfg(not(debug_assertions))]
76 sink::Column::Name,
77 #[cfg(debug_assertions)]
78 sink::Column::Definition,
79 ])
80 .filter(sink::Column::SinkId.is_in(job_ids.clone()))
81 .into_tuple()
82 .all(txn)
83 .await?;
84
85 let source_definitions: Vec<(JobId, String)> = Source::find()
86 .select_only()
87 .columns([
88 source::Column::SourceId,
89 #[cfg(not(debug_assertions))]
90 source::Column::Name,
91 #[cfg(debug_assertions)]
92 source::Column::Definition,
93 ])
94 .filter(source::Column::SourceId.is_in(job_ids.clone()))
95 .into_tuple()
96 .all(txn)
97 .await?;
98
99 let definitions: HashMap<JobId, String> = common_job_definitions
100 .into_iter()
101 .chain(sink_definitions.into_iter())
102 .chain(source_definitions.into_iter())
103 .collect();
104
105 Ok(definitions)
106}
107
108pub async fn load_fragment_info<C>(
109 txn: &C,
110 actor_id_counter: &AtomicU32,
111 database_id: Option<DatabaseId>,
112 worker_nodes: &ActiveStreamingWorkerNodes,
113 adaptive_parallelism_strategy: AdaptiveParallelismStrategy,
114) -> MetaResult<FragmentRenderMap>
115where
116 C: ConnectionTrait,
117{
118 let mut query = StreamingJob::find()
119 .select_only()
120 .column(streaming_job::Column::JobId);
121
122 if let Some(database_id) = database_id {
123 query = query
124 .join(JoinType::InnerJoin, streaming_job::Relation::Object.def())
125 .filter(object::Column::DatabaseId.eq(database_id));
126 }
127
128 let jobs: Vec<JobId> = query.into_tuple().all(txn).await?;
129
130 if jobs.is_empty() {
131 return Ok(HashMap::new());
132 }
133
134 let jobs: HashSet<JobId> = jobs.into_iter().collect();
135
136 let available_workers: BTreeMap<_, _> = worker_nodes
137 .current()
138 .values()
139 .filter(|worker| worker.is_streaming_schedulable())
140 .map(|worker| {
141 (
142 worker.id,
143 WorkerInfo {
144 parallelism: NonZeroUsize::new(worker.compute_node_parallelism()).unwrap(),
145 resource_group: worker.resource_group(),
146 },
147 )
148 })
149 .collect();
150
151 let RenderedGraph { fragments, .. } = render_jobs(
152 txn,
153 actor_id_counter,
154 jobs,
155 available_workers,
156 adaptive_parallelism_strategy,
157 )
158 .await?;
159
160 Ok(fragments)
161}
162
163#[derive(Debug)]
164pub struct TargetResourcePolicy {
165 pub resource_group: Option<String>,
166 pub parallelism: StreamingParallelism,
167}
168
169#[derive(Debug, Clone)]
170pub struct WorkerInfo {
171 pub parallelism: NonZeroUsize,
172 pub resource_group: Option<String>,
173}
174
175pub type FragmentRenderMap =
176 HashMap<DatabaseId, HashMap<JobId, HashMap<FragmentId, InflightFragmentInfo>>>;
177
178#[derive(Default)]
179pub struct RenderedGraph {
180 pub fragments: FragmentRenderMap,
181 pub ensembles: Vec<NoShuffleEnsemble>,
182}
183
184impl RenderedGraph {
185 pub fn empty() -> Self {
186 Self::default()
187 }
188}
189
190pub async fn render_fragments<C>(
195 txn: &C,
196 actor_id_counter: &AtomicU32,
197 ensembles: Vec<NoShuffleEnsemble>,
198 workers: BTreeMap<WorkerId, WorkerInfo>,
199 adaptive_parallelism_strategy: AdaptiveParallelismStrategy,
200) -> MetaResult<RenderedGraph>
201where
202 C: ConnectionTrait,
203{
204 if ensembles.is_empty() {
205 return Ok(RenderedGraph::empty());
206 }
207
208 let required_fragment_ids: HashSet<_> = ensembles
209 .iter()
210 .flat_map(|ensemble| ensemble.components.iter().copied())
211 .collect();
212
213 let fragment_models = Fragment::find()
214 .filter(fragment::Column::FragmentId.is_in(required_fragment_ids.iter().copied()))
215 .all(txn)
216 .await?;
217
218 let found_fragment_ids: HashSet<_> = fragment_models
219 .iter()
220 .map(|fragment| fragment.fragment_id)
221 .collect();
222
223 if found_fragment_ids.len() != required_fragment_ids.len() {
224 let missing = required_fragment_ids
225 .difference(&found_fragment_ids)
226 .copied()
227 .collect_vec();
228 return Err(anyhow!("fragments {:?} not found", missing).into());
229 }
230
231 let fragment_map: HashMap<_, _> = fragment_models
232 .into_iter()
233 .map(|fragment| (fragment.fragment_id, fragment))
234 .collect();
235
236 let job_ids: HashSet<_> = fragment_map
237 .values()
238 .map(|fragment| fragment.job_id)
239 .collect();
240
241 if job_ids.is_empty() {
242 return Ok(RenderedGraph::empty());
243 }
244
245 let jobs: HashMap<_, _> = StreamingJob::find()
246 .filter(streaming_job::Column::JobId.is_in(job_ids.iter().copied().collect_vec()))
247 .all(txn)
248 .await?
249 .into_iter()
250 .map(|job| (job.job_id, job))
251 .collect();
252
253 let found_job_ids: HashSet<_> = jobs.keys().copied().collect();
254 if found_job_ids.len() != job_ids.len() {
255 let missing = job_ids.difference(&found_job_ids).copied().collect_vec();
256 return Err(anyhow!("streaming jobs {:?} not found", missing).into());
257 }
258
259 let fragments = render_no_shuffle_ensembles(
260 txn,
261 actor_id_counter,
262 &ensembles,
263 &fragment_map,
264 &jobs,
265 &workers,
266 adaptive_parallelism_strategy,
267 )
268 .await?;
269
270 Ok(RenderedGraph {
271 fragments,
272 ensembles,
273 })
274}
275
276pub async fn render_jobs<C>(
279 txn: &C,
280 actor_id_counter: &AtomicU32,
281 job_ids: HashSet<JobId>,
282 workers: BTreeMap<WorkerId, WorkerInfo>,
283 adaptive_parallelism_strategy: AdaptiveParallelismStrategy,
284) -> MetaResult<RenderedGraph>
285where
286 C: ConnectionTrait,
287{
288 let excluded_fragments_query = FragmentRelation::find()
289 .select_only()
290 .column(fragment_relation::Column::TargetFragmentId)
291 .filter(fragment_relation::Column::DispatcherType.eq(DispatcherType::NoShuffle))
292 .into_query();
293
294 let condition = Condition::all()
295 .add(fragment::Column::JobId.is_in(job_ids.clone()))
296 .add(fragment::Column::FragmentId.not_in_subquery(excluded_fragments_query));
297
298 let fragments: Vec<FragmentId> = Fragment::find()
299 .select_only()
300 .column(fragment::Column::FragmentId)
301 .filter(condition)
302 .into_tuple()
303 .all(txn)
304 .await?;
305
306 let ensembles = find_fragment_no_shuffle_dags_detailed(txn, &fragments).await?;
307
308 let fragments = Fragment::find()
309 .filter(
310 fragment::Column::FragmentId.is_in(
311 ensembles
312 .iter()
313 .flat_map(|graph| graph.components.iter())
314 .cloned()
315 .collect_vec(),
316 ),
317 )
318 .all(txn)
319 .await?;
320
321 let fragment_map: HashMap<_, _> = fragments
322 .into_iter()
323 .map(|fragment| (fragment.fragment_id, fragment))
324 .collect();
325
326 let job_ids = fragment_map
327 .values()
328 .map(|fragment| fragment.job_id)
329 .collect::<BTreeSet<_>>()
330 .into_iter()
331 .collect_vec();
332
333 let jobs: HashMap<_, _> = StreamingJob::find()
334 .filter(streaming_job::Column::JobId.is_in(job_ids))
335 .all(txn)
336 .await?
337 .into_iter()
338 .map(|job| (job.job_id, job))
339 .collect();
340
341 let fragments = render_no_shuffle_ensembles(
342 txn,
343 actor_id_counter,
344 &ensembles,
345 &fragment_map,
346 &jobs,
347 &workers,
348 adaptive_parallelism_strategy,
349 )
350 .await?;
351
352 Ok(RenderedGraph {
353 fragments,
354 ensembles,
355 })
356}
357
358async fn render_no_shuffle_ensembles<C>(
364 txn: &C,
365 actor_id_counter: &AtomicU32,
366 ensembles: &[NoShuffleEnsemble],
367 fragment_map: &HashMap<FragmentId, fragment::Model>,
368 job_map: &HashMap<JobId, streaming_job::Model>,
369 worker_map: &BTreeMap<WorkerId, WorkerInfo>,
370 adaptive_parallelism_strategy: AdaptiveParallelismStrategy,
371) -> MetaResult<FragmentRenderMap>
372where
373 C: ConnectionTrait,
374{
375 if ensembles.is_empty() {
376 return Ok(HashMap::new());
377 }
378
379 #[cfg(debug_assertions)]
380 {
381 debug_sanity_check(ensembles, fragment_map, job_map);
382 }
383
384 let (fragment_source_ids, fragment_splits) =
385 resolve_source_fragments(txn, fragment_map).await?;
386
387 let job_ids = job_map.keys().copied().collect_vec();
388
389 let streaming_job_databases: HashMap<JobId, _> = StreamingJob::find()
390 .select_only()
391 .column(streaming_job::Column::JobId)
392 .column(object::Column::DatabaseId)
393 .join(JoinType::LeftJoin, streaming_job::Relation::Object.def())
394 .filter(streaming_job::Column::JobId.is_in(job_ids))
395 .into_tuple()
396 .all(txn)
397 .await?
398 .into_iter()
399 .collect();
400
401 let database_map: HashMap<_, _> = Database::find()
402 .filter(
403 database::Column::DatabaseId
404 .is_in(streaming_job_databases.values().copied().collect_vec()),
405 )
406 .all(txn)
407 .await?
408 .into_iter()
409 .map(|db| (db.database_id, db))
410 .collect();
411
412 let context = RenderActorsContext {
413 fragment_source_ids: &fragment_source_ids,
414 fragment_splits: &fragment_splits,
415 streaming_job_databases: &streaming_job_databases,
416 database_map: &database_map,
417 };
418
419 render_actors(
420 actor_id_counter,
421 ensembles,
422 fragment_map,
423 job_map,
424 worker_map,
425 adaptive_parallelism_strategy,
426 context,
427 )
428}
429
430struct RenderActorsContext<'a> {
433 fragment_source_ids: &'a HashMap<FragmentId, SourceId>,
434 fragment_splits: &'a HashMap<FragmentId, Vec<SplitImpl>>,
435 streaming_job_databases: &'a HashMap<JobId, DatabaseId>,
436 database_map: &'a HashMap<DatabaseId, database::Model>,
437}
438
439fn render_actors(
440 actor_id_counter: &AtomicU32,
441 ensembles: &[NoShuffleEnsemble],
442 fragment_map: &HashMap<FragmentId, fragment::Model>,
443 job_map: &HashMap<JobId, streaming_job::Model>,
444 worker_map: &BTreeMap<WorkerId, WorkerInfo>,
445 adaptive_parallelism_strategy: AdaptiveParallelismStrategy,
446 context: RenderActorsContext<'_>,
447) -> MetaResult<FragmentRenderMap> {
448 let RenderActorsContext {
449 fragment_source_ids,
450 fragment_splits: fragment_splits_map,
451 streaming_job_databases,
452 database_map,
453 } = context;
454
455 let mut all_fragments: FragmentRenderMap = HashMap::new();
456
457 for NoShuffleEnsemble {
458 entries,
459 components,
460 } in ensembles
461 {
462 tracing::debug!("rendering ensemble entries {:?}", entries);
463
464 let entry_fragments = entries
465 .iter()
466 .map(|fragment_id| fragment_map.get(fragment_id).unwrap())
467 .collect_vec();
468
469 let entry_fragment_parallelism = entry_fragments
470 .iter()
471 .map(|fragment| fragment.parallelism.clone())
472 .dedup()
473 .exactly_one()
474 .map_err(|_| {
475 anyhow!(
476 "entry fragments {:?} have inconsistent parallelism settings",
477 entries.iter().copied().collect_vec()
478 )
479 })?;
480
481 let (job_id, vnode_count) = entry_fragments
482 .iter()
483 .map(|f| (f.job_id, f.vnode_count as usize))
484 .dedup()
485 .exactly_one()
486 .map_err(|_| anyhow!("Multiple jobs found in no-shuffle ensemble"))?;
487
488 let job = job_map
489 .get(&job_id)
490 .ok_or_else(|| anyhow!("streaming job {job_id} not found"))?;
491
492 let resource_group = match &job.specific_resource_group {
493 None => {
494 let database = streaming_job_databases
495 .get(&job_id)
496 .and_then(|database_id| database_map.get(database_id))
497 .unwrap();
498 database.resource_group.clone()
499 }
500 Some(resource_group) => resource_group.clone(),
501 };
502
503 let available_workers: BTreeMap<WorkerId, NonZeroUsize> = worker_map
504 .iter()
505 .filter_map(|(worker_id, worker)| {
506 if worker
507 .resource_group
508 .as_deref()
509 .unwrap_or(DEFAULT_RESOURCE_GROUP)
510 == resource_group.as_str()
511 {
512 Some((*worker_id, worker.parallelism))
513 } else {
514 None
515 }
516 })
517 .collect();
518
519 let total_parallelism = available_workers.values().map(|w| w.get()).sum::<usize>();
520
521 let actual_parallelism = match entry_fragment_parallelism
522 .as_ref()
523 .unwrap_or(&job.parallelism)
524 {
525 StreamingParallelism::Adaptive | StreamingParallelism::Custom => {
526 adaptive_parallelism_strategy.compute_target_parallelism(total_parallelism)
527 }
528 StreamingParallelism::Fixed(n) => *n,
529 }
530 .min(vnode_count)
531 .min(job.max_parallelism as usize);
532
533 tracing::debug!(
534 "job {}, final {} parallelism {:?} total_parallelism {} job_max {} vnode count {} fragment_override {:?}",
535 job_id,
536 actual_parallelism,
537 job.parallelism,
538 total_parallelism,
539 job.max_parallelism,
540 vnode_count,
541 entry_fragment_parallelism
542 );
543
544 let assigner = AssignerBuilder::new(job_id).build();
545
546 let actors = (0..(actual_parallelism as u32))
547 .map_into::<ActorId>()
548 .collect_vec();
549 let vnodes = (0..vnode_count).collect_vec();
550
551 let assignment = assigner.assign_hierarchical(&available_workers, &actors, &vnodes)?;
552
553 let source_entry_fragment = entry_fragments.iter().find(|f| {
554 let mask = FragmentTypeMask::from(f.fragment_type_mask);
555 if mask.contains(FragmentTypeFlag::Source) {
556 assert!(!mask.contains(FragmentTypeFlag::SourceScan))
557 }
558 mask.contains(FragmentTypeFlag::Source) && !mask.contains(FragmentTypeFlag::Dml)
559 });
560
561 let (fragment_splits, shared_source_id) = match source_entry_fragment {
562 Some(entry_fragment) => {
563 let source_id = fragment_source_ids
564 .get(&entry_fragment.fragment_id)
565 .ok_or_else(|| {
566 anyhow!(
567 "missing source id in source fragment {}",
568 entry_fragment.fragment_id
569 )
570 })?;
571
572 let entry_fragment_id = entry_fragment.fragment_id;
573
574 let empty_actor_splits: HashMap<_, _> =
575 actors.iter().map(|actor_id| (*actor_id, vec![])).collect();
576
577 let splits = fragment_splits_map
578 .get(&entry_fragment_id)
579 .cloned()
580 .unwrap_or_default();
581
582 let splits: BTreeMap<_, _> = splits.into_iter().map(|s| (s.id(), s)).collect();
583
584 let fragment_splits = crate::stream::source_manager::reassign_splits(
585 entry_fragment_id,
586 empty_actor_splits,
587 &splits,
588 SplitDiffOptions::default(),
589 )
590 .unwrap_or_default();
591 (fragment_splits, Some(*source_id))
592 }
593 None => (HashMap::new(), None),
594 };
595
596 for component_fragment_id in components {
597 let &fragment::Model {
598 fragment_id,
599 job_id,
600 fragment_type_mask,
601 distribution_type,
602 ref stream_node,
603 ref state_table_ids,
604 ..
605 } = fragment_map.get(component_fragment_id).unwrap();
606
607 let actor_count =
608 u32::try_from(actors.len()).expect("actor parallelism exceeds u32::MAX");
609 let actor_id_base = actor_id_counter.fetch_add(actor_count, Ordering::Relaxed);
610
611 let actors: HashMap<ActorId, InflightActorInfo> = assignment
612 .iter()
613 .flat_map(|(worker_id, actors)| {
614 actors
615 .iter()
616 .map(move |(actor_id, vnodes)| (worker_id, actor_id, vnodes))
617 })
618 .map(|(&worker_id, &actor_idx, vnodes)| {
619 let vnode_bitmap = match distribution_type {
620 DistributionType::Single => None,
621 DistributionType::Hash => Some(Bitmap::from_indices(vnode_count, vnodes)),
622 };
623
624 let actor_id = actor_idx + actor_id_base;
625
626 let splits = if let Some(source_id) = fragment_source_ids.get(&fragment_id) {
627 assert_eq!(shared_source_id, Some(*source_id));
628
629 fragment_splits
630 .get(&(actor_idx))
631 .cloned()
632 .unwrap_or_default()
633 } else {
634 vec![]
635 };
636
637 (
638 actor_id,
639 InflightActorInfo {
640 worker_id,
641 vnode_bitmap,
642 splits,
643 },
644 )
645 })
646 .collect();
647
648 let fragment = InflightFragmentInfo {
649 fragment_id,
650 distribution_type,
651 fragment_type_mask: fragment_type_mask.into(),
652 vnode_count,
653 nodes: stream_node.to_protobuf(),
654 actors,
655 state_table_ids: state_table_ids.inner_ref().iter().copied().collect(),
656 };
657
658 let &database_id = streaming_job_databases.get(&job_id).ok_or_else(|| {
659 anyhow!("streaming job {job_id} not found in streaming_job_databases")
660 })?;
661
662 all_fragments
663 .entry(database_id)
664 .or_default()
665 .entry(job_id)
666 .or_default()
667 .insert(fragment_id, fragment);
668 }
669 }
670
671 Ok(all_fragments)
672}
673
674#[cfg(debug_assertions)]
675fn debug_sanity_check(
676 ensembles: &[NoShuffleEnsemble],
677 fragment_map: &HashMap<FragmentId, fragment::Model>,
678 jobs: &HashMap<JobId, streaming_job::Model>,
679) {
680 debug_assert!(
682 ensembles
683 .iter()
684 .all(|ensemble| ensemble.entries.is_subset(&ensemble.components)),
685 "entries must be subset of components"
686 );
687
688 let mut missing_fragments = BTreeSet::new();
689 let mut missing_jobs = BTreeSet::new();
690
691 for fragment_id in ensembles
692 .iter()
693 .flat_map(|ensemble| ensemble.components.iter())
694 {
695 match fragment_map.get(fragment_id) {
696 Some(fragment) => {
697 if !jobs.contains_key(&fragment.job_id) {
698 missing_jobs.insert(fragment.job_id);
699 }
700 }
701 None => {
702 missing_fragments.insert(*fragment_id);
703 }
704 }
705 }
706
707 debug_assert!(
708 missing_fragments.is_empty(),
709 "missing fragments in fragment_map: {:?}",
710 missing_fragments
711 );
712
713 debug_assert!(
714 missing_jobs.is_empty(),
715 "missing jobs for fragments' job_id: {:?}",
716 missing_jobs
717 );
718
719 for ensemble in ensembles {
720 let unique_vnode_counts: Vec<_> = ensemble
721 .components
722 .iter()
723 .flat_map(|fragment_id| {
724 fragment_map
725 .get(fragment_id)
726 .map(|fragment| fragment.vnode_count)
727 })
728 .unique()
729 .collect();
730
731 debug_assert!(
732 unique_vnode_counts.len() <= 1,
733 "components in ensemble must share same vnode_count: ensemble={:?}, vnode_counts={:?}",
734 ensemble.components,
735 unique_vnode_counts
736 );
737 }
738}
739
740async fn resolve_source_fragments<C>(
741 txn: &C,
742 fragment_map: &HashMap<FragmentId, fragment::Model>,
743) -> MetaResult<(
744 HashMap<FragmentId, SourceId>,
745 HashMap<FragmentId, Vec<SplitImpl>>,
746)>
747where
748 C: ConnectionTrait,
749{
750 let mut source_fragment_ids: HashMap<SourceId, _> = HashMap::new();
751 for (fragment_id, fragment) in fragment_map {
752 let mask = FragmentTypeMask::from(fragment.fragment_type_mask);
753 if mask.contains(FragmentTypeFlag::Source)
754 && let Some(source_id) = fragment.stream_node.to_protobuf().find_stream_source()
755 {
756 source_fragment_ids
757 .entry(source_id)
758 .or_insert_with(BTreeSet::new)
759 .insert(fragment_id);
760 }
761
762 if mask.contains(FragmentTypeFlag::SourceScan)
763 && let Some((source_id, _)) = fragment.stream_node.to_protobuf().find_source_backfill()
764 {
765 source_fragment_ids
766 .entry(source_id)
767 .or_insert_with(BTreeSet::new)
768 .insert(fragment_id);
769 }
770 }
771
772 let fragment_source_ids: HashMap<_, _> = source_fragment_ids
773 .iter()
774 .flat_map(|(source_id, fragment_ids)| {
775 fragment_ids
776 .iter()
777 .map(|fragment_id| (**fragment_id, *source_id as SourceId))
778 })
779 .collect();
780
781 let fragment_ids = fragment_source_ids.keys().copied().collect_vec();
782
783 let fragment_splits: Vec<_> = FragmentSplits::find()
784 .filter(fragment_splits::Column::FragmentId.is_in(fragment_ids))
785 .all(txn)
786 .await?;
787
788 let fragment_splits: HashMap<_, _> = fragment_splits
789 .into_iter()
790 .flat_map(|model| {
791 model.splits.map(|splits| {
792 (
793 model.fragment_id,
794 splits
795 .to_protobuf()
796 .splits
797 .iter()
798 .flat_map(SplitImpl::try_from)
799 .collect_vec(),
800 )
801 })
802 })
803 .collect();
804
805 Ok((fragment_source_ids, fragment_splits))
806}
807
808#[derive(Debug)]
810pub struct ActorGraph<'a> {
811 pub fragments: &'a HashMap<FragmentId, (Fragment, Vec<StreamActor>)>,
812 pub locations: &'a HashMap<ActorId, WorkerId>,
813}
814
815#[derive(Debug, Clone)]
816pub struct NoShuffleEnsemble {
817 entries: HashSet<FragmentId>,
818 components: HashSet<FragmentId>,
819}
820
821impl NoShuffleEnsemble {
822 pub fn fragments(&self) -> impl Iterator<Item = FragmentId> + '_ {
823 self.components.iter().cloned()
824 }
825
826 pub fn entry_fragments(&self) -> impl Iterator<Item = FragmentId> + '_ {
827 self.entries.iter().copied()
828 }
829
830 pub fn component_fragments(&self) -> impl Iterator<Item = FragmentId> + '_ {
831 self.components.iter().copied()
832 }
833
834 pub fn contains_entry(&self, fragment_id: &FragmentId) -> bool {
835 self.entries.contains(fragment_id)
836 }
837}
838
839pub async fn find_fragment_no_shuffle_dags_detailed(
840 db: &impl ConnectionTrait,
841 initial_fragment_ids: &[FragmentId],
842) -> MetaResult<Vec<NoShuffleEnsemble>> {
843 let all_no_shuffle_relations: Vec<(_, _)> = FragmentRelation::find()
844 .columns([
845 fragment_relation::Column::SourceFragmentId,
846 fragment_relation::Column::TargetFragmentId,
847 ])
848 .filter(fragment_relation::Column::DispatcherType.eq(DispatcherType::NoShuffle))
849 .into_tuple()
850 .all(db)
851 .await?;
852
853 let mut forward_edges: HashMap<FragmentId, Vec<FragmentId>> = HashMap::new();
854 let mut backward_edges: HashMap<FragmentId, Vec<FragmentId>> = HashMap::new();
855
856 for (src, dst) in all_no_shuffle_relations {
857 forward_edges.entry(src).or_default().push(dst);
858 backward_edges.entry(dst).or_default().push(src);
859 }
860
861 find_no_shuffle_graphs(initial_fragment_ids, &forward_edges, &backward_edges)
862}
863
864fn find_no_shuffle_graphs(
865 initial_fragment_ids: &[impl Into<FragmentId> + Copy],
866 forward_edges: &HashMap<FragmentId, Vec<FragmentId>>,
867 backward_edges: &HashMap<FragmentId, Vec<FragmentId>>,
868) -> MetaResult<Vec<NoShuffleEnsemble>> {
869 let mut graphs: Vec<NoShuffleEnsemble> = Vec::new();
870 let mut globally_visited: HashSet<FragmentId> = HashSet::new();
871
872 for &init_id in initial_fragment_ids {
873 let init_id = init_id.into();
874 if globally_visited.contains(&init_id) {
875 continue;
876 }
877
878 let mut components = HashSet::new();
880 let mut queue: VecDeque<FragmentId> = VecDeque::new();
881
882 queue.push_back(init_id);
883 globally_visited.insert(init_id);
884
885 while let Some(current_id) = queue.pop_front() {
886 components.insert(current_id);
887 let neighbors = forward_edges
888 .get(¤t_id)
889 .into_iter()
890 .flatten()
891 .chain(backward_edges.get(¤t_id).into_iter().flatten());
892
893 for &neighbor_id in neighbors {
894 if globally_visited.insert(neighbor_id) {
895 queue.push_back(neighbor_id);
896 }
897 }
898 }
899
900 let mut entries = HashSet::new();
902 for &node_id in &components {
903 let is_root = match backward_edges.get(&node_id) {
904 Some(parents) => parents.iter().all(|p| !components.contains(p)),
905 None => true,
906 };
907 if is_root {
908 entries.insert(node_id);
909 }
910 }
911
912 if !entries.is_empty() {
914 graphs.push(NoShuffleEnsemble {
915 entries,
916 components,
917 });
918 }
919 }
920
921 Ok(graphs)
922}
923
924#[cfg(test)]
925mod tests {
926 use std::collections::{BTreeSet, HashMap, HashSet};
927 use std::sync::Arc;
928
929 use risingwave_connector::source::SplitImpl;
930 use risingwave_connector::source::test_source::TestSourceSplit;
931 use risingwave_meta_model::{CreateType, I32Array, JobStatus, StreamNode, TableIdArray};
932 use risingwave_pb::stream_plan::StreamNode as PbStreamNode;
933
934 use super::*;
935
936 type Edges = (
939 HashMap<FragmentId, Vec<FragmentId>>,
940 HashMap<FragmentId, Vec<FragmentId>>,
941 );
942
943 fn build_edges(relations: &[(u32, u32)]) -> Edges {
946 let mut forward_edges: HashMap<FragmentId, Vec<FragmentId>> = HashMap::new();
947 let mut backward_edges: HashMap<FragmentId, Vec<FragmentId>> = HashMap::new();
948 for &(src, dst) in relations {
949 forward_edges
950 .entry(src.into())
951 .or_default()
952 .push(dst.into());
953 backward_edges
954 .entry(dst.into())
955 .or_default()
956 .push(src.into());
957 }
958 (forward_edges, backward_edges)
959 }
960
961 fn to_hashset(ids: &[u32]) -> HashSet<FragmentId> {
963 ids.iter().map(|id| (*id).into()).collect()
964 }
965
966 #[allow(deprecated)]
967 fn build_fragment(
968 fragment_id: FragmentId,
969 job_id: JobId,
970 fragment_type_mask: i32,
971 distribution_type: DistributionType,
972 vnode_count: i32,
973 parallelism: StreamingParallelism,
974 ) -> fragment::Model {
975 fragment::Model {
976 fragment_id,
977 job_id,
978 fragment_type_mask,
979 distribution_type,
980 stream_node: StreamNode::from(&PbStreamNode::default()),
981 state_table_ids: TableIdArray::default(),
982 upstream_fragment_id: I32Array::default(),
983 vnode_count,
984 parallelism: Some(parallelism),
985 }
986 }
987
988 type ActorState = (ActorId, WorkerId, Option<Vec<usize>>, Vec<String>);
989
990 fn collect_actor_state(fragment: &InflightFragmentInfo) -> Vec<ActorState> {
991 let base = fragment.actors.keys().copied().min().unwrap_or_default();
992
993 let mut entries: Vec<_> = fragment
994 .actors
995 .iter()
996 .map(|(&actor_id, info)| {
997 let idx = actor_id.as_raw_id() - base.as_raw_id();
998 let vnode_indices = info.vnode_bitmap.as_ref().map(|bitmap| {
999 bitmap
1000 .iter()
1001 .enumerate()
1002 .filter_map(|(pos, is_set)| is_set.then_some(pos))
1003 .collect::<Vec<_>>()
1004 });
1005 let splits = info
1006 .splits
1007 .iter()
1008 .map(|split| split.id().to_string())
1009 .collect::<Vec<_>>();
1010 (idx.into(), info.worker_id, vnode_indices, splits)
1011 })
1012 .collect();
1013
1014 entries.sort_by_key(|(idx, _, _, _)| *idx);
1015 entries
1016 }
1017
1018 #[test]
1019 fn test_single_linear_chain() {
1020 let (forward, backward) = build_edges(&[(1, 2), (2, 3)]);
1023 let initial_ids = &[2];
1024
1025 let result = find_no_shuffle_graphs(initial_ids, &forward, &backward);
1027
1028 assert!(result.is_ok());
1030 let graphs = result.unwrap();
1031
1032 assert_eq!(graphs.len(), 1);
1033 let graph = &graphs[0];
1034 assert_eq!(graph.entries, to_hashset(&[1]));
1035 assert_eq!(graph.components, to_hashset(&[1, 2, 3]));
1036 }
1037
1038 #[test]
1039 fn test_two_disconnected_graphs() {
1040 let (forward, backward) = build_edges(&[(1, 2), (10, 11)]);
1043 let initial_ids = &[2, 10];
1044
1045 let mut graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1047
1048 assert_eq!(graphs.len(), 2);
1050
1051 graphs.sort_by_key(|g| *g.components.iter().min().unwrap_or(&0.into()));
1053
1054 assert_eq!(graphs[0].entries, to_hashset(&[1]));
1056 assert_eq!(graphs[0].components, to_hashset(&[1, 2]));
1057
1058 assert_eq!(graphs[1].entries, to_hashset(&[10]));
1060 assert_eq!(graphs[1].components, to_hashset(&[10, 11]));
1061 }
1062
1063 #[test]
1064 fn test_multiple_entries_in_one_graph() {
1065 let (forward, backward) = build_edges(&[(1, 3), (2, 3)]);
1067 let initial_ids = &[3];
1068
1069 let graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1071
1072 assert_eq!(graphs.len(), 1);
1074 let graph = &graphs[0];
1075 assert_eq!(graph.entries, to_hashset(&[1, 2]));
1076 assert_eq!(graph.components, to_hashset(&[1, 2, 3]));
1077 }
1078
1079 #[test]
1080 fn test_diamond_shape_graph() {
1081 let (forward, backward) = build_edges(&[(1, 2), (1, 3), (2, 4), (3, 4)]);
1083 let initial_ids = &[4];
1084
1085 let graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1087
1088 assert_eq!(graphs.len(), 1);
1090 let graph = &graphs[0];
1091 assert_eq!(graph.entries, to_hashset(&[1]));
1092 assert_eq!(graph.components, to_hashset(&[1, 2, 3, 4]));
1093 }
1094
1095 #[test]
1096 fn test_starting_with_multiple_nodes_in_same_graph() {
1097 let (forward, backward) = build_edges(&[(1, 2), (2, 3), (3, 4)]);
1100 let initial_ids = &[2, 4];
1101
1102 let graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1104
1105 assert_eq!(graphs.len(), 1);
1107 let graph = &graphs[0];
1108 assert_eq!(graph.entries, to_hashset(&[1]));
1109 assert_eq!(graph.components, to_hashset(&[1, 2, 3, 4]));
1110 }
1111
1112 #[test]
1113 fn test_empty_initial_ids() {
1114 let (forward, backward) = build_edges(&[(1, 2)]);
1116 let initial_ids: &[u32] = &[];
1117
1118 let graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1120
1121 assert!(graphs.is_empty());
1123 }
1124
1125 #[test]
1126 fn test_isolated_node_as_input() {
1127 let (forward, backward) = build_edges(&[(1, 2)]);
1129 let initial_ids = &[100];
1130
1131 let graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1133
1134 assert_eq!(graphs.len(), 1);
1136 let graph = &graphs[0];
1137 assert_eq!(graph.entries, to_hashset(&[100]));
1138 assert_eq!(graph.components, to_hashset(&[100]));
1139 }
1140
1141 #[test]
1142 fn test_graph_with_a_cycle() {
1143 let (forward, backward) = build_edges(&[(1, 2), (2, 3), (3, 1)]);
1148 let initial_ids = &[2];
1149
1150 let graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1152
1153 assert!(
1155 graphs.is_empty(),
1156 "A graph with no entries should not be returned"
1157 );
1158 }
1159 #[test]
1160 fn test_custom_complex() {
1161 let (forward, backward) = build_edges(&[(1, 3), (1, 8), (2, 3), (4, 3), (3, 5), (6, 7)]);
1162 let initial_ids = &[1, 2, 4, 6];
1163
1164 let mut graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1166
1167 assert_eq!(graphs.len(), 2);
1169 graphs.sort_by_key(|g| *g.components.iter().min().unwrap_or(&0.into()));
1171
1172 assert_eq!(graphs[0].entries, to_hashset(&[1, 2, 4]));
1174 assert_eq!(graphs[0].components, to_hashset(&[1, 2, 3, 4, 5, 8]));
1175
1176 assert_eq!(graphs[1].entries, to_hashset(&[6]));
1178 assert_eq!(graphs[1].components, to_hashset(&[6, 7]));
1179 }
1180
1181 #[test]
1182 fn render_actors_increments_actor_counter() {
1183 let actor_id_counter = AtomicU32::new(100);
1184 let fragment_id: FragmentId = 1.into();
1185 let job_id: JobId = 10.into();
1186 let database_id: DatabaseId = DatabaseId::new(3);
1187
1188 let fragment_model = build_fragment(
1189 fragment_id,
1190 job_id,
1191 0,
1192 DistributionType::Single,
1193 1,
1194 StreamingParallelism::Fixed(1),
1195 );
1196
1197 let job_model = streaming_job::Model {
1198 job_id,
1199 job_status: JobStatus::Created,
1200 create_type: CreateType::Foreground,
1201 timezone: None,
1202 config_override: None,
1203 parallelism: StreamingParallelism::Fixed(1),
1204 max_parallelism: 1,
1205 specific_resource_group: None,
1206 };
1207
1208 let database_model = database::Model {
1209 database_id,
1210 name: "test_db".into(),
1211 resource_group: "rg-a".into(),
1212 barrier_interval_ms: None,
1213 checkpoint_frequency: None,
1214 };
1215
1216 let ensembles = vec![NoShuffleEnsemble {
1217 entries: HashSet::from([fragment_id]),
1218 components: HashSet::from([fragment_id]),
1219 }];
1220
1221 let fragment_map = HashMap::from([(fragment_id, fragment_model)]);
1222 let job_map = HashMap::from([(job_id, job_model)]);
1223
1224 let worker_map = BTreeMap::from([(
1225 1.into(),
1226 WorkerInfo {
1227 parallelism: NonZeroUsize::new(1).unwrap(),
1228 resource_group: Some("rg-a".into()),
1229 },
1230 )]);
1231
1232 let fragment_source_ids: HashMap<FragmentId, SourceId> = HashMap::new();
1233 let fragment_splits: HashMap<FragmentId, Vec<SplitImpl>> = HashMap::new();
1234 let streaming_job_databases = HashMap::from([(job_id, database_id)]);
1235 let database_map = HashMap::from([(database_id, database_model)]);
1236
1237 let context = RenderActorsContext {
1238 fragment_source_ids: &fragment_source_ids,
1239 fragment_splits: &fragment_splits,
1240 streaming_job_databases: &streaming_job_databases,
1241 database_map: &database_map,
1242 };
1243
1244 let result = render_actors(
1245 &actor_id_counter,
1246 &ensembles,
1247 &fragment_map,
1248 &job_map,
1249 &worker_map,
1250 AdaptiveParallelismStrategy::Auto,
1251 context,
1252 )
1253 .expect("actor rendering succeeds");
1254
1255 let state = collect_actor_state(&result[&database_id][&job_id][&fragment_id]);
1256 assert_eq!(state.len(), 1);
1257 assert!(
1258 state[0].2.is_none(),
1259 "single distribution should not assign vnode bitmaps"
1260 );
1261 assert_eq!(actor_id_counter.load(Ordering::Relaxed), 101);
1262 }
1263
1264 #[test]
1265 fn render_actors_aligns_hash_vnode_bitmaps() {
1266 let actor_id_counter = AtomicU32::new(0);
1267 let entry_fragment_id: FragmentId = 1.into();
1268 let downstream_fragment_id: FragmentId = 2.into();
1269 let job_id: JobId = 20.into();
1270 let database_id: DatabaseId = DatabaseId::new(5);
1271
1272 let entry_fragment = build_fragment(
1273 entry_fragment_id,
1274 job_id,
1275 0,
1276 DistributionType::Hash,
1277 4,
1278 StreamingParallelism::Fixed(2),
1279 );
1280
1281 let downstream_fragment = build_fragment(
1282 downstream_fragment_id,
1283 job_id,
1284 0,
1285 DistributionType::Hash,
1286 4,
1287 StreamingParallelism::Fixed(2),
1288 );
1289
1290 let job_model = streaming_job::Model {
1291 job_id,
1292 job_status: JobStatus::Created,
1293 create_type: CreateType::Background,
1294 timezone: None,
1295 config_override: None,
1296 parallelism: StreamingParallelism::Fixed(2),
1297 max_parallelism: 2,
1298 specific_resource_group: None,
1299 };
1300
1301 let database_model = database::Model {
1302 database_id,
1303 name: "test_db_hash".into(),
1304 resource_group: "rg-hash".into(),
1305 barrier_interval_ms: None,
1306 checkpoint_frequency: None,
1307 };
1308
1309 let ensembles = vec![NoShuffleEnsemble {
1310 entries: HashSet::from([entry_fragment_id]),
1311 components: HashSet::from([entry_fragment_id, downstream_fragment_id]),
1312 }];
1313
1314 let fragment_map = HashMap::from([
1315 (entry_fragment_id, entry_fragment),
1316 (downstream_fragment_id, downstream_fragment),
1317 ]);
1318 let job_map = HashMap::from([(job_id, job_model)]);
1319
1320 let worker_map = BTreeMap::from([
1321 (
1322 1.into(),
1323 WorkerInfo {
1324 parallelism: NonZeroUsize::new(1).unwrap(),
1325 resource_group: Some("rg-hash".into()),
1326 },
1327 ),
1328 (
1329 2.into(),
1330 WorkerInfo {
1331 parallelism: NonZeroUsize::new(1).unwrap(),
1332 resource_group: Some("rg-hash".into()),
1333 },
1334 ),
1335 ]);
1336
1337 let fragment_source_ids: HashMap<FragmentId, SourceId> = HashMap::new();
1338 let fragment_splits: HashMap<FragmentId, Vec<SplitImpl>> = HashMap::new();
1339 let streaming_job_databases = HashMap::from([(job_id, database_id)]);
1340 let database_map = HashMap::from([(database_id, database_model)]);
1341
1342 let context = RenderActorsContext {
1343 fragment_source_ids: &fragment_source_ids,
1344 fragment_splits: &fragment_splits,
1345 streaming_job_databases: &streaming_job_databases,
1346 database_map: &database_map,
1347 };
1348
1349 let result = render_actors(
1350 &actor_id_counter,
1351 &ensembles,
1352 &fragment_map,
1353 &job_map,
1354 &worker_map,
1355 AdaptiveParallelismStrategy::Auto,
1356 context,
1357 )
1358 .expect("actor rendering succeeds");
1359
1360 let entry_state = collect_actor_state(&result[&database_id][&job_id][&entry_fragment_id]);
1361 let downstream_state =
1362 collect_actor_state(&result[&database_id][&job_id][&downstream_fragment_id]);
1363
1364 assert_eq!(entry_state.len(), 2);
1365 assert_eq!(entry_state, downstream_state);
1366
1367 let assigned_vnodes: BTreeSet<_> = entry_state
1368 .iter()
1369 .flat_map(|(_, _, vnodes, _)| {
1370 vnodes
1371 .as_ref()
1372 .expect("hash distribution should populate vnode bitmap")
1373 .iter()
1374 .copied()
1375 })
1376 .collect();
1377 assert_eq!(assigned_vnodes, BTreeSet::from([0, 1, 2, 3]));
1378 assert_eq!(actor_id_counter.load(Ordering::Relaxed), 4);
1379 }
1380
1381 #[test]
1382 fn render_actors_propagates_source_splits() {
1383 let actor_id_counter = AtomicU32::new(0);
1384 let entry_fragment_id: FragmentId = 11.into();
1385 let downstream_fragment_id: FragmentId = 12.into();
1386 let job_id: JobId = 30.into();
1387 let database_id: DatabaseId = DatabaseId::new(7);
1388 let source_id: SourceId = 99.into();
1389
1390 let source_mask = FragmentTypeFlag::raw_flag([FragmentTypeFlag::Source]) as i32;
1391 let source_scan_mask = FragmentTypeFlag::raw_flag([FragmentTypeFlag::SourceScan]) as i32;
1392
1393 let entry_fragment = build_fragment(
1394 entry_fragment_id,
1395 job_id,
1396 source_mask,
1397 DistributionType::Hash,
1398 4,
1399 StreamingParallelism::Fixed(2),
1400 );
1401
1402 let downstream_fragment = build_fragment(
1403 downstream_fragment_id,
1404 job_id,
1405 source_scan_mask,
1406 DistributionType::Hash,
1407 4,
1408 StreamingParallelism::Fixed(2),
1409 );
1410
1411 let job_model = streaming_job::Model {
1412 job_id,
1413 job_status: JobStatus::Created,
1414 create_type: CreateType::Background,
1415 timezone: None,
1416 config_override: None,
1417 parallelism: StreamingParallelism::Fixed(2),
1418 max_parallelism: 2,
1419 specific_resource_group: None,
1420 };
1421
1422 let database_model = database::Model {
1423 database_id,
1424 name: "split_db".into(),
1425 resource_group: "rg-source".into(),
1426 barrier_interval_ms: None,
1427 checkpoint_frequency: None,
1428 };
1429
1430 let ensembles = vec![NoShuffleEnsemble {
1431 entries: HashSet::from([entry_fragment_id]),
1432 components: HashSet::from([entry_fragment_id, downstream_fragment_id]),
1433 }];
1434
1435 let fragment_map = HashMap::from([
1436 (entry_fragment_id, entry_fragment),
1437 (downstream_fragment_id, downstream_fragment),
1438 ]);
1439 let job_map = HashMap::from([(job_id, job_model)]);
1440
1441 let worker_map = BTreeMap::from([
1442 (
1443 1.into(),
1444 WorkerInfo {
1445 parallelism: NonZeroUsize::new(1).unwrap(),
1446 resource_group: Some("rg-source".into()),
1447 },
1448 ),
1449 (
1450 2.into(),
1451 WorkerInfo {
1452 parallelism: NonZeroUsize::new(1).unwrap(),
1453 resource_group: Some("rg-source".into()),
1454 },
1455 ),
1456 ]);
1457
1458 let split_a = SplitImpl::Test(TestSourceSplit {
1459 id: Arc::<str>::from("split-a"),
1460 properties: HashMap::new(),
1461 offset: "0".into(),
1462 });
1463 let split_b = SplitImpl::Test(TestSourceSplit {
1464 id: Arc::<str>::from("split-b"),
1465 properties: HashMap::new(),
1466 offset: "0".into(),
1467 });
1468
1469 let fragment_source_ids = HashMap::from([
1470 (entry_fragment_id, source_id),
1471 (downstream_fragment_id, source_id),
1472 ]);
1473 let fragment_splits =
1474 HashMap::from([(entry_fragment_id, vec![split_a.clone(), split_b.clone()])]);
1475 let streaming_job_databases = HashMap::from([(job_id, database_id)]);
1476 let database_map = HashMap::from([(database_id, database_model)]);
1477
1478 let context = RenderActorsContext {
1479 fragment_source_ids: &fragment_source_ids,
1480 fragment_splits: &fragment_splits,
1481 streaming_job_databases: &streaming_job_databases,
1482 database_map: &database_map,
1483 };
1484
1485 let result = render_actors(
1486 &actor_id_counter,
1487 &ensembles,
1488 &fragment_map,
1489 &job_map,
1490 &worker_map,
1491 AdaptiveParallelismStrategy::Auto,
1492 context,
1493 )
1494 .expect("actor rendering succeeds");
1495
1496 let entry_state = collect_actor_state(&result[&database_id][&job_id][&entry_fragment_id]);
1497 let downstream_state =
1498 collect_actor_state(&result[&database_id][&job_id][&downstream_fragment_id]);
1499
1500 assert_eq!(entry_state, downstream_state);
1501
1502 let split_ids: BTreeSet<_> = entry_state
1503 .iter()
1504 .flat_map(|(_, _, _, splits)| splits.iter().cloned())
1505 .collect();
1506 assert_eq!(
1507 split_ids,
1508 BTreeSet::from([split_a.id().to_string(), split_b.id().to_string()])
1509 );
1510 assert_eq!(actor_id_counter.load(Ordering::Relaxed), 4);
1511 }
1512}