risingwave_meta/model/
stream.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
16use std::ops::{AddAssign, Deref};
17use std::sync::Arc;
18
19use itertools::Itertools;
20use risingwave_common::bitmap::Bitmap;
21use risingwave_common::catalog::{FragmentTypeFlag, FragmentTypeMask, TableId};
22use risingwave_common::hash::{IsSingleton, VirtualNode, VnodeCount, VnodeCountCompat};
23use risingwave_common::id::JobId;
24use risingwave_common::system_param::AdaptiveParallelismStrategy;
25use risingwave_common::system_param::adaptive_parallelism_strategy::parse_strategy;
26use risingwave_common::util::stream_graph_visitor::{self, visit_stream_node_body};
27use risingwave_meta_model::{DispatcherType, SourceId, StreamingParallelism, WorkerId, fragment};
28use risingwave_pb::catalog::Table;
29use risingwave_pb::common::ActorInfo;
30use risingwave_pb::id::SubscriberId;
31use risingwave_pb::meta::table_fragments::fragment::{
32    FragmentDistributionType, PbFragmentDistributionType,
33};
34use risingwave_pb::meta::table_fragments::{PbActorStatus, PbFragment, State};
35use risingwave_pb::meta::table_parallelism::{
36    FixedParallelism, Parallelism, PbAdaptiveParallelism, PbCustomParallelism, PbFixedParallelism,
37    PbParallelism,
38};
39use risingwave_pb::meta::{PbTableFragments, PbTableParallelism};
40use risingwave_pb::plan_common::PbExprContext;
41use risingwave_pb::stream_plan::stream_node::NodeBody;
42use risingwave_pb::stream_plan::{
43    DispatchStrategy, Dispatcher, PbDispatchOutputMapping, PbDispatcher, PbStreamActor,
44    PbStreamContext, StreamNode,
45};
46use strum::Display;
47
48use super::{ActorId, FragmentId};
49
50/// The parallelism for a `TableFragments`.
51#[derive(Debug, Copy, Clone, Eq, PartialEq)]
52pub enum TableParallelism {
53    /// This is when the system decides the parallelism, based on the available worker parallelisms.
54    Adaptive,
55    /// We set this when the `TableFragments` parallelism is changed.
56    /// All fragments which are part of the `TableFragment` will have the same parallelism as this.
57    Fixed(usize),
58    /// We set this when the individual parallelisms of the `Fragments`
59    /// can differ within a `TableFragments`.
60    /// This is set for `risectl`, since it has a low-level interface,
61    /// scale individual `Fragments` within `TableFragments`.
62    /// When that happens, the `TableFragments` no longer has a consistent
63    /// parallelism, so we set this to indicate that.
64    Custom,
65}
66
67impl From<PbTableParallelism> for TableParallelism {
68    fn from(value: PbTableParallelism) -> Self {
69        use Parallelism::*;
70        match &value.parallelism {
71            Some(Fixed(FixedParallelism { parallelism: n })) => Self::Fixed(*n as usize),
72            Some(Adaptive(_)) | Some(Auto(_)) => Self::Adaptive,
73            Some(Custom(_)) => Self::Custom,
74            _ => unreachable!(),
75        }
76    }
77}
78
79impl From<TableParallelism> for PbTableParallelism {
80    fn from(value: TableParallelism) -> Self {
81        use TableParallelism::*;
82
83        let parallelism = match value {
84            Adaptive => PbParallelism::Adaptive(PbAdaptiveParallelism {}),
85            Fixed(n) => PbParallelism::Fixed(PbFixedParallelism {
86                parallelism: n as u32,
87            }),
88            Custom => PbParallelism::Custom(PbCustomParallelism {}),
89        };
90
91        Self {
92            parallelism: Some(parallelism),
93        }
94    }
95}
96
97impl From<StreamingParallelism> for TableParallelism {
98    fn from(value: StreamingParallelism) -> Self {
99        match value {
100            StreamingParallelism::Adaptive => TableParallelism::Adaptive,
101            StreamingParallelism::Fixed(n) => TableParallelism::Fixed(n),
102            StreamingParallelism::Custom => TableParallelism::Custom,
103        }
104    }
105}
106
107impl From<TableParallelism> for StreamingParallelism {
108    fn from(value: TableParallelism) -> Self {
109        match value {
110            TableParallelism::Adaptive => StreamingParallelism::Adaptive,
111            TableParallelism::Fixed(n) => StreamingParallelism::Fixed(n),
112            TableParallelism::Custom => StreamingParallelism::Custom,
113        }
114    }
115}
116
117pub type ActorUpstreams = BTreeMap<FragmentId, HashMap<ActorId, ActorInfo>>;
118pub type StreamActorWithDispatchers = (StreamActor, Vec<PbDispatcher>);
119pub type StreamActorWithUpDownstreams = (StreamActor, ActorUpstreams, Vec<PbDispatcher>);
120pub type FragmentActorDispatchers = HashMap<FragmentId, HashMap<ActorId, Vec<PbDispatcher>>>;
121
122pub type FragmentDownstreamRelation = HashMap<FragmentId, Vec<DownstreamFragmentRelation>>;
123/// downstream `fragment_id` -> original upstream `fragment_id` -> new upstream `fragment_id`
124pub type FragmentReplaceUpstream = HashMap<FragmentId, HashMap<FragmentId, FragmentId>>;
125/// The newly added no-shuffle actor dispatcher from upstream fragment to downstream fragment
126/// upstream `fragment_id` -> downstream `fragment_id` -> upstream `actor_id` -> downstream `actor_id`
127pub type ActorNewNoShuffle = HashMap<FragmentId, HashMap<FragmentId, HashMap<ActorId, ActorId>>>;
128
129#[derive(Debug, Clone)]
130pub struct DownstreamFragmentRelation {
131    pub downstream_fragment_id: FragmentId,
132    pub dispatcher_type: DispatcherType,
133    pub dist_key_indices: Vec<u32>,
134    pub output_mapping: PbDispatchOutputMapping,
135}
136
137impl From<(FragmentId, DispatchStrategy)> for DownstreamFragmentRelation {
138    fn from((fragment_id, dispatch): (FragmentId, DispatchStrategy)) -> Self {
139        Self {
140            downstream_fragment_id: fragment_id,
141            dispatcher_type: dispatch.get_type().unwrap().into(),
142            dist_key_indices: dispatch.dist_key_indices,
143            output_mapping: dispatch.output_mapping.unwrap(),
144        }
145    }
146}
147
148#[derive(Debug, Clone)]
149pub struct StreamJobFragmentsToCreate {
150    pub inner: StreamJobFragments,
151    pub downstreams: FragmentDownstreamRelation,
152}
153
154impl Deref for StreamJobFragmentsToCreate {
155    type Target = StreamJobFragments;
156
157    fn deref(&self) -> &Self::Target {
158        &self.inner
159    }
160}
161
162#[derive(Clone, Debug)]
163pub struct StreamActor {
164    pub actor_id: ActorId,
165    pub fragment_id: FragmentId,
166    pub vnode_bitmap: Option<Bitmap>,
167    pub mview_definition: String,
168    pub expr_context: Option<PbExprContext>,
169    // TODO: shall we merge `config_override` with `expr_context` to be a `StreamContext`?
170    pub config_override: Arc<str>,
171}
172
173impl StreamActor {
174    fn to_protobuf(&self, dispatchers: impl Iterator<Item = Dispatcher>) -> PbStreamActor {
175        PbStreamActor {
176            actor_id: self.actor_id,
177            fragment_id: self.fragment_id,
178            dispatcher: dispatchers.collect(),
179            vnode_bitmap: self
180                .vnode_bitmap
181                .as_ref()
182                .map(|bitmap| bitmap.to_protobuf()),
183            mview_definition: self.mview_definition.clone(),
184            expr_context: self.expr_context.clone(),
185            config_override: self.config_override.to_string(),
186        }
187    }
188}
189
190#[derive(Clone, Debug, Default)]
191pub struct Fragment {
192    pub fragment_id: FragmentId,
193    pub fragment_type_mask: FragmentTypeMask,
194    pub distribution_type: PbFragmentDistributionType,
195    pub state_table_ids: Vec<TableId>,
196    pub maybe_vnode_count: Option<u32>,
197    pub nodes: StreamNode,
198}
199
200impl Fragment {
201    pub fn to_protobuf(
202        &self,
203        actors: &[StreamActor],
204        upstream_fragments: impl Iterator<Item = FragmentId>,
205        dispatchers: Option<&HashMap<ActorId, Vec<Dispatcher>>>,
206    ) -> PbFragment {
207        PbFragment {
208            fragment_id: self.fragment_id,
209            fragment_type_mask: self.fragment_type_mask.into(),
210            distribution_type: self.distribution_type as _,
211            actors: actors
212                .iter()
213                .map(|actor| {
214                    actor.to_protobuf(
215                        dispatchers
216                            .and_then(|dispatchers| dispatchers.get(&actor.actor_id))
217                            .into_iter()
218                            .flatten()
219                            .cloned(),
220                    )
221                })
222                .collect(),
223            state_table_ids: self.state_table_ids.clone(),
224            upstream_fragment_ids: upstream_fragments.collect(),
225            maybe_vnode_count: self.maybe_vnode_count,
226            nodes: Some(self.nodes.clone()),
227        }
228    }
229}
230
231impl VnodeCountCompat for Fragment {
232    fn vnode_count_inner(&self) -> VnodeCount {
233        VnodeCount::from_protobuf(self.maybe_vnode_count, || self.is_singleton())
234    }
235}
236
237impl IsSingleton for Fragment {
238    fn is_singleton(&self) -> bool {
239        matches!(self.distribution_type, FragmentDistributionType::Single)
240    }
241}
242
243impl From<fragment::Model> for Fragment {
244    fn from(model: fragment::Model) -> Self {
245        Self {
246            fragment_id: model.fragment_id,
247            fragment_type_mask: FragmentTypeMask::from(model.fragment_type_mask),
248            distribution_type: model.distribution_type.into(),
249            state_table_ids: model.state_table_ids.into_inner(),
250            maybe_vnode_count: VnodeCount::set(model.vnode_count).to_protobuf(),
251            nodes: model.stream_node.to_protobuf(),
252        }
253    }
254}
255
256/// Fragments of a streaming job. Corresponds to [`PbTableFragments`].
257/// (It was previously called `TableFragments` due to historical reasons.)
258///
259/// We store whole fragments in a single column family as follow:
260/// `stream_job_id` => `StreamJobFragments`.
261#[derive(Debug, Clone)]
262pub struct StreamJobFragments {
263    /// The table id.
264    pub stream_job_id: JobId,
265
266    /// The state of the table fragments.
267    pub state: State,
268
269    /// The table fragments.
270    pub fragments: BTreeMap<FragmentId, Fragment>,
271
272    /// The streaming context associated with this stream plan and its fragments
273    pub ctx: StreamContext,
274
275    /// The parallelism assigned to this table fragments
276    pub assigned_parallelism: TableParallelism,
277
278    /// The max parallelism specified when the streaming job was created, i.e., expected vnode count.
279    ///
280    /// The reason for persisting this value is mainly to check if a parallelism change (via `ALTER
281    /// .. SET PARALLELISM`) is valid, so that the behavior can be consistent with the creation of
282    /// the streaming job.
283    ///
284    /// Note that the actual vnode count, denoted by `vnode_count` in `fragments`, may be different
285    /// from this value (see `StreamFragmentGraph.max_parallelism` for more details.). As a result,
286    /// checking the parallelism change with this value can be inaccurate in some cases. However,
287    /// when generating resizing plans, we still take the `vnode_count` of each fragment into account.
288    pub max_parallelism: usize,
289}
290
291#[derive(Debug, Clone, Default)]
292pub struct StreamContext {
293    /// The timezone used to interpret timestamps and dates for conversion
294    pub timezone: Option<String>,
295
296    /// The partial config of this job to override the global config.
297    pub config_override: Arc<str>,
298
299    /// The adaptive parallelism strategy for this job if it overrides the system default.
300    pub adaptive_parallelism_strategy: Option<AdaptiveParallelismStrategy>,
301}
302
303impl StreamContext {
304    pub fn to_protobuf(&self) -> PbStreamContext {
305        PbStreamContext {
306            timezone: self.timezone.clone().unwrap_or("".into()),
307            config_override: self.config_override.to_string(),
308            adaptive_parallelism_strategy: self
309                .adaptive_parallelism_strategy
310                .as_ref()
311                .map(ToString::to_string)
312                .unwrap_or_default(),
313            backfill_adaptive_parallelism_strategy: String::new(),
314        }
315    }
316
317    pub fn to_expr_context(&self) -> PbExprContext {
318        PbExprContext {
319            // `self.timezone` must always be set; an invalid value is used here for debugging if it's not.
320            time_zone: self.timezone.clone().unwrap_or("Empty Time Zone".into()),
321            strict_mode: false,
322        }
323    }
324
325    pub fn from_protobuf(prost: &PbStreamContext) -> Self {
326        Self {
327            timezone: if prost.get_timezone().is_empty() {
328                None
329            } else {
330                Some(prost.get_timezone().clone())
331            },
332            config_override: prost.get_config_override().as_str().into(),
333            adaptive_parallelism_strategy: if prost.get_adaptive_parallelism_strategy().is_empty() {
334                None
335            } else {
336                Some(
337                    parse_strategy(prost.get_adaptive_parallelism_strategy())
338                        .expect("adaptive parallelism strategy should be validated in frontend"),
339                )
340            },
341        }
342    }
343}
344
345#[easy_ext::ext(StreamingJobModelContextExt)]
346impl risingwave_meta_model::streaming_job::Model {
347    pub fn stream_context(&self) -> StreamContext {
348        StreamContext {
349            timezone: self.timezone.clone(),
350            config_override: self.config_override.clone().unwrap_or_default().into(),
351            adaptive_parallelism_strategy: self.adaptive_parallelism_strategy.as_deref().map(|s| {
352                parse_strategy(s).expect("strategy should be validated before persisting")
353            }),
354        }
355    }
356}
357
358impl StreamJobFragments {
359    pub fn to_protobuf(
360        &self,
361        fragment_actors: &HashMap<FragmentId, Vec<StreamActor>>,
362        fragment_upstreams: &HashMap<FragmentId, HashSet<FragmentId>>,
363        fragment_dispatchers: &FragmentActorDispatchers,
364        actor_status: HashMap<ActorId, PbActorStatus>,
365    ) -> PbTableFragments {
366        PbTableFragments {
367            table_id: self.stream_job_id,
368            state: self.state as _,
369            fragments: self
370                .fragments
371                .iter()
372                .map(|(id, fragment)| {
373                    let actors = fragment_actors.get(id).map(|a| a.as_slice()).unwrap_or(&[]);
374                    (
375                        *id,
376                        fragment.to_protobuf(
377                            actors,
378                            fragment_upstreams.get(id).into_iter().flatten().cloned(),
379                            fragment_dispatchers.get(id),
380                        ),
381                    )
382                })
383                .collect(),
384            actor_status,
385            ctx: Some(self.ctx.to_protobuf()),
386            parallelism: Some(self.assigned_parallelism.into()),
387            node_label: "".to_owned(),
388            backfill_done: true,
389            max_parallelism: Some(self.max_parallelism as _),
390        }
391    }
392}
393
394pub type StreamJobActorsToCreate = HashMap<
395    WorkerId,
396    HashMap<
397        FragmentId,
398        (
399            StreamNode,
400            Vec<StreamActorWithUpDownstreams>,
401            HashSet<SubscriberId>,
402        ),
403    >,
404>;
405
406impl StreamJobFragments {
407    /// Create a new `TableFragments` with state of `Initial`, with other fields empty.
408    pub fn for_test(job_id: JobId, fragments: BTreeMap<FragmentId, Fragment>) -> Self {
409        Self::new(
410            job_id,
411            fragments,
412            StreamContext::default(),
413            TableParallelism::Adaptive,
414            VirtualNode::COUNT_FOR_TEST,
415        )
416    }
417
418    /// Create a new `TableFragments` with state of `Initial`.
419    pub fn new(
420        stream_job_id: JobId,
421        fragments: BTreeMap<FragmentId, Fragment>,
422        ctx: StreamContext,
423        table_parallelism: TableParallelism,
424        max_parallelism: usize,
425    ) -> Self {
426        Self {
427            stream_job_id,
428            state: State::Initial,
429            fragments,
430            ctx,
431            assigned_parallelism: table_parallelism,
432            max_parallelism,
433        }
434    }
435
436    pub fn fragment_ids(&self) -> impl Iterator<Item = FragmentId> + '_ {
437        self.fragments.keys().cloned()
438    }
439
440    pub fn fragments(&self) -> impl Iterator<Item = &Fragment> {
441        self.fragments.values()
442    }
443
444    /// Returns the table id.
445    pub fn stream_job_id(&self) -> JobId {
446        self.stream_job_id
447    }
448
449    /// Returns the timezone of the table
450    pub fn timezone(&self) -> Option<String> {
451        self.ctx.timezone.clone()
452    }
453
454    /// Returns whether the table fragments is in `Created` state.
455    pub fn is_created(&self) -> bool {
456        self.state == State::Created
457    }
458
459    /// Returns mview fragment ids.
460    #[cfg(test)]
461    pub fn mview_fragment_ids(&self) -> Vec<FragmentId> {
462        self.fragments
463            .values()
464            .filter(move |fragment| {
465                fragment
466                    .fragment_type_mask
467                    .contains(FragmentTypeFlag::Mview)
468            })
469            .map(|fragment| fragment.fragment_id)
470            .collect()
471    }
472
473    /// Returns actor ids that need to be tracked when creating MV.
474    pub fn tracking_progress_actor_ids_impl(
475        fragments: impl IntoIterator<Item = (FragmentTypeMask, impl Iterator<Item = ActorId>)>,
476    ) -> Vec<(ActorId, BackfillUpstreamType)> {
477        let mut actor_ids = vec![];
478        for (fragment_type_mask, actors) in fragments {
479            if fragment_type_mask.contains(FragmentTypeFlag::CdcFilter) {
480                // Note: CDC table job contains a StreamScan fragment (StreamCdcScan node) and a CdcFilter fragment.
481                // We don't track any fragments' progress.
482                return vec![];
483            }
484            if fragment_type_mask.contains_any([
485                FragmentTypeFlag::Values,
486                FragmentTypeFlag::StreamScan,
487                FragmentTypeFlag::SourceScan,
488                FragmentTypeFlag::LocalityProvider,
489            ]) {
490                actor_ids.extend(actors.map(|actor_id| {
491                    (
492                        actor_id,
493                        BackfillUpstreamType::from_fragment_type_mask(fragment_type_mask),
494                    )
495                }));
496            }
497        }
498        actor_ids
499    }
500
501    pub fn root_fragment(&self) -> Option<Fragment> {
502        self.mview_fragment()
503            .or_else(|| self.sink_fragment())
504            .or_else(|| self.source_fragment())
505    }
506
507    /// Returns the fragment with the `Mview` type flag.
508    pub fn mview_fragment(&self) -> Option<Fragment> {
509        self.fragments
510            .values()
511            .find(|fragment| {
512                fragment
513                    .fragment_type_mask
514                    .contains(FragmentTypeFlag::Mview)
515            })
516            .cloned()
517    }
518
519    pub fn source_fragment(&self) -> Option<Fragment> {
520        self.fragments
521            .values()
522            .find(|fragment| {
523                fragment
524                    .fragment_type_mask
525                    .contains(FragmentTypeFlag::Source)
526            })
527            .cloned()
528    }
529
530    pub fn sink_fragment(&self) -> Option<Fragment> {
531        self.fragments
532            .values()
533            .find(|fragment| fragment.fragment_type_mask.contains(FragmentTypeFlag::Sink))
534            .cloned()
535    }
536
537    /// Extract the fragments that include source executors that contains an external stream source,
538    /// grouping by source id.
539    pub fn stream_source_fragments(&self) -> HashMap<SourceId, BTreeSet<FragmentId>> {
540        let mut source_fragments = HashMap::new();
541
542        for fragment in self.fragments() {
543            {
544                if let Some(source_id) = fragment.nodes.find_stream_source() {
545                    source_fragments
546                        .entry(source_id)
547                        .or_insert(BTreeSet::new())
548                        .insert(fragment.fragment_id as FragmentId);
549                }
550            }
551        }
552        source_fragments
553    }
554
555    pub fn source_backfill_fragments(
556        &self,
557    ) -> HashMap<SourceId, BTreeSet<(FragmentId, FragmentId)>> {
558        Self::source_backfill_fragments_impl(
559            self.fragments
560                .iter()
561                .map(|(fragment_id, fragment)| (*fragment_id, &fragment.nodes)),
562        )
563    }
564
565    /// Returns (`source_id`, -> (`source_backfill_fragment_id`, `upstream_source_fragment_id`)).
566    ///
567    /// Note: the fragment `source_backfill_fragment_id` may actually have multiple upstream fragments,
568    /// but only one of them is the upstream source fragment, which is what we return.
569    pub fn source_backfill_fragments_impl(
570        fragments: impl Iterator<Item = (FragmentId, &StreamNode)>,
571    ) -> HashMap<SourceId, BTreeSet<(FragmentId, FragmentId)>> {
572        let mut source_backfill_fragments = HashMap::new();
573
574        for (fragment_id, fragment_node) in fragments {
575            {
576                if let Some((source_id, upstream_source_fragment_id)) =
577                    fragment_node.find_source_backfill()
578                {
579                    source_backfill_fragments
580                        .entry(source_id)
581                        .or_insert(BTreeSet::new())
582                        .insert((fragment_id, upstream_source_fragment_id));
583                }
584            }
585        }
586        source_backfill_fragments
587    }
588
589    /// Find the table job's `Union` fragment.
590    /// Panics if not found.
591    pub fn union_fragment_for_table(&mut self) -> &mut Fragment {
592        let mut union_fragment_id = None;
593        for (fragment_id, fragment) in &self.fragments {
594            {
595                {
596                    visit_stream_node_body(&fragment.nodes, |body| {
597                        if let NodeBody::Union(_) = body {
598                            if let Some(union_fragment_id) = union_fragment_id.as_mut() {
599                                // The union fragment should be unique.
600                                assert_eq!(*union_fragment_id, *fragment_id);
601                            } else {
602                                union_fragment_id = Some(*fragment_id);
603                            }
604                        }
605                    })
606                }
607            }
608        }
609
610        let union_fragment_id =
611            union_fragment_id.expect("fragment of placeholder merger not found");
612
613        (self
614            .fragments
615            .get_mut(&union_fragment_id)
616            .unwrap_or_else(|| panic!("fragment {} not found", union_fragment_id))) as _
617    }
618
619    /// Resolve dependent table
620    fn resolve_dependent_table(stream_node: &StreamNode, table_ids: &mut HashMap<TableId, usize>) {
621        let table_id = match stream_node.node_body.as_ref() {
622            Some(NodeBody::StreamScan(stream_scan)) => Some(stream_scan.table_id),
623            Some(NodeBody::StreamCdcScan(stream_scan)) => Some(stream_scan.table_id),
624            Some(NodeBody::LocalityProvider(state)) => {
625                Some(state.state_table.as_ref().expect("must have state").id)
626            }
627            _ => None,
628        };
629        if let Some(table_id) = table_id {
630            table_ids.entry(table_id).or_default().add_assign(1);
631        }
632
633        for child in &stream_node.input {
634            Self::resolve_dependent_table(child, table_ids);
635        }
636    }
637
638    pub fn upstream_table_counts(&self) -> HashMap<TableId, usize> {
639        Self::upstream_table_counts_impl(self.fragments.values().map(|fragment| &fragment.nodes))
640    }
641
642    /// Returns upstream table counts.
643    pub fn upstream_table_counts_impl(
644        fragment_nodes: impl Iterator<Item = &StreamNode>,
645    ) -> HashMap<TableId, usize> {
646        let mut table_ids = HashMap::new();
647        fragment_nodes.for_each(|node| {
648            Self::resolve_dependent_table(node, &mut table_ids);
649        });
650
651        table_ids
652    }
653
654    pub fn mv_table_id(&self) -> Option<TableId> {
655        self.fragments
656            .values()
657            .flat_map(|f| f.state_table_ids.iter().copied())
658            .find(|table_id| self.stream_job_id.is_mv_table_id(*table_id))
659    }
660
661    pub fn collect_tables(fragments: impl Iterator<Item = &Fragment>) -> BTreeMap<TableId, Table> {
662        let mut tables = BTreeMap::new();
663        for fragment in fragments {
664            stream_graph_visitor::visit_stream_node_tables_inner(
665                &mut fragment.nodes.clone(),
666                false,
667                true,
668                |table, _| {
669                    let table_id = table.id;
670                    tables
671                        .try_insert(table_id, table.clone())
672                        .unwrap_or_else(|_| panic!("duplicated table id `{}`", table_id));
673                },
674            );
675        }
676        tables
677    }
678
679    /// Returns the internal table ids without the mview table.
680    pub fn internal_table_ids(&self) -> Vec<TableId> {
681        self.fragments
682            .values()
683            .flat_map(|f| f.state_table_ids.iter().copied())
684            .filter(|&t| !self.stream_job_id.is_mv_table_id(t))
685            .collect_vec()
686    }
687
688    /// Returns all internal table ids including the mview table.
689    pub fn all_table_ids(&self) -> impl Iterator<Item = TableId> + '_ {
690        self.fragments
691            .values()
692            .flat_map(|f| f.state_table_ids.clone())
693    }
694}
695
696#[derive(Debug, Display, Clone, Copy, PartialEq, Eq)]
697pub enum BackfillUpstreamType {
698    MView,
699    Values,
700    Source,
701    LocalityProvider,
702}
703
704impl BackfillUpstreamType {
705    pub fn from_fragment_type_mask(mask: FragmentTypeMask) -> Self {
706        let is_mview = mask.contains(FragmentTypeFlag::StreamScan);
707        let is_values = mask.contains(FragmentTypeFlag::Values);
708        let is_source = mask.contains(FragmentTypeFlag::SourceScan);
709        let is_locality_provider = mask.contains(FragmentTypeFlag::LocalityProvider);
710
711        // Note: in theory we can have multiple backfill executors in one fragment, but currently it's not possible.
712        // See <https://github.com/risingwavelabs/risingwave/issues/6236>.
713        debug_assert!(
714            is_mview as u8 + is_values as u8 + is_source as u8 + is_locality_provider as u8 == 1,
715            "a backfill fragment should either be mview, value, source, or locality provider, found {:?}",
716            mask
717        );
718
719        if is_mview {
720            BackfillUpstreamType::MView
721        } else if is_values {
722            BackfillUpstreamType::Values
723        } else if is_source {
724            BackfillUpstreamType::Source
725        } else if is_locality_provider {
726            BackfillUpstreamType::LocalityProvider
727        } else {
728            unreachable!("invalid fragment type mask: {:?}", mask);
729        }
730    }
731}