1#![allow(
16 clippy::collapsible_if,
17 clippy::explicit_iter_loop,
18 reason = "generated by crepe"
19)]
20
21use std::collections::{BTreeMap, HashMap};
22use std::num::NonZeroUsize;
23
24use anyhow::Context;
25use either::Either;
26use enum_as_inner::EnumAsInner;
27use itertools::Itertools;
28use risingwave_common::hash::{
29 ActorAlignmentId, ActorAlignmentMapping, ActorMapping, VnodeCountCompat,
30};
31use risingwave_common::util::stream_graph_visitor::visit_fragment;
32use risingwave_common::{bail, hash};
33use risingwave_connector::source::cdc::{CDC_BACKFILL_MAX_PARALLELISM, CdcScanOptions};
34use risingwave_meta_model::WorkerId;
35use risingwave_pb::common::{ActorInfo, WorkerNode};
36use risingwave_pb::meta::table_fragments::fragment::{
37 FragmentDistributionType, PbFragmentDistributionType,
38};
39use risingwave_pb::stream_plan::DispatcherType::{self, *};
40
41use crate::MetaResult;
42use crate::model::{ActorId, Fragment};
43use crate::stream::AssignerBuilder;
44use crate::stream::stream_graph::fragment::CompleteStreamFragmentGraph;
45use crate::stream::stream_graph::id::GlobalFragmentId as Id;
46
47type HashMappingId = usize;
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
51enum Req {
52 Singleton(WorkerId),
54 Hash(HashMappingId),
56 AnyVnodeCount(usize),
59}
60
61impl Req {
62 #[allow(non_upper_case_globals)]
64 const AnySingleton: Self = Self::AnyVnodeCount(1);
65
66 fn merge(a: Self, b: Self, mapping_len: impl Fn(HashMappingId) -> usize) -> MetaResult<Self> {
70 let merge = |a, b| match (a, b) {
72 (Self::AnySingleton, Self::Singleton(id)) => Some(Self::Singleton(id)),
73 (Self::AnyVnodeCount(count), Self::Hash(id)) if mapping_len(id) == count => {
74 Some(Self::Hash(id))
75 }
76 _ => None,
77 };
78
79 match merge(a, b).or_else(|| merge(b, a)) {
80 Some(req) => Ok(req),
81 None => bail!("incompatible requirements `{a:?}` and `{b:?}`"),
82 }
83 }
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
88enum Fact {
89 Edge {
91 from: Id,
92 to: Id,
93 dt: DispatcherType,
94 },
95 Req { id: Id, req: Req },
97}
98
99crepe::crepe! {
100 @input
101 struct Input(Fact);
102
103 struct Edge(Id, Id, DispatcherType);
104 struct ExternalReq(Id, Req);
105
106 @output
107 struct Requirement(Id, Req);
108
109 Edge(from, to, dt) <- Input(f), let Fact::Edge { from, to, dt } = f;
111 Requirement(id, req) <- Input(f), let Fact::Req { id, req } = f;
112
113 Requirement(y, Req::AnySingleton) <- Edge(_, y, Simple);
115 Requirement(x, d) <- Edge(x, y, NoShuffle), Requirement(y, d);
117 Requirement(y, d) <- Edge(x, y, NoShuffle), Requirement(x, d);
118}
119
120#[derive(Debug, Clone, EnumAsInner)]
122pub(super) enum Distribution {
123 Singleton(WorkerId),
125
126 Hash(ActorAlignmentMapping),
128}
129
130impl Distribution {
131 pub fn parallelism(&self) -> usize {
133 self.actors().count()
134 }
135
136 pub fn actors(&self) -> impl Iterator<Item = ActorAlignmentId> + '_ {
138 match self {
139 Distribution::Singleton(p) => {
140 Either::Left(std::iter::once(ActorAlignmentId::new(*p as _, 0)))
141 }
142 Distribution::Hash(mapping) => Either::Right(mapping.iter_unique()),
143 }
144 }
145
146 pub fn vnode_count(&self) -> usize {
148 match self {
149 Distribution::Singleton(_) => 1, Distribution::Hash(mapping) => mapping.len(),
151 }
152 }
153
154 pub fn from_fragment(fragment: &Fragment, actor_location: &HashMap<ActorId, WorkerId>) -> Self {
156 match fragment.distribution_type {
157 FragmentDistributionType::Unspecified => unreachable!(),
158 FragmentDistributionType::Single => {
159 let actor_id = fragment.actors.iter().exactly_one().unwrap().actor_id;
160 let location = actor_location.get(&actor_id).unwrap();
161 Distribution::Singleton(*location)
162 }
163 FragmentDistributionType::Hash => {
164 let actor_bitmaps: HashMap<_, _> = fragment
165 .actors
166 .iter()
167 .map(|actor| {
168 (
169 actor.actor_id as hash::ActorId,
170 actor.vnode_bitmap.clone().unwrap(),
171 )
172 })
173 .collect();
174
175 let actor_mapping = ActorMapping::from_bitmaps(&actor_bitmaps);
176 let actor_location = actor_location
177 .iter()
178 .map(|(&k, &v)| (k, v as u32))
179 .collect();
180 let mapping = actor_mapping.to_actor_alignment(&actor_location);
181
182 Distribution::Hash(mapping)
183 }
184 }
185 }
186
187 pub fn to_distribution_type(&self) -> PbFragmentDistributionType {
189 match self {
190 Distribution::Singleton(_) => PbFragmentDistributionType::Single,
191 Distribution::Hash(_) => PbFragmentDistributionType::Hash,
192 }
193 }
194}
195
196pub(super) struct Scheduler {
198 default_hash_mapping: ActorAlignmentMapping,
200
201 default_singleton_worker: WorkerId,
203
204 dynamic_mapping_fn: Box<dyn Fn(usize, Option<usize>) -> anyhow::Result<ActorAlignmentMapping>>,
206}
207
208impl Scheduler {
209 pub fn new(
217 streaming_job_id: u32,
218 workers: &HashMap<u32, WorkerNode>,
219 default_parallelism: NonZeroUsize,
220 expected_vnode_count: usize,
221 ) -> MetaResult<Self> {
222 let parallelism = default_parallelism.get();
223 assert!(
224 parallelism <= expected_vnode_count,
225 "parallelism should be limited by vnode count in previous steps"
226 );
227
228 let assigner = AssignerBuilder::new(streaming_job_id).build();
229
230 let worker_weights = workers
231 .iter()
232 .map(|(worker_id, worker)| {
233 (
234 *worker_id,
235 NonZeroUsize::new(worker.compute_node_parallelism()).unwrap(),
236 )
237 })
238 .collect();
239
240 let actor_idxes = (0..parallelism).collect_vec();
241 let vnodes = (0..expected_vnode_count).collect_vec();
242
243 let assignment = assigner.assign_hierarchical(&worker_weights, &actor_idxes, &vnodes)?;
244
245 let default_hash_mapping =
246 ActorAlignmentMapping::from_assignment(assignment, expected_vnode_count);
247
248 let single_actor_idxes = std::iter::once(0).collect_vec();
249
250 let single_assignment =
251 assigner.assign_hierarchical(&worker_weights, &single_actor_idxes, &vnodes)?;
252
253 let default_singleton_worker =
254 single_assignment.keys().exactly_one().cloned().unwrap() as _;
255
256 let dynamic_mapping_fn = Box::new(
257 move |limited_count: usize, force_parallelism: Option<usize>| {
258 let parallelism = if let Some(force_parallelism) = force_parallelism {
259 force_parallelism.min(limited_count)
260 } else {
261 parallelism.min(limited_count)
262 };
263 let assignment = assigner.assign_hierarchical(
264 &worker_weights,
265 &(0..parallelism).collect_vec(),
266 &(0..limited_count).collect_vec(),
267 )?;
268
269 let mapping = ActorAlignmentMapping::from_assignment(assignment, limited_count);
270 Ok(mapping)
271 },
272 );
273 Ok(Self {
274 default_hash_mapping,
275 default_singleton_worker,
276 dynamic_mapping_fn,
277 })
278 }
279
280 pub fn schedule(
283 &self,
284 graph: &CompleteStreamFragmentGraph,
285 ) -> MetaResult<HashMap<Id, Distribution>> {
286 let existing_distribution = graph.existing_distribution();
287
288 let all_hash_mappings = existing_distribution
290 .values()
291 .flat_map(|dist| dist.as_hash())
292 .cloned()
293 .unique()
294 .collect_vec();
295 let hash_mapping_id: HashMap<_, _> = all_hash_mappings
296 .iter()
297 .enumerate()
298 .map(|(i, m)| (m.clone(), i))
299 .collect();
300
301 let mut facts = Vec::new();
302
303 for (&id, fragment) in graph.building_fragments() {
305 if fragment.requires_singleton {
306 facts.push(Fact::Req {
307 id,
308 req: Req::AnySingleton,
309 });
310 }
311 }
312 let mut force_parallelism_fragment_ids: HashMap<_, _> = HashMap::default();
313 for (&id, fragment) in graph.building_fragments() {
316 visit_fragment(fragment, |node| {
317 use risingwave_pb::stream_plan::stream_node::NodeBody;
318 let vnode_count = match node {
319 NodeBody::StreamScan(node) => {
320 if let Some(table) = &node.arrangement_table {
321 table.vnode_count()
322 } else if let Some(table) = &node.table_desc {
323 table.vnode_count()
324 } else {
325 return;
326 }
327 }
328 NodeBody::TemporalJoin(node) => node.get_table_desc().unwrap().vnode_count(),
329 NodeBody::BatchPlan(node) => node.get_table_desc().unwrap().vnode_count(),
330 NodeBody::Lookup(node) => node
331 .get_arrangement_table_info()
332 .unwrap()
333 .get_table_desc()
334 .unwrap()
335 .vnode_count(),
336 NodeBody::StreamCdcScan(node) => {
337 let Some(ref options) = node.options else {
338 return;
339 };
340 let options = CdcScanOptions::from_proto(options);
341 if options.is_parallelized_backfill() {
342 force_parallelism_fragment_ids
343 .insert(id, options.backfill_parallelism as usize);
344 CDC_BACKFILL_MAX_PARALLELISM as usize
345 } else {
346 return;
347 }
348 }
349 _ => return,
350 };
351 facts.push(Fact::Req {
352 id,
353 req: Req::AnyVnodeCount(vnode_count),
354 });
355 });
356 }
357 for (id, dist) in existing_distribution {
359 let req = match dist {
360 Distribution::Singleton(worker_id) => Req::Singleton(worker_id),
361 Distribution::Hash(mapping) => Req::Hash(hash_mapping_id[&mapping]),
362 };
363 facts.push(Fact::Req { id, req });
364 }
365 for (from, to, edge) in graph.all_edges() {
367 facts.push(Fact::Edge {
368 from,
369 to,
370 dt: edge.dispatch_strategy.r#type(),
371 });
372 }
373
374 let mut crepe = Crepe::new();
376 crepe.extend(facts.into_iter().map(Input));
377 let (reqs,) = crepe.run();
378 let reqs = reqs
379 .into_iter()
380 .map(|Requirement(id, req)| (id, req))
381 .into_group_map();
382
383 let mut distributions = HashMap::new();
385 for &id in graph.building_fragments().keys() {
386 let dist = match reqs.get(&id) {
387 Some(reqs) => {
389 let req = (reqs.iter().copied())
390 .try_reduce(|a, b| Req::merge(a, b, |id| all_hash_mappings[id].len()))
391 .with_context(|| {
392 format!("cannot fulfill scheduling requirements for fragment {id:?}")
393 })?
394 .unwrap();
395
396 match req {
398 Req::Singleton(worker_id) => Distribution::Singleton(worker_id),
399 Req::Hash(mapping) => {
400 Distribution::Hash(all_hash_mappings[mapping].clone())
401 }
402 Req::AnySingleton => Distribution::Singleton(self.default_singleton_worker),
403 Req::AnyVnodeCount(vnode_count) => {
404 let force_parallelism =
405 force_parallelism_fragment_ids.get(&id).copied();
406 let mapping = (self.dynamic_mapping_fn)(vnode_count, force_parallelism)
407 .with_context(|| {
408 format!(
409 "failed to build dynamic mapping for fragment {id:?} with vnode count {vnode_count}"
410 )
411 })?;
412
413 Distribution::Hash(mapping)
414 }
415 }
416 }
417 None => Distribution::Hash(self.default_hash_mapping.clone()),
419 };
420
421 distributions.insert(id, dist);
422 }
423
424 tracing::debug!(?distributions, "schedule fragments");
425
426 Ok(distributions)
427 }
428}
429
430#[cfg_attr(test, derive(Default))]
432pub struct Locations {
433 pub actor_locations: BTreeMap<ActorId, ActorAlignmentId>,
435 pub worker_locations: HashMap<WorkerId, WorkerNode>,
437}
438
439impl Locations {
440 pub fn worker_actors(&self) -> HashMap<WorkerId, Vec<ActorId>> {
442 self.actor_locations
443 .iter()
444 .map(|(actor_id, alignment_id)| (alignment_id.worker_id() as WorkerId, *actor_id))
445 .into_group_map()
446 }
447
448 pub fn actor_infos(&self) -> impl Iterator<Item = ActorInfo> + '_ {
450 self.actor_locations
451 .iter()
452 .map(|(actor_id, alignment_id)| ActorInfo {
453 actor_id: *actor_id,
454 host: self.worker_locations[&(alignment_id.worker_id() as WorkerId)]
455 .host
456 .clone(),
457 })
458 }
459}
460
461#[cfg(test)]
462mod tests {
463 use super::*;
464
465 #[derive(Debug)]
466 enum Result {
467 DefaultHash,
468 Required(Req),
469 }
470
471 impl Result {
472 #[allow(non_upper_case_globals)]
473 const DefaultSingleton: Self = Self::Required(Req::AnySingleton);
474 }
475
476 fn run_and_merge(
477 facts: impl IntoIterator<Item = Fact>,
478 mapping_len: impl Fn(HashMappingId) -> usize,
479 ) -> MetaResult<HashMap<Id, Req>> {
480 let mut crepe = Crepe::new();
481 crepe.extend(facts.into_iter().map(Input));
482 let (reqs,) = crepe.run();
483
484 let reqs = reqs
485 .into_iter()
486 .map(|Requirement(id, req)| (id, req))
487 .into_group_map();
488
489 let mut merged = HashMap::new();
490 for (id, reqs) in reqs {
491 let req = (reqs.iter().copied())
492 .try_reduce(|a, b| Req::merge(a, b, &mapping_len))
493 .with_context(|| {
494 format!("cannot fulfill scheduling requirements for fragment {id:?}")
495 })?
496 .unwrap();
497 merged.insert(id, req);
498 }
499
500 Ok(merged)
501 }
502
503 fn test_success(facts: impl IntoIterator<Item = Fact>, expected: HashMap<Id, Result>) {
504 test_success_with_mapping_len(facts, expected, |_| 0);
505 }
506
507 fn test_success_with_mapping_len(
508 facts: impl IntoIterator<Item = Fact>,
509 expected: HashMap<Id, Result>,
510 mapping_len: impl Fn(HashMappingId) -> usize,
511 ) {
512 let reqs = run_and_merge(facts, mapping_len).unwrap();
513
514 for (id, expected) in expected {
515 match (reqs.get(&id), expected) {
516 (None, Result::DefaultHash) => {}
517 (Some(actual), Result::Required(expected)) if *actual == expected => {}
518 (actual, expected) => panic!(
519 "unexpected result for fragment {id:?}\nactual: {actual:?}\nexpected: {expected:?}"
520 ),
521 }
522 }
523 }
524
525 fn test_failed(facts: impl IntoIterator<Item = Fact>) {
526 run_and_merge(facts, |_| 0).unwrap_err();
527 }
528
529 #[test]
531 fn test_single_fragment_hash() {
532 #[rustfmt::skip]
533 let facts = [];
534
535 let expected = maplit::hashmap! {
536 101.into() => Result::DefaultHash,
537 };
538
539 test_success(facts, expected);
540 }
541
542 #[test]
544 fn test_single_fragment_singleton() {
545 #[rustfmt::skip]
546 let facts = [
547 Fact::Req { id: 101.into(), req: Req::AnySingleton },
548 ];
549
550 let expected = maplit::hashmap! {
551 101.into() => Result::DefaultSingleton,
552 };
553
554 test_success(facts, expected);
555 }
556
557 #[test]
561 fn test_scheduling_mv_on_mv() {
562 #[rustfmt::skip]
563 let facts = [
564 Fact::Req { id: 1.into(), req: Req::Hash(1) },
565 Fact::Req { id: 2.into(), req: Req::Singleton(0) },
566 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
567 Fact::Edge { from: 2.into(), to: 102.into(), dt: NoShuffle },
568 Fact::Edge { from: 101.into(), to: 103.into(), dt: Hash },
569 Fact::Edge { from: 102.into(), to: 103.into(), dt: Hash },
570 Fact::Edge { from: 103.into(), to: 104.into(), dt: Simple },
571 ];
572
573 let expected = maplit::hashmap! {
574 101.into() => Result::Required(Req::Hash(1)),
575 102.into() => Result::Required(Req::Singleton(0)),
576 103.into() => Result::DefaultHash,
577 104.into() => Result::DefaultSingleton,
578 };
579
580 test_success(facts, expected);
581 }
582
583 #[test]
587 fn test_delta_join() {
588 #[rustfmt::skip]
589 let facts = [
590 Fact::Req { id: 1.into(), req: Req::Hash(1) },
591 Fact::Req { id: 2.into(), req: Req::Hash(2) },
592 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
593 Fact::Edge { from: 2.into(), to: 102.into(), dt: NoShuffle },
594 Fact::Edge { from: 101.into(), to: 103.into(), dt: NoShuffle },
595 Fact::Edge { from: 102.into(), to: 104.into(), dt: NoShuffle },
596 Fact::Edge { from: 101.into(), to: 104.into(), dt: Hash },
597 Fact::Edge { from: 102.into(), to: 103.into(), dt: Hash },
598 Fact::Edge { from: 103.into(), to: 105.into(), dt: Hash },
599 Fact::Edge { from: 104.into(), to: 105.into(), dt: Hash },
600 ];
601
602 let expected = maplit::hashmap! {
603 101.into() => Result::Required(Req::Hash(1)),
604 102.into() => Result::Required(Req::Hash(2)),
605 103.into() => Result::Required(Req::Hash(1)),
606 104.into() => Result::Required(Req::Hash(2)),
607 105.into() => Result::DefaultHash,
608 };
609
610 test_success(facts, expected);
611 }
612
613 #[test]
617 fn test_singleton_leaf() {
618 #[rustfmt::skip]
619 let facts = [
620 Fact::Req { id: 1.into(), req: Req::Hash(1) },
621 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
622 Fact::Req { id: 102.into(), req: Req::AnySingleton }, Fact::Edge { from: 101.into(), to: 103.into(), dt: Hash },
624 Fact::Edge { from: 102.into(), to: 103.into(), dt: Broadcast },
625 ];
626
627 let expected = maplit::hashmap! {
628 101.into() => Result::Required(Req::Hash(1)),
629 102.into() => Result::DefaultSingleton,
630 103.into() => Result::DefaultHash,
631 };
632
633 test_success(facts, expected);
634 }
635
636 #[test]
640 fn test_upstream_hash_shard_failed() {
641 #[rustfmt::skip]
642 let facts = [
643 Fact::Req { id: 1.into(), req: Req::Hash(1) },
644 Fact::Req { id: 2.into(), req: Req::Hash(2) },
645 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
646 Fact::Edge { from: 2.into(), to: 101.into(), dt: NoShuffle },
647 ];
648
649 test_failed(facts);
650 }
651
652 #[test]
654 fn test_arrangement_backfill_vnode_count() {
655 #[rustfmt::skip]
656 let facts = [
657 Fact::Req { id: 1.into(), req: Req::Hash(1) },
658 Fact::Req { id: 101.into(), req: Req::AnyVnodeCount(128) },
659 Fact::Edge { from: 1.into(), to: 101.into(), dt: Hash },
660 ];
661
662 let expected = maplit::hashmap! {
663 101.into() => Result::Required(Req::AnyVnodeCount(128)),
664 };
665
666 test_success(facts, expected);
667 }
668
669 #[test]
671 fn test_no_shuffle_backfill_vnode_count() {
672 #[rustfmt::skip]
673 let facts = [
674 Fact::Req { id: 1.into(), req: Req::Hash(1) },
675 Fact::Req { id: 101.into(), req: Req::AnyVnodeCount(128) },
676 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
677 ];
678
679 let expected = maplit::hashmap! {
680 101.into() => Result::Required(Req::Hash(1)),
681 };
682
683 test_success_with_mapping_len(facts, expected, |id| {
684 assert_eq!(id, 1);
685 128
686 });
687 }
688
689 #[test]
691 fn test_no_shuffle_backfill_mismatched_vnode_count() {
692 #[rustfmt::skip]
693 let facts = [
694 Fact::Req { id: 1.into(), req: Req::Hash(1) },
695 Fact::Req { id: 101.into(), req: Req::AnyVnodeCount(128) },
696 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
697 ];
698
699 test_failed(facts);
701 }
702
703 #[test]
705 fn test_backfill_singleton_vnode_count() {
706 #[rustfmt::skip]
707 let facts = [
708 Fact::Req { id: 1.into(), req: Req::Singleton(0) },
709 Fact::Req { id: 101.into(), req: Req::AnySingleton },
710 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle }, ];
712
713 let expected = maplit::hashmap! {
714 101.into() => Result::Required(Req::Singleton(0)),
715 };
716
717 test_success(facts, expected);
718 }
719}