1#![allow(
16 clippy::collapsible_if,
17 clippy::explicit_iter_loop,
18 reason = "generated by crepe"
19)]
20
21use std::collections::{BTreeMap, HashMap};
22use std::num::NonZeroUsize;
23
24use anyhow::Context;
25use enum_as_inner::EnumAsInner;
26use itertools::Itertools;
27use risingwave_common::bail;
28use risingwave_common::hash::{ActorAlignmentId, VnodeCountCompat};
29use risingwave_common::id::JobId;
30use risingwave_common::util::stream_graph_visitor::visit_fragment;
31use risingwave_connector::source::cdc::{CDC_BACKFILL_MAX_PARALLELISM, CdcScanOptions};
32use risingwave_meta_model::WorkerId;
33use risingwave_pb::common::WorkerNode;
34use risingwave_pb::meta::table_fragments::fragment::{
35 FragmentDistributionType, PbFragmentDistributionType,
36};
37use risingwave_pb::stream_plan::DispatcherType::{self, *};
38
39use crate::MetaResult;
40use crate::model::{ActorId, Fragment};
41use crate::stream::stream_graph::fragment::CompleteStreamFragmentGraph;
42use crate::stream::stream_graph::id::GlobalFragmentId as Id;
43
44type HashMappingId = usize;
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
48enum Req {
49 Singleton,
51 Hash(HashMappingId),
53 AnyVnodeCount(usize),
56}
57
58impl Req {
59 #[allow(non_upper_case_globals)]
61 const AnySingleton: Self = Self::AnyVnodeCount(1);
62
63 fn merge(a: Self, b: Self, mapping_len: impl Fn(HashMappingId) -> usize) -> MetaResult<Self> {
67 let merge = |a, b| match (a, b) {
69 (Self::AnySingleton, Self::Singleton) => Some(Self::Singleton),
70 (Self::AnyVnodeCount(count), Self::Hash(id)) if mapping_len(id) == count => {
71 Some(Self::Hash(id))
72 }
73 _ => None,
74 };
75
76 match merge(a, b).or_else(|| merge(b, a)) {
77 Some(req) => Ok(req),
78 None => bail!("incompatible requirements `{a:?}` and `{b:?}`"),
79 }
80 }
81}
82
83#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
85enum Fact {
86 Edge {
88 from: Id,
89 to: Id,
90 dt: DispatcherType,
91 },
92 Req { id: Id, req: Req },
94}
95
96crepe::crepe! {
97 @input
98 struct Input(Fact);
99
100 struct Edge(Id, Id, DispatcherType);
101 struct ExternalReq(Id, Req);
102
103 @output
104 struct Requirement(Id, Req);
105
106 Edge(from, to, dt) <- Input(f), let Fact::Edge { from, to, dt } = f;
108 Requirement(id, req) <- Input(f), let Fact::Req { id, req } = f;
109
110 Requirement(y, Req::AnySingleton) <- Edge(_, y, Simple);
112 Requirement(x, d) <- Edge(x, y, NoShuffle), Requirement(y, d);
114 Requirement(y, d) <- Edge(x, y, NoShuffle), Requirement(x, d);
115}
116
117#[derive(Debug, Clone, EnumAsInner)]
119pub(super) enum Distribution {
120 Singleton,
122
123 Hash(usize),
125}
126
127impl Distribution {
128 pub fn vnode_count(&self) -> usize {
130 match self {
131 Distribution::Singleton => 1, Distribution::Hash(vnode_count) => *vnode_count,
133 }
134 }
135
136 pub fn from_fragment(fragment: &Fragment) -> Self {
138 match fragment.distribution_type {
139 FragmentDistributionType::Single => Distribution::Singleton,
140 FragmentDistributionType::Hash => Distribution::Hash(fragment.vnode_count()),
141 PbFragmentDistributionType::Unspecified => {
142 unreachable!()
143 }
144 }
145 }
146
147 pub fn to_distribution_type(&self) -> PbFragmentDistributionType {
149 match self {
150 Distribution::Singleton => PbFragmentDistributionType::Single,
151 Distribution::Hash(_) => PbFragmentDistributionType::Hash,
152 }
153 }
154}
155
156pub(super) struct Scheduler {
158 default_vnode_count: usize,
160}
161
162impl Scheduler {
163 pub fn new(
171 streaming_job_id: JobId,
172 default_parallelism: NonZeroUsize,
173 expected_vnode_count: usize,
174 ) -> MetaResult<Self> {
175 let parallelism = default_parallelism.get();
176 assert!(
177 parallelism <= expected_vnode_count,
178 "parallelism should be limited by vnode count in previous steps for job {streaming_job_id}"
179 );
180
181 Ok(Self {
182 default_vnode_count: expected_vnode_count,
183 })
184 }
185
186 pub fn schedule(
189 &self,
190 graph: &CompleteStreamFragmentGraph,
191 ) -> MetaResult<HashMap<Id, Distribution>> {
192 let existing_distribution = graph.existing_distribution();
193
194 let all_hash_mappings = existing_distribution
196 .values()
197 .flat_map(|dist| dist.as_hash())
198 .cloned()
199 .unique()
200 .collect_vec();
201 let hash_mapping_id: HashMap<_, _> = all_hash_mappings
202 .iter()
203 .enumerate()
204 .map(|(i, m)| (*m, i))
205 .collect();
206
207 let mut facts = Vec::new();
208
209 for (&id, fragment) in graph.building_fragments() {
211 if fragment.requires_singleton {
212 facts.push(Fact::Req {
213 id,
214 req: Req::AnySingleton,
215 });
216 }
217 }
218 let mut force_parallelism_fragment_ids: HashMap<_, _> = HashMap::default();
219 for (&id, fragment) in graph.building_fragments() {
222 visit_fragment(fragment, |node| {
223 use risingwave_pb::stream_plan::stream_node::NodeBody;
224 let vnode_count = match node {
225 NodeBody::StreamScan(node) => {
226 if let Some(table) = &node.arrangement_table {
227 table.vnode_count()
228 } else if let Some(table) = &node.table_desc {
229 table.vnode_count()
230 } else {
231 return;
232 }
233 }
234 NodeBody::TemporalJoin(node) => node.get_table_desc().unwrap().vnode_count(),
235 NodeBody::BatchPlan(node) => node.get_table_desc().unwrap().vnode_count(),
236 NodeBody::Lookup(node) => node
237 .get_arrangement_table_info()
238 .unwrap()
239 .get_table_desc()
240 .unwrap()
241 .vnode_count(),
242 NodeBody::StreamCdcScan(node) => {
243 let Some(ref options) = node.options else {
244 return;
245 };
246 let options = CdcScanOptions::from_proto(options);
247 if options.is_parallelized_backfill() {
248 force_parallelism_fragment_ids
249 .insert(id, options.backfill_parallelism as usize);
250 CDC_BACKFILL_MAX_PARALLELISM as usize
251 } else {
252 return;
253 }
254 }
255 NodeBody::GapFill(node) => {
256 let buffer_table = node.get_state_table().unwrap();
258 if let Some(vnode_count) = buffer_table.vnode_count_inner().value_opt() {
260 vnode_count
261 } else {
262 return;
264 }
265 }
266 _ => return,
267 };
268 facts.push(Fact::Req {
269 id,
270 req: Req::AnyVnodeCount(vnode_count),
271 });
272 });
273 }
274 for (id, dist) in existing_distribution {
276 let req = match dist {
277 Distribution::Singleton => Req::Singleton,
278 Distribution::Hash(mapping) => Req::Hash(hash_mapping_id[&mapping]),
279 };
280 facts.push(Fact::Req { id, req });
281 }
282 for (from, to, edge) in graph.all_edges() {
284 facts.push(Fact::Edge {
285 from,
286 to,
287 dt: edge.dispatch_strategy.r#type(),
288 });
289 }
290
291 let mut crepe = Crepe::new();
293 crepe.extend(facts.into_iter().map(Input));
294 let (reqs,) = crepe.run();
295 let reqs = reqs
296 .into_iter()
297 .map(|Requirement(id, req)| (id, req))
298 .into_group_map();
299
300 let mut distributions = HashMap::new();
302 for &id in graph.building_fragments().keys() {
303 let dist = match reqs.get(&id) {
304 Some(reqs) => {
306 let req = (reqs.iter().copied())
307 .try_reduce(|a, b| Req::merge(a, b, |id| all_hash_mappings[id]))
308 .with_context(|| {
309 format!("cannot fulfill scheduling requirements for fragment {id:?}")
310 })?
311 .unwrap();
312
313 match req {
315 Req::Singleton => Distribution::Singleton,
316 Req::Hash(mapping) => Distribution::Hash(all_hash_mappings[mapping]),
317 Req::AnySingleton => Distribution::Singleton,
318 Req::AnyVnodeCount(vnode_count) => Distribution::Hash(vnode_count),
319 }
320 }
321 None => Distribution::Hash(self.default_vnode_count),
323 };
324
325 distributions.insert(id, dist);
326 }
327
328 tracing::debug!(?distributions, "schedule fragments");
329
330 Ok(distributions)
331 }
332}
333
334#[cfg_attr(test, derive(Default))]
336pub struct Locations {
337 pub actor_locations: BTreeMap<ActorId, ActorAlignmentId>,
339 pub worker_locations: HashMap<WorkerId, WorkerNode>,
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346
347 #[derive(Debug)]
348 enum Result {
349 DefaultHash,
350 Required(Req),
351 }
352
353 impl Result {
354 #[allow(non_upper_case_globals)]
355 const DefaultSingleton: Self = Self::Required(Req::AnySingleton);
356 }
357
358 fn run_and_merge(
359 facts: impl IntoIterator<Item = Fact>,
360 mapping_len: impl Fn(HashMappingId) -> usize,
361 ) -> MetaResult<HashMap<Id, Req>> {
362 let mut crepe = Crepe::new();
363 crepe.extend(facts.into_iter().map(Input));
364 let (reqs,) = crepe.run();
365
366 let reqs = reqs
367 .into_iter()
368 .map(|Requirement(id, req)| (id, req))
369 .into_group_map();
370
371 let mut merged = HashMap::new();
372 for (id, reqs) in reqs {
373 let req = (reqs.iter().copied())
374 .try_reduce(|a, b| Req::merge(a, b, &mapping_len))
375 .with_context(|| {
376 format!("cannot fulfill scheduling requirements for fragment {id:?}")
377 })?
378 .unwrap();
379 merged.insert(id, req);
380 }
381
382 Ok(merged)
383 }
384
385 fn test_success(facts: impl IntoIterator<Item = Fact>, expected: HashMap<Id, Result>) {
386 test_success_with_mapping_len(facts, expected, |_| 0);
387 }
388
389 fn test_success_with_mapping_len(
390 facts: impl IntoIterator<Item = Fact>,
391 expected: HashMap<Id, Result>,
392 mapping_len: impl Fn(HashMappingId) -> usize,
393 ) {
394 let reqs = run_and_merge(facts, mapping_len).unwrap();
395
396 for (id, expected) in expected {
397 match (reqs.get(&id), expected) {
398 (None, Result::DefaultHash) => {}
399 (Some(actual), Result::Required(expected)) if *actual == expected => {}
400 (actual, expected) => panic!(
401 "unexpected result for fragment {id:?}\nactual: {actual:?}\nexpected: {expected:?}"
402 ),
403 }
404 }
405 }
406
407 fn test_failed(facts: impl IntoIterator<Item = Fact>) {
408 run_and_merge(facts, |_| 0).unwrap_err();
409 }
410
411 #[test]
413 fn test_single_fragment_hash() {
414 #[rustfmt::skip]
415 let facts = [];
416
417 let expected = maplit::hashmap! {
418 101.into() => Result::DefaultHash,
419 };
420
421 test_success(facts, expected);
422 }
423
424 #[test]
426 fn test_single_fragment_singleton() {
427 #[rustfmt::skip]
428 let facts = [
429 Fact::Req { id: 101.into(), req: Req::AnySingleton },
430 ];
431
432 let expected = maplit::hashmap! {
433 101.into() => Result::DefaultSingleton,
434 };
435
436 test_success(facts, expected);
437 }
438
439 #[test]
443 fn test_scheduling_mv_on_mv() {
444 #[rustfmt::skip]
445 let facts = [
446 Fact::Req { id: 1.into(), req: Req::Hash(1) },
447 Fact::Req { id: 2.into(), req: Req::Singleton },
448 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
449 Fact::Edge { from: 2.into(), to: 102.into(), dt: NoShuffle },
450 Fact::Edge { from: 101.into(), to: 103.into(), dt: Hash },
451 Fact::Edge { from: 102.into(), to: 103.into(), dt: Hash },
452 Fact::Edge { from: 103.into(), to: 104.into(), dt: Simple },
453 ];
454
455 let expected = maplit::hashmap! {
456 101.into() => Result::Required(Req::Hash(1)),
457 102.into() => Result::Required(Req::Singleton),
458 103.into() => Result::DefaultHash,
459 104.into() => Result::DefaultSingleton,
460 };
461
462 test_success(facts, expected);
463 }
464
465 #[test]
469 fn test_delta_join() {
470 #[rustfmt::skip]
471 let facts = [
472 Fact::Req { id: 1.into(), req: Req::Hash(1) },
473 Fact::Req { id: 2.into(), req: Req::Hash(2) },
474 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
475 Fact::Edge { from: 2.into(), to: 102.into(), dt: NoShuffle },
476 Fact::Edge { from: 101.into(), to: 103.into(), dt: NoShuffle },
477 Fact::Edge { from: 102.into(), to: 104.into(), dt: NoShuffle },
478 Fact::Edge { from: 101.into(), to: 104.into(), dt: Hash },
479 Fact::Edge { from: 102.into(), to: 103.into(), dt: Hash },
480 Fact::Edge { from: 103.into(), to: 105.into(), dt: Hash },
481 Fact::Edge { from: 104.into(), to: 105.into(), dt: Hash },
482 ];
483
484 let expected = maplit::hashmap! {
485 101.into() => Result::Required(Req::Hash(1)),
486 102.into() => Result::Required(Req::Hash(2)),
487 103.into() => Result::Required(Req::Hash(1)),
488 104.into() => Result::Required(Req::Hash(2)),
489 105.into() => Result::DefaultHash,
490 };
491
492 test_success(facts, expected);
493 }
494
495 #[test]
499 fn test_singleton_leaf() {
500 #[rustfmt::skip]
501 let facts = [
502 Fact::Req { id: 1.into(), req: Req::Hash(1) },
503 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
504 Fact::Req { id: 102.into(), req: Req::AnySingleton }, Fact::Edge { from: 101.into(), to: 103.into(), dt: Hash },
506 Fact::Edge { from: 102.into(), to: 103.into(), dt: Broadcast },
507 ];
508
509 let expected = maplit::hashmap! {
510 101.into() => Result::Required(Req::Hash(1)),
511 102.into() => Result::DefaultSingleton,
512 103.into() => Result::DefaultHash,
513 };
514
515 test_success(facts, expected);
516 }
517
518 #[test]
522 fn test_upstream_hash_shard_failed() {
523 #[rustfmt::skip]
524 let facts = [
525 Fact::Req { id: 1.into(), req: Req::Hash(1) },
526 Fact::Req { id: 2.into(), req: Req::Hash(2) },
527 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
528 Fact::Edge { from: 2.into(), to: 101.into(), dt: NoShuffle },
529 ];
530
531 test_failed(facts);
532 }
533
534 #[test]
536 fn test_arrangement_backfill_vnode_count() {
537 #[rustfmt::skip]
538 let facts = [
539 Fact::Req { id: 1.into(), req: Req::Hash(1) },
540 Fact::Req { id: 101.into(), req: Req::AnyVnodeCount(128) },
541 Fact::Edge { from: 1.into(), to: 101.into(), dt: Hash },
542 ];
543
544 let expected = maplit::hashmap! {
545 101.into() => Result::Required(Req::AnyVnodeCount(128)),
546 };
547
548 test_success(facts, expected);
549 }
550
551 #[test]
553 fn test_no_shuffle_backfill_vnode_count() {
554 #[rustfmt::skip]
555 let facts = [
556 Fact::Req { id: 1.into(), req: Req::Hash(1) },
557 Fact::Req { id: 101.into(), req: Req::AnyVnodeCount(128) },
558 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
559 ];
560
561 let expected = maplit::hashmap! {
562 101.into() => Result::Required(Req::Hash(1)),
563 };
564
565 test_success_with_mapping_len(facts, expected, |id| {
566 assert_eq!(id, 1);
567 128
568 });
569 }
570
571 #[test]
573 fn test_no_shuffle_backfill_mismatched_vnode_count() {
574 #[rustfmt::skip]
575 let facts = [
576 Fact::Req { id: 1.into(), req: Req::Hash(1) },
577 Fact::Req { id: 101.into(), req: Req::AnyVnodeCount(128) },
578 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
579 ];
580
581 test_failed(facts);
583 }
584
585 #[test]
587 fn test_backfill_singleton_vnode_count() {
588 #[rustfmt::skip]
589 let facts = [
590 Fact::Req { id: 1.into(), req: Req::Singleton },
591 Fact::Req { id: 101.into(), req: Req::AnySingleton },
592 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle }, ];
594
595 let expected = maplit::hashmap! {
596 101.into() => Result::Required(Req::Singleton),
597 };
598
599 test_success(facts, expected);
600 }
601}