risingwave_meta/stream/stream_graph/
fragment.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap, HashSet};
16use std::num::NonZeroUsize;
17use std::ops::{Deref, DerefMut};
18use std::sync::LazyLock;
19
20use anyhow::{Context, anyhow};
21use enum_as_inner::EnumAsInner;
22use itertools::Itertools;
23use risingwave_common::bail;
24use risingwave_common::catalog::{
25    CDC_SOURCE_COLUMN_NUM, ColumnCatalog, ColumnId, Field, FragmentTypeFlag, FragmentTypeMask,
26    TableId, generate_internal_table_name_with_type,
27};
28use risingwave_common::hash::VnodeCount;
29use risingwave_common::id::JobId;
30use risingwave_common::util::iter_util::ZipEqFast;
31use risingwave_common::util::stream_graph_visitor::{
32    self, visit_stream_node_cont, visit_stream_node_cont_mut,
33};
34use risingwave_connector::sink::catalog::SinkType;
35use risingwave_meta_model::streaming_job::BackfillOrders;
36use risingwave_pb::catalog::{PbSink, PbTable, Table};
37use risingwave_pb::ddl_service::TableJobType;
38use risingwave_pb::expr::{ExprNode as PbExprNode, expr_node};
39use risingwave_pb::id::{RelationId, StreamNodeLocalOperatorId};
40use risingwave_pb::plan_common::{PbColumnCatalog, PbColumnDesc};
41use risingwave_pb::stream_plan::dispatch_output_mapping::TypePair;
42use risingwave_pb::stream_plan::stream_fragment_graph::{
43    Parallelism, StreamFragment, StreamFragmentEdge as StreamFragmentEdgeProto,
44};
45use risingwave_pb::stream_plan::stream_node::{NodeBody, PbNodeBody};
46use risingwave_pb::stream_plan::{
47    BackfillOrder, DispatchOutputMapping, DispatchStrategy, DispatcherType, PbStreamNode,
48    PbStreamScanType, StreamFragmentGraph as StreamFragmentGraphProto, StreamNode, StreamScanNode,
49    StreamScanType,
50};
51
52use crate::barrier::SnapshotBackfillInfo;
53use crate::controller::id::IdGeneratorManager;
54use crate::manager::{MetaSrvEnv, StreamingJob, StreamingJobType};
55use crate::model::{Fragment, FragmentDownstreamRelation, FragmentId};
56use crate::stream::stream_graph::id::{GlobalFragmentId, GlobalFragmentIdGen, GlobalTableIdGen};
57use crate::stream::stream_graph::schedule::Distribution;
58use crate::{MetaError, MetaResult};
59
60/// The fragment in the building phase, including the [`StreamFragment`] from the frontend and
61/// several additional helper fields.
62#[derive(Debug, Clone)]
63pub(super) struct BuildingFragment {
64    /// The fragment structure from the frontend, with the global fragment ID.
65    inner: StreamFragment,
66
67    /// The ID of the job if it contains the streaming job node.
68    job_id: Option<JobId>,
69
70    /// The required column IDs of each upstream table.
71    /// Will be converted to indices when building the edge connected to the upstream.
72    ///
73    /// For shared CDC table on source, its `vec![]`, since the upstream source's output schema is fixed.
74    upstream_job_columns: HashMap<JobId, Vec<PbColumnDesc>>,
75}
76
77impl BuildingFragment {
78    /// Create a new [`BuildingFragment`] from a [`StreamFragment`]. The global fragment ID and
79    /// global table IDs will be correctly filled with the given `id` and `table_id_gen`.
80    fn new(
81        id: GlobalFragmentId,
82        fragment: StreamFragment,
83        job: &StreamingJob,
84        table_id_gen: GlobalTableIdGen,
85    ) -> Self {
86        let mut fragment = StreamFragment {
87            fragment_id: id.as_global_id(),
88            ..fragment
89        };
90
91        // Fill the information of the internal tables in the fragment.
92        Self::fill_internal_tables(&mut fragment, job, table_id_gen);
93
94        let job_id = Self::fill_job(&mut fragment, job).then(|| job.id());
95        let upstream_job_columns =
96            Self::extract_upstream_columns_except_cross_db_backfill(&fragment);
97
98        Self {
99            inner: fragment,
100            job_id,
101            upstream_job_columns,
102        }
103    }
104
105    /// Extract the internal tables from the fragment.
106    fn extract_internal_tables(&self) -> Vec<Table> {
107        let mut fragment = self.inner.clone();
108        let mut tables = Vec::new();
109        stream_graph_visitor::visit_internal_tables(&mut fragment, |table, _| {
110            tables.push(table.clone());
111        });
112        tables
113    }
114
115    /// Fill the information with the internal tables in the fragment.
116    fn fill_internal_tables(
117        fragment: &mut StreamFragment,
118        job: &StreamingJob,
119        table_id_gen: GlobalTableIdGen,
120    ) {
121        let fragment_id = fragment.fragment_id;
122        stream_graph_visitor::visit_internal_tables(fragment, |table, table_type_name| {
123            table.id = table_id_gen
124                .to_global_id(table.id.as_raw_id())
125                .as_global_id();
126            table.schema_id = job.schema_id();
127            table.database_id = job.database_id();
128            table.name = generate_internal_table_name_with_type(
129                &job.name(),
130                fragment_id,
131                table.id,
132                table_type_name,
133            );
134            table.fragment_id = fragment_id;
135            table.owner = job.owner();
136            table.job_id = Some(job.id());
137        });
138    }
139
140    /// Fill the information with the job in the fragment.
141    fn fill_job(fragment: &mut StreamFragment, job: &StreamingJob) -> bool {
142        let job_id = job.id();
143        let fragment_id = fragment.fragment_id;
144        let mut has_job = false;
145
146        stream_graph_visitor::visit_fragment_mut(fragment, |node_body| match node_body {
147            NodeBody::Materialize(materialize_node) => {
148                materialize_node.table_id = job_id.as_mv_table_id();
149
150                // Fill the table field of `MaterializeNode` from the job.
151                let table = materialize_node.table.insert(job.table().unwrap().clone());
152                table.fragment_id = fragment_id; // this will later be synced back to `job.table` with `set_info_from_graph`
153                // In production, do not include full definition in the table in plan node.
154                if cfg!(not(debug_assertions)) {
155                    table.definition = job.name();
156                }
157
158                has_job = true;
159            }
160            NodeBody::Sink(sink_node) => {
161                sink_node.sink_desc.as_mut().unwrap().id = job_id.as_sink_id();
162
163                has_job = true;
164            }
165            NodeBody::IcebergWithPkIndexWriter(writer_node) => {
166                writer_node.sink_desc.as_mut().unwrap().id = job_id.as_sink_id();
167
168                has_job = true;
169            }
170            NodeBody::IcebergWithPkIndexDvMerger(merger_node) => {
171                merger_node.sink_desc.as_mut().unwrap().id = job_id.as_sink_id();
172
173                has_job = true;
174            }
175            NodeBody::Dml(dml_node) => {
176                dml_node.table_id = job_id.as_mv_table_id();
177                dml_node.table_version_id = job.table_version_id().unwrap();
178            }
179            NodeBody::StreamFsFetch(fs_fetch_node) => {
180                if let StreamingJob::Table(table_source, _, _) = job
181                    && let Some(node_inner) = fs_fetch_node.node_inner.as_mut()
182                    && let Some(source) = table_source
183                {
184                    node_inner.source_id = source.id;
185                    if let Some(id) = source.optional_associated_table_id {
186                        node_inner.associated_table_id = Some(id.into());
187                    }
188                }
189            }
190            NodeBody::Source(source_node) => {
191                match job {
192                    // Note: For table without connector, it has a dummy Source node.
193                    // Note: For table with connector, it's source node has a source id different with the table id (job id), assigned in create_job_catalog.
194                    StreamingJob::Table(source, _table, _table_job_type) => {
195                        if let Some(source_inner) = source_node.source_inner.as_mut()
196                            && let Some(source) = source
197                        {
198                            debug_assert_ne!(source.id, job_id.as_raw_id());
199                            source_inner.source_id = source.id;
200                            if let Some(id) = source.optional_associated_table_id {
201                                source_inner.associated_table_id = Some(id.into());
202                            }
203                        }
204                    }
205                    StreamingJob::Source(source) => {
206                        has_job = true;
207                        if let Some(source_inner) = source_node.source_inner.as_mut() {
208                            debug_assert_eq!(source.id, job_id.as_raw_id());
209                            source_inner.source_id = source.id;
210                            if let Some(id) = source.optional_associated_table_id {
211                                source_inner.associated_table_id = Some(id.into());
212                            }
213                        }
214                    }
215                    // For other job types, no need to fill the source id, since it refers to an existing source.
216                    _ => {}
217                }
218            }
219            NodeBody::StreamCdcScan(node) => {
220                if let Some(table_desc) = node.cdc_table_desc.as_mut() {
221                    table_desc.table_id = job_id.as_mv_table_id();
222                }
223            }
224            NodeBody::VectorIndexWrite(node) => {
225                let table = node.table.as_mut().unwrap();
226                table.id = job_id.as_mv_table_id();
227                table.database_id = job.database_id();
228                table.schema_id = job.schema_id();
229                table.fragment_id = fragment_id;
230                #[cfg(not(debug_assertions))]
231                {
232                    table.definition = job.name();
233                }
234
235                has_job = true;
236            }
237            _ => {}
238        });
239
240        has_job
241    }
242
243    /// Extract the required columns of each upstream table except for cross-db backfill.
244    fn extract_upstream_columns_except_cross_db_backfill(
245        fragment: &StreamFragment,
246    ) -> HashMap<JobId, Vec<PbColumnDesc>> {
247        let mut table_columns = HashMap::new();
248
249        stream_graph_visitor::visit_fragment(fragment, |node_body| {
250            let (table_id, column_ids) = match node_body {
251                NodeBody::StreamScan(stream_scan) => {
252                    if stream_scan.get_stream_scan_type().unwrap()
253                        == StreamScanType::CrossDbSnapshotBackfill
254                    {
255                        return;
256                    }
257                    (
258                        stream_scan.table_id.as_job_id(),
259                        stream_scan.upstream_columns(),
260                    )
261                }
262                NodeBody::CdcFilter(cdc_filter) => (
263                    cdc_filter.upstream_source_id.as_share_source_job_id(),
264                    vec![],
265                ),
266                NodeBody::SourceBackfill(backfill) => (
267                    backfill.upstream_source_id.as_share_source_job_id(),
268                    // FIXME: only pass required columns instead of all columns here
269                    backfill.column_descs(),
270                ),
271                _ => return,
272            };
273            table_columns
274                .try_insert(table_id, column_ids)
275                .expect("currently there should be no two same upstream tables in a fragment");
276        });
277
278        table_columns
279    }
280
281    pub fn has_shuffled_backfill(&self) -> bool {
282        let stream_node = match self.inner.node.as_ref() {
283            Some(node) => node,
284            _ => return false,
285        };
286        let mut has_shuffled_backfill = false;
287        let has_shuffled_backfill_mut_ref = &mut has_shuffled_backfill;
288        visit_stream_node_cont(stream_node, |node| {
289            let is_shuffled_backfill = if let Some(node) = &node.node_body
290                && let Some(node) = node.as_stream_scan()
291            {
292                node.stream_scan_type == StreamScanType::ArrangementBackfill as i32
293                    || node.stream_scan_type == StreamScanType::SnapshotBackfill as i32
294            } else {
295                false
296            };
297            if is_shuffled_backfill {
298                *has_shuffled_backfill_mut_ref = true;
299                false
300            } else {
301                true
302            }
303        });
304        has_shuffled_backfill
305    }
306}
307
308impl Deref for BuildingFragment {
309    type Target = StreamFragment;
310
311    fn deref(&self) -> &Self::Target {
312        &self.inner
313    }
314}
315
316impl DerefMut for BuildingFragment {
317    fn deref_mut(&mut self) -> &mut Self::Target {
318        &mut self.inner
319    }
320}
321
322/// The ID of an edge in the fragment graph. For different types of edges, the ID will be in
323/// different variants.
324#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, EnumAsInner)]
325pub(super) enum EdgeId {
326    /// The edge between two building (internal) fragments.
327    Internal {
328        /// The ID generated by the frontend, generally the operator ID of `Exchange`.
329        /// See [`StreamFragmentEdgeProto`].
330        link_id: u64,
331    },
332
333    /// The edge between an upstream external fragment and downstream building fragment. Used for
334    /// MV on MV.
335    UpstreamExternal {
336        /// The ID of the upstream table or materialized view.
337        upstream_job_id: JobId,
338        /// The ID of the downstream fragment.
339        downstream_fragment_id: GlobalFragmentId,
340    },
341
342    /// The edge between an upstream building fragment and downstream external fragment. Used for
343    /// schema change (replace table plan).
344    DownstreamExternal(DownstreamExternalEdgeId),
345}
346
347#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
348pub(super) struct DownstreamExternalEdgeId {
349    /// The ID of the original upstream fragment (`Materialize`).
350    pub(super) original_upstream_fragment_id: GlobalFragmentId,
351    /// The ID of the downstream fragment.
352    pub(super) downstream_fragment_id: GlobalFragmentId,
353}
354
355/// The edge in the fragment graph.
356///
357/// The edge can be either internal or external. This is distinguished by the [`EdgeId`].
358#[derive(Debug, Clone)]
359pub(super) struct StreamFragmentEdge {
360    /// The ID of the edge.
361    pub id: EdgeId,
362
363    /// The strategy used for dispatching the data.
364    pub dispatch_strategy: DispatchStrategy,
365}
366
367impl StreamFragmentEdge {
368    fn from_protobuf(edge: &StreamFragmentEdgeProto) -> Self {
369        Self {
370            // By creating an edge from the protobuf, we know that the edge is from the frontend and
371            // is internal.
372            id: EdgeId::Internal {
373                link_id: edge.link_id,
374            },
375            dispatch_strategy: edge.get_dispatch_strategy().unwrap().clone(),
376        }
377    }
378}
379
380fn clone_fragment(fragment: &Fragment, id_generator_manager: &IdGeneratorManager) -> Fragment {
381    let fragment_id = GlobalFragmentIdGen::new(id_generator_manager, 1)
382        .to_global_id(0)
383        .as_global_id();
384    Fragment {
385        fragment_id,
386        fragment_type_mask: fragment.fragment_type_mask,
387        distribution_type: fragment.distribution_type,
388        state_table_ids: fragment.state_table_ids.clone(),
389        maybe_vnode_count: fragment.maybe_vnode_count,
390        nodes: fragment.nodes.clone(),
391    }
392}
393
394pub fn check_sink_fragments_support_refresh_schema(
395    fragments: &BTreeMap<FragmentId, Fragment>,
396) -> MetaResult<()> {
397    if fragments.len() != 1 {
398        return Err(anyhow!(
399            "sink with auto schema change should have only 1 fragment, but got {:?}",
400            fragments.len()
401        )
402        .into());
403    }
404    let (_, fragment) = fragments.first_key_value().expect("non-empty");
405    let sink_node = &fragment.nodes;
406    let PbNodeBody::Sink(_) = sink_node.node_body.as_ref().unwrap() else {
407        return Err(anyhow!("expect PbNodeBody::Sink but got: {:?}", sink_node.node_body).into());
408    };
409    let [stream_input_node] = sink_node.input.as_slice() else {
410        panic!("Sink has more than 1 input: {:?}", sink_node.input);
411    };
412    let stream_scan_node = match stream_input_node.node_body.as_ref().unwrap() {
413        PbNodeBody::StreamScan(_) => stream_input_node,
414        PbNodeBody::Project(_) => {
415            let [stream_scan_node] = stream_input_node.input.as_slice() else {
416                return Err(anyhow!(
417                    "Project node must have exactly 1 input for auto schema change, but got {:?}",
418                    stream_input_node.input.len()
419                )
420                .into());
421            };
422            stream_scan_node
423        }
424        _ => {
425            return Err(anyhow!(
426                "expect PbNodeBody::StreamScan or PbNodeBody::Project but got: {:?}",
427                stream_input_node.node_body
428            )
429            .into());
430        }
431    };
432    let PbNodeBody::StreamScan(scan) = stream_scan_node.node_body.as_ref().unwrap() else {
433        return Err(anyhow!(
434            "expect PbNodeBody::StreamScan but got: {:?}",
435            stream_scan_node.node_body
436        )
437        .into());
438    };
439    let stream_scan_type = PbStreamScanType::try_from(scan.stream_scan_type).unwrap();
440    if stream_scan_type != PbStreamScanType::ArrangementBackfill {
441        return Err(anyhow!(
442            "unsupported stream_scan_type for auto refresh schema: {:?}",
443            stream_scan_type
444        )
445        .into());
446    }
447    let [merge_node, _batch_plan_node] = stream_scan_node.input.as_slice() else {
448        panic!(
449            "the number of StreamScan inputs is not 2: {:?}",
450            stream_scan_node.input
451        );
452    };
453    let NodeBody::Merge(_) = merge_node.node_body.as_ref().unwrap() else {
454        return Err(anyhow!(
455            "expect PbNodeBody::Merge but got: {:?}",
456            merge_node.node_body
457        )
458        .into());
459    };
460    Ok(())
461}
462
463/// Output mapping info after rewriting a `StreamScan` node.
464struct ScanRewriteResult {
465    old_output_index_to_new_output_index: HashMap<u32, u32>,
466    new_output_index_by_column_id: HashMap<ColumnId, u32>,
467    output_fields: Vec<risingwave_pb::plan_common::Field>,
468}
469
470/// Append new columns to a sink/log-store column list with updated names/ids.
471fn extend_sink_columns(
472    sink_columns: &mut Vec<PbColumnCatalog>,
473    new_columns: &[ColumnCatalog],
474    get_column_name: impl Fn(&String) -> String,
475) {
476    let next_column_id = sink_columns
477        .iter()
478        .map(|col| col.column_desc.as_ref().unwrap().column_id + 1)
479        .max()
480        .unwrap_or(1);
481    sink_columns.extend(new_columns.iter().enumerate().map(|(i, col)| {
482        let mut col = col.to_protobuf();
483        let column_desc = col.column_desc.as_mut().unwrap();
484        column_desc.column_id = next_column_id + (i as i32);
485        column_desc.name = get_column_name(&column_desc.name);
486        col
487    }));
488}
489
490/// Build sink column list after removing and appending columns.
491fn build_new_sink_columns(
492    sink: &PbSink,
493    removed_column_names: &HashSet<String>,
494    newly_added_columns: &[ColumnCatalog],
495) -> Vec<PbColumnCatalog> {
496    let mut columns: Vec<PbColumnCatalog> = sink
497        .columns
498        .iter()
499        .filter(|col| {
500            let column_name = &col.column_desc.as_ref().unwrap().name;
501            !removed_column_names.contains(column_name)
502        })
503        .cloned()
504        .collect();
505    extend_sink_columns(&mut columns, newly_added_columns, |name| name.clone());
506    columns
507}
508
509/// Rewrite log store table columns for schema change.
510fn rewrite_log_store_table(
511    log_store_table: &mut PbTable,
512    removed_log_store_column_names: &HashSet<String>,
513    newly_added_columns: &[ColumnCatalog],
514    upstream_table_name: &str,
515) {
516    log_store_table.columns.retain(|col| {
517        !removed_log_store_column_names.contains(&col.column_desc.as_ref().unwrap().name)
518    });
519    extend_sink_columns(&mut log_store_table.columns, newly_added_columns, |name| {
520        format!("{}_{}", upstream_table_name, name)
521    });
522    log_store_table.value_indices = (0..log_store_table.columns.len() as i32).collect();
523}
524
525/// Rewrite `StreamScan` + Merge to match the new upstream schema.
526fn rewrite_stream_scan_and_merge(
527    stream_scan_node: &mut StreamNode,
528    removed_column_ids: &HashSet<ColumnId>,
529    newly_added_columns: &[ColumnCatalog],
530    upstream_table: &PbTable,
531    upstream_table_fragment_id: FragmentId,
532) -> MetaResult<ScanRewriteResult> {
533    let PbNodeBody::StreamScan(scan) = stream_scan_node.node_body.as_mut().unwrap() else {
534        return Err(anyhow!(
535            "expect PbNodeBody::StreamScan but got: {:?}",
536            stream_scan_node.node_body
537        )
538        .into());
539    };
540    let [merge_node, _batch_plan_node] = stream_scan_node.input.as_mut_slice() else {
541        panic!(
542            "the number of StreamScan inputs is not 2: {:?}",
543            stream_scan_node.input
544        );
545    };
546    let NodeBody::Merge(merge) = merge_node.node_body.as_mut().unwrap() else {
547        return Err(anyhow!(
548            "expect PbNodeBody::Merge but got: {:?}",
549            merge_node.node_body
550        )
551        .into());
552    };
553
554    let stream_scan_type = PbStreamScanType::try_from(scan.stream_scan_type).unwrap();
555    if stream_scan_type != PbStreamScanType::ArrangementBackfill {
556        return Err(anyhow!(
557            "unsupported stream_scan_type for auto refresh schema: {:?}",
558            stream_scan_type
559        )
560        .into());
561    }
562
563    let upstream_columns_by_id: HashMap<i32, PbColumnDesc> = upstream_table
564        .columns
565        .iter()
566        .map(|col| {
567            let desc = col.column_desc.as_ref().unwrap().clone();
568            (desc.column_id, desc)
569        })
570        .collect();
571
572    let old_upstream_column_ids = scan.upstream_column_ids.clone();
573    let old_output_indices = scan.output_indices.clone();
574    let mut old_upstream_index_to_new_upstream_index = HashMap::new();
575    let mut new_upstream_column_ids = Vec::new();
576    for (old_idx, &column_id) in old_upstream_column_ids.iter().enumerate() {
577        if !removed_column_ids.contains(&ColumnId::new(column_id as _)) {
578            let new_idx = new_upstream_column_ids.len() as u32;
579            old_upstream_index_to_new_upstream_index.insert(old_idx as u32, new_idx);
580            new_upstream_column_ids.push(column_id);
581        }
582    }
583    let mut new_output_indices = Vec::new();
584    for old_output_index in &old_output_indices {
585        if let Some(new_index) = old_upstream_index_to_new_upstream_index.get(old_output_index) {
586            new_output_indices.push(*new_index);
587        }
588    }
589    for col in newly_added_columns {
590        let new_index = new_upstream_column_ids.len() as u32;
591        new_upstream_column_ids.push(col.column_id().get_id());
592        new_output_indices.push(new_index);
593    }
594
595    let new_output_column_ids: Vec<i32> = new_output_indices
596        .iter()
597        .map(|&idx| new_upstream_column_ids[idx as usize])
598        .collect();
599    let mut new_output_index_by_column_id = HashMap::new();
600    for (pos, &column_id) in new_output_column_ids.iter().enumerate() {
601        new_output_index_by_column_id.insert(ColumnId::new(column_id as _), pos as u32);
602    }
603    let mut old_output_index_to_new_output_index = HashMap::new();
604    for (old_pos, old_output_index) in old_output_indices.iter().enumerate() {
605        let column_id = old_upstream_column_ids[*old_output_index as usize];
606        if let Some(new_pos) = new_output_index_by_column_id.get(&ColumnId::new(column_id as _)) {
607            old_output_index_to_new_output_index.insert(old_pos as u32, *new_pos);
608        }
609    }
610
611    scan.arrangement_table = Some(upstream_table.clone());
612    scan.upstream_column_ids = new_upstream_column_ids;
613    scan.output_indices = new_output_indices;
614    let table_desc = scan.table_desc.as_mut().unwrap();
615    table_desc.columns = scan
616        .upstream_column_ids
617        .iter()
618        .map(|column_id| {
619            upstream_columns_by_id
620                .get(column_id)
621                .unwrap_or_else(|| panic!("upstream column id not found: {}", column_id))
622                .clone()
623        })
624        .collect();
625
626    stream_scan_node.fields = new_output_column_ids
627        .iter()
628        .map(|column_id| {
629            let col_desc = upstream_columns_by_id
630                .get(column_id)
631                .unwrap_or_else(|| panic!("upstream column id not found: {}", column_id));
632            Field::new(
633                format!("{}.{}", upstream_table.name, col_desc.name),
634                col_desc.column_type.as_ref().unwrap().into(),
635            )
636            .to_prost()
637        })
638        .collect();
639    // following logic in <StreamTableScan as Explain>::distill
640    stream_scan_node.identity = {
641        let columns = stream_scan_node
642            .fields
643            .iter()
644            .map(|col| &col.name)
645            .join(", ");
646        format!("StreamTableScan {{ table: t, columns: [{columns}] }}")
647    };
648
649    // update merge node
650    merge_node.fields = scan
651        .upstream_column_ids
652        .iter()
653        .map(|&column_id| {
654            let col_desc = upstream_columns_by_id
655                .get(&column_id)
656                .unwrap_or_else(|| panic!("upstream column id not found: {}", column_id));
657            Field::new(
658                col_desc.name.clone(),
659                col_desc.column_type.as_ref().unwrap().into(),
660            )
661            .to_prost()
662        })
663        .collect();
664    merge.upstream_fragment_id = upstream_table_fragment_id;
665
666    Ok(ScanRewriteResult {
667        old_output_index_to_new_output_index,
668        new_output_index_by_column_id,
669        output_fields: stream_scan_node.fields.clone(),
670    })
671}
672
673/// Rewrite Project node input refs and extend with newly added columns.
674fn rewrite_project_node(
675    project_node: &mut StreamNode,
676    scan_rewrite: &ScanRewriteResult,
677    newly_added_columns: &[ColumnCatalog],
678    removed_column_ids: &HashSet<ColumnId>,
679    upstream_table_name: &str,
680) -> MetaResult<()> {
681    let PbNodeBody::Project(project_node_body) = project_node.node_body.as_mut().unwrap() else {
682        return Err(anyhow!(
683            "expect PbNodeBody::Project but got: {:?}",
684            project_node.node_body
685        )
686        .into());
687    };
688    let has_non_input_ref = project_node_body
689        .select_list
690        .iter()
691        .any(|expr| !matches!(expr.rex_node, Some(expr_node::RexNode::InputRef(_))));
692    if has_non_input_ref && !removed_column_ids.is_empty() {
693        return Err(anyhow!(
694            "auto schema change with drop column only supports Project with InputRef"
695        )
696        .into());
697    }
698
699    let mut new_select_list = Vec::with_capacity(project_node_body.select_list.len());
700    let mut new_project_fields = Vec::with_capacity(project_node.fields.len());
701    for (index, expr) in project_node_body.select_list.iter().enumerate() {
702        let mut new_expr = expr.clone();
703        if let Some(expr_node::RexNode::InputRef(old_index)) = new_expr.rex_node {
704            let Some(&new_index) = scan_rewrite
705                .old_output_index_to_new_output_index
706                .get(&old_index)
707            else {
708                continue;
709            };
710            new_expr.rex_node = Some(expr_node::RexNode::InputRef(new_index));
711        } else if !removed_column_ids.is_empty() {
712            return Err(anyhow!(
713                "auto schema change with drop column only supports Project with InputRef"
714            )
715            .into());
716        }
717        new_select_list.push(new_expr);
718        new_project_fields.push(project_node.fields[index].clone());
719    }
720
721    for col in newly_added_columns {
722        let Some(&new_index) = scan_rewrite
723            .new_output_index_by_column_id
724            .get(&col.column_id())
725        else {
726            return Err(anyhow!("new column id not found in scan output").into());
727        };
728        new_select_list.push(PbExprNode {
729            function_type: expr_node::Type::Unspecified as i32,
730            return_type: Some(col.data_type().to_protobuf()),
731            rex_node: Some(expr_node::RexNode::InputRef(new_index)),
732        });
733        new_project_fields.push(
734            Field::new(
735                format!("{}.{}", upstream_table_name, col.column_desc.name),
736                col.data_type().clone(),
737            )
738            .to_prost(),
739        );
740    }
741
742    project_node_body.select_list = new_select_list;
743    project_node.fields = new_project_fields;
744    Ok(())
745}
746
747pub fn rewrite_refresh_schema_sink_fragment(
748    original_sink_fragment: &Fragment,
749    sink: &PbSink,
750    newly_added_columns: &[ColumnCatalog],
751    removed_columns: &[ColumnCatalog],
752    upstream_table: &PbTable,
753    upstream_table_fragment_id: FragmentId,
754    id_generator_manager: &IdGeneratorManager,
755) -> MetaResult<(Fragment, Vec<PbColumnCatalog>, Option<PbTable>)> {
756    let removed_column_ids: HashSet<_> =
757        removed_columns.iter().map(|col| col.column_id()).collect();
758    let removed_log_store_column_names: HashSet<_> = removed_columns
759        .iter()
760        .map(|col| format!("{}_{}", upstream_table.name, col.column_desc.name))
761        .collect();
762    let removed_sink_column_names: HashSet<_> = removed_columns
763        .iter()
764        .map(|col| col.column_desc.name.clone())
765        .collect();
766    let new_sink_columns =
767        build_new_sink_columns(sink, &removed_sink_column_names, newly_added_columns);
768
769    let mut new_sink_fragment = clone_fragment(original_sink_fragment, id_generator_manager);
770    let sink_node = &mut new_sink_fragment.nodes;
771    let PbNodeBody::Sink(sink_node_body) = sink_node.node_body.as_mut().unwrap() else {
772        return Err(anyhow!("expect PbNodeBody::Sink but got: {:?}", sink_node.node_body).into());
773    };
774    let [stream_input_node] = sink_node.input.as_mut_slice() else {
775        panic!("Sink has more than 1 input: {:?}", sink_node.input);
776    };
777    let stream_input_body = stream_input_node.node_body.as_ref().unwrap();
778    let stream_input_is_project = matches!(stream_input_body, PbNodeBody::Project(_));
779    let stream_input_is_scan = matches!(stream_input_body, PbNodeBody::StreamScan(_));
780    if !stream_input_is_project && !stream_input_is_scan {
781        return Err(anyhow!(
782            "expect PbNodeBody::StreamScan or PbNodeBody::Project but got: {:?}",
783            stream_input_body
784        )
785        .into());
786    }
787
788    // update sink_node
789    // following logic in <StreamSink as Explain>::distill
790    sink_node.identity = {
791        let sink_type = SinkType::from_proto(sink.sink_type());
792        let sink_type_str = sink_type.type_str();
793        let column_names = new_sink_columns
794            .iter()
795            .map(|col| {
796                ColumnCatalog::from(col.clone())
797                    .name_with_hidden()
798                    .to_string()
799            })
800            .join(", ");
801        let downstream_pk = if !sink_type.is_append_only() {
802            let downstream_pk = sink
803                .downstream_pk
804                .iter()
805                .map(|i| &sink.columns[*i as usize].column_desc.as_ref().unwrap().name)
806                .collect_vec();
807            format!(", downstream_pk: {downstream_pk:?}")
808        } else {
809            "".to_owned()
810        };
811        format!("StreamSink {{ type: {sink_type_str}, columns: [{column_names}]{downstream_pk} }}")
812    };
813    let new_log_store_table = if let Some(log_store_table) = &mut sink_node_body.table {
814        rewrite_log_store_table(
815            log_store_table,
816            &removed_log_store_column_names,
817            newly_added_columns,
818            &upstream_table.name,
819        );
820        Some(log_store_table.clone())
821    } else {
822        None
823    };
824    sink_node_body.sink_desc.as_mut().unwrap().column_catalogs = new_sink_columns.clone();
825
826    let stream_scan_node = if stream_input_is_project {
827        let [stream_scan_node] = stream_input_node.input.as_mut_slice() else {
828            return Err(anyhow!(
829                "Project node must have exactly 1 input for auto schema change, but got {:?}",
830                stream_input_node.input.len()
831            )
832            .into());
833        };
834        stream_scan_node
835    } else {
836        stream_input_node
837    };
838    let scan_rewrite = rewrite_stream_scan_and_merge(
839        stream_scan_node,
840        &removed_column_ids,
841        newly_added_columns,
842        upstream_table,
843        upstream_table_fragment_id,
844    )?;
845
846    if stream_input_is_project {
847        let [project_node] = sink_node.input.as_mut_slice() else {
848            panic!("Sink has more than 1 input: {:?}", sink_node.input);
849        };
850        rewrite_project_node(
851            project_node,
852            &scan_rewrite,
853            newly_added_columns,
854            &removed_column_ids,
855            &upstream_table.name,
856        )?;
857        sink_node.fields = project_node.fields.clone();
858    } else {
859        sink_node.fields = scan_rewrite.output_fields;
860    }
861    Ok((new_sink_fragment, new_sink_columns, new_log_store_table))
862}
863
864/// Adjacency list (G) of backfill orders.
865/// `G[10] -> [1, 2, 11]`
866/// means for the backfill node in `fragment 10`
867/// should be backfilled before the backfill nodes in `fragment 1, 2 and 11`.
868#[derive(Clone, Debug, Default)]
869pub struct FragmentBackfillOrder<const EXTENDED: bool> {
870    inner: HashMap<FragmentId, Vec<FragmentId>>,
871}
872
873impl<const EXTENDED: bool> Deref for FragmentBackfillOrder<EXTENDED> {
874    type Target = HashMap<FragmentId, Vec<FragmentId>>;
875
876    fn deref(&self) -> &Self::Target {
877        &self.inner
878    }
879}
880
881impl UserDefinedFragmentBackfillOrder {
882    pub fn new(inner: HashMap<FragmentId, Vec<FragmentId>>) -> Self {
883        Self { inner }
884    }
885
886    pub fn merge(orders: impl Iterator<Item = Self>) -> Self {
887        Self {
888            inner: orders.flat_map(|order| order.inner).collect(),
889        }
890    }
891
892    pub fn to_meta_model(&self) -> BackfillOrders {
893        self.inner.clone().into()
894    }
895}
896
897pub type UserDefinedFragmentBackfillOrder = FragmentBackfillOrder<false>;
898pub type ExtendedFragmentBackfillOrder = FragmentBackfillOrder<true>;
899
900/// In-memory representation of a **Fragment** Graph, built from the [`StreamFragmentGraphProto`]
901/// from the frontend.
902///
903/// This only includes nodes and edges of the current job itself. It will be converted to [`CompleteStreamFragmentGraph`] later,
904/// that contains the additional information of pre-existing
905/// fragments, which are connected to the graph's top-most or bottom-most fragments.
906#[derive(Default, Debug)]
907pub struct StreamFragmentGraph {
908    /// stores all the fragments in the graph.
909    pub(super) fragments: HashMap<GlobalFragmentId, BuildingFragment>,
910
911    /// stores edges between fragments: upstream => downstream.
912    pub(super) downstreams:
913        HashMap<GlobalFragmentId, HashMap<GlobalFragmentId, StreamFragmentEdge>>,
914
915    /// stores edges between fragments: downstream -> upstream.
916    pub(super) upstreams: HashMap<GlobalFragmentId, HashMap<GlobalFragmentId, StreamFragmentEdge>>,
917
918    /// Dependent relations of this job.
919    dependent_table_ids: HashSet<TableId>,
920
921    /// The default parallelism of the job, specified by the `STREAMING_PARALLELISM` session
922    /// variable. If not specified, all active worker slots will be used.
923    specified_parallelism: Option<NonZeroUsize>,
924    /// The parallelism to use during backfill, specified by the `STREAMING_PARALLELISM_FOR_BACKFILL`
925    /// session variable. If not specified, falls back to `specified_parallelism`.
926    specified_backfill_parallelism: Option<NonZeroUsize>,
927
928    /// Specified max parallelism, i.e., expected vnode count for the graph.
929    ///
930    /// The scheduler on the meta service will use this as a hint to decide the vnode count
931    /// for each fragment.
932    ///
933    /// Note that the actual vnode count may be different from this value.
934    /// For example, a no-shuffle exchange between current fragment graph and an existing
935    /// upstream fragment graph requires two fragments to be in the same distribution,
936    /// thus the same vnode count.
937    max_parallelism: usize,
938
939    /// The backfill ordering strategy of the graph.
940    backfill_order: BackfillOrder,
941}
942
943impl StreamFragmentGraph {
944    /// Create a new [`StreamFragmentGraph`] from the given [`StreamFragmentGraphProto`], with all
945    /// global IDs correctly filled.
946    pub fn new(
947        env: &MetaSrvEnv,
948        proto: StreamFragmentGraphProto,
949        job: &StreamingJob,
950    ) -> MetaResult<Self> {
951        let fragment_id_gen =
952            GlobalFragmentIdGen::new(env.id_gen_manager(), proto.fragments.len() as u64);
953        // Note: in SQL backend, the ids generated here are fake and will be overwritten again
954        // with `refill_internal_table_ids` later.
955        // TODO: refactor the code to remove this step.
956        let table_id_gen = GlobalTableIdGen::new(env.id_gen_manager(), proto.table_ids_cnt as u64);
957
958        // Create nodes.
959        let fragments: HashMap<_, _> = proto
960            .fragments
961            .into_iter()
962            .map(|(id, fragment)| {
963                let id = fragment_id_gen.to_global_id(id.as_raw_id());
964                let fragment = BuildingFragment::new(id, fragment, job, table_id_gen);
965                (id, fragment)
966            })
967            .collect();
968
969        assert_eq!(
970            fragments
971                .values()
972                .map(|f| f.extract_internal_tables().len() as u32)
973                .sum::<u32>(),
974            proto.table_ids_cnt
975        );
976
977        // Create edges.
978        let mut downstreams = HashMap::new();
979        let mut upstreams = HashMap::new();
980
981        for edge in proto.edges {
982            let upstream_id = fragment_id_gen.to_global_id(edge.upstream_id.as_raw_id());
983            let downstream_id = fragment_id_gen.to_global_id(edge.downstream_id.as_raw_id());
984            let edge = StreamFragmentEdge::from_protobuf(&edge);
985
986            upstreams
987                .entry(downstream_id)
988                .or_insert_with(HashMap::new)
989                .try_insert(upstream_id, edge.clone())
990                .unwrap();
991            downstreams
992                .entry(upstream_id)
993                .or_insert_with(HashMap::new)
994                .try_insert(downstream_id, edge)
995                .unwrap();
996        }
997
998        // Note: Here we directly use the field `dependent_table_ids` in the proto (resolved in
999        // frontend), instead of visiting the graph ourselves.
1000        let dependent_table_ids = proto.dependent_table_ids.iter().copied().collect();
1001
1002        let specified_parallelism = if let Some(Parallelism { parallelism }) = proto.parallelism {
1003            Some(NonZeroUsize::new(parallelism as usize).context("parallelism should not be 0")?)
1004        } else {
1005            None
1006        };
1007        let specified_backfill_parallelism =
1008            if let Some(Parallelism { parallelism }) = proto.backfill_parallelism {
1009                Some(
1010                    NonZeroUsize::new(parallelism as usize)
1011                        .context("backfill parallelism should not be 0")?,
1012                )
1013            } else {
1014                None
1015            };
1016
1017        let max_parallelism = proto.max_parallelism as usize;
1018        let backfill_order = proto.backfill_order.unwrap_or(BackfillOrder {
1019            order: Default::default(),
1020        });
1021
1022        Ok(Self {
1023            fragments,
1024            downstreams,
1025            upstreams,
1026            dependent_table_ids,
1027            specified_parallelism,
1028            specified_backfill_parallelism,
1029            max_parallelism,
1030            backfill_order,
1031        })
1032    }
1033
1034    /// Retrieve the **incomplete** internal tables map of the whole graph.
1035    ///
1036    /// Note that some fields in the table catalogs are not filled during the current phase, e.g.,
1037    /// `fragment_id`, `vnode_count`. They will be all filled after a `TableFragments` is built.
1038    /// Be careful when using the returned values.
1039    pub fn incomplete_internal_tables(&self) -> BTreeMap<TableId, Table> {
1040        let mut tables = BTreeMap::new();
1041        for fragment in self.fragments.values() {
1042            for table in fragment.extract_internal_tables() {
1043                let table_id = table.id;
1044                tables
1045                    .try_insert(table_id, table)
1046                    .unwrap_or_else(|_| panic!("duplicated table id `{}`", table_id));
1047            }
1048        }
1049        tables
1050    }
1051
1052    /// Refill the internal tables' `table_id`s according to the given map, typically obtained from
1053    /// `create_internal_table_catalog`.
1054    pub fn refill_internal_table_ids(&mut self, table_id_map: HashMap<TableId, TableId>) {
1055        for fragment in self.fragments.values_mut() {
1056            stream_graph_visitor::visit_internal_tables(
1057                &mut fragment.inner,
1058                |table, _table_type_name| {
1059                    let target = table_id_map.get(&table.id).cloned().unwrap();
1060                    table.id = target;
1061                },
1062            );
1063        }
1064    }
1065
1066    /// Use a trivial algorithm to match the internal tables of the new graph for
1067    /// `ALTER TABLE` or `ALTER SOURCE`.
1068    pub fn fit_internal_tables_trivial(
1069        &mut self,
1070        mut old_internal_tables: Vec<Table>,
1071    ) -> MetaResult<()> {
1072        let mut new_internal_table_ids = Vec::new();
1073        for fragment in self.fragments.values() {
1074            for table in &fragment.extract_internal_tables() {
1075                new_internal_table_ids.push(table.id);
1076            }
1077        }
1078
1079        if new_internal_table_ids.len() != old_internal_tables.len() {
1080            bail!(
1081                "Different number of internal tables. New: {}, Old: {}",
1082                new_internal_table_ids.len(),
1083                old_internal_tables.len()
1084            );
1085        }
1086        old_internal_tables.sort_by(|a, b| a.id.cmp(&b.id));
1087        new_internal_table_ids.sort();
1088
1089        let internal_table_id_map = new_internal_table_ids
1090            .into_iter()
1091            .zip_eq_fast(old_internal_tables.into_iter())
1092            .collect::<HashMap<_, _>>();
1093
1094        // TODO(alter-mv): unify this with `fit_internal_table_ids_with_mapping` after we
1095        // confirm the behavior is the same.
1096        for fragment in self.fragments.values_mut() {
1097            stream_graph_visitor::visit_internal_tables(
1098                &mut fragment.inner,
1099                |table, _table_type_name| {
1100                    // XXX: this replaces the entire table, instead of just the id!
1101                    let target = internal_table_id_map.get(&table.id).cloned().unwrap();
1102                    *table = target;
1103                },
1104            );
1105        }
1106
1107        Ok(())
1108    }
1109
1110    /// Fit the internal tables' `table_id`s according to the given mapping.
1111    pub fn fit_internal_table_ids_with_mapping(&mut self, mut matches: HashMap<TableId, Table>) {
1112        for fragment in self.fragments.values_mut() {
1113            stream_graph_visitor::visit_internal_tables(
1114                &mut fragment.inner,
1115                |table, _table_type_name| {
1116                    let target = matches.remove(&table.id).unwrap_or_else(|| {
1117                        panic!("no matching table for table {}({})", table.id, table.name)
1118                    });
1119                    table.id = target.id;
1120                    table.maybe_vnode_count = target.maybe_vnode_count;
1121                },
1122            );
1123        }
1124    }
1125
1126    pub fn fit_snapshot_backfill_epochs(
1127        &mut self,
1128        mut snapshot_backfill_epochs: HashMap<StreamNodeLocalOperatorId, u64>,
1129    ) {
1130        for fragment in self.fragments.values_mut() {
1131            visit_stream_node_cont_mut(fragment.node.as_mut().unwrap(), |node| {
1132                if let PbNodeBody::StreamScan(scan) = node.node_body.as_mut().unwrap()
1133                    && let StreamScanType::SnapshotBackfill
1134                    | StreamScanType::CrossDbSnapshotBackfill = scan.stream_scan_type()
1135                {
1136                    let Some(epoch) = snapshot_backfill_epochs.remove(&node.operator_id) else {
1137                        panic!("no snapshot epoch found for node {:?}", node)
1138                    };
1139                    scan.snapshot_backfill_epoch = Some(epoch);
1140                }
1141                true
1142            })
1143        }
1144    }
1145
1146    /// Returns the fragment id where the streaming job node located.
1147    pub fn table_fragment_id(&self) -> FragmentId {
1148        self.fragments
1149            .values()
1150            .filter(|b| b.job_id.is_some())
1151            .map(|b| b.fragment_id)
1152            .exactly_one()
1153            .expect("require exactly 1 materialize/sink/cdc source node when creating the streaming job")
1154    }
1155
1156    /// Returns the fragment id where the table dml is received.
1157    pub fn dml_fragment_id(&self) -> Option<FragmentId> {
1158        self.fragments
1159            .values()
1160            .filter(|b| {
1161                FragmentTypeMask::from(b.fragment_type_mask).contains(FragmentTypeFlag::Dml)
1162            })
1163            .map(|b| b.fragment_id)
1164            .at_most_one()
1165            .expect("require at most 1 dml node when creating the streaming job")
1166    }
1167
1168    /// Get the dependent streaming job ids of this job.
1169    pub fn dependent_table_ids(&self) -> &HashSet<TableId> {
1170        &self.dependent_table_ids
1171    }
1172
1173    /// Get the parallelism of the job, if specified by the user.
1174    pub fn specified_parallelism(&self) -> Option<NonZeroUsize> {
1175        self.specified_parallelism
1176    }
1177
1178    /// Get the backfill parallelism of the job, if specified by the user.
1179    pub fn specified_backfill_parallelism(&self) -> Option<NonZeroUsize> {
1180        self.specified_backfill_parallelism
1181    }
1182
1183    /// Get the expected vnode count of the graph. See documentation of the field for more details.
1184    pub fn max_parallelism(&self) -> usize {
1185        self.max_parallelism
1186    }
1187
1188    /// Get downstreams of a fragment.
1189    fn get_downstreams(
1190        &self,
1191        fragment_id: GlobalFragmentId,
1192    ) -> &HashMap<GlobalFragmentId, StreamFragmentEdge> {
1193        self.downstreams.get(&fragment_id).unwrap_or(&EMPTY_HASHMAP)
1194    }
1195
1196    /// Get upstreams of a fragment.
1197    fn get_upstreams(
1198        &self,
1199        fragment_id: GlobalFragmentId,
1200    ) -> &HashMap<GlobalFragmentId, StreamFragmentEdge> {
1201        self.upstreams.get(&fragment_id).unwrap_or(&EMPTY_HASHMAP)
1202    }
1203
1204    pub fn collect_snapshot_backfill_info(
1205        &self,
1206    ) -> MetaResult<(Option<SnapshotBackfillInfo>, SnapshotBackfillInfo)> {
1207        Self::collect_snapshot_backfill_info_impl(self.fragments.values().map(|fragment| {
1208            (
1209                fragment.node.as_ref().unwrap(),
1210                fragment.fragment_type_mask.into(),
1211            )
1212        }))
1213    }
1214
1215    /// Returns `Ok((Some(``snapshot_backfill_info``), ``cross_db_snapshot_backfill_info``))`
1216    pub fn collect_snapshot_backfill_info_impl(
1217        fragments: impl IntoIterator<Item = (&PbStreamNode, FragmentTypeMask)>,
1218    ) -> MetaResult<(Option<SnapshotBackfillInfo>, SnapshotBackfillInfo)> {
1219        let mut prev_stream_scan: Option<(Option<SnapshotBackfillInfo>, StreamScanNode)> = None;
1220        let mut cross_db_info = SnapshotBackfillInfo {
1221            upstream_mv_table_id_to_backfill_epoch: Default::default(),
1222        };
1223        let mut result = Ok(());
1224        for (node, fragment_type_mask) in fragments {
1225            visit_stream_node_cont(node, |node| {
1226                if let Some(NodeBody::StreamScan(stream_scan)) = node.node_body.as_ref() {
1227                    let stream_scan_type = StreamScanType::try_from(stream_scan.stream_scan_type)
1228                        .expect("invalid stream_scan_type");
1229                    let is_snapshot_backfill = match stream_scan_type {
1230                        StreamScanType::SnapshotBackfill => {
1231                            assert!(
1232                                fragment_type_mask
1233                                    .contains(FragmentTypeFlag::SnapshotBackfillStreamScan)
1234                            );
1235                            true
1236                        }
1237                        StreamScanType::CrossDbSnapshotBackfill => {
1238                            assert!(
1239                                fragment_type_mask
1240                                    .contains(FragmentTypeFlag::CrossDbSnapshotBackfillStreamScan)
1241                            );
1242                            cross_db_info
1243                                .upstream_mv_table_id_to_backfill_epoch
1244                                .insert(stream_scan.table_id, stream_scan.snapshot_backfill_epoch);
1245
1246                            return true;
1247                        }
1248                        _ => false,
1249                    };
1250
1251                    match &mut prev_stream_scan {
1252                        Some((prev_snapshot_backfill_info, prev_stream_scan)) => {
1253                            match (prev_snapshot_backfill_info, is_snapshot_backfill) {
1254                                (Some(prev_snapshot_backfill_info), true) => {
1255                                    prev_snapshot_backfill_info
1256                                        .upstream_mv_table_id_to_backfill_epoch
1257                                        .insert(
1258                                            stream_scan.table_id,
1259                                            stream_scan.snapshot_backfill_epoch,
1260                                        );
1261                                    true
1262                                }
1263                                (None, false) => true,
1264                                (_, _) => {
1265                                    result = Err(anyhow!("must be either all snapshot_backfill or no snapshot_backfill. Curr: {stream_scan:?} Prev: {prev_stream_scan:?}").into());
1266                                    false
1267                                }
1268                            }
1269                        }
1270                        None => {
1271                            prev_stream_scan = Some((
1272                                if is_snapshot_backfill {
1273                                    Some(SnapshotBackfillInfo {
1274                                        upstream_mv_table_id_to_backfill_epoch: HashMap::from_iter(
1275                                            [(
1276                                                stream_scan.table_id,
1277                                                stream_scan.snapshot_backfill_epoch,
1278                                            )],
1279                                        ),
1280                                    })
1281                                } else {
1282                                    None
1283                                },
1284                                *stream_scan.clone(),
1285                            ));
1286                            true
1287                        }
1288                    }
1289                } else {
1290                    true
1291                }
1292            })
1293        }
1294        result.map(|_| {
1295            (
1296                prev_stream_scan
1297                    .map(|(snapshot_backfill_info, _)| snapshot_backfill_info)
1298                    .unwrap_or(None),
1299                cross_db_info,
1300            )
1301        })
1302    }
1303
1304    /// Collect the mapping from table / `source_id` -> `fragment_id`
1305    pub fn collect_backfill_mapping(
1306        fragments: impl Iterator<Item = (FragmentId, FragmentTypeMask, &PbStreamNode)>,
1307    ) -> HashMap<RelationId, Vec<FragmentId>> {
1308        let mut mapping = HashMap::new();
1309        for (fragment_id, fragment_type_mask, node) in fragments {
1310            let has_some_scan = fragment_type_mask
1311                .contains_any([FragmentTypeFlag::StreamScan, FragmentTypeFlag::SourceScan]);
1312            if has_some_scan {
1313                visit_stream_node_cont(node, |node| {
1314                    match node.node_body.as_ref() {
1315                        Some(NodeBody::StreamScan(stream_scan)) => {
1316                            let table_id = stream_scan.table_id;
1317                            let fragments: &mut Vec<_> =
1318                                mapping.entry(table_id.as_relation_id()).or_default();
1319                            fragments.push(fragment_id);
1320                            // each fragment should have only 1 scan node.
1321                            false
1322                        }
1323                        Some(NodeBody::SourceBackfill(source_backfill)) => {
1324                            let source_id = source_backfill.upstream_source_id;
1325                            let fragments: &mut Vec<_> =
1326                                mapping.entry(source_id.as_relation_id()).or_default();
1327                            fragments.push(fragment_id);
1328                            // each fragment should have only 1 scan node.
1329                            false
1330                        }
1331                        _ => true,
1332                    }
1333                })
1334            }
1335        }
1336        mapping
1337    }
1338
1339    /// Initially the mapping that comes from frontend is between `table_ids`.
1340    /// We should remap it to fragment level, since we track progress by actor, and we can get
1341    /// a fragment <-> actor mapping
1342    pub fn create_fragment_backfill_ordering(&self) -> UserDefinedFragmentBackfillOrder {
1343        let mapping =
1344            Self::collect_backfill_mapping(self.fragments.iter().map(|(fragment_id, fragment)| {
1345                (
1346                    fragment_id.as_global_id(),
1347                    fragment.fragment_type_mask.into(),
1348                    fragment.node.as_ref().expect("should exist node"),
1349                )
1350            }));
1351        let mut fragment_ordering: HashMap<FragmentId, Vec<FragmentId>> = HashMap::new();
1352
1353        // 1. Add backfill dependencies
1354        for (rel_id, downstream_rel_ids) in &self.backfill_order.order {
1355            let fragment_ids = mapping.get(rel_id).unwrap();
1356            for fragment_id in fragment_ids {
1357                let downstream_fragment_ids = downstream_rel_ids
1358                    .data
1359                    .iter()
1360                    .flat_map(|&downstream_rel_id| mapping.get(&downstream_rel_id).unwrap().iter())
1361                    .copied()
1362                    .collect();
1363                fragment_ordering.insert(*fragment_id, downstream_fragment_ids);
1364            }
1365        }
1366
1367        UserDefinedFragmentBackfillOrder {
1368            inner: fragment_ordering,
1369        }
1370    }
1371
1372    pub fn extend_fragment_backfill_ordering_with_locality_backfill<
1373        'a,
1374        FI: Iterator<Item = (FragmentId, FragmentTypeMask, &'a PbStreamNode)> + 'a,
1375    >(
1376        fragment_ordering: UserDefinedFragmentBackfillOrder,
1377        fragment_downstreams: &FragmentDownstreamRelation,
1378        get_fragments: impl Fn() -> FI,
1379    ) -> ExtendedFragmentBackfillOrder {
1380        let mut fragment_ordering = fragment_ordering.inner;
1381        let mapping = Self::collect_backfill_mapping(get_fragments());
1382        // If no backfill order is specified, we still need to ensure that all backfill fragments
1383        // run before LocalityProvider fragments.
1384        if fragment_ordering.is_empty() {
1385            for value in mapping.values() {
1386                for &fragment_id in value {
1387                    fragment_ordering.entry(fragment_id).or_default();
1388                }
1389            }
1390        }
1391
1392        // 2. Add dependencies: all backfill fragments should run before LocalityProvider fragments
1393        let locality_provider_dependencies = Self::find_locality_provider_dependencies(
1394            get_fragments().map(|(fragment_id, _, node)| (fragment_id, node)),
1395            fragment_downstreams,
1396        );
1397
1398        let backfill_fragments: HashSet<FragmentId> = mapping.values().flatten().copied().collect();
1399
1400        // Calculate LocalityProvider root fragments (zero indegree)
1401        // Root fragments are those that appear as keys but never appear as downstream dependencies
1402        let all_locality_provider_fragments: HashSet<FragmentId> =
1403            locality_provider_dependencies.keys().copied().collect();
1404        let downstream_locality_provider_fragments: HashSet<FragmentId> =
1405            locality_provider_dependencies
1406                .values()
1407                .flatten()
1408                .copied()
1409                .collect();
1410        let locality_provider_root_fragments: Vec<FragmentId> = all_locality_provider_fragments
1411            .difference(&downstream_locality_provider_fragments)
1412            .copied()
1413            .collect();
1414
1415        // For each backfill fragment, add only the root LocalityProvider fragments as dependents
1416        // This ensures backfill completes before any LocalityProvider starts, while minimizing dependencies
1417        for &backfill_fragment_id in &backfill_fragments {
1418            fragment_ordering
1419                .entry(backfill_fragment_id)
1420                .or_default()
1421                .extend(locality_provider_root_fragments.iter().copied());
1422        }
1423
1424        // 3. Add LocalityProvider internal dependencies
1425        for (fragment_id, downstream_fragments) in locality_provider_dependencies {
1426            fragment_ordering
1427                .entry(fragment_id)
1428                .or_default()
1429                .extend(downstream_fragments);
1430        }
1431
1432        // Deduplicate downstream entries per fragment; overlaps are common when the same fragment
1433        // is reached via multiple paths (e.g., with StreamShare) and would otherwise appear
1434        // multiple times.
1435        for downstream in fragment_ordering.values_mut() {
1436            let mut seen = HashSet::new();
1437            downstream.retain(|id| seen.insert(*id));
1438        }
1439
1440        ExtendedFragmentBackfillOrder {
1441            inner: fragment_ordering,
1442        }
1443    }
1444
1445    pub fn find_locality_provider_fragment_state_table_mapping(
1446        &self,
1447    ) -> HashMap<FragmentId, Vec<TableId>> {
1448        let mut mapping: HashMap<FragmentId, Vec<TableId>> = HashMap::new();
1449
1450        for (fragment_id, fragment) in &self.fragments {
1451            let fragment_id = fragment_id.as_global_id();
1452
1453            // Check if this fragment contains a LocalityProvider node
1454            if let Some(node) = fragment.node.as_ref() {
1455                let mut state_table_ids = Vec::new();
1456
1457                visit_stream_node_cont(node, |stream_node| {
1458                    if let Some(NodeBody::LocalityProvider(locality_provider)) =
1459                        stream_node.node_body.as_ref()
1460                    {
1461                        // Collect state table ID (except the progress table)
1462                        let state_table_id = locality_provider
1463                            .state_table
1464                            .as_ref()
1465                            .expect("must have state table")
1466                            .id;
1467                        state_table_ids.push(state_table_id);
1468                        false // Stop visiting once we find a LocalityProvider
1469                    } else {
1470                        true // Continue visiting
1471                    }
1472                });
1473
1474                if !state_table_ids.is_empty() {
1475                    mapping.insert(fragment_id, state_table_ids);
1476                }
1477            }
1478        }
1479
1480        mapping
1481    }
1482
1483    /// Find dependency relationships among fragments containing `LocalityProvider` nodes.
1484    /// Returns a mapping where each fragment ID maps to a list of fragment IDs that should be processed after it.
1485    /// Following the same semantics as `FragmentBackfillOrder`:
1486    /// `G[10] -> [1, 2, 11]` means `LocalityProvider` in fragment 10 should be processed
1487    /// before `LocalityProviders` in fragments 1, 2, and 11.
1488    ///
1489    /// This method assumes each fragment contains at most one `LocalityProvider` node.
1490    pub fn find_locality_provider_dependencies<'a>(
1491        fragments_nodes: impl Iterator<Item = (FragmentId, &'a PbStreamNode)>,
1492        fragment_downstreams: &FragmentDownstreamRelation,
1493    ) -> HashMap<FragmentId, Vec<FragmentId>> {
1494        let mut locality_provider_fragments = HashSet::new();
1495        let mut dependencies: HashMap<FragmentId, Vec<FragmentId>> = HashMap::new();
1496
1497        // First, identify all fragments that contain LocalityProvider nodes
1498        for (fragment_id, node) in fragments_nodes {
1499            let has_locality_provider = Self::fragment_has_locality_provider(node);
1500
1501            if has_locality_provider {
1502                locality_provider_fragments.insert(fragment_id);
1503                dependencies.entry(fragment_id).or_default();
1504            }
1505        }
1506
1507        // Build dependency relationships between LocalityProvider fragments
1508        // For each LocalityProvider fragment, find all downstream LocalityProvider fragments
1509        // The upstream fragment should be processed before the downstream fragments
1510        for &provider_fragment_id in &locality_provider_fragments {
1511            // Find all fragments downstream from this LocalityProvider fragment
1512            let mut visited = HashSet::new();
1513            let mut downstream_locality_providers = Vec::new();
1514
1515            Self::collect_downstream_locality_providers(
1516                provider_fragment_id,
1517                &locality_provider_fragments,
1518                fragment_downstreams,
1519                &mut visited,
1520                &mut downstream_locality_providers,
1521            );
1522
1523            // This fragment should be processed before all its downstream LocalityProvider fragments
1524            dependencies
1525                .entry(provider_fragment_id)
1526                .or_default()
1527                .extend(downstream_locality_providers);
1528        }
1529
1530        dependencies
1531    }
1532
1533    fn fragment_has_locality_provider(node: &PbStreamNode) -> bool {
1534        let mut has_locality_provider = false;
1535
1536        {
1537            visit_stream_node_cont(node, |stream_node| {
1538                if let Some(NodeBody::LocalityProvider(_)) = stream_node.node_body.as_ref() {
1539                    has_locality_provider = true;
1540                    false // Stop visiting once we find a LocalityProvider
1541                } else {
1542                    true // Continue visiting
1543                }
1544            });
1545        }
1546
1547        has_locality_provider
1548    }
1549
1550    /// Recursively collect downstream `LocalityProvider` fragments
1551    fn collect_downstream_locality_providers(
1552        current_fragment_id: FragmentId,
1553        locality_provider_fragments: &HashSet<FragmentId>,
1554        fragment_downstreams: &FragmentDownstreamRelation,
1555        visited: &mut HashSet<FragmentId>,
1556        downstream_providers: &mut Vec<FragmentId>,
1557    ) {
1558        if visited.contains(&current_fragment_id) {
1559            return;
1560        }
1561        visited.insert(current_fragment_id);
1562
1563        // Check all downstream fragments
1564        for downstream_fragment_id in fragment_downstreams
1565            .get(&current_fragment_id)
1566            .into_iter()
1567            .flat_map(|downstreams| {
1568                downstreams
1569                    .iter()
1570                    .map(|downstream| downstream.downstream_fragment_id)
1571            })
1572        {
1573            // If the downstream fragment is a LocalityProvider, add it to results
1574            if locality_provider_fragments.contains(&downstream_fragment_id) {
1575                downstream_providers.push(downstream_fragment_id);
1576            }
1577
1578            // Recursively check further downstream
1579            Self::collect_downstream_locality_providers(
1580                downstream_fragment_id,
1581                locality_provider_fragments,
1582                fragment_downstreams,
1583                visited,
1584                downstream_providers,
1585            );
1586        }
1587    }
1588}
1589
1590/// Fill snapshot epoch for `StreamScanNode` of `SnapshotBackfill`.
1591/// Return `true` when has change applied.
1592pub fn fill_snapshot_backfill_epoch(
1593    node: &mut StreamNode,
1594    snapshot_backfill_info: Option<&SnapshotBackfillInfo>,
1595    cross_db_snapshot_backfill_info: &SnapshotBackfillInfo,
1596) -> MetaResult<bool> {
1597    let mut result = Ok(());
1598    let mut applied = false;
1599    visit_stream_node_cont_mut(node, |node| {
1600        if let Some(NodeBody::StreamScan(stream_scan)) = node.node_body.as_mut()
1601            && (stream_scan.stream_scan_type == StreamScanType::SnapshotBackfill as i32
1602                || stream_scan.stream_scan_type == StreamScanType::CrossDbSnapshotBackfill as i32)
1603        {
1604            result = try {
1605                let table_id = stream_scan.table_id;
1606                let snapshot_epoch = cross_db_snapshot_backfill_info
1607                    .upstream_mv_table_id_to_backfill_epoch
1608                    .get(&table_id)
1609                    .or_else(|| {
1610                        snapshot_backfill_info.and_then(|snapshot_backfill_info| {
1611                            snapshot_backfill_info
1612                                .upstream_mv_table_id_to_backfill_epoch
1613                                .get(&table_id)
1614                        })
1615                    })
1616                    .ok_or_else(|| anyhow!("upstream table id not covered: {}", table_id))?
1617                    .ok_or_else(|| anyhow!("upstream table id not set: {}", table_id))?;
1618                if let Some(prev_snapshot_epoch) =
1619                    stream_scan.snapshot_backfill_epoch.replace(snapshot_epoch)
1620                {
1621                    Err(anyhow!(
1622                        "snapshot backfill epoch set again: {} {} {}",
1623                        table_id,
1624                        prev_snapshot_epoch,
1625                        snapshot_epoch
1626                    ))?;
1627                }
1628                applied = true;
1629            };
1630            result.is_ok()
1631        } else {
1632            true
1633        }
1634    });
1635    result.map(|_| applied)
1636}
1637
1638static EMPTY_HASHMAP: LazyLock<HashMap<GlobalFragmentId, StreamFragmentEdge>> =
1639    LazyLock::new(HashMap::new);
1640
1641/// A fragment that is either being built or already exists. Used for generalize the logic of
1642/// [`crate::stream::ActorGraphBuilder`].
1643#[derive(Debug, Clone, EnumAsInner)]
1644pub(super) enum EitherFragment {
1645    /// An internal fragment that is being built for the current streaming job.
1646    Building(BuildingFragment),
1647
1648    /// An existing fragment that is external but connected to the fragments being built.
1649    Existing,
1650}
1651
1652/// A wrapper of [`StreamFragmentGraph`] that contains the additional information of pre-existing
1653/// fragments, which are connected to the graph's top-most or bottom-most fragments.
1654///
1655/// For example,
1656/// - if we're going to build a mview on an existing mview, the upstream fragment containing the
1657///   `Materialize` node will be included in this structure.
1658/// - if we're going to replace the plan of a table with downstream mviews, the downstream fragments
1659///   containing the `StreamScan` nodes will be included in this structure.
1660#[derive(Debug)]
1661pub struct CompleteStreamFragmentGraph {
1662    /// The fragment graph of the streaming job being built.
1663    building_graph: StreamFragmentGraph,
1664
1665    /// The required information of existing fragments.
1666    existing_fragments: HashMap<GlobalFragmentId, Fragment>,
1667
1668    /// Extra edges between existing fragments and the building fragments.
1669    extra_downstreams: HashMap<GlobalFragmentId, HashMap<GlobalFragmentId, StreamFragmentEdge>>,
1670
1671    /// Extra edges between existing fragments and the building fragments.
1672    extra_upstreams: HashMap<GlobalFragmentId, HashMap<GlobalFragmentId, StreamFragmentEdge>>,
1673}
1674
1675pub struct FragmentGraphUpstreamContext {
1676    /// Root fragment is the root of upstream stream graph, which can be a
1677    /// mview fragment or source fragment for cdc source job
1678    pub upstream_root_fragments: HashMap<JobId, Fragment>,
1679}
1680
1681pub struct FragmentGraphDownstreamContext {
1682    pub original_root_fragment_id: FragmentId,
1683    pub downstream_fragments: Vec<(DispatcherType, Fragment)>,
1684}
1685
1686impl CompleteStreamFragmentGraph {
1687    /// Create a new [`CompleteStreamFragmentGraph`] with empty existing fragments, i.e., there's no
1688    /// upstream mviews.
1689    #[cfg(test)]
1690    pub fn for_test(graph: StreamFragmentGraph) -> Self {
1691        Self {
1692            building_graph: graph,
1693            existing_fragments: Default::default(),
1694            extra_downstreams: Default::default(),
1695            extra_upstreams: Default::default(),
1696        }
1697    }
1698
1699    /// Create a new [`CompleteStreamFragmentGraph`] for newly created job (which has no downstreams).
1700    /// e.g., MV on MV and CDC/Source Table with the upstream existing
1701    /// `Materialize` or `Source` fragments.
1702    pub fn with_upstreams(
1703        graph: StreamFragmentGraph,
1704        upstream_context: FragmentGraphUpstreamContext,
1705        job_type: StreamingJobType,
1706    ) -> MetaResult<Self> {
1707        Self::build_helper(graph, Some(upstream_context), None, job_type)
1708    }
1709
1710    /// Create a new [`CompleteStreamFragmentGraph`] for replacing an existing table/source,
1711    /// with the downstream existing `StreamScan`/`StreamSourceScan` fragments.
1712    pub fn with_downstreams(
1713        graph: StreamFragmentGraph,
1714        downstream_context: FragmentGraphDownstreamContext,
1715        job_type: StreamingJobType,
1716    ) -> MetaResult<Self> {
1717        Self::build_helper(graph, None, Some(downstream_context), job_type)
1718    }
1719
1720    /// For replacing an existing table based on shared cdc source, which has both upstreams and downstreams.
1721    pub fn with_upstreams_and_downstreams(
1722        graph: StreamFragmentGraph,
1723        upstream_context: FragmentGraphUpstreamContext,
1724        downstream_context: FragmentGraphDownstreamContext,
1725        job_type: StreamingJobType,
1726    ) -> MetaResult<Self> {
1727        Self::build_helper(
1728            graph,
1729            Some(upstream_context),
1730            Some(downstream_context),
1731            job_type,
1732        )
1733    }
1734
1735    /// The core logic of building a [`CompleteStreamFragmentGraph`], i.e., adding extra upstream/downstream fragments.
1736    fn build_helper(
1737        mut graph: StreamFragmentGraph,
1738        upstream_ctx: Option<FragmentGraphUpstreamContext>,
1739        downstream_ctx: Option<FragmentGraphDownstreamContext>,
1740        job_type: StreamingJobType,
1741    ) -> MetaResult<Self> {
1742        let mut extra_downstreams = HashMap::new();
1743        let mut extra_upstreams = HashMap::new();
1744        let mut existing_fragments = HashMap::new();
1745
1746        if let Some(FragmentGraphUpstreamContext {
1747            upstream_root_fragments,
1748        }) = upstream_ctx
1749        {
1750            for (&id, fragment) in &mut graph.fragments {
1751                let uses_shuffled_backfill = fragment.has_shuffled_backfill();
1752
1753                for (&upstream_job_id, required_columns) in &fragment.upstream_job_columns {
1754                    let upstream_fragment = upstream_root_fragments
1755                        .get(&upstream_job_id)
1756                        .context("upstream fragment not found")?;
1757                    let upstream_root_fragment_id =
1758                        GlobalFragmentId::new(upstream_fragment.fragment_id);
1759
1760                    let edge = match job_type {
1761                        StreamingJobType::Table(TableJobType::SharedCdcSource) => {
1762                            // we traverse all fragments in the graph, and we should find out the
1763                            // CdcFilter fragment and add an edge between upstream source fragment and it.
1764                            assert_ne!(
1765                                (fragment.fragment_type_mask & FragmentTypeFlag::CdcFilter as u32),
1766                                0
1767                            );
1768
1769                            tracing::debug!(
1770                                ?upstream_root_fragment_id,
1771                                ?required_columns,
1772                                identity = ?fragment.inner.get_node().unwrap().get_identity(),
1773                                current_frag_id=?id,
1774                                "CdcFilter with upstream source fragment"
1775                            );
1776
1777                            StreamFragmentEdge {
1778                                id: EdgeId::UpstreamExternal {
1779                                    upstream_job_id,
1780                                    downstream_fragment_id: id,
1781                                },
1782                                // We always use `NoShuffle` for the exchange between the upstream
1783                                // `Source` and the downstream `StreamScan` of the new cdc table.
1784                                dispatch_strategy: DispatchStrategy {
1785                                    r#type: DispatcherType::NoShuffle as _,
1786                                    dist_key_indices: vec![], // not used for `NoShuffle`
1787                                    output_mapping: DispatchOutputMapping::identical(
1788                                        CDC_SOURCE_COLUMN_NUM as _,
1789                                    )
1790                                    .into(),
1791                                },
1792                            }
1793                        }
1794
1795                        // handle MV on MV/Source
1796                        StreamingJobType::MaterializedView
1797                        | StreamingJobType::Sink
1798                        | StreamingJobType::Index => {
1799                            // Build the extra edges between the upstream `Materialize` and
1800                            // the downstream `StreamScan` of the new job.
1801                            if upstream_fragment
1802                                .fragment_type_mask
1803                                .contains(FragmentTypeFlag::Mview)
1804                            {
1805                                // Resolve the required output columns from the upstream materialized view.
1806                                let (dist_key_indices, output_mapping) = {
1807                                    let mview_node = upstream_fragment
1808                                        .nodes
1809                                        .get_node_body()
1810                                        .unwrap()
1811                                        .as_materialize()
1812                                        .unwrap();
1813                                    let all_columns = mview_node.column_descs();
1814                                    let dist_key_indices = mview_node.dist_key_indices();
1815                                    let output_mapping = gen_output_mapping(
1816                                        required_columns,
1817                                        &all_columns,
1818                                    )
1819                                    .context(
1820                                        "BUG: column not found in the upstream materialized view",
1821                                    )?;
1822                                    (dist_key_indices, output_mapping)
1823                                };
1824                                let dispatch_strategy = mv_on_mv_dispatch_strategy(
1825                                    uses_shuffled_backfill,
1826                                    dist_key_indices,
1827                                    output_mapping,
1828                                );
1829
1830                                StreamFragmentEdge {
1831                                    id: EdgeId::UpstreamExternal {
1832                                        upstream_job_id,
1833                                        downstream_fragment_id: id,
1834                                    },
1835                                    dispatch_strategy,
1836                                }
1837                            }
1838                            // Build the extra edges between the upstream `Source` and
1839                            // the downstream `SourceBackfill` of the new job.
1840                            else if upstream_fragment
1841                                .fragment_type_mask
1842                                .contains(FragmentTypeFlag::Source)
1843                            {
1844                                let output_mapping = {
1845                                    let source_node = upstream_fragment
1846                                        .nodes
1847                                        .get_node_body()
1848                                        .unwrap()
1849                                        .as_source()
1850                                        .unwrap();
1851
1852                                    let all_columns = source_node.column_descs().unwrap();
1853                                    gen_output_mapping(required_columns, &all_columns).context(
1854                                        "BUG: column not found in the upstream source node",
1855                                    )?
1856                                };
1857
1858                                StreamFragmentEdge {
1859                                    id: EdgeId::UpstreamExternal {
1860                                        upstream_job_id,
1861                                        downstream_fragment_id: id,
1862                                    },
1863                                    // We always use `NoShuffle` for the exchange between the upstream
1864                                    // `Source` and the downstream `StreamScan` of the new MV.
1865                                    dispatch_strategy: DispatchStrategy {
1866                                        r#type: DispatcherType::NoShuffle as _,
1867                                        dist_key_indices: vec![], // not used for `NoShuffle`
1868                                        output_mapping: Some(output_mapping),
1869                                    },
1870                                }
1871                            } else {
1872                                bail!(
1873                                    "the upstream fragment should be a MView or Source, got fragment type: {:b}",
1874                                    upstream_fragment.fragment_type_mask
1875                                )
1876                            }
1877                        }
1878                        StreamingJobType::Source | StreamingJobType::Table(_) => {
1879                            bail!(
1880                                "the streaming job shouldn't have an upstream fragment, job_type: {:?}",
1881                                job_type
1882                            )
1883                        }
1884                    };
1885
1886                    // put the edge into the extra edges
1887                    extra_downstreams
1888                        .entry(upstream_root_fragment_id)
1889                        .or_insert_with(HashMap::new)
1890                        .try_insert(id, edge.clone())
1891                        .unwrap();
1892                    extra_upstreams
1893                        .entry(id)
1894                        .or_insert_with(HashMap::new)
1895                        .try_insert(upstream_root_fragment_id, edge)
1896                        .unwrap();
1897                }
1898            }
1899
1900            existing_fragments.extend(
1901                upstream_root_fragments
1902                    .into_values()
1903                    .map(|f| (GlobalFragmentId::new(f.fragment_id), f)),
1904            );
1905        }
1906
1907        if let Some(FragmentGraphDownstreamContext {
1908            original_root_fragment_id,
1909            downstream_fragments,
1910        }) = downstream_ctx
1911        {
1912            let original_table_fragment_id = GlobalFragmentId::new(original_root_fragment_id);
1913            let table_fragment_id = GlobalFragmentId::new(graph.table_fragment_id());
1914
1915            // Build the extra edges between the `Materialize` and the downstream `StreamScan` of the
1916            // existing materialized views.
1917            for (dispatcher_type, fragment) in &downstream_fragments {
1918                let id = GlobalFragmentId::new(fragment.fragment_id);
1919
1920                // Similar to `extract_upstream_columns_except_cross_db_backfill`.
1921                let output_columns = {
1922                    let mut res = None;
1923
1924                    stream_graph_visitor::visit_stream_node_body(&fragment.nodes, |node_body| {
1925                        let columns = match node_body {
1926                            NodeBody::StreamScan(stream_scan) => stream_scan.upstream_columns(),
1927                            NodeBody::SourceBackfill(source_backfill) => {
1928                                // FIXME: only pass required columns instead of all columns here
1929                                source_backfill.column_descs()
1930                            }
1931                            _ => return,
1932                        };
1933                        res = Some(columns);
1934                    });
1935
1936                    res.context("failed to locate downstream scan")?
1937                };
1938
1939                let table_fragment = graph.fragments.get(&table_fragment_id).unwrap();
1940                let nodes = table_fragment.node.as_ref().unwrap();
1941
1942                let (dist_key_indices, output_mapping) = match job_type {
1943                    StreamingJobType::Table(_) | StreamingJobType::MaterializedView => {
1944                        let mview_node = nodes.get_node_body().unwrap().as_materialize().unwrap();
1945                        let all_columns = mview_node.column_descs();
1946                        let dist_key_indices = mview_node.dist_key_indices();
1947                        let output_mapping = gen_output_mapping(&output_columns, &all_columns)
1948                            .ok_or_else(|| {
1949                                MetaError::invalid_parameter(
1950                                    "unable to drop the column due to \
1951                                     being referenced by downstream materialized views or sinks",
1952                                )
1953                            })?;
1954                        (dist_key_indices, output_mapping)
1955                    }
1956
1957                    StreamingJobType::Source => {
1958                        let source_node = nodes.get_node_body().unwrap().as_source().unwrap();
1959                        let all_columns = source_node.column_descs().unwrap();
1960                        let output_mapping = gen_output_mapping(&output_columns, &all_columns)
1961                            .ok_or_else(|| {
1962                                MetaError::invalid_parameter(
1963                                    "unable to drop the column due to \
1964                                     being referenced by downstream materialized views or sinks",
1965                                )
1966                            })?;
1967                        assert_eq!(*dispatcher_type, DispatcherType::NoShuffle);
1968                        (
1969                            vec![], // not used for `NoShuffle`
1970                            output_mapping,
1971                        )
1972                    }
1973
1974                    _ => bail!("unsupported job type for replacement: {job_type:?}"),
1975                };
1976
1977                let edge = StreamFragmentEdge {
1978                    id: EdgeId::DownstreamExternal(DownstreamExternalEdgeId {
1979                        original_upstream_fragment_id: original_table_fragment_id,
1980                        downstream_fragment_id: id,
1981                    }),
1982                    dispatch_strategy: DispatchStrategy {
1983                        r#type: *dispatcher_type as i32,
1984                        output_mapping: Some(output_mapping),
1985                        dist_key_indices,
1986                    },
1987                };
1988
1989                extra_downstreams
1990                    .entry(table_fragment_id)
1991                    .or_insert_with(HashMap::new)
1992                    .try_insert(id, edge.clone())
1993                    .unwrap();
1994                extra_upstreams
1995                    .entry(id)
1996                    .or_insert_with(HashMap::new)
1997                    .try_insert(table_fragment_id, edge)
1998                    .unwrap();
1999            }
2000
2001            existing_fragments.extend(
2002                downstream_fragments
2003                    .into_iter()
2004                    .map(|(_, f)| (GlobalFragmentId::new(f.fragment_id), f)),
2005            );
2006        }
2007
2008        Ok(Self {
2009            building_graph: graph,
2010            existing_fragments,
2011            extra_downstreams,
2012            extra_upstreams,
2013        })
2014    }
2015}
2016
2017/// Generate the `output_mapping` for [`DispatchStrategy`] from given columns.
2018fn gen_output_mapping(
2019    required_columns: &[PbColumnDesc],
2020    upstream_columns: &[PbColumnDesc],
2021) -> Option<DispatchOutputMapping> {
2022    let len = required_columns.len();
2023    let mut indices = vec![0; len];
2024    let mut types = None;
2025
2026    for (i, r) in required_columns.iter().enumerate() {
2027        let (ui, u) = upstream_columns
2028            .iter()
2029            .find_position(|&u| u.column_id == r.column_id)?;
2030        indices[i] = ui as u32;
2031
2032        // Only if we encounter type change (`ALTER TABLE ALTER COLUMN TYPE`) will we generate a
2033        // non-empty `types`.
2034        if u.column_type != r.column_type {
2035            types.get_or_insert_with(|| vec![TypePair::default(); len])[i] = TypePair {
2036                upstream: u.column_type.clone(),
2037                downstream: r.column_type.clone(),
2038            };
2039        }
2040    }
2041
2042    // If there's no type change, indicate it by empty `types`.
2043    let types = types.unwrap_or(Vec::new());
2044
2045    Some(DispatchOutputMapping { indices, types })
2046}
2047
2048fn mv_on_mv_dispatch_strategy(
2049    uses_shuffled_backfill: bool,
2050    dist_key_indices: Vec<u32>,
2051    output_mapping: DispatchOutputMapping,
2052) -> DispatchStrategy {
2053    if uses_shuffled_backfill {
2054        if !dist_key_indices.is_empty() {
2055            DispatchStrategy {
2056                r#type: DispatcherType::Hash as _,
2057                dist_key_indices,
2058                output_mapping: Some(output_mapping),
2059            }
2060        } else {
2061            DispatchStrategy {
2062                r#type: DispatcherType::Simple as _,
2063                dist_key_indices: vec![], // empty for Simple
2064                output_mapping: Some(output_mapping),
2065            }
2066        }
2067    } else {
2068        DispatchStrategy {
2069            r#type: DispatcherType::NoShuffle as _,
2070            dist_key_indices: vec![], // not used for `NoShuffle`
2071            output_mapping: Some(output_mapping),
2072        }
2073    }
2074}
2075
2076impl CompleteStreamFragmentGraph {
2077    /// Returns **all** fragment IDs in the complete graph, including the ones that are not in the
2078    /// building graph.
2079    pub(super) fn all_fragment_ids(&self) -> impl Iterator<Item = GlobalFragmentId> + '_ {
2080        self.building_graph
2081            .fragments
2082            .keys()
2083            .chain(self.existing_fragments.keys())
2084            .copied()
2085    }
2086
2087    /// Returns an iterator of **all** edges in the complete graph, including the external edges.
2088    pub(super) fn all_edges(
2089        &self,
2090    ) -> impl Iterator<Item = (GlobalFragmentId, GlobalFragmentId, &StreamFragmentEdge)> + '_ {
2091        self.building_graph
2092            .downstreams
2093            .iter()
2094            .chain(self.extra_downstreams.iter())
2095            .flat_map(|(&from, tos)| tos.iter().map(move |(&to, edge)| (from, to, edge)))
2096    }
2097
2098    /// Returns the distribution of the existing fragments.
2099    pub(super) fn existing_distribution(&self) -> HashMap<GlobalFragmentId, Distribution> {
2100        self.existing_fragments
2101            .iter()
2102            .map(|(&id, f)| (id, Distribution::from_fragment(f)))
2103            .collect()
2104    }
2105
2106    /// Generate topological order of **all** fragments in this graph, including the ones that are
2107    /// not in the building graph. Returns error if the graph is not a DAG and topological sort can
2108    /// not be done.
2109    ///
2110    /// For MV on MV, the first fragment popped out from the heap will be the top-most node, or the
2111    /// `Sink` / `Materialize` in stream graph.
2112    pub(super) fn topo_order(&self) -> MetaResult<Vec<GlobalFragmentId>> {
2113        let mut topo = Vec::new();
2114        let mut downstream_cnts = HashMap::new();
2115
2116        // Iterate all fragments.
2117        for fragment_id in self.all_fragment_ids() {
2118            // Count how many downstreams we have for a given fragment.
2119            let downstream_cnt = self.get_downstreams(fragment_id).count();
2120            if downstream_cnt == 0 {
2121                topo.push(fragment_id);
2122            } else {
2123                downstream_cnts.insert(fragment_id, downstream_cnt);
2124            }
2125        }
2126
2127        let mut i = 0;
2128        while let Some(&fragment_id) = topo.get(i) {
2129            i += 1;
2130            // Find if we can process more fragments.
2131            for (upstream_job_id, _) in self.get_upstreams(fragment_id) {
2132                let downstream_cnt = downstream_cnts.get_mut(&upstream_job_id).unwrap();
2133                *downstream_cnt -= 1;
2134                if *downstream_cnt == 0 {
2135                    downstream_cnts.remove(&upstream_job_id);
2136                    topo.push(upstream_job_id);
2137                }
2138            }
2139        }
2140
2141        if !downstream_cnts.is_empty() {
2142            // There are fragments that are not processed yet.
2143            bail!("graph is not a DAG");
2144        }
2145
2146        Ok(topo)
2147    }
2148
2149    /// Seal a [`BuildingFragment`] from the graph into a [`Fragment`], which will be further used
2150    /// to build actors on the compute nodes and persist into meta store.
2151    pub(super) fn seal_fragment(
2152        &self,
2153        id: GlobalFragmentId,
2154        distribution: Distribution,
2155        stream_node: StreamNode,
2156    ) -> Fragment {
2157        let building_fragment = self.get_fragment(id).into_building().unwrap();
2158        let internal_tables = building_fragment.extract_internal_tables();
2159        let BuildingFragment {
2160            inner,
2161            job_id,
2162            upstream_job_columns: _,
2163        } = building_fragment;
2164
2165        let distribution_type = distribution.to_distribution_type();
2166        let vnode_count = distribution.vnode_count();
2167
2168        let materialized_fragment_id =
2169            if FragmentTypeMask::from(inner.fragment_type_mask).contains(FragmentTypeFlag::Mview) {
2170                job_id.map(JobId::as_mv_table_id)
2171            } else {
2172                None
2173            };
2174
2175        let vector_index_fragment_id =
2176            if inner.fragment_type_mask & FragmentTypeFlag::VectorIndexWrite as u32 != 0 {
2177                job_id.map(JobId::as_mv_table_id)
2178            } else {
2179                None
2180            };
2181
2182        let state_table_ids = internal_tables
2183            .iter()
2184            .map(|t| t.id)
2185            .chain(materialized_fragment_id)
2186            .chain(vector_index_fragment_id)
2187            .collect();
2188
2189        Fragment {
2190            fragment_id: inner.fragment_id,
2191            fragment_type_mask: inner.fragment_type_mask.into(),
2192            distribution_type,
2193            state_table_ids,
2194            maybe_vnode_count: VnodeCount::set(vnode_count).to_protobuf(),
2195            nodes: stream_node,
2196        }
2197    }
2198
2199    /// Get a fragment from the complete graph, which can be either a building fragment or an
2200    /// existing fragment.
2201    pub(super) fn get_fragment(&self, fragment_id: GlobalFragmentId) -> EitherFragment {
2202        if self.existing_fragments.contains_key(&fragment_id) {
2203            EitherFragment::Existing
2204        } else {
2205            EitherFragment::Building(
2206                self.building_graph
2207                    .fragments
2208                    .get(&fragment_id)
2209                    .unwrap()
2210                    .clone(),
2211            )
2212        }
2213    }
2214
2215    /// Get **all** downstreams of a fragment, including the ones that are not in the building
2216    /// graph.
2217    pub(super) fn get_downstreams(
2218        &self,
2219        fragment_id: GlobalFragmentId,
2220    ) -> impl Iterator<Item = (GlobalFragmentId, &StreamFragmentEdge)> {
2221        self.building_graph
2222            .get_downstreams(fragment_id)
2223            .iter()
2224            .chain(
2225                self.extra_downstreams
2226                    .get(&fragment_id)
2227                    .into_iter()
2228                    .flatten(),
2229            )
2230            .map(|(&id, edge)| (id, edge))
2231    }
2232
2233    /// Get **all** upstreams of a fragment, including the ones that are not in the building
2234    /// graph.
2235    pub(super) fn get_upstreams(
2236        &self,
2237        fragment_id: GlobalFragmentId,
2238    ) -> impl Iterator<Item = (GlobalFragmentId, &StreamFragmentEdge)> {
2239        self.building_graph
2240            .get_upstreams(fragment_id)
2241            .iter()
2242            .chain(self.extra_upstreams.get(&fragment_id).into_iter().flatten())
2243            .map(|(&id, edge)| (id, edge))
2244    }
2245
2246    /// Returns all building fragments in the graph.
2247    pub(super) fn building_fragments(&self) -> &HashMap<GlobalFragmentId, BuildingFragment> {
2248        &self.building_graph.fragments
2249    }
2250
2251    /// Returns all building fragments in the graph, mutable.
2252    pub(super) fn building_fragments_mut(
2253        &mut self,
2254    ) -> &mut HashMap<GlobalFragmentId, BuildingFragment> {
2255        &mut self.building_graph.fragments
2256    }
2257
2258    /// Get the expected vnode count of the building graph. See documentation of the field for more details.
2259    pub(super) fn max_parallelism(&self) -> usize {
2260        self.building_graph.max_parallelism()
2261    }
2262}
2263
2264#[cfg(test)]
2265mod tests {
2266    use risingwave_common::catalog::{ColumnDesc, ColumnId};
2267    use risingwave_common::types::DataType;
2268    use risingwave_pb::catalog::SinkType as PbSinkType;
2269    use risingwave_pb::meta::table_fragments::fragment::PbFragmentDistributionType;
2270    use risingwave_pb::plan_common::StorageTableDesc;
2271    use risingwave_pb::stream_plan::{
2272        BatchPlanNode, MergeNode, ProjectNode, SinkDesc, SinkLogStoreType, SinkNode, StreamNode,
2273        StreamScanNode, StreamScanType,
2274    };
2275
2276    use super::*;
2277
2278    fn make_column(name: &str, id: i32, data_type: DataType) -> ColumnCatalog {
2279        ColumnCatalog::visible(ColumnDesc::named(name, ColumnId::new(id), data_type))
2280    }
2281
2282    fn make_field(table_name: &str, column: &ColumnCatalog) -> risingwave_pb::plan_common::Field {
2283        Field::new(
2284            format!("{}.{}", table_name, column.column_desc.name),
2285            column.data_type().clone(),
2286        )
2287        .to_prost()
2288    }
2289
2290    fn make_input_ref(index: u32, data_type: &DataType) -> PbExprNode {
2291        PbExprNode {
2292            function_type: expr_node::Type::Unspecified as i32,
2293            return_type: Some(data_type.to_protobuf()),
2294            rex_node: Some(expr_node::RexNode::InputRef(index)),
2295        }
2296    }
2297
2298    fn make_stream_scan_node(
2299        table_name: &str,
2300        table_id: u32,
2301        columns: &[ColumnCatalog],
2302    ) -> StreamNode {
2303        let merge_node = StreamNode {
2304            node_body: Some(NodeBody::Merge(Box::new(MergeNode {
2305                upstream_fragment_id: 0.into(),
2306                ..Default::default()
2307            }))),
2308            fields: columns
2309                .iter()
2310                .map(|col| make_field(table_name, col))
2311                .collect(),
2312            ..Default::default()
2313        };
2314        let batch_plan_node = StreamNode {
2315            node_body: Some(NodeBody::BatchPlan(Box::new(BatchPlanNode {
2316                ..Default::default()
2317            }))),
2318            ..Default::default()
2319        };
2320        let stream_scan_node = StreamScanNode {
2321            table_id: table_id.into(),
2322            upstream_column_ids: columns.iter().map(|c| c.column_id().get_id()).collect(),
2323            output_indices: (0..columns.len()).map(|i| i as u32).collect(),
2324            stream_scan_type: StreamScanType::ArrangementBackfill as i32,
2325            table_desc: Some(StorageTableDesc {
2326                table_id: table_id.into(),
2327                columns: columns
2328                    .iter()
2329                    .map(|col| col.column_desc.to_protobuf())
2330                    .collect(),
2331                value_indices: (0..columns.len()).map(|i| i as u32).collect(),
2332                versioned: true,
2333                ..Default::default()
2334            }),
2335            ..Default::default()
2336        };
2337        StreamNode {
2338            node_body: Some(NodeBody::StreamScan(Box::new(stream_scan_node))),
2339            fields: columns
2340                .iter()
2341                .map(|col| make_field(table_name, col))
2342                .collect(),
2343            input: vec![merge_node, batch_plan_node],
2344            ..Default::default()
2345        }
2346    }
2347
2348    fn make_project_node(
2349        table_name: &str,
2350        columns: &[ColumnCatalog],
2351        input: StreamNode,
2352    ) -> StreamNode {
2353        let select_list = columns
2354            .iter()
2355            .enumerate()
2356            .map(|(i, col)| make_input_ref(i as u32, col.data_type()))
2357            .collect();
2358        StreamNode {
2359            node_body: Some(NodeBody::Project(Box::new(ProjectNode {
2360                select_list,
2361                ..Default::default()
2362            }))),
2363            fields: columns
2364                .iter()
2365                .map(|col| make_field(table_name, col))
2366                .collect(),
2367            input: vec![input],
2368            ..Default::default()
2369        }
2370    }
2371
2372    #[tokio::test]
2373    async fn test_rewrite_refresh_schema_sink_fragment_with_project() {
2374        let env = MetaSrvEnv::for_test().await;
2375        let id_gen_manager = env.id_gen_manager().as_ref();
2376
2377        let table_name = "t";
2378        let columns = vec![
2379            make_column("a", 1, DataType::Int64),
2380            make_column("b", 2, DataType::Int64),
2381        ];
2382        let new_column = make_column("c", 3, DataType::Varchar);
2383
2384        let mut upstream_columns = columns.clone();
2385        upstream_columns.push(new_column.clone());
2386        let upstream_table = PbTable {
2387            name: table_name.to_owned(),
2388            columns: upstream_columns
2389                .iter()
2390                .map(|col| col.to_protobuf())
2391                .collect(),
2392            ..Default::default()
2393        };
2394
2395        let sink = PbSink {
2396            columns: columns.iter().map(|col| col.to_protobuf()).collect(),
2397            sink_type: PbSinkType::AppendOnly as i32,
2398            ..Default::default()
2399        };
2400
2401        let sink_desc = SinkDesc {
2402            sink_type: PbSinkType::AppendOnly as i32,
2403            column_catalogs: sink.columns.clone(),
2404            ..Default::default()
2405        };
2406
2407        let stream_scan_node = make_stream_scan_node(table_name, 1, &columns);
2408        let project_node = make_project_node(table_name, &columns, stream_scan_node);
2409
2410        let log_store_table = PbTable {
2411            columns: columns
2412                .iter()
2413                .cloned()
2414                .map(|mut col| {
2415                    col.column_desc.name = format!("{}_{}", table_name, col.column_desc.name);
2416                    col.to_protobuf()
2417                })
2418                .collect(),
2419            value_indices: (0..columns.len()).map(|i| i as i32).collect(),
2420            ..Default::default()
2421        };
2422
2423        let original_fragment = Fragment {
2424            fragment_id: 1.into(),
2425            fragment_type_mask: FragmentTypeMask::default(),
2426            distribution_type: PbFragmentDistributionType::Single,
2427            state_table_ids: vec![],
2428            maybe_vnode_count: None,
2429            nodes: StreamNode {
2430                node_body: Some(NodeBody::Sink(Box::new(SinkNode {
2431                    sink_desc: Some(sink_desc),
2432                    table: Some(log_store_table),
2433                    ..Default::default()
2434                }))),
2435                fields: columns
2436                    .iter()
2437                    .map(|col| make_field(table_name, col))
2438                    .collect(),
2439                input: vec![project_node],
2440                ..Default::default()
2441            },
2442        };
2443
2444        let (new_fragment, _, _) = rewrite_refresh_schema_sink_fragment(
2445            &original_fragment,
2446            &sink,
2447            std::slice::from_ref(&new_column),
2448            &[],
2449            &upstream_table,
2450            7.into(),
2451            id_gen_manager,
2452        )
2453        .unwrap();
2454
2455        let sink_node = &new_fragment.nodes;
2456        let [project_node] = sink_node.input.as_slice() else {
2457            panic!("Sink has more than 1 input: {:?}", sink_node.input);
2458        };
2459        let PbNodeBody::Project(project_body) = project_node.node_body.as_ref().unwrap() else {
2460            panic!(
2461                "expect PbNodeBody::Project but got: {:?}",
2462                project_node.node_body
2463            );
2464        };
2465        assert_eq!(project_body.select_list.len(), columns.len() + 1);
2466        let last_expr = project_body.select_list.last().unwrap();
2467        assert!(
2468            matches!(last_expr.rex_node, Some(expr_node::RexNode::InputRef(idx)) if idx == columns.len() as u32)
2469        );
2470        assert_eq!(project_node.fields.len(), columns.len() + 1);
2471
2472        let [stream_scan_node] = project_node.input.as_slice() else {
2473            panic!("Project has more than 1 input: {:?}", project_node.input);
2474        };
2475        let PbNodeBody::StreamScan(scan) = stream_scan_node.node_body.as_ref().unwrap() else {
2476            panic!(
2477                "expect PbNodeBody::StreamScan but got: {:?}",
2478                stream_scan_node.node_body
2479            );
2480        };
2481        assert_eq!(
2482            scan.upstream_column_ids.last().copied(),
2483            Some(new_column.column_id().get_id())
2484        );
2485        assert_eq!(
2486            scan.output_indices.last().copied(),
2487            Some(columns.len() as u32)
2488        );
2489        assert_eq!(
2490            stream_scan_node.fields.last().unwrap().name,
2491            format!("{}.{}", table_name, new_column.column_desc.name)
2492        );
2493    }
2494
2495    #[tokio::test]
2496    async fn test_rewrite_refresh_schema_sink_fragment_drop_column_with_project() {
2497        let env = MetaSrvEnv::for_test().await;
2498        let id_gen_manager = env.id_gen_manager().as_ref();
2499
2500        let table_name = "t";
2501        let columns = vec![
2502            make_column("a", 1, DataType::Int64),
2503            make_column("b", 2, DataType::Int64),
2504            make_column("tmp", 3, DataType::Varchar),
2505        ];
2506        let removed_column = columns.last().unwrap().clone();
2507        let upstream_columns = columns[..2].to_vec();
2508
2509        let upstream_table = PbTable {
2510            name: table_name.to_owned(),
2511            columns: upstream_columns
2512                .iter()
2513                .map(|col| col.to_protobuf())
2514                .collect(),
2515            ..Default::default()
2516        };
2517
2518        let sink = PbSink {
2519            columns: columns.iter().map(|col| col.to_protobuf()).collect(),
2520            sink_type: PbSinkType::AppendOnly as i32,
2521            ..Default::default()
2522        };
2523
2524        let sink_desc = SinkDesc {
2525            sink_type: PbSinkType::AppendOnly as i32,
2526            column_catalogs: sink.columns.clone(),
2527            ..Default::default()
2528        };
2529
2530        let stream_scan_node = make_stream_scan_node(table_name, 1, &columns);
2531        let project_node = make_project_node(table_name, &columns, stream_scan_node);
2532
2533        let log_store_table = PbTable {
2534            columns: columns
2535                .iter()
2536                .cloned()
2537                .map(|mut col| {
2538                    col.column_desc.name = format!("{}_{}", table_name, col.column_desc.name);
2539                    col.to_protobuf()
2540                })
2541                .collect(),
2542            value_indices: (0..columns.len()).map(|i| i as i32).collect(),
2543            ..Default::default()
2544        };
2545
2546        let original_fragment = Fragment {
2547            fragment_id: 1.into(),
2548            fragment_type_mask: FragmentTypeMask::default(),
2549            distribution_type: PbFragmentDistributionType::Single,
2550            state_table_ids: vec![],
2551            maybe_vnode_count: None,
2552            nodes: StreamNode {
2553                node_body: Some(NodeBody::Sink(Box::new(SinkNode {
2554                    sink_desc: Some(sink_desc),
2555                    table: Some(log_store_table),
2556                    log_store_type: SinkLogStoreType::KvLogStore as i32,
2557                    ..Default::default()
2558                }))),
2559                fields: columns
2560                    .iter()
2561                    .map(|col| make_field(table_name, col))
2562                    .collect(),
2563                input: vec![project_node],
2564                ..Default::default()
2565            },
2566        };
2567
2568        let (new_fragment, new_schema, new_log_store_table) = rewrite_refresh_schema_sink_fragment(
2569            &original_fragment,
2570            &sink,
2571            &[],
2572            std::slice::from_ref(&removed_column),
2573            &upstream_table,
2574            7.into(),
2575            id_gen_manager,
2576        )
2577        .unwrap();
2578
2579        assert_eq!(new_schema.len(), 2);
2580        assert!(
2581            new_schema.iter().all(|col| {
2582                col.column_desc.as_ref().map(|desc| desc.name.as_str()) != Some("tmp")
2583            })
2584        );
2585
2586        let sink_node = &new_fragment.nodes;
2587        let [project_node] = sink_node.input.as_slice() else {
2588            panic!("Sink has more than 1 input: {:?}", sink_node.input);
2589        };
2590        let PbNodeBody::Project(project_body) = project_node.node_body.as_ref().unwrap() else {
2591            panic!(
2592                "expect PbNodeBody::Project but got: {:?}",
2593                project_node.node_body
2594            );
2595        };
2596        assert_eq!(project_body.select_list.len(), 2);
2597        assert!(project_node.fields.iter().all(|f| !f.name.contains("tmp")));
2598
2599        let [stream_scan_node] = project_node.input.as_slice() else {
2600            panic!("Project has more than 1 input: {:?}", project_node.input);
2601        };
2602        let PbNodeBody::StreamScan(scan) = stream_scan_node.node_body.as_ref().unwrap() else {
2603            panic!(
2604                "expect PbNodeBody::StreamScan but got: {:?}",
2605                stream_scan_node.node_body
2606            );
2607        };
2608        assert!(
2609            !scan
2610                .upstream_column_ids
2611                .iter()
2612                .any(|&id| id == removed_column.column_id().get_id())
2613        );
2614        assert!(
2615            stream_scan_node
2616                .fields
2617                .iter()
2618                .all(|f| !f.name.contains("tmp"))
2619        );
2620
2621        let new_log_store_table = new_log_store_table.expect("log store table should be updated");
2622        assert!(
2623            new_log_store_table.columns.iter().all(|col| !col
2624                .column_desc
2625                .as_ref()
2626                .unwrap()
2627                .name
2628                .contains("tmp"))
2629        );
2630        assert_eq!(
2631            new_log_store_table.value_indices,
2632            (0..new_log_store_table.columns.len() as i32).collect::<Vec<_>>()
2633        );
2634    }
2635}