risingwave_meta/stream/stream_graph/
schedule.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap};
16
17use anyhow::Context;
18use enum_as_inner::EnumAsInner;
19use itertools::Itertools;
20use risingwave_common::bail;
21use risingwave_common::hash::{ActorAlignmentId, VnodeCountCompat};
22use risingwave_common::util::stream_graph_visitor::visit_fragment;
23use risingwave_connector::source::cdc::{CDC_BACKFILL_MAX_PARALLELISM, CdcScanOptions};
24use risingwave_meta_model::WorkerId;
25use risingwave_pb::common::WorkerNode;
26use risingwave_pb::meta::table_fragments::fragment::{
27    FragmentDistributionType, PbFragmentDistributionType,
28};
29use risingwave_pb::stream_plan::DispatcherType::{self, *};
30
31use crate::MetaResult;
32use crate::model::{ActorId, Fragment};
33use crate::stream::stream_graph::fragment::CompleteStreamFragmentGraph;
34use crate::stream::stream_graph::id::GlobalFragmentId as Id;
35
36type HashMappingId = usize;
37
38/// The internal structure for processing scheduling requirements in the scheduler.
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
40enum Req {
41    /// The fragment must be singleton and is scheduled to the given worker id.
42    Singleton,
43    /// The fragment must be hash-distributed and is scheduled by the given hash mapping.
44    Hash(HashMappingId),
45    /// The fragment must have the given vnode count, but can be scheduled anywhere.
46    /// When the vnode count is 1, it means the fragment must be singleton.
47    AnyVnodeCount(usize),
48}
49
50impl Req {
51    /// Equivalent to `Req::AnyVnodeCount(1)`.
52    #[expect(non_upper_case_globals)]
53    const AnySingleton: Self = Self::AnyVnodeCount(1);
54
55    /// Merge two requirements. Returns an error if the requirements are incompatible.
56    ///
57    /// The `mapping_len` function is used to get the vnode count of a hash mapping by its id.
58    fn merge(a: Self, b: Self, mapping_len: impl Fn(HashMappingId) -> usize) -> MetaResult<Self> {
59        // Note that a and b are always different, as they come from a set.
60        let merge = |a, b| match (a, b) {
61            (Self::AnySingleton, Self::Singleton) => Some(Self::Singleton),
62            (Self::AnyVnodeCount(count), Self::Hash(id)) if mapping_len(id) == count => {
63                Some(Self::Hash(id))
64            }
65            _ => None,
66        };
67
68        match merge(a, b).or_else(|| merge(b, a)) {
69            Some(req) => Ok(req),
70            None => bail!("incompatible requirements `{a:?}` and `{b:?}`"),
71        }
72    }
73}
74
75/// Facts as the input of the scheduler.
76#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
77enum Fact {
78    /// An edge in the fragment graph.
79    Edge {
80        from: Id,
81        to: Id,
82        dt: DispatcherType,
83    },
84    /// A scheduling requirement for a fragment.
85    Req { id: Id, req: Req },
86}
87
88crepe::crepe! {
89    @input
90    struct Input(Fact);
91
92    struct Edge(Id, Id, DispatcherType);
93    struct ExternalReq(Id, Req);
94
95    @output
96    struct Requirement(Id, Req);
97
98    // Extract facts.
99    Edge(from, to, dt) <- Input(f), let Fact::Edge { from, to, dt } = f;
100    Requirement(id, req) <- Input(f), let Fact::Req { id, req } = f;
101
102    // The downstream fragment of a `Simple` edge must be singleton.
103    Requirement(y, Req::AnySingleton) <- Edge(_, y, Simple);
104    // Requirements propagate through `NoShuffle` edges.
105    Requirement(x, d) <- Edge(x, y, NoShuffle), Requirement(y, d);
106    Requirement(y, d) <- Edge(x, y, NoShuffle), Requirement(x, d);
107}
108
109/// The distribution (scheduling result) of a fragment.
110#[derive(Debug, Clone, EnumAsInner)]
111pub(super) enum Distribution {
112    /// The fragment is singleton and is scheduled to the given worker slot.
113    Singleton,
114
115    /// The fragment is hash-distributed and is scheduled by the given hash mapping.
116    Hash(usize),
117}
118
119impl Distribution {
120    /// Get the vnode count of the distribution.
121    pub fn vnode_count(&self) -> usize {
122        match self {
123            Distribution::Singleton => 1, // only `SINGLETON_VNODE`
124            Distribution::Hash(vnode_count) => *vnode_count,
125        }
126    }
127
128    /// Create a distribution from a persisted protobuf `Fragment`.
129    pub fn from_fragment(fragment: &Fragment) -> Self {
130        match fragment.distribution_type {
131            FragmentDistributionType::Single => Distribution::Singleton,
132            FragmentDistributionType::Hash => Distribution::Hash(fragment.vnode_count()),
133            PbFragmentDistributionType::Unspecified => {
134                unreachable!()
135            }
136        }
137    }
138
139    /// Convert the distribution to [`PbFragmentDistributionType`].
140    pub fn to_distribution_type(&self) -> PbFragmentDistributionType {
141        match self {
142            Distribution::Singleton => PbFragmentDistributionType::Single,
143            Distribution::Hash(_) => PbFragmentDistributionType::Hash,
144        }
145    }
146}
147
148/// [`Scheduler`] schedules the distribution of fragments in a stream graph.
149pub(super) struct Scheduler {
150    /// The default hash mapping for hash-distributed fragments, if there's no requirement derived.
151    default_vnode_count: usize,
152}
153
154impl Scheduler {
155    /// Create a new [`Scheduler`] with the expected vnode count of the streaming job.
156    pub fn new(expected_vnode_count: usize) -> MetaResult<Self> {
157        Ok(Self {
158            default_vnode_count: expected_vnode_count,
159        })
160    }
161
162    /// Schedule the given complete graph and returns the distribution of each **building
163    /// fragment**.
164    pub fn schedule(
165        &self,
166        graph: &CompleteStreamFragmentGraph,
167    ) -> MetaResult<HashMap<Id, Distribution>> {
168        let existing_distribution = graph.existing_distribution();
169
170        // Build an index map for all hash mappings.
171        let all_hash_mappings = existing_distribution
172            .values()
173            .flat_map(|dist| dist.as_hash())
174            .cloned()
175            .unique()
176            .collect_vec();
177        let hash_mapping_id: HashMap<_, _> = all_hash_mappings
178            .iter()
179            .enumerate()
180            .map(|(i, m)| (*m, i))
181            .collect();
182
183        let mut facts = Vec::new();
184
185        // Singletons.
186        for (&id, fragment) in graph.building_fragments() {
187            if fragment.requires_singleton {
188                facts.push(Fact::Req {
189                    id,
190                    req: Req::AnySingleton,
191                });
192            }
193        }
194        let mut force_parallelism_fragment_ids: HashMap<_, _> = HashMap::default();
195        // Vnode count requirements: if a fragment is going to look up an existing table,
196        // it must have the same vnode count as that table.
197        for (&id, fragment) in graph.building_fragments() {
198            visit_fragment(fragment, |node| {
199                use risingwave_pb::stream_plan::stream_node::NodeBody;
200                let vnode_count = match node {
201                    NodeBody::StreamScan(node) => {
202                        if let Some(table) = &node.arrangement_table {
203                            table.vnode_count()
204                        } else if let Some(table) = &node.table_desc {
205                            table.vnode_count()
206                        } else {
207                            return;
208                        }
209                    }
210                    NodeBody::TemporalJoin(node) => node.get_table_desc().unwrap().vnode_count(),
211                    NodeBody::BatchPlan(node) => node.get_table_desc().unwrap().vnode_count(),
212                    NodeBody::Lookup(node) => node
213                        .get_arrangement_table_info()
214                        .unwrap()
215                        .get_table_desc()
216                        .unwrap()
217                        .vnode_count(),
218                    NodeBody::StreamCdcScan(node) => {
219                        let Some(ref options) = node.options else {
220                            return;
221                        };
222                        let options = CdcScanOptions::from_proto(options);
223                        if options.is_parallelized_backfill() {
224                            force_parallelism_fragment_ids
225                                .insert(id, options.backfill_parallelism as usize);
226                            CDC_BACKFILL_MAX_PARALLELISM as usize
227                        } else {
228                            return;
229                        }
230                    }
231                    NodeBody::GapFill(node) => {
232                        // GapFill node uses buffer_table for vnode count requirement
233                        let buffer_table = node.get_state_table().unwrap();
234                        // Check if vnode_count is a placeholder, skip if so as it will be filled later
235                        if let Some(vnode_count) = buffer_table.vnode_count_inner().value_opt() {
236                            vnode_count
237                        } else {
238                            // Skip this node as vnode_count is still a placeholder
239                            return;
240                        }
241                    }
242                    _ => return,
243                };
244                facts.push(Fact::Req {
245                    id,
246                    req: Req::AnyVnodeCount(vnode_count),
247                });
248            });
249        }
250        // Distributions of existing fragments.
251        for (id, dist) in existing_distribution {
252            let req = match dist {
253                Distribution::Singleton => Req::Singleton,
254                Distribution::Hash(mapping) => Req::Hash(hash_mapping_id[&mapping]),
255            };
256            facts.push(Fact::Req { id, req });
257        }
258        // Edges.
259        for (from, to, edge) in graph.all_edges() {
260            facts.push(Fact::Edge {
261                from,
262                to,
263                dt: edge.dispatch_strategy.r#type(),
264            });
265        }
266
267        // Run the algorithm to propagate requirements.
268        let mut crepe = Crepe::new();
269        crepe.extend(facts.into_iter().map(Input));
270        let (reqs,) = crepe.run();
271        let reqs = reqs
272            .into_iter()
273            .map(|Requirement(id, req)| (id, req))
274            .into_group_map();
275
276        // Derive scheduling result from requirements.
277        let mut distributions = HashMap::new();
278        for &id in graph.building_fragments().keys() {
279            let dist = match reqs.get(&id) {
280                // Merge all requirements.
281                Some(reqs) => {
282                    let req = (reqs.iter().copied())
283                        .try_reduce(|a, b| Req::merge(a, b, |id| all_hash_mappings[id]))
284                        .with_context(|| {
285                            format!("cannot fulfill scheduling requirements for fragment {id:?}")
286                        })?
287                        .unwrap();
288
289                    // Derive distribution from the merged requirement.
290                    match req {
291                        Req::Singleton => Distribution::Singleton,
292                        Req::Hash(mapping) => Distribution::Hash(all_hash_mappings[mapping]),
293                        Req::AnySingleton => Distribution::Singleton,
294                        Req::AnyVnodeCount(vnode_count) => Distribution::Hash(vnode_count),
295                    }
296                }
297                // No requirement, use the default.
298                None => Distribution::Hash(self.default_vnode_count),
299            };
300
301            distributions.insert(id, dist);
302        }
303
304        tracing::debug!(?distributions, "schedule fragments");
305
306        Ok(distributions)
307    }
308}
309
310/// [`Locations`] represents the locations of the actors.
311#[cfg_attr(test, derive(Default))]
312pub struct Locations {
313    /// actor location map.
314    pub actor_locations: BTreeMap<ActorId, ActorAlignmentId>,
315    /// worker location map.
316    pub worker_locations: HashMap<WorkerId, WorkerNode>,
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322
323    #[derive(Debug)]
324    enum Result {
325        DefaultHash,
326        Required(Req),
327    }
328
329    impl Result {
330        #[expect(non_upper_case_globals)]
331        const DefaultSingleton: Self = Self::Required(Req::AnySingleton);
332    }
333
334    fn run_and_merge(
335        facts: impl IntoIterator<Item = Fact>,
336        mapping_len: impl Fn(HashMappingId) -> usize,
337    ) -> MetaResult<HashMap<Id, Req>> {
338        let mut crepe = Crepe::new();
339        crepe.extend(facts.into_iter().map(Input));
340        let (reqs,) = crepe.run();
341
342        let reqs = reqs
343            .into_iter()
344            .map(|Requirement(id, req)| (id, req))
345            .into_group_map();
346
347        let mut merged = HashMap::new();
348        for (id, reqs) in reqs {
349            let req = (reqs.iter().copied())
350                .try_reduce(|a, b| Req::merge(a, b, &mapping_len))
351                .with_context(|| {
352                    format!("cannot fulfill scheduling requirements for fragment {id:?}")
353                })?
354                .unwrap();
355            merged.insert(id, req);
356        }
357
358        Ok(merged)
359    }
360
361    fn test_success(facts: impl IntoIterator<Item = Fact>, expected: HashMap<Id, Result>) {
362        test_success_with_mapping_len(facts, expected, |_| 0);
363    }
364
365    fn test_success_with_mapping_len(
366        facts: impl IntoIterator<Item = Fact>,
367        expected: HashMap<Id, Result>,
368        mapping_len: impl Fn(HashMappingId) -> usize,
369    ) {
370        let reqs = run_and_merge(facts, mapping_len).unwrap();
371
372        for (id, expected) in expected {
373            match (reqs.get(&id), expected) {
374                (None, Result::DefaultHash) => {}
375                (Some(actual), Result::Required(expected)) if *actual == expected => {}
376                (actual, expected) => panic!(
377                    "unexpected result for fragment {id:?}\nactual: {actual:?}\nexpected: {expected:?}"
378                ),
379            }
380        }
381    }
382
383    fn test_failed(facts: impl IntoIterator<Item = Fact>) {
384        run_and_merge(facts, |_| 0).unwrap_err();
385    }
386
387    // 101
388    #[test]
389    fn test_single_fragment_hash() {
390        #[rustfmt::skip]
391        let facts = [];
392
393        let expected = maplit::hashmap! {
394            101.into() => Result::DefaultHash,
395        };
396
397        test_success(facts, expected);
398    }
399
400    // 101
401    #[test]
402    fn test_single_fragment_singleton() {
403        #[rustfmt::skip]
404        let facts = [
405            Fact::Req { id: 101.into(), req: Req::AnySingleton },
406        ];
407
408        let expected = maplit::hashmap! {
409            101.into() => Result::DefaultSingleton,
410        };
411
412        test_success(facts, expected);
413    }
414
415    // 1 -|-> 101 -->
416    //                103 --> 104
417    // 2 -|-> 102 -->
418    #[test]
419    fn test_scheduling_mv_on_mv() {
420        #[rustfmt::skip]
421        let facts = [
422            Fact::Req { id: 1.into(), req: Req::Hash(1) },
423            Fact::Req { id: 2.into(), req: Req::Singleton },
424            Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
425            Fact::Edge { from: 2.into(), to: 102.into(), dt: NoShuffle },
426            Fact::Edge { from: 101.into(), to: 103.into(), dt: Hash },
427            Fact::Edge { from: 102.into(), to: 103.into(), dt: Hash },
428            Fact::Edge { from: 103.into(), to: 104.into(), dt: Simple },
429        ];
430
431        let expected = maplit::hashmap! {
432            101.into() => Result::Required(Req::Hash(1)),
433            102.into() => Result::Required(Req::Singleton),
434            103.into() => Result::DefaultHash,
435            104.into() => Result::DefaultSingleton,
436        };
437
438        test_success(facts, expected);
439    }
440
441    // 1 -|-> 101 --> 103 -->
442    //             X          105
443    // 2 -|-> 102 --> 104 -->
444    #[test]
445    fn test_delta_join() {
446        #[rustfmt::skip]
447        let facts = [
448            Fact::Req { id: 1.into(), req: Req::Hash(1) },
449            Fact::Req { id: 2.into(), req: Req::Hash(2) },
450            Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
451            Fact::Edge { from: 2.into(), to: 102.into(), dt: NoShuffle },
452            Fact::Edge { from: 101.into(), to: 103.into(), dt: NoShuffle },
453            Fact::Edge { from: 102.into(), to: 104.into(), dt: NoShuffle },
454            Fact::Edge { from: 101.into(), to: 104.into(), dt: Hash },
455            Fact::Edge { from: 102.into(), to: 103.into(), dt: Hash },
456            Fact::Edge { from: 103.into(), to: 105.into(), dt: Hash },
457            Fact::Edge { from: 104.into(), to: 105.into(), dt: Hash },
458        ];
459
460        let expected = maplit::hashmap! {
461            101.into() => Result::Required(Req::Hash(1)),
462            102.into() => Result::Required(Req::Hash(2)),
463            103.into() => Result::Required(Req::Hash(1)),
464            104.into() => Result::Required(Req::Hash(2)),
465            105.into() => Result::DefaultHash,
466        };
467
468        test_success(facts, expected);
469    }
470
471    // 1 -|-> 101 -->
472    //                103
473    //        102 -->
474    #[test]
475    fn test_singleton_leaf() {
476        #[rustfmt::skip]
477        let facts = [
478            Fact::Req { id: 1.into(), req: Req::Hash(1) },
479            Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
480            Fact::Req { id: 102.into(), req: Req::AnySingleton }, // like `Now`
481            Fact::Edge { from: 101.into(), to: 103.into(), dt: Hash },
482            Fact::Edge { from: 102.into(), to: 103.into(), dt: Broadcast },
483        ];
484
485        let expected = maplit::hashmap! {
486            101.into() => Result::Required(Req::Hash(1)),
487            102.into() => Result::DefaultSingleton,
488            103.into() => Result::DefaultHash,
489        };
490
491        test_success(facts, expected);
492    }
493
494    // 1 -|->
495    //        101
496    // 2 -|->
497    #[test]
498    fn test_upstream_hash_shard_failed() {
499        #[rustfmt::skip]
500        let facts = [
501            Fact::Req { id: 1.into(), req: Req::Hash(1) },
502            Fact::Req { id: 2.into(), req: Req::Hash(2) },
503            Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
504            Fact::Edge { from: 2.into(), to: 101.into(), dt: NoShuffle },
505        ];
506
507        test_failed(facts);
508    }
509
510    // 1 -|~> 101
511    #[test]
512    fn test_arrangement_backfill_vnode_count() {
513        #[rustfmt::skip]
514        let facts = [
515            Fact::Req { id: 1.into(), req: Req::Hash(1) },
516            Fact::Req { id: 101.into(), req: Req::AnyVnodeCount(128) },
517            Fact::Edge { from: 1.into(), to: 101.into(), dt: Hash },
518        ];
519
520        let expected = maplit::hashmap! {
521            101.into() => Result::Required(Req::AnyVnodeCount(128)),
522        };
523
524        test_success(facts, expected);
525    }
526
527    // 1 -|~> 101
528    #[test]
529    fn test_no_shuffle_backfill_vnode_count() {
530        #[rustfmt::skip]
531        let facts = [
532            Fact::Req { id: 1.into(), req: Req::Hash(1) },
533            Fact::Req { id: 101.into(), req: Req::AnyVnodeCount(128) },
534            Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
535        ];
536
537        let expected = maplit::hashmap! {
538            101.into() => Result::Required(Req::Hash(1)),
539        };
540
541        test_success_with_mapping_len(facts, expected, |id| {
542            assert_eq!(id, 1);
543            128
544        });
545    }
546
547    // 1 -|~> 101
548    #[test]
549    fn test_no_shuffle_backfill_mismatched_vnode_count() {
550        #[rustfmt::skip]
551        let facts = [
552            Fact::Req { id: 1.into(), req: Req::Hash(1) },
553            Fact::Req { id: 101.into(), req: Req::AnyVnodeCount(128) },
554            Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle },
555        ];
556
557        // Not specifying `mapping_len` should fail.
558        test_failed(facts);
559    }
560
561    // 1 -|~> 101
562    #[test]
563    fn test_backfill_singleton_vnode_count() {
564        #[rustfmt::skip]
565        let facts = [
566            Fact::Req { id: 1.into(), req: Req::Singleton },
567            Fact::Req { id: 101.into(), req: Req::AnySingleton },
568            Fact::Edge { from: 1.into(), to: 101.into(), dt: NoShuffle }, // or `Simple`
569        ];
570
571        let expected = maplit::hashmap! {
572            101.into() => Result::Required(Req::Singleton),
573        };
574
575        test_success(facts, expected);
576    }
577}