1#![allow(
16 clippy::collapsible_if,
17 clippy::explicit_iter_loop,
18 reason = "generated by crepe"
19)]
20
21use std::collections::{BTreeMap, HashMap};
22use std::num::NonZeroUsize;
23
24use anyhow::Context;
25use either::Either;
26use enum_as_inner::EnumAsInner;
27use itertools::Itertools;
28use risingwave_common::hash::{
29 ActorAlignmentId, ActorAlignmentMapping, ActorMapping, VnodeCountCompat,
30};
31use risingwave_common::id::JobId;
32use risingwave_common::util::stream_graph_visitor::visit_fragment;
33use risingwave_common::{bail, hash};
34use risingwave_connector::source::cdc::{CDC_BACKFILL_MAX_PARALLELISM, CdcScanOptions};
35use risingwave_meta_model::WorkerId;
36use risingwave_meta_model::fragment::DistributionType;
37use risingwave_pb::common::{ActorInfo, WorkerNode};
38use risingwave_pb::meta::table_fragments::fragment::PbFragmentDistributionType;
39use risingwave_pb::stream_plan::DispatcherType::{self, *};
40
41use crate::MetaResult;
42use crate::barrier::SharedFragmentInfo;
43use crate::model::ActorId;
44use crate::stream::AssignerBuilder;
45use crate::stream::stream_graph::fragment::CompleteStreamFragmentGraph;
46use crate::stream::stream_graph::id::GlobalFragmentId as Id;
47
48type HashMappingId = usize;
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
52enum Req {
53 Singleton(WorkerId),
55 Hash(HashMappingId),
57 AnyVnodeCount(usize),
60}
61
62impl Req {
63 #[allow(non_upper_case_globals)]
65 const AnySingleton: Self = Self::AnyVnodeCount(1);
66
67 fn merge(a: Self, b: Self, mapping_len: impl Fn(HashMappingId) -> usize) -> MetaResult<Self> {
71 let merge = |a, b| match (a, b) {
73 (Self::AnySingleton, Self::Singleton(id)) => Some(Self::Singleton(id)),
74 (Self::AnyVnodeCount(count), Self::Hash(id)) if mapping_len(id) == count => {
75 Some(Self::Hash(id))
76 }
77 _ => None,
78 };
79
80 match merge(a, b).or_else(|| merge(b, a)) {
81 Some(req) => Ok(req),
82 None => bail!("incompatible requirements `{a:?}` and `{b:?}`"),
83 }
84 }
85}
86
87#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
89enum Fact {
90 Edge {
92 from: Id,
93 to: Id,
94 dt: DispatcherType,
95 },
96 Req { id: Id, req: Req },
98}
99
100crepe::crepe! {
101 @input
102 struct Input(Fact);
103
104 struct Edge(Id, Id, DispatcherType);
105 struct ExternalReq(Id, Req);
106
107 @output
108 struct Requirement(Id, Req);
109
110 Edge(from, to, dt) <- Input(f), let Fact::Edge { from, to, dt } = f;
112 Requirement(id, req) <- Input(f), let Fact::Req { id, req } = f;
113
114 Requirement(y, Req::AnySingleton) <- Edge(_, y, Simple);
116 Requirement(x, d) <- Edge(x, y, NoShuffle), Requirement(y, d);
118 Requirement(y, d) <- Edge(x, y, NoShuffle), Requirement(x, d);
119}
120
121#[derive(Debug, Clone, EnumAsInner)]
123pub(super) enum Distribution {
124 Singleton(WorkerId),
126
127 Hash(ActorAlignmentMapping),
129}
130
131impl Distribution {
132 pub fn parallelism(&self) -> usize {
134 self.actors().count()
135 }
136
137 pub fn actors(&self) -> impl Iterator<Item = ActorAlignmentId> + '_ {
139 match self {
140 Distribution::Singleton(p) => {
141 Either::Left(std::iter::once(ActorAlignmentId::new(*p, 0)))
142 }
143 Distribution::Hash(mapping) => Either::Right(mapping.iter_unique()),
144 }
145 }
146
147 pub fn vnode_count(&self) -> usize {
149 match self {
150 Distribution::Singleton(_) => 1, Distribution::Hash(mapping) => mapping.len(),
152 }
153 }
154
155 pub fn from_fragment(
157 fragment: &SharedFragmentInfo,
158 actor_location: &HashMap<ActorId, WorkerId>,
159 ) -> Self {
160 match fragment.distribution_type {
161 DistributionType::Single => {
162 let (actor_id, _) = fragment.actors.iter().exactly_one().unwrap();
163 let location = actor_location.get(actor_id).unwrap();
164 Distribution::Singleton(*location)
165 }
166 DistributionType::Hash => {
167 let actor_bitmaps: HashMap<_, _> = fragment
168 .actors
169 .iter()
170 .map(|(actor_id, actor_info)| {
171 (
172 *actor_id as hash::ActorId,
173 actor_info.vnode_bitmap.clone().unwrap(),
174 )
175 })
176 .collect();
177
178 let actor_mapping = ActorMapping::from_bitmaps(&actor_bitmaps);
179 let actor_location = actor_location.iter().map(|(&k, &v)| (k, v)).collect();
180 let mapping = actor_mapping.to_actor_alignment(&actor_location);
181
182 Distribution::Hash(mapping)
183 }
184 }
185 }
186
187 pub fn to_distribution_type(&self) -> PbFragmentDistributionType {
189 match self {
190 Distribution::Singleton(_) => PbFragmentDistributionType::Single,
191 Distribution::Hash(_) => PbFragmentDistributionType::Hash,
192 }
193 }
194}
195
196pub(super) struct Scheduler {
198 default_hash_mapping: ActorAlignmentMapping,
200
201 default_singleton_worker: WorkerId,
203
204 dynamic_mapping_fn: Box<dyn Fn(usize, Option<usize>) -> anyhow::Result<ActorAlignmentMapping>>,
206}
207
208impl Scheduler {
209 pub fn new(
217 streaming_job_id: JobId,
218 workers: &HashMap<WorkerId, WorkerNode>,
219 default_parallelism: NonZeroUsize,
220 expected_vnode_count: usize,
221 ) -> MetaResult<Self> {
222 let parallelism = default_parallelism.get();
223 assert!(
224 parallelism <= expected_vnode_count,
225 "parallelism should be limited by vnode count in previous steps"
226 );
227
228 let assigner = AssignerBuilder::new(streaming_job_id).build();
229
230 let worker_weights = workers
231 .iter()
232 .map(|(worker_id, worker)| {
233 (
234 *worker_id,
235 NonZeroUsize::new(worker.compute_node_parallelism()).unwrap(),
236 )
237 })
238 .collect();
239
240 let actor_idxes = (0..parallelism).collect_vec();
241 let vnodes = (0..expected_vnode_count).collect_vec();
242
243 let assignment = assigner.assign_hierarchical(&worker_weights, &actor_idxes, &vnodes)?;
244
245 let default_hash_mapping =
246 ActorAlignmentMapping::from_assignment(assignment, expected_vnode_count);
247
248 let single_actor_idxes = std::iter::once(0).collect_vec();
249
250 let single_assignment =
251 assigner.assign_hierarchical(&worker_weights, &single_actor_idxes, &vnodes)?;
252
253 let default_singleton_worker =
254 single_assignment.keys().exactly_one().cloned().unwrap() as _;
255
256 let dynamic_mapping_fn = Box::new(
257 move |limited_count: usize, force_parallelism: Option<usize>| {
258 let parallelism = if let Some(force_parallelism) = force_parallelism {
259 force_parallelism.min(limited_count)
260 } else {
261 parallelism.min(limited_count)
262 };
263 let assignment = assigner.assign_hierarchical(
264 &worker_weights,
265 &(0..parallelism).collect_vec(),
266 &(0..limited_count).collect_vec(),
267 )?;
268
269 let mapping = ActorAlignmentMapping::from_assignment(assignment, limited_count);
270 Ok(mapping)
271 },
272 );
273 Ok(Self {
274 default_hash_mapping,
275 default_singleton_worker,
276 dynamic_mapping_fn,
277 })
278 }
279
280 pub fn schedule(
283 &self,
284 graph: &CompleteStreamFragmentGraph,
285 ) -> MetaResult<HashMap<Id, Distribution>> {
286 let existing_distribution = graph.existing_distribution();
287
288 let all_hash_mappings = existing_distribution
290 .values()
291 .flat_map(|dist| dist.as_hash())
292 .cloned()
293 .unique()
294 .collect_vec();
295 let hash_mapping_id: HashMap<_, _> = all_hash_mappings
296 .iter()
297 .enumerate()
298 .map(|(i, m)| (m.clone(), i))
299 .collect();
300
301 let mut facts = Vec::new();
302
303 for (&id, fragment) in graph.building_fragments() {
305 if fragment.requires_singleton {
306 facts.push(Fact::Req {
307 id,
308 req: Req::AnySingleton,
309 });
310 }
311 }
312 let mut force_parallelism_fragment_ids: HashMap<_, _> = HashMap::default();
313 for (&id, fragment) in graph.building_fragments() {
316 visit_fragment(fragment, |node| {
317 use risingwave_pb::stream_plan::stream_node::NodeBody;
318 let vnode_count = match node {
319 NodeBody::StreamScan(node) => {
320 if let Some(table) = &node.arrangement_table {
321 table.vnode_count()
322 } else if let Some(table) = &node.table_desc {
323 table.vnode_count()
324 } else {
325 return;
326 }
327 }
328 NodeBody::TemporalJoin(node) => node.get_table_desc().unwrap().vnode_count(),
329 NodeBody::BatchPlan(node) => node.get_table_desc().unwrap().vnode_count(),
330 NodeBody::Lookup(node) => node
331 .get_arrangement_table_info()
332 .unwrap()
333 .get_table_desc()
334 .unwrap()
335 .vnode_count(),
336 NodeBody::StreamCdcScan(node) => {
337 let Some(ref options) = node.options else {
338 return;
339 };
340 let options = CdcScanOptions::from_proto(options);
341 if options.is_parallelized_backfill() {
342 force_parallelism_fragment_ids
343 .insert(id, options.backfill_parallelism as usize);
344 CDC_BACKFILL_MAX_PARALLELISM as usize
345 } else {
346 return;
347 }
348 }
349 NodeBody::GapFill(node) => {
350 let buffer_table = node.get_state_table().unwrap();
352 if let Some(vnode_count) = buffer_table.vnode_count_inner().value_opt() {
354 vnode_count
355 } else {
356 return;
358 }
359 }
360 _ => return,
361 };
362 facts.push(Fact::Req {
363 id,
364 req: Req::AnyVnodeCount(vnode_count),
365 });
366 });
367 }
368 for (id, dist) in existing_distribution {
370 let req = match dist {
371 Distribution::Singleton(worker_id) => Req::Singleton(worker_id),
372 Distribution::Hash(mapping) => Req::Hash(hash_mapping_id[&mapping]),
373 };
374 facts.push(Fact::Req { id, req });
375 }
376 for (from, to, edge) in graph.all_edges() {
378 facts.push(Fact::Edge {
379 from,
380 to,
381 dt: edge.dispatch_strategy.r#type(),
382 });
383 }
384
385 let mut crepe = Crepe::new();
387 crepe.extend(facts.into_iter().map(Input));
388 let (reqs,) = crepe.run();
389 let reqs = reqs
390 .into_iter()
391 .map(|Requirement(id, req)| (id, req))
392 .into_group_map();
393
394 let mut distributions = HashMap::new();
396 for &id in graph.building_fragments().keys() {
397 let dist = match reqs.get(&id) {
398 Some(reqs) => {
400 let req = (reqs.iter().copied())
401 .try_reduce(|a, b| Req::merge(a, b, |id| all_hash_mappings[id].len()))
402 .with_context(|| {
403 format!("cannot fulfill scheduling requirements for fragment {id:?}")
404 })?
405 .unwrap();
406
407 match req {
409 Req::Singleton(worker_id) => Distribution::Singleton(worker_id),
410 Req::Hash(mapping) => {
411 Distribution::Hash(all_hash_mappings[mapping].clone())
412 }
413 Req::AnySingleton => Distribution::Singleton(self.default_singleton_worker),
414 Req::AnyVnodeCount(vnode_count) => {
415 let force_parallelism =
416 force_parallelism_fragment_ids.get(&id).copied();
417 let mapping = (self.dynamic_mapping_fn)(vnode_count, force_parallelism)
418 .with_context(|| {
419 format!(
420 "failed to build dynamic mapping for fragment {id:?} with vnode count {vnode_count}"
421 )
422 })?;
423
424 Distribution::Hash(mapping)
425 }
426 }
427 }
428 None => Distribution::Hash(self.default_hash_mapping.clone()),
430 };
431
432 distributions.insert(id, dist);
433 }
434
435 tracing::debug!(?distributions, "schedule fragments");
436
437 Ok(distributions)
438 }
439}
440
441#[cfg_attr(test, derive(Default))]
443pub struct Locations {
444 pub actor_locations: BTreeMap<ActorId, ActorAlignmentId>,
446 pub worker_locations: HashMap<WorkerId, WorkerNode>,
448}
449
450impl Locations {
451 pub fn worker_actors(&self) -> HashMap<WorkerId, Vec<ActorId>> {
453 self.actor_locations
454 .iter()
455 .map(|(actor_id, alignment_id)| (alignment_id.worker_id(), *actor_id))
456 .into_group_map()
457 }
458
459 pub fn actor_infos(&self) -> impl Iterator<Item = ActorInfo> + '_ {
461 self.actor_locations
462 .iter()
463 .map(|(&actor_id, alignment_id)| ActorInfo {
464 actor_id,
465 host: self.worker_locations[&(alignment_id.worker_id() as WorkerId)]
466 .host
467 .clone(),
468 })
469 }
470}
471
472#[cfg(test)]
473mod tests {
474 use super::*;
475
476 #[derive(Debug)]
477 enum Result {
478 DefaultHash,
479 Required(Req),
480 }
481
482 impl Result {
483 #[allow(non_upper_case_globals)]
484 const DefaultSingleton: Self = Self::Required(Req::AnySingleton);
485 }
486
487 fn run_and_merge(
488 facts: impl IntoIterator<Item = Fact>,
489 mapping_len: impl Fn(HashMappingId) -> usize,
490 ) -> MetaResult<HashMap<Id, Req>> {
491 let mut crepe = Crepe::new();
492 crepe.extend(facts.into_iter().map(Input));
493 let (reqs,) = crepe.run();
494
495 let reqs = reqs
496 .into_iter()
497 .map(|Requirement(id, req)| (id, req))
498 .into_group_map();
499
500 let mut merged = HashMap::new();
501 for (id, reqs) in reqs {
502 let req = (reqs.iter().copied())
503 .try_reduce(|a, b| Req::merge(a, b, &mapping_len))
504 .with_context(|| {
505 format!("cannot fulfill scheduling requirements for fragment {id:?}")
506 })?
507 .unwrap();
508 merged.insert(id, req);
509 }
510
511 Ok(merged)
512 }
513
514 fn test_success(facts: impl IntoIterator<Item = Fact>, expected: HashMap<Id, Result>) {
515 test_success_with_mapping_len(facts, expected, |_| 0);
516 }
517
518 fn test_success_with_mapping_len(
519 facts: impl IntoIterator<Item = Fact>,
520 expected: HashMap<Id, Result>,
521 mapping_len: impl Fn(HashMappingId) -> usize,
522 ) {
523 let reqs = run_and_merge(facts, mapping_len).unwrap();
524
525 for (id, expected) in expected {
526 match (reqs.get(&id), expected) {
527 (None, Result::DefaultHash) => {}
528 (Some(actual), Result::Required(expected)) if *actual == expected => {}
529 (actual, expected) => panic!(
530 "unexpected result for fragment {id:?}\nactual: {actual:?}\nexpected: {expected:?}"
531 ),
532 }
533 }
534 }
535
536 fn test_failed(facts: impl IntoIterator<Item = Fact>) {
537 run_and_merge(facts, |_| 0).unwrap_err();
538 }
539
540 #[test]
542 fn test_single_fragment_hash() {
543 #[rustfmt::skip]
544 let facts = [];
545
546 let expected = maplit::hashmap! {
547 101.into() => Result::DefaultHash,
548 };
549
550 test_success(facts, expected);
551 }
552
553 #[test]
555 fn test_single_fragment_singleton() {
556 #[rustfmt::skip]
557 let facts = [
558 Fact::Req { id: 101.into(), req: Req::AnySingleton },
559 ];
560
561 let expected = maplit::hashmap! {
562 101.into() => Result::DefaultSingleton,
563 };
564
565 test_success(facts, expected);
566 }
567
568 #[test]
572 fn test_scheduling_mv_on_mv() {
573 #[rustfmt::skip]
574 let facts = [
575 Fact::Req { id: 1.into(), req: Req::Hash(1) },
576 Fact::Req { id: 2.into(), req: Req::Singleton(0.into()) },
577 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
578 Fact::Edge { from: 2.into(), to: 102.into(), dt: NoShuffle },
579 Fact::Edge { from: 101.into(), to: 103.into(), dt: Hash },
580 Fact::Edge { from: 102.into(), to: 103.into(), dt: Hash },
581 Fact::Edge { from: 103.into(), to: 104.into(), dt: Simple },
582 ];
583
584 let expected = maplit::hashmap! {
585 101.into() => Result::Required(Req::Hash(1)),
586 102.into() => Result::Required(Req::Singleton(0.into())),
587 103.into() => Result::DefaultHash,
588 104.into() => Result::DefaultSingleton,
589 };
590
591 test_success(facts, expected);
592 }
593
594 #[test]
598 fn test_delta_join() {
599 #[rustfmt::skip]
600 let facts = [
601 Fact::Req { id: 1.into(), req: Req::Hash(1) },
602 Fact::Req { id: 2.into(), req: Req::Hash(2) },
603 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
604 Fact::Edge { from: 2.into(), to: 102.into(), dt: NoShuffle },
605 Fact::Edge { from: 101.into(), to: 103.into(), dt: NoShuffle },
606 Fact::Edge { from: 102.into(), to: 104.into(), dt: NoShuffle },
607 Fact::Edge { from: 101.into(), to: 104.into(), dt: Hash },
608 Fact::Edge { from: 102.into(), to: 103.into(), dt: Hash },
609 Fact::Edge { from: 103.into(), to: 105.into(), dt: Hash },
610 Fact::Edge { from: 104.into(), to: 105.into(), dt: Hash },
611 ];
612
613 let expected = maplit::hashmap! {
614 101.into() => Result::Required(Req::Hash(1)),
615 102.into() => Result::Required(Req::Hash(2)),
616 103.into() => Result::Required(Req::Hash(1)),
617 104.into() => Result::Required(Req::Hash(2)),
618 105.into() => Result::DefaultHash,
619 };
620
621 test_success(facts, expected);
622 }
623
624 #[test]
628 fn test_singleton_leaf() {
629 #[rustfmt::skip]
630 let facts = [
631 Fact::Req { id: 1.into(), req: Req::Hash(1) },
632 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
633 Fact::Req { id: 102.into(), req: Req::AnySingleton }, Fact::Edge { from: 101.into(), to: 103.into(), dt: Hash },
635 Fact::Edge { from: 102.into(), to: 103.into(), dt: Broadcast },
636 ];
637
638 let expected = maplit::hashmap! {
639 101.into() => Result::Required(Req::Hash(1)),
640 102.into() => Result::DefaultSingleton,
641 103.into() => Result::DefaultHash,
642 };
643
644 test_success(facts, expected);
645 }
646
647 #[test]
651 fn test_upstream_hash_shard_failed() {
652 #[rustfmt::skip]
653 let facts = [
654 Fact::Req { id: 1.into(), req: Req::Hash(1) },
655 Fact::Req { id: 2.into(), req: Req::Hash(2) },
656 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
657 Fact::Edge { from: 2.into(), to: 101.into(), dt: NoShuffle },
658 ];
659
660 test_failed(facts);
661 }
662
663 #[test]
665 fn test_arrangement_backfill_vnode_count() {
666 #[rustfmt::skip]
667 let facts = [
668 Fact::Req { id: 1.into(), req: Req::Hash(1) },
669 Fact::Req { id: 101.into(), req: Req::AnyVnodeCount(128) },
670 Fact::Edge { from: 1.into(), to: 101.into(), dt: Hash },
671 ];
672
673 let expected = maplit::hashmap! {
674 101.into() => Result::Required(Req::AnyVnodeCount(128)),
675 };
676
677 test_success(facts, expected);
678 }
679
680 #[test]
682 fn test_no_shuffle_backfill_vnode_count() {
683 #[rustfmt::skip]
684 let facts = [
685 Fact::Req { id: 1.into(), req: Req::Hash(1) },
686 Fact::Req { id: 101.into(), req: Req::AnyVnodeCount(128) },
687 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
688 ];
689
690 let expected = maplit::hashmap! {
691 101.into() => Result::Required(Req::Hash(1)),
692 };
693
694 test_success_with_mapping_len(facts, expected, |id| {
695 assert_eq!(id, 1);
696 128
697 });
698 }
699
700 #[test]
702 fn test_no_shuffle_backfill_mismatched_vnode_count() {
703 #[rustfmt::skip]
704 let facts = [
705 Fact::Req { id: 1.into(), req: Req::Hash(1) },
706 Fact::Req { id: 101.into(), req: Req::AnyVnodeCount(128) },
707 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
708 ];
709
710 test_failed(facts);
712 }
713
714 #[test]
716 fn test_backfill_singleton_vnode_count() {
717 #[rustfmt::skip]
718 let facts = [
719 Fact::Req { id: 1.into(), req: Req::Singleton(0.into()) },
720 Fact::Req { id: 101.into(), req: Req::AnySingleton },
721 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle }, ];
723
724 let expected = maplit::hashmap! {
725 101.into() => Result::Required(Req::Singleton(0.into())),
726 };
727
728 test_success(facts, expected);
729 }
730}