1#![allow(
16 clippy::collapsible_if,
17 clippy::explicit_iter_loop,
18 reason = "generated by crepe"
19)]
20
21use std::collections::{BTreeMap, HashMap};
22use std::num::NonZeroUsize;
23
24use anyhow::Context;
25use either::Either;
26use enum_as_inner::EnumAsInner;
27use itertools::Itertools;
28use risingwave_common::hash::{
29 ActorAlignmentId, ActorAlignmentMapping, ActorMapping, VnodeCountCompat,
30};
31use risingwave_common::util::stream_graph_visitor::visit_fragment;
32use risingwave_common::{bail, hash};
33use risingwave_meta_model::WorkerId;
34use risingwave_pb::common::{ActorInfo, WorkerNode};
35use risingwave_pb::meta::table_fragments::fragment::{
36 FragmentDistributionType, PbFragmentDistributionType,
37};
38use risingwave_pb::stream_plan::DispatcherType::{self, *};
39
40use crate::MetaResult;
41use crate::model::{ActorId, Fragment};
42use crate::stream::AssignerBuilder;
43use crate::stream::stream_graph::fragment::CompleteStreamFragmentGraph;
44use crate::stream::stream_graph::id::GlobalFragmentId as Id;
45
46type HashMappingId = usize;
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
50enum Req {
51 Singleton(WorkerId),
53 Hash(HashMappingId),
55 AnyVnodeCount(usize),
58}
59
60impl Req {
61 #[allow(non_upper_case_globals)]
63 const AnySingleton: Self = Self::AnyVnodeCount(1);
64
65 fn merge(a: Self, b: Self, mapping_len: impl Fn(HashMappingId) -> usize) -> MetaResult<Self> {
69 let merge = |a, b| match (a, b) {
71 (Self::AnySingleton, Self::Singleton(id)) => Some(Self::Singleton(id)),
72 (Self::AnyVnodeCount(count), Self::Hash(id)) if mapping_len(id) == count => {
73 Some(Self::Hash(id))
74 }
75 _ => None,
76 };
77
78 match merge(a, b).or_else(|| merge(b, a)) {
79 Some(req) => Ok(req),
80 None => bail!("incompatible requirements `{a:?}` and `{b:?}`"),
81 }
82 }
83}
84
85#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
87enum Fact {
88 Edge {
90 from: Id,
91 to: Id,
92 dt: DispatcherType,
93 },
94 Req { id: Id, req: Req },
96}
97
98crepe::crepe! {
99 @input
100 struct Input(Fact);
101
102 struct Edge(Id, Id, DispatcherType);
103 struct ExternalReq(Id, Req);
104
105 @output
106 struct Requirement(Id, Req);
107
108 Edge(from, to, dt) <- Input(f), let Fact::Edge { from, to, dt } = f;
110 Requirement(id, req) <- Input(f), let Fact::Req { id, req } = f;
111
112 Requirement(y, Req::AnySingleton) <- Edge(_, y, Simple);
114 Requirement(x, d) <- Edge(x, y, NoShuffle), Requirement(y, d);
116 Requirement(y, d) <- Edge(x, y, NoShuffle), Requirement(x, d);
117}
118
119#[derive(Debug, Clone, EnumAsInner)]
121pub(super) enum Distribution {
122 Singleton(WorkerId),
124
125 Hash(ActorAlignmentMapping),
127}
128
129impl Distribution {
130 pub fn parallelism(&self) -> usize {
132 self.actors().count()
133 }
134
135 pub fn actors(&self) -> impl Iterator<Item = ActorAlignmentId> + '_ {
137 match self {
138 Distribution::Singleton(p) => {
139 Either::Left(std::iter::once(ActorAlignmentId::new(*p as _, 0)))
140 }
141 Distribution::Hash(mapping) => Either::Right(mapping.iter_unique()),
142 }
143 }
144
145 pub fn vnode_count(&self) -> usize {
147 match self {
148 Distribution::Singleton(_) => 1, Distribution::Hash(mapping) => mapping.len(),
150 }
151 }
152
153 pub fn from_fragment(fragment: &Fragment, actor_location: &HashMap<ActorId, WorkerId>) -> Self {
155 match fragment.distribution_type {
156 FragmentDistributionType::Unspecified => unreachable!(),
157 FragmentDistributionType::Single => {
158 let actor_id = fragment.actors.iter().exactly_one().unwrap().actor_id;
159 let location = actor_location.get(&actor_id).unwrap();
160 Distribution::Singleton(*location)
161 }
162 FragmentDistributionType::Hash => {
163 let actor_bitmaps: HashMap<_, _> = fragment
164 .actors
165 .iter()
166 .map(|actor| {
167 (
168 actor.actor_id as hash::ActorId,
169 actor.vnode_bitmap.clone().unwrap(),
170 )
171 })
172 .collect();
173
174 let actor_mapping = ActorMapping::from_bitmaps(&actor_bitmaps);
175 let actor_location = actor_location
176 .iter()
177 .map(|(&k, &v)| (k, v as u32))
178 .collect();
179 let mapping = actor_mapping.to_actor_alignment(&actor_location);
180
181 Distribution::Hash(mapping)
182 }
183 }
184 }
185
186 pub fn to_distribution_type(&self) -> PbFragmentDistributionType {
188 match self {
189 Distribution::Singleton(_) => PbFragmentDistributionType::Single,
190 Distribution::Hash(_) => PbFragmentDistributionType::Hash,
191 }
192 }
193}
194
195pub(super) struct Scheduler {
197 default_hash_mapping: ActorAlignmentMapping,
199
200 default_singleton_worker: WorkerId,
202
203 dynamic_mapping_fn: Box<dyn Fn(usize) -> anyhow::Result<ActorAlignmentMapping>>,
205}
206
207impl Scheduler {
208 pub fn new(
216 streaming_job_id: u32,
217 workers: &HashMap<u32, WorkerNode>,
218 default_parallelism: NonZeroUsize,
219 expected_vnode_count: usize,
220 ) -> MetaResult<Self> {
221 let parallelism = default_parallelism.get();
222 assert!(
223 parallelism <= expected_vnode_count,
224 "parallelism should be limited by vnode count in previous steps"
225 );
226
227 let assigner = AssignerBuilder::new(streaming_job_id).build();
228
229 let worker_weights = workers
230 .iter()
231 .map(|(worker_id, worker)| {
232 (
233 *worker_id,
234 NonZeroUsize::new(worker.compute_node_parallelism()).unwrap(),
235 )
236 })
237 .collect();
238
239 let actor_idxes = (0..parallelism).collect_vec();
240 let vnodes = (0..expected_vnode_count).collect_vec();
241
242 let assignment = assigner.assign_hierarchical(&worker_weights, &actor_idxes, &vnodes)?;
243
244 let default_hash_mapping =
245 ActorAlignmentMapping::from_assignment(assignment, expected_vnode_count);
246
247 let single_actor_idxes = std::iter::once(0).collect_vec();
248
249 let single_assignment =
250 assigner.assign_hierarchical(&worker_weights, &single_actor_idxes, &vnodes)?;
251
252 let default_singleton_worker =
253 single_assignment.keys().exactly_one().cloned().unwrap() as _;
254
255 let dynamic_mapping_fn = Box::new(move |limited_count: usize| {
256 let parallelism = parallelism.min(limited_count);
257
258 let assignment = assigner.assign_hierarchical(
259 &worker_weights,
260 &(0..parallelism).collect_vec(),
261 &(0..limited_count).collect_vec(),
262 )?;
263
264 let mapping = ActorAlignmentMapping::from_assignment(assignment, limited_count);
265 Ok(mapping)
266 });
267 Ok(Self {
268 default_hash_mapping,
269 default_singleton_worker,
270 dynamic_mapping_fn,
271 })
272 }
273
274 pub fn schedule(
277 &self,
278 graph: &CompleteStreamFragmentGraph,
279 ) -> MetaResult<HashMap<Id, Distribution>> {
280 let existing_distribution = graph.existing_distribution();
281
282 let all_hash_mappings = existing_distribution
284 .values()
285 .flat_map(|dist| dist.as_hash())
286 .cloned()
287 .unique()
288 .collect_vec();
289 let hash_mapping_id: HashMap<_, _> = all_hash_mappings
290 .iter()
291 .enumerate()
292 .map(|(i, m)| (m.clone(), i))
293 .collect();
294
295 let mut facts = Vec::new();
296
297 for (&id, fragment) in graph.building_fragments() {
299 if fragment.requires_singleton {
300 facts.push(Fact::Req {
301 id,
302 req: Req::AnySingleton,
303 });
304 }
305 }
306 for (&id, fragment) in graph.building_fragments() {
309 visit_fragment(fragment, |node| {
310 use risingwave_pb::stream_plan::stream_node::NodeBody;
311 let vnode_count = match node {
312 NodeBody::StreamScan(node) => {
313 if let Some(table) = &node.arrangement_table {
314 table.vnode_count()
315 } else if let Some(table) = &node.table_desc {
316 table.vnode_count()
317 } else {
318 return;
319 }
320 }
321 NodeBody::TemporalJoin(node) => node.get_table_desc().unwrap().vnode_count(),
322 NodeBody::BatchPlan(node) => node.get_table_desc().unwrap().vnode_count(),
323 NodeBody::Lookup(node) => node
324 .get_arrangement_table_info()
325 .unwrap()
326 .get_table_desc()
327 .unwrap()
328 .vnode_count(),
329 _ => return,
330 };
331 facts.push(Fact::Req {
332 id,
333 req: Req::AnyVnodeCount(vnode_count),
334 });
335 });
336 }
337 for (id, dist) in existing_distribution {
339 let req = match dist {
340 Distribution::Singleton(worker_id) => Req::Singleton(worker_id),
341 Distribution::Hash(mapping) => Req::Hash(hash_mapping_id[&mapping]),
342 };
343 facts.push(Fact::Req { id, req });
344 }
345 for (from, to, edge) in graph.all_edges() {
347 facts.push(Fact::Edge {
348 from,
349 to,
350 dt: edge.dispatch_strategy.r#type(),
351 });
352 }
353
354 let mut crepe = Crepe::new();
356 crepe.extend(facts.into_iter().map(Input));
357 let (reqs,) = crepe.run();
358 let reqs = reqs
359 .into_iter()
360 .map(|Requirement(id, req)| (id, req))
361 .into_group_map();
362
363 let mut distributions = HashMap::new();
365 for &id in graph.building_fragments().keys() {
366 let dist = match reqs.get(&id) {
367 Some(reqs) => {
369 let req = (reqs.iter().copied())
370 .try_reduce(|a, b| Req::merge(a, b, |id| all_hash_mappings[id].len()))
371 .with_context(|| {
372 format!("cannot fulfill scheduling requirements for fragment {id:?}")
373 })?
374 .unwrap();
375
376 match req {
378 Req::Singleton(worker_id) => Distribution::Singleton(worker_id),
379 Req::Hash(mapping) => {
380 Distribution::Hash(all_hash_mappings[mapping].clone())
381 }
382 Req::AnySingleton => Distribution::Singleton(self.default_singleton_worker),
383 Req::AnyVnodeCount(vnode_count) => {
384 let mapping = (self.dynamic_mapping_fn)(vnode_count)
385 .with_context(|| {
386 format!(
387 "failed to build dynamic mapping for fragment {id:?} with vnode count {vnode_count}"
388 )
389 })?;
390
391 Distribution::Hash(mapping)
392 }
393 }
394 }
395 None => Distribution::Hash(self.default_hash_mapping.clone()),
397 };
398
399 distributions.insert(id, dist);
400 }
401
402 tracing::debug!(?distributions, "schedule fragments");
403
404 Ok(distributions)
405 }
406}
407
408#[cfg_attr(test, derive(Default))]
410pub struct Locations {
411 pub actor_locations: BTreeMap<ActorId, ActorAlignmentId>,
413 pub worker_locations: HashMap<WorkerId, WorkerNode>,
415}
416
417impl Locations {
418 pub fn worker_actors(&self) -> HashMap<WorkerId, Vec<ActorId>> {
420 self.actor_locations
421 .iter()
422 .map(|(actor_id, alignment_id)| (alignment_id.worker_id() as WorkerId, *actor_id))
423 .into_group_map()
424 }
425
426 pub fn actor_infos(&self) -> impl Iterator<Item = ActorInfo> + '_ {
428 self.actor_locations
429 .iter()
430 .map(|(actor_id, alignment_id)| ActorInfo {
431 actor_id: *actor_id,
432 host: self.worker_locations[&(alignment_id.worker_id() as WorkerId)]
433 .host
434 .clone(),
435 })
436 }
437}
438
439#[cfg(test)]
440mod tests {
441 use super::*;
442
443 #[derive(Debug)]
444 enum Result {
445 DefaultHash,
446 Required(Req),
447 }
448
449 impl Result {
450 #[allow(non_upper_case_globals)]
451 const DefaultSingleton: Self = Self::Required(Req::AnySingleton);
452 }
453
454 fn run_and_merge(
455 facts: impl IntoIterator<Item = Fact>,
456 mapping_len: impl Fn(HashMappingId) -> usize,
457 ) -> MetaResult<HashMap<Id, Req>> {
458 let mut crepe = Crepe::new();
459 crepe.extend(facts.into_iter().map(Input));
460 let (reqs,) = crepe.run();
461
462 let reqs = reqs
463 .into_iter()
464 .map(|Requirement(id, req)| (id, req))
465 .into_group_map();
466
467 let mut merged = HashMap::new();
468 for (id, reqs) in reqs {
469 let req = (reqs.iter().copied())
470 .try_reduce(|a, b| Req::merge(a, b, &mapping_len))
471 .with_context(|| {
472 format!("cannot fulfill scheduling requirements for fragment {id:?}")
473 })?
474 .unwrap();
475 merged.insert(id, req);
476 }
477
478 Ok(merged)
479 }
480
481 fn test_success(facts: impl IntoIterator<Item = Fact>, expected: HashMap<Id, Result>) {
482 test_success_with_mapping_len(facts, expected, |_| 0);
483 }
484
485 fn test_success_with_mapping_len(
486 facts: impl IntoIterator<Item = Fact>,
487 expected: HashMap<Id, Result>,
488 mapping_len: impl Fn(HashMappingId) -> usize,
489 ) {
490 let reqs = run_and_merge(facts, mapping_len).unwrap();
491
492 for (id, expected) in expected {
493 match (reqs.get(&id), expected) {
494 (None, Result::DefaultHash) => {}
495 (Some(actual), Result::Required(expected)) if *actual == expected => {}
496 (actual, expected) => panic!(
497 "unexpected result for fragment {id:?}\nactual: {actual:?}\nexpected: {expected:?}"
498 ),
499 }
500 }
501 }
502
503 fn test_failed(facts: impl IntoIterator<Item = Fact>) {
504 run_and_merge(facts, |_| 0).unwrap_err();
505 }
506
507 #[test]
509 fn test_single_fragment_hash() {
510 #[rustfmt::skip]
511 let facts = [];
512
513 let expected = maplit::hashmap! {
514 101.into() => Result::DefaultHash,
515 };
516
517 test_success(facts, expected);
518 }
519
520 #[test]
522 fn test_single_fragment_singleton() {
523 #[rustfmt::skip]
524 let facts = [
525 Fact::Req { id: 101.into(), req: Req::AnySingleton },
526 ];
527
528 let expected = maplit::hashmap! {
529 101.into() => Result::DefaultSingleton,
530 };
531
532 test_success(facts, expected);
533 }
534
535 #[test]
539 fn test_scheduling_mv_on_mv() {
540 #[rustfmt::skip]
541 let facts = [
542 Fact::Req { id: 1.into(), req: Req::Hash(1) },
543 Fact::Req { id: 2.into(), req: Req::Singleton(0) },
544 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
545 Fact::Edge { from: 2.into(), to: 102.into(), dt: NoShuffle },
546 Fact::Edge { from: 101.into(), to: 103.into(), dt: Hash },
547 Fact::Edge { from: 102.into(), to: 103.into(), dt: Hash },
548 Fact::Edge { from: 103.into(), to: 104.into(), dt: Simple },
549 ];
550
551 let expected = maplit::hashmap! {
552 101.into() => Result::Required(Req::Hash(1)),
553 102.into() => Result::Required(Req::Singleton(0)),
554 103.into() => Result::DefaultHash,
555 104.into() => Result::DefaultSingleton,
556 };
557
558 test_success(facts, expected);
559 }
560
561 #[test]
565 fn test_delta_join() {
566 #[rustfmt::skip]
567 let facts = [
568 Fact::Req { id: 1.into(), req: Req::Hash(1) },
569 Fact::Req { id: 2.into(), req: Req::Hash(2) },
570 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
571 Fact::Edge { from: 2.into(), to: 102.into(), dt: NoShuffle },
572 Fact::Edge { from: 101.into(), to: 103.into(), dt: NoShuffle },
573 Fact::Edge { from: 102.into(), to: 104.into(), dt: NoShuffle },
574 Fact::Edge { from: 101.into(), to: 104.into(), dt: Hash },
575 Fact::Edge { from: 102.into(), to: 103.into(), dt: Hash },
576 Fact::Edge { from: 103.into(), to: 105.into(), dt: Hash },
577 Fact::Edge { from: 104.into(), to: 105.into(), dt: Hash },
578 ];
579
580 let expected = maplit::hashmap! {
581 101.into() => Result::Required(Req::Hash(1)),
582 102.into() => Result::Required(Req::Hash(2)),
583 103.into() => Result::Required(Req::Hash(1)),
584 104.into() => Result::Required(Req::Hash(2)),
585 105.into() => Result::DefaultHash,
586 };
587
588 test_success(facts, expected);
589 }
590
591 #[test]
595 fn test_singleton_leaf() {
596 #[rustfmt::skip]
597 let facts = [
598 Fact::Req { id: 1.into(), req: Req::Hash(1) },
599 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
600 Fact::Req { id: 102.into(), req: Req::AnySingleton }, Fact::Edge { from: 101.into(), to: 103.into(), dt: Hash },
602 Fact::Edge { from: 102.into(), to: 103.into(), dt: Broadcast },
603 ];
604
605 let expected = maplit::hashmap! {
606 101.into() => Result::Required(Req::Hash(1)),
607 102.into() => Result::DefaultSingleton,
608 103.into() => Result::DefaultHash,
609 };
610
611 test_success(facts, expected);
612 }
613
614 #[test]
618 fn test_upstream_hash_shard_failed() {
619 #[rustfmt::skip]
620 let facts = [
621 Fact::Req { id: 1.into(), req: Req::Hash(1) },
622 Fact::Req { id: 2.into(), req: Req::Hash(2) },
623 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
624 Fact::Edge { from: 2.into(), to: 101.into(), dt: NoShuffle },
625 ];
626
627 test_failed(facts);
628 }
629
630 #[test]
632 fn test_arrangement_backfill_vnode_count() {
633 #[rustfmt::skip]
634 let facts = [
635 Fact::Req { id: 1.into(), req: Req::Hash(1) },
636 Fact::Req { id: 101.into(), req: Req::AnyVnodeCount(128) },
637 Fact::Edge { from: 1.into(), to: 101.into(), dt: Hash },
638 ];
639
640 let expected = maplit::hashmap! {
641 101.into() => Result::Required(Req::AnyVnodeCount(128)),
642 };
643
644 test_success(facts, expected);
645 }
646
647 #[test]
649 fn test_no_shuffle_backfill_vnode_count() {
650 #[rustfmt::skip]
651 let facts = [
652 Fact::Req { id: 1.into(), req: Req::Hash(1) },
653 Fact::Req { id: 101.into(), req: Req::AnyVnodeCount(128) },
654 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
655 ];
656
657 let expected = maplit::hashmap! {
658 101.into() => Result::Required(Req::Hash(1)),
659 };
660
661 test_success_with_mapping_len(facts, expected, |id| {
662 assert_eq!(id, 1);
663 128
664 });
665 }
666
667 #[test]
669 fn test_no_shuffle_backfill_mismatched_vnode_count() {
670 #[rustfmt::skip]
671 let facts = [
672 Fact::Req { id: 1.into(), req: Req::Hash(1) },
673 Fact::Req { id: 101.into(), req: Req::AnyVnodeCount(128) },
674 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
675 ];
676
677 test_failed(facts);
679 }
680
681 #[test]
683 fn test_backfill_singleton_vnode_count() {
684 #[rustfmt::skip]
685 let facts = [
686 Fact::Req { id: 1.into(), req: Req::Singleton(0) },
687 Fact::Req { id: 101.into(), req: Req::AnySingleton },
688 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle }, ];
690
691 let expected = maplit::hashmap! {
692 101.into() => Result::Required(Req::Singleton(0)),
693 };
694
695 test_success(facts, expected);
696 }
697}