1#![allow(
16 clippy::collapsible_if,
17 clippy::explicit_iter_loop,
18 reason = "generated by crepe"
19)]
20
21use std::collections::{BTreeMap, HashMap};
22use std::num::NonZeroUsize;
23
24use anyhow::Context;
25use either::Either;
26use enum_as_inner::EnumAsInner;
27use itertools::Itertools;
28use risingwave_common::hash::{ActorMapping, VnodeCountCompat, WorkerSlotId, WorkerSlotMapping};
29use risingwave_common::util::stream_graph_visitor::visit_fragment;
30use risingwave_common::{bail, hash};
31use risingwave_meta_model::WorkerId;
32use risingwave_pb::common::{ActorInfo, WorkerNode};
33use risingwave_pb::meta::table_fragments::fragment::{
34 FragmentDistributionType, PbFragmentDistributionType,
35};
36use risingwave_pb::stream_plan::DispatcherType::{self, *};
37
38use crate::MetaResult;
39use crate::model::{ActorId, Fragment};
40use crate::stream::schedule_units_for_slots;
41use crate::stream::stream_graph::fragment::CompleteStreamFragmentGraph;
42use crate::stream::stream_graph::id::GlobalFragmentId as Id;
43
44type HashMappingId = usize;
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
48enum Req {
49 Singleton(WorkerSlotId),
51 Hash(HashMappingId),
53 AnyVnodeCount(usize),
56}
57
58impl Req {
59 #[allow(non_upper_case_globals)]
61 const AnySingleton: Self = Self::AnyVnodeCount(1);
62
63 fn merge(a: Self, b: Self, mapping_len: impl Fn(HashMappingId) -> usize) -> MetaResult<Self> {
67 let merge = |a, b| match (a, b) {
69 (Self::AnySingleton, Self::Singleton(id)) => Some(Self::Singleton(id)),
70 (Self::AnyVnodeCount(count), Self::Hash(id)) if mapping_len(id) == count => {
71 Some(Self::Hash(id))
72 }
73 _ => None,
74 };
75
76 match merge(a, b).or_else(|| merge(b, a)) {
77 Some(req) => Ok(req),
78 None => bail!("incompatible requirements `{a:?}` and `{b:?}`"),
79 }
80 }
81}
82
83#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
85enum Fact {
86 Edge {
88 from: Id,
89 to: Id,
90 dt: DispatcherType,
91 },
92 Req { id: Id, req: Req },
94}
95
96crepe::crepe! {
97 @input
98 struct Input(Fact);
99
100 struct Edge(Id, Id, DispatcherType);
101 struct ExternalReq(Id, Req);
102
103 @output
104 struct Requirement(Id, Req);
105
106 Edge(from, to, dt) <- Input(f), let Fact::Edge { from, to, dt } = f;
108 Requirement(id, req) <- Input(f), let Fact::Req { id, req } = f;
109
110 Requirement(y, Req::AnySingleton) <- Edge(_, y, Simple);
112 Requirement(x, d) <- Edge(x, y, NoShuffle), Requirement(y, d);
114 Requirement(y, d) <- Edge(x, y, NoShuffle), Requirement(x, d);
115}
116
117#[derive(Debug, Clone, EnumAsInner)]
119pub(super) enum Distribution {
120 Singleton(WorkerSlotId),
122
123 Hash(WorkerSlotMapping),
125}
126
127impl Distribution {
128 pub fn parallelism(&self) -> usize {
130 self.worker_slots().count()
131 }
132
133 pub fn worker_slots(&self) -> impl Iterator<Item = WorkerSlotId> + '_ {
135 match self {
136 Distribution::Singleton(p) => Either::Left(std::iter::once(*p)),
137 Distribution::Hash(mapping) => Either::Right(mapping.iter_unique()),
138 }
139 }
140
141 pub fn vnode_count(&self) -> usize {
143 match self {
144 Distribution::Singleton(_) => 1, Distribution::Hash(mapping) => mapping.len(),
146 }
147 }
148
149 pub fn from_fragment(fragment: &Fragment, actor_location: &HashMap<ActorId, WorkerId>) -> Self {
151 match fragment.distribution_type {
152 FragmentDistributionType::Unspecified => unreachable!(),
153 FragmentDistributionType::Single => {
154 let actor_id = fragment.actors.iter().exactly_one().unwrap().actor_id;
155 let location = actor_location.get(&actor_id).unwrap();
156 let worker_slot_id = WorkerSlotId::new(*location as _, 0);
157 Distribution::Singleton(worker_slot_id)
158 }
159 FragmentDistributionType::Hash => {
160 let actor_bitmaps: HashMap<_, _> = fragment
161 .actors
162 .iter()
163 .map(|actor| {
164 (
165 actor.actor_id as hash::ActorId,
166 actor.vnode_bitmap.clone().unwrap(),
167 )
168 })
169 .collect();
170
171 let actor_mapping = ActorMapping::from_bitmaps(&actor_bitmaps);
172 let actor_location = actor_location
173 .iter()
174 .map(|(&k, &v)| (k, v as u32))
175 .collect();
176 let mapping = actor_mapping.to_worker_slot(&actor_location);
177
178 Distribution::Hash(mapping)
179 }
180 }
181 }
182
183 pub fn to_distribution_type(&self) -> PbFragmentDistributionType {
185 match self {
186 Distribution::Singleton(_) => PbFragmentDistributionType::Single,
187 Distribution::Hash(_) => PbFragmentDistributionType::Hash,
188 }
189 }
190}
191
192pub(super) struct Scheduler {
194 scheduled_worker_slots: Vec<WorkerSlotId>,
196
197 default_hash_mapping: WorkerSlotMapping,
199
200 default_singleton_worker_slot: WorkerSlotId,
202}
203
204impl Scheduler {
205 pub fn new(
213 streaming_job_id: u32,
214 workers: &HashMap<u32, WorkerNode>,
215 default_parallelism: NonZeroUsize,
216 expected_vnode_count: usize,
217 ) -> MetaResult<Self> {
218 let slots = workers
221 .iter()
222 .map(|(worker_id, worker)| (*worker_id as WorkerId, worker.compute_node_parallelism()))
223 .collect();
224
225 let parallelism = default_parallelism.get();
226 assert!(
227 parallelism <= expected_vnode_count,
228 "parallelism should be limited by vnode count in previous steps"
229 );
230
231 let scheduled = schedule_units_for_slots(&slots, parallelism, streaming_job_id)?;
232
233 let scheduled_worker_slots = scheduled
234 .into_iter()
235 .flat_map(|(worker_id, size)| {
236 (0..size).map(move |slot| WorkerSlotId::new(worker_id as _, slot))
237 })
238 .collect_vec();
239
240 assert_eq!(scheduled_worker_slots.len(), parallelism);
241
242 let default_hash_mapping =
244 WorkerSlotMapping::build_from_ids(&scheduled_worker_slots, expected_vnode_count);
245
246 let single_scheduled = schedule_units_for_slots(&slots, 1, streaming_job_id)?;
247 let default_single_worker_id = single_scheduled.keys().exactly_one().cloned().unwrap();
248 let default_singleton_worker_slot = WorkerSlotId::new(default_single_worker_id as _, 0);
249
250 Ok(Self {
251 scheduled_worker_slots,
252 default_hash_mapping,
253 default_singleton_worker_slot,
254 })
255 }
256
257 pub fn schedule(
260 &self,
261 graph: &CompleteStreamFragmentGraph,
262 ) -> MetaResult<HashMap<Id, Distribution>> {
263 let existing_distribution = graph.existing_distribution();
264
265 let all_hash_mappings = existing_distribution
267 .values()
268 .flat_map(|dist| dist.as_hash())
269 .cloned()
270 .unique()
271 .collect_vec();
272 let hash_mapping_id: HashMap<_, _> = all_hash_mappings
273 .iter()
274 .enumerate()
275 .map(|(i, m)| (m.clone(), i))
276 .collect();
277
278 let mut facts = Vec::new();
279
280 for (&id, fragment) in graph.building_fragments() {
282 if fragment.requires_singleton {
283 facts.push(Fact::Req {
284 id,
285 req: Req::AnySingleton,
286 });
287 }
288 }
289 for (&id, fragment) in graph.building_fragments() {
292 visit_fragment(fragment, |node| {
293 use risingwave_pb::stream_plan::stream_node::NodeBody;
294 let vnode_count = match node {
295 NodeBody::StreamScan(node) => {
296 if let Some(table) = &node.arrangement_table {
297 table.vnode_count()
298 } else if let Some(table) = &node.table_desc {
299 table.vnode_count()
300 } else {
301 return;
302 }
303 }
304 NodeBody::TemporalJoin(node) => node.get_table_desc().unwrap().vnode_count(),
305 NodeBody::BatchPlan(node) => node.get_table_desc().unwrap().vnode_count(),
306 NodeBody::Lookup(node) => node
307 .get_arrangement_table_info()
308 .unwrap()
309 .get_table_desc()
310 .unwrap()
311 .vnode_count(),
312 _ => return,
313 };
314 facts.push(Fact::Req {
315 id,
316 req: Req::AnyVnodeCount(vnode_count),
317 });
318 });
319 }
320 for (id, dist) in existing_distribution {
322 let req = match dist {
323 Distribution::Singleton(worker_slot_id) => Req::Singleton(worker_slot_id),
324 Distribution::Hash(mapping) => Req::Hash(hash_mapping_id[&mapping]),
325 };
326 facts.push(Fact::Req { id, req });
327 }
328 for (from, to, edge) in graph.all_edges() {
330 facts.push(Fact::Edge {
331 from,
332 to,
333 dt: edge.dispatch_strategy.r#type(),
334 });
335 }
336
337 let mut crepe = Crepe::new();
339 crepe.extend(facts.into_iter().map(Input));
340 let (reqs,) = crepe.run();
341 let reqs = reqs
342 .into_iter()
343 .map(|Requirement(id, req)| (id, req))
344 .into_group_map();
345
346 let mut distributions = HashMap::new();
348 for &id in graph.building_fragments().keys() {
349 let dist = match reqs.get(&id) {
350 Some(reqs) => {
352 let req = (reqs.iter().copied())
353 .try_reduce(|a, b| Req::merge(a, b, |id| all_hash_mappings[id].len()))
354 .with_context(|| {
355 format!("cannot fulfill scheduling requirements for fragment {id:?}")
356 })?
357 .unwrap();
358
359 match req {
361 Req::Singleton(worker_slot) => Distribution::Singleton(worker_slot),
362 Req::Hash(mapping) => {
363 Distribution::Hash(all_hash_mappings[mapping].clone())
364 }
365 Req::AnySingleton => {
366 Distribution::Singleton(self.default_singleton_worker_slot)
367 }
368 Req::AnyVnodeCount(vnode_count) => {
369 let len = self.scheduled_worker_slots.len().min(vnode_count);
370 let mapping = WorkerSlotMapping::build_from_ids(
371 &self.scheduled_worker_slots[..len],
372 vnode_count,
373 );
374 Distribution::Hash(mapping)
375 }
376 }
377 }
378 None => Distribution::Hash(self.default_hash_mapping.clone()),
380 };
381
382 distributions.insert(id, dist);
383 }
384
385 tracing::debug!(?distributions, "schedule fragments");
386
387 Ok(distributions)
388 }
389}
390
391#[cfg_attr(test, derive(Default))]
393pub struct Locations {
394 pub actor_locations: BTreeMap<ActorId, WorkerSlotId>,
396 pub worker_locations: HashMap<WorkerId, WorkerNode>,
398}
399
400impl Locations {
401 pub fn worker_actors(&self) -> HashMap<WorkerId, Vec<ActorId>> {
403 self.actor_locations
404 .iter()
405 .map(|(actor_id, worker_slot_id)| (worker_slot_id.worker_id() as WorkerId, *actor_id))
406 .into_group_map()
407 }
408
409 pub fn actor_infos(&self) -> impl Iterator<Item = ActorInfo> + '_ {
411 self.actor_locations
412 .iter()
413 .map(|(actor_id, worker_slot_id)| ActorInfo {
414 actor_id: *actor_id,
415 host: self.worker_locations[&(worker_slot_id.worker_id() as WorkerId)]
416 .host
417 .clone(),
418 })
419 }
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425
426 #[derive(Debug)]
427 enum Result {
428 DefaultHash,
429 Required(Req),
430 }
431
432 impl Result {
433 #[allow(non_upper_case_globals)]
434 const DefaultSingleton: Self = Self::Required(Req::AnySingleton);
435 }
436
437 fn run_and_merge(
438 facts: impl IntoIterator<Item = Fact>,
439 mapping_len: impl Fn(HashMappingId) -> usize,
440 ) -> MetaResult<HashMap<Id, Req>> {
441 let mut crepe = Crepe::new();
442 crepe.extend(facts.into_iter().map(Input));
443 let (reqs,) = crepe.run();
444
445 let reqs = reqs
446 .into_iter()
447 .map(|Requirement(id, req)| (id, req))
448 .into_group_map();
449
450 let mut merged = HashMap::new();
451 for (id, reqs) in reqs {
452 let req = (reqs.iter().copied())
453 .try_reduce(|a, b| Req::merge(a, b, &mapping_len))
454 .with_context(|| {
455 format!("cannot fulfill scheduling requirements for fragment {id:?}")
456 })?
457 .unwrap();
458 merged.insert(id, req);
459 }
460
461 Ok(merged)
462 }
463
464 fn test_success(facts: impl IntoIterator<Item = Fact>, expected: HashMap<Id, Result>) {
465 test_success_with_mapping_len(facts, expected, |_| 0);
466 }
467
468 fn test_success_with_mapping_len(
469 facts: impl IntoIterator<Item = Fact>,
470 expected: HashMap<Id, Result>,
471 mapping_len: impl Fn(HashMappingId) -> usize,
472 ) {
473 let reqs = run_and_merge(facts, mapping_len).unwrap();
474
475 for (id, expected) in expected {
476 match (reqs.get(&id), expected) {
477 (None, Result::DefaultHash) => {}
478 (Some(actual), Result::Required(expected)) if *actual == expected => {}
479 (actual, expected) => panic!(
480 "unexpected result for fragment {id:?}\nactual: {actual:?}\nexpected: {expected:?}"
481 ),
482 }
483 }
484 }
485
486 fn test_failed(facts: impl IntoIterator<Item = Fact>) {
487 run_and_merge(facts, |_| 0).unwrap_err();
488 }
489
490 #[test]
492 fn test_single_fragment_hash() {
493 #[rustfmt::skip]
494 let facts = [];
495
496 let expected = maplit::hashmap! {
497 101.into() => Result::DefaultHash,
498 };
499
500 test_success(facts, expected);
501 }
502
503 #[test]
505 fn test_single_fragment_singleton() {
506 #[rustfmt::skip]
507 let facts = [
508 Fact::Req { id: 101.into(), req: Req::AnySingleton },
509 ];
510
511 let expected = maplit::hashmap! {
512 101.into() => Result::DefaultSingleton,
513 };
514
515 test_success(facts, expected);
516 }
517
518 #[test]
522 fn test_scheduling_mv_on_mv() {
523 #[rustfmt::skip]
524 let facts = [
525 Fact::Req { id: 1.into(), req: Req::Hash(1) },
526 Fact::Req { id: 2.into(), req: Req::Singleton(WorkerSlotId::new(0, 2)) },
527 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
528 Fact::Edge { from: 2.into(), to: 102.into(), dt: NoShuffle },
529 Fact::Edge { from: 101.into(), to: 103.into(), dt: Hash },
530 Fact::Edge { from: 102.into(), to: 103.into(), dt: Hash },
531 Fact::Edge { from: 103.into(), to: 104.into(), dt: Simple },
532 ];
533
534 let expected = maplit::hashmap! {
535 101.into() => Result::Required(Req::Hash(1)),
536 102.into() => Result::Required(Req::Singleton(WorkerSlotId::new(0, 2))),
537 103.into() => Result::DefaultHash,
538 104.into() => Result::DefaultSingleton,
539 };
540
541 test_success(facts, expected);
542 }
543
544 #[test]
548 fn test_delta_join() {
549 #[rustfmt::skip]
550 let facts = [
551 Fact::Req { id: 1.into(), req: Req::Hash(1) },
552 Fact::Req { id: 2.into(), req: Req::Hash(2) },
553 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
554 Fact::Edge { from: 2.into(), to: 102.into(), dt: NoShuffle },
555 Fact::Edge { from: 101.into(), to: 103.into(), dt: NoShuffle },
556 Fact::Edge { from: 102.into(), to: 104.into(), dt: NoShuffle },
557 Fact::Edge { from: 101.into(), to: 104.into(), dt: Hash },
558 Fact::Edge { from: 102.into(), to: 103.into(), dt: Hash },
559 Fact::Edge { from: 103.into(), to: 105.into(), dt: Hash },
560 Fact::Edge { from: 104.into(), to: 105.into(), dt: Hash },
561 ];
562
563 let expected = maplit::hashmap! {
564 101.into() => Result::Required(Req::Hash(1)),
565 102.into() => Result::Required(Req::Hash(2)),
566 103.into() => Result::Required(Req::Hash(1)),
567 104.into() => Result::Required(Req::Hash(2)),
568 105.into() => Result::DefaultHash,
569 };
570
571 test_success(facts, expected);
572 }
573
574 #[test]
578 fn test_singleton_leaf() {
579 #[rustfmt::skip]
580 let facts = [
581 Fact::Req { id: 1.into(), req: Req::Hash(1) },
582 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
583 Fact::Req { id: 102.into(), req: Req::AnySingleton }, Fact::Edge { from: 101.into(), to: 103.into(), dt: Hash },
585 Fact::Edge { from: 102.into(), to: 103.into(), dt: Broadcast },
586 ];
587
588 let expected = maplit::hashmap! {
589 101.into() => Result::Required(Req::Hash(1)),
590 102.into() => Result::DefaultSingleton,
591 103.into() => Result::DefaultHash,
592 };
593
594 test_success(facts, expected);
595 }
596
597 #[test]
601 fn test_upstream_hash_shard_failed() {
602 #[rustfmt::skip]
603 let facts = [
604 Fact::Req { id: 1.into(), req: Req::Hash(1) },
605 Fact::Req { id: 2.into(), req: Req::Hash(2) },
606 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
607 Fact::Edge { from: 2.into(), to: 101.into(), dt: NoShuffle },
608 ];
609
610 test_failed(facts);
611 }
612
613 #[test]
615 fn test_arrangement_backfill_vnode_count() {
616 #[rustfmt::skip]
617 let facts = [
618 Fact::Req { id: 1.into(), req: Req::Hash(1) },
619 Fact::Req { id: 101.into(), req: Req::AnyVnodeCount(128) },
620 Fact::Edge { from: 1.into(), to: 101.into(), dt: Hash },
621 ];
622
623 let expected = maplit::hashmap! {
624 101.into() => Result::Required(Req::AnyVnodeCount(128)),
625 };
626
627 test_success(facts, expected);
628 }
629
630 #[test]
632 fn test_no_shuffle_backfill_vnode_count() {
633 #[rustfmt::skip]
634 let facts = [
635 Fact::Req { id: 1.into(), req: Req::Hash(1) },
636 Fact::Req { id: 101.into(), req: Req::AnyVnodeCount(128) },
637 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
638 ];
639
640 let expected = maplit::hashmap! {
641 101.into() => Result::Required(Req::Hash(1)),
642 };
643
644 test_success_with_mapping_len(facts, expected, |id| {
645 assert_eq!(id, 1);
646 128
647 });
648 }
649
650 #[test]
652 fn test_no_shuffle_backfill_mismatched_vnode_count() {
653 #[rustfmt::skip]
654 let facts = [
655 Fact::Req { id: 1.into(), req: Req::Hash(1) },
656 Fact::Req { id: 101.into(), req: Req::AnyVnodeCount(128) },
657 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
658 ];
659
660 test_failed(facts);
662 }
663
664 #[test]
666 fn test_backfill_singleton_vnode_count() {
667 #[rustfmt::skip]
668 let facts = [
669 Fact::Req { id: 1.into(), req: Req::Singleton(WorkerSlotId::new(0, 2)) },
670 Fact::Req { id: 101.into(), req: Req::AnySingleton },
671 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle }, ];
673
674 let expected = maplit::hashmap! {
675 101.into() => Result::Required(Req::Singleton(WorkerSlotId::new(0, 2))),
676 };
677
678 test_success(facts, expected);
679 }
680}