1use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque};
16use std::num::NonZeroUsize;
17use std::sync::atomic::{AtomicU32, Ordering};
18
19use anyhow::anyhow;
20use itertools::Itertools;
21use risingwave_common::bitmap::Bitmap;
22use risingwave_common::catalog;
23use risingwave_common::catalog::{FragmentTypeFlag, FragmentTypeMask};
24use risingwave_common::system_param::AdaptiveParallelismStrategy;
25use risingwave_common::util::worker_util::DEFAULT_RESOURCE_GROUP;
26use risingwave_connector::source::{SplitImpl, SplitMetaData};
27use risingwave_meta_model::fragment::DistributionType;
28use risingwave_meta_model::prelude::{
29 Database, Fragment, FragmentRelation, FragmentSplits, Sink, Source, StreamingJob, Table,
30};
31use risingwave_meta_model::{
32 DatabaseId, DispatcherType, FragmentId, ObjectId, SourceId, StreamingParallelism, TableId,
33 WorkerId, database, fragment, fragment_relation, fragment_splits, object, sink, source,
34 streaming_job, table,
35};
36use risingwave_meta_model_migration::Condition;
37use sea_orm::{
38 ColumnTrait, ConnectionTrait, EntityTrait, JoinType, QueryFilter, QuerySelect, QueryTrait,
39 RelationTrait,
40};
41
42use crate::MetaResult;
43use crate::controller::fragment::{InflightActorInfo, InflightFragmentInfo};
44use crate::manager::ActiveStreamingWorkerNodes;
45use crate::model::{ActorId, StreamActor};
46use crate::stream::{AssignerBuilder, SplitDiffOptions};
47
48pub(crate) async fn resolve_streaming_job_definition<C>(
49 txn: &C,
50 job_ids: &HashSet<ObjectId>,
51) -> MetaResult<HashMap<ObjectId, String>>
52where
53 C: ConnectionTrait,
54{
55 let job_ids = job_ids.iter().cloned().collect_vec();
56
57 let common_job_definitions: Vec<(ObjectId, String)> = Table::find()
59 .select_only()
60 .columns([
61 table::Column::TableId,
62 #[cfg(not(debug_assertions))]
63 table::Column::Name,
64 #[cfg(debug_assertions)]
65 table::Column::Definition,
66 ])
67 .filter(table::Column::TableId.is_in(job_ids.clone()))
68 .into_tuple()
69 .all(txn)
70 .await?;
71
72 let sink_definitions: Vec<(ObjectId, String)> = Sink::find()
73 .select_only()
74 .columns([
75 sink::Column::SinkId,
76 #[cfg(not(debug_assertions))]
77 sink::Column::Name,
78 #[cfg(debug_assertions)]
79 sink::Column::Definition,
80 ])
81 .filter(sink::Column::SinkId.is_in(job_ids.clone()))
82 .into_tuple()
83 .all(txn)
84 .await?;
85
86 let source_definitions: Vec<(ObjectId, String)> = Source::find()
87 .select_only()
88 .columns([
89 source::Column::SourceId,
90 #[cfg(not(debug_assertions))]
91 source::Column::Name,
92 #[cfg(debug_assertions)]
93 source::Column::Definition,
94 ])
95 .filter(source::Column::SourceId.is_in(job_ids.clone()))
96 .into_tuple()
97 .all(txn)
98 .await?;
99
100 let definitions: HashMap<ObjectId, String> = common_job_definitions
101 .into_iter()
102 .chain(sink_definitions.into_iter())
103 .chain(source_definitions.into_iter())
104 .collect();
105
106 Ok(definitions)
107}
108
109pub async fn load_fragment_info<C>(
110 txn: &C,
111 actor_id_counter: &AtomicU32,
112 database_id: Option<DatabaseId>,
113 worker_nodes: &ActiveStreamingWorkerNodes,
114 adaptive_parallelism_strategy: AdaptiveParallelismStrategy,
115) -> MetaResult<FragmentRenderMap>
116where
117 C: ConnectionTrait,
118{
119 let mut query = StreamingJob::find()
120 .select_only()
121 .column(streaming_job::Column::JobId);
122
123 if let Some(database_id) = database_id {
124 query = query
125 .join(JoinType::InnerJoin, streaming_job::Relation::Object.def())
126 .filter(object::Column::DatabaseId.eq(database_id));
127 }
128
129 let jobs: Vec<ObjectId> = query.into_tuple().all(txn).await?;
130
131 if jobs.is_empty() {
132 return Ok(HashMap::new());
133 }
134
135 let jobs: HashSet<ObjectId> = jobs.into_iter().collect();
136
137 let available_workers: BTreeMap<_, _> = worker_nodes
138 .current()
139 .values()
140 .filter(|worker| worker.is_streaming_schedulable())
141 .map(|worker| {
142 (
143 worker.id as i32,
144 WorkerInfo {
145 parallelism: NonZeroUsize::new(worker.compute_node_parallelism()).unwrap(),
146 resource_group: worker.resource_group(),
147 },
148 )
149 })
150 .collect();
151
152 let RenderedGraph { fragments, .. } = render_jobs(
153 txn,
154 actor_id_counter,
155 jobs,
156 available_workers,
157 adaptive_parallelism_strategy,
158 )
159 .await?;
160
161 Ok(fragments)
162}
163
164#[derive(Debug)]
165pub struct TargetResourcePolicy {
166 pub resource_group: Option<String>,
167 pub parallelism: StreamingParallelism,
168}
169
170#[derive(Debug, Clone)]
171pub struct WorkerInfo {
172 pub parallelism: NonZeroUsize,
173 pub resource_group: Option<String>,
174}
175
176pub type FragmentRenderMap =
177 HashMap<DatabaseId, HashMap<TableId, HashMap<FragmentId, InflightFragmentInfo>>>;
178
179#[derive(Default)]
180pub struct RenderedGraph {
181 pub fragments: FragmentRenderMap,
182 pub ensembles: Vec<NoShuffleEnsemble>,
183}
184
185impl RenderedGraph {
186 pub fn empty() -> Self {
187 Self::default()
188 }
189}
190
191pub async fn render_fragments<C>(
196 txn: &C,
197 actor_id_counter: &AtomicU32,
198 ensembles: Vec<NoShuffleEnsemble>,
199 workers: BTreeMap<WorkerId, WorkerInfo>,
200 adaptive_parallelism_strategy: AdaptiveParallelismStrategy,
201) -> MetaResult<RenderedGraph>
202where
203 C: ConnectionTrait,
204{
205 if ensembles.is_empty() {
206 return Ok(RenderedGraph::empty());
207 }
208
209 let required_fragment_ids: HashSet<_> = ensembles
210 .iter()
211 .flat_map(|ensemble| ensemble.components.iter().copied())
212 .collect();
213
214 let fragment_models = Fragment::find()
215 .filter(fragment::Column::FragmentId.is_in(required_fragment_ids.iter().copied()))
216 .all(txn)
217 .await?;
218
219 let found_fragment_ids: HashSet<_> = fragment_models
220 .iter()
221 .map(|fragment| fragment.fragment_id)
222 .collect();
223
224 if found_fragment_ids.len() != required_fragment_ids.len() {
225 let missing = required_fragment_ids
226 .difference(&found_fragment_ids)
227 .copied()
228 .collect_vec();
229 return Err(anyhow!("fragments {:?} not found", missing).into());
230 }
231
232 let fragment_map: HashMap<_, _> = fragment_models
233 .into_iter()
234 .map(|fragment| (fragment.fragment_id, fragment))
235 .collect();
236
237 let job_ids: HashSet<_> = fragment_map
238 .values()
239 .map(|fragment| fragment.job_id)
240 .collect();
241
242 if job_ids.is_empty() {
243 return Ok(RenderedGraph::empty());
244 }
245
246 let jobs: HashMap<_, _> = StreamingJob::find()
247 .filter(streaming_job::Column::JobId.is_in(job_ids.iter().copied().collect_vec()))
248 .all(txn)
249 .await?
250 .into_iter()
251 .map(|job| (job.job_id, job))
252 .collect();
253
254 let found_job_ids: HashSet<_> = jobs.keys().copied().collect();
255 if found_job_ids.len() != job_ids.len() {
256 let missing = job_ids.difference(&found_job_ids).copied().collect_vec();
257 return Err(anyhow!("streaming jobs {:?} not found", missing).into());
258 }
259
260 let fragments = render_no_shuffle_ensembles(
261 txn,
262 actor_id_counter,
263 &ensembles,
264 &fragment_map,
265 &jobs,
266 &workers,
267 adaptive_parallelism_strategy,
268 )
269 .await?;
270
271 Ok(RenderedGraph {
272 fragments,
273 ensembles,
274 })
275}
276
277pub async fn render_jobs<C>(
280 txn: &C,
281 actor_id_counter: &AtomicU32,
282 job_ids: HashSet<ObjectId>,
283 workers: BTreeMap<WorkerId, WorkerInfo>,
284 adaptive_parallelism_strategy: AdaptiveParallelismStrategy,
285) -> MetaResult<RenderedGraph>
286where
287 C: ConnectionTrait,
288{
289 let excluded_fragments_query = FragmentRelation::find()
290 .select_only()
291 .column(fragment_relation::Column::TargetFragmentId)
292 .filter(fragment_relation::Column::DispatcherType.eq(DispatcherType::NoShuffle))
293 .into_query();
294
295 let condition = Condition::all()
296 .add(fragment::Column::JobId.is_in(job_ids.clone()))
297 .add(fragment::Column::FragmentId.not_in_subquery(excluded_fragments_query));
298
299 let fragments: Vec<FragmentId> = Fragment::find()
300 .select_only()
301 .column(fragment::Column::FragmentId)
302 .filter(condition)
303 .into_tuple()
304 .all(txn)
305 .await?;
306
307 let ensembles = find_fragment_no_shuffle_dags_detailed(txn, &fragments).await?;
308
309 let fragments = Fragment::find()
310 .filter(
311 fragment::Column::FragmentId.is_in(
312 ensembles
313 .iter()
314 .flat_map(|graph| graph.components.iter())
315 .cloned()
316 .collect_vec(),
317 ),
318 )
319 .all(txn)
320 .await?;
321
322 let fragment_map: HashMap<_, _> = fragments
323 .into_iter()
324 .map(|fragment| (fragment.fragment_id, fragment))
325 .collect();
326
327 let job_ids = fragment_map
328 .values()
329 .map(|fragment| fragment.job_id)
330 .collect::<BTreeSet<_>>()
331 .into_iter()
332 .collect_vec();
333
334 let jobs: HashMap<_, _> = StreamingJob::find()
335 .filter(streaming_job::Column::JobId.is_in(job_ids))
336 .all(txn)
337 .await?
338 .into_iter()
339 .map(|job| (job.job_id, job))
340 .collect();
341
342 let fragments = render_no_shuffle_ensembles(
343 txn,
344 actor_id_counter,
345 &ensembles,
346 &fragment_map,
347 &jobs,
348 &workers,
349 adaptive_parallelism_strategy,
350 )
351 .await?;
352
353 Ok(RenderedGraph {
354 fragments,
355 ensembles,
356 })
357}
358
359async fn render_no_shuffle_ensembles<C>(
365 txn: &C,
366 actor_id_counter: &AtomicU32,
367 ensembles: &[NoShuffleEnsemble],
368 fragment_map: &HashMap<FragmentId, fragment::Model>,
369 job_map: &HashMap<ObjectId, streaming_job::Model>,
370 worker_map: &BTreeMap<WorkerId, WorkerInfo>,
371 adaptive_parallelism_strategy: AdaptiveParallelismStrategy,
372) -> MetaResult<FragmentRenderMap>
373where
374 C: ConnectionTrait,
375{
376 if ensembles.is_empty() {
377 return Ok(HashMap::new());
378 }
379
380 #[cfg(debug_assertions)]
381 {
382 debug_sanity_check(ensembles, fragment_map, job_map);
383 }
384
385 let (fragment_source_ids, fragment_splits) =
386 resolve_source_fragments(txn, fragment_map).await?;
387
388 let job_ids = job_map.keys().copied().collect_vec();
389
390 let streaming_job_databases: HashMap<ObjectId, _> = StreamingJob::find()
391 .select_only()
392 .column(streaming_job::Column::JobId)
393 .column(object::Column::DatabaseId)
394 .join(JoinType::LeftJoin, streaming_job::Relation::Object.def())
395 .filter(streaming_job::Column::JobId.is_in(job_ids))
396 .into_tuple()
397 .all(txn)
398 .await?
399 .into_iter()
400 .collect();
401
402 let database_map: HashMap<_, _> = Database::find()
403 .filter(
404 database::Column::DatabaseId
405 .is_in(streaming_job_databases.values().copied().collect_vec()),
406 )
407 .all(txn)
408 .await?
409 .into_iter()
410 .map(|db| (db.database_id, db))
411 .collect();
412
413 let context = RenderActorsContext {
414 fragment_source_ids: &fragment_source_ids,
415 fragment_splits: &fragment_splits,
416 streaming_job_databases: &streaming_job_databases,
417 database_map: &database_map,
418 };
419
420 render_actors(
421 actor_id_counter,
422 ensembles,
423 fragment_map,
424 job_map,
425 worker_map,
426 adaptive_parallelism_strategy,
427 context,
428 )
429}
430
431struct RenderActorsContext<'a> {
434 fragment_source_ids: &'a HashMap<FragmentId, SourceId>,
435 fragment_splits: &'a HashMap<FragmentId, Vec<SplitImpl>>,
436 streaming_job_databases: &'a HashMap<ObjectId, DatabaseId>,
437 database_map: &'a HashMap<DatabaseId, database::Model>,
438}
439
440fn render_actors(
441 actor_id_counter: &AtomicU32,
442 ensembles: &[NoShuffleEnsemble],
443 fragment_map: &HashMap<FragmentId, fragment::Model>,
444 job_map: &HashMap<ObjectId, streaming_job::Model>,
445 worker_map: &BTreeMap<WorkerId, WorkerInfo>,
446 adaptive_parallelism_strategy: AdaptiveParallelismStrategy,
447 context: RenderActorsContext<'_>,
448) -> MetaResult<FragmentRenderMap> {
449 let RenderActorsContext {
450 fragment_source_ids,
451 fragment_splits: fragment_splits_map,
452 streaming_job_databases,
453 database_map,
454 } = context;
455
456 let mut all_fragments: FragmentRenderMap = HashMap::new();
457
458 for NoShuffleEnsemble {
459 entries,
460 components,
461 } in ensembles
462 {
463 tracing::debug!("rendering ensemble entries {:?}", entries);
464
465 let entry_fragments = entries
466 .iter()
467 .map(|fragment_id| fragment_map.get(fragment_id).unwrap())
468 .collect_vec();
469
470 let entry_fragment_parallelism = entry_fragments
471 .iter()
472 .map(|fragment| fragment.parallelism.clone())
473 .dedup()
474 .exactly_one()
475 .map_err(|_| {
476 anyhow!(
477 "entry fragments {:?} have inconsistent parallelism settings",
478 entries.iter().copied().collect_vec()
479 )
480 })?;
481
482 let (job_id, vnode_count) = entry_fragments
483 .iter()
484 .map(|f| (f.job_id, f.vnode_count as usize))
485 .dedup()
486 .exactly_one()
487 .map_err(|_| anyhow!("Multiple jobs found in no-shuffle ensemble"))?;
488
489 let job = job_map
490 .get(&job_id)
491 .ok_or_else(|| anyhow!("streaming job {job_id} not found"))?;
492
493 let resource_group = match &job.specific_resource_group {
494 None => {
495 let database = streaming_job_databases
496 .get(&job_id)
497 .and_then(|database_id| database_map.get(database_id))
498 .unwrap();
499 database.resource_group.clone()
500 }
501 Some(resource_group) => resource_group.clone(),
502 };
503
504 let available_workers: BTreeMap<WorkerId, NonZeroUsize> = worker_map
505 .iter()
506 .filter_map(|(worker_id, worker)| {
507 if worker
508 .resource_group
509 .as_deref()
510 .unwrap_or(DEFAULT_RESOURCE_GROUP)
511 == resource_group.as_str()
512 {
513 Some((*worker_id, worker.parallelism))
514 } else {
515 None
516 }
517 })
518 .collect();
519
520 let total_parallelism = available_workers.values().map(|w| w.get()).sum::<usize>();
521
522 let actual_parallelism = match entry_fragment_parallelism
523 .as_ref()
524 .unwrap_or(&job.parallelism)
525 {
526 StreamingParallelism::Adaptive | StreamingParallelism::Custom => {
527 adaptive_parallelism_strategy.compute_target_parallelism(total_parallelism)
528 }
529 StreamingParallelism::Fixed(n) => *n,
530 }
531 .min(vnode_count)
532 .min(job.max_parallelism as usize);
533
534 tracing::debug!(
535 "job {}, final {} parallelism {:?} total_parallelism {} job_max {} vnode count {} fragment_override {:?}",
536 job_id,
537 actual_parallelism,
538 job.parallelism,
539 total_parallelism,
540 job.max_parallelism,
541 vnode_count,
542 entry_fragment_parallelism
543 );
544
545 let assigner = AssignerBuilder::new(job_id).build();
546
547 let actors = (0..actual_parallelism).collect_vec();
548 let vnodes = (0..vnode_count).collect_vec();
549
550 let assignment = assigner.assign_hierarchical(&available_workers, &actors, &vnodes)?;
551
552 let source_entry_fragment = entry_fragments.iter().find(|f| {
553 let mask = FragmentTypeMask::from(f.fragment_type_mask);
554 if mask.contains(FragmentTypeFlag::Source) {
555 assert!(!mask.contains(FragmentTypeFlag::SourceScan))
556 }
557 mask.contains(FragmentTypeFlag::Source) && !mask.contains(FragmentTypeFlag::Dml)
558 });
559
560 let (fragment_splits, shared_source_id) = match source_entry_fragment {
561 Some(entry_fragment) => {
562 let source_id = fragment_source_ids
563 .get(&entry_fragment.fragment_id)
564 .ok_or_else(|| {
565 anyhow!(
566 "missing source id in source fragment {}",
567 entry_fragment.fragment_id
568 )
569 })?;
570
571 let entry_fragment_id = entry_fragment.fragment_id;
572
573 let empty_actor_splits: HashMap<_, _> = actors
574 .iter()
575 .map(|actor_id| (*actor_id as u32, vec![]))
576 .collect();
577
578 let splits = fragment_splits_map
579 .get(&entry_fragment_id)
580 .cloned()
581 .unwrap_or_default();
582
583 let splits: BTreeMap<_, _> = splits.into_iter().map(|s| (s.id(), s)).collect();
584
585 let fragment_splits = crate::stream::source_manager::reassign_splits(
586 entry_fragment_id as u32,
587 empty_actor_splits,
588 &splits,
589 SplitDiffOptions::default(),
590 )
591 .unwrap_or_default();
592 (fragment_splits, Some(*source_id))
593 }
594 None => (HashMap::new(), None),
595 };
596
597 for component_fragment_id in components {
598 let &fragment::Model {
599 fragment_id,
600 job_id,
601 fragment_type_mask,
602 distribution_type,
603 ref stream_node,
604 ref state_table_ids,
605 ..
606 } = fragment_map.get(component_fragment_id).unwrap();
607
608 let actor_count =
609 u32::try_from(actors.len()).expect("actor parallelism exceeds u32::MAX");
610 let actor_id_base = actor_id_counter.fetch_add(actor_count, Ordering::Relaxed);
611
612 let actors: HashMap<ActorId, InflightActorInfo> = assignment
613 .iter()
614 .flat_map(|(worker_id, actors)| {
615 actors
616 .iter()
617 .map(move |(actor_id, vnodes)| (worker_id, actor_id, vnodes))
618 })
619 .map(|(&worker_id, &actor_idx, vnodes)| {
620 let vnode_bitmap = match distribution_type {
621 DistributionType::Single => None,
622 DistributionType::Hash => Some(Bitmap::from_indices(vnode_count, vnodes)),
623 };
624
625 let actor_id = actor_id_base + actor_idx as u32;
626
627 let splits = if let Some(source_id) = fragment_source_ids.get(&fragment_id) {
628 assert_eq!(shared_source_id, Some(*source_id));
629
630 fragment_splits
631 .get(&(actor_idx as u32))
632 .cloned()
633 .unwrap_or_default()
634 } else {
635 vec![]
636 };
637
638 (
639 actor_id,
640 InflightActorInfo {
641 worker_id,
642 vnode_bitmap,
643 splits,
644 },
645 )
646 })
647 .collect();
648
649 let fragment = InflightFragmentInfo {
650 fragment_id: fragment_id as u32,
651 distribution_type,
652 fragment_type_mask: fragment_type_mask.into(),
653 vnode_count,
654 nodes: stream_node.to_protobuf(),
655 actors,
656 state_table_ids: state_table_ids
657 .inner_ref()
658 .iter()
659 .map(|id| catalog::TableId::new(*id as _))
660 .collect(),
661 };
662
663 let &database_id = streaming_job_databases.get(&job_id).ok_or_else(|| {
664 anyhow!("streaming job {job_id} not found in streaming_job_databases")
665 })?;
666
667 all_fragments
668 .entry(database_id)
669 .or_default()
670 .entry(job_id)
671 .or_default()
672 .insert(fragment_id, fragment);
673 }
674 }
675
676 Ok(all_fragments)
677}
678
679#[cfg(debug_assertions)]
680fn debug_sanity_check(
681 ensembles: &[NoShuffleEnsemble],
682 fragment_map: &HashMap<FragmentId, fragment::Model>,
683 jobs: &HashMap<ObjectId, streaming_job::Model>,
684) {
685 debug_assert!(
687 ensembles
688 .iter()
689 .all(|ensemble| ensemble.entries.is_subset(&ensemble.components)),
690 "entries must be subset of components"
691 );
692
693 let mut missing_fragments = BTreeSet::new();
694 let mut missing_jobs = BTreeSet::new();
695
696 for fragment_id in ensembles
697 .iter()
698 .flat_map(|ensemble| ensemble.components.iter())
699 {
700 match fragment_map.get(fragment_id) {
701 Some(fragment) => {
702 if !jobs.contains_key(&fragment.job_id) {
703 missing_jobs.insert(fragment.job_id);
704 }
705 }
706 None => {
707 missing_fragments.insert(*fragment_id);
708 }
709 }
710 }
711
712 debug_assert!(
713 missing_fragments.is_empty(),
714 "missing fragments in fragment_map: {:?}",
715 missing_fragments
716 );
717
718 debug_assert!(
719 missing_jobs.is_empty(),
720 "missing jobs for fragments' job_id: {:?}",
721 missing_jobs
722 );
723
724 for ensemble in ensembles {
725 let unique_vnode_counts: Vec<_> = ensemble
726 .components
727 .iter()
728 .flat_map(|fragment_id| {
729 fragment_map
730 .get(fragment_id)
731 .map(|fragment| fragment.vnode_count)
732 })
733 .unique()
734 .collect();
735
736 debug_assert!(
737 unique_vnode_counts.len() <= 1,
738 "components in ensemble must share same vnode_count: ensemble={:?}, vnode_counts={:?}",
739 ensemble.components,
740 unique_vnode_counts
741 );
742 }
743}
744
745async fn resolve_source_fragments<C>(
746 txn: &C,
747 fragment_map: &HashMap<FragmentId, fragment::Model>,
748) -> MetaResult<(
749 HashMap<FragmentId, SourceId>,
750 HashMap<FragmentId, Vec<SplitImpl>>,
751)>
752where
753 C: ConnectionTrait,
754{
755 let mut source_fragment_ids = HashMap::new();
756 for (fragment_id, fragment) in fragment_map {
757 let mask = FragmentTypeMask::from(fragment.fragment_type_mask);
758 if mask.contains(FragmentTypeFlag::Source)
759 && let Some(source_id) = fragment.stream_node.to_protobuf().find_stream_source()
760 {
761 source_fragment_ids
762 .entry(source_id)
763 .or_insert_with(BTreeSet::new)
764 .insert(fragment_id);
765 }
766
767 if mask.contains(FragmentTypeFlag::SourceScan)
768 && let Some((source_id, _)) = fragment.stream_node.to_protobuf().find_source_backfill()
769 {
770 source_fragment_ids
771 .entry(source_id)
772 .or_insert_with(BTreeSet::new)
773 .insert(fragment_id);
774 }
775 }
776
777 let fragment_source_ids: HashMap<_, _> = source_fragment_ids
778 .iter()
779 .flat_map(|(source_id, fragment_ids)| {
780 fragment_ids
781 .iter()
782 .map(|fragment_id| (**fragment_id, *source_id as SourceId))
783 })
784 .collect();
785
786 let fragment_ids = fragment_source_ids.keys().copied().collect_vec();
787
788 let fragment_splits: Vec<_> = FragmentSplits::find()
789 .filter(fragment_splits::Column::FragmentId.is_in(fragment_ids))
790 .all(txn)
791 .await?;
792
793 let fragment_splits: HashMap<_, _> = fragment_splits
794 .into_iter()
795 .flat_map(|model| {
796 model.splits.map(|splits| {
797 (
798 model.fragment_id,
799 splits
800 .to_protobuf()
801 .splits
802 .iter()
803 .flat_map(SplitImpl::try_from)
804 .collect_vec(),
805 )
806 })
807 })
808 .collect();
809
810 Ok((fragment_source_ids, fragment_splits))
811}
812
813#[derive(Debug)]
815pub struct ActorGraph<'a> {
816 pub fragments: &'a HashMap<FragmentId, (Fragment, Vec<StreamActor>)>,
817 pub locations: &'a HashMap<ActorId, WorkerId>,
818}
819
820#[derive(Debug, Clone)]
821pub struct NoShuffleEnsemble {
822 entries: HashSet<FragmentId>,
823 components: HashSet<FragmentId>,
824}
825
826impl NoShuffleEnsemble {
827 pub fn fragments(&self) -> impl Iterator<Item = FragmentId> + '_ {
828 self.components.iter().cloned()
829 }
830
831 pub fn entry_fragments(&self) -> impl Iterator<Item = FragmentId> + '_ {
832 self.entries.iter().copied()
833 }
834
835 pub fn component_fragments(&self) -> impl Iterator<Item = FragmentId> + '_ {
836 self.components.iter().copied()
837 }
838
839 pub fn contains_entry(&self, fragment_id: &FragmentId) -> bool {
840 self.entries.contains(fragment_id)
841 }
842}
843
844pub async fn find_fragment_no_shuffle_dags_detailed(
845 db: &impl ConnectionTrait,
846 initial_fragment_ids: &[FragmentId],
847) -> MetaResult<Vec<NoShuffleEnsemble>> {
848 let all_no_shuffle_relations: Vec<(_, _)> = FragmentRelation::find()
849 .columns([
850 fragment_relation::Column::SourceFragmentId,
851 fragment_relation::Column::TargetFragmentId,
852 ])
853 .filter(fragment_relation::Column::DispatcherType.eq(DispatcherType::NoShuffle))
854 .into_tuple()
855 .all(db)
856 .await?;
857
858 let mut forward_edges: HashMap<FragmentId, Vec<FragmentId>> = HashMap::new();
859 let mut backward_edges: HashMap<FragmentId, Vec<FragmentId>> = HashMap::new();
860
861 for (src, dst) in all_no_shuffle_relations {
862 forward_edges.entry(src).or_default().push(dst);
863 backward_edges.entry(dst).or_default().push(src);
864 }
865
866 find_no_shuffle_graphs(initial_fragment_ids, &forward_edges, &backward_edges)
867}
868
869fn find_no_shuffle_graphs(
870 initial_fragment_ids: &[FragmentId],
871 forward_edges: &HashMap<FragmentId, Vec<FragmentId>>,
872 backward_edges: &HashMap<FragmentId, Vec<FragmentId>>,
873) -> MetaResult<Vec<NoShuffleEnsemble>> {
874 let mut graphs: Vec<NoShuffleEnsemble> = Vec::new();
875 let mut globally_visited: HashSet<FragmentId> = HashSet::new();
876
877 for &init_id in initial_fragment_ids {
878 if globally_visited.contains(&init_id) {
879 continue;
880 }
881
882 let mut components = HashSet::new();
884 let mut queue: VecDeque<FragmentId> = VecDeque::new();
885
886 queue.push_back(init_id);
887 globally_visited.insert(init_id);
888
889 while let Some(current_id) = queue.pop_front() {
890 components.insert(current_id);
891 let neighbors = forward_edges
892 .get(¤t_id)
893 .into_iter()
894 .flatten()
895 .chain(backward_edges.get(¤t_id).into_iter().flatten());
896
897 for &neighbor_id in neighbors {
898 if globally_visited.insert(neighbor_id) {
899 queue.push_back(neighbor_id);
900 }
901 }
902 }
903
904 let mut entries = HashSet::new();
906 for &node_id in &components {
907 let is_root = match backward_edges.get(&node_id) {
908 Some(parents) => parents.iter().all(|p| !components.contains(p)),
909 None => true,
910 };
911 if is_root {
912 entries.insert(node_id);
913 }
914 }
915
916 if !entries.is_empty() {
918 graphs.push(NoShuffleEnsemble {
919 entries,
920 components,
921 });
922 }
923 }
924
925 Ok(graphs)
926}
927
928#[cfg(test)]
929mod tests {
930 use std::collections::{BTreeSet, HashMap, HashSet};
931 use std::sync::Arc;
932
933 use risingwave_connector::source::SplitImpl;
934 use risingwave_connector::source::test_source::TestSourceSplit;
935 use risingwave_meta_model::{CreateType, I32Array, JobStatus, StreamNode};
936 use risingwave_pb::stream_plan::StreamNode as PbStreamNode;
937
938 use super::*;
939
940 type Edges = (
943 HashMap<FragmentId, Vec<FragmentId>>,
944 HashMap<FragmentId, Vec<FragmentId>>,
945 );
946
947 fn build_edges(relations: &[(FragmentId, FragmentId)]) -> Edges {
950 let mut forward_edges: HashMap<FragmentId, Vec<FragmentId>> = HashMap::new();
951 let mut backward_edges: HashMap<FragmentId, Vec<FragmentId>> = HashMap::new();
952 for &(src, dst) in relations {
953 forward_edges.entry(src).or_default().push(dst);
954 backward_edges.entry(dst).or_default().push(src);
955 }
956 (forward_edges, backward_edges)
957 }
958
959 fn to_hashset(ids: &[FragmentId]) -> HashSet<FragmentId> {
961 ids.iter().cloned().collect()
962 }
963
964 #[allow(deprecated)]
965 fn build_fragment(
966 fragment_id: FragmentId,
967 job_id: ObjectId,
968 fragment_type_mask: i32,
969 distribution_type: DistributionType,
970 vnode_count: i32,
971 parallelism: StreamingParallelism,
972 ) -> fragment::Model {
973 fragment::Model {
974 fragment_id,
975 job_id,
976 fragment_type_mask,
977 distribution_type,
978 stream_node: StreamNode::from(&PbStreamNode::default()),
979 state_table_ids: I32Array::default(),
980 upstream_fragment_id: I32Array::default(),
981 vnode_count,
982 parallelism: Some(parallelism),
983 }
984 }
985
986 type ActorState = (u32, WorkerId, Option<Vec<usize>>, Vec<String>);
987
988 fn collect_actor_state(fragment: &InflightFragmentInfo) -> Vec<ActorState> {
989 let base = fragment.actors.keys().copied().min().unwrap_or_default();
990
991 let mut entries: Vec<_> = fragment
992 .actors
993 .iter()
994 .map(|(&actor_id, info)| {
995 let idx = actor_id - base;
996 let vnode_indices = info.vnode_bitmap.as_ref().map(|bitmap| {
997 bitmap
998 .iter()
999 .enumerate()
1000 .filter_map(|(pos, is_set)| is_set.then_some(pos))
1001 .collect::<Vec<_>>()
1002 });
1003 let splits = info
1004 .splits
1005 .iter()
1006 .map(|split| split.id().to_string())
1007 .collect::<Vec<_>>();
1008 (idx, info.worker_id, vnode_indices, splits)
1009 })
1010 .collect();
1011
1012 entries.sort_by_key(|(idx, _, _, _)| *idx);
1013 entries
1014 }
1015
1016 #[test]
1017 fn test_single_linear_chain() {
1018 let (forward, backward) = build_edges(&[(1, 2), (2, 3)]);
1021 let initial_ids = &[2];
1022
1023 let result = find_no_shuffle_graphs(initial_ids, &forward, &backward);
1025
1026 assert!(result.is_ok());
1028 let graphs = result.unwrap();
1029
1030 assert_eq!(graphs.len(), 1);
1031 let graph = &graphs[0];
1032 assert_eq!(graph.entries, to_hashset(&[1]));
1033 assert_eq!(graph.components, to_hashset(&[1, 2, 3]));
1034 }
1035
1036 #[test]
1037 fn test_two_disconnected_graphs() {
1038 let (forward, backward) = build_edges(&[(1, 2), (10, 11)]);
1041 let initial_ids = &[2, 10];
1042
1043 let mut graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1045
1046 assert_eq!(graphs.len(), 2);
1048
1049 graphs.sort_by_key(|g| *g.components.iter().min().unwrap_or(&0));
1051
1052 assert_eq!(graphs[0].entries, to_hashset(&[1]));
1054 assert_eq!(graphs[0].components, to_hashset(&[1, 2]));
1055
1056 assert_eq!(graphs[1].entries, to_hashset(&[10]));
1058 assert_eq!(graphs[1].components, to_hashset(&[10, 11]));
1059 }
1060
1061 #[test]
1062 fn test_multiple_entries_in_one_graph() {
1063 let (forward, backward) = build_edges(&[(1, 3), (2, 3)]);
1065 let initial_ids = &[3];
1066
1067 let graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1069
1070 assert_eq!(graphs.len(), 1);
1072 let graph = &graphs[0];
1073 assert_eq!(graph.entries, to_hashset(&[1, 2]));
1074 assert_eq!(graph.components, to_hashset(&[1, 2, 3]));
1075 }
1076
1077 #[test]
1078 fn test_diamond_shape_graph() {
1079 let (forward, backward) = build_edges(&[(1, 2), (1, 3), (2, 4), (3, 4)]);
1081 let initial_ids = &[4];
1082
1083 let graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1085
1086 assert_eq!(graphs.len(), 1);
1088 let graph = &graphs[0];
1089 assert_eq!(graph.entries, to_hashset(&[1]));
1090 assert_eq!(graph.components, to_hashset(&[1, 2, 3, 4]));
1091 }
1092
1093 #[test]
1094 fn test_starting_with_multiple_nodes_in_same_graph() {
1095 let (forward, backward) = build_edges(&[(1, 2), (2, 3), (3, 4)]);
1098 let initial_ids = &[2, 4];
1099
1100 let graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1102
1103 assert_eq!(graphs.len(), 1);
1105 let graph = &graphs[0];
1106 assert_eq!(graph.entries, to_hashset(&[1]));
1107 assert_eq!(graph.components, to_hashset(&[1, 2, 3, 4]));
1108 }
1109
1110 #[test]
1111 fn test_empty_initial_ids() {
1112 let (forward, backward) = build_edges(&[(1, 2)]);
1114 let initial_ids = &[];
1115
1116 let graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1118
1119 assert!(graphs.is_empty());
1121 }
1122
1123 #[test]
1124 fn test_isolated_node_as_input() {
1125 let (forward, backward) = build_edges(&[(1, 2)]);
1127 let initial_ids = &[100];
1128
1129 let graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1131
1132 assert_eq!(graphs.len(), 1);
1134 let graph = &graphs[0];
1135 assert_eq!(graph.entries, to_hashset(&[100]));
1136 assert_eq!(graph.components, to_hashset(&[100]));
1137 }
1138
1139 #[test]
1140 fn test_graph_with_a_cycle() {
1141 let (forward, backward) = build_edges(&[(1, 2), (2, 3), (3, 1)]);
1146 let initial_ids = &[2];
1147
1148 let graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1150
1151 assert!(
1153 graphs.is_empty(),
1154 "A graph with no entries should not be returned"
1155 );
1156 }
1157 #[test]
1158 fn test_custom_complex() {
1159 let (forward, backward) = build_edges(&[(1, 3), (1, 8), (2, 3), (4, 3), (3, 5), (6, 7)]);
1160 let initial_ids = &[1, 2, 4, 6];
1161
1162 let mut graphs = find_no_shuffle_graphs(initial_ids, &forward, &backward).unwrap();
1164
1165 assert_eq!(graphs.len(), 2);
1167 graphs.sort_by_key(|g| *g.components.iter().min().unwrap_or(&0));
1169
1170 assert_eq!(graphs[0].entries, to_hashset(&[1, 2, 4]));
1172 assert_eq!(graphs[0].components, to_hashset(&[1, 2, 3, 4, 5, 8]));
1173
1174 assert_eq!(graphs[1].entries, to_hashset(&[6]));
1176 assert_eq!(graphs[1].components, to_hashset(&[6, 7]));
1177 }
1178
1179 #[test]
1180 fn render_actors_increments_actor_counter() {
1181 let actor_id_counter = AtomicU32::new(100);
1182 let fragment_id: FragmentId = 1;
1183 let job_id: ObjectId = 10;
1184 let database_id: DatabaseId = 3;
1185
1186 let fragment_model = build_fragment(
1187 fragment_id,
1188 job_id,
1189 0,
1190 DistributionType::Single,
1191 1,
1192 StreamingParallelism::Fixed(1),
1193 );
1194
1195 let job_model = streaming_job::Model {
1196 job_id,
1197 job_status: JobStatus::Created,
1198 create_type: CreateType::Foreground,
1199 timezone: None,
1200 parallelism: StreamingParallelism::Fixed(1),
1201 max_parallelism: 1,
1202 specific_resource_group: None,
1203 };
1204
1205 let database_model = database::Model {
1206 database_id,
1207 name: "test_db".into(),
1208 resource_group: "rg-a".into(),
1209 barrier_interval_ms: None,
1210 checkpoint_frequency: None,
1211 };
1212
1213 let ensembles = vec![NoShuffleEnsemble {
1214 entries: HashSet::from([fragment_id]),
1215 components: HashSet::from([fragment_id]),
1216 }];
1217
1218 let fragment_map = HashMap::from([(fragment_id, fragment_model)]);
1219 let job_map = HashMap::from([(job_id, job_model)]);
1220
1221 let worker_map = BTreeMap::from([(
1222 1,
1223 WorkerInfo {
1224 parallelism: NonZeroUsize::new(1).unwrap(),
1225 resource_group: Some("rg-a".into()),
1226 },
1227 )]);
1228
1229 let fragment_source_ids: HashMap<FragmentId, SourceId> = HashMap::new();
1230 let fragment_splits: HashMap<FragmentId, Vec<SplitImpl>> = HashMap::new();
1231 let streaming_job_databases = HashMap::from([(job_id, database_id)]);
1232 let database_map = HashMap::from([(database_id, database_model)]);
1233
1234 let context = RenderActorsContext {
1235 fragment_source_ids: &fragment_source_ids,
1236 fragment_splits: &fragment_splits,
1237 streaming_job_databases: &streaming_job_databases,
1238 database_map: &database_map,
1239 };
1240
1241 let result = render_actors(
1242 &actor_id_counter,
1243 &ensembles,
1244 &fragment_map,
1245 &job_map,
1246 &worker_map,
1247 AdaptiveParallelismStrategy::Auto,
1248 context,
1249 )
1250 .expect("actor rendering succeeds");
1251
1252 let state = collect_actor_state(&result[&database_id][&job_id][&fragment_id]);
1253 assert_eq!(state.len(), 1);
1254 assert!(
1255 state[0].2.is_none(),
1256 "single distribution should not assign vnode bitmaps"
1257 );
1258 assert_eq!(actor_id_counter.load(Ordering::Relaxed), 101);
1259 }
1260
1261 #[test]
1262 fn render_actors_aligns_hash_vnode_bitmaps() {
1263 let actor_id_counter = AtomicU32::new(0);
1264 let entry_fragment_id: FragmentId = 1;
1265 let downstream_fragment_id: FragmentId = 2;
1266 let job_id: ObjectId = 20;
1267 let database_id: DatabaseId = 5;
1268
1269 let entry_fragment = build_fragment(
1270 entry_fragment_id,
1271 job_id,
1272 0,
1273 DistributionType::Hash,
1274 4,
1275 StreamingParallelism::Fixed(2),
1276 );
1277
1278 let downstream_fragment = build_fragment(
1279 downstream_fragment_id,
1280 job_id,
1281 0,
1282 DistributionType::Hash,
1283 4,
1284 StreamingParallelism::Fixed(2),
1285 );
1286
1287 let job_model = streaming_job::Model {
1288 job_id,
1289 job_status: JobStatus::Created,
1290 create_type: CreateType::Background,
1291 timezone: None,
1292 parallelism: StreamingParallelism::Fixed(2),
1293 max_parallelism: 2,
1294 specific_resource_group: None,
1295 };
1296
1297 let database_model = database::Model {
1298 database_id,
1299 name: "test_db_hash".into(),
1300 resource_group: "rg-hash".into(),
1301 barrier_interval_ms: None,
1302 checkpoint_frequency: None,
1303 };
1304
1305 let ensembles = vec![NoShuffleEnsemble {
1306 entries: HashSet::from([entry_fragment_id]),
1307 components: HashSet::from([entry_fragment_id, downstream_fragment_id]),
1308 }];
1309
1310 let fragment_map = HashMap::from([
1311 (entry_fragment_id, entry_fragment),
1312 (downstream_fragment_id, downstream_fragment),
1313 ]);
1314 let job_map = HashMap::from([(job_id, job_model)]);
1315
1316 let worker_map = BTreeMap::from([
1317 (
1318 1,
1319 WorkerInfo {
1320 parallelism: NonZeroUsize::new(1).unwrap(),
1321 resource_group: Some("rg-hash".into()),
1322 },
1323 ),
1324 (
1325 2,
1326 WorkerInfo {
1327 parallelism: NonZeroUsize::new(1).unwrap(),
1328 resource_group: Some("rg-hash".into()),
1329 },
1330 ),
1331 ]);
1332
1333 let fragment_source_ids: HashMap<FragmentId, SourceId> = HashMap::new();
1334 let fragment_splits: HashMap<FragmentId, Vec<SplitImpl>> = HashMap::new();
1335 let streaming_job_databases = HashMap::from([(job_id, database_id)]);
1336 let database_map = HashMap::from([(database_id, database_model)]);
1337
1338 let context = RenderActorsContext {
1339 fragment_source_ids: &fragment_source_ids,
1340 fragment_splits: &fragment_splits,
1341 streaming_job_databases: &streaming_job_databases,
1342 database_map: &database_map,
1343 };
1344
1345 let result = render_actors(
1346 &actor_id_counter,
1347 &ensembles,
1348 &fragment_map,
1349 &job_map,
1350 &worker_map,
1351 AdaptiveParallelismStrategy::Auto,
1352 context,
1353 )
1354 .expect("actor rendering succeeds");
1355
1356 let entry_state = collect_actor_state(&result[&database_id][&job_id][&entry_fragment_id]);
1357 let downstream_state =
1358 collect_actor_state(&result[&database_id][&job_id][&downstream_fragment_id]);
1359
1360 assert_eq!(entry_state.len(), 2);
1361 assert_eq!(entry_state, downstream_state);
1362
1363 let assigned_vnodes: BTreeSet<_> = entry_state
1364 .iter()
1365 .flat_map(|(_, _, vnodes, _)| {
1366 vnodes
1367 .as_ref()
1368 .expect("hash distribution should populate vnode bitmap")
1369 .iter()
1370 .copied()
1371 })
1372 .collect();
1373 assert_eq!(assigned_vnodes, BTreeSet::from([0, 1, 2, 3]));
1374 assert_eq!(actor_id_counter.load(Ordering::Relaxed), 4);
1375 }
1376
1377 #[test]
1378 fn render_actors_propagates_source_splits() {
1379 let actor_id_counter = AtomicU32::new(0);
1380 let entry_fragment_id: FragmentId = 11;
1381 let downstream_fragment_id: FragmentId = 12;
1382 let job_id: ObjectId = 30;
1383 let database_id: DatabaseId = 7;
1384 let source_id: SourceId = 99;
1385
1386 let source_mask = FragmentTypeFlag::raw_flag([FragmentTypeFlag::Source]) as i32;
1387 let source_scan_mask = FragmentTypeFlag::raw_flag([FragmentTypeFlag::SourceScan]) as i32;
1388
1389 let entry_fragment = build_fragment(
1390 entry_fragment_id,
1391 job_id,
1392 source_mask,
1393 DistributionType::Hash,
1394 4,
1395 StreamingParallelism::Fixed(2),
1396 );
1397
1398 let downstream_fragment = build_fragment(
1399 downstream_fragment_id,
1400 job_id,
1401 source_scan_mask,
1402 DistributionType::Hash,
1403 4,
1404 StreamingParallelism::Fixed(2),
1405 );
1406
1407 let job_model = streaming_job::Model {
1408 job_id,
1409 job_status: JobStatus::Created,
1410 create_type: CreateType::Background,
1411 timezone: None,
1412 parallelism: StreamingParallelism::Fixed(2),
1413 max_parallelism: 2,
1414 specific_resource_group: None,
1415 };
1416
1417 let database_model = database::Model {
1418 database_id,
1419 name: "split_db".into(),
1420 resource_group: "rg-source".into(),
1421 barrier_interval_ms: None,
1422 checkpoint_frequency: None,
1423 };
1424
1425 let ensembles = vec![NoShuffleEnsemble {
1426 entries: HashSet::from([entry_fragment_id]),
1427 components: HashSet::from([entry_fragment_id, downstream_fragment_id]),
1428 }];
1429
1430 let fragment_map = HashMap::from([
1431 (entry_fragment_id, entry_fragment),
1432 (downstream_fragment_id, downstream_fragment),
1433 ]);
1434 let job_map = HashMap::from([(job_id, job_model)]);
1435
1436 let worker_map = BTreeMap::from([
1437 (
1438 1,
1439 WorkerInfo {
1440 parallelism: NonZeroUsize::new(1).unwrap(),
1441 resource_group: Some("rg-source".into()),
1442 },
1443 ),
1444 (
1445 2,
1446 WorkerInfo {
1447 parallelism: NonZeroUsize::new(1).unwrap(),
1448 resource_group: Some("rg-source".into()),
1449 },
1450 ),
1451 ]);
1452
1453 let split_a = SplitImpl::Test(TestSourceSplit {
1454 id: Arc::<str>::from("split-a"),
1455 properties: HashMap::new(),
1456 offset: "0".into(),
1457 });
1458 let split_b = SplitImpl::Test(TestSourceSplit {
1459 id: Arc::<str>::from("split-b"),
1460 properties: HashMap::new(),
1461 offset: "0".into(),
1462 });
1463
1464 let fragment_source_ids = HashMap::from([
1465 (entry_fragment_id, source_id),
1466 (downstream_fragment_id, source_id),
1467 ]);
1468 let fragment_splits =
1469 HashMap::from([(entry_fragment_id, vec![split_a.clone(), split_b.clone()])]);
1470 let streaming_job_databases = HashMap::from([(job_id, database_id)]);
1471 let database_map = HashMap::from([(database_id, database_model)]);
1472
1473 let context = RenderActorsContext {
1474 fragment_source_ids: &fragment_source_ids,
1475 fragment_splits: &fragment_splits,
1476 streaming_job_databases: &streaming_job_databases,
1477 database_map: &database_map,
1478 };
1479
1480 let result = render_actors(
1481 &actor_id_counter,
1482 &ensembles,
1483 &fragment_map,
1484 &job_map,
1485 &worker_map,
1486 AdaptiveParallelismStrategy::Auto,
1487 context,
1488 )
1489 .expect("actor rendering succeeds");
1490
1491 let entry_state = collect_actor_state(&result[&database_id][&job_id][&entry_fragment_id]);
1492 let downstream_state =
1493 collect_actor_state(&result[&database_id][&job_id][&downstream_fragment_id]);
1494
1495 assert_eq!(entry_state, downstream_state);
1496
1497 let split_ids: BTreeSet<_> = entry_state
1498 .iter()
1499 .flat_map(|(_, _, _, splits)| splits.iter().cloned())
1500 .collect();
1501 assert_eq!(
1502 split_ids,
1503 BTreeSet::from([split_a.id().to_string(), split_b.id().to_string()])
1504 );
1505 assert_eq!(actor_id_counter.load(Ordering::Relaxed), 4);
1506 }
1507}