1#![allow(
16 clippy::collapsible_if,
17 clippy::explicit_iter_loop,
18 reason = "generated by crepe"
19)]
20
21use std::collections::{BTreeMap, HashMap};
22use std::num::NonZeroUsize;
23
24use anyhow::Context;
25use either::Either;
26use enum_as_inner::EnumAsInner;
27use itertools::Itertools;
28use risingwave_common::hash::{
29 ActorAlignmentId, ActorAlignmentMapping, ActorMapping, VnodeCountCompat,
30};
31use risingwave_common::id::JobId;
32use risingwave_common::util::stream_graph_visitor::visit_fragment;
33use risingwave_common::{bail, hash};
34use risingwave_connector::source::cdc::{CDC_BACKFILL_MAX_PARALLELISM, CdcScanOptions};
35use risingwave_meta_model::WorkerId;
36use risingwave_meta_model::fragment::DistributionType;
37use risingwave_pb::common::WorkerNode;
38use risingwave_pb::meta::table_fragments::fragment::PbFragmentDistributionType;
39use risingwave_pb::stream_plan::DispatcherType::{self, *};
40
41use crate::MetaResult;
42use crate::barrier::SharedFragmentInfo;
43use crate::model::ActorId;
44use crate::stream::AssignerBuilder;
45use crate::stream::stream_graph::fragment::CompleteStreamFragmentGraph;
46use crate::stream::stream_graph::id::GlobalFragmentId as Id;
47
48type HashMappingId = usize;
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
52enum Req {
53 Singleton(WorkerId),
55 Hash(HashMappingId),
57 AnyVnodeCount(usize),
60}
61
62impl Req {
63 #[allow(non_upper_case_globals)]
65 const AnySingleton: Self = Self::AnyVnodeCount(1);
66
67 fn merge(a: Self, b: Self, mapping_len: impl Fn(HashMappingId) -> usize) -> MetaResult<Self> {
71 let merge = |a, b| match (a, b) {
73 (Self::AnySingleton, Self::Singleton(id)) => Some(Self::Singleton(id)),
74 (Self::AnyVnodeCount(count), Self::Hash(id)) if mapping_len(id) == count => {
75 Some(Self::Hash(id))
76 }
77 _ => None,
78 };
79
80 match merge(a, b).or_else(|| merge(b, a)) {
81 Some(req) => Ok(req),
82 None => bail!("incompatible requirements `{a:?}` and `{b:?}`"),
83 }
84 }
85}
86
87#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
89enum Fact {
90 Edge {
92 from: Id,
93 to: Id,
94 dt: DispatcherType,
95 },
96 Req { id: Id, req: Req },
98}
99
100crepe::crepe! {
101 @input
102 struct Input(Fact);
103
104 struct Edge(Id, Id, DispatcherType);
105 struct ExternalReq(Id, Req);
106
107 @output
108 struct Requirement(Id, Req);
109
110 Edge(from, to, dt) <- Input(f), let Fact::Edge { from, to, dt } = f;
112 Requirement(id, req) <- Input(f), let Fact::Req { id, req } = f;
113
114 Requirement(y, Req::AnySingleton) <- Edge(_, y, Simple);
116 Requirement(x, d) <- Edge(x, y, NoShuffle), Requirement(y, d);
118 Requirement(y, d) <- Edge(x, y, NoShuffle), Requirement(x, d);
119}
120
121#[derive(Debug, Clone, EnumAsInner)]
123pub(super) enum Distribution {
124 Singleton(WorkerId),
126
127 Hash(ActorAlignmentMapping),
129}
130
131impl Distribution {
132 pub fn parallelism(&self) -> usize {
134 self.actors().count()
135 }
136
137 pub fn actors(&self) -> impl Iterator<Item = ActorAlignmentId> + '_ {
139 match self {
140 Distribution::Singleton(p) => {
141 Either::Left(std::iter::once(ActorAlignmentId::new(*p, 0)))
142 }
143 Distribution::Hash(mapping) => Either::Right(mapping.iter_unique()),
144 }
145 }
146
147 pub fn vnode_count(&self) -> usize {
149 match self {
150 Distribution::Singleton(_) => 1, Distribution::Hash(mapping) => mapping.len(),
152 }
153 }
154
155 pub fn from_fragment(
157 fragment: &SharedFragmentInfo,
158 actor_location: &HashMap<ActorId, WorkerId>,
159 ) -> Self {
160 match fragment.distribution_type {
161 DistributionType::Single => {
162 let (actor_id, _) = fragment.actors.iter().exactly_one().unwrap();
163 let location = actor_location.get(actor_id).unwrap();
164 Distribution::Singleton(*location)
165 }
166 DistributionType::Hash => {
167 let actor_bitmaps: HashMap<_, _> = fragment
168 .actors
169 .iter()
170 .map(|(actor_id, actor_info)| {
171 (
172 *actor_id as hash::ActorId,
173 actor_info.vnode_bitmap.clone().unwrap(),
174 )
175 })
176 .collect();
177
178 let actor_mapping = ActorMapping::from_bitmaps(&actor_bitmaps);
179 let actor_location = actor_location.iter().map(|(&k, &v)| (k, v)).collect();
180 let mapping = actor_mapping.to_actor_alignment(&actor_location);
181
182 Distribution::Hash(mapping)
183 }
184 }
185 }
186
187 pub fn to_distribution_type(&self) -> PbFragmentDistributionType {
189 match self {
190 Distribution::Singleton(_) => PbFragmentDistributionType::Single,
191 Distribution::Hash(_) => PbFragmentDistributionType::Hash,
192 }
193 }
194}
195
196pub(super) struct Scheduler {
198 default_hash_mapping: ActorAlignmentMapping,
200
201 default_singleton_worker: WorkerId,
203
204 dynamic_mapping_fn: Box<dyn Fn(usize, Option<usize>) -> anyhow::Result<ActorAlignmentMapping>>,
206}
207
208impl Scheduler {
209 pub fn new(
217 streaming_job_id: JobId,
218 workers: &HashMap<WorkerId, WorkerNode>,
219 default_parallelism: NonZeroUsize,
220 expected_vnode_count: usize,
221 ) -> MetaResult<Self> {
222 let parallelism = default_parallelism.get();
223 assert!(
224 parallelism <= expected_vnode_count,
225 "parallelism should be limited by vnode count in previous steps"
226 );
227
228 let assigner = AssignerBuilder::new(streaming_job_id).build();
229
230 let worker_weights = workers
231 .iter()
232 .map(|(worker_id, worker)| {
233 (
234 *worker_id,
235 NonZeroUsize::new(worker.compute_node_parallelism()).unwrap(),
236 )
237 })
238 .collect();
239
240 let actor_idxes = (0..parallelism).collect_vec();
241 let vnodes = (0..expected_vnode_count).collect_vec();
242
243 let assignment = assigner.assign_hierarchical(&worker_weights, &actor_idxes, &vnodes)?;
244
245 let default_hash_mapping =
246 ActorAlignmentMapping::from_assignment(assignment, expected_vnode_count);
247
248 let single_actor_idxes = std::iter::once(0).collect_vec();
249
250 let single_assignment =
251 assigner.assign_hierarchical(&worker_weights, &single_actor_idxes, &vnodes)?;
252
253 let default_singleton_worker =
254 single_assignment.keys().exactly_one().cloned().unwrap() as _;
255
256 let dynamic_mapping_fn = Box::new(
257 move |limited_count: usize, force_parallelism: Option<usize>| {
258 let parallelism = if let Some(force_parallelism) = force_parallelism {
259 force_parallelism.min(limited_count)
260 } else {
261 parallelism.min(limited_count)
262 };
263 let assignment = assigner.assign_hierarchical(
264 &worker_weights,
265 &(0..parallelism).collect_vec(),
266 &(0..limited_count).collect_vec(),
267 )?;
268
269 let mapping = ActorAlignmentMapping::from_assignment(assignment, limited_count);
270 Ok(mapping)
271 },
272 );
273 Ok(Self {
274 default_hash_mapping,
275 default_singleton_worker,
276 dynamic_mapping_fn,
277 })
278 }
279
280 pub fn schedule(
283 &self,
284 graph: &CompleteStreamFragmentGraph,
285 ) -> MetaResult<HashMap<Id, Distribution>> {
286 let existing_distribution = graph.existing_distribution();
287
288 let all_hash_mappings = existing_distribution
290 .values()
291 .flat_map(|dist| dist.as_hash())
292 .cloned()
293 .unique()
294 .collect_vec();
295 let hash_mapping_id: HashMap<_, _> = all_hash_mappings
296 .iter()
297 .enumerate()
298 .map(|(i, m)| (m.clone(), i))
299 .collect();
300
301 let mut facts = Vec::new();
302
303 for (&id, fragment) in graph.building_fragments() {
305 if fragment.requires_singleton {
306 facts.push(Fact::Req {
307 id,
308 req: Req::AnySingleton,
309 });
310 }
311 }
312 let mut force_parallelism_fragment_ids: HashMap<_, _> = HashMap::default();
313 for (&id, fragment) in graph.building_fragments() {
316 visit_fragment(fragment, |node| {
317 use risingwave_pb::stream_plan::stream_node::NodeBody;
318 let vnode_count = match node {
319 NodeBody::StreamScan(node) => {
320 if let Some(table) = &node.arrangement_table {
321 table.vnode_count()
322 } else if let Some(table) = &node.table_desc {
323 table.vnode_count()
324 } else {
325 return;
326 }
327 }
328 NodeBody::TemporalJoin(node) => node.get_table_desc().unwrap().vnode_count(),
329 NodeBody::BatchPlan(node) => node.get_table_desc().unwrap().vnode_count(),
330 NodeBody::Lookup(node) => node
331 .get_arrangement_table_info()
332 .unwrap()
333 .get_table_desc()
334 .unwrap()
335 .vnode_count(),
336 NodeBody::StreamCdcScan(node) => {
337 let Some(ref options) = node.options else {
338 return;
339 };
340 let options = CdcScanOptions::from_proto(options);
341 if options.is_parallelized_backfill() {
342 force_parallelism_fragment_ids
343 .insert(id, options.backfill_parallelism as usize);
344 CDC_BACKFILL_MAX_PARALLELISM as usize
345 } else {
346 return;
347 }
348 }
349 NodeBody::GapFill(node) => {
350 let buffer_table = node.get_state_table().unwrap();
352 if let Some(vnode_count) = buffer_table.vnode_count_inner().value_opt() {
354 vnode_count
355 } else {
356 return;
358 }
359 }
360 _ => return,
361 };
362 facts.push(Fact::Req {
363 id,
364 req: Req::AnyVnodeCount(vnode_count),
365 });
366 });
367 }
368 for (id, dist) in existing_distribution {
370 let req = match dist {
371 Distribution::Singleton(worker_id) => Req::Singleton(worker_id),
372 Distribution::Hash(mapping) => Req::Hash(hash_mapping_id[&mapping]),
373 };
374 facts.push(Fact::Req { id, req });
375 }
376 for (from, to, edge) in graph.all_edges() {
378 facts.push(Fact::Edge {
379 from,
380 to,
381 dt: edge.dispatch_strategy.r#type(),
382 });
383 }
384
385 let mut crepe = Crepe::new();
387 crepe.extend(facts.into_iter().map(Input));
388 let (reqs,) = crepe.run();
389 let reqs = reqs
390 .into_iter()
391 .map(|Requirement(id, req)| (id, req))
392 .into_group_map();
393
394 let mut distributions = HashMap::new();
396 for &id in graph.building_fragments().keys() {
397 let dist = match reqs.get(&id) {
398 Some(reqs) => {
400 let req = (reqs.iter().copied())
401 .try_reduce(|a, b| Req::merge(a, b, |id| all_hash_mappings[id].len()))
402 .with_context(|| {
403 format!("cannot fulfill scheduling requirements for fragment {id:?}")
404 })?
405 .unwrap();
406
407 match req {
409 Req::Singleton(worker_id) => Distribution::Singleton(worker_id),
410 Req::Hash(mapping) => {
411 Distribution::Hash(all_hash_mappings[mapping].clone())
412 }
413 Req::AnySingleton => Distribution::Singleton(self.default_singleton_worker),
414 Req::AnyVnodeCount(vnode_count) => {
415 let force_parallelism =
416 force_parallelism_fragment_ids.get(&id).copied();
417 let mapping = (self.dynamic_mapping_fn)(vnode_count, force_parallelism)
418 .with_context(|| {
419 format!(
420 "failed to build dynamic mapping for fragment {id:?} with vnode count {vnode_count}"
421 )
422 })?;
423
424 Distribution::Hash(mapping)
425 }
426 }
427 }
428 None => Distribution::Hash(self.default_hash_mapping.clone()),
430 };
431
432 distributions.insert(id, dist);
433 }
434
435 tracing::debug!(?distributions, "schedule fragments");
436
437 Ok(distributions)
438 }
439}
440
441#[cfg_attr(test, derive(Default))]
443pub struct Locations {
444 pub actor_locations: BTreeMap<ActorId, ActorAlignmentId>,
446 pub worker_locations: HashMap<WorkerId, WorkerNode>,
448}
449
450#[cfg(test)]
451mod tests {
452 use super::*;
453
454 #[derive(Debug)]
455 enum Result {
456 DefaultHash,
457 Required(Req),
458 }
459
460 impl Result {
461 #[allow(non_upper_case_globals)]
462 const DefaultSingleton: Self = Self::Required(Req::AnySingleton);
463 }
464
465 fn run_and_merge(
466 facts: impl IntoIterator<Item = Fact>,
467 mapping_len: impl Fn(HashMappingId) -> usize,
468 ) -> MetaResult<HashMap<Id, Req>> {
469 let mut crepe = Crepe::new();
470 crepe.extend(facts.into_iter().map(Input));
471 let (reqs,) = crepe.run();
472
473 let reqs = reqs
474 .into_iter()
475 .map(|Requirement(id, req)| (id, req))
476 .into_group_map();
477
478 let mut merged = HashMap::new();
479 for (id, reqs) in reqs {
480 let req = (reqs.iter().copied())
481 .try_reduce(|a, b| Req::merge(a, b, &mapping_len))
482 .with_context(|| {
483 format!("cannot fulfill scheduling requirements for fragment {id:?}")
484 })?
485 .unwrap();
486 merged.insert(id, req);
487 }
488
489 Ok(merged)
490 }
491
492 fn test_success(facts: impl IntoIterator<Item = Fact>, expected: HashMap<Id, Result>) {
493 test_success_with_mapping_len(facts, expected, |_| 0);
494 }
495
496 fn test_success_with_mapping_len(
497 facts: impl IntoIterator<Item = Fact>,
498 expected: HashMap<Id, Result>,
499 mapping_len: impl Fn(HashMappingId) -> usize,
500 ) {
501 let reqs = run_and_merge(facts, mapping_len).unwrap();
502
503 for (id, expected) in expected {
504 match (reqs.get(&id), expected) {
505 (None, Result::DefaultHash) => {}
506 (Some(actual), Result::Required(expected)) if *actual == expected => {}
507 (actual, expected) => panic!(
508 "unexpected result for fragment {id:?}\nactual: {actual:?}\nexpected: {expected:?}"
509 ),
510 }
511 }
512 }
513
514 fn test_failed(facts: impl IntoIterator<Item = Fact>) {
515 run_and_merge(facts, |_| 0).unwrap_err();
516 }
517
518 #[test]
520 fn test_single_fragment_hash() {
521 #[rustfmt::skip]
522 let facts = [];
523
524 let expected = maplit::hashmap! {
525 101.into() => Result::DefaultHash,
526 };
527
528 test_success(facts, expected);
529 }
530
531 #[test]
533 fn test_single_fragment_singleton() {
534 #[rustfmt::skip]
535 let facts = [
536 Fact::Req { id: 101.into(), req: Req::AnySingleton },
537 ];
538
539 let expected = maplit::hashmap! {
540 101.into() => Result::DefaultSingleton,
541 };
542
543 test_success(facts, expected);
544 }
545
546 #[test]
550 fn test_scheduling_mv_on_mv() {
551 #[rustfmt::skip]
552 let facts = [
553 Fact::Req { id: 1.into(), req: Req::Hash(1) },
554 Fact::Req { id: 2.into(), req: Req::Singleton(0.into()) },
555 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
556 Fact::Edge { from: 2.into(), to: 102.into(), dt: NoShuffle },
557 Fact::Edge { from: 101.into(), to: 103.into(), dt: Hash },
558 Fact::Edge { from: 102.into(), to: 103.into(), dt: Hash },
559 Fact::Edge { from: 103.into(), to: 104.into(), dt: Simple },
560 ];
561
562 let expected = maplit::hashmap! {
563 101.into() => Result::Required(Req::Hash(1)),
564 102.into() => Result::Required(Req::Singleton(0.into())),
565 103.into() => Result::DefaultHash,
566 104.into() => Result::DefaultSingleton,
567 };
568
569 test_success(facts, expected);
570 }
571
572 #[test]
576 fn test_delta_join() {
577 #[rustfmt::skip]
578 let facts = [
579 Fact::Req { id: 1.into(), req: Req::Hash(1) },
580 Fact::Req { id: 2.into(), req: Req::Hash(2) },
581 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
582 Fact::Edge { from: 2.into(), to: 102.into(), dt: NoShuffle },
583 Fact::Edge { from: 101.into(), to: 103.into(), dt: NoShuffle },
584 Fact::Edge { from: 102.into(), to: 104.into(), dt: NoShuffle },
585 Fact::Edge { from: 101.into(), to: 104.into(), dt: Hash },
586 Fact::Edge { from: 102.into(), to: 103.into(), dt: Hash },
587 Fact::Edge { from: 103.into(), to: 105.into(), dt: Hash },
588 Fact::Edge { from: 104.into(), to: 105.into(), dt: Hash },
589 ];
590
591 let expected = maplit::hashmap! {
592 101.into() => Result::Required(Req::Hash(1)),
593 102.into() => Result::Required(Req::Hash(2)),
594 103.into() => Result::Required(Req::Hash(1)),
595 104.into() => Result::Required(Req::Hash(2)),
596 105.into() => Result::DefaultHash,
597 };
598
599 test_success(facts, expected);
600 }
601
602 #[test]
606 fn test_singleton_leaf() {
607 #[rustfmt::skip]
608 let facts = [
609 Fact::Req { id: 1.into(), req: Req::Hash(1) },
610 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
611 Fact::Req { id: 102.into(), req: Req::AnySingleton }, Fact::Edge { from: 101.into(), to: 103.into(), dt: Hash },
613 Fact::Edge { from: 102.into(), to: 103.into(), dt: Broadcast },
614 ];
615
616 let expected = maplit::hashmap! {
617 101.into() => Result::Required(Req::Hash(1)),
618 102.into() => Result::DefaultSingleton,
619 103.into() => Result::DefaultHash,
620 };
621
622 test_success(facts, expected);
623 }
624
625 #[test]
629 fn test_upstream_hash_shard_failed() {
630 #[rustfmt::skip]
631 let facts = [
632 Fact::Req { id: 1.into(), req: Req::Hash(1) },
633 Fact::Req { id: 2.into(), req: Req::Hash(2) },
634 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
635 Fact::Edge { from: 2.into(), to: 101.into(), dt: NoShuffle },
636 ];
637
638 test_failed(facts);
639 }
640
641 #[test]
643 fn test_arrangement_backfill_vnode_count() {
644 #[rustfmt::skip]
645 let facts = [
646 Fact::Req { id: 1.into(), req: Req::Hash(1) },
647 Fact::Req { id: 101.into(), req: Req::AnyVnodeCount(128) },
648 Fact::Edge { from: 1.into(), to: 101.into(), dt: Hash },
649 ];
650
651 let expected = maplit::hashmap! {
652 101.into() => Result::Required(Req::AnyVnodeCount(128)),
653 };
654
655 test_success(facts, expected);
656 }
657
658 #[test]
660 fn test_no_shuffle_backfill_vnode_count() {
661 #[rustfmt::skip]
662 let facts = [
663 Fact::Req { id: 1.into(), req: Req::Hash(1) },
664 Fact::Req { id: 101.into(), req: Req::AnyVnodeCount(128) },
665 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
666 ];
667
668 let expected = maplit::hashmap! {
669 101.into() => Result::Required(Req::Hash(1)),
670 };
671
672 test_success_with_mapping_len(facts, expected, |id| {
673 assert_eq!(id, 1);
674 128
675 });
676 }
677
678 #[test]
680 fn test_no_shuffle_backfill_mismatched_vnode_count() {
681 #[rustfmt::skip]
682 let facts = [
683 Fact::Req { id: 1.into(), req: Req::Hash(1) },
684 Fact::Req { id: 101.into(), req: Req::AnyVnodeCount(128) },
685 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
686 ];
687
688 test_failed(facts);
690 }
691
692 #[test]
694 fn test_backfill_singleton_vnode_count() {
695 #[rustfmt::skip]
696 let facts = [
697 Fact::Req { id: 1.into(), req: Req::Singleton(0.into()) },
698 Fact::Req { id: 101.into(), req: Req::AnySingleton },
699 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle }, ];
701
702 let expected = maplit::hashmap! {
703 101.into() => Result::Required(Req::Singleton(0.into())),
704 };
705
706 test_success(facts, expected);
707 }
708}