1use std::collections::{BTreeMap, HashMap};
16
17use anyhow::Context;
18use enum_as_inner::EnumAsInner;
19use itertools::Itertools;
20use risingwave_common::bail;
21use risingwave_common::hash::{ActorAlignmentId, VnodeCountCompat};
22use risingwave_common::util::stream_graph_visitor::visit_fragment;
23use risingwave_connector::source::cdc::{CDC_BACKFILL_MAX_PARALLELISM, CdcScanOptions};
24use risingwave_meta_model::WorkerId;
25use risingwave_pb::common::WorkerNode;
26use risingwave_pb::meta::table_fragments::fragment::{
27 FragmentDistributionType, PbFragmentDistributionType,
28};
29use risingwave_pb::stream_plan::DispatcherType::{self, *};
30
31use crate::MetaResult;
32use crate::model::{ActorId, Fragment};
33use crate::stream::stream_graph::fragment::CompleteStreamFragmentGraph;
34use crate::stream::stream_graph::id::GlobalFragmentId as Id;
35
36type HashMappingId = usize;
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
40enum Req {
41 Singleton,
43 Hash(HashMappingId),
45 AnyVnodeCount(usize),
48}
49
50impl Req {
51 #[expect(non_upper_case_globals)]
53 const AnySingleton: Self = Self::AnyVnodeCount(1);
54
55 fn merge(a: Self, b: Self, mapping_len: impl Fn(HashMappingId) -> usize) -> MetaResult<Self> {
59 let merge = |a, b| match (a, b) {
61 (Self::AnySingleton, Self::Singleton) => Some(Self::Singleton),
62 (Self::AnyVnodeCount(count), Self::Hash(id)) if mapping_len(id) == count => {
63 Some(Self::Hash(id))
64 }
65 _ => None,
66 };
67
68 match merge(a, b).or_else(|| merge(b, a)) {
69 Some(req) => Ok(req),
70 None => bail!("incompatible requirements `{a:?}` and `{b:?}`"),
71 }
72 }
73}
74
75#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
77enum Fact {
78 Edge {
80 from: Id,
81 to: Id,
82 dt: DispatcherType,
83 },
84 Req { id: Id, req: Req },
86}
87
88crepe::crepe! {
89 @input
90 struct Input(Fact);
91
92 struct Edge(Id, Id, DispatcherType);
93 struct ExternalReq(Id, Req);
94
95 @output
96 struct Requirement(Id, Req);
97
98 Edge(from, to, dt) <- Input(f), let Fact::Edge { from, to, dt } = f;
100 Requirement(id, req) <- Input(f), let Fact::Req { id, req } = f;
101
102 Requirement(y, Req::AnySingleton) <- Edge(_, y, Simple);
104 Requirement(x, d) <- Edge(x, y, NoShuffle), Requirement(y, d);
106 Requirement(y, d) <- Edge(x, y, NoShuffle), Requirement(x, d);
107}
108
109#[derive(Debug, Clone, EnumAsInner)]
111pub(super) enum Distribution {
112 Singleton,
114
115 Hash(usize),
117}
118
119impl Distribution {
120 pub fn vnode_count(&self) -> usize {
122 match self {
123 Distribution::Singleton => 1, Distribution::Hash(vnode_count) => *vnode_count,
125 }
126 }
127
128 pub fn from_fragment(fragment: &Fragment) -> Self {
130 match fragment.distribution_type {
131 FragmentDistributionType::Single => Distribution::Singleton,
132 FragmentDistributionType::Hash => Distribution::Hash(fragment.vnode_count()),
133 PbFragmentDistributionType::Unspecified => {
134 unreachable!()
135 }
136 }
137 }
138
139 pub fn to_distribution_type(&self) -> PbFragmentDistributionType {
141 match self {
142 Distribution::Singleton => PbFragmentDistributionType::Single,
143 Distribution::Hash(_) => PbFragmentDistributionType::Hash,
144 }
145 }
146}
147
148pub(super) struct Scheduler {
150 default_vnode_count: usize,
152}
153
154impl Scheduler {
155 pub fn new(expected_vnode_count: usize) -> MetaResult<Self> {
157 Ok(Self {
158 default_vnode_count: expected_vnode_count,
159 })
160 }
161
162 pub fn schedule(
165 &self,
166 graph: &CompleteStreamFragmentGraph,
167 ) -> MetaResult<HashMap<Id, Distribution>> {
168 let existing_distribution = graph.existing_distribution();
169
170 let all_hash_mappings = existing_distribution
172 .values()
173 .flat_map(|dist| dist.as_hash())
174 .cloned()
175 .unique()
176 .collect_vec();
177 let hash_mapping_id: HashMap<_, _> = all_hash_mappings
178 .iter()
179 .enumerate()
180 .map(|(i, m)| (*m, i))
181 .collect();
182
183 let mut facts = Vec::new();
184
185 for (&id, fragment) in graph.building_fragments() {
187 if fragment.requires_singleton {
188 facts.push(Fact::Req {
189 id,
190 req: Req::AnySingleton,
191 });
192 }
193 }
194 let mut force_parallelism_fragment_ids: HashMap<_, _> = HashMap::default();
195 for (&id, fragment) in graph.building_fragments() {
198 visit_fragment(fragment, |node| {
199 use risingwave_pb::stream_plan::stream_node::NodeBody;
200 let vnode_count = match node {
201 NodeBody::StreamScan(node) => {
202 if let Some(table) = &node.arrangement_table {
203 table.vnode_count()
204 } else if let Some(table) = &node.table_desc {
205 table.vnode_count()
206 } else {
207 return;
208 }
209 }
210 NodeBody::TemporalJoin(node) => node.get_table_desc().unwrap().vnode_count(),
211 NodeBody::BatchPlan(node) => node.get_table_desc().unwrap().vnode_count(),
212 NodeBody::Lookup(node) => node
213 .get_arrangement_table_info()
214 .unwrap()
215 .get_table_desc()
216 .unwrap()
217 .vnode_count(),
218 NodeBody::StreamCdcScan(node) => {
219 let Some(ref options) = node.options else {
220 return;
221 };
222 let options = CdcScanOptions::from_proto(options);
223 if options.is_parallelized_backfill() {
224 force_parallelism_fragment_ids
225 .insert(id, options.backfill_parallelism as usize);
226 CDC_BACKFILL_MAX_PARALLELISM as usize
227 } else {
228 return;
229 }
230 }
231 NodeBody::GapFill(node) => {
232 let buffer_table = node.get_state_table().unwrap();
234 if let Some(vnode_count) = buffer_table.vnode_count_inner().value_opt() {
236 vnode_count
237 } else {
238 return;
240 }
241 }
242 _ => return,
243 };
244 facts.push(Fact::Req {
245 id,
246 req: Req::AnyVnodeCount(vnode_count),
247 });
248 });
249 }
250 for (id, dist) in existing_distribution {
252 let req = match dist {
253 Distribution::Singleton => Req::Singleton,
254 Distribution::Hash(mapping) => Req::Hash(hash_mapping_id[&mapping]),
255 };
256 facts.push(Fact::Req { id, req });
257 }
258 for (from, to, edge) in graph.all_edges() {
260 facts.push(Fact::Edge {
261 from,
262 to,
263 dt: edge.dispatch_strategy.r#type(),
264 });
265 }
266
267 let mut crepe = Crepe::new();
269 crepe.extend(facts.into_iter().map(Input));
270 let (reqs,) = crepe.run();
271 let reqs = reqs
272 .into_iter()
273 .map(|Requirement(id, req)| (id, req))
274 .into_group_map();
275
276 let mut distributions = HashMap::new();
278 for &id in graph.building_fragments().keys() {
279 let dist = match reqs.get(&id) {
280 Some(reqs) => {
282 let req = (reqs.iter().copied())
283 .try_reduce(|a, b| Req::merge(a, b, |id| all_hash_mappings[id]))
284 .with_context(|| {
285 format!("cannot fulfill scheduling requirements for fragment {id:?}")
286 })?
287 .unwrap();
288
289 match req {
291 Req::Singleton => Distribution::Singleton,
292 Req::Hash(mapping) => Distribution::Hash(all_hash_mappings[mapping]),
293 Req::AnySingleton => Distribution::Singleton,
294 Req::AnyVnodeCount(vnode_count) => Distribution::Hash(vnode_count),
295 }
296 }
297 None => Distribution::Hash(self.default_vnode_count),
299 };
300
301 distributions.insert(id, dist);
302 }
303
304 tracing::debug!(?distributions, "schedule fragments");
305
306 Ok(distributions)
307 }
308}
309
310#[cfg_attr(test, derive(Default))]
312pub struct Locations {
313 pub actor_locations: BTreeMap<ActorId, ActorAlignmentId>,
315 pub worker_locations: HashMap<WorkerId, WorkerNode>,
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322
323 #[derive(Debug)]
324 enum Result {
325 DefaultHash,
326 Required(Req),
327 }
328
329 impl Result {
330 #[expect(non_upper_case_globals)]
331 const DefaultSingleton: Self = Self::Required(Req::AnySingleton);
332 }
333
334 fn run_and_merge(
335 facts: impl IntoIterator<Item = Fact>,
336 mapping_len: impl Fn(HashMappingId) -> usize,
337 ) -> MetaResult<HashMap<Id, Req>> {
338 let mut crepe = Crepe::new();
339 crepe.extend(facts.into_iter().map(Input));
340 let (reqs,) = crepe.run();
341
342 let reqs = reqs
343 .into_iter()
344 .map(|Requirement(id, req)| (id, req))
345 .into_group_map();
346
347 let mut merged = HashMap::new();
348 for (id, reqs) in reqs {
349 let req = (reqs.iter().copied())
350 .try_reduce(|a, b| Req::merge(a, b, &mapping_len))
351 .with_context(|| {
352 format!("cannot fulfill scheduling requirements for fragment {id:?}")
353 })?
354 .unwrap();
355 merged.insert(id, req);
356 }
357
358 Ok(merged)
359 }
360
361 fn test_success(facts: impl IntoIterator<Item = Fact>, expected: HashMap<Id, Result>) {
362 test_success_with_mapping_len(facts, expected, |_| 0);
363 }
364
365 fn test_success_with_mapping_len(
366 facts: impl IntoIterator<Item = Fact>,
367 expected: HashMap<Id, Result>,
368 mapping_len: impl Fn(HashMappingId) -> usize,
369 ) {
370 let reqs = run_and_merge(facts, mapping_len).unwrap();
371
372 for (id, expected) in expected {
373 match (reqs.get(&id), expected) {
374 (None, Result::DefaultHash) => {}
375 (Some(actual), Result::Required(expected)) if *actual == expected => {}
376 (actual, expected) => panic!(
377 "unexpected result for fragment {id:?}\nactual: {actual:?}\nexpected: {expected:?}"
378 ),
379 }
380 }
381 }
382
383 fn test_failed(facts: impl IntoIterator<Item = Fact>) {
384 run_and_merge(facts, |_| 0).unwrap_err();
385 }
386
387 #[test]
389 fn test_single_fragment_hash() {
390 #[rustfmt::skip]
391 let facts = [];
392
393 let expected = maplit::hashmap! {
394 101.into() => Result::DefaultHash,
395 };
396
397 test_success(facts, expected);
398 }
399
400 #[test]
402 fn test_single_fragment_singleton() {
403 #[rustfmt::skip]
404 let facts = [
405 Fact::Req { id: 101.into(), req: Req::AnySingleton },
406 ];
407
408 let expected = maplit::hashmap! {
409 101.into() => Result::DefaultSingleton,
410 };
411
412 test_success(facts, expected);
413 }
414
415 #[test]
419 fn test_scheduling_mv_on_mv() {
420 #[rustfmt::skip]
421 let facts = [
422 Fact::Req { id: 1.into(), req: Req::Hash(1) },
423 Fact::Req { id: 2.into(), req: Req::Singleton },
424 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
425 Fact::Edge { from: 2.into(), to: 102.into(), dt: NoShuffle },
426 Fact::Edge { from: 101.into(), to: 103.into(), dt: Hash },
427 Fact::Edge { from: 102.into(), to: 103.into(), dt: Hash },
428 Fact::Edge { from: 103.into(), to: 104.into(), dt: Simple },
429 ];
430
431 let expected = maplit::hashmap! {
432 101.into() => Result::Required(Req::Hash(1)),
433 102.into() => Result::Required(Req::Singleton),
434 103.into() => Result::DefaultHash,
435 104.into() => Result::DefaultSingleton,
436 };
437
438 test_success(facts, expected);
439 }
440
441 #[test]
445 fn test_delta_join() {
446 #[rustfmt::skip]
447 let facts = [
448 Fact::Req { id: 1.into(), req: Req::Hash(1) },
449 Fact::Req { id: 2.into(), req: Req::Hash(2) },
450 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
451 Fact::Edge { from: 2.into(), to: 102.into(), dt: NoShuffle },
452 Fact::Edge { from: 101.into(), to: 103.into(), dt: NoShuffle },
453 Fact::Edge { from: 102.into(), to: 104.into(), dt: NoShuffle },
454 Fact::Edge { from: 101.into(), to: 104.into(), dt: Hash },
455 Fact::Edge { from: 102.into(), to: 103.into(), dt: Hash },
456 Fact::Edge { from: 103.into(), to: 105.into(), dt: Hash },
457 Fact::Edge { from: 104.into(), to: 105.into(), dt: Hash },
458 ];
459
460 let expected = maplit::hashmap! {
461 101.into() => Result::Required(Req::Hash(1)),
462 102.into() => Result::Required(Req::Hash(2)),
463 103.into() => Result::Required(Req::Hash(1)),
464 104.into() => Result::Required(Req::Hash(2)),
465 105.into() => Result::DefaultHash,
466 };
467
468 test_success(facts, expected);
469 }
470
471 #[test]
475 fn test_singleton_leaf() {
476 #[rustfmt::skip]
477 let facts = [
478 Fact::Req { id: 1.into(), req: Req::Hash(1) },
479 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
480 Fact::Req { id: 102.into(), req: Req::AnySingleton }, Fact::Edge { from: 101.into(), to: 103.into(), dt: Hash },
482 Fact::Edge { from: 102.into(), to: 103.into(), dt: Broadcast },
483 ];
484
485 let expected = maplit::hashmap! {
486 101.into() => Result::Required(Req::Hash(1)),
487 102.into() => Result::DefaultSingleton,
488 103.into() => Result::DefaultHash,
489 };
490
491 test_success(facts, expected);
492 }
493
494 #[test]
498 fn test_upstream_hash_shard_failed() {
499 #[rustfmt::skip]
500 let facts = [
501 Fact::Req { id: 1.into(), req: Req::Hash(1) },
502 Fact::Req { id: 2.into(), req: Req::Hash(2) },
503 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
504 Fact::Edge { from: 2.into(), to: 101.into(), dt: NoShuffle },
505 ];
506
507 test_failed(facts);
508 }
509
510 #[test]
512 fn test_arrangement_backfill_vnode_count() {
513 #[rustfmt::skip]
514 let facts = [
515 Fact::Req { id: 1.into(), req: Req::Hash(1) },
516 Fact::Req { id: 101.into(), req: Req::AnyVnodeCount(128) },
517 Fact::Edge { from: 1.into(), to: 101.into(), dt: Hash },
518 ];
519
520 let expected = maplit::hashmap! {
521 101.into() => Result::Required(Req::AnyVnodeCount(128)),
522 };
523
524 test_success(facts, expected);
525 }
526
527 #[test]
529 fn test_no_shuffle_backfill_vnode_count() {
530 #[rustfmt::skip]
531 let facts = [
532 Fact::Req { id: 1.into(), req: Req::Hash(1) },
533 Fact::Req { id: 101.into(), req: Req::AnyVnodeCount(128) },
534 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
535 ];
536
537 let expected = maplit::hashmap! {
538 101.into() => Result::Required(Req::Hash(1)),
539 };
540
541 test_success_with_mapping_len(facts, expected, |id| {
542 assert_eq!(id, 1);
543 128
544 });
545 }
546
547 #[test]
549 fn test_no_shuffle_backfill_mismatched_vnode_count() {
550 #[rustfmt::skip]
551 let facts = [
552 Fact::Req { id: 1.into(), req: Req::Hash(1) },
553 Fact::Req { id: 101.into(), req: Req::AnyVnodeCount(128) },
554 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
555 ];
556
557 test_failed(facts);
559 }
560
561 #[test]
563 fn test_backfill_singleton_vnode_count() {
564 #[rustfmt::skip]
565 let facts = [
566 Fact::Req { id: 1.into(), req: Req::Singleton },
567 Fact::Req { id: 101.into(), req: Req::AnySingleton },
568 Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle }, ];
570
571 let expected = maplit::hashmap! {
572 101.into() => Result::Required(Req::Singleton),
573 };
574
575 test_success(facts, expected);
576 }
577}