risingwave_meta/stream/stream_graph/
fragment.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap, HashSet};
16use std::num::NonZeroUsize;
17use std::ops::{Deref, DerefMut};
18use std::sync::LazyLock;
19use std::sync::atomic::AtomicU32;
20
21use anyhow::{Context, anyhow};
22use enum_as_inner::EnumAsInner;
23use itertools::Itertools;
24use risingwave_common::bail;
25use risingwave_common::catalog::{
26    CDC_SOURCE_COLUMN_NUM, ColumnCatalog, Field, FragmentTypeFlag, FragmentTypeMask, TableId,
27    generate_internal_table_name_with_type,
28};
29use risingwave_common::hash::VnodeCount;
30use risingwave_common::id::JobId;
31use risingwave_common::util::iter_util::ZipEqFast;
32use risingwave_common::util::stream_graph_visitor::{
33    self, visit_stream_node_cont, visit_stream_node_cont_mut,
34};
35use risingwave_connector::sink::catalog::SinkType;
36use risingwave_meta_model::WorkerId;
37use risingwave_meta_model::streaming_job::BackfillOrders;
38use risingwave_pb::catalog::{PbSink, PbTable, Table};
39use risingwave_pb::ddl_service::TableJobType;
40use risingwave_pb::id::{RelationId, StreamNodeLocalOperatorId};
41use risingwave_pb::plan_common::{PbColumnCatalog, PbColumnDesc};
42use risingwave_pb::stream_plan::dispatch_output_mapping::TypePair;
43use risingwave_pb::stream_plan::stream_fragment_graph::{
44    Parallelism, StreamFragment, StreamFragmentEdge as StreamFragmentEdgeProto,
45};
46use risingwave_pb::stream_plan::stream_node::{NodeBody, PbNodeBody};
47use risingwave_pb::stream_plan::{
48    BackfillOrder, DispatchOutputMapping, DispatchStrategy, DispatcherType, PbStreamNode,
49    PbStreamScanType, StreamFragmentGraph as StreamFragmentGraphProto, StreamNode, StreamScanNode,
50    StreamScanType,
51};
52
53use crate::barrier::{SharedFragmentInfo, SnapshotBackfillInfo};
54use crate::controller::id::IdGeneratorManager;
55use crate::manager::{MetaSrvEnv, StreamingJob, StreamingJobType};
56use crate::model::{ActorId, Fragment, FragmentDownstreamRelation, FragmentId, StreamActor};
57use crate::stream::stream_graph::id::{
58    GlobalActorIdGen, GlobalFragmentId, GlobalFragmentIdGen, GlobalTableIdGen,
59};
60use crate::stream::stream_graph::schedule::Distribution;
61use crate::{MetaError, MetaResult};
62
63/// The fragment in the building phase, including the [`StreamFragment`] from the frontend and
64/// several additional helper fields.
65#[derive(Debug, Clone)]
66pub(super) struct BuildingFragment {
67    /// The fragment structure from the frontend, with the global fragment ID.
68    inner: StreamFragment,
69
70    /// The ID of the job if it contains the streaming job node.
71    job_id: Option<JobId>,
72
73    /// The required column IDs of each upstream table.
74    /// Will be converted to indices when building the edge connected to the upstream.
75    ///
76    /// For shared CDC table on source, its `vec![]`, since the upstream source's output schema is fixed.
77    upstream_job_columns: HashMap<JobId, Vec<PbColumnDesc>>,
78}
79
80impl BuildingFragment {
81    /// Create a new [`BuildingFragment`] from a [`StreamFragment`]. The global fragment ID and
82    /// global table IDs will be correctly filled with the given `id` and `table_id_gen`.
83    fn new(
84        id: GlobalFragmentId,
85        fragment: StreamFragment,
86        job: &StreamingJob,
87        table_id_gen: GlobalTableIdGen,
88    ) -> Self {
89        let mut fragment = StreamFragment {
90            fragment_id: id.as_global_id(),
91            ..fragment
92        };
93
94        // Fill the information of the internal tables in the fragment.
95        Self::fill_internal_tables(&mut fragment, job, table_id_gen);
96
97        let job_id = Self::fill_job(&mut fragment, job).then(|| job.id());
98        let upstream_job_columns =
99            Self::extract_upstream_columns_except_cross_db_backfill(&fragment);
100
101        Self {
102            inner: fragment,
103            job_id,
104            upstream_job_columns,
105        }
106    }
107
108    /// Extract the internal tables from the fragment.
109    fn extract_internal_tables(&self) -> Vec<Table> {
110        let mut fragment = self.inner.clone();
111        let mut tables = Vec::new();
112        stream_graph_visitor::visit_internal_tables(&mut fragment, |table, _| {
113            tables.push(table.clone());
114        });
115        tables
116    }
117
118    /// Fill the information with the internal tables in the fragment.
119    fn fill_internal_tables(
120        fragment: &mut StreamFragment,
121        job: &StreamingJob,
122        table_id_gen: GlobalTableIdGen,
123    ) {
124        let fragment_id = fragment.fragment_id;
125        stream_graph_visitor::visit_internal_tables(fragment, |table, table_type_name| {
126            table.id = table_id_gen
127                .to_global_id(table.id.as_raw_id())
128                .as_global_id();
129            table.schema_id = job.schema_id();
130            table.database_id = job.database_id();
131            table.name = generate_internal_table_name_with_type(
132                &job.name(),
133                fragment_id,
134                table.id,
135                table_type_name,
136            );
137            table.fragment_id = fragment_id;
138            table.owner = job.owner();
139            table.job_id = Some(job.id());
140        });
141    }
142
143    /// Fill the information with the job in the fragment.
144    fn fill_job(fragment: &mut StreamFragment, job: &StreamingJob) -> bool {
145        let job_id = job.id();
146        let fragment_id = fragment.fragment_id;
147        let mut has_job = false;
148
149        stream_graph_visitor::visit_fragment_mut(fragment, |node_body| match node_body {
150            NodeBody::Materialize(materialize_node) => {
151                materialize_node.table_id = job_id.as_mv_table_id();
152
153                // Fill the table field of `MaterializeNode` from the job.
154                let table = materialize_node.table.insert(job.table().unwrap().clone());
155                table.fragment_id = fragment_id; // this will later be synced back to `job.table` with `set_info_from_graph`
156                // In production, do not include full definition in the table in plan node.
157                if cfg!(not(debug_assertions)) {
158                    table.definition = job.name();
159                }
160
161                has_job = true;
162            }
163            NodeBody::Sink(sink_node) => {
164                sink_node.sink_desc.as_mut().unwrap().id = job_id.as_sink_id();
165
166                has_job = true;
167            }
168            NodeBody::Dml(dml_node) => {
169                dml_node.table_id = job_id.as_mv_table_id();
170                dml_node.table_version_id = job.table_version_id().unwrap();
171            }
172            NodeBody::StreamFsFetch(fs_fetch_node) => {
173                if let StreamingJob::Table(table_source, _, _) = job
174                    && let Some(node_inner) = fs_fetch_node.node_inner.as_mut()
175                    && let Some(source) = table_source
176                {
177                    node_inner.source_id = source.id;
178                    if let Some(id) = source.optional_associated_table_id {
179                        node_inner.associated_table_id = Some(id.into());
180                    }
181                }
182            }
183            NodeBody::Source(source_node) => {
184                match job {
185                    // Note: For table without connector, it has a dummy Source node.
186                    // Note: For table with connector, it's source node has a source id different with the table id (job id), assigned in create_job_catalog.
187                    StreamingJob::Table(source, _table, _table_job_type) => {
188                        if let Some(source_inner) = source_node.source_inner.as_mut()
189                            && let Some(source) = source
190                        {
191                            debug_assert_ne!(source.id, job_id.as_raw_id());
192                            source_inner.source_id = source.id;
193                            if let Some(id) = source.optional_associated_table_id {
194                                source_inner.associated_table_id = Some(id.into());
195                            }
196                        }
197                    }
198                    StreamingJob::Source(source) => {
199                        has_job = true;
200                        if let Some(source_inner) = source_node.source_inner.as_mut() {
201                            debug_assert_eq!(source.id, job_id.as_raw_id());
202                            source_inner.source_id = source.id;
203                            if let Some(id) = source.optional_associated_table_id {
204                                source_inner.associated_table_id = Some(id.into());
205                            }
206                        }
207                    }
208                    // For other job types, no need to fill the source id, since it refers to an existing source.
209                    _ => {}
210                }
211            }
212            NodeBody::StreamCdcScan(node) => {
213                if let Some(table_desc) = node.cdc_table_desc.as_mut() {
214                    table_desc.table_id = job_id.as_mv_table_id();
215                }
216            }
217            NodeBody::VectorIndexWrite(node) => {
218                let table = node.table.as_mut().unwrap();
219                table.id = job_id.as_mv_table_id();
220                table.database_id = job.database_id();
221                table.schema_id = job.schema_id();
222                table.fragment_id = fragment_id;
223                #[cfg(not(debug_assertions))]
224                {
225                    table.definition = job.name();
226                }
227
228                has_job = true;
229            }
230            _ => {}
231        });
232
233        has_job
234    }
235
236    /// Extract the required columns of each upstream table except for cross-db backfill.
237    fn extract_upstream_columns_except_cross_db_backfill(
238        fragment: &StreamFragment,
239    ) -> HashMap<JobId, Vec<PbColumnDesc>> {
240        let mut table_columns = HashMap::new();
241
242        stream_graph_visitor::visit_fragment(fragment, |node_body| {
243            let (table_id, column_ids) = match node_body {
244                NodeBody::StreamScan(stream_scan) => {
245                    if stream_scan.get_stream_scan_type().unwrap()
246                        == StreamScanType::CrossDbSnapshotBackfill
247                    {
248                        return;
249                    }
250                    (
251                        stream_scan.table_id.as_job_id(),
252                        stream_scan.upstream_columns(),
253                    )
254                }
255                NodeBody::CdcFilter(cdc_filter) => (
256                    cdc_filter.upstream_source_id.as_share_source_job_id(),
257                    vec![],
258                ),
259                NodeBody::SourceBackfill(backfill) => (
260                    backfill.upstream_source_id.as_share_source_job_id(),
261                    // FIXME: only pass required columns instead of all columns here
262                    backfill.column_descs(),
263                ),
264                _ => return,
265            };
266            table_columns
267                .try_insert(table_id, column_ids)
268                .expect("currently there should be no two same upstream tables in a fragment");
269        });
270
271        table_columns
272    }
273
274    pub fn has_shuffled_backfill(&self) -> bool {
275        let stream_node = match self.inner.node.as_ref() {
276            Some(node) => node,
277            _ => return false,
278        };
279        let mut has_shuffled_backfill = false;
280        let has_shuffled_backfill_mut_ref = &mut has_shuffled_backfill;
281        visit_stream_node_cont(stream_node, |node| {
282            let is_shuffled_backfill = if let Some(node) = &node.node_body
283                && let Some(node) = node.as_stream_scan()
284            {
285                node.stream_scan_type == StreamScanType::ArrangementBackfill as i32
286                    || node.stream_scan_type == StreamScanType::SnapshotBackfill as i32
287            } else {
288                false
289            };
290            if is_shuffled_backfill {
291                *has_shuffled_backfill_mut_ref = true;
292                false
293            } else {
294                true
295            }
296        });
297        has_shuffled_backfill
298    }
299}
300
301impl Deref for BuildingFragment {
302    type Target = StreamFragment;
303
304    fn deref(&self) -> &Self::Target {
305        &self.inner
306    }
307}
308
309impl DerefMut for BuildingFragment {
310    fn deref_mut(&mut self) -> &mut Self::Target {
311        &mut self.inner
312    }
313}
314
315/// The ID of an edge in the fragment graph. For different types of edges, the ID will be in
316/// different variants.
317#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, EnumAsInner)]
318pub(super) enum EdgeId {
319    /// The edge between two building (internal) fragments.
320    Internal {
321        /// The ID generated by the frontend, generally the operator ID of `Exchange`.
322        /// See [`StreamFragmentEdgeProto`].
323        link_id: u64,
324    },
325
326    /// The edge between an upstream external fragment and downstream building fragment. Used for
327    /// MV on MV.
328    UpstreamExternal {
329        /// The ID of the upstream table or materialized view.
330        upstream_job_id: JobId,
331        /// The ID of the downstream fragment.
332        downstream_fragment_id: GlobalFragmentId,
333    },
334
335    /// The edge between an upstream building fragment and downstream external fragment. Used for
336    /// schema change (replace table plan).
337    DownstreamExternal(DownstreamExternalEdgeId),
338}
339
340#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
341pub(super) struct DownstreamExternalEdgeId {
342    /// The ID of the original upstream fragment (`Materialize`).
343    pub(super) original_upstream_fragment_id: GlobalFragmentId,
344    /// The ID of the downstream fragment.
345    pub(super) downstream_fragment_id: GlobalFragmentId,
346}
347
348/// The edge in the fragment graph.
349///
350/// The edge can be either internal or external. This is distinguished by the [`EdgeId`].
351#[derive(Debug, Clone)]
352pub(super) struct StreamFragmentEdge {
353    /// The ID of the edge.
354    pub id: EdgeId,
355
356    /// The strategy used for dispatching the data.
357    pub dispatch_strategy: DispatchStrategy,
358}
359
360impl StreamFragmentEdge {
361    fn from_protobuf(edge: &StreamFragmentEdgeProto) -> Self {
362        Self {
363            // By creating an edge from the protobuf, we know that the edge is from the frontend and
364            // is internal.
365            id: EdgeId::Internal {
366                link_id: edge.link_id,
367            },
368            dispatch_strategy: edge.get_dispatch_strategy().unwrap().clone(),
369        }
370    }
371}
372
373fn clone_fragment(
374    fragment: &Fragment,
375    id_generator_manager: &IdGeneratorManager,
376    actor_id_counter: &AtomicU32,
377) -> Fragment {
378    let fragment_id = GlobalFragmentIdGen::new(id_generator_manager, 1)
379        .to_global_id(0)
380        .as_global_id();
381    let actor_id_gen = GlobalActorIdGen::new(actor_id_counter, fragment.actors.len() as _);
382    Fragment {
383        fragment_id,
384        fragment_type_mask: fragment.fragment_type_mask,
385        distribution_type: fragment.distribution_type,
386        actors: fragment
387            .actors
388            .iter()
389            .enumerate()
390            .map(|(i, actor)| StreamActor {
391                actor_id: actor_id_gen.to_global_id(i as _).as_global_id(),
392                fragment_id,
393                vnode_bitmap: actor.vnode_bitmap.clone(),
394                mview_definition: actor.mview_definition.clone(),
395                expr_context: actor.expr_context.clone(),
396                config_override: actor.config_override.clone(),
397            })
398            .collect(),
399        state_table_ids: fragment.state_table_ids.clone(),
400        maybe_vnode_count: fragment.maybe_vnode_count,
401        nodes: fragment.nodes.clone(),
402    }
403}
404
405pub fn check_sink_fragments_support_refresh_schema(
406    fragments: &BTreeMap<FragmentId, Fragment>,
407) -> MetaResult<()> {
408    if fragments.len() != 1 {
409        return Err(anyhow!(
410            "sink with auto schema change should have only 1 fragment, but got {:?}",
411            fragments.len()
412        )
413        .into());
414    }
415    let (_, fragment) = fragments.first_key_value().expect("non-empty");
416    let sink_node = &fragment.nodes;
417    let PbNodeBody::Sink(_) = sink_node.node_body.as_ref().unwrap() else {
418        return Err(anyhow!("expect PbNodeBody::Sink but got: {:?}", sink_node.node_body).into());
419    };
420    let [stream_scan_node] = sink_node.input.as_slice() else {
421        panic!("Sink has more than 1 input: {:?}", sink_node.input);
422    };
423    let PbNodeBody::StreamScan(scan) = stream_scan_node.node_body.as_ref().unwrap() else {
424        return Err(anyhow!(
425            "expect PbNodeBody::StreamScan but got: {:?}",
426            stream_scan_node.node_body
427        )
428        .into());
429    };
430    let stream_scan_type = PbStreamScanType::try_from(scan.stream_scan_type).unwrap();
431    if stream_scan_type != PbStreamScanType::ArrangementBackfill {
432        return Err(anyhow!(
433            "unsupported stream_scan_type for auto refresh schema: {:?}",
434            stream_scan_type
435        )
436        .into());
437    }
438    let [merge_node, _batch_plan_node] = stream_scan_node.input.as_slice() else {
439        panic!(
440            "the number of StreamScan inputs is not 2: {:?}",
441            stream_scan_node.input
442        );
443    };
444    let NodeBody::Merge(_) = merge_node.node_body.as_ref().unwrap() else {
445        return Err(anyhow!(
446            "expect PbNodeBody::Merge but got: {:?}",
447            merge_node.node_body
448        )
449        .into());
450    };
451    Ok(())
452}
453
454pub fn rewrite_refresh_schema_sink_fragment(
455    original_sink_fragment: &Fragment,
456    sink: &PbSink,
457    newly_added_columns: &[ColumnCatalog],
458    upstream_table: &PbTable,
459    upstream_table_fragment_id: FragmentId,
460    id_generator_manager: &IdGeneratorManager,
461    actor_id_counter: &AtomicU32,
462) -> MetaResult<(Fragment, Vec<PbColumnCatalog>, Option<PbTable>)> {
463    let mut new_sink_columns = sink.columns.clone();
464    fn extend_sink_columns(
465        sink_columns: &mut Vec<PbColumnCatalog>,
466        new_columns: &[ColumnCatalog],
467        get_column_name: impl Fn(&String) -> String,
468    ) {
469        let next_column_id = sink_columns
470            .iter()
471            .map(|col| col.column_desc.as_ref().unwrap().column_id + 1)
472            .max()
473            .unwrap_or(1);
474        sink_columns.extend(new_columns.iter().enumerate().map(|(i, col)| {
475            let mut col = col.to_protobuf();
476            let column_desc = col.column_desc.as_mut().unwrap();
477            column_desc.column_id = next_column_id + (i as i32);
478            column_desc.name = get_column_name(&column_desc.name);
479            col
480        }));
481    }
482    extend_sink_columns(&mut new_sink_columns, newly_added_columns, |name| {
483        name.clone()
484    });
485
486    let mut new_sink_fragment = clone_fragment(
487        original_sink_fragment,
488        id_generator_manager,
489        actor_id_counter,
490    );
491    let sink_node = &mut new_sink_fragment.nodes;
492    let PbNodeBody::Sink(sink_node_body) = sink_node.node_body.as_mut().unwrap() else {
493        return Err(anyhow!("expect PbNodeBody::Sink but got: {:?}", sink_node.node_body).into());
494    };
495    let [stream_scan_node] = sink_node.input.as_mut_slice() else {
496        panic!("Sink has more than 1 input: {:?}", sink_node.input);
497    };
498    let PbNodeBody::StreamScan(scan) = stream_scan_node.node_body.as_mut().unwrap() else {
499        return Err(anyhow!(
500            "expect PbNodeBody::StreamScan but got: {:?}",
501            stream_scan_node.node_body
502        )
503        .into());
504    };
505    let [merge_node, _batch_plan_node] = stream_scan_node.input.as_mut_slice() else {
506        panic!(
507            "the number of StreamScan inputs is not 2: {:?}",
508            stream_scan_node.input
509        );
510    };
511    let NodeBody::Merge(merge) = merge_node.node_body.as_mut().unwrap() else {
512        return Err(anyhow!(
513            "expect PbNodeBody::Merge but got: {:?}",
514            merge_node.node_body
515        )
516        .into());
517    };
518    // update sink_node
519    // following logic in <StreamSink as Explain>::distill
520    sink_node.identity = {
521        let sink_type = SinkType::from_proto(sink.sink_type());
522        let sink_type_str = sink_type.type_str();
523        let column_names = new_sink_columns
524            .iter()
525            .map(|col| {
526                ColumnCatalog::from(col.clone())
527                    .name_with_hidden()
528                    .to_string()
529            })
530            .join(", ");
531        let downstream_pk = if !sink_type.is_append_only() {
532            let downstream_pk = sink
533                .downstream_pk
534                .iter()
535                .map(|i| &sink.columns[*i as usize].column_desc.as_ref().unwrap().name)
536                .collect_vec();
537            format!(", downstream_pk: {downstream_pk:?}")
538        } else {
539            "".to_owned()
540        };
541        format!("StreamSink {{ type: {sink_type_str}, columns: [{column_names}]{downstream_pk} }}")
542    };
543    sink_node
544        .fields
545        .extend(newly_added_columns.iter().map(|col| {
546            Field::new(
547                format!("{}.{}", upstream_table.name, col.column_desc.name),
548                col.data_type().clone(),
549            )
550            .to_prost()
551        }));
552
553    let new_log_store_table = if let Some(log_store_table) = &mut sink_node_body.table {
554        extend_sink_columns(&mut log_store_table.columns, newly_added_columns, |name| {
555            format!("{}_{}", upstream_table.name, name)
556        });
557        Some(log_store_table.clone())
558    } else {
559        None
560    };
561    sink_node_body.sink_desc.as_mut().unwrap().column_catalogs = new_sink_columns.clone();
562
563    // update stream scan node
564    stream_scan_node
565        .fields
566        .extend(newly_added_columns.iter().map(|col| {
567            Field::new(
568                format!("{}.{}", upstream_table.name, col.column_desc.name),
569                col.data_type().clone(),
570            )
571            .to_prost()
572        }));
573    // following logic in <StreamTableScan as Explain>::distill
574    stream_scan_node.identity = {
575        let columns = stream_scan_node
576            .fields
577            .iter()
578            .map(|col| &col.name)
579            .join(", ");
580        format!("StreamTableScan {{ table: t, columns: [{columns}] }}")
581    };
582
583    let stream_scan_type = PbStreamScanType::try_from(scan.stream_scan_type).unwrap();
584    if stream_scan_type != PbStreamScanType::ArrangementBackfill {
585        return Err(anyhow!(
586            "unsupported stream_scan_type for auto refresh schema: {:?}",
587            stream_scan_type
588        )
589        .into());
590    }
591    scan.arrangement_table = Some(upstream_table.clone());
592    scan.output_indices.extend(
593        (0..newly_added_columns.len()).map(|i| (i + scan.upstream_column_ids.len()) as u32),
594    );
595    scan.upstream_column_ids.extend(
596        newly_added_columns
597            .iter()
598            .map(|col| col.column_id().get_id()),
599    );
600    let table_desc = scan.table_desc.as_mut().unwrap();
601    table_desc
602        .value_indices
603        .extend((0..newly_added_columns.len()).map(|i| (i + table_desc.columns.len()) as u32));
604    table_desc.columns.extend(
605        newly_added_columns
606            .iter()
607            .map(|col| col.column_desc.to_protobuf()),
608    );
609
610    // update merge node
611    merge_node.fields = scan
612        .upstream_column_ids
613        .iter()
614        .map(|&column_id| {
615            let col = upstream_table
616                .columns
617                .iter()
618                .find(|c| c.column_desc.as_ref().unwrap().column_id == column_id)
619                .unwrap();
620            let col_desc = col.column_desc.as_ref().unwrap();
621            Field::new(
622                col_desc.name.clone(),
623                col_desc.column_type.as_ref().unwrap().into(),
624            )
625            .to_prost()
626        })
627        .collect();
628    merge.upstream_fragment_id = upstream_table_fragment_id;
629    Ok((new_sink_fragment, new_sink_columns, new_log_store_table))
630}
631
632/// Adjacency list (G) of backfill orders.
633/// `G[10] -> [1, 2, 11]`
634/// means for the backfill node in `fragment 10`
635/// should be backfilled before the backfill nodes in `fragment 1, 2 and 11`.
636#[derive(Clone, Debug, Default)]
637pub struct FragmentBackfillOrder<const EXTENDED: bool> {
638    inner: HashMap<FragmentId, Vec<FragmentId>>,
639}
640
641impl<const EXTENDED: bool> Deref for FragmentBackfillOrder<EXTENDED> {
642    type Target = HashMap<FragmentId, Vec<FragmentId>>;
643
644    fn deref(&self) -> &Self::Target {
645        &self.inner
646    }
647}
648
649impl UserDefinedFragmentBackfillOrder {
650    pub fn new(inner: HashMap<FragmentId, Vec<FragmentId>>) -> Self {
651        Self { inner }
652    }
653
654    pub fn merge(orders: impl Iterator<Item = Self>) -> Self {
655        Self {
656            inner: orders.flat_map(|order| order.inner).collect(),
657        }
658    }
659
660    pub fn to_meta_model(&self) -> BackfillOrders {
661        self.inner.clone().into()
662    }
663}
664
665pub type UserDefinedFragmentBackfillOrder = FragmentBackfillOrder<false>;
666pub type ExtendedFragmentBackfillOrder = FragmentBackfillOrder<true>;
667
668/// In-memory representation of a **Fragment** Graph, built from the [`StreamFragmentGraphProto`]
669/// from the frontend.
670///
671/// This only includes nodes and edges of the current job itself. It will be converted to [`CompleteStreamFragmentGraph`] later,
672/// that contains the additional information of pre-existing
673/// fragments, which are connected to the graph's top-most or bottom-most fragments.
674#[derive(Default, Debug)]
675pub struct StreamFragmentGraph {
676    /// stores all the fragments in the graph.
677    pub(super) fragments: HashMap<GlobalFragmentId, BuildingFragment>,
678
679    /// stores edges between fragments: upstream => downstream.
680    pub(super) downstreams:
681        HashMap<GlobalFragmentId, HashMap<GlobalFragmentId, StreamFragmentEdge>>,
682
683    /// stores edges between fragments: downstream -> upstream.
684    pub(super) upstreams: HashMap<GlobalFragmentId, HashMap<GlobalFragmentId, StreamFragmentEdge>>,
685
686    /// Dependent relations of this job.
687    dependent_table_ids: HashSet<TableId>,
688
689    /// The default parallelism of the job, specified by the `STREAMING_PARALLELISM` session
690    /// variable. If not specified, all active worker slots will be used.
691    specified_parallelism: Option<NonZeroUsize>,
692    /// The parallelism to use during backfill, specified by the `STREAMING_PARALLELISM_FOR_BACKFILL`
693    /// session variable. If not specified, falls back to `specified_parallelism`.
694    specified_backfill_parallelism: Option<NonZeroUsize>,
695
696    /// Specified max parallelism, i.e., expected vnode count for the graph.
697    ///
698    /// The scheduler on the meta service will use this as a hint to decide the vnode count
699    /// for each fragment.
700    ///
701    /// Note that the actual vnode count may be different from this value.
702    /// For example, a no-shuffle exchange between current fragment graph and an existing
703    /// upstream fragment graph requires two fragments to be in the same distribution,
704    /// thus the same vnode count.
705    max_parallelism: usize,
706
707    /// The backfill ordering strategy of the graph.
708    backfill_order: BackfillOrder,
709}
710
711impl StreamFragmentGraph {
712    /// Create a new [`StreamFragmentGraph`] from the given [`StreamFragmentGraphProto`], with all
713    /// global IDs correctly filled.
714    pub fn new(
715        env: &MetaSrvEnv,
716        proto: StreamFragmentGraphProto,
717        job: &StreamingJob,
718    ) -> MetaResult<Self> {
719        let fragment_id_gen =
720            GlobalFragmentIdGen::new(env.id_gen_manager(), proto.fragments.len() as u64);
721        // Note: in SQL backend, the ids generated here are fake and will be overwritten again
722        // with `refill_internal_table_ids` later.
723        // TODO: refactor the code to remove this step.
724        let table_id_gen = GlobalTableIdGen::new(env.id_gen_manager(), proto.table_ids_cnt as u64);
725
726        // Create nodes.
727        let fragments: HashMap<_, _> = proto
728            .fragments
729            .into_iter()
730            .map(|(id, fragment)| {
731                let id = fragment_id_gen.to_global_id(id.as_raw_id());
732                let fragment = BuildingFragment::new(id, fragment, job, table_id_gen);
733                (id, fragment)
734            })
735            .collect();
736
737        assert_eq!(
738            fragments
739                .values()
740                .map(|f| f.extract_internal_tables().len() as u32)
741                .sum::<u32>(),
742            proto.table_ids_cnt
743        );
744
745        // Create edges.
746        let mut downstreams = HashMap::new();
747        let mut upstreams = HashMap::new();
748
749        for edge in proto.edges {
750            let upstream_id = fragment_id_gen.to_global_id(edge.upstream_id.as_raw_id());
751            let downstream_id = fragment_id_gen.to_global_id(edge.downstream_id.as_raw_id());
752            let edge = StreamFragmentEdge::from_protobuf(&edge);
753
754            upstreams
755                .entry(downstream_id)
756                .or_insert_with(HashMap::new)
757                .try_insert(upstream_id, edge.clone())
758                .unwrap();
759            downstreams
760                .entry(upstream_id)
761                .or_insert_with(HashMap::new)
762                .try_insert(downstream_id, edge)
763                .unwrap();
764        }
765
766        // Note: Here we directly use the field `dependent_table_ids` in the proto (resolved in
767        // frontend), instead of visiting the graph ourselves.
768        let dependent_table_ids = proto.dependent_table_ids.iter().copied().collect();
769
770        let specified_parallelism = if let Some(Parallelism { parallelism }) = proto.parallelism {
771            Some(NonZeroUsize::new(parallelism as usize).context("parallelism should not be 0")?)
772        } else {
773            None
774        };
775        let specified_backfill_parallelism =
776            if let Some(Parallelism { parallelism }) = proto.backfill_parallelism {
777                Some(
778                    NonZeroUsize::new(parallelism as usize)
779                        .context("backfill parallelism should not be 0")?,
780                )
781            } else {
782                None
783            };
784
785        let max_parallelism = proto.max_parallelism as usize;
786        let backfill_order = proto.backfill_order.unwrap_or(BackfillOrder {
787            order: Default::default(),
788        });
789
790        Ok(Self {
791            fragments,
792            downstreams,
793            upstreams,
794            dependent_table_ids,
795            specified_parallelism,
796            specified_backfill_parallelism,
797            max_parallelism,
798            backfill_order,
799        })
800    }
801
802    /// Retrieve the **incomplete** internal tables map of the whole graph.
803    ///
804    /// Note that some fields in the table catalogs are not filled during the current phase, e.g.,
805    /// `fragment_id`, `vnode_count`. They will be all filled after a `TableFragments` is built.
806    /// Be careful when using the returned values.
807    pub fn incomplete_internal_tables(&self) -> BTreeMap<TableId, Table> {
808        let mut tables = BTreeMap::new();
809        for fragment in self.fragments.values() {
810            for table in fragment.extract_internal_tables() {
811                let table_id = table.id;
812                tables
813                    .try_insert(table_id, table)
814                    .unwrap_or_else(|_| panic!("duplicated table id `{}`", table_id));
815            }
816        }
817        tables
818    }
819
820    /// Refill the internal tables' `table_id`s according to the given map, typically obtained from
821    /// `create_internal_table_catalog`.
822    pub fn refill_internal_table_ids(&mut self, table_id_map: HashMap<TableId, TableId>) {
823        for fragment in self.fragments.values_mut() {
824            stream_graph_visitor::visit_internal_tables(
825                &mut fragment.inner,
826                |table, _table_type_name| {
827                    let target = table_id_map.get(&table.id).cloned().unwrap();
828                    table.id = target;
829                },
830            );
831        }
832    }
833
834    /// Use a trivial algorithm to match the internal tables of the new graph for
835    /// `ALTER TABLE` or `ALTER SOURCE`.
836    pub fn fit_internal_tables_trivial(
837        &mut self,
838        mut old_internal_tables: Vec<Table>,
839    ) -> MetaResult<()> {
840        let mut new_internal_table_ids = Vec::new();
841        for fragment in self.fragments.values() {
842            for table in &fragment.extract_internal_tables() {
843                new_internal_table_ids.push(table.id);
844            }
845        }
846
847        if new_internal_table_ids.len() != old_internal_tables.len() {
848            bail!(
849                "Different number of internal tables. New: {}, Old: {}",
850                new_internal_table_ids.len(),
851                old_internal_tables.len()
852            );
853        }
854        old_internal_tables.sort_by(|a, b| a.id.cmp(&b.id));
855        new_internal_table_ids.sort();
856
857        let internal_table_id_map = new_internal_table_ids
858            .into_iter()
859            .zip_eq_fast(old_internal_tables.into_iter())
860            .collect::<HashMap<_, _>>();
861
862        // TODO(alter-mv): unify this with `fit_internal_table_ids_with_mapping` after we
863        // confirm the behavior is the same.
864        for fragment in self.fragments.values_mut() {
865            stream_graph_visitor::visit_internal_tables(
866                &mut fragment.inner,
867                |table, _table_type_name| {
868                    // XXX: this replaces the entire table, instead of just the id!
869                    let target = internal_table_id_map.get(&table.id).cloned().unwrap();
870                    *table = target;
871                },
872            );
873        }
874
875        Ok(())
876    }
877
878    /// Fit the internal tables' `table_id`s according to the given mapping.
879    pub fn fit_internal_table_ids_with_mapping(&mut self, mut matches: HashMap<TableId, Table>) {
880        for fragment in self.fragments.values_mut() {
881            stream_graph_visitor::visit_internal_tables(
882                &mut fragment.inner,
883                |table, _table_type_name| {
884                    let target = matches.remove(&table.id).unwrap_or_else(|| {
885                        panic!("no matching table for table {}({})", table.id, table.name)
886                    });
887                    table.id = target.id;
888                    table.maybe_vnode_count = target.maybe_vnode_count;
889                },
890            );
891        }
892    }
893
894    pub fn fit_snapshot_backfill_epochs(
895        &mut self,
896        mut snapshot_backfill_epochs: HashMap<StreamNodeLocalOperatorId, u64>,
897    ) {
898        for fragment in self.fragments.values_mut() {
899            visit_stream_node_cont_mut(fragment.node.as_mut().unwrap(), |node| {
900                if let PbNodeBody::StreamScan(scan) = node.node_body.as_mut().unwrap()
901                    && let StreamScanType::SnapshotBackfill
902                    | StreamScanType::CrossDbSnapshotBackfill = scan.stream_scan_type()
903                {
904                    let Some(epoch) = snapshot_backfill_epochs.remove(&node.operator_id) else {
905                        panic!("no snapshot epoch found for node {:?}", node)
906                    };
907                    scan.snapshot_backfill_epoch = Some(epoch);
908                }
909                true
910            })
911        }
912    }
913
914    /// Returns the fragment id where the streaming job node located.
915    pub fn table_fragment_id(&self) -> FragmentId {
916        self.fragments
917            .values()
918            .filter(|b| b.job_id.is_some())
919            .map(|b| b.fragment_id)
920            .exactly_one()
921            .expect("require exactly 1 materialize/sink/cdc source node when creating the streaming job")
922    }
923
924    /// Returns the fragment id where the table dml is received.
925    pub fn dml_fragment_id(&self) -> Option<FragmentId> {
926        self.fragments
927            .values()
928            .filter(|b| {
929                FragmentTypeMask::from(b.fragment_type_mask).contains(FragmentTypeFlag::Dml)
930            })
931            .map(|b| b.fragment_id)
932            .at_most_one()
933            .expect("require at most 1 dml node when creating the streaming job")
934    }
935
936    /// Get the dependent streaming job ids of this job.
937    pub fn dependent_table_ids(&self) -> &HashSet<TableId> {
938        &self.dependent_table_ids
939    }
940
941    /// Get the parallelism of the job, if specified by the user.
942    pub fn specified_parallelism(&self) -> Option<NonZeroUsize> {
943        self.specified_parallelism
944    }
945
946    /// Get the backfill parallelism of the job, if specified by the user.
947    pub fn specified_backfill_parallelism(&self) -> Option<NonZeroUsize> {
948        self.specified_backfill_parallelism
949    }
950
951    /// Get the expected vnode count of the graph. See documentation of the field for more details.
952    pub fn max_parallelism(&self) -> usize {
953        self.max_parallelism
954    }
955
956    /// Get downstreams of a fragment.
957    fn get_downstreams(
958        &self,
959        fragment_id: GlobalFragmentId,
960    ) -> &HashMap<GlobalFragmentId, StreamFragmentEdge> {
961        self.downstreams.get(&fragment_id).unwrap_or(&EMPTY_HASHMAP)
962    }
963
964    /// Get upstreams of a fragment.
965    fn get_upstreams(
966        &self,
967        fragment_id: GlobalFragmentId,
968    ) -> &HashMap<GlobalFragmentId, StreamFragmentEdge> {
969        self.upstreams.get(&fragment_id).unwrap_or(&EMPTY_HASHMAP)
970    }
971
972    pub fn collect_snapshot_backfill_info(
973        &self,
974    ) -> MetaResult<(Option<SnapshotBackfillInfo>, SnapshotBackfillInfo)> {
975        Self::collect_snapshot_backfill_info_impl(self.fragments.values().map(|fragment| {
976            (
977                fragment.node.as_ref().unwrap(),
978                fragment.fragment_type_mask.into(),
979            )
980        }))
981    }
982
983    /// Returns `Ok((Some(``snapshot_backfill_info``), ``cross_db_snapshot_backfill_info``))`
984    pub fn collect_snapshot_backfill_info_impl(
985        fragments: impl IntoIterator<Item = (&PbStreamNode, FragmentTypeMask)>,
986    ) -> MetaResult<(Option<SnapshotBackfillInfo>, SnapshotBackfillInfo)> {
987        let mut prev_stream_scan: Option<(Option<SnapshotBackfillInfo>, StreamScanNode)> = None;
988        let mut cross_db_info = SnapshotBackfillInfo {
989            upstream_mv_table_id_to_backfill_epoch: Default::default(),
990        };
991        let mut result = Ok(());
992        for (node, fragment_type_mask) in fragments {
993            visit_stream_node_cont(node, |node| {
994                if let Some(NodeBody::StreamScan(stream_scan)) = node.node_body.as_ref() {
995                    let stream_scan_type = StreamScanType::try_from(stream_scan.stream_scan_type)
996                        .expect("invalid stream_scan_type");
997                    let is_snapshot_backfill = match stream_scan_type {
998                        StreamScanType::SnapshotBackfill => {
999                            assert!(
1000                                fragment_type_mask
1001                                    .contains(FragmentTypeFlag::SnapshotBackfillStreamScan)
1002                            );
1003                            true
1004                        }
1005                        StreamScanType::CrossDbSnapshotBackfill => {
1006                            assert!(
1007                                fragment_type_mask
1008                                    .contains(FragmentTypeFlag::CrossDbSnapshotBackfillStreamScan)
1009                            );
1010                            cross_db_info
1011                                .upstream_mv_table_id_to_backfill_epoch
1012                                .insert(stream_scan.table_id, stream_scan.snapshot_backfill_epoch);
1013
1014                            return true;
1015                        }
1016                        _ => false,
1017                    };
1018
1019                    match &mut prev_stream_scan {
1020                        Some((prev_snapshot_backfill_info, prev_stream_scan)) => {
1021                            match (prev_snapshot_backfill_info, is_snapshot_backfill) {
1022                                (Some(prev_snapshot_backfill_info), true) => {
1023                                    prev_snapshot_backfill_info
1024                                        .upstream_mv_table_id_to_backfill_epoch
1025                                        .insert(
1026                                            stream_scan.table_id,
1027                                            stream_scan.snapshot_backfill_epoch,
1028                                        );
1029                                    true
1030                                }
1031                                (None, false) => true,
1032                                (_, _) => {
1033                                    result = Err(anyhow!("must be either all snapshot_backfill or no snapshot_backfill. Curr: {stream_scan:?} Prev: {prev_stream_scan:?}").into());
1034                                    false
1035                                }
1036                            }
1037                        }
1038                        None => {
1039                            prev_stream_scan = Some((
1040                                if is_snapshot_backfill {
1041                                    Some(SnapshotBackfillInfo {
1042                                        upstream_mv_table_id_to_backfill_epoch: HashMap::from_iter(
1043                                            [(
1044                                                stream_scan.table_id,
1045                                                stream_scan.snapshot_backfill_epoch,
1046                                            )],
1047                                        ),
1048                                    })
1049                                } else {
1050                                    None
1051                                },
1052                                *stream_scan.clone(),
1053                            ));
1054                            true
1055                        }
1056                    }
1057                } else {
1058                    true
1059                }
1060            })
1061        }
1062        result.map(|_| {
1063            (
1064                prev_stream_scan
1065                    .map(|(snapshot_backfill_info, _)| snapshot_backfill_info)
1066                    .unwrap_or(None),
1067                cross_db_info,
1068            )
1069        })
1070    }
1071
1072    /// Collect the mapping from table / `source_id` -> `fragment_id`
1073    pub fn collect_backfill_mapping(
1074        fragments: impl Iterator<Item = (FragmentId, FragmentTypeMask, &PbStreamNode)>,
1075    ) -> HashMap<RelationId, Vec<FragmentId>> {
1076        let mut mapping = HashMap::new();
1077        for (fragment_id, fragment_type_mask, node) in fragments {
1078            let has_some_scan = fragment_type_mask
1079                .contains_any([FragmentTypeFlag::StreamScan, FragmentTypeFlag::SourceScan]);
1080            if has_some_scan {
1081                visit_stream_node_cont(node, |node| {
1082                    match node.node_body.as_ref() {
1083                        Some(NodeBody::StreamScan(stream_scan)) => {
1084                            let table_id = stream_scan.table_id;
1085                            let fragments: &mut Vec<_> =
1086                                mapping.entry(table_id.as_relation_id()).or_default();
1087                            fragments.push(fragment_id);
1088                            // each fragment should have only 1 scan node.
1089                            false
1090                        }
1091                        Some(NodeBody::SourceBackfill(source_backfill)) => {
1092                            let source_id = source_backfill.upstream_source_id;
1093                            let fragments: &mut Vec<_> =
1094                                mapping.entry(source_id.as_relation_id()).or_default();
1095                            fragments.push(fragment_id);
1096                            // each fragment should have only 1 scan node.
1097                            false
1098                        }
1099                        _ => true,
1100                    }
1101                })
1102            }
1103        }
1104        mapping
1105    }
1106
1107    /// Initially the mapping that comes from frontend is between `table_ids`.
1108    /// We should remap it to fragment level, since we track progress by actor, and we can get
1109    /// a fragment <-> actor mapping
1110    pub fn create_fragment_backfill_ordering(&self) -> UserDefinedFragmentBackfillOrder {
1111        let mapping =
1112            Self::collect_backfill_mapping(self.fragments.iter().map(|(fragment_id, fragment)| {
1113                (
1114                    fragment_id.as_global_id(),
1115                    fragment.fragment_type_mask.into(),
1116                    fragment.node.as_ref().expect("should exist node"),
1117                )
1118            }));
1119        let mut fragment_ordering: HashMap<FragmentId, Vec<FragmentId>> = HashMap::new();
1120
1121        // 1. Add backfill dependencies
1122        for (rel_id, downstream_rel_ids) in &self.backfill_order.order {
1123            let fragment_ids = mapping.get(rel_id).unwrap();
1124            for fragment_id in fragment_ids {
1125                let downstream_fragment_ids = downstream_rel_ids
1126                    .data
1127                    .iter()
1128                    .flat_map(|&downstream_rel_id| {
1129                        mapping.get(&downstream_rel_id.into()).unwrap().iter()
1130                    })
1131                    .copied()
1132                    .collect();
1133                fragment_ordering.insert(*fragment_id, downstream_fragment_ids);
1134            }
1135        }
1136
1137        UserDefinedFragmentBackfillOrder {
1138            inner: fragment_ordering,
1139        }
1140    }
1141
1142    pub fn extend_fragment_backfill_ordering_with_locality_backfill<
1143        'a,
1144        FI: Iterator<Item = (FragmentId, FragmentTypeMask, &'a PbStreamNode)> + 'a,
1145    >(
1146        fragment_ordering: UserDefinedFragmentBackfillOrder,
1147        fragment_downstreams: &FragmentDownstreamRelation,
1148        get_fragments: impl Fn() -> FI,
1149    ) -> ExtendedFragmentBackfillOrder {
1150        let mut fragment_ordering = fragment_ordering.inner;
1151        let mapping = Self::collect_backfill_mapping(get_fragments());
1152        // If no backfill order is specified, we still need to ensure that all backfill fragments
1153        // run before LocalityProvider fragments.
1154        if fragment_ordering.is_empty() {
1155            for value in mapping.values() {
1156                for &fragment_id in value {
1157                    fragment_ordering.entry(fragment_id).or_default();
1158                }
1159            }
1160        }
1161
1162        // 2. Add dependencies: all backfill fragments should run before LocalityProvider fragments
1163        let locality_provider_dependencies = Self::find_locality_provider_dependencies(
1164            get_fragments().map(|(fragment_id, _, node)| (fragment_id, node)),
1165            fragment_downstreams,
1166        );
1167
1168        let backfill_fragments: HashSet<FragmentId> = mapping.values().flatten().copied().collect();
1169
1170        // Calculate LocalityProvider root fragments (zero indegree)
1171        // Root fragments are those that appear as keys but never appear as downstream dependencies
1172        let all_locality_provider_fragments: HashSet<FragmentId> =
1173            locality_provider_dependencies.keys().copied().collect();
1174        let downstream_locality_provider_fragments: HashSet<FragmentId> =
1175            locality_provider_dependencies
1176                .values()
1177                .flatten()
1178                .copied()
1179                .collect();
1180        let locality_provider_root_fragments: Vec<FragmentId> = all_locality_provider_fragments
1181            .difference(&downstream_locality_provider_fragments)
1182            .copied()
1183            .collect();
1184
1185        // For each backfill fragment, add only the root LocalityProvider fragments as dependents
1186        // This ensures backfill completes before any LocalityProvider starts, while minimizing dependencies
1187        for &backfill_fragment_id in &backfill_fragments {
1188            fragment_ordering
1189                .entry(backfill_fragment_id)
1190                .or_default()
1191                .extend(locality_provider_root_fragments.iter().copied());
1192        }
1193
1194        // 3. Add LocalityProvider internal dependencies
1195        for (fragment_id, downstream_fragments) in locality_provider_dependencies {
1196            fragment_ordering
1197                .entry(fragment_id)
1198                .or_default()
1199                .extend(downstream_fragments);
1200        }
1201
1202        // Deduplicate downstream entries per fragment; overlaps are common when the same fragment
1203        // is reached via multiple paths (e.g., with StreamShare) and would otherwise appear
1204        // multiple times.
1205        for downstream in fragment_ordering.values_mut() {
1206            let mut seen = HashSet::new();
1207            downstream.retain(|id| seen.insert(*id));
1208        }
1209
1210        ExtendedFragmentBackfillOrder {
1211            inner: fragment_ordering,
1212        }
1213    }
1214
1215    pub fn find_locality_provider_fragment_state_table_mapping(
1216        &self,
1217    ) -> HashMap<FragmentId, Vec<TableId>> {
1218        let mut mapping: HashMap<FragmentId, Vec<TableId>> = HashMap::new();
1219
1220        for (fragment_id, fragment) in &self.fragments {
1221            let fragment_id = fragment_id.as_global_id();
1222
1223            // Check if this fragment contains a LocalityProvider node
1224            if let Some(node) = fragment.node.as_ref() {
1225                let mut state_table_ids = Vec::new();
1226
1227                visit_stream_node_cont(node, |stream_node| {
1228                    if let Some(NodeBody::LocalityProvider(locality_provider)) =
1229                        stream_node.node_body.as_ref()
1230                    {
1231                        // Collect state table ID (except the progress table)
1232                        let state_table_id = locality_provider
1233                            .state_table
1234                            .as_ref()
1235                            .expect("must have state table")
1236                            .id;
1237                        state_table_ids.push(state_table_id);
1238                        false // Stop visiting once we find a LocalityProvider
1239                    } else {
1240                        true // Continue visiting
1241                    }
1242                });
1243
1244                if !state_table_ids.is_empty() {
1245                    mapping.insert(fragment_id, state_table_ids);
1246                }
1247            }
1248        }
1249
1250        mapping
1251    }
1252
1253    /// Find dependency relationships among fragments containing `LocalityProvider` nodes.
1254    /// Returns a mapping where each fragment ID maps to a list of fragment IDs that should be processed after it.
1255    /// Following the same semantics as `FragmentBackfillOrder`:
1256    /// `G[10] -> [1, 2, 11]` means `LocalityProvider` in fragment 10 should be processed
1257    /// before `LocalityProviders` in fragments 1, 2, and 11.
1258    ///
1259    /// This method assumes each fragment contains at most one `LocalityProvider` node.
1260    pub fn find_locality_provider_dependencies<'a>(
1261        fragments_nodes: impl Iterator<Item = (FragmentId, &'a PbStreamNode)>,
1262        fragment_downstreams: &FragmentDownstreamRelation,
1263    ) -> HashMap<FragmentId, Vec<FragmentId>> {
1264        let mut locality_provider_fragments = HashSet::new();
1265        let mut dependencies: HashMap<FragmentId, Vec<FragmentId>> = HashMap::new();
1266
1267        // First, identify all fragments that contain LocalityProvider nodes
1268        for (fragment_id, node) in fragments_nodes {
1269            let has_locality_provider = Self::fragment_has_locality_provider(node);
1270
1271            if has_locality_provider {
1272                locality_provider_fragments.insert(fragment_id);
1273                dependencies.entry(fragment_id).or_default();
1274            }
1275        }
1276
1277        // Build dependency relationships between LocalityProvider fragments
1278        // For each LocalityProvider fragment, find all downstream LocalityProvider fragments
1279        // The upstream fragment should be processed before the downstream fragments
1280        for &provider_fragment_id in &locality_provider_fragments {
1281            // Find all fragments downstream from this LocalityProvider fragment
1282            let mut visited = HashSet::new();
1283            let mut downstream_locality_providers = Vec::new();
1284
1285            Self::collect_downstream_locality_providers(
1286                provider_fragment_id,
1287                &locality_provider_fragments,
1288                fragment_downstreams,
1289                &mut visited,
1290                &mut downstream_locality_providers,
1291            );
1292
1293            // This fragment should be processed before all its downstream LocalityProvider fragments
1294            dependencies
1295                .entry(provider_fragment_id)
1296                .or_default()
1297                .extend(downstream_locality_providers);
1298        }
1299
1300        dependencies
1301    }
1302
1303    fn fragment_has_locality_provider(node: &PbStreamNode) -> bool {
1304        let mut has_locality_provider = false;
1305
1306        {
1307            visit_stream_node_cont(node, |stream_node| {
1308                if let Some(NodeBody::LocalityProvider(_)) = stream_node.node_body.as_ref() {
1309                    has_locality_provider = true;
1310                    false // Stop visiting once we find a LocalityProvider
1311                } else {
1312                    true // Continue visiting
1313                }
1314            });
1315        }
1316
1317        has_locality_provider
1318    }
1319
1320    /// Recursively collect downstream `LocalityProvider` fragments
1321    fn collect_downstream_locality_providers(
1322        current_fragment_id: FragmentId,
1323        locality_provider_fragments: &HashSet<FragmentId>,
1324        fragment_downstreams: &FragmentDownstreamRelation,
1325        visited: &mut HashSet<FragmentId>,
1326        downstream_providers: &mut Vec<FragmentId>,
1327    ) {
1328        if visited.contains(&current_fragment_id) {
1329            return;
1330        }
1331        visited.insert(current_fragment_id);
1332
1333        // Check all downstream fragments
1334        for downstream_fragment_id in fragment_downstreams
1335            .get(&current_fragment_id)
1336            .into_iter()
1337            .flat_map(|downstreams| {
1338                downstreams
1339                    .iter()
1340                    .map(|downstream| downstream.downstream_fragment_id)
1341            })
1342        {
1343            // If the downstream fragment is a LocalityProvider, add it to results
1344            if locality_provider_fragments.contains(&downstream_fragment_id) {
1345                downstream_providers.push(downstream_fragment_id);
1346            }
1347
1348            // Recursively check further downstream
1349            Self::collect_downstream_locality_providers(
1350                downstream_fragment_id,
1351                locality_provider_fragments,
1352                fragment_downstreams,
1353                visited,
1354                downstream_providers,
1355            );
1356        }
1357    }
1358}
1359
1360/// Fill snapshot epoch for `StreamScanNode` of `SnapshotBackfill`.
1361/// Return `true` when has change applied.
1362pub fn fill_snapshot_backfill_epoch(
1363    node: &mut StreamNode,
1364    snapshot_backfill_info: Option<&SnapshotBackfillInfo>,
1365    cross_db_snapshot_backfill_info: &SnapshotBackfillInfo,
1366) -> MetaResult<bool> {
1367    let mut result = Ok(());
1368    let mut applied = false;
1369    visit_stream_node_cont_mut(node, |node| {
1370        if let Some(NodeBody::StreamScan(stream_scan)) = node.node_body.as_mut()
1371            && (stream_scan.stream_scan_type == StreamScanType::SnapshotBackfill as i32
1372                || stream_scan.stream_scan_type == StreamScanType::CrossDbSnapshotBackfill as i32)
1373        {
1374            result = try {
1375                let table_id = stream_scan.table_id;
1376                let snapshot_epoch = cross_db_snapshot_backfill_info
1377                    .upstream_mv_table_id_to_backfill_epoch
1378                    .get(&table_id)
1379                    .or_else(|| {
1380                        snapshot_backfill_info.and_then(|snapshot_backfill_info| {
1381                            snapshot_backfill_info
1382                                .upstream_mv_table_id_to_backfill_epoch
1383                                .get(&table_id)
1384                        })
1385                    })
1386                    .ok_or_else(|| anyhow!("upstream table id not covered: {}", table_id))?
1387                    .ok_or_else(|| anyhow!("upstream table id not set: {}", table_id))?;
1388                if let Some(prev_snapshot_epoch) =
1389                    stream_scan.snapshot_backfill_epoch.replace(snapshot_epoch)
1390                {
1391                    Err(anyhow!(
1392                        "snapshot backfill epoch set again: {} {} {}",
1393                        table_id,
1394                        prev_snapshot_epoch,
1395                        snapshot_epoch
1396                    ))?;
1397                }
1398                applied = true;
1399            };
1400            result.is_ok()
1401        } else {
1402            true
1403        }
1404    });
1405    result.map(|_| applied)
1406}
1407
1408static EMPTY_HASHMAP: LazyLock<HashMap<GlobalFragmentId, StreamFragmentEdge>> =
1409    LazyLock::new(HashMap::new);
1410
1411/// A fragment that is either being built or already exists. Used for generalize the logic of
1412/// [`crate::stream::ActorGraphBuilder`].
1413#[derive(Debug, Clone, EnumAsInner)]
1414pub(super) enum EitherFragment {
1415    /// An internal fragment that is being built for the current streaming job.
1416    Building(BuildingFragment),
1417
1418    /// An existing fragment that is external but connected to the fragments being built.
1419    Existing(SharedFragmentInfo),
1420}
1421
1422/// A wrapper of [`StreamFragmentGraph`] that contains the additional information of pre-existing
1423/// fragments, which are connected to the graph's top-most or bottom-most fragments.
1424///
1425/// For example,
1426/// - if we're going to build a mview on an existing mview, the upstream fragment containing the
1427///   `Materialize` node will be included in this structure.
1428/// - if we're going to replace the plan of a table with downstream mviews, the downstream fragments
1429///   containing the `StreamScan` nodes will be included in this structure.
1430#[derive(Debug)]
1431pub struct CompleteStreamFragmentGraph {
1432    /// The fragment graph of the streaming job being built.
1433    building_graph: StreamFragmentGraph,
1434
1435    /// The required information of existing fragments.
1436    existing_fragments: HashMap<GlobalFragmentId, SharedFragmentInfo>,
1437
1438    /// The location of the actors in the existing fragments.
1439    existing_actor_location: HashMap<ActorId, WorkerId>,
1440
1441    /// Extra edges between existing fragments and the building fragments.
1442    extra_downstreams: HashMap<GlobalFragmentId, HashMap<GlobalFragmentId, StreamFragmentEdge>>,
1443
1444    /// Extra edges between existing fragments and the building fragments.
1445    extra_upstreams: HashMap<GlobalFragmentId, HashMap<GlobalFragmentId, StreamFragmentEdge>>,
1446}
1447
1448pub struct FragmentGraphUpstreamContext {
1449    /// Root fragment is the root of upstream stream graph, which can be a
1450    /// mview fragment or source fragment for cdc source job
1451    pub upstream_root_fragments: HashMap<JobId, (SharedFragmentInfo, PbStreamNode)>,
1452    pub upstream_actor_location: HashMap<ActorId, WorkerId>,
1453}
1454
1455pub struct FragmentGraphDownstreamContext {
1456    pub original_root_fragment_id: FragmentId,
1457    pub downstream_fragments: Vec<(DispatcherType, SharedFragmentInfo, PbStreamNode)>,
1458    pub downstream_actor_location: HashMap<ActorId, WorkerId>,
1459}
1460
1461impl CompleteStreamFragmentGraph {
1462    /// Create a new [`CompleteStreamFragmentGraph`] with empty existing fragments, i.e., there's no
1463    /// upstream mviews.
1464    #[cfg(test)]
1465    pub fn for_test(graph: StreamFragmentGraph) -> Self {
1466        Self {
1467            building_graph: graph,
1468            existing_fragments: Default::default(),
1469            existing_actor_location: Default::default(),
1470            extra_downstreams: Default::default(),
1471            extra_upstreams: Default::default(),
1472        }
1473    }
1474
1475    /// Create a new [`CompleteStreamFragmentGraph`] for newly created job (which has no downstreams).
1476    /// e.g., MV on MV and CDC/Source Table with the upstream existing
1477    /// `Materialize` or `Source` fragments.
1478    pub fn with_upstreams(
1479        graph: StreamFragmentGraph,
1480        upstream_context: FragmentGraphUpstreamContext,
1481        job_type: StreamingJobType,
1482    ) -> MetaResult<Self> {
1483        Self::build_helper(graph, Some(upstream_context), None, job_type)
1484    }
1485
1486    /// Create a new [`CompleteStreamFragmentGraph`] for replacing an existing table/source,
1487    /// with the downstream existing `StreamScan`/`StreamSourceScan` fragments.
1488    pub fn with_downstreams(
1489        graph: StreamFragmentGraph,
1490        downstream_context: FragmentGraphDownstreamContext,
1491        job_type: StreamingJobType,
1492    ) -> MetaResult<Self> {
1493        Self::build_helper(graph, None, Some(downstream_context), job_type)
1494    }
1495
1496    /// For replacing an existing table based on shared cdc source, which has both upstreams and downstreams.
1497    pub fn with_upstreams_and_downstreams(
1498        graph: StreamFragmentGraph,
1499        upstream_context: FragmentGraphUpstreamContext,
1500        downstream_context: FragmentGraphDownstreamContext,
1501        job_type: StreamingJobType,
1502    ) -> MetaResult<Self> {
1503        Self::build_helper(
1504            graph,
1505            Some(upstream_context),
1506            Some(downstream_context),
1507            job_type,
1508        )
1509    }
1510
1511    /// The core logic of building a [`CompleteStreamFragmentGraph`], i.e., adding extra upstream/downstream fragments.
1512    fn build_helper(
1513        mut graph: StreamFragmentGraph,
1514        upstream_ctx: Option<FragmentGraphUpstreamContext>,
1515        downstream_ctx: Option<FragmentGraphDownstreamContext>,
1516        job_type: StreamingJobType,
1517    ) -> MetaResult<Self> {
1518        let mut extra_downstreams = HashMap::new();
1519        let mut extra_upstreams = HashMap::new();
1520        let mut existing_fragments = HashMap::new();
1521
1522        let mut existing_actor_location = HashMap::new();
1523
1524        if let Some(FragmentGraphUpstreamContext {
1525            upstream_root_fragments,
1526            upstream_actor_location,
1527        }) = upstream_ctx
1528        {
1529            for (&id, fragment) in &mut graph.fragments {
1530                let uses_shuffled_backfill = fragment.has_shuffled_backfill();
1531
1532                for (&upstream_job_id, required_columns) in &fragment.upstream_job_columns {
1533                    let (upstream_fragment, nodes) = upstream_root_fragments
1534                        .get(&upstream_job_id)
1535                        .context("upstream fragment not found")?;
1536                    let upstream_root_fragment_id =
1537                        GlobalFragmentId::new(upstream_fragment.fragment_id);
1538
1539                    let edge = match job_type {
1540                        StreamingJobType::Table(TableJobType::SharedCdcSource) => {
1541                            // we traverse all fragments in the graph, and we should find out the
1542                            // CdcFilter fragment and add an edge between upstream source fragment and it.
1543                            assert_ne!(
1544                                (fragment.fragment_type_mask & FragmentTypeFlag::CdcFilter as u32),
1545                                0
1546                            );
1547
1548                            tracing::debug!(
1549                                ?upstream_root_fragment_id,
1550                                ?required_columns,
1551                                identity = ?fragment.inner.get_node().unwrap().get_identity(),
1552                                current_frag_id=?id,
1553                                "CdcFilter with upstream source fragment"
1554                            );
1555
1556                            StreamFragmentEdge {
1557                                id: EdgeId::UpstreamExternal {
1558                                    upstream_job_id,
1559                                    downstream_fragment_id: id,
1560                                },
1561                                // We always use `NoShuffle` for the exchange between the upstream
1562                                // `Source` and the downstream `StreamScan` of the new cdc table.
1563                                dispatch_strategy: DispatchStrategy {
1564                                    r#type: DispatcherType::NoShuffle as _,
1565                                    dist_key_indices: vec![], // not used for `NoShuffle`
1566                                    output_mapping: DispatchOutputMapping::identical(
1567                                        CDC_SOURCE_COLUMN_NUM as _,
1568                                    )
1569                                    .into(),
1570                                },
1571                            }
1572                        }
1573
1574                        // handle MV on MV/Source
1575                        StreamingJobType::MaterializedView
1576                        | StreamingJobType::Sink
1577                        | StreamingJobType::Index => {
1578                            // Build the extra edges between the upstream `Materialize` and
1579                            // the downstream `StreamScan` of the new job.
1580                            if upstream_fragment
1581                                .fragment_type_mask
1582                                .contains(FragmentTypeFlag::Mview)
1583                            {
1584                                // Resolve the required output columns from the upstream materialized view.
1585                                let (dist_key_indices, output_mapping) = {
1586                                    let mview_node =
1587                                        nodes.get_node_body().unwrap().as_materialize().unwrap();
1588                                    let all_columns = mview_node.column_descs();
1589                                    let dist_key_indices = mview_node.dist_key_indices();
1590                                    let output_mapping = gen_output_mapping(
1591                                        required_columns,
1592                                        &all_columns,
1593                                    )
1594                                    .context(
1595                                        "BUG: column not found in the upstream materialized view",
1596                                    )?;
1597                                    (dist_key_indices, output_mapping)
1598                                };
1599                                let dispatch_strategy = mv_on_mv_dispatch_strategy(
1600                                    uses_shuffled_backfill,
1601                                    dist_key_indices,
1602                                    output_mapping,
1603                                );
1604
1605                                StreamFragmentEdge {
1606                                    id: EdgeId::UpstreamExternal {
1607                                        upstream_job_id,
1608                                        downstream_fragment_id: id,
1609                                    },
1610                                    dispatch_strategy,
1611                                }
1612                            }
1613                            // Build the extra edges between the upstream `Source` and
1614                            // the downstream `SourceBackfill` of the new job.
1615                            else if upstream_fragment
1616                                .fragment_type_mask
1617                                .contains(FragmentTypeFlag::Source)
1618                            {
1619                                let output_mapping = {
1620                                    let source_node =
1621                                        nodes.get_node_body().unwrap().as_source().unwrap();
1622
1623                                    let all_columns = source_node.column_descs().unwrap();
1624                                    gen_output_mapping(required_columns, &all_columns).context(
1625                                        "BUG: column not found in the upstream source node",
1626                                    )?
1627                                };
1628
1629                                StreamFragmentEdge {
1630                                    id: EdgeId::UpstreamExternal {
1631                                        upstream_job_id,
1632                                        downstream_fragment_id: id,
1633                                    },
1634                                    // We always use `NoShuffle` for the exchange between the upstream
1635                                    // `Source` and the downstream `StreamScan` of the new MV.
1636                                    dispatch_strategy: DispatchStrategy {
1637                                        r#type: DispatcherType::NoShuffle as _,
1638                                        dist_key_indices: vec![], // not used for `NoShuffle`
1639                                        output_mapping: Some(output_mapping),
1640                                    },
1641                                }
1642                            } else {
1643                                bail!(
1644                                    "the upstream fragment should be a MView or Source, got fragment type: {:b}",
1645                                    upstream_fragment.fragment_type_mask
1646                                )
1647                            }
1648                        }
1649                        StreamingJobType::Source | StreamingJobType::Table(_) => {
1650                            bail!(
1651                                "the streaming job shouldn't have an upstream fragment, job_type: {:?}",
1652                                job_type
1653                            )
1654                        }
1655                    };
1656
1657                    // put the edge into the extra edges
1658                    extra_downstreams
1659                        .entry(upstream_root_fragment_id)
1660                        .or_insert_with(HashMap::new)
1661                        .try_insert(id, edge.clone())
1662                        .unwrap();
1663                    extra_upstreams
1664                        .entry(id)
1665                        .or_insert_with(HashMap::new)
1666                        .try_insert(upstream_root_fragment_id, edge)
1667                        .unwrap();
1668                }
1669            }
1670
1671            existing_fragments.extend(
1672                upstream_root_fragments
1673                    .into_values()
1674                    .map(|(f, _)| (GlobalFragmentId::new(f.fragment_id), f)),
1675            );
1676
1677            existing_actor_location.extend(upstream_actor_location);
1678        }
1679
1680        if let Some(FragmentGraphDownstreamContext {
1681            original_root_fragment_id,
1682            downstream_fragments,
1683            downstream_actor_location,
1684        }) = downstream_ctx
1685        {
1686            let original_table_fragment_id = GlobalFragmentId::new(original_root_fragment_id);
1687            let table_fragment_id = GlobalFragmentId::new(graph.table_fragment_id());
1688
1689            // Build the extra edges between the `Materialize` and the downstream `StreamScan` of the
1690            // existing materialized views.
1691            for (dispatcher_type, fragment, nodes) in &downstream_fragments {
1692                let id = GlobalFragmentId::new(fragment.fragment_id);
1693
1694                // Similar to `extract_upstream_columns_except_cross_db_backfill`.
1695                let output_columns = {
1696                    let mut res = None;
1697
1698                    stream_graph_visitor::visit_stream_node_body(nodes, |node_body| {
1699                        let columns = match node_body {
1700                            NodeBody::StreamScan(stream_scan) => stream_scan.upstream_columns(),
1701                            NodeBody::SourceBackfill(source_backfill) => {
1702                                // FIXME: only pass required columns instead of all columns here
1703                                source_backfill.column_descs()
1704                            }
1705                            _ => return,
1706                        };
1707                        res = Some(columns);
1708                    });
1709
1710                    res.context("failed to locate downstream scan")?
1711                };
1712
1713                let table_fragment = graph.fragments.get(&table_fragment_id).unwrap();
1714                let nodes = table_fragment.node.as_ref().unwrap();
1715
1716                let (dist_key_indices, output_mapping) = match job_type {
1717                    StreamingJobType::Table(_) | StreamingJobType::MaterializedView => {
1718                        let mview_node = nodes.get_node_body().unwrap().as_materialize().unwrap();
1719                        let all_columns = mview_node.column_descs();
1720                        let dist_key_indices = mview_node.dist_key_indices();
1721                        let output_mapping = gen_output_mapping(&output_columns, &all_columns)
1722                            .ok_or_else(|| {
1723                                MetaError::invalid_parameter(
1724                                    "unable to drop the column due to \
1725                                     being referenced by downstream materialized views or sinks",
1726                                )
1727                            })?;
1728                        (dist_key_indices, output_mapping)
1729                    }
1730
1731                    StreamingJobType::Source => {
1732                        let source_node = nodes.get_node_body().unwrap().as_source().unwrap();
1733                        let all_columns = source_node.column_descs().unwrap();
1734                        let output_mapping = gen_output_mapping(&output_columns, &all_columns)
1735                            .ok_or_else(|| {
1736                                MetaError::invalid_parameter(
1737                                    "unable to drop the column due to \
1738                                     being referenced by downstream materialized views or sinks",
1739                                )
1740                            })?;
1741                        assert_eq!(*dispatcher_type, DispatcherType::NoShuffle);
1742                        (
1743                            vec![], // not used for `NoShuffle`
1744                            output_mapping,
1745                        )
1746                    }
1747
1748                    _ => bail!("unsupported job type for replacement: {job_type:?}"),
1749                };
1750
1751                let edge = StreamFragmentEdge {
1752                    id: EdgeId::DownstreamExternal(DownstreamExternalEdgeId {
1753                        original_upstream_fragment_id: original_table_fragment_id,
1754                        downstream_fragment_id: id,
1755                    }),
1756                    dispatch_strategy: DispatchStrategy {
1757                        r#type: *dispatcher_type as i32,
1758                        output_mapping: Some(output_mapping),
1759                        dist_key_indices,
1760                    },
1761                };
1762
1763                extra_downstreams
1764                    .entry(table_fragment_id)
1765                    .or_insert_with(HashMap::new)
1766                    .try_insert(id, edge.clone())
1767                    .unwrap();
1768                extra_upstreams
1769                    .entry(id)
1770                    .or_insert_with(HashMap::new)
1771                    .try_insert(table_fragment_id, edge)
1772                    .unwrap();
1773            }
1774
1775            existing_fragments.extend(
1776                downstream_fragments
1777                    .into_iter()
1778                    .map(|(_, f, _)| (GlobalFragmentId::new(f.fragment_id), f)),
1779            );
1780
1781            existing_actor_location.extend(downstream_actor_location);
1782        }
1783
1784        Ok(Self {
1785            building_graph: graph,
1786            existing_fragments,
1787            existing_actor_location,
1788            extra_downstreams,
1789            extra_upstreams,
1790        })
1791    }
1792}
1793
1794/// Generate the `output_mapping` for [`DispatchStrategy`] from given columns.
1795fn gen_output_mapping(
1796    required_columns: &[PbColumnDesc],
1797    upstream_columns: &[PbColumnDesc],
1798) -> Option<DispatchOutputMapping> {
1799    let len = required_columns.len();
1800    let mut indices = vec![0; len];
1801    let mut types = None;
1802
1803    for (i, r) in required_columns.iter().enumerate() {
1804        let (ui, u) = upstream_columns
1805            .iter()
1806            .find_position(|&u| u.column_id == r.column_id)?;
1807        indices[i] = ui as u32;
1808
1809        // Only if we encounter type change (`ALTER TABLE ALTER COLUMN TYPE`) will we generate a
1810        // non-empty `types`.
1811        if u.column_type != r.column_type {
1812            types.get_or_insert_with(|| vec![TypePair::default(); len])[i] = TypePair {
1813                upstream: u.column_type.clone(),
1814                downstream: r.column_type.clone(),
1815            };
1816        }
1817    }
1818
1819    // If there's no type change, indicate it by empty `types`.
1820    let types = types.unwrap_or(Vec::new());
1821
1822    Some(DispatchOutputMapping { indices, types })
1823}
1824
1825fn mv_on_mv_dispatch_strategy(
1826    uses_shuffled_backfill: bool,
1827    dist_key_indices: Vec<u32>,
1828    output_mapping: DispatchOutputMapping,
1829) -> DispatchStrategy {
1830    if uses_shuffled_backfill {
1831        if !dist_key_indices.is_empty() {
1832            DispatchStrategy {
1833                r#type: DispatcherType::Hash as _,
1834                dist_key_indices,
1835                output_mapping: Some(output_mapping),
1836            }
1837        } else {
1838            DispatchStrategy {
1839                r#type: DispatcherType::Simple as _,
1840                dist_key_indices: vec![], // empty for Simple
1841                output_mapping: Some(output_mapping),
1842            }
1843        }
1844    } else {
1845        DispatchStrategy {
1846            r#type: DispatcherType::NoShuffle as _,
1847            dist_key_indices: vec![], // not used for `NoShuffle`
1848            output_mapping: Some(output_mapping),
1849        }
1850    }
1851}
1852
1853impl CompleteStreamFragmentGraph {
1854    /// Returns **all** fragment IDs in the complete graph, including the ones that are not in the
1855    /// building graph.
1856    pub(super) fn all_fragment_ids(&self) -> impl Iterator<Item = GlobalFragmentId> + '_ {
1857        self.building_graph
1858            .fragments
1859            .keys()
1860            .chain(self.existing_fragments.keys())
1861            .copied()
1862    }
1863
1864    /// Returns an iterator of **all** edges in the complete graph, including the external edges.
1865    pub(super) fn all_edges(
1866        &self,
1867    ) -> impl Iterator<Item = (GlobalFragmentId, GlobalFragmentId, &StreamFragmentEdge)> + '_ {
1868        self.building_graph
1869            .downstreams
1870            .iter()
1871            .chain(self.extra_downstreams.iter())
1872            .flat_map(|(&from, tos)| tos.iter().map(move |(&to, edge)| (from, to, edge)))
1873    }
1874
1875    /// Returns the distribution of the existing fragments.
1876    pub(super) fn existing_distribution(&self) -> HashMap<GlobalFragmentId, Distribution> {
1877        self.existing_fragments
1878            .iter()
1879            .map(|(&id, f)| {
1880                (
1881                    id,
1882                    Distribution::from_fragment(f, &self.existing_actor_location),
1883                )
1884            })
1885            .collect()
1886    }
1887
1888    /// Generate topological order of **all** fragments in this graph, including the ones that are
1889    /// not in the building graph. Returns error if the graph is not a DAG and topological sort can
1890    /// not be done.
1891    ///
1892    /// For MV on MV, the first fragment popped out from the heap will be the top-most node, or the
1893    /// `Sink` / `Materialize` in stream graph.
1894    pub(super) fn topo_order(&self) -> MetaResult<Vec<GlobalFragmentId>> {
1895        let mut topo = Vec::new();
1896        let mut downstream_cnts = HashMap::new();
1897
1898        // Iterate all fragments.
1899        for fragment_id in self.all_fragment_ids() {
1900            // Count how many downstreams we have for a given fragment.
1901            let downstream_cnt = self.get_downstreams(fragment_id).count();
1902            if downstream_cnt == 0 {
1903                topo.push(fragment_id);
1904            } else {
1905                downstream_cnts.insert(fragment_id, downstream_cnt);
1906            }
1907        }
1908
1909        let mut i = 0;
1910        while let Some(&fragment_id) = topo.get(i) {
1911            i += 1;
1912            // Find if we can process more fragments.
1913            for (upstream_job_id, _) in self.get_upstreams(fragment_id) {
1914                let downstream_cnt = downstream_cnts.get_mut(&upstream_job_id).unwrap();
1915                *downstream_cnt -= 1;
1916                if *downstream_cnt == 0 {
1917                    downstream_cnts.remove(&upstream_job_id);
1918                    topo.push(upstream_job_id);
1919                }
1920            }
1921        }
1922
1923        if !downstream_cnts.is_empty() {
1924            // There are fragments that are not processed yet.
1925            bail!("graph is not a DAG");
1926        }
1927
1928        Ok(topo)
1929    }
1930
1931    /// Seal a [`BuildingFragment`] from the graph into a [`Fragment`], which will be further used
1932    /// to build actors on the compute nodes and persist into meta store.
1933    pub(super) fn seal_fragment(
1934        &self,
1935        id: GlobalFragmentId,
1936        actors: Vec<StreamActor>,
1937        distribution: Distribution,
1938        stream_node: StreamNode,
1939    ) -> Fragment {
1940        let building_fragment = self.get_fragment(id).into_building().unwrap();
1941        let internal_tables = building_fragment.extract_internal_tables();
1942        let BuildingFragment {
1943            inner,
1944            job_id,
1945            upstream_job_columns: _,
1946        } = building_fragment;
1947
1948        let distribution_type = distribution.to_distribution_type();
1949        let vnode_count = distribution.vnode_count();
1950
1951        let materialized_fragment_id =
1952            if FragmentTypeMask::from(inner.fragment_type_mask).contains(FragmentTypeFlag::Mview) {
1953                job_id.map(JobId::as_mv_table_id)
1954            } else {
1955                None
1956            };
1957
1958        let vector_index_fragment_id =
1959            if inner.fragment_type_mask & FragmentTypeFlag::VectorIndexWrite as u32 != 0 {
1960                job_id.map(JobId::as_mv_table_id)
1961            } else {
1962                None
1963            };
1964
1965        let state_table_ids = internal_tables
1966            .iter()
1967            .map(|t| t.id)
1968            .chain(materialized_fragment_id)
1969            .chain(vector_index_fragment_id)
1970            .collect();
1971
1972        Fragment {
1973            fragment_id: inner.fragment_id,
1974            fragment_type_mask: inner.fragment_type_mask.into(),
1975            distribution_type,
1976            actors,
1977            state_table_ids,
1978            maybe_vnode_count: VnodeCount::set(vnode_count).to_protobuf(),
1979            nodes: stream_node,
1980        }
1981    }
1982
1983    /// Get a fragment from the complete graph, which can be either a building fragment or an
1984    /// existing fragment.
1985    pub(super) fn get_fragment(&self, fragment_id: GlobalFragmentId) -> EitherFragment {
1986        if let Some(fragment) = self.existing_fragments.get(&fragment_id) {
1987            EitherFragment::Existing(fragment.clone())
1988        } else {
1989            EitherFragment::Building(
1990                self.building_graph
1991                    .fragments
1992                    .get(&fragment_id)
1993                    .unwrap()
1994                    .clone(),
1995            )
1996        }
1997    }
1998
1999    /// Get **all** downstreams of a fragment, including the ones that are not in the building
2000    /// graph.
2001    pub(super) fn get_downstreams(
2002        &self,
2003        fragment_id: GlobalFragmentId,
2004    ) -> impl Iterator<Item = (GlobalFragmentId, &StreamFragmentEdge)> {
2005        self.building_graph
2006            .get_downstreams(fragment_id)
2007            .iter()
2008            .chain(
2009                self.extra_downstreams
2010                    .get(&fragment_id)
2011                    .into_iter()
2012                    .flatten(),
2013            )
2014            .map(|(&id, edge)| (id, edge))
2015    }
2016
2017    /// Get **all** upstreams of a fragment, including the ones that are not in the building
2018    /// graph.
2019    pub(super) fn get_upstreams(
2020        &self,
2021        fragment_id: GlobalFragmentId,
2022    ) -> impl Iterator<Item = (GlobalFragmentId, &StreamFragmentEdge)> {
2023        self.building_graph
2024            .get_upstreams(fragment_id)
2025            .iter()
2026            .chain(self.extra_upstreams.get(&fragment_id).into_iter().flatten())
2027            .map(|(&id, edge)| (id, edge))
2028    }
2029
2030    /// Returns all building fragments in the graph.
2031    pub(super) fn building_fragments(&self) -> &HashMap<GlobalFragmentId, BuildingFragment> {
2032        &self.building_graph.fragments
2033    }
2034
2035    /// Returns all building fragments in the graph, mutable.
2036    pub(super) fn building_fragments_mut(
2037        &mut self,
2038    ) -> &mut HashMap<GlobalFragmentId, BuildingFragment> {
2039        &mut self.building_graph.fragments
2040    }
2041
2042    /// Get the expected vnode count of the building graph. See documentation of the field for more details.
2043    pub(super) fn max_parallelism(&self) -> usize {
2044        self.building_graph.max_parallelism()
2045    }
2046}