risingwave_meta/stream/stream_graph/
fragment.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap, HashSet};
16use std::num::NonZeroUsize;
17use std::ops::{Deref, DerefMut};
18use std::sync::LazyLock;
19
20use anyhow::{Context, anyhow};
21use enum_as_inner::EnumAsInner;
22use itertools::Itertools;
23use risingwave_common::bail;
24use risingwave_common::catalog::{
25    CDC_SOURCE_COLUMN_NUM, TableId, generate_internal_table_name_with_type,
26};
27use risingwave_common::hash::VnodeCount;
28use risingwave_common::util::iter_util::ZipEqFast;
29use risingwave_common::util::stream_graph_visitor;
30use risingwave_common::util::stream_graph_visitor::{
31    visit_stream_node_cont, visit_stream_node_cont_mut,
32};
33use risingwave_meta_model::WorkerId;
34use risingwave_pb::catalog::Table;
35use risingwave_pb::ddl_service::TableJobType;
36use risingwave_pb::stream_plan::stream_fragment_graph::{
37    Parallelism, StreamFragment, StreamFragmentEdge as StreamFragmentEdgeProto,
38};
39use risingwave_pb::stream_plan::stream_node::NodeBody;
40use risingwave_pb::stream_plan::{
41    DispatchStrategy, DispatcherType, FragmentTypeFlag,
42    StreamFragmentGraph as StreamFragmentGraphProto, StreamNode, StreamScanNode, StreamScanType,
43};
44
45use crate::MetaResult;
46use crate::barrier::SnapshotBackfillInfo;
47use crate::manager::{MetaSrvEnv, StreamingJob, StreamingJobType};
48use crate::model::{ActorId, Fragment, FragmentId, StreamActor};
49use crate::stream::stream_graph::id::{GlobalFragmentId, GlobalFragmentIdGen, GlobalTableIdGen};
50use crate::stream::stream_graph::schedule::Distribution;
51
52/// The fragment in the building phase, including the [`StreamFragment`] from the frontend and
53/// several additional helper fields.
54#[derive(Debug, Clone)]
55pub(super) struct BuildingFragment {
56    /// The fragment structure from the frontend, with the global fragment ID.
57    inner: StreamFragment,
58
59    /// The ID of the job if it contains the streaming job node.
60    job_id: Option<u32>,
61
62    /// The required column IDs of each upstream table.
63    /// Will be converted to indices when building the edge connected to the upstream.
64    ///
65    /// For shared CDC table on source, its `vec![]`, since the upstream source's output schema is fixed.
66    upstream_table_columns: HashMap<TableId, Vec<i32>>,
67}
68
69impl BuildingFragment {
70    /// Create a new [`BuildingFragment`] from a [`StreamFragment`]. The global fragment ID and
71    /// global table IDs will be correctly filled with the given `id` and `table_id_gen`.
72    fn new(
73        id: GlobalFragmentId,
74        fragment: StreamFragment,
75        job: &StreamingJob,
76        table_id_gen: GlobalTableIdGen,
77    ) -> Self {
78        let mut fragment = StreamFragment {
79            fragment_id: id.as_global_id(),
80            ..fragment
81        };
82
83        // Fill the information of the internal tables in the fragment.
84        Self::fill_internal_tables(&mut fragment, job, table_id_gen);
85
86        let job_id = Self::fill_job(&mut fragment, job).then(|| job.id());
87        let upstream_table_columns =
88            Self::extract_upstream_table_columns_except_cross_db_backfill(&fragment);
89
90        Self {
91            inner: fragment,
92            job_id,
93            upstream_table_columns,
94        }
95    }
96
97    /// Extract the internal tables from the fragment.
98    fn extract_internal_tables(&self) -> Vec<Table> {
99        let mut fragment = self.inner.to_owned();
100        let mut tables = Vec::new();
101        stream_graph_visitor::visit_internal_tables(&mut fragment, |table, _| {
102            tables.push(table.clone());
103        });
104        tables
105    }
106
107    /// Fill the information with the internal tables in the fragment.
108    fn fill_internal_tables(
109        fragment: &mut StreamFragment,
110        job: &StreamingJob,
111        table_id_gen: GlobalTableIdGen,
112    ) {
113        let fragment_id = fragment.fragment_id;
114        stream_graph_visitor::visit_internal_tables(fragment, |table, table_type_name| {
115            table.id = table_id_gen.to_global_id(table.id).as_global_id();
116            table.schema_id = job.schema_id();
117            table.database_id = job.database_id();
118            table.name = generate_internal_table_name_with_type(
119                &job.name(),
120                fragment_id,
121                table.id,
122                table_type_name,
123            );
124            table.fragment_id = fragment_id;
125            table.owner = job.owner();
126        });
127    }
128
129    /// Fill the information with the job in the fragment.
130    fn fill_job(fragment: &mut StreamFragment, job: &StreamingJob) -> bool {
131        let job_id = job.id();
132        let fragment_id = fragment.fragment_id;
133        let mut has_job = false;
134
135        stream_graph_visitor::visit_fragment_mut(fragment, |node_body| match node_body {
136            NodeBody::Materialize(materialize_node) => {
137                materialize_node.table_id = job_id;
138
139                // Fill the ID of the `Table`.
140                let table = materialize_node.table.as_mut().unwrap();
141                table.id = job_id;
142                table.database_id = job.database_id();
143                table.schema_id = job.schema_id();
144                table.fragment_id = fragment_id;
145                #[cfg(not(debug_assertions))]
146                {
147                    table.definition = job.name();
148                }
149
150                has_job = true;
151            }
152            NodeBody::Sink(sink_node) => {
153                sink_node.sink_desc.as_mut().unwrap().id = job_id;
154
155                has_job = true;
156            }
157            NodeBody::Dml(dml_node) => {
158                dml_node.table_id = job_id;
159                dml_node.table_version_id = job.table_version_id().unwrap();
160            }
161            NodeBody::StreamFsFetch(fs_fetch_node) => {
162                if let StreamingJob::Table(table_source, _, _) = job {
163                    if let Some(node_inner) = fs_fetch_node.node_inner.as_mut()
164                        && let Some(source) = table_source
165                    {
166                        node_inner.source_id = source.id;
167                    }
168                }
169            }
170            NodeBody::Source(source_node) => {
171                match job {
172                    // Note: For table without connector, it has a dummy Source node.
173                    // Note: For table with connector, it's source node has a source id different with the table id (job id), assigned in create_job_catalog.
174                    StreamingJob::Table(source, _table, _table_job_type) => {
175                        if let Some(source_inner) = source_node.source_inner.as_mut() {
176                            if let Some(source) = source {
177                                debug_assert_ne!(source.id, job_id);
178                                source_inner.source_id = source.id;
179                            }
180                        }
181                    }
182                    StreamingJob::Source(source) => {
183                        has_job = true;
184                        if let Some(source_inner) = source_node.source_inner.as_mut() {
185                            debug_assert_eq!(source.id, job_id);
186                            source_inner.source_id = source.id;
187                        }
188                    }
189                    // For other job types, no need to fill the source id, since it refers to an existing source.
190                    _ => {}
191                }
192            }
193            NodeBody::StreamCdcScan(node) => {
194                if let Some(table_desc) = node.cdc_table_desc.as_mut() {
195                    table_desc.table_id = job_id;
196                }
197            }
198            _ => {}
199        });
200
201        has_job
202    }
203
204    /// Extract the required columns (in IDs) of each upstream table except for cross-db backfill.
205    fn extract_upstream_table_columns_except_cross_db_backfill(
206        fragment: &StreamFragment,
207    ) -> HashMap<TableId, Vec<i32>> {
208        let mut table_columns = HashMap::new();
209
210        stream_graph_visitor::visit_fragment(fragment, |node_body| {
211            let (table_id, column_ids) = match node_body {
212                NodeBody::StreamScan(stream_scan) => {
213                    if stream_scan.get_stream_scan_type().unwrap()
214                        == StreamScanType::CrossDbSnapshotBackfill
215                    {
216                        return;
217                    }
218                    (
219                        stream_scan.table_id.into(),
220                        stream_scan.upstream_column_ids.clone(),
221                    )
222                }
223                NodeBody::CdcFilter(cdc_filter) => (cdc_filter.upstream_source_id.into(), vec![]),
224                NodeBody::SourceBackfill(backfill) => (
225                    backfill.upstream_source_id.into(),
226                    // FIXME: only pass required columns instead of all columns here
227                    backfill
228                        .columns
229                        .iter()
230                        .map(|c| c.column_desc.as_ref().unwrap().column_id)
231                        .collect(),
232                ),
233                _ => return,
234            };
235            table_columns
236                .try_insert(table_id, column_ids)
237                .expect("currently there should be no two same upstream tables in a fragment");
238        });
239
240        table_columns
241    }
242
243    pub fn has_shuffled_backfill(&self) -> bool {
244        let stream_node = match self.inner.node.as_ref() {
245            Some(node) => node,
246            _ => return false,
247        };
248        let mut has_shuffled_backfill = false;
249        let has_shuffled_backfill_mut_ref = &mut has_shuffled_backfill;
250        visit_stream_node_cont(stream_node, |node| {
251            let is_shuffled_backfill = if let Some(node) = &node.node_body
252                && let Some(node) = node.as_stream_scan()
253            {
254                node.stream_scan_type == StreamScanType::ArrangementBackfill as i32
255                    || node.stream_scan_type == StreamScanType::SnapshotBackfill as i32
256            } else {
257                false
258            };
259            if is_shuffled_backfill {
260                *has_shuffled_backfill_mut_ref = true;
261                false
262            } else {
263                true
264            }
265        });
266        has_shuffled_backfill
267    }
268}
269
270impl Deref for BuildingFragment {
271    type Target = StreamFragment;
272
273    fn deref(&self) -> &Self::Target {
274        &self.inner
275    }
276}
277
278impl DerefMut for BuildingFragment {
279    fn deref_mut(&mut self) -> &mut Self::Target {
280        &mut self.inner
281    }
282}
283
284/// The ID of an edge in the fragment graph. For different types of edges, the ID will be in
285/// different variants.
286#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, EnumAsInner)]
287pub(super) enum EdgeId {
288    /// The edge between two building (internal) fragments.
289    Internal {
290        /// The ID generated by the frontend, generally the operator ID of `Exchange`.
291        /// See [`StreamFragmentEdgeProto`].
292        link_id: u64,
293    },
294
295    /// The edge between an upstream external fragment and downstream building fragment. Used for
296    /// MV on MV.
297    UpstreamExternal {
298        /// The ID of the upstream table or materialized view.
299        upstream_table_id: TableId,
300        /// The ID of the downstream fragment.
301        downstream_fragment_id: GlobalFragmentId,
302    },
303
304    /// The edge between an upstream building fragment and downstream external fragment. Used for
305    /// schema change (replace table plan).
306    DownstreamExternal(DownstreamExternalEdgeId),
307}
308
309#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
310pub(super) struct DownstreamExternalEdgeId {
311    /// The ID of the original upstream fragment (`Materialize`).
312    pub(super) original_upstream_fragment_id: GlobalFragmentId,
313    /// The ID of the downstream fragment.
314    pub(super) downstream_fragment_id: GlobalFragmentId,
315}
316
317/// The edge in the fragment graph.
318///
319/// The edge can be either internal or external. This is distinguished by the [`EdgeId`].
320#[derive(Debug, Clone)]
321pub(super) struct StreamFragmentEdge {
322    /// The ID of the edge.
323    pub id: EdgeId,
324
325    /// The strategy used for dispatching the data.
326    pub dispatch_strategy: DispatchStrategy,
327}
328
329impl StreamFragmentEdge {
330    fn from_protobuf(edge: &StreamFragmentEdgeProto) -> Self {
331        Self {
332            // By creating an edge from the protobuf, we know that the edge is from the frontend and
333            // is internal.
334            id: EdgeId::Internal {
335                link_id: edge.link_id,
336            },
337            dispatch_strategy: edge.get_dispatch_strategy().unwrap().clone(),
338        }
339    }
340}
341
342/// In-memory representation of a **Fragment** Graph, built from the [`StreamFragmentGraphProto`]
343/// from the frontend.
344///
345/// This only includes nodes and edges of the current job itself. It will be converted to [`CompleteStreamFragmentGraph`] later,
346/// that contains the additional information of pre-existing
347/// fragments, which are connected to the graph's top-most or bottom-most fragments.
348#[derive(Default, Debug)]
349pub struct StreamFragmentGraph {
350    /// stores all the fragments in the graph.
351    fragments: HashMap<GlobalFragmentId, BuildingFragment>,
352
353    /// stores edges between fragments: upstream => downstream.
354    downstreams: HashMap<GlobalFragmentId, HashMap<GlobalFragmentId, StreamFragmentEdge>>,
355
356    /// stores edges between fragments: downstream -> upstream.
357    upstreams: HashMap<GlobalFragmentId, HashMap<GlobalFragmentId, StreamFragmentEdge>>,
358
359    /// Dependent relations of this job.
360    dependent_table_ids: HashSet<TableId>,
361
362    /// The default parallelism of the job, specified by the `STREAMING_PARALLELISM` session
363    /// variable. If not specified, all active worker slots will be used.
364    specified_parallelism: Option<NonZeroUsize>,
365
366    /// Specified max parallelism, i.e., expected vnode count for the graph.
367    ///
368    /// The scheduler on the meta service will use this as a hint to decide the vnode count
369    /// for each fragment.
370    ///
371    /// Note that the actual vnode count may be different from this value.
372    /// For example, a no-shuffle exchange between current fragment graph and an existing
373    /// upstream fragment graph requires two fragments to be in the same distribution,
374    /// thus the same vnode count.
375    max_parallelism: usize,
376}
377
378impl StreamFragmentGraph {
379    /// Create a new [`StreamFragmentGraph`] from the given [`StreamFragmentGraphProto`], with all
380    /// global IDs correctly filled.
381    pub fn new(
382        env: &MetaSrvEnv,
383        proto: StreamFragmentGraphProto,
384        job: &StreamingJob,
385    ) -> MetaResult<Self> {
386        let fragment_id_gen =
387            GlobalFragmentIdGen::new(env.id_gen_manager(), proto.fragments.len() as u64);
388        // Note: in SQL backend, the ids generated here are fake and will be overwritten again
389        // with `refill_internal_table_ids` later.
390        // TODO: refactor the code to remove this step.
391        let table_id_gen = GlobalTableIdGen::new(env.id_gen_manager(), proto.table_ids_cnt as u64);
392
393        // Create nodes.
394        let fragments: HashMap<_, _> = proto
395            .fragments
396            .into_iter()
397            .map(|(id, fragment)| {
398                let id = fragment_id_gen.to_global_id(id);
399                let fragment = BuildingFragment::new(id, fragment, job, table_id_gen);
400                (id, fragment)
401            })
402            .collect();
403
404        assert_eq!(
405            fragments
406                .values()
407                .map(|f| f.extract_internal_tables().len() as u32)
408                .sum::<u32>(),
409            proto.table_ids_cnt
410        );
411
412        // Create edges.
413        let mut downstreams = HashMap::new();
414        let mut upstreams = HashMap::new();
415
416        for edge in proto.edges {
417            let upstream_id = fragment_id_gen.to_global_id(edge.upstream_id);
418            let downstream_id = fragment_id_gen.to_global_id(edge.downstream_id);
419            let edge = StreamFragmentEdge::from_protobuf(&edge);
420
421            upstreams
422                .entry(downstream_id)
423                .or_insert_with(HashMap::new)
424                .try_insert(upstream_id, edge.clone())
425                .unwrap();
426            downstreams
427                .entry(upstream_id)
428                .or_insert_with(HashMap::new)
429                .try_insert(downstream_id, edge)
430                .unwrap();
431        }
432
433        // Note: Here we directly use the field `dependent_table_ids` in the proto (resolved in
434        // frontend), instead of visiting the graph ourselves.
435        let dependent_table_ids = proto
436            .dependent_table_ids
437            .iter()
438            .map(TableId::from)
439            .collect();
440
441        let specified_parallelism = if let Some(Parallelism { parallelism }) = proto.parallelism {
442            Some(NonZeroUsize::new(parallelism as usize).context("parallelism should not be 0")?)
443        } else {
444            None
445        };
446
447        let max_parallelism = proto.max_parallelism as usize;
448
449        Ok(Self {
450            fragments,
451            downstreams,
452            upstreams,
453            dependent_table_ids,
454            specified_parallelism,
455            max_parallelism,
456        })
457    }
458
459    /// Retrieve the **incomplete** internal tables map of the whole graph.
460    ///
461    /// Note that some fields in the table catalogs are not filled during the current phase, e.g.,
462    /// `fragment_id`, `vnode_count`. They will be all filled after a `TableFragments` is built.
463    /// Be careful when using the returned values.
464    ///
465    /// See also [`crate::model::StreamJobFragments::internal_tables`].
466    pub fn incomplete_internal_tables(&self) -> BTreeMap<u32, Table> {
467        let mut tables = BTreeMap::new();
468        for fragment in self.fragments.values() {
469            for table in fragment.extract_internal_tables() {
470                let table_id = table.id;
471                tables
472                    .try_insert(table_id, table)
473                    .unwrap_or_else(|_| panic!("duplicated table id `{}`", table_id));
474            }
475        }
476        tables
477    }
478
479    /// Refill the internal tables' `table_id`s according to the given map, typically obtained from
480    /// `create_internal_table_catalog`.
481    pub fn refill_internal_table_ids(&mut self, table_id_map: HashMap<u32, u32>) {
482        for fragment in self.fragments.values_mut() {
483            stream_graph_visitor::visit_internal_tables(
484                &mut fragment.inner,
485                |table, _table_type_name| {
486                    let target = table_id_map.get(&table.id).cloned().unwrap();
487                    table.id = target;
488                },
489            );
490        }
491    }
492
493    /// Set internal tables' `table_id`s according to a list of internal tables
494    pub fn fit_internal_table_ids(
495        &mut self,
496        mut old_internal_tables: Vec<Table>,
497    ) -> MetaResult<()> {
498        let mut new_internal_table_ids = Vec::new();
499        for fragment in self.fragments.values() {
500            for table in &fragment.extract_internal_tables() {
501                new_internal_table_ids.push(table.id);
502            }
503        }
504
505        if new_internal_table_ids.len() != old_internal_tables.len() {
506            bail!(
507                "Different number of internal tables. New: {}, Old: {}",
508                new_internal_table_ids.len(),
509                old_internal_tables.len()
510            );
511        }
512        old_internal_tables.sort_by(|a, b| a.id.cmp(&b.id));
513        new_internal_table_ids.sort();
514
515        let internal_table_id_map = new_internal_table_ids
516            .into_iter()
517            .zip_eq_fast(old_internal_tables.into_iter())
518            .collect::<HashMap<_, _>>();
519
520        for fragment in self.fragments.values_mut() {
521            stream_graph_visitor::visit_internal_tables(
522                &mut fragment.inner,
523                |table, _table_type_name| {
524                    let target = internal_table_id_map.get(&table.id).cloned().unwrap();
525                    *table = target;
526                },
527            );
528        }
529
530        Ok(())
531    }
532
533    /// Returns the fragment id where the streaming job node located.
534    pub fn table_fragment_id(&self) -> FragmentId {
535        self.fragments
536            .values()
537            .filter(|b| b.job_id.is_some())
538            .map(|b| b.fragment_id)
539            .exactly_one()
540            .expect("require exactly 1 materialize/sink/cdc source node when creating the streaming job")
541    }
542
543    /// Returns the fragment id where the table dml is received.
544    pub fn dml_fragment_id(&self) -> Option<FragmentId> {
545        self.fragments
546            .values()
547            .filter(|b| b.fragment_type_mask & FragmentTypeFlag::Dml as u32 != 0)
548            .map(|b| b.fragment_id)
549            .at_most_one()
550            .expect("require at most 1 dml node when creating the streaming job")
551    }
552
553    /// Get the dependent streaming job ids of this job.
554    pub fn dependent_table_ids(&self) -> &HashSet<TableId> {
555        &self.dependent_table_ids
556    }
557
558    /// Get the parallelism of the job, if specified by the user.
559    pub fn specified_parallelism(&self) -> Option<NonZeroUsize> {
560        self.specified_parallelism
561    }
562
563    /// Get the expected vnode count of the graph. See documentation of the field for more details.
564    pub fn max_parallelism(&self) -> usize {
565        self.max_parallelism
566    }
567
568    /// Get downstreams of a fragment.
569    fn get_downstreams(
570        &self,
571        fragment_id: GlobalFragmentId,
572    ) -> &HashMap<GlobalFragmentId, StreamFragmentEdge> {
573        self.downstreams.get(&fragment_id).unwrap_or(&EMPTY_HASHMAP)
574    }
575
576    /// Get upstreams of a fragment.
577    fn get_upstreams(
578        &self,
579        fragment_id: GlobalFragmentId,
580    ) -> &HashMap<GlobalFragmentId, StreamFragmentEdge> {
581        self.upstreams.get(&fragment_id).unwrap_or(&EMPTY_HASHMAP)
582    }
583
584    /// Returns `Ok((Some(``snapshot_backfill_info``), ``cross_db_snapshot_backfill_info``))`
585    pub fn collect_snapshot_backfill_info(
586        &self,
587    ) -> MetaResult<(Option<SnapshotBackfillInfo>, SnapshotBackfillInfo)> {
588        let mut prev_stream_scan: Option<(Option<SnapshotBackfillInfo>, StreamScanNode)> = None;
589        let mut cross_db_info = SnapshotBackfillInfo {
590            upstream_mv_table_id_to_backfill_epoch: Default::default(),
591        };
592        let mut result = Ok(());
593        for (node, fragment_type_mask) in self
594            .fragments
595            .values()
596            .map(|fragment| (fragment.node.as_ref().unwrap(), fragment.fragment_type_mask))
597        {
598            visit_stream_node_cont(node, |node| {
599                if let Some(NodeBody::StreamScan(stream_scan)) = node.node_body.as_ref() {
600                    let stream_scan_type = StreamScanType::try_from(stream_scan.stream_scan_type)
601                        .expect("invalid stream_scan_type");
602                    let is_snapshot_backfill = match stream_scan_type {
603                        StreamScanType::SnapshotBackfill => {
604                            assert!(
605                                (fragment_type_mask
606                                    & (FragmentTypeFlag::SnapshotBackfillStreamScan as u32))
607                                    > 0
608                            );
609                            true
610                        }
611                        StreamScanType::CrossDbSnapshotBackfill => {
612                            assert!(
613                                (fragment_type_mask
614                                    & (FragmentTypeFlag::CrossDbSnapshotBackfillStreamScan as u32))
615                                    > 0
616                            );
617                            cross_db_info
618                                .upstream_mv_table_id_to_backfill_epoch
619                                .insert(TableId::new(stream_scan.table_id), None);
620
621                            return true;
622                        }
623                        _ => false,
624                    };
625
626                    match &mut prev_stream_scan {
627                        Some((prev_snapshot_backfill_info, prev_stream_scan)) => {
628                            match (prev_snapshot_backfill_info, is_snapshot_backfill) {
629                                (Some(prev_snapshot_backfill_info), true) => {
630                                    prev_snapshot_backfill_info
631                                        .upstream_mv_table_id_to_backfill_epoch
632                                        .insert(TableId::new(stream_scan.table_id), None);
633                                    true
634                                }
635                                (None, false) => true,
636                                (_, _) => {
637                                    result = Err(anyhow!("must be either all snapshot_backfill or no snapshot_backfill. Curr: {stream_scan:?} Prev: {prev_stream_scan:?}").into());
638                                    false
639                                }
640                            }
641                        }
642                        None => {
643                            prev_stream_scan = Some((
644                                if is_snapshot_backfill {
645                                    Some(SnapshotBackfillInfo {
646                                        upstream_mv_table_id_to_backfill_epoch: HashMap::from_iter(
647                                            [(TableId::new(stream_scan.table_id), None)],
648                                        ),
649                                    })
650                                } else {
651                                    None
652                                },
653                                *stream_scan.clone(),
654                            ));
655                            true
656                        }
657                    }
658                } else {
659                    true
660                }
661            })
662        }
663        result.map(|_| {
664            (
665                prev_stream_scan
666                    .map(|(snapshot_backfill_info, _)| snapshot_backfill_info)
667                    .unwrap_or(None),
668                cross_db_info,
669            )
670        })
671    }
672}
673
674/// Fill snapshot epoch for `StreamScanNode` of `SnapshotBackfill`.
675/// Return `true` when has change applied.
676pub fn fill_snapshot_backfill_epoch(
677    node: &mut StreamNode,
678    snapshot_backfill_info: Option<&SnapshotBackfillInfo>,
679    cross_db_snapshot_backfill_info: &SnapshotBackfillInfo,
680) -> MetaResult<bool> {
681    let mut result = Ok(());
682    let mut applied = false;
683    visit_stream_node_cont_mut(node, |node| {
684        if let Some(NodeBody::StreamScan(stream_scan)) = node.node_body.as_mut()
685            && (stream_scan.stream_scan_type == StreamScanType::SnapshotBackfill as i32
686                || stream_scan.stream_scan_type == StreamScanType::CrossDbSnapshotBackfill as i32)
687        {
688            result = try {
689                let table_id = TableId::new(stream_scan.table_id);
690                let snapshot_epoch = cross_db_snapshot_backfill_info
691                    .upstream_mv_table_id_to_backfill_epoch
692                    .get(&table_id)
693                    .or_else(|| {
694                        snapshot_backfill_info.and_then(|snapshot_backfill_info| {
695                            snapshot_backfill_info
696                                .upstream_mv_table_id_to_backfill_epoch
697                                .get(&table_id)
698                        })
699                    })
700                    .ok_or_else(|| anyhow!("upstream table id not covered: {}", table_id))?
701                    .ok_or_else(|| anyhow!("upstream table id not set: {}", table_id))?;
702                if let Some(prev_snapshot_epoch) =
703                    stream_scan.snapshot_backfill_epoch.replace(snapshot_epoch)
704                {
705                    Err(anyhow!(
706                        "snapshot backfill epoch set again: {} {} {}",
707                        table_id,
708                        prev_snapshot_epoch,
709                        snapshot_epoch
710                    ))?;
711                }
712                applied = true;
713            };
714            result.is_ok()
715        } else {
716            true
717        }
718    });
719    result.map(|_| applied)
720}
721
722static EMPTY_HASHMAP: LazyLock<HashMap<GlobalFragmentId, StreamFragmentEdge>> =
723    LazyLock::new(HashMap::new);
724
725/// A fragment that is either being built or already exists. Used for generalize the logic of
726/// [`crate::stream::ActorGraphBuilder`].
727#[derive(Debug, Clone, EnumAsInner)]
728pub(super) enum EitherFragment {
729    /// An internal fragment that is being built for the current streaming job.
730    Building(BuildingFragment),
731
732    /// An existing fragment that is external but connected to the fragments being built.
733    Existing(Fragment),
734}
735
736/// A wrapper of [`StreamFragmentGraph`] that contains the additional information of pre-existing
737/// fragments, which are connected to the graph's top-most or bottom-most fragments.
738///
739/// For example,
740/// - if we're going to build a mview on an existing mview, the upstream fragment containing the
741///   `Materialize` node will be included in this structure.
742/// - if we're going to replace the plan of a table with downstream mviews, the downstream fragments
743///   containing the `StreamScan` nodes will be included in this structure.
744#[derive(Debug)]
745pub struct CompleteStreamFragmentGraph {
746    /// The fragment graph of the streaming job being built.
747    building_graph: StreamFragmentGraph,
748
749    /// The required information of existing fragments.
750    existing_fragments: HashMap<GlobalFragmentId, Fragment>,
751
752    /// The location of the actors in the existing fragments.
753    existing_actor_location: HashMap<ActorId, WorkerId>,
754
755    /// Extra edges between existing fragments and the building fragments.
756    extra_downstreams: HashMap<GlobalFragmentId, HashMap<GlobalFragmentId, StreamFragmentEdge>>,
757
758    /// Extra edges between existing fragments and the building fragments.
759    extra_upstreams: HashMap<GlobalFragmentId, HashMap<GlobalFragmentId, StreamFragmentEdge>>,
760}
761
762pub struct FragmentGraphUpstreamContext {
763    /// Root fragment is the root of upstream stream graph, which can be a
764    /// mview fragment or source fragment for cdc source job
765    upstream_root_fragments: HashMap<TableId, Fragment>,
766    upstream_actor_location: HashMap<ActorId, WorkerId>,
767}
768
769pub struct FragmentGraphDownstreamContext {
770    original_root_fragment_id: FragmentId,
771    downstream_fragments: Vec<(DispatchStrategy, Fragment)>,
772    downstream_actor_location: HashMap<ActorId, WorkerId>,
773}
774
775impl CompleteStreamFragmentGraph {
776    /// Create a new [`CompleteStreamFragmentGraph`] with empty existing fragments, i.e., there's no
777    /// upstream mviews.
778    #[cfg(test)]
779    pub fn for_test(graph: StreamFragmentGraph) -> Self {
780        Self {
781            building_graph: graph,
782            existing_fragments: Default::default(),
783            existing_actor_location: Default::default(),
784            extra_downstreams: Default::default(),
785            extra_upstreams: Default::default(),
786        }
787    }
788
789    /// Create a new [`CompleteStreamFragmentGraph`] for newly created job (which has no downstreams).
790    /// e.g., MV on MV and CDC/Source Table with the upstream existing
791    /// `Materialize` or `Source` fragments.
792    pub fn with_upstreams(
793        graph: StreamFragmentGraph,
794        upstream_root_fragments: HashMap<TableId, Fragment>,
795        existing_actor_location: HashMap<ActorId, WorkerId>,
796        job_type: StreamingJobType,
797    ) -> MetaResult<Self> {
798        Self::build_helper(
799            graph,
800            Some(FragmentGraphUpstreamContext {
801                upstream_root_fragments,
802                upstream_actor_location: existing_actor_location,
803            }),
804            None,
805            job_type,
806        )
807    }
808
809    /// Create a new [`CompleteStreamFragmentGraph`] for replacing an existing table/source,
810    /// with the downstream existing `StreamScan`/`StreamSourceScan` fragments.
811    pub fn with_downstreams(
812        graph: StreamFragmentGraph,
813        original_root_fragment_id: FragmentId,
814        downstream_fragments: Vec<(DispatchStrategy, Fragment)>,
815        existing_actor_location: HashMap<ActorId, WorkerId>,
816        job_type: StreamingJobType,
817    ) -> MetaResult<Self> {
818        Self::build_helper(
819            graph,
820            None,
821            Some(FragmentGraphDownstreamContext {
822                original_root_fragment_id,
823                downstream_fragments,
824                downstream_actor_location: existing_actor_location,
825            }),
826            job_type,
827        )
828    }
829
830    /// For replacing an existing table based on shared cdc source, which has both upstreams and downstreams.
831    pub fn with_upstreams_and_downstreams(
832        graph: StreamFragmentGraph,
833        upstream_root_fragments: HashMap<TableId, Fragment>,
834        upstream_actor_location: HashMap<ActorId, WorkerId>,
835        original_root_fragment_id: FragmentId,
836        downstream_fragments: Vec<(DispatchStrategy, Fragment)>,
837        downstream_actor_location: HashMap<ActorId, WorkerId>,
838        job_type: StreamingJobType,
839    ) -> MetaResult<Self> {
840        Self::build_helper(
841            graph,
842            Some(FragmentGraphUpstreamContext {
843                upstream_root_fragments,
844                upstream_actor_location,
845            }),
846            Some(FragmentGraphDownstreamContext {
847                original_root_fragment_id,
848                downstream_fragments,
849                downstream_actor_location,
850            }),
851            job_type,
852        )
853    }
854
855    /// The core logic of building a [`CompleteStreamFragmentGraph`], i.e., adding extra upstream/downstream fragments.
856    fn build_helper(
857        mut graph: StreamFragmentGraph,
858        upstream_ctx: Option<FragmentGraphUpstreamContext>,
859        downstream_ctx: Option<FragmentGraphDownstreamContext>,
860        job_type: StreamingJobType,
861    ) -> MetaResult<Self> {
862        let mut extra_downstreams = HashMap::new();
863        let mut extra_upstreams = HashMap::new();
864        let mut existing_fragments = HashMap::new();
865
866        let mut existing_actor_location = HashMap::new();
867
868        if let Some(FragmentGraphUpstreamContext {
869            upstream_root_fragments,
870            upstream_actor_location,
871        }) = upstream_ctx
872        {
873            for (&id, fragment) in &mut graph.fragments {
874                let uses_shuffled_backfill = fragment.has_shuffled_backfill();
875                for (&upstream_table_id, output_columns) in &fragment.upstream_table_columns {
876                    let (up_fragment_id, edge) = match job_type {
877                        StreamingJobType::Table(TableJobType::SharedCdcSource) => {
878                            let source_fragment = upstream_root_fragments
879                                .get(&upstream_table_id)
880                                .context("upstream source fragment not found")?;
881                            let source_job_id = GlobalFragmentId::new(source_fragment.fragment_id);
882
883                            // we traverse all fragments in the graph, and we should find out the
884                            // CdcFilter fragment and add an edge between upstream source fragment and it.
885                            assert_ne!(
886                                (fragment.fragment_type_mask & FragmentTypeFlag::CdcFilter as u32),
887                                0
888                            );
889
890                            tracing::debug!(
891                                ?source_job_id,
892                                ?output_columns,
893                                identity = ?fragment.inner.get_node().unwrap().get_identity(),
894                                current_frag_id=?id,
895                                "CdcFilter with upstream source fragment"
896                            );
897                            let edge = StreamFragmentEdge {
898                                id: EdgeId::UpstreamExternal {
899                                    upstream_table_id,
900                                    downstream_fragment_id: id,
901                                },
902                                // We always use `NoShuffle` for the exchange between the upstream
903                                // `Source` and the downstream `StreamScan` of the new cdc table.
904                                dispatch_strategy: DispatchStrategy {
905                                    r#type: DispatcherType::NoShuffle as _,
906                                    dist_key_indices: vec![], // not used for `NoShuffle`
907                                    output_indices: (0..CDC_SOURCE_COLUMN_NUM as _).collect(),
908                                },
909                            };
910
911                            (source_job_id, edge)
912                        }
913                        StreamingJobType::MaterializedView
914                        | StreamingJobType::Sink
915                        | StreamingJobType::Index => {
916                            // handle MV on MV/Source
917
918                            // Build the extra edges between the upstream `Materialize` and the downstream `StreamScan`
919                            // of the new materialized view.
920                            let upstream_fragment = upstream_root_fragments
921                                .get(&upstream_table_id)
922                                .context("upstream materialized view fragment not found")?;
923                            let upstream_root_fragment_id =
924                                GlobalFragmentId::new(upstream_fragment.fragment_id);
925
926                            if upstream_fragment.fragment_type_mask & FragmentTypeFlag::Mview as u32
927                                != 0
928                            {
929                                // Resolve the required output columns from the upstream materialized view.
930                                let (dist_key_indices, output_indices) = {
931                                    let nodes = &upstream_fragment.nodes;
932                                    let mview_node =
933                                        nodes.get_node_body().unwrap().as_materialize().unwrap();
934                                    let all_column_ids = mview_node.column_ids();
935                                    let dist_key_indices = mview_node.dist_key_indices();
936                                    let output_indices = output_columns
937                                        .iter()
938                                        .map(|c| {
939                                            all_column_ids
940                                                .iter()
941                                                .position(|&id| id == *c)
942                                                .map(|i| i as u32)
943                                        })
944                                        .collect::<Option<Vec<_>>>()
945                                        .context(
946                                            "column not found in the upstream materialized view",
947                                        )?;
948                                    (dist_key_indices, output_indices)
949                                };
950                                let dispatch_strategy = mv_on_mv_dispatch_strategy(
951                                    uses_shuffled_backfill,
952                                    dist_key_indices,
953                                    output_indices,
954                                );
955                                let edge = StreamFragmentEdge {
956                                    id: EdgeId::UpstreamExternal {
957                                        upstream_table_id,
958                                        downstream_fragment_id: id,
959                                    },
960                                    dispatch_strategy,
961                                };
962
963                                (upstream_root_fragment_id, edge)
964                            } else if upstream_fragment.fragment_type_mask
965                                & FragmentTypeFlag::Source as u32
966                                != 0
967                            {
968                                let source_fragment = upstream_root_fragments
969                                    .get(&upstream_table_id)
970                                    .context("upstream source fragment not found")?;
971                                let source_job_id =
972                                    GlobalFragmentId::new(source_fragment.fragment_id);
973
974                                let output_indices = {
975                                    let nodes = &upstream_fragment.nodes;
976                                    let source_node =
977                                        nodes.get_node_body().unwrap().as_source().unwrap();
978
979                                    let all_column_ids = source_node.column_ids().unwrap();
980                                    output_columns
981                                        .iter()
982                                        .map(|c| {
983                                            all_column_ids
984                                                .iter()
985                                                .position(|&id| id == *c)
986                                                .map(|i| i as u32)
987                                        })
988                                        .collect::<Option<Vec<_>>>()
989                                        .context("column not found in the upstream source node")?
990                                };
991
992                                let edge = StreamFragmentEdge {
993                                    id: EdgeId::UpstreamExternal {
994                                        upstream_table_id,
995                                        downstream_fragment_id: id,
996                                    },
997                                    // We always use `NoShuffle` for the exchange between the upstream
998                                    // `Source` and the downstream `StreamScan` of the new MV.
999                                    dispatch_strategy: DispatchStrategy {
1000                                        r#type: DispatcherType::NoShuffle as _,
1001                                        dist_key_indices: vec![], // not used for `NoShuffle`
1002                                        output_indices,
1003                                    },
1004                                };
1005
1006                                (source_job_id, edge)
1007                            } else {
1008                                bail!(
1009                                    "the upstream fragment should be a MView or Source, got fragment type: {:b}",
1010                                    upstream_fragment.fragment_type_mask
1011                                )
1012                            }
1013                        }
1014                        StreamingJobType::Source | StreamingJobType::Table(_) => {
1015                            bail!(
1016                                "the streaming job shouldn't have an upstream fragment, job_type: {:?}",
1017                                job_type
1018                            )
1019                        }
1020                    };
1021
1022                    // put the edge into the extra edges
1023                    extra_downstreams
1024                        .entry(up_fragment_id)
1025                        .or_insert_with(HashMap::new)
1026                        .try_insert(id, edge.clone())
1027                        .unwrap();
1028                    extra_upstreams
1029                        .entry(id)
1030                        .or_insert_with(HashMap::new)
1031                        .try_insert(up_fragment_id, edge)
1032                        .unwrap();
1033                }
1034            }
1035
1036            existing_fragments.extend(
1037                upstream_root_fragments
1038                    .into_values()
1039                    .map(|f| (GlobalFragmentId::new(f.fragment_id), f)),
1040            );
1041
1042            existing_actor_location.extend(upstream_actor_location);
1043        }
1044
1045        if let Some(FragmentGraphDownstreamContext {
1046            original_root_fragment_id,
1047            downstream_fragments,
1048            downstream_actor_location,
1049        }) = downstream_ctx
1050        {
1051            let original_table_fragment_id = GlobalFragmentId::new(original_root_fragment_id);
1052            let table_fragment_id = GlobalFragmentId::new(graph.table_fragment_id());
1053
1054            // Build the extra edges between the `Materialize` and the downstream `StreamScan` of the
1055            // existing materialized views.
1056            for (dispatch_strategy, fragment) in &downstream_fragments {
1057                let id = GlobalFragmentId::new(fragment.fragment_id);
1058
1059                let edge = StreamFragmentEdge {
1060                    id: EdgeId::DownstreamExternal(DownstreamExternalEdgeId {
1061                        original_upstream_fragment_id: original_table_fragment_id,
1062                        downstream_fragment_id: id,
1063                    }),
1064                    dispatch_strategy: dispatch_strategy.clone(),
1065                };
1066
1067                extra_downstreams
1068                    .entry(table_fragment_id)
1069                    .or_insert_with(HashMap::new)
1070                    .try_insert(id, edge.clone())
1071                    .unwrap();
1072                extra_upstreams
1073                    .entry(id)
1074                    .or_insert_with(HashMap::new)
1075                    .try_insert(table_fragment_id, edge)
1076                    .unwrap();
1077            }
1078
1079            existing_fragments.extend(
1080                downstream_fragments
1081                    .into_iter()
1082                    .map(|(_, f)| (GlobalFragmentId::new(f.fragment_id), f)),
1083            );
1084
1085            existing_actor_location.extend(downstream_actor_location);
1086        }
1087
1088        Ok(Self {
1089            building_graph: graph,
1090            existing_fragments,
1091            existing_actor_location,
1092            extra_downstreams,
1093            extra_upstreams,
1094        })
1095    }
1096}
1097
1098fn mv_on_mv_dispatch_strategy(
1099    uses_shuffled_backfill: bool,
1100    dist_key_indices: Vec<u32>,
1101    output_indices: Vec<u32>,
1102) -> DispatchStrategy {
1103    if uses_shuffled_backfill {
1104        if !dist_key_indices.is_empty() {
1105            DispatchStrategy {
1106                r#type: DispatcherType::Hash as _,
1107                dist_key_indices,
1108                output_indices,
1109            }
1110        } else {
1111            DispatchStrategy {
1112                r#type: DispatcherType::Simple as _,
1113                dist_key_indices: vec![], // empty for Simple
1114                output_indices,
1115            }
1116        }
1117    } else {
1118        DispatchStrategy {
1119            r#type: DispatcherType::NoShuffle as _,
1120            dist_key_indices: vec![], // not used for `NoShuffle`
1121            output_indices,
1122        }
1123    }
1124}
1125
1126impl CompleteStreamFragmentGraph {
1127    /// Returns **all** fragment IDs in the complete graph, including the ones that are not in the
1128    /// building graph.
1129    pub(super) fn all_fragment_ids(&self) -> impl Iterator<Item = GlobalFragmentId> + '_ {
1130        self.building_graph
1131            .fragments
1132            .keys()
1133            .chain(self.existing_fragments.keys())
1134            .copied()
1135    }
1136
1137    /// Returns an iterator of **all** edges in the complete graph, including the external edges.
1138    pub(super) fn all_edges(
1139        &self,
1140    ) -> impl Iterator<Item = (GlobalFragmentId, GlobalFragmentId, &StreamFragmentEdge)> + '_ {
1141        self.building_graph
1142            .downstreams
1143            .iter()
1144            .chain(self.extra_downstreams.iter())
1145            .flat_map(|(&from, tos)| tos.iter().map(move |(&to, edge)| (from, to, edge)))
1146    }
1147
1148    /// Returns the distribution of the existing fragments.
1149    pub(super) fn existing_distribution(&self) -> HashMap<GlobalFragmentId, Distribution> {
1150        self.existing_fragments
1151            .iter()
1152            .map(|(&id, f)| {
1153                (
1154                    id,
1155                    Distribution::from_fragment(f, &self.existing_actor_location),
1156                )
1157            })
1158            .collect()
1159    }
1160
1161    /// Generate topological order of **all** fragments in this graph, including the ones that are
1162    /// not in the building graph. Returns error if the graph is not a DAG and topological sort can
1163    /// not be done.
1164    ///
1165    /// For MV on MV, the first fragment popped out from the heap will be the top-most node, or the
1166    /// `Sink` / `Materialize` in stream graph.
1167    pub(super) fn topo_order(&self) -> MetaResult<Vec<GlobalFragmentId>> {
1168        let mut topo = Vec::new();
1169        let mut downstream_cnts = HashMap::new();
1170
1171        // Iterate all fragments.
1172        for fragment_id in self.all_fragment_ids() {
1173            // Count how many downstreams we have for a given fragment.
1174            let downstream_cnt = self.get_downstreams(fragment_id).count();
1175            if downstream_cnt == 0 {
1176                topo.push(fragment_id);
1177            } else {
1178                downstream_cnts.insert(fragment_id, downstream_cnt);
1179            }
1180        }
1181
1182        let mut i = 0;
1183        while let Some(&fragment_id) = topo.get(i) {
1184            i += 1;
1185            // Find if we can process more fragments.
1186            for (upstream_id, _) in self.get_upstreams(fragment_id) {
1187                let downstream_cnt = downstream_cnts.get_mut(&upstream_id).unwrap();
1188                *downstream_cnt -= 1;
1189                if *downstream_cnt == 0 {
1190                    downstream_cnts.remove(&upstream_id);
1191                    topo.push(upstream_id);
1192                }
1193            }
1194        }
1195
1196        if !downstream_cnts.is_empty() {
1197            // There are fragments that are not processed yet.
1198            bail!("graph is not a DAG");
1199        }
1200
1201        Ok(topo)
1202    }
1203
1204    /// Seal a [`BuildingFragment`] from the graph into a [`Fragment`], which will be further used
1205    /// to build actors on the compute nodes and persist into meta store.
1206    pub(super) fn seal_fragment(
1207        &self,
1208        id: GlobalFragmentId,
1209        actors: Vec<StreamActor>,
1210        distribution: Distribution,
1211        stream_node: StreamNode,
1212    ) -> Fragment {
1213        let building_fragment = self.get_fragment(id).into_building().unwrap();
1214        let internal_tables = building_fragment.extract_internal_tables();
1215        let BuildingFragment {
1216            inner,
1217            job_id,
1218            upstream_table_columns: _,
1219        } = building_fragment;
1220
1221        let distribution_type = distribution.to_distribution_type();
1222        let vnode_count = distribution.vnode_count();
1223
1224        let materialized_fragment_id =
1225            if inner.fragment_type_mask & FragmentTypeFlag::Mview as u32 != 0 {
1226                job_id
1227            } else {
1228                None
1229            };
1230
1231        let state_table_ids = internal_tables
1232            .iter()
1233            .map(|t| t.id)
1234            .chain(materialized_fragment_id)
1235            .collect();
1236
1237        Fragment {
1238            fragment_id: inner.fragment_id,
1239            fragment_type_mask: inner.fragment_type_mask,
1240            distribution_type,
1241            actors,
1242            state_table_ids,
1243            maybe_vnode_count: VnodeCount::set(vnode_count).to_protobuf(),
1244            nodes: stream_node,
1245        }
1246    }
1247
1248    /// Get a fragment from the complete graph, which can be either a building fragment or an
1249    /// existing fragment.
1250    pub(super) fn get_fragment(&self, fragment_id: GlobalFragmentId) -> EitherFragment {
1251        if let Some(fragment) = self.existing_fragments.get(&fragment_id) {
1252            EitherFragment::Existing(fragment.clone())
1253        } else {
1254            EitherFragment::Building(
1255                self.building_graph
1256                    .fragments
1257                    .get(&fragment_id)
1258                    .unwrap()
1259                    .clone(),
1260            )
1261        }
1262    }
1263
1264    /// Get **all** downstreams of a fragment, including the ones that are not in the building
1265    /// graph.
1266    pub(super) fn get_downstreams(
1267        &self,
1268        fragment_id: GlobalFragmentId,
1269    ) -> impl Iterator<Item = (GlobalFragmentId, &StreamFragmentEdge)> {
1270        self.building_graph
1271            .get_downstreams(fragment_id)
1272            .iter()
1273            .chain(
1274                self.extra_downstreams
1275                    .get(&fragment_id)
1276                    .into_iter()
1277                    .flatten(),
1278            )
1279            .map(|(&id, edge)| (id, edge))
1280    }
1281
1282    /// Get **all** upstreams of a fragment, including the ones that are not in the building
1283    /// graph.
1284    pub(super) fn get_upstreams(
1285        &self,
1286        fragment_id: GlobalFragmentId,
1287    ) -> impl Iterator<Item = (GlobalFragmentId, &StreamFragmentEdge)> {
1288        self.building_graph
1289            .get_upstreams(fragment_id)
1290            .iter()
1291            .chain(self.extra_upstreams.get(&fragment_id).into_iter().flatten())
1292            .map(|(&id, edge)| (id, edge))
1293    }
1294
1295    /// Returns all building fragments in the graph.
1296    pub(super) fn building_fragments(&self) -> &HashMap<GlobalFragmentId, BuildingFragment> {
1297        &self.building_graph.fragments
1298    }
1299
1300    /// Returns all building fragments in the graph, mutable.
1301    pub(super) fn building_fragments_mut(
1302        &mut self,
1303    ) -> &mut HashMap<GlobalFragmentId, BuildingFragment> {
1304        &mut self.building_graph.fragments
1305    }
1306
1307    /// Get the expected vnode count of the building graph. See documentation of the field for more details.
1308    pub(super) fn max_parallelism(&self) -> usize {
1309        self.building_graph.max_parallelism()
1310    }
1311}