risingwave_meta/stream/stream_graph/
fragment.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap, HashSet};
16use std::num::NonZeroUsize;
17use std::ops::{Deref, DerefMut};
18use std::sync::LazyLock;
19
20use anyhow::{Context, anyhow};
21use enum_as_inner::EnumAsInner;
22use itertools::Itertools;
23use risingwave_common::bail;
24use risingwave_common::catalog::{
25    CDC_SOURCE_COLUMN_NUM, ColumnCatalog, ColumnId, Field, FragmentTypeFlag, FragmentTypeMask,
26    TableId, generate_internal_table_name_with_type,
27};
28use risingwave_common::hash::VnodeCount;
29use risingwave_common::id::JobId;
30use risingwave_common::util::iter_util::ZipEqFast;
31use risingwave_common::util::stream_graph_visitor::{
32    self, visit_stream_node_cont, visit_stream_node_cont_mut,
33};
34use risingwave_connector::sink::catalog::SinkType;
35use risingwave_meta_model::streaming_job::BackfillOrders;
36use risingwave_pb::catalog::{PbSink, PbTable, Table};
37use risingwave_pb::ddl_service::TableJobType;
38use risingwave_pb::expr::{ExprNode as PbExprNode, expr_node};
39use risingwave_pb::id::{RelationId, StreamNodeLocalOperatorId};
40use risingwave_pb::plan_common::{PbColumnCatalog, PbColumnDesc};
41use risingwave_pb::stream_plan::dispatch_output_mapping::TypePair;
42use risingwave_pb::stream_plan::stream_fragment_graph::{
43    Parallelism, StreamFragment, StreamFragmentEdge as StreamFragmentEdgeProto,
44};
45use risingwave_pb::stream_plan::stream_node::{NodeBody, PbNodeBody};
46use risingwave_pb::stream_plan::{
47    BackfillOrder, DispatchOutputMapping, DispatchStrategy, DispatcherType, PbStreamNode,
48    PbStreamScanType, StreamFragmentGraph as StreamFragmentGraphProto, StreamNode, StreamScanNode,
49    StreamScanType,
50};
51
52use crate::barrier::SnapshotBackfillInfo;
53use crate::controller::id::IdGeneratorManager;
54use crate::manager::{MetaSrvEnv, StreamingJob, StreamingJobType};
55use crate::model::{Fragment, FragmentDownstreamRelation, FragmentId};
56use crate::stream::stream_graph::id::{GlobalFragmentId, GlobalFragmentIdGen, GlobalTableIdGen};
57use crate::stream::stream_graph::schedule::Distribution;
58use crate::{MetaError, MetaResult};
59
60/// The fragment in the building phase, including the [`StreamFragment`] from the frontend and
61/// several additional helper fields.
62#[derive(Debug, Clone)]
63pub(super) struct BuildingFragment {
64    /// The fragment structure from the frontend, with the global fragment ID.
65    inner: StreamFragment,
66
67    /// The ID of the job if it contains the streaming job node.
68    job_id: Option<JobId>,
69
70    /// The required column IDs of each upstream table.
71    /// Will be converted to indices when building the edge connected to the upstream.
72    ///
73    /// For shared CDC table on source, its `vec![]`, since the upstream source's output schema is fixed.
74    upstream_job_columns: HashMap<JobId, Vec<PbColumnDesc>>,
75}
76
77impl BuildingFragment {
78    /// Create a new [`BuildingFragment`] from a [`StreamFragment`]. The global fragment ID and
79    /// global table IDs will be correctly filled with the given `id` and `table_id_gen`.
80    fn new(
81        id: GlobalFragmentId,
82        fragment: StreamFragment,
83        job: &StreamingJob,
84        table_id_gen: GlobalTableIdGen,
85    ) -> Self {
86        let mut fragment = StreamFragment {
87            fragment_id: id.as_global_id(),
88            ..fragment
89        };
90
91        // Fill the information of the internal tables in the fragment.
92        Self::fill_internal_tables(&mut fragment, job, table_id_gen);
93
94        let job_id = Self::fill_job(&mut fragment, job).then(|| job.id());
95        let upstream_job_columns =
96            Self::extract_upstream_columns_except_cross_db_backfill(&fragment);
97
98        Self {
99            inner: fragment,
100            job_id,
101            upstream_job_columns,
102        }
103    }
104
105    /// Extract the internal tables from the fragment.
106    fn extract_internal_tables(&self) -> Vec<Table> {
107        let mut fragment = self.inner.clone();
108        let mut tables = Vec::new();
109        stream_graph_visitor::visit_internal_tables(&mut fragment, |table, _| {
110            tables.push(table.clone());
111        });
112        tables
113    }
114
115    /// Fill the information with the internal tables in the fragment.
116    fn fill_internal_tables(
117        fragment: &mut StreamFragment,
118        job: &StreamingJob,
119        table_id_gen: GlobalTableIdGen,
120    ) {
121        let fragment_id = fragment.fragment_id;
122        stream_graph_visitor::visit_internal_tables(fragment, |table, table_type_name| {
123            table.id = table_id_gen
124                .to_global_id(table.id.as_raw_id())
125                .as_global_id();
126            table.schema_id = job.schema_id();
127            table.database_id = job.database_id();
128            table.name = generate_internal_table_name_with_type(
129                &job.name(),
130                fragment_id,
131                table.id,
132                table_type_name,
133            );
134            table.fragment_id = fragment_id;
135            table.owner = job.owner();
136            table.job_id = Some(job.id());
137        });
138    }
139
140    /// Fill the information with the job in the fragment.
141    fn fill_job(fragment: &mut StreamFragment, job: &StreamingJob) -> bool {
142        let job_id = job.id();
143        let fragment_id = fragment.fragment_id;
144        let mut has_job = false;
145
146        stream_graph_visitor::visit_fragment_mut(fragment, |node_body| match node_body {
147            NodeBody::Materialize(materialize_node) => {
148                materialize_node.table_id = job_id.as_mv_table_id();
149
150                // Fill the table field of `MaterializeNode` from the job.
151                let table = materialize_node.table.insert(job.table().unwrap().clone());
152                table.fragment_id = fragment_id; // this will later be synced back to `job.table` with `set_info_from_graph`
153                // In production, do not include full definition in the table in plan node.
154                if cfg!(not(debug_assertions)) {
155                    table.definition = job.name();
156                }
157
158                has_job = true;
159            }
160            NodeBody::Sink(sink_node) => {
161                sink_node.sink_desc.as_mut().unwrap().id = job_id.as_sink_id();
162
163                has_job = true;
164            }
165            NodeBody::Dml(dml_node) => {
166                dml_node.table_id = job_id.as_mv_table_id();
167                dml_node.table_version_id = job.table_version_id().unwrap();
168            }
169            NodeBody::StreamFsFetch(fs_fetch_node) => {
170                if let StreamingJob::Table(table_source, _, _) = job
171                    && let Some(node_inner) = fs_fetch_node.node_inner.as_mut()
172                    && let Some(source) = table_source
173                {
174                    node_inner.source_id = source.id;
175                    if let Some(id) = source.optional_associated_table_id {
176                        node_inner.associated_table_id = Some(id.into());
177                    }
178                }
179            }
180            NodeBody::Source(source_node) => {
181                match job {
182                    // Note: For table without connector, it has a dummy Source node.
183                    // Note: For table with connector, it's source node has a source id different with the table id (job id), assigned in create_job_catalog.
184                    StreamingJob::Table(source, _table, _table_job_type) => {
185                        if let Some(source_inner) = source_node.source_inner.as_mut()
186                            && let Some(source) = source
187                        {
188                            debug_assert_ne!(source.id, job_id.as_raw_id());
189                            source_inner.source_id = source.id;
190                            if let Some(id) = source.optional_associated_table_id {
191                                source_inner.associated_table_id = Some(id.into());
192                            }
193                        }
194                    }
195                    StreamingJob::Source(source) => {
196                        has_job = true;
197                        if let Some(source_inner) = source_node.source_inner.as_mut() {
198                            debug_assert_eq!(source.id, job_id.as_raw_id());
199                            source_inner.source_id = source.id;
200                            if let Some(id) = source.optional_associated_table_id {
201                                source_inner.associated_table_id = Some(id.into());
202                            }
203                        }
204                    }
205                    // For other job types, no need to fill the source id, since it refers to an existing source.
206                    _ => {}
207                }
208            }
209            NodeBody::StreamCdcScan(node) => {
210                if let Some(table_desc) = node.cdc_table_desc.as_mut() {
211                    table_desc.table_id = job_id.as_mv_table_id();
212                }
213            }
214            NodeBody::VectorIndexWrite(node) => {
215                let table = node.table.as_mut().unwrap();
216                table.id = job_id.as_mv_table_id();
217                table.database_id = job.database_id();
218                table.schema_id = job.schema_id();
219                table.fragment_id = fragment_id;
220                #[cfg(not(debug_assertions))]
221                {
222                    table.definition = job.name();
223                }
224
225                has_job = true;
226            }
227            _ => {}
228        });
229
230        has_job
231    }
232
233    /// Extract the required columns of each upstream table except for cross-db backfill.
234    fn extract_upstream_columns_except_cross_db_backfill(
235        fragment: &StreamFragment,
236    ) -> HashMap<JobId, Vec<PbColumnDesc>> {
237        let mut table_columns = HashMap::new();
238
239        stream_graph_visitor::visit_fragment(fragment, |node_body| {
240            let (table_id, column_ids) = match node_body {
241                NodeBody::StreamScan(stream_scan) => {
242                    if stream_scan.get_stream_scan_type().unwrap()
243                        == StreamScanType::CrossDbSnapshotBackfill
244                    {
245                        return;
246                    }
247                    (
248                        stream_scan.table_id.as_job_id(),
249                        stream_scan.upstream_columns(),
250                    )
251                }
252                NodeBody::CdcFilter(cdc_filter) => (
253                    cdc_filter.upstream_source_id.as_share_source_job_id(),
254                    vec![],
255                ),
256                NodeBody::SourceBackfill(backfill) => (
257                    backfill.upstream_source_id.as_share_source_job_id(),
258                    // FIXME: only pass required columns instead of all columns here
259                    backfill.column_descs(),
260                ),
261                _ => return,
262            };
263            table_columns
264                .try_insert(table_id, column_ids)
265                .expect("currently there should be no two same upstream tables in a fragment");
266        });
267
268        table_columns
269    }
270
271    pub fn has_shuffled_backfill(&self) -> bool {
272        let stream_node = match self.inner.node.as_ref() {
273            Some(node) => node,
274            _ => return false,
275        };
276        let mut has_shuffled_backfill = false;
277        let has_shuffled_backfill_mut_ref = &mut has_shuffled_backfill;
278        visit_stream_node_cont(stream_node, |node| {
279            let is_shuffled_backfill = if let Some(node) = &node.node_body
280                && let Some(node) = node.as_stream_scan()
281            {
282                node.stream_scan_type == StreamScanType::ArrangementBackfill as i32
283                    || node.stream_scan_type == StreamScanType::SnapshotBackfill as i32
284            } else {
285                false
286            };
287            if is_shuffled_backfill {
288                *has_shuffled_backfill_mut_ref = true;
289                false
290            } else {
291                true
292            }
293        });
294        has_shuffled_backfill
295    }
296}
297
298impl Deref for BuildingFragment {
299    type Target = StreamFragment;
300
301    fn deref(&self) -> &Self::Target {
302        &self.inner
303    }
304}
305
306impl DerefMut for BuildingFragment {
307    fn deref_mut(&mut self) -> &mut Self::Target {
308        &mut self.inner
309    }
310}
311
312/// The ID of an edge in the fragment graph. For different types of edges, the ID will be in
313/// different variants.
314#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, EnumAsInner)]
315pub(super) enum EdgeId {
316    /// The edge between two building (internal) fragments.
317    Internal {
318        /// The ID generated by the frontend, generally the operator ID of `Exchange`.
319        /// See [`StreamFragmentEdgeProto`].
320        link_id: u64,
321    },
322
323    /// The edge between an upstream external fragment and downstream building fragment. Used for
324    /// MV on MV.
325    UpstreamExternal {
326        /// The ID of the upstream table or materialized view.
327        upstream_job_id: JobId,
328        /// The ID of the downstream fragment.
329        downstream_fragment_id: GlobalFragmentId,
330    },
331
332    /// The edge between an upstream building fragment and downstream external fragment. Used for
333    /// schema change (replace table plan).
334    DownstreamExternal(DownstreamExternalEdgeId),
335}
336
337#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
338pub(super) struct DownstreamExternalEdgeId {
339    /// The ID of the original upstream fragment (`Materialize`).
340    pub(super) original_upstream_fragment_id: GlobalFragmentId,
341    /// The ID of the downstream fragment.
342    pub(super) downstream_fragment_id: GlobalFragmentId,
343}
344
345/// The edge in the fragment graph.
346///
347/// The edge can be either internal or external. This is distinguished by the [`EdgeId`].
348#[derive(Debug, Clone)]
349pub(super) struct StreamFragmentEdge {
350    /// The ID of the edge.
351    pub id: EdgeId,
352
353    /// The strategy used for dispatching the data.
354    pub dispatch_strategy: DispatchStrategy,
355}
356
357impl StreamFragmentEdge {
358    fn from_protobuf(edge: &StreamFragmentEdgeProto) -> Self {
359        Self {
360            // By creating an edge from the protobuf, we know that the edge is from the frontend and
361            // is internal.
362            id: EdgeId::Internal {
363                link_id: edge.link_id,
364            },
365            dispatch_strategy: edge.get_dispatch_strategy().unwrap().clone(),
366        }
367    }
368}
369
370fn clone_fragment(fragment: &Fragment, id_generator_manager: &IdGeneratorManager) -> Fragment {
371    let fragment_id = GlobalFragmentIdGen::new(id_generator_manager, 1)
372        .to_global_id(0)
373        .as_global_id();
374    Fragment {
375        fragment_id,
376        fragment_type_mask: fragment.fragment_type_mask,
377        distribution_type: fragment.distribution_type,
378        state_table_ids: fragment.state_table_ids.clone(),
379        maybe_vnode_count: fragment.maybe_vnode_count,
380        nodes: fragment.nodes.clone(),
381    }
382}
383
384pub fn check_sink_fragments_support_refresh_schema(
385    fragments: &BTreeMap<FragmentId, Fragment>,
386) -> MetaResult<()> {
387    if fragments.len() != 1 {
388        return Err(anyhow!(
389            "sink with auto schema change should have only 1 fragment, but got {:?}",
390            fragments.len()
391        )
392        .into());
393    }
394    let (_, fragment) = fragments.first_key_value().expect("non-empty");
395    let sink_node = &fragment.nodes;
396    let PbNodeBody::Sink(_) = sink_node.node_body.as_ref().unwrap() else {
397        return Err(anyhow!("expect PbNodeBody::Sink but got: {:?}", sink_node.node_body).into());
398    };
399    let [stream_input_node] = sink_node.input.as_slice() else {
400        panic!("Sink has more than 1 input: {:?}", sink_node.input);
401    };
402    let stream_scan_node = match stream_input_node.node_body.as_ref().unwrap() {
403        PbNodeBody::StreamScan(_) => stream_input_node,
404        PbNodeBody::Project(_) => {
405            let [stream_scan_node] = stream_input_node.input.as_slice() else {
406                return Err(anyhow!(
407                    "Project node must have exactly 1 input for auto schema change, but got {:?}",
408                    stream_input_node.input.len()
409                )
410                .into());
411            };
412            stream_scan_node
413        }
414        _ => {
415            return Err(anyhow!(
416                "expect PbNodeBody::StreamScan or PbNodeBody::Project but got: {:?}",
417                stream_input_node.node_body
418            )
419            .into());
420        }
421    };
422    let PbNodeBody::StreamScan(scan) = stream_scan_node.node_body.as_ref().unwrap() else {
423        return Err(anyhow!(
424            "expect PbNodeBody::StreamScan but got: {:?}",
425            stream_scan_node.node_body
426        )
427        .into());
428    };
429    let stream_scan_type = PbStreamScanType::try_from(scan.stream_scan_type).unwrap();
430    if stream_scan_type != PbStreamScanType::ArrangementBackfill {
431        return Err(anyhow!(
432            "unsupported stream_scan_type for auto refresh schema: {:?}",
433            stream_scan_type
434        )
435        .into());
436    }
437    let [merge_node, _batch_plan_node] = stream_scan_node.input.as_slice() else {
438        panic!(
439            "the number of StreamScan inputs is not 2: {:?}",
440            stream_scan_node.input
441        );
442    };
443    let NodeBody::Merge(_) = merge_node.node_body.as_ref().unwrap() else {
444        return Err(anyhow!(
445            "expect PbNodeBody::Merge but got: {:?}",
446            merge_node.node_body
447        )
448        .into());
449    };
450    Ok(())
451}
452
453/// Output mapping info after rewriting a `StreamScan` node.
454struct ScanRewriteResult {
455    old_output_index_to_new_output_index: HashMap<u32, u32>,
456    new_output_index_by_column_id: HashMap<ColumnId, u32>,
457    output_fields: Vec<risingwave_pb::plan_common::Field>,
458}
459
460/// Append new columns to a sink/log-store column list with updated names/ids.
461fn extend_sink_columns(
462    sink_columns: &mut Vec<PbColumnCatalog>,
463    new_columns: &[ColumnCatalog],
464    get_column_name: impl Fn(&String) -> String,
465) {
466    let next_column_id = sink_columns
467        .iter()
468        .map(|col| col.column_desc.as_ref().unwrap().column_id + 1)
469        .max()
470        .unwrap_or(1);
471    sink_columns.extend(new_columns.iter().enumerate().map(|(i, col)| {
472        let mut col = col.to_protobuf();
473        let column_desc = col.column_desc.as_mut().unwrap();
474        column_desc.column_id = next_column_id + (i as i32);
475        column_desc.name = get_column_name(&column_desc.name);
476        col
477    }));
478}
479
480/// Build sink column list after removing and appending columns.
481fn build_new_sink_columns(
482    sink: &PbSink,
483    removed_column_ids: &HashSet<ColumnId>,
484    newly_added_columns: &[ColumnCatalog],
485) -> Vec<PbColumnCatalog> {
486    let mut columns: Vec<PbColumnCatalog> = sink
487        .columns
488        .iter()
489        .filter(|col| {
490            let column_id = ColumnId::new(col.column_desc.as_ref().unwrap().column_id as _);
491            !removed_column_ids.contains(&column_id)
492        })
493        .cloned()
494        .collect();
495    extend_sink_columns(&mut columns, newly_added_columns, |name| name.clone());
496    columns
497}
498
499/// Rewrite log store table columns for schema change.
500fn rewrite_log_store_table(
501    log_store_table: &mut PbTable,
502    removed_log_store_column_names: &HashSet<String>,
503    newly_added_columns: &[ColumnCatalog],
504    upstream_table_name: &str,
505) {
506    log_store_table.columns.retain(|col| {
507        !removed_log_store_column_names.contains(&col.column_desc.as_ref().unwrap().name)
508    });
509    extend_sink_columns(&mut log_store_table.columns, newly_added_columns, |name| {
510        format!("{}_{}", upstream_table_name, name)
511    });
512    log_store_table.value_indices = (0..log_store_table.columns.len() as i32).collect();
513}
514
515/// Rewrite `StreamScan` + Merge to match the new upstream schema.
516fn rewrite_stream_scan_and_merge(
517    stream_scan_node: &mut StreamNode,
518    removed_column_ids: &HashSet<ColumnId>,
519    newly_added_columns: &[ColumnCatalog],
520    upstream_table: &PbTable,
521    upstream_table_fragment_id: FragmentId,
522) -> MetaResult<ScanRewriteResult> {
523    let PbNodeBody::StreamScan(scan) = stream_scan_node.node_body.as_mut().unwrap() else {
524        return Err(anyhow!(
525            "expect PbNodeBody::StreamScan but got: {:?}",
526            stream_scan_node.node_body
527        )
528        .into());
529    };
530    let [merge_node, _batch_plan_node] = stream_scan_node.input.as_mut_slice() else {
531        panic!(
532            "the number of StreamScan inputs is not 2: {:?}",
533            stream_scan_node.input
534        );
535    };
536    let NodeBody::Merge(merge) = merge_node.node_body.as_mut().unwrap() else {
537        return Err(anyhow!(
538            "expect PbNodeBody::Merge but got: {:?}",
539            merge_node.node_body
540        )
541        .into());
542    };
543
544    let stream_scan_type = PbStreamScanType::try_from(scan.stream_scan_type).unwrap();
545    if stream_scan_type != PbStreamScanType::ArrangementBackfill {
546        return Err(anyhow!(
547            "unsupported stream_scan_type for auto refresh schema: {:?}",
548            stream_scan_type
549        )
550        .into());
551    }
552
553    let upstream_columns_by_id: HashMap<i32, PbColumnDesc> = upstream_table
554        .columns
555        .iter()
556        .map(|col| {
557            let desc = col.column_desc.as_ref().unwrap().clone();
558            (desc.column_id, desc)
559        })
560        .collect();
561
562    let old_upstream_column_ids = scan.upstream_column_ids.clone();
563    let old_output_indices = scan.output_indices.clone();
564    let mut old_upstream_index_to_new_upstream_index = HashMap::new();
565    let mut new_upstream_column_ids = Vec::new();
566    for (old_idx, &column_id) in old_upstream_column_ids.iter().enumerate() {
567        if !removed_column_ids.contains(&ColumnId::new(column_id as _)) {
568            let new_idx = new_upstream_column_ids.len() as u32;
569            old_upstream_index_to_new_upstream_index.insert(old_idx as u32, new_idx);
570            new_upstream_column_ids.push(column_id);
571        }
572    }
573    let mut new_output_indices = Vec::new();
574    for old_output_index in &old_output_indices {
575        if let Some(new_index) = old_upstream_index_to_new_upstream_index.get(old_output_index) {
576            new_output_indices.push(*new_index);
577        }
578    }
579    for col in newly_added_columns {
580        let new_index = new_upstream_column_ids.len() as u32;
581        new_upstream_column_ids.push(col.column_id().get_id());
582        new_output_indices.push(new_index);
583    }
584
585    let new_output_column_ids: Vec<i32> = new_output_indices
586        .iter()
587        .map(|&idx| new_upstream_column_ids[idx as usize])
588        .collect();
589    let mut new_output_index_by_column_id = HashMap::new();
590    for (pos, &column_id) in new_output_column_ids.iter().enumerate() {
591        new_output_index_by_column_id.insert(ColumnId::new(column_id as _), pos as u32);
592    }
593    let mut old_output_index_to_new_output_index = HashMap::new();
594    for (old_pos, old_output_index) in old_output_indices.iter().enumerate() {
595        let column_id = old_upstream_column_ids[*old_output_index as usize];
596        if let Some(new_pos) = new_output_index_by_column_id.get(&ColumnId::new(column_id as _)) {
597            old_output_index_to_new_output_index.insert(old_pos as u32, *new_pos);
598        }
599    }
600
601    scan.arrangement_table = Some(upstream_table.clone());
602    scan.upstream_column_ids = new_upstream_column_ids;
603    scan.output_indices = new_output_indices;
604    let table_desc = scan.table_desc.as_mut().unwrap();
605    table_desc.columns = scan
606        .upstream_column_ids
607        .iter()
608        .map(|column_id| {
609            upstream_columns_by_id
610                .get(column_id)
611                .unwrap_or_else(|| panic!("upstream column id not found: {}", column_id))
612                .clone()
613        })
614        .collect();
615
616    stream_scan_node.fields = new_output_column_ids
617        .iter()
618        .map(|column_id| {
619            let col_desc = upstream_columns_by_id
620                .get(column_id)
621                .unwrap_or_else(|| panic!("upstream column id not found: {}", column_id));
622            Field::new(
623                format!("{}.{}", upstream_table.name, col_desc.name),
624                col_desc.column_type.as_ref().unwrap().into(),
625            )
626            .to_prost()
627        })
628        .collect();
629    // following logic in <StreamTableScan as Explain>::distill
630    stream_scan_node.identity = {
631        let columns = stream_scan_node
632            .fields
633            .iter()
634            .map(|col| &col.name)
635            .join(", ");
636        format!("StreamTableScan {{ table: t, columns: [{columns}] }}")
637    };
638
639    // update merge node
640    merge_node.fields = scan
641        .upstream_column_ids
642        .iter()
643        .map(|&column_id| {
644            let col_desc = upstream_columns_by_id
645                .get(&column_id)
646                .unwrap_or_else(|| panic!("upstream column id not found: {}", column_id));
647            Field::new(
648                col_desc.name.clone(),
649                col_desc.column_type.as_ref().unwrap().into(),
650            )
651            .to_prost()
652        })
653        .collect();
654    merge.upstream_fragment_id = upstream_table_fragment_id;
655
656    Ok(ScanRewriteResult {
657        old_output_index_to_new_output_index,
658        new_output_index_by_column_id,
659        output_fields: stream_scan_node.fields.clone(),
660    })
661}
662
663/// Rewrite Project node input refs and extend with newly added columns.
664fn rewrite_project_node(
665    project_node: &mut StreamNode,
666    scan_rewrite: &ScanRewriteResult,
667    newly_added_columns: &[ColumnCatalog],
668    removed_column_ids: &HashSet<ColumnId>,
669    upstream_table_name: &str,
670) -> MetaResult<()> {
671    let PbNodeBody::Project(project_node_body) = project_node.node_body.as_mut().unwrap() else {
672        return Err(anyhow!(
673            "expect PbNodeBody::Project but got: {:?}",
674            project_node.node_body
675        )
676        .into());
677    };
678    let has_non_input_ref = project_node_body
679        .select_list
680        .iter()
681        .any(|expr| !matches!(expr.rex_node, Some(expr_node::RexNode::InputRef(_))));
682    if has_non_input_ref && !removed_column_ids.is_empty() {
683        return Err(anyhow!(
684            "auto schema change with drop column only supports Project with InputRef"
685        )
686        .into());
687    }
688
689    let mut new_select_list = Vec::with_capacity(project_node_body.select_list.len());
690    let mut new_project_fields = Vec::with_capacity(project_node.fields.len());
691    for (index, expr) in project_node_body.select_list.iter().enumerate() {
692        let mut new_expr = expr.clone();
693        if let Some(expr_node::RexNode::InputRef(old_index)) = new_expr.rex_node {
694            let Some(&new_index) = scan_rewrite
695                .old_output_index_to_new_output_index
696                .get(&old_index)
697            else {
698                continue;
699            };
700            new_expr.rex_node = Some(expr_node::RexNode::InputRef(new_index));
701        } else if !removed_column_ids.is_empty() {
702            return Err(anyhow!(
703                "auto schema change with drop column only supports Project with InputRef"
704            )
705            .into());
706        }
707        new_select_list.push(new_expr);
708        new_project_fields.push(project_node.fields[index].clone());
709    }
710
711    for col in newly_added_columns {
712        let Some(&new_index) = scan_rewrite
713            .new_output_index_by_column_id
714            .get(&col.column_id())
715        else {
716            return Err(anyhow!("new column id not found in scan output").into());
717        };
718        new_select_list.push(PbExprNode {
719            function_type: expr_node::Type::Unspecified as i32,
720            return_type: Some(col.data_type().to_protobuf()),
721            rex_node: Some(expr_node::RexNode::InputRef(new_index)),
722        });
723        new_project_fields.push(
724            Field::new(
725                format!("{}.{}", upstream_table_name, col.column_desc.name),
726                col.data_type().clone(),
727            )
728            .to_prost(),
729        );
730    }
731
732    project_node_body.select_list = new_select_list;
733    project_node.fields = new_project_fields;
734    Ok(())
735}
736
737pub fn rewrite_refresh_schema_sink_fragment(
738    original_sink_fragment: &Fragment,
739    sink: &PbSink,
740    newly_added_columns: &[ColumnCatalog],
741    removed_columns: &[ColumnCatalog],
742    upstream_table: &PbTable,
743    upstream_table_fragment_id: FragmentId,
744    id_generator_manager: &IdGeneratorManager,
745) -> MetaResult<(Fragment, Vec<PbColumnCatalog>, Option<PbTable>)> {
746    let removed_column_ids: HashSet<_> =
747        removed_columns.iter().map(|col| col.column_id()).collect();
748    let removed_log_store_column_names: HashSet<_> = removed_columns
749        .iter()
750        .map(|col| format!("{}_{}", upstream_table.name, col.column_desc.name))
751        .collect();
752    let new_sink_columns = build_new_sink_columns(sink, &removed_column_ids, newly_added_columns);
753
754    let mut new_sink_fragment = clone_fragment(original_sink_fragment, id_generator_manager);
755    let sink_node = &mut new_sink_fragment.nodes;
756    let PbNodeBody::Sink(sink_node_body) = sink_node.node_body.as_mut().unwrap() else {
757        return Err(anyhow!("expect PbNodeBody::Sink but got: {:?}", sink_node.node_body).into());
758    };
759    let [stream_input_node] = sink_node.input.as_mut_slice() else {
760        panic!("Sink has more than 1 input: {:?}", sink_node.input);
761    };
762    let stream_input_body = stream_input_node.node_body.as_ref().unwrap();
763    let stream_input_is_project = matches!(stream_input_body, PbNodeBody::Project(_));
764    let stream_input_is_scan = matches!(stream_input_body, PbNodeBody::StreamScan(_));
765    if !stream_input_is_project && !stream_input_is_scan {
766        return Err(anyhow!(
767            "expect PbNodeBody::StreamScan or PbNodeBody::Project but got: {:?}",
768            stream_input_body
769        )
770        .into());
771    }
772
773    // update sink_node
774    // following logic in <StreamSink as Explain>::distill
775    sink_node.identity = {
776        let sink_type = SinkType::from_proto(sink.sink_type());
777        let sink_type_str = sink_type.type_str();
778        let column_names = new_sink_columns
779            .iter()
780            .map(|col| {
781                ColumnCatalog::from(col.clone())
782                    .name_with_hidden()
783                    .to_string()
784            })
785            .join(", ");
786        let downstream_pk = if !sink_type.is_append_only() {
787            let downstream_pk = sink
788                .downstream_pk
789                .iter()
790                .map(|i| &sink.columns[*i as usize].column_desc.as_ref().unwrap().name)
791                .collect_vec();
792            format!(", downstream_pk: {downstream_pk:?}")
793        } else {
794            "".to_owned()
795        };
796        format!("StreamSink {{ type: {sink_type_str}, columns: [{column_names}]{downstream_pk} }}")
797    };
798    let new_log_store_table = if let Some(log_store_table) = &mut sink_node_body.table {
799        rewrite_log_store_table(
800            log_store_table,
801            &removed_log_store_column_names,
802            newly_added_columns,
803            &upstream_table.name,
804        );
805        Some(log_store_table.clone())
806    } else {
807        None
808    };
809    sink_node_body.sink_desc.as_mut().unwrap().column_catalogs = new_sink_columns.clone();
810
811    let stream_scan_node = if stream_input_is_project {
812        let [stream_scan_node] = stream_input_node.input.as_mut_slice() else {
813            return Err(anyhow!(
814                "Project node must have exactly 1 input for auto schema change, but got {:?}",
815                stream_input_node.input.len()
816            )
817            .into());
818        };
819        stream_scan_node
820    } else {
821        stream_input_node
822    };
823    let scan_rewrite = rewrite_stream_scan_and_merge(
824        stream_scan_node,
825        &removed_column_ids,
826        newly_added_columns,
827        upstream_table,
828        upstream_table_fragment_id,
829    )?;
830
831    if stream_input_is_project {
832        let [project_node] = sink_node.input.as_mut_slice() else {
833            panic!("Sink has more than 1 input: {:?}", sink_node.input);
834        };
835        rewrite_project_node(
836            project_node,
837            &scan_rewrite,
838            newly_added_columns,
839            &removed_column_ids,
840            &upstream_table.name,
841        )?;
842        sink_node.fields = project_node.fields.clone();
843    } else {
844        sink_node.fields = scan_rewrite.output_fields;
845    }
846    Ok((new_sink_fragment, new_sink_columns, new_log_store_table))
847}
848
849/// Adjacency list (G) of backfill orders.
850/// `G[10] -> [1, 2, 11]`
851/// means for the backfill node in `fragment 10`
852/// should be backfilled before the backfill nodes in `fragment 1, 2 and 11`.
853#[derive(Clone, Debug, Default)]
854pub struct FragmentBackfillOrder<const EXTENDED: bool> {
855    inner: HashMap<FragmentId, Vec<FragmentId>>,
856}
857
858impl<const EXTENDED: bool> Deref for FragmentBackfillOrder<EXTENDED> {
859    type Target = HashMap<FragmentId, Vec<FragmentId>>;
860
861    fn deref(&self) -> &Self::Target {
862        &self.inner
863    }
864}
865
866impl UserDefinedFragmentBackfillOrder {
867    pub fn new(inner: HashMap<FragmentId, Vec<FragmentId>>) -> Self {
868        Self { inner }
869    }
870
871    pub fn merge(orders: impl Iterator<Item = Self>) -> Self {
872        Self {
873            inner: orders.flat_map(|order| order.inner).collect(),
874        }
875    }
876
877    pub fn to_meta_model(&self) -> BackfillOrders {
878        self.inner.clone().into()
879    }
880}
881
882pub type UserDefinedFragmentBackfillOrder = FragmentBackfillOrder<false>;
883pub type ExtendedFragmentBackfillOrder = FragmentBackfillOrder<true>;
884
885/// In-memory representation of a **Fragment** Graph, built from the [`StreamFragmentGraphProto`]
886/// from the frontend.
887///
888/// This only includes nodes and edges of the current job itself. It will be converted to [`CompleteStreamFragmentGraph`] later,
889/// that contains the additional information of pre-existing
890/// fragments, which are connected to the graph's top-most or bottom-most fragments.
891#[derive(Default, Debug)]
892pub struct StreamFragmentGraph {
893    /// stores all the fragments in the graph.
894    pub(super) fragments: HashMap<GlobalFragmentId, BuildingFragment>,
895
896    /// stores edges between fragments: upstream => downstream.
897    pub(super) downstreams:
898        HashMap<GlobalFragmentId, HashMap<GlobalFragmentId, StreamFragmentEdge>>,
899
900    /// stores edges between fragments: downstream -> upstream.
901    pub(super) upstreams: HashMap<GlobalFragmentId, HashMap<GlobalFragmentId, StreamFragmentEdge>>,
902
903    /// Dependent relations of this job.
904    dependent_table_ids: HashSet<TableId>,
905
906    /// The default parallelism of the job, specified by the `STREAMING_PARALLELISM` session
907    /// variable. If not specified, all active worker slots will be used.
908    specified_parallelism: Option<NonZeroUsize>,
909    /// The parallelism to use during backfill, specified by the `STREAMING_PARALLELISM_FOR_BACKFILL`
910    /// session variable. If not specified, falls back to `specified_parallelism`.
911    specified_backfill_parallelism: Option<NonZeroUsize>,
912
913    /// Specified max parallelism, i.e., expected vnode count for the graph.
914    ///
915    /// The scheduler on the meta service will use this as a hint to decide the vnode count
916    /// for each fragment.
917    ///
918    /// Note that the actual vnode count may be different from this value.
919    /// For example, a no-shuffle exchange between current fragment graph and an existing
920    /// upstream fragment graph requires two fragments to be in the same distribution,
921    /// thus the same vnode count.
922    max_parallelism: usize,
923
924    /// The backfill ordering strategy of the graph.
925    backfill_order: BackfillOrder,
926}
927
928impl StreamFragmentGraph {
929    /// Create a new [`StreamFragmentGraph`] from the given [`StreamFragmentGraphProto`], with all
930    /// global IDs correctly filled.
931    pub fn new(
932        env: &MetaSrvEnv,
933        proto: StreamFragmentGraphProto,
934        job: &StreamingJob,
935    ) -> MetaResult<Self> {
936        let fragment_id_gen =
937            GlobalFragmentIdGen::new(env.id_gen_manager(), proto.fragments.len() as u64);
938        // Note: in SQL backend, the ids generated here are fake and will be overwritten again
939        // with `refill_internal_table_ids` later.
940        // TODO: refactor the code to remove this step.
941        let table_id_gen = GlobalTableIdGen::new(env.id_gen_manager(), proto.table_ids_cnt as u64);
942
943        // Create nodes.
944        let fragments: HashMap<_, _> = proto
945            .fragments
946            .into_iter()
947            .map(|(id, fragment)| {
948                let id = fragment_id_gen.to_global_id(id.as_raw_id());
949                let fragment = BuildingFragment::new(id, fragment, job, table_id_gen);
950                (id, fragment)
951            })
952            .collect();
953
954        assert_eq!(
955            fragments
956                .values()
957                .map(|f| f.extract_internal_tables().len() as u32)
958                .sum::<u32>(),
959            proto.table_ids_cnt
960        );
961
962        // Create edges.
963        let mut downstreams = HashMap::new();
964        let mut upstreams = HashMap::new();
965
966        for edge in proto.edges {
967            let upstream_id = fragment_id_gen.to_global_id(edge.upstream_id.as_raw_id());
968            let downstream_id = fragment_id_gen.to_global_id(edge.downstream_id.as_raw_id());
969            let edge = StreamFragmentEdge::from_protobuf(&edge);
970
971            upstreams
972                .entry(downstream_id)
973                .or_insert_with(HashMap::new)
974                .try_insert(upstream_id, edge.clone())
975                .unwrap();
976            downstreams
977                .entry(upstream_id)
978                .or_insert_with(HashMap::new)
979                .try_insert(downstream_id, edge)
980                .unwrap();
981        }
982
983        // Note: Here we directly use the field `dependent_table_ids` in the proto (resolved in
984        // frontend), instead of visiting the graph ourselves.
985        let dependent_table_ids = proto.dependent_table_ids.iter().copied().collect();
986
987        let specified_parallelism = if let Some(Parallelism { parallelism }) = proto.parallelism {
988            Some(NonZeroUsize::new(parallelism as usize).context("parallelism should not be 0")?)
989        } else {
990            None
991        };
992        let specified_backfill_parallelism =
993            if let Some(Parallelism { parallelism }) = proto.backfill_parallelism {
994                Some(
995                    NonZeroUsize::new(parallelism as usize)
996                        .context("backfill parallelism should not be 0")?,
997                )
998            } else {
999                None
1000            };
1001
1002        let max_parallelism = proto.max_parallelism as usize;
1003        let backfill_order = proto.backfill_order.unwrap_or(BackfillOrder {
1004            order: Default::default(),
1005        });
1006
1007        Ok(Self {
1008            fragments,
1009            downstreams,
1010            upstreams,
1011            dependent_table_ids,
1012            specified_parallelism,
1013            specified_backfill_parallelism,
1014            max_parallelism,
1015            backfill_order,
1016        })
1017    }
1018
1019    /// Retrieve the **incomplete** internal tables map of the whole graph.
1020    ///
1021    /// Note that some fields in the table catalogs are not filled during the current phase, e.g.,
1022    /// `fragment_id`, `vnode_count`. They will be all filled after a `TableFragments` is built.
1023    /// Be careful when using the returned values.
1024    pub fn incomplete_internal_tables(&self) -> BTreeMap<TableId, Table> {
1025        let mut tables = BTreeMap::new();
1026        for fragment in self.fragments.values() {
1027            for table in fragment.extract_internal_tables() {
1028                let table_id = table.id;
1029                tables
1030                    .try_insert(table_id, table)
1031                    .unwrap_or_else(|_| panic!("duplicated table id `{}`", table_id));
1032            }
1033        }
1034        tables
1035    }
1036
1037    /// Refill the internal tables' `table_id`s according to the given map, typically obtained from
1038    /// `create_internal_table_catalog`.
1039    pub fn refill_internal_table_ids(&mut self, table_id_map: HashMap<TableId, TableId>) {
1040        for fragment in self.fragments.values_mut() {
1041            stream_graph_visitor::visit_internal_tables(
1042                &mut fragment.inner,
1043                |table, _table_type_name| {
1044                    let target = table_id_map.get(&table.id).cloned().unwrap();
1045                    table.id = target;
1046                },
1047            );
1048        }
1049    }
1050
1051    /// Use a trivial algorithm to match the internal tables of the new graph for
1052    /// `ALTER TABLE` or `ALTER SOURCE`.
1053    pub fn fit_internal_tables_trivial(
1054        &mut self,
1055        mut old_internal_tables: Vec<Table>,
1056    ) -> MetaResult<()> {
1057        let mut new_internal_table_ids = Vec::new();
1058        for fragment in self.fragments.values() {
1059            for table in &fragment.extract_internal_tables() {
1060                new_internal_table_ids.push(table.id);
1061            }
1062        }
1063
1064        if new_internal_table_ids.len() != old_internal_tables.len() {
1065            bail!(
1066                "Different number of internal tables. New: {}, Old: {}",
1067                new_internal_table_ids.len(),
1068                old_internal_tables.len()
1069            );
1070        }
1071        old_internal_tables.sort_by(|a, b| a.id.cmp(&b.id));
1072        new_internal_table_ids.sort();
1073
1074        let internal_table_id_map = new_internal_table_ids
1075            .into_iter()
1076            .zip_eq_fast(old_internal_tables.into_iter())
1077            .collect::<HashMap<_, _>>();
1078
1079        // TODO(alter-mv): unify this with `fit_internal_table_ids_with_mapping` after we
1080        // confirm the behavior is the same.
1081        for fragment in self.fragments.values_mut() {
1082            stream_graph_visitor::visit_internal_tables(
1083                &mut fragment.inner,
1084                |table, _table_type_name| {
1085                    // XXX: this replaces the entire table, instead of just the id!
1086                    let target = internal_table_id_map.get(&table.id).cloned().unwrap();
1087                    *table = target;
1088                },
1089            );
1090        }
1091
1092        Ok(())
1093    }
1094
1095    /// Fit the internal tables' `table_id`s according to the given mapping.
1096    pub fn fit_internal_table_ids_with_mapping(&mut self, mut matches: HashMap<TableId, Table>) {
1097        for fragment in self.fragments.values_mut() {
1098            stream_graph_visitor::visit_internal_tables(
1099                &mut fragment.inner,
1100                |table, _table_type_name| {
1101                    let target = matches.remove(&table.id).unwrap_or_else(|| {
1102                        panic!("no matching table for table {}({})", table.id, table.name)
1103                    });
1104                    table.id = target.id;
1105                    table.maybe_vnode_count = target.maybe_vnode_count;
1106                },
1107            );
1108        }
1109    }
1110
1111    pub fn fit_snapshot_backfill_epochs(
1112        &mut self,
1113        mut snapshot_backfill_epochs: HashMap<StreamNodeLocalOperatorId, u64>,
1114    ) {
1115        for fragment in self.fragments.values_mut() {
1116            visit_stream_node_cont_mut(fragment.node.as_mut().unwrap(), |node| {
1117                if let PbNodeBody::StreamScan(scan) = node.node_body.as_mut().unwrap()
1118                    && let StreamScanType::SnapshotBackfill
1119                    | StreamScanType::CrossDbSnapshotBackfill = scan.stream_scan_type()
1120                {
1121                    let Some(epoch) = snapshot_backfill_epochs.remove(&node.operator_id) else {
1122                        panic!("no snapshot epoch found for node {:?}", node)
1123                    };
1124                    scan.snapshot_backfill_epoch = Some(epoch);
1125                }
1126                true
1127            })
1128        }
1129    }
1130
1131    /// Returns the fragment id where the streaming job node located.
1132    pub fn table_fragment_id(&self) -> FragmentId {
1133        self.fragments
1134            .values()
1135            .filter(|b| b.job_id.is_some())
1136            .map(|b| b.fragment_id)
1137            .exactly_one()
1138            .expect("require exactly 1 materialize/sink/cdc source node when creating the streaming job")
1139    }
1140
1141    /// Returns the fragment id where the table dml is received.
1142    pub fn dml_fragment_id(&self) -> Option<FragmentId> {
1143        self.fragments
1144            .values()
1145            .filter(|b| {
1146                FragmentTypeMask::from(b.fragment_type_mask).contains(FragmentTypeFlag::Dml)
1147            })
1148            .map(|b| b.fragment_id)
1149            .at_most_one()
1150            .expect("require at most 1 dml node when creating the streaming job")
1151    }
1152
1153    /// Get the dependent streaming job ids of this job.
1154    pub fn dependent_table_ids(&self) -> &HashSet<TableId> {
1155        &self.dependent_table_ids
1156    }
1157
1158    /// Get the parallelism of the job, if specified by the user.
1159    pub fn specified_parallelism(&self) -> Option<NonZeroUsize> {
1160        self.specified_parallelism
1161    }
1162
1163    /// Get the backfill parallelism of the job, if specified by the user.
1164    pub fn specified_backfill_parallelism(&self) -> Option<NonZeroUsize> {
1165        self.specified_backfill_parallelism
1166    }
1167
1168    /// Get the expected vnode count of the graph. See documentation of the field for more details.
1169    pub fn max_parallelism(&self) -> usize {
1170        self.max_parallelism
1171    }
1172
1173    /// Get downstreams of a fragment.
1174    fn get_downstreams(
1175        &self,
1176        fragment_id: GlobalFragmentId,
1177    ) -> &HashMap<GlobalFragmentId, StreamFragmentEdge> {
1178        self.downstreams.get(&fragment_id).unwrap_or(&EMPTY_HASHMAP)
1179    }
1180
1181    /// Get upstreams of a fragment.
1182    fn get_upstreams(
1183        &self,
1184        fragment_id: GlobalFragmentId,
1185    ) -> &HashMap<GlobalFragmentId, StreamFragmentEdge> {
1186        self.upstreams.get(&fragment_id).unwrap_or(&EMPTY_HASHMAP)
1187    }
1188
1189    pub fn collect_snapshot_backfill_info(
1190        &self,
1191    ) -> MetaResult<(Option<SnapshotBackfillInfo>, SnapshotBackfillInfo)> {
1192        Self::collect_snapshot_backfill_info_impl(self.fragments.values().map(|fragment| {
1193            (
1194                fragment.node.as_ref().unwrap(),
1195                fragment.fragment_type_mask.into(),
1196            )
1197        }))
1198    }
1199
1200    /// Returns `Ok((Some(``snapshot_backfill_info``), ``cross_db_snapshot_backfill_info``))`
1201    pub fn collect_snapshot_backfill_info_impl(
1202        fragments: impl IntoIterator<Item = (&PbStreamNode, FragmentTypeMask)>,
1203    ) -> MetaResult<(Option<SnapshotBackfillInfo>, SnapshotBackfillInfo)> {
1204        let mut prev_stream_scan: Option<(Option<SnapshotBackfillInfo>, StreamScanNode)> = None;
1205        let mut cross_db_info = SnapshotBackfillInfo {
1206            upstream_mv_table_id_to_backfill_epoch: Default::default(),
1207        };
1208        let mut result = Ok(());
1209        for (node, fragment_type_mask) in fragments {
1210            visit_stream_node_cont(node, |node| {
1211                if let Some(NodeBody::StreamScan(stream_scan)) = node.node_body.as_ref() {
1212                    let stream_scan_type = StreamScanType::try_from(stream_scan.stream_scan_type)
1213                        .expect("invalid stream_scan_type");
1214                    let is_snapshot_backfill = match stream_scan_type {
1215                        StreamScanType::SnapshotBackfill => {
1216                            assert!(
1217                                fragment_type_mask
1218                                    .contains(FragmentTypeFlag::SnapshotBackfillStreamScan)
1219                            );
1220                            true
1221                        }
1222                        StreamScanType::CrossDbSnapshotBackfill => {
1223                            assert!(
1224                                fragment_type_mask
1225                                    .contains(FragmentTypeFlag::CrossDbSnapshotBackfillStreamScan)
1226                            );
1227                            cross_db_info
1228                                .upstream_mv_table_id_to_backfill_epoch
1229                                .insert(stream_scan.table_id, stream_scan.snapshot_backfill_epoch);
1230
1231                            return true;
1232                        }
1233                        _ => false,
1234                    };
1235
1236                    match &mut prev_stream_scan {
1237                        Some((prev_snapshot_backfill_info, prev_stream_scan)) => {
1238                            match (prev_snapshot_backfill_info, is_snapshot_backfill) {
1239                                (Some(prev_snapshot_backfill_info), true) => {
1240                                    prev_snapshot_backfill_info
1241                                        .upstream_mv_table_id_to_backfill_epoch
1242                                        .insert(
1243                                            stream_scan.table_id,
1244                                            stream_scan.snapshot_backfill_epoch,
1245                                        );
1246                                    true
1247                                }
1248                                (None, false) => true,
1249                                (_, _) => {
1250                                    result = Err(anyhow!("must be either all snapshot_backfill or no snapshot_backfill. Curr: {stream_scan:?} Prev: {prev_stream_scan:?}").into());
1251                                    false
1252                                }
1253                            }
1254                        }
1255                        None => {
1256                            prev_stream_scan = Some((
1257                                if is_snapshot_backfill {
1258                                    Some(SnapshotBackfillInfo {
1259                                        upstream_mv_table_id_to_backfill_epoch: HashMap::from_iter(
1260                                            [(
1261                                                stream_scan.table_id,
1262                                                stream_scan.snapshot_backfill_epoch,
1263                                            )],
1264                                        ),
1265                                    })
1266                                } else {
1267                                    None
1268                                },
1269                                *stream_scan.clone(),
1270                            ));
1271                            true
1272                        }
1273                    }
1274                } else {
1275                    true
1276                }
1277            })
1278        }
1279        result.map(|_| {
1280            (
1281                prev_stream_scan
1282                    .map(|(snapshot_backfill_info, _)| snapshot_backfill_info)
1283                    .unwrap_or(None),
1284                cross_db_info,
1285            )
1286        })
1287    }
1288
1289    /// Collect the mapping from table / `source_id` -> `fragment_id`
1290    pub fn collect_backfill_mapping(
1291        fragments: impl Iterator<Item = (FragmentId, FragmentTypeMask, &PbStreamNode)>,
1292    ) -> HashMap<RelationId, Vec<FragmentId>> {
1293        let mut mapping = HashMap::new();
1294        for (fragment_id, fragment_type_mask, node) in fragments {
1295            let has_some_scan = fragment_type_mask
1296                .contains_any([FragmentTypeFlag::StreamScan, FragmentTypeFlag::SourceScan]);
1297            if has_some_scan {
1298                visit_stream_node_cont(node, |node| {
1299                    match node.node_body.as_ref() {
1300                        Some(NodeBody::StreamScan(stream_scan)) => {
1301                            let table_id = stream_scan.table_id;
1302                            let fragments: &mut Vec<_> =
1303                                mapping.entry(table_id.as_relation_id()).or_default();
1304                            fragments.push(fragment_id);
1305                            // each fragment should have only 1 scan node.
1306                            false
1307                        }
1308                        Some(NodeBody::SourceBackfill(source_backfill)) => {
1309                            let source_id = source_backfill.upstream_source_id;
1310                            let fragments: &mut Vec<_> =
1311                                mapping.entry(source_id.as_relation_id()).or_default();
1312                            fragments.push(fragment_id);
1313                            // each fragment should have only 1 scan node.
1314                            false
1315                        }
1316                        _ => true,
1317                    }
1318                })
1319            }
1320        }
1321        mapping
1322    }
1323
1324    /// Initially the mapping that comes from frontend is between `table_ids`.
1325    /// We should remap it to fragment level, since we track progress by actor, and we can get
1326    /// a fragment <-> actor mapping
1327    pub fn create_fragment_backfill_ordering(&self) -> UserDefinedFragmentBackfillOrder {
1328        let mapping =
1329            Self::collect_backfill_mapping(self.fragments.iter().map(|(fragment_id, fragment)| {
1330                (
1331                    fragment_id.as_global_id(),
1332                    fragment.fragment_type_mask.into(),
1333                    fragment.node.as_ref().expect("should exist node"),
1334                )
1335            }));
1336        let mut fragment_ordering: HashMap<FragmentId, Vec<FragmentId>> = HashMap::new();
1337
1338        // 1. Add backfill dependencies
1339        for (rel_id, downstream_rel_ids) in &self.backfill_order.order {
1340            let fragment_ids = mapping.get(rel_id).unwrap();
1341            for fragment_id in fragment_ids {
1342                let downstream_fragment_ids = downstream_rel_ids
1343                    .data
1344                    .iter()
1345                    .flat_map(|&downstream_rel_id| mapping.get(&downstream_rel_id).unwrap().iter())
1346                    .copied()
1347                    .collect();
1348                fragment_ordering.insert(*fragment_id, downstream_fragment_ids);
1349            }
1350        }
1351
1352        UserDefinedFragmentBackfillOrder {
1353            inner: fragment_ordering,
1354        }
1355    }
1356
1357    pub fn extend_fragment_backfill_ordering_with_locality_backfill<
1358        'a,
1359        FI: Iterator<Item = (FragmentId, FragmentTypeMask, &'a PbStreamNode)> + 'a,
1360    >(
1361        fragment_ordering: UserDefinedFragmentBackfillOrder,
1362        fragment_downstreams: &FragmentDownstreamRelation,
1363        get_fragments: impl Fn() -> FI,
1364    ) -> ExtendedFragmentBackfillOrder {
1365        let mut fragment_ordering = fragment_ordering.inner;
1366        let mapping = Self::collect_backfill_mapping(get_fragments());
1367        // If no backfill order is specified, we still need to ensure that all backfill fragments
1368        // run before LocalityProvider fragments.
1369        if fragment_ordering.is_empty() {
1370            for value in mapping.values() {
1371                for &fragment_id in value {
1372                    fragment_ordering.entry(fragment_id).or_default();
1373                }
1374            }
1375        }
1376
1377        // 2. Add dependencies: all backfill fragments should run before LocalityProvider fragments
1378        let locality_provider_dependencies = Self::find_locality_provider_dependencies(
1379            get_fragments().map(|(fragment_id, _, node)| (fragment_id, node)),
1380            fragment_downstreams,
1381        );
1382
1383        let backfill_fragments: HashSet<FragmentId> = mapping.values().flatten().copied().collect();
1384
1385        // Calculate LocalityProvider root fragments (zero indegree)
1386        // Root fragments are those that appear as keys but never appear as downstream dependencies
1387        let all_locality_provider_fragments: HashSet<FragmentId> =
1388            locality_provider_dependencies.keys().copied().collect();
1389        let downstream_locality_provider_fragments: HashSet<FragmentId> =
1390            locality_provider_dependencies
1391                .values()
1392                .flatten()
1393                .copied()
1394                .collect();
1395        let locality_provider_root_fragments: Vec<FragmentId> = all_locality_provider_fragments
1396            .difference(&downstream_locality_provider_fragments)
1397            .copied()
1398            .collect();
1399
1400        // For each backfill fragment, add only the root LocalityProvider fragments as dependents
1401        // This ensures backfill completes before any LocalityProvider starts, while minimizing dependencies
1402        for &backfill_fragment_id in &backfill_fragments {
1403            fragment_ordering
1404                .entry(backfill_fragment_id)
1405                .or_default()
1406                .extend(locality_provider_root_fragments.iter().copied());
1407        }
1408
1409        // 3. Add LocalityProvider internal dependencies
1410        for (fragment_id, downstream_fragments) in locality_provider_dependencies {
1411            fragment_ordering
1412                .entry(fragment_id)
1413                .or_default()
1414                .extend(downstream_fragments);
1415        }
1416
1417        // Deduplicate downstream entries per fragment; overlaps are common when the same fragment
1418        // is reached via multiple paths (e.g., with StreamShare) and would otherwise appear
1419        // multiple times.
1420        for downstream in fragment_ordering.values_mut() {
1421            let mut seen = HashSet::new();
1422            downstream.retain(|id| seen.insert(*id));
1423        }
1424
1425        ExtendedFragmentBackfillOrder {
1426            inner: fragment_ordering,
1427        }
1428    }
1429
1430    pub fn find_locality_provider_fragment_state_table_mapping(
1431        &self,
1432    ) -> HashMap<FragmentId, Vec<TableId>> {
1433        let mut mapping: HashMap<FragmentId, Vec<TableId>> = HashMap::new();
1434
1435        for (fragment_id, fragment) in &self.fragments {
1436            let fragment_id = fragment_id.as_global_id();
1437
1438            // Check if this fragment contains a LocalityProvider node
1439            if let Some(node) = fragment.node.as_ref() {
1440                let mut state_table_ids = Vec::new();
1441
1442                visit_stream_node_cont(node, |stream_node| {
1443                    if let Some(NodeBody::LocalityProvider(locality_provider)) =
1444                        stream_node.node_body.as_ref()
1445                    {
1446                        // Collect state table ID (except the progress table)
1447                        let state_table_id = locality_provider
1448                            .state_table
1449                            .as_ref()
1450                            .expect("must have state table")
1451                            .id;
1452                        state_table_ids.push(state_table_id);
1453                        false // Stop visiting once we find a LocalityProvider
1454                    } else {
1455                        true // Continue visiting
1456                    }
1457                });
1458
1459                if !state_table_ids.is_empty() {
1460                    mapping.insert(fragment_id, state_table_ids);
1461                }
1462            }
1463        }
1464
1465        mapping
1466    }
1467
1468    /// Find dependency relationships among fragments containing `LocalityProvider` nodes.
1469    /// Returns a mapping where each fragment ID maps to a list of fragment IDs that should be processed after it.
1470    /// Following the same semantics as `FragmentBackfillOrder`:
1471    /// `G[10] -> [1, 2, 11]` means `LocalityProvider` in fragment 10 should be processed
1472    /// before `LocalityProviders` in fragments 1, 2, and 11.
1473    ///
1474    /// This method assumes each fragment contains at most one `LocalityProvider` node.
1475    pub fn find_locality_provider_dependencies<'a>(
1476        fragments_nodes: impl Iterator<Item = (FragmentId, &'a PbStreamNode)>,
1477        fragment_downstreams: &FragmentDownstreamRelation,
1478    ) -> HashMap<FragmentId, Vec<FragmentId>> {
1479        let mut locality_provider_fragments = HashSet::new();
1480        let mut dependencies: HashMap<FragmentId, Vec<FragmentId>> = HashMap::new();
1481
1482        // First, identify all fragments that contain LocalityProvider nodes
1483        for (fragment_id, node) in fragments_nodes {
1484            let has_locality_provider = Self::fragment_has_locality_provider(node);
1485
1486            if has_locality_provider {
1487                locality_provider_fragments.insert(fragment_id);
1488                dependencies.entry(fragment_id).or_default();
1489            }
1490        }
1491
1492        // Build dependency relationships between LocalityProvider fragments
1493        // For each LocalityProvider fragment, find all downstream LocalityProvider fragments
1494        // The upstream fragment should be processed before the downstream fragments
1495        for &provider_fragment_id in &locality_provider_fragments {
1496            // Find all fragments downstream from this LocalityProvider fragment
1497            let mut visited = HashSet::new();
1498            let mut downstream_locality_providers = Vec::new();
1499
1500            Self::collect_downstream_locality_providers(
1501                provider_fragment_id,
1502                &locality_provider_fragments,
1503                fragment_downstreams,
1504                &mut visited,
1505                &mut downstream_locality_providers,
1506            );
1507
1508            // This fragment should be processed before all its downstream LocalityProvider fragments
1509            dependencies
1510                .entry(provider_fragment_id)
1511                .or_default()
1512                .extend(downstream_locality_providers);
1513        }
1514
1515        dependencies
1516    }
1517
1518    fn fragment_has_locality_provider(node: &PbStreamNode) -> bool {
1519        let mut has_locality_provider = false;
1520
1521        {
1522            visit_stream_node_cont(node, |stream_node| {
1523                if let Some(NodeBody::LocalityProvider(_)) = stream_node.node_body.as_ref() {
1524                    has_locality_provider = true;
1525                    false // Stop visiting once we find a LocalityProvider
1526                } else {
1527                    true // Continue visiting
1528                }
1529            });
1530        }
1531
1532        has_locality_provider
1533    }
1534
1535    /// Recursively collect downstream `LocalityProvider` fragments
1536    fn collect_downstream_locality_providers(
1537        current_fragment_id: FragmentId,
1538        locality_provider_fragments: &HashSet<FragmentId>,
1539        fragment_downstreams: &FragmentDownstreamRelation,
1540        visited: &mut HashSet<FragmentId>,
1541        downstream_providers: &mut Vec<FragmentId>,
1542    ) {
1543        if visited.contains(&current_fragment_id) {
1544            return;
1545        }
1546        visited.insert(current_fragment_id);
1547
1548        // Check all downstream fragments
1549        for downstream_fragment_id in fragment_downstreams
1550            .get(&current_fragment_id)
1551            .into_iter()
1552            .flat_map(|downstreams| {
1553                downstreams
1554                    .iter()
1555                    .map(|downstream| downstream.downstream_fragment_id)
1556            })
1557        {
1558            // If the downstream fragment is a LocalityProvider, add it to results
1559            if locality_provider_fragments.contains(&downstream_fragment_id) {
1560                downstream_providers.push(downstream_fragment_id);
1561            }
1562
1563            // Recursively check further downstream
1564            Self::collect_downstream_locality_providers(
1565                downstream_fragment_id,
1566                locality_provider_fragments,
1567                fragment_downstreams,
1568                visited,
1569                downstream_providers,
1570            );
1571        }
1572    }
1573}
1574
1575/// Fill snapshot epoch for `StreamScanNode` of `SnapshotBackfill`.
1576/// Return `true` when has change applied.
1577pub fn fill_snapshot_backfill_epoch(
1578    node: &mut StreamNode,
1579    snapshot_backfill_info: Option<&SnapshotBackfillInfo>,
1580    cross_db_snapshot_backfill_info: &SnapshotBackfillInfo,
1581) -> MetaResult<bool> {
1582    let mut result = Ok(());
1583    let mut applied = false;
1584    visit_stream_node_cont_mut(node, |node| {
1585        if let Some(NodeBody::StreamScan(stream_scan)) = node.node_body.as_mut()
1586            && (stream_scan.stream_scan_type == StreamScanType::SnapshotBackfill as i32
1587                || stream_scan.stream_scan_type == StreamScanType::CrossDbSnapshotBackfill as i32)
1588        {
1589            result = try {
1590                let table_id = stream_scan.table_id;
1591                let snapshot_epoch = cross_db_snapshot_backfill_info
1592                    .upstream_mv_table_id_to_backfill_epoch
1593                    .get(&table_id)
1594                    .or_else(|| {
1595                        snapshot_backfill_info.and_then(|snapshot_backfill_info| {
1596                            snapshot_backfill_info
1597                                .upstream_mv_table_id_to_backfill_epoch
1598                                .get(&table_id)
1599                        })
1600                    })
1601                    .ok_or_else(|| anyhow!("upstream table id not covered: {}", table_id))?
1602                    .ok_or_else(|| anyhow!("upstream table id not set: {}", table_id))?;
1603                if let Some(prev_snapshot_epoch) =
1604                    stream_scan.snapshot_backfill_epoch.replace(snapshot_epoch)
1605                {
1606                    Err(anyhow!(
1607                        "snapshot backfill epoch set again: {} {} {}",
1608                        table_id,
1609                        prev_snapshot_epoch,
1610                        snapshot_epoch
1611                    ))?;
1612                }
1613                applied = true;
1614            };
1615            result.is_ok()
1616        } else {
1617            true
1618        }
1619    });
1620    result.map(|_| applied)
1621}
1622
1623static EMPTY_HASHMAP: LazyLock<HashMap<GlobalFragmentId, StreamFragmentEdge>> =
1624    LazyLock::new(HashMap::new);
1625
1626/// A fragment that is either being built or already exists. Used for generalize the logic of
1627/// [`crate::stream::ActorGraphBuilder`].
1628#[derive(Debug, Clone, EnumAsInner)]
1629pub(super) enum EitherFragment {
1630    /// An internal fragment that is being built for the current streaming job.
1631    Building(BuildingFragment),
1632
1633    /// An existing fragment that is external but connected to the fragments being built.
1634    Existing,
1635}
1636
1637/// A wrapper of [`StreamFragmentGraph`] that contains the additional information of pre-existing
1638/// fragments, which are connected to the graph's top-most or bottom-most fragments.
1639///
1640/// For example,
1641/// - if we're going to build a mview on an existing mview, the upstream fragment containing the
1642///   `Materialize` node will be included in this structure.
1643/// - if we're going to replace the plan of a table with downstream mviews, the downstream fragments
1644///   containing the `StreamScan` nodes will be included in this structure.
1645#[derive(Debug)]
1646pub struct CompleteStreamFragmentGraph {
1647    /// The fragment graph of the streaming job being built.
1648    building_graph: StreamFragmentGraph,
1649
1650    /// The required information of existing fragments.
1651    existing_fragments: HashMap<GlobalFragmentId, Fragment>,
1652
1653    /// Extra edges between existing fragments and the building fragments.
1654    extra_downstreams: HashMap<GlobalFragmentId, HashMap<GlobalFragmentId, StreamFragmentEdge>>,
1655
1656    /// Extra edges between existing fragments and the building fragments.
1657    extra_upstreams: HashMap<GlobalFragmentId, HashMap<GlobalFragmentId, StreamFragmentEdge>>,
1658}
1659
1660pub struct FragmentGraphUpstreamContext {
1661    /// Root fragment is the root of upstream stream graph, which can be a
1662    /// mview fragment or source fragment for cdc source job
1663    pub upstream_root_fragments: HashMap<JobId, Fragment>,
1664}
1665
1666pub struct FragmentGraphDownstreamContext {
1667    pub original_root_fragment_id: FragmentId,
1668    pub downstream_fragments: Vec<(DispatcherType, Fragment)>,
1669}
1670
1671impl CompleteStreamFragmentGraph {
1672    /// Create a new [`CompleteStreamFragmentGraph`] with empty existing fragments, i.e., there's no
1673    /// upstream mviews.
1674    #[cfg(test)]
1675    pub fn for_test(graph: StreamFragmentGraph) -> Self {
1676        Self {
1677            building_graph: graph,
1678            existing_fragments: Default::default(),
1679            extra_downstreams: Default::default(),
1680            extra_upstreams: Default::default(),
1681        }
1682    }
1683
1684    /// Create a new [`CompleteStreamFragmentGraph`] for newly created job (which has no downstreams).
1685    /// e.g., MV on MV and CDC/Source Table with the upstream existing
1686    /// `Materialize` or `Source` fragments.
1687    pub fn with_upstreams(
1688        graph: StreamFragmentGraph,
1689        upstream_context: FragmentGraphUpstreamContext,
1690        job_type: StreamingJobType,
1691    ) -> MetaResult<Self> {
1692        Self::build_helper(graph, Some(upstream_context), None, job_type)
1693    }
1694
1695    /// Create a new [`CompleteStreamFragmentGraph`] for replacing an existing table/source,
1696    /// with the downstream existing `StreamScan`/`StreamSourceScan` fragments.
1697    pub fn with_downstreams(
1698        graph: StreamFragmentGraph,
1699        downstream_context: FragmentGraphDownstreamContext,
1700        job_type: StreamingJobType,
1701    ) -> MetaResult<Self> {
1702        Self::build_helper(graph, None, Some(downstream_context), job_type)
1703    }
1704
1705    /// For replacing an existing table based on shared cdc source, which has both upstreams and downstreams.
1706    pub fn with_upstreams_and_downstreams(
1707        graph: StreamFragmentGraph,
1708        upstream_context: FragmentGraphUpstreamContext,
1709        downstream_context: FragmentGraphDownstreamContext,
1710        job_type: StreamingJobType,
1711    ) -> MetaResult<Self> {
1712        Self::build_helper(
1713            graph,
1714            Some(upstream_context),
1715            Some(downstream_context),
1716            job_type,
1717        )
1718    }
1719
1720    /// The core logic of building a [`CompleteStreamFragmentGraph`], i.e., adding extra upstream/downstream fragments.
1721    fn build_helper(
1722        mut graph: StreamFragmentGraph,
1723        upstream_ctx: Option<FragmentGraphUpstreamContext>,
1724        downstream_ctx: Option<FragmentGraphDownstreamContext>,
1725        job_type: StreamingJobType,
1726    ) -> MetaResult<Self> {
1727        let mut extra_downstreams = HashMap::new();
1728        let mut extra_upstreams = HashMap::new();
1729        let mut existing_fragments = HashMap::new();
1730
1731        if let Some(FragmentGraphUpstreamContext {
1732            upstream_root_fragments,
1733        }) = upstream_ctx
1734        {
1735            for (&id, fragment) in &mut graph.fragments {
1736                let uses_shuffled_backfill = fragment.has_shuffled_backfill();
1737
1738                for (&upstream_job_id, required_columns) in &fragment.upstream_job_columns {
1739                    let upstream_fragment = upstream_root_fragments
1740                        .get(&upstream_job_id)
1741                        .context("upstream fragment not found")?;
1742                    let upstream_root_fragment_id =
1743                        GlobalFragmentId::new(upstream_fragment.fragment_id);
1744
1745                    let edge = match job_type {
1746                        StreamingJobType::Table(TableJobType::SharedCdcSource) => {
1747                            // we traverse all fragments in the graph, and we should find out the
1748                            // CdcFilter fragment and add an edge between upstream source fragment and it.
1749                            assert_ne!(
1750                                (fragment.fragment_type_mask & FragmentTypeFlag::CdcFilter as u32),
1751                                0
1752                            );
1753
1754                            tracing::debug!(
1755                                ?upstream_root_fragment_id,
1756                                ?required_columns,
1757                                identity = ?fragment.inner.get_node().unwrap().get_identity(),
1758                                current_frag_id=?id,
1759                                "CdcFilter with upstream source fragment"
1760                            );
1761
1762                            StreamFragmentEdge {
1763                                id: EdgeId::UpstreamExternal {
1764                                    upstream_job_id,
1765                                    downstream_fragment_id: id,
1766                                },
1767                                // We always use `NoShuffle` for the exchange between the upstream
1768                                // `Source` and the downstream `StreamScan` of the new cdc table.
1769                                dispatch_strategy: DispatchStrategy {
1770                                    r#type: DispatcherType::NoShuffle as _,
1771                                    dist_key_indices: vec![], // not used for `NoShuffle`
1772                                    output_mapping: DispatchOutputMapping::identical(
1773                                        CDC_SOURCE_COLUMN_NUM as _,
1774                                    )
1775                                    .into(),
1776                                },
1777                            }
1778                        }
1779
1780                        // handle MV on MV/Source
1781                        StreamingJobType::MaterializedView
1782                        | StreamingJobType::Sink
1783                        | StreamingJobType::Index => {
1784                            // Build the extra edges between the upstream `Materialize` and
1785                            // the downstream `StreamScan` of the new job.
1786                            if upstream_fragment
1787                                .fragment_type_mask
1788                                .contains(FragmentTypeFlag::Mview)
1789                            {
1790                                // Resolve the required output columns from the upstream materialized view.
1791                                let (dist_key_indices, output_mapping) = {
1792                                    let mview_node = upstream_fragment
1793                                        .nodes
1794                                        .get_node_body()
1795                                        .unwrap()
1796                                        .as_materialize()
1797                                        .unwrap();
1798                                    let all_columns = mview_node.column_descs();
1799                                    let dist_key_indices = mview_node.dist_key_indices();
1800                                    let output_mapping = gen_output_mapping(
1801                                        required_columns,
1802                                        &all_columns,
1803                                    )
1804                                    .context(
1805                                        "BUG: column not found in the upstream materialized view",
1806                                    )?;
1807                                    (dist_key_indices, output_mapping)
1808                                };
1809                                let dispatch_strategy = mv_on_mv_dispatch_strategy(
1810                                    uses_shuffled_backfill,
1811                                    dist_key_indices,
1812                                    output_mapping,
1813                                );
1814
1815                                StreamFragmentEdge {
1816                                    id: EdgeId::UpstreamExternal {
1817                                        upstream_job_id,
1818                                        downstream_fragment_id: id,
1819                                    },
1820                                    dispatch_strategy,
1821                                }
1822                            }
1823                            // Build the extra edges between the upstream `Source` and
1824                            // the downstream `SourceBackfill` of the new job.
1825                            else if upstream_fragment
1826                                .fragment_type_mask
1827                                .contains(FragmentTypeFlag::Source)
1828                            {
1829                                let output_mapping = {
1830                                    let source_node = upstream_fragment
1831                                        .nodes
1832                                        .get_node_body()
1833                                        .unwrap()
1834                                        .as_source()
1835                                        .unwrap();
1836
1837                                    let all_columns = source_node.column_descs().unwrap();
1838                                    gen_output_mapping(required_columns, &all_columns).context(
1839                                        "BUG: column not found in the upstream source node",
1840                                    )?
1841                                };
1842
1843                                StreamFragmentEdge {
1844                                    id: EdgeId::UpstreamExternal {
1845                                        upstream_job_id,
1846                                        downstream_fragment_id: id,
1847                                    },
1848                                    // We always use `NoShuffle` for the exchange between the upstream
1849                                    // `Source` and the downstream `StreamScan` of the new MV.
1850                                    dispatch_strategy: DispatchStrategy {
1851                                        r#type: DispatcherType::NoShuffle as _,
1852                                        dist_key_indices: vec![], // not used for `NoShuffle`
1853                                        output_mapping: Some(output_mapping),
1854                                    },
1855                                }
1856                            } else {
1857                                bail!(
1858                                    "the upstream fragment should be a MView or Source, got fragment type: {:b}",
1859                                    upstream_fragment.fragment_type_mask
1860                                )
1861                            }
1862                        }
1863                        StreamingJobType::Source | StreamingJobType::Table(_) => {
1864                            bail!(
1865                                "the streaming job shouldn't have an upstream fragment, job_type: {:?}",
1866                                job_type
1867                            )
1868                        }
1869                    };
1870
1871                    // put the edge into the extra edges
1872                    extra_downstreams
1873                        .entry(upstream_root_fragment_id)
1874                        .or_insert_with(HashMap::new)
1875                        .try_insert(id, edge.clone())
1876                        .unwrap();
1877                    extra_upstreams
1878                        .entry(id)
1879                        .or_insert_with(HashMap::new)
1880                        .try_insert(upstream_root_fragment_id, edge)
1881                        .unwrap();
1882                }
1883            }
1884
1885            existing_fragments.extend(
1886                upstream_root_fragments
1887                    .into_values()
1888                    .map(|f| (GlobalFragmentId::new(f.fragment_id), f)),
1889            );
1890        }
1891
1892        if let Some(FragmentGraphDownstreamContext {
1893            original_root_fragment_id,
1894            downstream_fragments,
1895        }) = downstream_ctx
1896        {
1897            let original_table_fragment_id = GlobalFragmentId::new(original_root_fragment_id);
1898            let table_fragment_id = GlobalFragmentId::new(graph.table_fragment_id());
1899
1900            // Build the extra edges between the `Materialize` and the downstream `StreamScan` of the
1901            // existing materialized views.
1902            for (dispatcher_type, fragment) in &downstream_fragments {
1903                let id = GlobalFragmentId::new(fragment.fragment_id);
1904
1905                // Similar to `extract_upstream_columns_except_cross_db_backfill`.
1906                let output_columns = {
1907                    let mut res = None;
1908
1909                    stream_graph_visitor::visit_stream_node_body(&fragment.nodes, |node_body| {
1910                        let columns = match node_body {
1911                            NodeBody::StreamScan(stream_scan) => stream_scan.upstream_columns(),
1912                            NodeBody::SourceBackfill(source_backfill) => {
1913                                // FIXME: only pass required columns instead of all columns here
1914                                source_backfill.column_descs()
1915                            }
1916                            _ => return,
1917                        };
1918                        res = Some(columns);
1919                    });
1920
1921                    res.context("failed to locate downstream scan")?
1922                };
1923
1924                let table_fragment = graph.fragments.get(&table_fragment_id).unwrap();
1925                let nodes = table_fragment.node.as_ref().unwrap();
1926
1927                let (dist_key_indices, output_mapping) = match job_type {
1928                    StreamingJobType::Table(_) | StreamingJobType::MaterializedView => {
1929                        let mview_node = nodes.get_node_body().unwrap().as_materialize().unwrap();
1930                        let all_columns = mview_node.column_descs();
1931                        let dist_key_indices = mview_node.dist_key_indices();
1932                        let output_mapping = gen_output_mapping(&output_columns, &all_columns)
1933                            .ok_or_else(|| {
1934                                MetaError::invalid_parameter(
1935                                    "unable to drop the column due to \
1936                                     being referenced by downstream materialized views or sinks",
1937                                )
1938                            })?;
1939                        (dist_key_indices, output_mapping)
1940                    }
1941
1942                    StreamingJobType::Source => {
1943                        let source_node = nodes.get_node_body().unwrap().as_source().unwrap();
1944                        let all_columns = source_node.column_descs().unwrap();
1945                        let output_mapping = gen_output_mapping(&output_columns, &all_columns)
1946                            .ok_or_else(|| {
1947                                MetaError::invalid_parameter(
1948                                    "unable to drop the column due to \
1949                                     being referenced by downstream materialized views or sinks",
1950                                )
1951                            })?;
1952                        assert_eq!(*dispatcher_type, DispatcherType::NoShuffle);
1953                        (
1954                            vec![], // not used for `NoShuffle`
1955                            output_mapping,
1956                        )
1957                    }
1958
1959                    _ => bail!("unsupported job type for replacement: {job_type:?}"),
1960                };
1961
1962                let edge = StreamFragmentEdge {
1963                    id: EdgeId::DownstreamExternal(DownstreamExternalEdgeId {
1964                        original_upstream_fragment_id: original_table_fragment_id,
1965                        downstream_fragment_id: id,
1966                    }),
1967                    dispatch_strategy: DispatchStrategy {
1968                        r#type: *dispatcher_type as i32,
1969                        output_mapping: Some(output_mapping),
1970                        dist_key_indices,
1971                    },
1972                };
1973
1974                extra_downstreams
1975                    .entry(table_fragment_id)
1976                    .or_insert_with(HashMap::new)
1977                    .try_insert(id, edge.clone())
1978                    .unwrap();
1979                extra_upstreams
1980                    .entry(id)
1981                    .or_insert_with(HashMap::new)
1982                    .try_insert(table_fragment_id, edge)
1983                    .unwrap();
1984            }
1985
1986            existing_fragments.extend(
1987                downstream_fragments
1988                    .into_iter()
1989                    .map(|(_, f)| (GlobalFragmentId::new(f.fragment_id), f)),
1990            );
1991        }
1992
1993        Ok(Self {
1994            building_graph: graph,
1995            existing_fragments,
1996            extra_downstreams,
1997            extra_upstreams,
1998        })
1999    }
2000}
2001
2002/// Generate the `output_mapping` for [`DispatchStrategy`] from given columns.
2003fn gen_output_mapping(
2004    required_columns: &[PbColumnDesc],
2005    upstream_columns: &[PbColumnDesc],
2006) -> Option<DispatchOutputMapping> {
2007    let len = required_columns.len();
2008    let mut indices = vec![0; len];
2009    let mut types = None;
2010
2011    for (i, r) in required_columns.iter().enumerate() {
2012        let (ui, u) = upstream_columns
2013            .iter()
2014            .find_position(|&u| u.column_id == r.column_id)?;
2015        indices[i] = ui as u32;
2016
2017        // Only if we encounter type change (`ALTER TABLE ALTER COLUMN TYPE`) will we generate a
2018        // non-empty `types`.
2019        if u.column_type != r.column_type {
2020            types.get_or_insert_with(|| vec![TypePair::default(); len])[i] = TypePair {
2021                upstream: u.column_type.clone(),
2022                downstream: r.column_type.clone(),
2023            };
2024        }
2025    }
2026
2027    // If there's no type change, indicate it by empty `types`.
2028    let types = types.unwrap_or(Vec::new());
2029
2030    Some(DispatchOutputMapping { indices, types })
2031}
2032
2033fn mv_on_mv_dispatch_strategy(
2034    uses_shuffled_backfill: bool,
2035    dist_key_indices: Vec<u32>,
2036    output_mapping: DispatchOutputMapping,
2037) -> DispatchStrategy {
2038    if uses_shuffled_backfill {
2039        if !dist_key_indices.is_empty() {
2040            DispatchStrategy {
2041                r#type: DispatcherType::Hash as _,
2042                dist_key_indices,
2043                output_mapping: Some(output_mapping),
2044            }
2045        } else {
2046            DispatchStrategy {
2047                r#type: DispatcherType::Simple as _,
2048                dist_key_indices: vec![], // empty for Simple
2049                output_mapping: Some(output_mapping),
2050            }
2051        }
2052    } else {
2053        DispatchStrategy {
2054            r#type: DispatcherType::NoShuffle as _,
2055            dist_key_indices: vec![], // not used for `NoShuffle`
2056            output_mapping: Some(output_mapping),
2057        }
2058    }
2059}
2060
2061impl CompleteStreamFragmentGraph {
2062    /// Returns **all** fragment IDs in the complete graph, including the ones that are not in the
2063    /// building graph.
2064    pub(super) fn all_fragment_ids(&self) -> impl Iterator<Item = GlobalFragmentId> + '_ {
2065        self.building_graph
2066            .fragments
2067            .keys()
2068            .chain(self.existing_fragments.keys())
2069            .copied()
2070    }
2071
2072    /// Returns an iterator of **all** edges in the complete graph, including the external edges.
2073    pub(super) fn all_edges(
2074        &self,
2075    ) -> impl Iterator<Item = (GlobalFragmentId, GlobalFragmentId, &StreamFragmentEdge)> + '_ {
2076        self.building_graph
2077            .downstreams
2078            .iter()
2079            .chain(self.extra_downstreams.iter())
2080            .flat_map(|(&from, tos)| tos.iter().map(move |(&to, edge)| (from, to, edge)))
2081    }
2082
2083    /// Returns the distribution of the existing fragments.
2084    pub(super) fn existing_distribution(&self) -> HashMap<GlobalFragmentId, Distribution> {
2085        self.existing_fragments
2086            .iter()
2087            .map(|(&id, f)| (id, Distribution::from_fragment(f)))
2088            .collect()
2089    }
2090
2091    /// Generate topological order of **all** fragments in this graph, including the ones that are
2092    /// not in the building graph. Returns error if the graph is not a DAG and topological sort can
2093    /// not be done.
2094    ///
2095    /// For MV on MV, the first fragment popped out from the heap will be the top-most node, or the
2096    /// `Sink` / `Materialize` in stream graph.
2097    pub(super) fn topo_order(&self) -> MetaResult<Vec<GlobalFragmentId>> {
2098        let mut topo = Vec::new();
2099        let mut downstream_cnts = HashMap::new();
2100
2101        // Iterate all fragments.
2102        for fragment_id in self.all_fragment_ids() {
2103            // Count how many downstreams we have for a given fragment.
2104            let downstream_cnt = self.get_downstreams(fragment_id).count();
2105            if downstream_cnt == 0 {
2106                topo.push(fragment_id);
2107            } else {
2108                downstream_cnts.insert(fragment_id, downstream_cnt);
2109            }
2110        }
2111
2112        let mut i = 0;
2113        while let Some(&fragment_id) = topo.get(i) {
2114            i += 1;
2115            // Find if we can process more fragments.
2116            for (upstream_job_id, _) in self.get_upstreams(fragment_id) {
2117                let downstream_cnt = downstream_cnts.get_mut(&upstream_job_id).unwrap();
2118                *downstream_cnt -= 1;
2119                if *downstream_cnt == 0 {
2120                    downstream_cnts.remove(&upstream_job_id);
2121                    topo.push(upstream_job_id);
2122                }
2123            }
2124        }
2125
2126        if !downstream_cnts.is_empty() {
2127            // There are fragments that are not processed yet.
2128            bail!("graph is not a DAG");
2129        }
2130
2131        Ok(topo)
2132    }
2133
2134    /// Seal a [`BuildingFragment`] from the graph into a [`Fragment`], which will be further used
2135    /// to build actors on the compute nodes and persist into meta store.
2136    pub(super) fn seal_fragment(
2137        &self,
2138        id: GlobalFragmentId,
2139        distribution: Distribution,
2140        stream_node: StreamNode,
2141    ) -> Fragment {
2142        let building_fragment = self.get_fragment(id).into_building().unwrap();
2143        let internal_tables = building_fragment.extract_internal_tables();
2144        let BuildingFragment {
2145            inner,
2146            job_id,
2147            upstream_job_columns: _,
2148        } = building_fragment;
2149
2150        let distribution_type = distribution.to_distribution_type();
2151        let vnode_count = distribution.vnode_count();
2152
2153        let materialized_fragment_id =
2154            if FragmentTypeMask::from(inner.fragment_type_mask).contains(FragmentTypeFlag::Mview) {
2155                job_id.map(JobId::as_mv_table_id)
2156            } else {
2157                None
2158            };
2159
2160        let vector_index_fragment_id =
2161            if inner.fragment_type_mask & FragmentTypeFlag::VectorIndexWrite as u32 != 0 {
2162                job_id.map(JobId::as_mv_table_id)
2163            } else {
2164                None
2165            };
2166
2167        let state_table_ids = internal_tables
2168            .iter()
2169            .map(|t| t.id)
2170            .chain(materialized_fragment_id)
2171            .chain(vector_index_fragment_id)
2172            .collect();
2173
2174        Fragment {
2175            fragment_id: inner.fragment_id,
2176            fragment_type_mask: inner.fragment_type_mask.into(),
2177            distribution_type,
2178            state_table_ids,
2179            maybe_vnode_count: VnodeCount::set(vnode_count).to_protobuf(),
2180            nodes: stream_node,
2181        }
2182    }
2183
2184    /// Get a fragment from the complete graph, which can be either a building fragment or an
2185    /// existing fragment.
2186    pub(super) fn get_fragment(&self, fragment_id: GlobalFragmentId) -> EitherFragment {
2187        if self.existing_fragments.contains_key(&fragment_id) {
2188            EitherFragment::Existing
2189        } else {
2190            EitherFragment::Building(
2191                self.building_graph
2192                    .fragments
2193                    .get(&fragment_id)
2194                    .unwrap()
2195                    .clone(),
2196            )
2197        }
2198    }
2199
2200    /// Get **all** downstreams of a fragment, including the ones that are not in the building
2201    /// graph.
2202    pub(super) fn get_downstreams(
2203        &self,
2204        fragment_id: GlobalFragmentId,
2205    ) -> impl Iterator<Item = (GlobalFragmentId, &StreamFragmentEdge)> {
2206        self.building_graph
2207            .get_downstreams(fragment_id)
2208            .iter()
2209            .chain(
2210                self.extra_downstreams
2211                    .get(&fragment_id)
2212                    .into_iter()
2213                    .flatten(),
2214            )
2215            .map(|(&id, edge)| (id, edge))
2216    }
2217
2218    /// Get **all** upstreams of a fragment, including the ones that are not in the building
2219    /// graph.
2220    pub(super) fn get_upstreams(
2221        &self,
2222        fragment_id: GlobalFragmentId,
2223    ) -> impl Iterator<Item = (GlobalFragmentId, &StreamFragmentEdge)> {
2224        self.building_graph
2225            .get_upstreams(fragment_id)
2226            .iter()
2227            .chain(self.extra_upstreams.get(&fragment_id).into_iter().flatten())
2228            .map(|(&id, edge)| (id, edge))
2229    }
2230
2231    /// Returns all building fragments in the graph.
2232    pub(super) fn building_fragments(&self) -> &HashMap<GlobalFragmentId, BuildingFragment> {
2233        &self.building_graph.fragments
2234    }
2235
2236    /// Returns all building fragments in the graph, mutable.
2237    pub(super) fn building_fragments_mut(
2238        &mut self,
2239    ) -> &mut HashMap<GlobalFragmentId, BuildingFragment> {
2240        &mut self.building_graph.fragments
2241    }
2242
2243    /// Get the expected vnode count of the building graph. See documentation of the field for more details.
2244    pub(super) fn max_parallelism(&self) -> usize {
2245        self.building_graph.max_parallelism()
2246    }
2247}
2248
2249#[cfg(test)]
2250mod tests {
2251    use risingwave_common::catalog::{ColumnDesc, ColumnId};
2252    use risingwave_common::types::DataType;
2253    use risingwave_pb::catalog::SinkType as PbSinkType;
2254    use risingwave_pb::meta::table_fragments::fragment::PbFragmentDistributionType;
2255    use risingwave_pb::plan_common::StorageTableDesc;
2256    use risingwave_pb::stream_plan::{
2257        BatchPlanNode, MergeNode, ProjectNode, SinkDesc, SinkLogStoreType, SinkNode, StreamNode,
2258        StreamScanNode, StreamScanType,
2259    };
2260
2261    use super::*;
2262
2263    fn make_column(name: &str, id: i32, data_type: DataType) -> ColumnCatalog {
2264        ColumnCatalog::visible(ColumnDesc::named(name, ColumnId::new(id), data_type))
2265    }
2266
2267    fn make_field(table_name: &str, column: &ColumnCatalog) -> risingwave_pb::plan_common::Field {
2268        Field::new(
2269            format!("{}.{}", table_name, column.column_desc.name),
2270            column.data_type().clone(),
2271        )
2272        .to_prost()
2273    }
2274
2275    fn make_input_ref(index: u32, data_type: &DataType) -> PbExprNode {
2276        PbExprNode {
2277            function_type: expr_node::Type::Unspecified as i32,
2278            return_type: Some(data_type.to_protobuf()),
2279            rex_node: Some(expr_node::RexNode::InputRef(index)),
2280        }
2281    }
2282
2283    fn make_stream_scan_node(
2284        table_name: &str,
2285        table_id: u32,
2286        columns: &[ColumnCatalog],
2287    ) -> StreamNode {
2288        let merge_node = StreamNode {
2289            node_body: Some(NodeBody::Merge(Box::new(MergeNode {
2290                upstream_fragment_id: 0.into(),
2291                ..Default::default()
2292            }))),
2293            fields: columns
2294                .iter()
2295                .map(|col| make_field(table_name, col))
2296                .collect(),
2297            ..Default::default()
2298        };
2299        let batch_plan_node = StreamNode {
2300            node_body: Some(NodeBody::BatchPlan(Box::new(BatchPlanNode {
2301                ..Default::default()
2302            }))),
2303            ..Default::default()
2304        };
2305        let stream_scan_node = StreamScanNode {
2306            table_id: table_id.into(),
2307            upstream_column_ids: columns.iter().map(|c| c.column_id().get_id()).collect(),
2308            output_indices: (0..columns.len()).map(|i| i as u32).collect(),
2309            stream_scan_type: StreamScanType::ArrangementBackfill as i32,
2310            table_desc: Some(StorageTableDesc {
2311                table_id: table_id.into(),
2312                columns: columns
2313                    .iter()
2314                    .map(|col| col.column_desc.to_protobuf())
2315                    .collect(),
2316                value_indices: (0..columns.len()).map(|i| i as u32).collect(),
2317                versioned: true,
2318                ..Default::default()
2319            }),
2320            ..Default::default()
2321        };
2322        StreamNode {
2323            node_body: Some(NodeBody::StreamScan(Box::new(stream_scan_node))),
2324            fields: columns
2325                .iter()
2326                .map(|col| make_field(table_name, col))
2327                .collect(),
2328            input: vec![merge_node, batch_plan_node],
2329            ..Default::default()
2330        }
2331    }
2332
2333    fn make_project_node(
2334        table_name: &str,
2335        columns: &[ColumnCatalog],
2336        input: StreamNode,
2337    ) -> StreamNode {
2338        let select_list = columns
2339            .iter()
2340            .enumerate()
2341            .map(|(i, col)| make_input_ref(i as u32, col.data_type()))
2342            .collect();
2343        StreamNode {
2344            node_body: Some(NodeBody::Project(Box::new(ProjectNode {
2345                select_list,
2346                ..Default::default()
2347            }))),
2348            fields: columns
2349                .iter()
2350                .map(|col| make_field(table_name, col))
2351                .collect(),
2352            input: vec![input],
2353            ..Default::default()
2354        }
2355    }
2356
2357    #[tokio::test]
2358    async fn test_rewrite_refresh_schema_sink_fragment_with_project() {
2359        let env = MetaSrvEnv::for_test().await;
2360        let id_gen_manager = env.id_gen_manager().as_ref();
2361
2362        let table_name = "t";
2363        let columns = vec![
2364            make_column("a", 1, DataType::Int64),
2365            make_column("b", 2, DataType::Int64),
2366        ];
2367        let new_column = make_column("c", 3, DataType::Varchar);
2368
2369        let mut upstream_columns = columns.clone();
2370        upstream_columns.push(new_column.clone());
2371        let upstream_table = PbTable {
2372            name: table_name.to_owned(),
2373            columns: upstream_columns
2374                .iter()
2375                .map(|col| col.to_protobuf())
2376                .collect(),
2377            ..Default::default()
2378        };
2379
2380        let sink = PbSink {
2381            columns: columns.iter().map(|col| col.to_protobuf()).collect(),
2382            sink_type: PbSinkType::AppendOnly as i32,
2383            ..Default::default()
2384        };
2385
2386        let sink_desc = SinkDesc {
2387            sink_type: PbSinkType::AppendOnly as i32,
2388            column_catalogs: sink.columns.clone(),
2389            ..Default::default()
2390        };
2391
2392        let stream_scan_node = make_stream_scan_node(table_name, 1, &columns);
2393        let project_node = make_project_node(table_name, &columns, stream_scan_node);
2394
2395        let log_store_table = PbTable {
2396            columns: columns
2397                .iter()
2398                .cloned()
2399                .map(|mut col| {
2400                    col.column_desc.name = format!("{}_{}", table_name, col.column_desc.name);
2401                    col.to_protobuf()
2402                })
2403                .collect(),
2404            value_indices: (0..columns.len()).map(|i| i as i32).collect(),
2405            ..Default::default()
2406        };
2407
2408        let original_fragment = Fragment {
2409            fragment_id: 1.into(),
2410            fragment_type_mask: FragmentTypeMask::default(),
2411            distribution_type: PbFragmentDistributionType::Single,
2412            state_table_ids: vec![],
2413            maybe_vnode_count: None,
2414            nodes: StreamNode {
2415                node_body: Some(NodeBody::Sink(Box::new(SinkNode {
2416                    sink_desc: Some(sink_desc),
2417                    table: Some(log_store_table),
2418                    ..Default::default()
2419                }))),
2420                fields: columns
2421                    .iter()
2422                    .map(|col| make_field(table_name, col))
2423                    .collect(),
2424                input: vec![project_node],
2425                ..Default::default()
2426            },
2427        };
2428
2429        let (new_fragment, _, _) = rewrite_refresh_schema_sink_fragment(
2430            &original_fragment,
2431            &sink,
2432            std::slice::from_ref(&new_column),
2433            &[],
2434            &upstream_table,
2435            7.into(),
2436            id_gen_manager,
2437        )
2438        .unwrap();
2439
2440        let sink_node = &new_fragment.nodes;
2441        let [project_node] = sink_node.input.as_slice() else {
2442            panic!("Sink has more than 1 input: {:?}", sink_node.input);
2443        };
2444        let PbNodeBody::Project(project_body) = project_node.node_body.as_ref().unwrap() else {
2445            panic!(
2446                "expect PbNodeBody::Project but got: {:?}",
2447                project_node.node_body
2448            );
2449        };
2450        assert_eq!(project_body.select_list.len(), columns.len() + 1);
2451        let last_expr = project_body.select_list.last().unwrap();
2452        assert!(
2453            matches!(last_expr.rex_node, Some(expr_node::RexNode::InputRef(idx)) if idx == columns.len() as u32)
2454        );
2455        assert_eq!(project_node.fields.len(), columns.len() + 1);
2456
2457        let [stream_scan_node] = project_node.input.as_slice() else {
2458            panic!("Project has more than 1 input: {:?}", project_node.input);
2459        };
2460        let PbNodeBody::StreamScan(scan) = stream_scan_node.node_body.as_ref().unwrap() else {
2461            panic!(
2462                "expect PbNodeBody::StreamScan but got: {:?}",
2463                stream_scan_node.node_body
2464            );
2465        };
2466        assert_eq!(
2467            scan.upstream_column_ids.last().copied(),
2468            Some(new_column.column_id().get_id())
2469        );
2470        assert_eq!(
2471            scan.output_indices.last().copied(),
2472            Some(columns.len() as u32)
2473        );
2474        assert_eq!(
2475            stream_scan_node.fields.last().unwrap().name,
2476            format!("{}.{}", table_name, new_column.column_desc.name)
2477        );
2478    }
2479
2480    #[tokio::test]
2481    async fn test_rewrite_refresh_schema_sink_fragment_drop_column_with_project() {
2482        let env = MetaSrvEnv::for_test().await;
2483        let id_gen_manager = env.id_gen_manager().as_ref();
2484
2485        let table_name = "t";
2486        let columns = vec![
2487            make_column("a", 1, DataType::Int64),
2488            make_column("b", 2, DataType::Int64),
2489            make_column("tmp", 3, DataType::Varchar),
2490        ];
2491        let removed_column = columns.last().unwrap().clone();
2492        let upstream_columns = columns[..2].to_vec();
2493
2494        let upstream_table = PbTable {
2495            name: table_name.to_owned(),
2496            columns: upstream_columns
2497                .iter()
2498                .map(|col| col.to_protobuf())
2499                .collect(),
2500            ..Default::default()
2501        };
2502
2503        let sink = PbSink {
2504            columns: columns.iter().map(|col| col.to_protobuf()).collect(),
2505            sink_type: PbSinkType::AppendOnly as i32,
2506            ..Default::default()
2507        };
2508
2509        let sink_desc = SinkDesc {
2510            sink_type: PbSinkType::AppendOnly as i32,
2511            column_catalogs: sink.columns.clone(),
2512            ..Default::default()
2513        };
2514
2515        let stream_scan_node = make_stream_scan_node(table_name, 1, &columns);
2516        let project_node = make_project_node(table_name, &columns, stream_scan_node);
2517
2518        let log_store_table = PbTable {
2519            columns: columns
2520                .iter()
2521                .cloned()
2522                .map(|mut col| {
2523                    col.column_desc.name = format!("{}_{}", table_name, col.column_desc.name);
2524                    col.to_protobuf()
2525                })
2526                .collect(),
2527            value_indices: (0..columns.len()).map(|i| i as i32).collect(),
2528            ..Default::default()
2529        };
2530
2531        let original_fragment = Fragment {
2532            fragment_id: 1.into(),
2533            fragment_type_mask: FragmentTypeMask::default(),
2534            distribution_type: PbFragmentDistributionType::Single,
2535            state_table_ids: vec![],
2536            maybe_vnode_count: None,
2537            nodes: StreamNode {
2538                node_body: Some(NodeBody::Sink(Box::new(SinkNode {
2539                    sink_desc: Some(sink_desc),
2540                    table: Some(log_store_table),
2541                    log_store_type: SinkLogStoreType::KvLogStore as i32,
2542                    ..Default::default()
2543                }))),
2544                fields: columns
2545                    .iter()
2546                    .map(|col| make_field(table_name, col))
2547                    .collect(),
2548                input: vec![project_node],
2549                ..Default::default()
2550            },
2551        };
2552
2553        let (new_fragment, new_schema, new_log_store_table) = rewrite_refresh_schema_sink_fragment(
2554            &original_fragment,
2555            &sink,
2556            &[],
2557            std::slice::from_ref(&removed_column),
2558            &upstream_table,
2559            7.into(),
2560            id_gen_manager,
2561        )
2562        .unwrap();
2563
2564        assert_eq!(new_schema.len(), 2);
2565        assert!(
2566            new_schema.iter().all(|col| {
2567                col.column_desc.as_ref().map(|desc| desc.name.as_str()) != Some("tmp")
2568            })
2569        );
2570
2571        let sink_node = &new_fragment.nodes;
2572        let [project_node] = sink_node.input.as_slice() else {
2573            panic!("Sink has more than 1 input: {:?}", sink_node.input);
2574        };
2575        let PbNodeBody::Project(project_body) = project_node.node_body.as_ref().unwrap() else {
2576            panic!(
2577                "expect PbNodeBody::Project but got: {:?}",
2578                project_node.node_body
2579            );
2580        };
2581        assert_eq!(project_body.select_list.len(), 2);
2582        assert!(project_node.fields.iter().all(|f| !f.name.contains("tmp")));
2583
2584        let [stream_scan_node] = project_node.input.as_slice() else {
2585            panic!("Project has more than 1 input: {:?}", project_node.input);
2586        };
2587        let PbNodeBody::StreamScan(scan) = stream_scan_node.node_body.as_ref().unwrap() else {
2588            panic!(
2589                "expect PbNodeBody::StreamScan but got: {:?}",
2590                stream_scan_node.node_body
2591            );
2592        };
2593        assert!(
2594            !scan
2595                .upstream_column_ids
2596                .iter()
2597                .any(|&id| id == removed_column.column_id().get_id())
2598        );
2599        assert!(
2600            stream_scan_node
2601                .fields
2602                .iter()
2603                .all(|f| !f.name.contains("tmp"))
2604        );
2605
2606        let new_log_store_table = new_log_store_table.expect("log store table should be updated");
2607        assert!(
2608            new_log_store_table.columns.iter().all(|col| !col
2609                .column_desc
2610                .as_ref()
2611                .unwrap()
2612                .name
2613                .contains("tmp"))
2614        );
2615        assert_eq!(
2616            new_log_store_table.value_indices,
2617            (0..new_log_store_table.columns.len() as i32).collect::<Vec<_>>()
2618        );
2619    }
2620}