1use std::collections::{BTreeMap, HashMap, HashSet};
16use std::num::NonZeroUsize;
17use std::ops::{Deref, DerefMut};
18use std::sync::LazyLock;
19
20use anyhow::{Context, anyhow};
21use enum_as_inner::EnumAsInner;
22use itertools::Itertools;
23use risingwave_common::bail;
24use risingwave_common::catalog::{
25 CDC_SOURCE_COLUMN_NUM, ColumnCatalog, ColumnId, Field, FragmentTypeFlag, FragmentTypeMask,
26 TableId, generate_internal_table_name_with_type,
27};
28use risingwave_common::hash::VnodeCount;
29use risingwave_common::id::JobId;
30use risingwave_common::util::iter_util::ZipEqFast;
31use risingwave_common::util::stream_graph_visitor::{
32 self, visit_stream_node_cont, visit_stream_node_cont_mut,
33};
34use risingwave_connector::sink::catalog::SinkType;
35use risingwave_meta_model::streaming_job::BackfillOrders;
36use risingwave_pb::catalog::{PbSink, PbTable, Table};
37use risingwave_pb::ddl_service::TableJobType;
38use risingwave_pb::expr::{ExprNode as PbExprNode, expr_node};
39use risingwave_pb::id::{RelationId, StreamNodeLocalOperatorId};
40use risingwave_pb::plan_common::{PbColumnCatalog, PbColumnDesc};
41use risingwave_pb::stream_plan::dispatch_output_mapping::TypePair;
42use risingwave_pb::stream_plan::stream_fragment_graph::{
43 Parallelism, StreamFragment, StreamFragmentEdge as StreamFragmentEdgeProto,
44};
45use risingwave_pb::stream_plan::stream_node::{NodeBody, PbNodeBody};
46use risingwave_pb::stream_plan::{
47 BackfillOrder, DispatchOutputMapping, DispatchStrategy, DispatcherType, PbStreamNode,
48 PbStreamScanType, StreamFragmentGraph as StreamFragmentGraphProto, StreamNode, StreamScanNode,
49 StreamScanType,
50};
51
52use crate::barrier::SnapshotBackfillInfo;
53use crate::controller::id::IdGeneratorManager;
54use crate::manager::{MetaSrvEnv, StreamingJob, StreamingJobType};
55use crate::model::{Fragment, FragmentDownstreamRelation, FragmentId};
56use crate::stream::stream_graph::id::{GlobalFragmentId, GlobalFragmentIdGen, GlobalTableIdGen};
57use crate::stream::stream_graph::schedule::Distribution;
58use crate::{MetaError, MetaResult};
59
60#[derive(Debug, Clone)]
63pub(super) struct BuildingFragment {
64 inner: StreamFragment,
66
67 job_id: Option<JobId>,
69
70 upstream_job_columns: HashMap<JobId, Vec<PbColumnDesc>>,
75}
76
77impl BuildingFragment {
78 fn new(
81 id: GlobalFragmentId,
82 fragment: StreamFragment,
83 job: &StreamingJob,
84 table_id_gen: GlobalTableIdGen,
85 ) -> Self {
86 let mut fragment = StreamFragment {
87 fragment_id: id.as_global_id(),
88 ..fragment
89 };
90
91 Self::fill_internal_tables(&mut fragment, job, table_id_gen);
93
94 let job_id = Self::fill_job(&mut fragment, job).then(|| job.id());
95 let upstream_job_columns =
96 Self::extract_upstream_columns_except_cross_db_backfill(&fragment);
97
98 Self {
99 inner: fragment,
100 job_id,
101 upstream_job_columns,
102 }
103 }
104
105 fn extract_internal_tables(&self) -> Vec<Table> {
107 let mut fragment = self.inner.clone();
108 let mut tables = Vec::new();
109 stream_graph_visitor::visit_internal_tables(&mut fragment, |table, _| {
110 tables.push(table.clone());
111 });
112 tables
113 }
114
115 fn fill_internal_tables(
117 fragment: &mut StreamFragment,
118 job: &StreamingJob,
119 table_id_gen: GlobalTableIdGen,
120 ) {
121 let fragment_id = fragment.fragment_id;
122 stream_graph_visitor::visit_internal_tables(fragment, |table, table_type_name| {
123 table.id = table_id_gen
124 .to_global_id(table.id.as_raw_id())
125 .as_global_id();
126 table.schema_id = job.schema_id();
127 table.database_id = job.database_id();
128 table.name = generate_internal_table_name_with_type(
129 &job.name(),
130 fragment_id,
131 table.id,
132 table_type_name,
133 );
134 table.fragment_id = fragment_id;
135 table.owner = job.owner();
136 table.job_id = Some(job.id());
137 });
138 }
139
140 fn fill_job(fragment: &mut StreamFragment, job: &StreamingJob) -> bool {
142 let job_id = job.id();
143 let fragment_id = fragment.fragment_id;
144 let mut has_job = false;
145
146 stream_graph_visitor::visit_fragment_mut(fragment, |node_body| match node_body {
147 NodeBody::Materialize(materialize_node) => {
148 materialize_node.table_id = job_id.as_mv_table_id();
149
150 let table = materialize_node.table.insert(job.table().unwrap().clone());
152 table.fragment_id = fragment_id; if cfg!(not(debug_assertions)) {
155 table.definition = job.name();
156 }
157
158 has_job = true;
159 }
160 NodeBody::Sink(sink_node) => {
161 sink_node.sink_desc.as_mut().unwrap().id = job_id.as_sink_id();
162
163 has_job = true;
164 }
165 NodeBody::Dml(dml_node) => {
166 dml_node.table_id = job_id.as_mv_table_id();
167 dml_node.table_version_id = job.table_version_id().unwrap();
168 }
169 NodeBody::StreamFsFetch(fs_fetch_node) => {
170 if let StreamingJob::Table(table_source, _, _) = job
171 && let Some(node_inner) = fs_fetch_node.node_inner.as_mut()
172 && let Some(source) = table_source
173 {
174 node_inner.source_id = source.id;
175 if let Some(id) = source.optional_associated_table_id {
176 node_inner.associated_table_id = Some(id.into());
177 }
178 }
179 }
180 NodeBody::Source(source_node) => {
181 match job {
182 StreamingJob::Table(source, _table, _table_job_type) => {
185 if let Some(source_inner) = source_node.source_inner.as_mut()
186 && let Some(source) = source
187 {
188 debug_assert_ne!(source.id, job_id.as_raw_id());
189 source_inner.source_id = source.id;
190 if let Some(id) = source.optional_associated_table_id {
191 source_inner.associated_table_id = Some(id.into());
192 }
193 }
194 }
195 StreamingJob::Source(source) => {
196 has_job = true;
197 if let Some(source_inner) = source_node.source_inner.as_mut() {
198 debug_assert_eq!(source.id, job_id.as_raw_id());
199 source_inner.source_id = source.id;
200 if let Some(id) = source.optional_associated_table_id {
201 source_inner.associated_table_id = Some(id.into());
202 }
203 }
204 }
205 _ => {}
207 }
208 }
209 NodeBody::StreamCdcScan(node) => {
210 if let Some(table_desc) = node.cdc_table_desc.as_mut() {
211 table_desc.table_id = job_id.as_mv_table_id();
212 }
213 }
214 NodeBody::VectorIndexWrite(node) => {
215 let table = node.table.as_mut().unwrap();
216 table.id = job_id.as_mv_table_id();
217 table.database_id = job.database_id();
218 table.schema_id = job.schema_id();
219 table.fragment_id = fragment_id;
220 #[cfg(not(debug_assertions))]
221 {
222 table.definition = job.name();
223 }
224
225 has_job = true;
226 }
227 _ => {}
228 });
229
230 has_job
231 }
232
233 fn extract_upstream_columns_except_cross_db_backfill(
235 fragment: &StreamFragment,
236 ) -> HashMap<JobId, Vec<PbColumnDesc>> {
237 let mut table_columns = HashMap::new();
238
239 stream_graph_visitor::visit_fragment(fragment, |node_body| {
240 let (table_id, column_ids) = match node_body {
241 NodeBody::StreamScan(stream_scan) => {
242 if stream_scan.get_stream_scan_type().unwrap()
243 == StreamScanType::CrossDbSnapshotBackfill
244 {
245 return;
246 }
247 (
248 stream_scan.table_id.as_job_id(),
249 stream_scan.upstream_columns(),
250 )
251 }
252 NodeBody::CdcFilter(cdc_filter) => (
253 cdc_filter.upstream_source_id.as_share_source_job_id(),
254 vec![],
255 ),
256 NodeBody::SourceBackfill(backfill) => (
257 backfill.upstream_source_id.as_share_source_job_id(),
258 backfill.column_descs(),
260 ),
261 _ => return,
262 };
263 table_columns
264 .try_insert(table_id, column_ids)
265 .expect("currently there should be no two same upstream tables in a fragment");
266 });
267
268 table_columns
269 }
270
271 pub fn has_shuffled_backfill(&self) -> bool {
272 let stream_node = match self.inner.node.as_ref() {
273 Some(node) => node,
274 _ => return false,
275 };
276 let mut has_shuffled_backfill = false;
277 let has_shuffled_backfill_mut_ref = &mut has_shuffled_backfill;
278 visit_stream_node_cont(stream_node, |node| {
279 let is_shuffled_backfill = if let Some(node) = &node.node_body
280 && let Some(node) = node.as_stream_scan()
281 {
282 node.stream_scan_type == StreamScanType::ArrangementBackfill as i32
283 || node.stream_scan_type == StreamScanType::SnapshotBackfill as i32
284 } else {
285 false
286 };
287 if is_shuffled_backfill {
288 *has_shuffled_backfill_mut_ref = true;
289 false
290 } else {
291 true
292 }
293 });
294 has_shuffled_backfill
295 }
296}
297
298impl Deref for BuildingFragment {
299 type Target = StreamFragment;
300
301 fn deref(&self) -> &Self::Target {
302 &self.inner
303 }
304}
305
306impl DerefMut for BuildingFragment {
307 fn deref_mut(&mut self) -> &mut Self::Target {
308 &mut self.inner
309 }
310}
311
312#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, EnumAsInner)]
315pub(super) enum EdgeId {
316 Internal {
318 link_id: u64,
321 },
322
323 UpstreamExternal {
326 upstream_job_id: JobId,
328 downstream_fragment_id: GlobalFragmentId,
330 },
331
332 DownstreamExternal(DownstreamExternalEdgeId),
335}
336
337#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
338pub(super) struct DownstreamExternalEdgeId {
339 pub(super) original_upstream_fragment_id: GlobalFragmentId,
341 pub(super) downstream_fragment_id: GlobalFragmentId,
343}
344
345#[derive(Debug, Clone)]
349pub(super) struct StreamFragmentEdge {
350 pub id: EdgeId,
352
353 pub dispatch_strategy: DispatchStrategy,
355}
356
357impl StreamFragmentEdge {
358 fn from_protobuf(edge: &StreamFragmentEdgeProto) -> Self {
359 Self {
360 id: EdgeId::Internal {
363 link_id: edge.link_id,
364 },
365 dispatch_strategy: edge.get_dispatch_strategy().unwrap().clone(),
366 }
367 }
368}
369
370fn clone_fragment(fragment: &Fragment, id_generator_manager: &IdGeneratorManager) -> Fragment {
371 let fragment_id = GlobalFragmentIdGen::new(id_generator_manager, 1)
372 .to_global_id(0)
373 .as_global_id();
374 Fragment {
375 fragment_id,
376 fragment_type_mask: fragment.fragment_type_mask,
377 distribution_type: fragment.distribution_type,
378 state_table_ids: fragment.state_table_ids.clone(),
379 maybe_vnode_count: fragment.maybe_vnode_count,
380 nodes: fragment.nodes.clone(),
381 }
382}
383
384pub fn check_sink_fragments_support_refresh_schema(
385 fragments: &BTreeMap<FragmentId, Fragment>,
386) -> MetaResult<()> {
387 if fragments.len() != 1 {
388 return Err(anyhow!(
389 "sink with auto schema change should have only 1 fragment, but got {:?}",
390 fragments.len()
391 )
392 .into());
393 }
394 let (_, fragment) = fragments.first_key_value().expect("non-empty");
395 let sink_node = &fragment.nodes;
396 let PbNodeBody::Sink(_) = sink_node.node_body.as_ref().unwrap() else {
397 return Err(anyhow!("expect PbNodeBody::Sink but got: {:?}", sink_node.node_body).into());
398 };
399 let [stream_input_node] = sink_node.input.as_slice() else {
400 panic!("Sink has more than 1 input: {:?}", sink_node.input);
401 };
402 let stream_scan_node = match stream_input_node.node_body.as_ref().unwrap() {
403 PbNodeBody::StreamScan(_) => stream_input_node,
404 PbNodeBody::Project(_) => {
405 let [stream_scan_node] = stream_input_node.input.as_slice() else {
406 return Err(anyhow!(
407 "Project node must have exactly 1 input for auto schema change, but got {:?}",
408 stream_input_node.input.len()
409 )
410 .into());
411 };
412 stream_scan_node
413 }
414 _ => {
415 return Err(anyhow!(
416 "expect PbNodeBody::StreamScan or PbNodeBody::Project but got: {:?}",
417 stream_input_node.node_body
418 )
419 .into());
420 }
421 };
422 let PbNodeBody::StreamScan(scan) = stream_scan_node.node_body.as_ref().unwrap() else {
423 return Err(anyhow!(
424 "expect PbNodeBody::StreamScan but got: {:?}",
425 stream_scan_node.node_body
426 )
427 .into());
428 };
429 let stream_scan_type = PbStreamScanType::try_from(scan.stream_scan_type).unwrap();
430 if stream_scan_type != PbStreamScanType::ArrangementBackfill {
431 return Err(anyhow!(
432 "unsupported stream_scan_type for auto refresh schema: {:?}",
433 stream_scan_type
434 )
435 .into());
436 }
437 let [merge_node, _batch_plan_node] = stream_scan_node.input.as_slice() else {
438 panic!(
439 "the number of StreamScan inputs is not 2: {:?}",
440 stream_scan_node.input
441 );
442 };
443 let NodeBody::Merge(_) = merge_node.node_body.as_ref().unwrap() else {
444 return Err(anyhow!(
445 "expect PbNodeBody::Merge but got: {:?}",
446 merge_node.node_body
447 )
448 .into());
449 };
450 Ok(())
451}
452
453struct ScanRewriteResult {
455 old_output_index_to_new_output_index: HashMap<u32, u32>,
456 new_output_index_by_column_id: HashMap<ColumnId, u32>,
457 output_fields: Vec<risingwave_pb::plan_common::Field>,
458}
459
460fn extend_sink_columns(
462 sink_columns: &mut Vec<PbColumnCatalog>,
463 new_columns: &[ColumnCatalog],
464 get_column_name: impl Fn(&String) -> String,
465) {
466 let next_column_id = sink_columns
467 .iter()
468 .map(|col| col.column_desc.as_ref().unwrap().column_id + 1)
469 .max()
470 .unwrap_or(1);
471 sink_columns.extend(new_columns.iter().enumerate().map(|(i, col)| {
472 let mut col = col.to_protobuf();
473 let column_desc = col.column_desc.as_mut().unwrap();
474 column_desc.column_id = next_column_id + (i as i32);
475 column_desc.name = get_column_name(&column_desc.name);
476 col
477 }));
478}
479
480fn build_new_sink_columns(
482 sink: &PbSink,
483 removed_column_ids: &HashSet<ColumnId>,
484 newly_added_columns: &[ColumnCatalog],
485) -> Vec<PbColumnCatalog> {
486 let mut columns: Vec<PbColumnCatalog> = sink
487 .columns
488 .iter()
489 .filter(|col| {
490 let column_id = ColumnId::new(col.column_desc.as_ref().unwrap().column_id as _);
491 !removed_column_ids.contains(&column_id)
492 })
493 .cloned()
494 .collect();
495 extend_sink_columns(&mut columns, newly_added_columns, |name| name.clone());
496 columns
497}
498
499fn rewrite_log_store_table(
501 log_store_table: &mut PbTable,
502 removed_log_store_column_names: &HashSet<String>,
503 newly_added_columns: &[ColumnCatalog],
504 upstream_table_name: &str,
505) {
506 log_store_table.columns.retain(|col| {
507 !removed_log_store_column_names.contains(&col.column_desc.as_ref().unwrap().name)
508 });
509 extend_sink_columns(&mut log_store_table.columns, newly_added_columns, |name| {
510 format!("{}_{}", upstream_table_name, name)
511 });
512 log_store_table.value_indices = (0..log_store_table.columns.len() as i32).collect();
513}
514
515fn rewrite_stream_scan_and_merge(
517 stream_scan_node: &mut StreamNode,
518 removed_column_ids: &HashSet<ColumnId>,
519 newly_added_columns: &[ColumnCatalog],
520 upstream_table: &PbTable,
521 upstream_table_fragment_id: FragmentId,
522) -> MetaResult<ScanRewriteResult> {
523 let PbNodeBody::StreamScan(scan) = stream_scan_node.node_body.as_mut().unwrap() else {
524 return Err(anyhow!(
525 "expect PbNodeBody::StreamScan but got: {:?}",
526 stream_scan_node.node_body
527 )
528 .into());
529 };
530 let [merge_node, _batch_plan_node] = stream_scan_node.input.as_mut_slice() else {
531 panic!(
532 "the number of StreamScan inputs is not 2: {:?}",
533 stream_scan_node.input
534 );
535 };
536 let NodeBody::Merge(merge) = merge_node.node_body.as_mut().unwrap() else {
537 return Err(anyhow!(
538 "expect PbNodeBody::Merge but got: {:?}",
539 merge_node.node_body
540 )
541 .into());
542 };
543
544 let stream_scan_type = PbStreamScanType::try_from(scan.stream_scan_type).unwrap();
545 if stream_scan_type != PbStreamScanType::ArrangementBackfill {
546 return Err(anyhow!(
547 "unsupported stream_scan_type for auto refresh schema: {:?}",
548 stream_scan_type
549 )
550 .into());
551 }
552
553 let upstream_columns_by_id: HashMap<i32, PbColumnDesc> = upstream_table
554 .columns
555 .iter()
556 .map(|col| {
557 let desc = col.column_desc.as_ref().unwrap().clone();
558 (desc.column_id, desc)
559 })
560 .collect();
561
562 let old_upstream_column_ids = scan.upstream_column_ids.clone();
563 let old_output_indices = scan.output_indices.clone();
564 let mut old_upstream_index_to_new_upstream_index = HashMap::new();
565 let mut new_upstream_column_ids = Vec::new();
566 for (old_idx, &column_id) in old_upstream_column_ids.iter().enumerate() {
567 if !removed_column_ids.contains(&ColumnId::new(column_id as _)) {
568 let new_idx = new_upstream_column_ids.len() as u32;
569 old_upstream_index_to_new_upstream_index.insert(old_idx as u32, new_idx);
570 new_upstream_column_ids.push(column_id);
571 }
572 }
573 let mut new_output_indices = Vec::new();
574 for old_output_index in &old_output_indices {
575 if let Some(new_index) = old_upstream_index_to_new_upstream_index.get(old_output_index) {
576 new_output_indices.push(*new_index);
577 }
578 }
579 for col in newly_added_columns {
580 let new_index = new_upstream_column_ids.len() as u32;
581 new_upstream_column_ids.push(col.column_id().get_id());
582 new_output_indices.push(new_index);
583 }
584
585 let new_output_column_ids: Vec<i32> = new_output_indices
586 .iter()
587 .map(|&idx| new_upstream_column_ids[idx as usize])
588 .collect();
589 let mut new_output_index_by_column_id = HashMap::new();
590 for (pos, &column_id) in new_output_column_ids.iter().enumerate() {
591 new_output_index_by_column_id.insert(ColumnId::new(column_id as _), pos as u32);
592 }
593 let mut old_output_index_to_new_output_index = HashMap::new();
594 for (old_pos, old_output_index) in old_output_indices.iter().enumerate() {
595 let column_id = old_upstream_column_ids[*old_output_index as usize];
596 if let Some(new_pos) = new_output_index_by_column_id.get(&ColumnId::new(column_id as _)) {
597 old_output_index_to_new_output_index.insert(old_pos as u32, *new_pos);
598 }
599 }
600
601 scan.arrangement_table = Some(upstream_table.clone());
602 scan.upstream_column_ids = new_upstream_column_ids;
603 scan.output_indices = new_output_indices;
604 let table_desc = scan.table_desc.as_mut().unwrap();
605 table_desc.columns = scan
606 .upstream_column_ids
607 .iter()
608 .map(|column_id| {
609 upstream_columns_by_id
610 .get(column_id)
611 .unwrap_or_else(|| panic!("upstream column id not found: {}", column_id))
612 .clone()
613 })
614 .collect();
615
616 stream_scan_node.fields = new_output_column_ids
617 .iter()
618 .map(|column_id| {
619 let col_desc = upstream_columns_by_id
620 .get(column_id)
621 .unwrap_or_else(|| panic!("upstream column id not found: {}", column_id));
622 Field::new(
623 format!("{}.{}", upstream_table.name, col_desc.name),
624 col_desc.column_type.as_ref().unwrap().into(),
625 )
626 .to_prost()
627 })
628 .collect();
629 stream_scan_node.identity = {
631 let columns = stream_scan_node
632 .fields
633 .iter()
634 .map(|col| &col.name)
635 .join(", ");
636 format!("StreamTableScan {{ table: t, columns: [{columns}] }}")
637 };
638
639 merge_node.fields = scan
641 .upstream_column_ids
642 .iter()
643 .map(|&column_id| {
644 let col_desc = upstream_columns_by_id
645 .get(&column_id)
646 .unwrap_or_else(|| panic!("upstream column id not found: {}", column_id));
647 Field::new(
648 col_desc.name.clone(),
649 col_desc.column_type.as_ref().unwrap().into(),
650 )
651 .to_prost()
652 })
653 .collect();
654 merge.upstream_fragment_id = upstream_table_fragment_id;
655
656 Ok(ScanRewriteResult {
657 old_output_index_to_new_output_index,
658 new_output_index_by_column_id,
659 output_fields: stream_scan_node.fields.clone(),
660 })
661}
662
663fn rewrite_project_node(
665 project_node: &mut StreamNode,
666 scan_rewrite: &ScanRewriteResult,
667 newly_added_columns: &[ColumnCatalog],
668 removed_column_ids: &HashSet<ColumnId>,
669 upstream_table_name: &str,
670) -> MetaResult<()> {
671 let PbNodeBody::Project(project_node_body) = project_node.node_body.as_mut().unwrap() else {
672 return Err(anyhow!(
673 "expect PbNodeBody::Project but got: {:?}",
674 project_node.node_body
675 )
676 .into());
677 };
678 let has_non_input_ref = project_node_body
679 .select_list
680 .iter()
681 .any(|expr| !matches!(expr.rex_node, Some(expr_node::RexNode::InputRef(_))));
682 if has_non_input_ref && !removed_column_ids.is_empty() {
683 return Err(anyhow!(
684 "auto schema change with drop column only supports Project with InputRef"
685 )
686 .into());
687 }
688
689 let mut new_select_list = Vec::with_capacity(project_node_body.select_list.len());
690 let mut new_project_fields = Vec::with_capacity(project_node.fields.len());
691 for (index, expr) in project_node_body.select_list.iter().enumerate() {
692 let mut new_expr = expr.clone();
693 if let Some(expr_node::RexNode::InputRef(old_index)) = new_expr.rex_node {
694 let Some(&new_index) = scan_rewrite
695 .old_output_index_to_new_output_index
696 .get(&old_index)
697 else {
698 continue;
699 };
700 new_expr.rex_node = Some(expr_node::RexNode::InputRef(new_index));
701 } else if !removed_column_ids.is_empty() {
702 return Err(anyhow!(
703 "auto schema change with drop column only supports Project with InputRef"
704 )
705 .into());
706 }
707 new_select_list.push(new_expr);
708 new_project_fields.push(project_node.fields[index].clone());
709 }
710
711 for col in newly_added_columns {
712 let Some(&new_index) = scan_rewrite
713 .new_output_index_by_column_id
714 .get(&col.column_id())
715 else {
716 return Err(anyhow!("new column id not found in scan output").into());
717 };
718 new_select_list.push(PbExprNode {
719 function_type: expr_node::Type::Unspecified as i32,
720 return_type: Some(col.data_type().to_protobuf()),
721 rex_node: Some(expr_node::RexNode::InputRef(new_index)),
722 });
723 new_project_fields.push(
724 Field::new(
725 format!("{}.{}", upstream_table_name, col.column_desc.name),
726 col.data_type().clone(),
727 )
728 .to_prost(),
729 );
730 }
731
732 project_node_body.select_list = new_select_list;
733 project_node.fields = new_project_fields;
734 Ok(())
735}
736
737pub fn rewrite_refresh_schema_sink_fragment(
738 original_sink_fragment: &Fragment,
739 sink: &PbSink,
740 newly_added_columns: &[ColumnCatalog],
741 removed_columns: &[ColumnCatalog],
742 upstream_table: &PbTable,
743 upstream_table_fragment_id: FragmentId,
744 id_generator_manager: &IdGeneratorManager,
745) -> MetaResult<(Fragment, Vec<PbColumnCatalog>, Option<PbTable>)> {
746 let removed_column_ids: HashSet<_> =
747 removed_columns.iter().map(|col| col.column_id()).collect();
748 let removed_log_store_column_names: HashSet<_> = removed_columns
749 .iter()
750 .map(|col| format!("{}_{}", upstream_table.name, col.column_desc.name))
751 .collect();
752 let new_sink_columns = build_new_sink_columns(sink, &removed_column_ids, newly_added_columns);
753
754 let mut new_sink_fragment = clone_fragment(original_sink_fragment, id_generator_manager);
755 let sink_node = &mut new_sink_fragment.nodes;
756 let PbNodeBody::Sink(sink_node_body) = sink_node.node_body.as_mut().unwrap() else {
757 return Err(anyhow!("expect PbNodeBody::Sink but got: {:?}", sink_node.node_body).into());
758 };
759 let [stream_input_node] = sink_node.input.as_mut_slice() else {
760 panic!("Sink has more than 1 input: {:?}", sink_node.input);
761 };
762 let stream_input_body = stream_input_node.node_body.as_ref().unwrap();
763 let stream_input_is_project = matches!(stream_input_body, PbNodeBody::Project(_));
764 let stream_input_is_scan = matches!(stream_input_body, PbNodeBody::StreamScan(_));
765 if !stream_input_is_project && !stream_input_is_scan {
766 return Err(anyhow!(
767 "expect PbNodeBody::StreamScan or PbNodeBody::Project but got: {:?}",
768 stream_input_body
769 )
770 .into());
771 }
772
773 sink_node.identity = {
776 let sink_type = SinkType::from_proto(sink.sink_type());
777 let sink_type_str = sink_type.type_str();
778 let column_names = new_sink_columns
779 .iter()
780 .map(|col| {
781 ColumnCatalog::from(col.clone())
782 .name_with_hidden()
783 .to_string()
784 })
785 .join(", ");
786 let downstream_pk = if !sink_type.is_append_only() {
787 let downstream_pk = sink
788 .downstream_pk
789 .iter()
790 .map(|i| &sink.columns[*i as usize].column_desc.as_ref().unwrap().name)
791 .collect_vec();
792 format!(", downstream_pk: {downstream_pk:?}")
793 } else {
794 "".to_owned()
795 };
796 format!("StreamSink {{ type: {sink_type_str}, columns: [{column_names}]{downstream_pk} }}")
797 };
798 let new_log_store_table = if let Some(log_store_table) = &mut sink_node_body.table {
799 rewrite_log_store_table(
800 log_store_table,
801 &removed_log_store_column_names,
802 newly_added_columns,
803 &upstream_table.name,
804 );
805 Some(log_store_table.clone())
806 } else {
807 None
808 };
809 sink_node_body.sink_desc.as_mut().unwrap().column_catalogs = new_sink_columns.clone();
810
811 let stream_scan_node = if stream_input_is_project {
812 let [stream_scan_node] = stream_input_node.input.as_mut_slice() else {
813 return Err(anyhow!(
814 "Project node must have exactly 1 input for auto schema change, but got {:?}",
815 stream_input_node.input.len()
816 )
817 .into());
818 };
819 stream_scan_node
820 } else {
821 stream_input_node
822 };
823 let scan_rewrite = rewrite_stream_scan_and_merge(
824 stream_scan_node,
825 &removed_column_ids,
826 newly_added_columns,
827 upstream_table,
828 upstream_table_fragment_id,
829 )?;
830
831 if stream_input_is_project {
832 let [project_node] = sink_node.input.as_mut_slice() else {
833 panic!("Sink has more than 1 input: {:?}", sink_node.input);
834 };
835 rewrite_project_node(
836 project_node,
837 &scan_rewrite,
838 newly_added_columns,
839 &removed_column_ids,
840 &upstream_table.name,
841 )?;
842 sink_node.fields = project_node.fields.clone();
843 } else {
844 sink_node.fields = scan_rewrite.output_fields;
845 }
846 Ok((new_sink_fragment, new_sink_columns, new_log_store_table))
847}
848
849#[derive(Clone, Debug, Default)]
854pub struct FragmentBackfillOrder<const EXTENDED: bool> {
855 inner: HashMap<FragmentId, Vec<FragmentId>>,
856}
857
858impl<const EXTENDED: bool> Deref for FragmentBackfillOrder<EXTENDED> {
859 type Target = HashMap<FragmentId, Vec<FragmentId>>;
860
861 fn deref(&self) -> &Self::Target {
862 &self.inner
863 }
864}
865
866impl UserDefinedFragmentBackfillOrder {
867 pub fn new(inner: HashMap<FragmentId, Vec<FragmentId>>) -> Self {
868 Self { inner }
869 }
870
871 pub fn merge(orders: impl Iterator<Item = Self>) -> Self {
872 Self {
873 inner: orders.flat_map(|order| order.inner).collect(),
874 }
875 }
876
877 pub fn to_meta_model(&self) -> BackfillOrders {
878 self.inner.clone().into()
879 }
880}
881
882pub type UserDefinedFragmentBackfillOrder = FragmentBackfillOrder<false>;
883pub type ExtendedFragmentBackfillOrder = FragmentBackfillOrder<true>;
884
885#[derive(Default, Debug)]
892pub struct StreamFragmentGraph {
893 pub(super) fragments: HashMap<GlobalFragmentId, BuildingFragment>,
895
896 pub(super) downstreams:
898 HashMap<GlobalFragmentId, HashMap<GlobalFragmentId, StreamFragmentEdge>>,
899
900 pub(super) upstreams: HashMap<GlobalFragmentId, HashMap<GlobalFragmentId, StreamFragmentEdge>>,
902
903 dependent_table_ids: HashSet<TableId>,
905
906 specified_parallelism: Option<NonZeroUsize>,
909 specified_backfill_parallelism: Option<NonZeroUsize>,
912
913 max_parallelism: usize,
923
924 backfill_order: BackfillOrder,
926}
927
928impl StreamFragmentGraph {
929 pub fn new(
932 env: &MetaSrvEnv,
933 proto: StreamFragmentGraphProto,
934 job: &StreamingJob,
935 ) -> MetaResult<Self> {
936 let fragment_id_gen =
937 GlobalFragmentIdGen::new(env.id_gen_manager(), proto.fragments.len() as u64);
938 let table_id_gen = GlobalTableIdGen::new(env.id_gen_manager(), proto.table_ids_cnt as u64);
942
943 let fragments: HashMap<_, _> = proto
945 .fragments
946 .into_iter()
947 .map(|(id, fragment)| {
948 let id = fragment_id_gen.to_global_id(id.as_raw_id());
949 let fragment = BuildingFragment::new(id, fragment, job, table_id_gen);
950 (id, fragment)
951 })
952 .collect();
953
954 assert_eq!(
955 fragments
956 .values()
957 .map(|f| f.extract_internal_tables().len() as u32)
958 .sum::<u32>(),
959 proto.table_ids_cnt
960 );
961
962 let mut downstreams = HashMap::new();
964 let mut upstreams = HashMap::new();
965
966 for edge in proto.edges {
967 let upstream_id = fragment_id_gen.to_global_id(edge.upstream_id.as_raw_id());
968 let downstream_id = fragment_id_gen.to_global_id(edge.downstream_id.as_raw_id());
969 let edge = StreamFragmentEdge::from_protobuf(&edge);
970
971 upstreams
972 .entry(downstream_id)
973 .or_insert_with(HashMap::new)
974 .try_insert(upstream_id, edge.clone())
975 .unwrap();
976 downstreams
977 .entry(upstream_id)
978 .or_insert_with(HashMap::new)
979 .try_insert(downstream_id, edge)
980 .unwrap();
981 }
982
983 let dependent_table_ids = proto.dependent_table_ids.iter().copied().collect();
986
987 let specified_parallelism = if let Some(Parallelism { parallelism }) = proto.parallelism {
988 Some(NonZeroUsize::new(parallelism as usize).context("parallelism should not be 0")?)
989 } else {
990 None
991 };
992 let specified_backfill_parallelism =
993 if let Some(Parallelism { parallelism }) = proto.backfill_parallelism {
994 Some(
995 NonZeroUsize::new(parallelism as usize)
996 .context("backfill parallelism should not be 0")?,
997 )
998 } else {
999 None
1000 };
1001
1002 let max_parallelism = proto.max_parallelism as usize;
1003 let backfill_order = proto.backfill_order.unwrap_or(BackfillOrder {
1004 order: Default::default(),
1005 });
1006
1007 Ok(Self {
1008 fragments,
1009 downstreams,
1010 upstreams,
1011 dependent_table_ids,
1012 specified_parallelism,
1013 specified_backfill_parallelism,
1014 max_parallelism,
1015 backfill_order,
1016 })
1017 }
1018
1019 pub fn incomplete_internal_tables(&self) -> BTreeMap<TableId, Table> {
1025 let mut tables = BTreeMap::new();
1026 for fragment in self.fragments.values() {
1027 for table in fragment.extract_internal_tables() {
1028 let table_id = table.id;
1029 tables
1030 .try_insert(table_id, table)
1031 .unwrap_or_else(|_| panic!("duplicated table id `{}`", table_id));
1032 }
1033 }
1034 tables
1035 }
1036
1037 pub fn refill_internal_table_ids(&mut self, table_id_map: HashMap<TableId, TableId>) {
1040 for fragment in self.fragments.values_mut() {
1041 stream_graph_visitor::visit_internal_tables(
1042 &mut fragment.inner,
1043 |table, _table_type_name| {
1044 let target = table_id_map.get(&table.id).cloned().unwrap();
1045 table.id = target;
1046 },
1047 );
1048 }
1049 }
1050
1051 pub fn fit_internal_tables_trivial(
1054 &mut self,
1055 mut old_internal_tables: Vec<Table>,
1056 ) -> MetaResult<()> {
1057 let mut new_internal_table_ids = Vec::new();
1058 for fragment in self.fragments.values() {
1059 for table in &fragment.extract_internal_tables() {
1060 new_internal_table_ids.push(table.id);
1061 }
1062 }
1063
1064 if new_internal_table_ids.len() != old_internal_tables.len() {
1065 bail!(
1066 "Different number of internal tables. New: {}, Old: {}",
1067 new_internal_table_ids.len(),
1068 old_internal_tables.len()
1069 );
1070 }
1071 old_internal_tables.sort_by(|a, b| a.id.cmp(&b.id));
1072 new_internal_table_ids.sort();
1073
1074 let internal_table_id_map = new_internal_table_ids
1075 .into_iter()
1076 .zip_eq_fast(old_internal_tables.into_iter())
1077 .collect::<HashMap<_, _>>();
1078
1079 for fragment in self.fragments.values_mut() {
1082 stream_graph_visitor::visit_internal_tables(
1083 &mut fragment.inner,
1084 |table, _table_type_name| {
1085 let target = internal_table_id_map.get(&table.id).cloned().unwrap();
1087 *table = target;
1088 },
1089 );
1090 }
1091
1092 Ok(())
1093 }
1094
1095 pub fn fit_internal_table_ids_with_mapping(&mut self, mut matches: HashMap<TableId, Table>) {
1097 for fragment in self.fragments.values_mut() {
1098 stream_graph_visitor::visit_internal_tables(
1099 &mut fragment.inner,
1100 |table, _table_type_name| {
1101 let target = matches.remove(&table.id).unwrap_or_else(|| {
1102 panic!("no matching table for table {}({})", table.id, table.name)
1103 });
1104 table.id = target.id;
1105 table.maybe_vnode_count = target.maybe_vnode_count;
1106 },
1107 );
1108 }
1109 }
1110
1111 pub fn fit_snapshot_backfill_epochs(
1112 &mut self,
1113 mut snapshot_backfill_epochs: HashMap<StreamNodeLocalOperatorId, u64>,
1114 ) {
1115 for fragment in self.fragments.values_mut() {
1116 visit_stream_node_cont_mut(fragment.node.as_mut().unwrap(), |node| {
1117 if let PbNodeBody::StreamScan(scan) = node.node_body.as_mut().unwrap()
1118 && let StreamScanType::SnapshotBackfill
1119 | StreamScanType::CrossDbSnapshotBackfill = scan.stream_scan_type()
1120 {
1121 let Some(epoch) = snapshot_backfill_epochs.remove(&node.operator_id) else {
1122 panic!("no snapshot epoch found for node {:?}", node)
1123 };
1124 scan.snapshot_backfill_epoch = Some(epoch);
1125 }
1126 true
1127 })
1128 }
1129 }
1130
1131 pub fn table_fragment_id(&self) -> FragmentId {
1133 self.fragments
1134 .values()
1135 .filter(|b| b.job_id.is_some())
1136 .map(|b| b.fragment_id)
1137 .exactly_one()
1138 .expect("require exactly 1 materialize/sink/cdc source node when creating the streaming job")
1139 }
1140
1141 pub fn dml_fragment_id(&self) -> Option<FragmentId> {
1143 self.fragments
1144 .values()
1145 .filter(|b| {
1146 FragmentTypeMask::from(b.fragment_type_mask).contains(FragmentTypeFlag::Dml)
1147 })
1148 .map(|b| b.fragment_id)
1149 .at_most_one()
1150 .expect("require at most 1 dml node when creating the streaming job")
1151 }
1152
1153 pub fn dependent_table_ids(&self) -> &HashSet<TableId> {
1155 &self.dependent_table_ids
1156 }
1157
1158 pub fn specified_parallelism(&self) -> Option<NonZeroUsize> {
1160 self.specified_parallelism
1161 }
1162
1163 pub fn specified_backfill_parallelism(&self) -> Option<NonZeroUsize> {
1165 self.specified_backfill_parallelism
1166 }
1167
1168 pub fn max_parallelism(&self) -> usize {
1170 self.max_parallelism
1171 }
1172
1173 fn get_downstreams(
1175 &self,
1176 fragment_id: GlobalFragmentId,
1177 ) -> &HashMap<GlobalFragmentId, StreamFragmentEdge> {
1178 self.downstreams.get(&fragment_id).unwrap_or(&EMPTY_HASHMAP)
1179 }
1180
1181 fn get_upstreams(
1183 &self,
1184 fragment_id: GlobalFragmentId,
1185 ) -> &HashMap<GlobalFragmentId, StreamFragmentEdge> {
1186 self.upstreams.get(&fragment_id).unwrap_or(&EMPTY_HASHMAP)
1187 }
1188
1189 pub fn collect_snapshot_backfill_info(
1190 &self,
1191 ) -> MetaResult<(Option<SnapshotBackfillInfo>, SnapshotBackfillInfo)> {
1192 Self::collect_snapshot_backfill_info_impl(self.fragments.values().map(|fragment| {
1193 (
1194 fragment.node.as_ref().unwrap(),
1195 fragment.fragment_type_mask.into(),
1196 )
1197 }))
1198 }
1199
1200 pub fn collect_snapshot_backfill_info_impl(
1202 fragments: impl IntoIterator<Item = (&PbStreamNode, FragmentTypeMask)>,
1203 ) -> MetaResult<(Option<SnapshotBackfillInfo>, SnapshotBackfillInfo)> {
1204 let mut prev_stream_scan: Option<(Option<SnapshotBackfillInfo>, StreamScanNode)> = None;
1205 let mut cross_db_info = SnapshotBackfillInfo {
1206 upstream_mv_table_id_to_backfill_epoch: Default::default(),
1207 };
1208 let mut result = Ok(());
1209 for (node, fragment_type_mask) in fragments {
1210 visit_stream_node_cont(node, |node| {
1211 if let Some(NodeBody::StreamScan(stream_scan)) = node.node_body.as_ref() {
1212 let stream_scan_type = StreamScanType::try_from(stream_scan.stream_scan_type)
1213 .expect("invalid stream_scan_type");
1214 let is_snapshot_backfill = match stream_scan_type {
1215 StreamScanType::SnapshotBackfill => {
1216 assert!(
1217 fragment_type_mask
1218 .contains(FragmentTypeFlag::SnapshotBackfillStreamScan)
1219 );
1220 true
1221 }
1222 StreamScanType::CrossDbSnapshotBackfill => {
1223 assert!(
1224 fragment_type_mask
1225 .contains(FragmentTypeFlag::CrossDbSnapshotBackfillStreamScan)
1226 );
1227 cross_db_info
1228 .upstream_mv_table_id_to_backfill_epoch
1229 .insert(stream_scan.table_id, stream_scan.snapshot_backfill_epoch);
1230
1231 return true;
1232 }
1233 _ => false,
1234 };
1235
1236 match &mut prev_stream_scan {
1237 Some((prev_snapshot_backfill_info, prev_stream_scan)) => {
1238 match (prev_snapshot_backfill_info, is_snapshot_backfill) {
1239 (Some(prev_snapshot_backfill_info), true) => {
1240 prev_snapshot_backfill_info
1241 .upstream_mv_table_id_to_backfill_epoch
1242 .insert(
1243 stream_scan.table_id,
1244 stream_scan.snapshot_backfill_epoch,
1245 );
1246 true
1247 }
1248 (None, false) => true,
1249 (_, _) => {
1250 result = Err(anyhow!("must be either all snapshot_backfill or no snapshot_backfill. Curr: {stream_scan:?} Prev: {prev_stream_scan:?}").into());
1251 false
1252 }
1253 }
1254 }
1255 None => {
1256 prev_stream_scan = Some((
1257 if is_snapshot_backfill {
1258 Some(SnapshotBackfillInfo {
1259 upstream_mv_table_id_to_backfill_epoch: HashMap::from_iter(
1260 [(
1261 stream_scan.table_id,
1262 stream_scan.snapshot_backfill_epoch,
1263 )],
1264 ),
1265 })
1266 } else {
1267 None
1268 },
1269 *stream_scan.clone(),
1270 ));
1271 true
1272 }
1273 }
1274 } else {
1275 true
1276 }
1277 })
1278 }
1279 result.map(|_| {
1280 (
1281 prev_stream_scan
1282 .map(|(snapshot_backfill_info, _)| snapshot_backfill_info)
1283 .unwrap_or(None),
1284 cross_db_info,
1285 )
1286 })
1287 }
1288
1289 pub fn collect_backfill_mapping(
1291 fragments: impl Iterator<Item = (FragmentId, FragmentTypeMask, &PbStreamNode)>,
1292 ) -> HashMap<RelationId, Vec<FragmentId>> {
1293 let mut mapping = HashMap::new();
1294 for (fragment_id, fragment_type_mask, node) in fragments {
1295 let has_some_scan = fragment_type_mask
1296 .contains_any([FragmentTypeFlag::StreamScan, FragmentTypeFlag::SourceScan]);
1297 if has_some_scan {
1298 visit_stream_node_cont(node, |node| {
1299 match node.node_body.as_ref() {
1300 Some(NodeBody::StreamScan(stream_scan)) => {
1301 let table_id = stream_scan.table_id;
1302 let fragments: &mut Vec<_> =
1303 mapping.entry(table_id.as_relation_id()).or_default();
1304 fragments.push(fragment_id);
1305 false
1307 }
1308 Some(NodeBody::SourceBackfill(source_backfill)) => {
1309 let source_id = source_backfill.upstream_source_id;
1310 let fragments: &mut Vec<_> =
1311 mapping.entry(source_id.as_relation_id()).or_default();
1312 fragments.push(fragment_id);
1313 false
1315 }
1316 _ => true,
1317 }
1318 })
1319 }
1320 }
1321 mapping
1322 }
1323
1324 pub fn create_fragment_backfill_ordering(&self) -> UserDefinedFragmentBackfillOrder {
1328 let mapping =
1329 Self::collect_backfill_mapping(self.fragments.iter().map(|(fragment_id, fragment)| {
1330 (
1331 fragment_id.as_global_id(),
1332 fragment.fragment_type_mask.into(),
1333 fragment.node.as_ref().expect("should exist node"),
1334 )
1335 }));
1336 let mut fragment_ordering: HashMap<FragmentId, Vec<FragmentId>> = HashMap::new();
1337
1338 for (rel_id, downstream_rel_ids) in &self.backfill_order.order {
1340 let fragment_ids = mapping.get(rel_id).unwrap();
1341 for fragment_id in fragment_ids {
1342 let downstream_fragment_ids = downstream_rel_ids
1343 .data
1344 .iter()
1345 .flat_map(|&downstream_rel_id| mapping.get(&downstream_rel_id).unwrap().iter())
1346 .copied()
1347 .collect();
1348 fragment_ordering.insert(*fragment_id, downstream_fragment_ids);
1349 }
1350 }
1351
1352 UserDefinedFragmentBackfillOrder {
1353 inner: fragment_ordering,
1354 }
1355 }
1356
1357 pub fn extend_fragment_backfill_ordering_with_locality_backfill<
1358 'a,
1359 FI: Iterator<Item = (FragmentId, FragmentTypeMask, &'a PbStreamNode)> + 'a,
1360 >(
1361 fragment_ordering: UserDefinedFragmentBackfillOrder,
1362 fragment_downstreams: &FragmentDownstreamRelation,
1363 get_fragments: impl Fn() -> FI,
1364 ) -> ExtendedFragmentBackfillOrder {
1365 let mut fragment_ordering = fragment_ordering.inner;
1366 let mapping = Self::collect_backfill_mapping(get_fragments());
1367 if fragment_ordering.is_empty() {
1370 for value in mapping.values() {
1371 for &fragment_id in value {
1372 fragment_ordering.entry(fragment_id).or_default();
1373 }
1374 }
1375 }
1376
1377 let locality_provider_dependencies = Self::find_locality_provider_dependencies(
1379 get_fragments().map(|(fragment_id, _, node)| (fragment_id, node)),
1380 fragment_downstreams,
1381 );
1382
1383 let backfill_fragments: HashSet<FragmentId> = mapping.values().flatten().copied().collect();
1384
1385 let all_locality_provider_fragments: HashSet<FragmentId> =
1388 locality_provider_dependencies.keys().copied().collect();
1389 let downstream_locality_provider_fragments: HashSet<FragmentId> =
1390 locality_provider_dependencies
1391 .values()
1392 .flatten()
1393 .copied()
1394 .collect();
1395 let locality_provider_root_fragments: Vec<FragmentId> = all_locality_provider_fragments
1396 .difference(&downstream_locality_provider_fragments)
1397 .copied()
1398 .collect();
1399
1400 for &backfill_fragment_id in &backfill_fragments {
1403 fragment_ordering
1404 .entry(backfill_fragment_id)
1405 .or_default()
1406 .extend(locality_provider_root_fragments.iter().copied());
1407 }
1408
1409 for (fragment_id, downstream_fragments) in locality_provider_dependencies {
1411 fragment_ordering
1412 .entry(fragment_id)
1413 .or_default()
1414 .extend(downstream_fragments);
1415 }
1416
1417 for downstream in fragment_ordering.values_mut() {
1421 let mut seen = HashSet::new();
1422 downstream.retain(|id| seen.insert(*id));
1423 }
1424
1425 ExtendedFragmentBackfillOrder {
1426 inner: fragment_ordering,
1427 }
1428 }
1429
1430 pub fn find_locality_provider_fragment_state_table_mapping(
1431 &self,
1432 ) -> HashMap<FragmentId, Vec<TableId>> {
1433 let mut mapping: HashMap<FragmentId, Vec<TableId>> = HashMap::new();
1434
1435 for (fragment_id, fragment) in &self.fragments {
1436 let fragment_id = fragment_id.as_global_id();
1437
1438 if let Some(node) = fragment.node.as_ref() {
1440 let mut state_table_ids = Vec::new();
1441
1442 visit_stream_node_cont(node, |stream_node| {
1443 if let Some(NodeBody::LocalityProvider(locality_provider)) =
1444 stream_node.node_body.as_ref()
1445 {
1446 let state_table_id = locality_provider
1448 .state_table
1449 .as_ref()
1450 .expect("must have state table")
1451 .id;
1452 state_table_ids.push(state_table_id);
1453 false } else {
1455 true }
1457 });
1458
1459 if !state_table_ids.is_empty() {
1460 mapping.insert(fragment_id, state_table_ids);
1461 }
1462 }
1463 }
1464
1465 mapping
1466 }
1467
1468 pub fn find_locality_provider_dependencies<'a>(
1476 fragments_nodes: impl Iterator<Item = (FragmentId, &'a PbStreamNode)>,
1477 fragment_downstreams: &FragmentDownstreamRelation,
1478 ) -> HashMap<FragmentId, Vec<FragmentId>> {
1479 let mut locality_provider_fragments = HashSet::new();
1480 let mut dependencies: HashMap<FragmentId, Vec<FragmentId>> = HashMap::new();
1481
1482 for (fragment_id, node) in fragments_nodes {
1484 let has_locality_provider = Self::fragment_has_locality_provider(node);
1485
1486 if has_locality_provider {
1487 locality_provider_fragments.insert(fragment_id);
1488 dependencies.entry(fragment_id).or_default();
1489 }
1490 }
1491
1492 for &provider_fragment_id in &locality_provider_fragments {
1496 let mut visited = HashSet::new();
1498 let mut downstream_locality_providers = Vec::new();
1499
1500 Self::collect_downstream_locality_providers(
1501 provider_fragment_id,
1502 &locality_provider_fragments,
1503 fragment_downstreams,
1504 &mut visited,
1505 &mut downstream_locality_providers,
1506 );
1507
1508 dependencies
1510 .entry(provider_fragment_id)
1511 .or_default()
1512 .extend(downstream_locality_providers);
1513 }
1514
1515 dependencies
1516 }
1517
1518 fn fragment_has_locality_provider(node: &PbStreamNode) -> bool {
1519 let mut has_locality_provider = false;
1520
1521 {
1522 visit_stream_node_cont(node, |stream_node| {
1523 if let Some(NodeBody::LocalityProvider(_)) = stream_node.node_body.as_ref() {
1524 has_locality_provider = true;
1525 false } else {
1527 true }
1529 });
1530 }
1531
1532 has_locality_provider
1533 }
1534
1535 fn collect_downstream_locality_providers(
1537 current_fragment_id: FragmentId,
1538 locality_provider_fragments: &HashSet<FragmentId>,
1539 fragment_downstreams: &FragmentDownstreamRelation,
1540 visited: &mut HashSet<FragmentId>,
1541 downstream_providers: &mut Vec<FragmentId>,
1542 ) {
1543 if visited.contains(¤t_fragment_id) {
1544 return;
1545 }
1546 visited.insert(current_fragment_id);
1547
1548 for downstream_fragment_id in fragment_downstreams
1550 .get(¤t_fragment_id)
1551 .into_iter()
1552 .flat_map(|downstreams| {
1553 downstreams
1554 .iter()
1555 .map(|downstream| downstream.downstream_fragment_id)
1556 })
1557 {
1558 if locality_provider_fragments.contains(&downstream_fragment_id) {
1560 downstream_providers.push(downstream_fragment_id);
1561 }
1562
1563 Self::collect_downstream_locality_providers(
1565 downstream_fragment_id,
1566 locality_provider_fragments,
1567 fragment_downstreams,
1568 visited,
1569 downstream_providers,
1570 );
1571 }
1572 }
1573}
1574
1575pub fn fill_snapshot_backfill_epoch(
1578 node: &mut StreamNode,
1579 snapshot_backfill_info: Option<&SnapshotBackfillInfo>,
1580 cross_db_snapshot_backfill_info: &SnapshotBackfillInfo,
1581) -> MetaResult<bool> {
1582 let mut result = Ok(());
1583 let mut applied = false;
1584 visit_stream_node_cont_mut(node, |node| {
1585 if let Some(NodeBody::StreamScan(stream_scan)) = node.node_body.as_mut()
1586 && (stream_scan.stream_scan_type == StreamScanType::SnapshotBackfill as i32
1587 || stream_scan.stream_scan_type == StreamScanType::CrossDbSnapshotBackfill as i32)
1588 {
1589 result = try {
1590 let table_id = stream_scan.table_id;
1591 let snapshot_epoch = cross_db_snapshot_backfill_info
1592 .upstream_mv_table_id_to_backfill_epoch
1593 .get(&table_id)
1594 .or_else(|| {
1595 snapshot_backfill_info.and_then(|snapshot_backfill_info| {
1596 snapshot_backfill_info
1597 .upstream_mv_table_id_to_backfill_epoch
1598 .get(&table_id)
1599 })
1600 })
1601 .ok_or_else(|| anyhow!("upstream table id not covered: {}", table_id))?
1602 .ok_or_else(|| anyhow!("upstream table id not set: {}", table_id))?;
1603 if let Some(prev_snapshot_epoch) =
1604 stream_scan.snapshot_backfill_epoch.replace(snapshot_epoch)
1605 {
1606 Err(anyhow!(
1607 "snapshot backfill epoch set again: {} {} {}",
1608 table_id,
1609 prev_snapshot_epoch,
1610 snapshot_epoch
1611 ))?;
1612 }
1613 applied = true;
1614 };
1615 result.is_ok()
1616 } else {
1617 true
1618 }
1619 });
1620 result.map(|_| applied)
1621}
1622
1623static EMPTY_HASHMAP: LazyLock<HashMap<GlobalFragmentId, StreamFragmentEdge>> =
1624 LazyLock::new(HashMap::new);
1625
1626#[derive(Debug, Clone, EnumAsInner)]
1629pub(super) enum EitherFragment {
1630 Building(BuildingFragment),
1632
1633 Existing,
1635}
1636
1637#[derive(Debug)]
1646pub struct CompleteStreamFragmentGraph {
1647 building_graph: StreamFragmentGraph,
1649
1650 existing_fragments: HashMap<GlobalFragmentId, Fragment>,
1652
1653 extra_downstreams: HashMap<GlobalFragmentId, HashMap<GlobalFragmentId, StreamFragmentEdge>>,
1655
1656 extra_upstreams: HashMap<GlobalFragmentId, HashMap<GlobalFragmentId, StreamFragmentEdge>>,
1658}
1659
1660pub struct FragmentGraphUpstreamContext {
1661 pub upstream_root_fragments: HashMap<JobId, Fragment>,
1664}
1665
1666pub struct FragmentGraphDownstreamContext {
1667 pub original_root_fragment_id: FragmentId,
1668 pub downstream_fragments: Vec<(DispatcherType, Fragment)>,
1669}
1670
1671impl CompleteStreamFragmentGraph {
1672 #[cfg(test)]
1675 pub fn for_test(graph: StreamFragmentGraph) -> Self {
1676 Self {
1677 building_graph: graph,
1678 existing_fragments: Default::default(),
1679 extra_downstreams: Default::default(),
1680 extra_upstreams: Default::default(),
1681 }
1682 }
1683
1684 pub fn with_upstreams(
1688 graph: StreamFragmentGraph,
1689 upstream_context: FragmentGraphUpstreamContext,
1690 job_type: StreamingJobType,
1691 ) -> MetaResult<Self> {
1692 Self::build_helper(graph, Some(upstream_context), None, job_type)
1693 }
1694
1695 pub fn with_downstreams(
1698 graph: StreamFragmentGraph,
1699 downstream_context: FragmentGraphDownstreamContext,
1700 job_type: StreamingJobType,
1701 ) -> MetaResult<Self> {
1702 Self::build_helper(graph, None, Some(downstream_context), job_type)
1703 }
1704
1705 pub fn with_upstreams_and_downstreams(
1707 graph: StreamFragmentGraph,
1708 upstream_context: FragmentGraphUpstreamContext,
1709 downstream_context: FragmentGraphDownstreamContext,
1710 job_type: StreamingJobType,
1711 ) -> MetaResult<Self> {
1712 Self::build_helper(
1713 graph,
1714 Some(upstream_context),
1715 Some(downstream_context),
1716 job_type,
1717 )
1718 }
1719
1720 fn build_helper(
1722 mut graph: StreamFragmentGraph,
1723 upstream_ctx: Option<FragmentGraphUpstreamContext>,
1724 downstream_ctx: Option<FragmentGraphDownstreamContext>,
1725 job_type: StreamingJobType,
1726 ) -> MetaResult<Self> {
1727 let mut extra_downstreams = HashMap::new();
1728 let mut extra_upstreams = HashMap::new();
1729 let mut existing_fragments = HashMap::new();
1730
1731 if let Some(FragmentGraphUpstreamContext {
1732 upstream_root_fragments,
1733 }) = upstream_ctx
1734 {
1735 for (&id, fragment) in &mut graph.fragments {
1736 let uses_shuffled_backfill = fragment.has_shuffled_backfill();
1737
1738 for (&upstream_job_id, required_columns) in &fragment.upstream_job_columns {
1739 let upstream_fragment = upstream_root_fragments
1740 .get(&upstream_job_id)
1741 .context("upstream fragment not found")?;
1742 let upstream_root_fragment_id =
1743 GlobalFragmentId::new(upstream_fragment.fragment_id);
1744
1745 let edge = match job_type {
1746 StreamingJobType::Table(TableJobType::SharedCdcSource) => {
1747 assert_ne!(
1750 (fragment.fragment_type_mask & FragmentTypeFlag::CdcFilter as u32),
1751 0
1752 );
1753
1754 tracing::debug!(
1755 ?upstream_root_fragment_id,
1756 ?required_columns,
1757 identity = ?fragment.inner.get_node().unwrap().get_identity(),
1758 current_frag_id=?id,
1759 "CdcFilter with upstream source fragment"
1760 );
1761
1762 StreamFragmentEdge {
1763 id: EdgeId::UpstreamExternal {
1764 upstream_job_id,
1765 downstream_fragment_id: id,
1766 },
1767 dispatch_strategy: DispatchStrategy {
1770 r#type: DispatcherType::NoShuffle as _,
1771 dist_key_indices: vec![], output_mapping: DispatchOutputMapping::identical(
1773 CDC_SOURCE_COLUMN_NUM as _,
1774 )
1775 .into(),
1776 },
1777 }
1778 }
1779
1780 StreamingJobType::MaterializedView
1782 | StreamingJobType::Sink
1783 | StreamingJobType::Index => {
1784 if upstream_fragment
1787 .fragment_type_mask
1788 .contains(FragmentTypeFlag::Mview)
1789 {
1790 let (dist_key_indices, output_mapping) = {
1792 let mview_node = upstream_fragment
1793 .nodes
1794 .get_node_body()
1795 .unwrap()
1796 .as_materialize()
1797 .unwrap();
1798 let all_columns = mview_node.column_descs();
1799 let dist_key_indices = mview_node.dist_key_indices();
1800 let output_mapping = gen_output_mapping(
1801 required_columns,
1802 &all_columns,
1803 )
1804 .context(
1805 "BUG: column not found in the upstream materialized view",
1806 )?;
1807 (dist_key_indices, output_mapping)
1808 };
1809 let dispatch_strategy = mv_on_mv_dispatch_strategy(
1810 uses_shuffled_backfill,
1811 dist_key_indices,
1812 output_mapping,
1813 );
1814
1815 StreamFragmentEdge {
1816 id: EdgeId::UpstreamExternal {
1817 upstream_job_id,
1818 downstream_fragment_id: id,
1819 },
1820 dispatch_strategy,
1821 }
1822 }
1823 else if upstream_fragment
1826 .fragment_type_mask
1827 .contains(FragmentTypeFlag::Source)
1828 {
1829 let output_mapping = {
1830 let source_node = upstream_fragment
1831 .nodes
1832 .get_node_body()
1833 .unwrap()
1834 .as_source()
1835 .unwrap();
1836
1837 let all_columns = source_node.column_descs().unwrap();
1838 gen_output_mapping(required_columns, &all_columns).context(
1839 "BUG: column not found in the upstream source node",
1840 )?
1841 };
1842
1843 StreamFragmentEdge {
1844 id: EdgeId::UpstreamExternal {
1845 upstream_job_id,
1846 downstream_fragment_id: id,
1847 },
1848 dispatch_strategy: DispatchStrategy {
1851 r#type: DispatcherType::NoShuffle as _,
1852 dist_key_indices: vec![], output_mapping: Some(output_mapping),
1854 },
1855 }
1856 } else {
1857 bail!(
1858 "the upstream fragment should be a MView or Source, got fragment type: {:b}",
1859 upstream_fragment.fragment_type_mask
1860 )
1861 }
1862 }
1863 StreamingJobType::Source | StreamingJobType::Table(_) => {
1864 bail!(
1865 "the streaming job shouldn't have an upstream fragment, job_type: {:?}",
1866 job_type
1867 )
1868 }
1869 };
1870
1871 extra_downstreams
1873 .entry(upstream_root_fragment_id)
1874 .or_insert_with(HashMap::new)
1875 .try_insert(id, edge.clone())
1876 .unwrap();
1877 extra_upstreams
1878 .entry(id)
1879 .or_insert_with(HashMap::new)
1880 .try_insert(upstream_root_fragment_id, edge)
1881 .unwrap();
1882 }
1883 }
1884
1885 existing_fragments.extend(
1886 upstream_root_fragments
1887 .into_values()
1888 .map(|f| (GlobalFragmentId::new(f.fragment_id), f)),
1889 );
1890 }
1891
1892 if let Some(FragmentGraphDownstreamContext {
1893 original_root_fragment_id,
1894 downstream_fragments,
1895 }) = downstream_ctx
1896 {
1897 let original_table_fragment_id = GlobalFragmentId::new(original_root_fragment_id);
1898 let table_fragment_id = GlobalFragmentId::new(graph.table_fragment_id());
1899
1900 for (dispatcher_type, fragment) in &downstream_fragments {
1903 let id = GlobalFragmentId::new(fragment.fragment_id);
1904
1905 let output_columns = {
1907 let mut res = None;
1908
1909 stream_graph_visitor::visit_stream_node_body(&fragment.nodes, |node_body| {
1910 let columns = match node_body {
1911 NodeBody::StreamScan(stream_scan) => stream_scan.upstream_columns(),
1912 NodeBody::SourceBackfill(source_backfill) => {
1913 source_backfill.column_descs()
1915 }
1916 _ => return,
1917 };
1918 res = Some(columns);
1919 });
1920
1921 res.context("failed to locate downstream scan")?
1922 };
1923
1924 let table_fragment = graph.fragments.get(&table_fragment_id).unwrap();
1925 let nodes = table_fragment.node.as_ref().unwrap();
1926
1927 let (dist_key_indices, output_mapping) = match job_type {
1928 StreamingJobType::Table(_) | StreamingJobType::MaterializedView => {
1929 let mview_node = nodes.get_node_body().unwrap().as_materialize().unwrap();
1930 let all_columns = mview_node.column_descs();
1931 let dist_key_indices = mview_node.dist_key_indices();
1932 let output_mapping = gen_output_mapping(&output_columns, &all_columns)
1933 .ok_or_else(|| {
1934 MetaError::invalid_parameter(
1935 "unable to drop the column due to \
1936 being referenced by downstream materialized views or sinks",
1937 )
1938 })?;
1939 (dist_key_indices, output_mapping)
1940 }
1941
1942 StreamingJobType::Source => {
1943 let source_node = nodes.get_node_body().unwrap().as_source().unwrap();
1944 let all_columns = source_node.column_descs().unwrap();
1945 let output_mapping = gen_output_mapping(&output_columns, &all_columns)
1946 .ok_or_else(|| {
1947 MetaError::invalid_parameter(
1948 "unable to drop the column due to \
1949 being referenced by downstream materialized views or sinks",
1950 )
1951 })?;
1952 assert_eq!(*dispatcher_type, DispatcherType::NoShuffle);
1953 (
1954 vec![], output_mapping,
1956 )
1957 }
1958
1959 _ => bail!("unsupported job type for replacement: {job_type:?}"),
1960 };
1961
1962 let edge = StreamFragmentEdge {
1963 id: EdgeId::DownstreamExternal(DownstreamExternalEdgeId {
1964 original_upstream_fragment_id: original_table_fragment_id,
1965 downstream_fragment_id: id,
1966 }),
1967 dispatch_strategy: DispatchStrategy {
1968 r#type: *dispatcher_type as i32,
1969 output_mapping: Some(output_mapping),
1970 dist_key_indices,
1971 },
1972 };
1973
1974 extra_downstreams
1975 .entry(table_fragment_id)
1976 .or_insert_with(HashMap::new)
1977 .try_insert(id, edge.clone())
1978 .unwrap();
1979 extra_upstreams
1980 .entry(id)
1981 .or_insert_with(HashMap::new)
1982 .try_insert(table_fragment_id, edge)
1983 .unwrap();
1984 }
1985
1986 existing_fragments.extend(
1987 downstream_fragments
1988 .into_iter()
1989 .map(|(_, f)| (GlobalFragmentId::new(f.fragment_id), f)),
1990 );
1991 }
1992
1993 Ok(Self {
1994 building_graph: graph,
1995 existing_fragments,
1996 extra_downstreams,
1997 extra_upstreams,
1998 })
1999 }
2000}
2001
2002fn gen_output_mapping(
2004 required_columns: &[PbColumnDesc],
2005 upstream_columns: &[PbColumnDesc],
2006) -> Option<DispatchOutputMapping> {
2007 let len = required_columns.len();
2008 let mut indices = vec![0; len];
2009 let mut types = None;
2010
2011 for (i, r) in required_columns.iter().enumerate() {
2012 let (ui, u) = upstream_columns
2013 .iter()
2014 .find_position(|&u| u.column_id == r.column_id)?;
2015 indices[i] = ui as u32;
2016
2017 if u.column_type != r.column_type {
2020 types.get_or_insert_with(|| vec![TypePair::default(); len])[i] = TypePair {
2021 upstream: u.column_type.clone(),
2022 downstream: r.column_type.clone(),
2023 };
2024 }
2025 }
2026
2027 let types = types.unwrap_or(Vec::new());
2029
2030 Some(DispatchOutputMapping { indices, types })
2031}
2032
2033fn mv_on_mv_dispatch_strategy(
2034 uses_shuffled_backfill: bool,
2035 dist_key_indices: Vec<u32>,
2036 output_mapping: DispatchOutputMapping,
2037) -> DispatchStrategy {
2038 if uses_shuffled_backfill {
2039 if !dist_key_indices.is_empty() {
2040 DispatchStrategy {
2041 r#type: DispatcherType::Hash as _,
2042 dist_key_indices,
2043 output_mapping: Some(output_mapping),
2044 }
2045 } else {
2046 DispatchStrategy {
2047 r#type: DispatcherType::Simple as _,
2048 dist_key_indices: vec![], output_mapping: Some(output_mapping),
2050 }
2051 }
2052 } else {
2053 DispatchStrategy {
2054 r#type: DispatcherType::NoShuffle as _,
2055 dist_key_indices: vec![], output_mapping: Some(output_mapping),
2057 }
2058 }
2059}
2060
2061impl CompleteStreamFragmentGraph {
2062 pub(super) fn all_fragment_ids(&self) -> impl Iterator<Item = GlobalFragmentId> + '_ {
2065 self.building_graph
2066 .fragments
2067 .keys()
2068 .chain(self.existing_fragments.keys())
2069 .copied()
2070 }
2071
2072 pub(super) fn all_edges(
2074 &self,
2075 ) -> impl Iterator<Item = (GlobalFragmentId, GlobalFragmentId, &StreamFragmentEdge)> + '_ {
2076 self.building_graph
2077 .downstreams
2078 .iter()
2079 .chain(self.extra_downstreams.iter())
2080 .flat_map(|(&from, tos)| tos.iter().map(move |(&to, edge)| (from, to, edge)))
2081 }
2082
2083 pub(super) fn existing_distribution(&self) -> HashMap<GlobalFragmentId, Distribution> {
2085 self.existing_fragments
2086 .iter()
2087 .map(|(&id, f)| (id, Distribution::from_fragment(f)))
2088 .collect()
2089 }
2090
2091 pub(super) fn topo_order(&self) -> MetaResult<Vec<GlobalFragmentId>> {
2098 let mut topo = Vec::new();
2099 let mut downstream_cnts = HashMap::new();
2100
2101 for fragment_id in self.all_fragment_ids() {
2103 let downstream_cnt = self.get_downstreams(fragment_id).count();
2105 if downstream_cnt == 0 {
2106 topo.push(fragment_id);
2107 } else {
2108 downstream_cnts.insert(fragment_id, downstream_cnt);
2109 }
2110 }
2111
2112 let mut i = 0;
2113 while let Some(&fragment_id) = topo.get(i) {
2114 i += 1;
2115 for (upstream_job_id, _) in self.get_upstreams(fragment_id) {
2117 let downstream_cnt = downstream_cnts.get_mut(&upstream_job_id).unwrap();
2118 *downstream_cnt -= 1;
2119 if *downstream_cnt == 0 {
2120 downstream_cnts.remove(&upstream_job_id);
2121 topo.push(upstream_job_id);
2122 }
2123 }
2124 }
2125
2126 if !downstream_cnts.is_empty() {
2127 bail!("graph is not a DAG");
2129 }
2130
2131 Ok(topo)
2132 }
2133
2134 pub(super) fn seal_fragment(
2137 &self,
2138 id: GlobalFragmentId,
2139 distribution: Distribution,
2140 stream_node: StreamNode,
2141 ) -> Fragment {
2142 let building_fragment = self.get_fragment(id).into_building().unwrap();
2143 let internal_tables = building_fragment.extract_internal_tables();
2144 let BuildingFragment {
2145 inner,
2146 job_id,
2147 upstream_job_columns: _,
2148 } = building_fragment;
2149
2150 let distribution_type = distribution.to_distribution_type();
2151 let vnode_count = distribution.vnode_count();
2152
2153 let materialized_fragment_id =
2154 if FragmentTypeMask::from(inner.fragment_type_mask).contains(FragmentTypeFlag::Mview) {
2155 job_id.map(JobId::as_mv_table_id)
2156 } else {
2157 None
2158 };
2159
2160 let vector_index_fragment_id =
2161 if inner.fragment_type_mask & FragmentTypeFlag::VectorIndexWrite as u32 != 0 {
2162 job_id.map(JobId::as_mv_table_id)
2163 } else {
2164 None
2165 };
2166
2167 let state_table_ids = internal_tables
2168 .iter()
2169 .map(|t| t.id)
2170 .chain(materialized_fragment_id)
2171 .chain(vector_index_fragment_id)
2172 .collect();
2173
2174 Fragment {
2175 fragment_id: inner.fragment_id,
2176 fragment_type_mask: inner.fragment_type_mask.into(),
2177 distribution_type,
2178 state_table_ids,
2179 maybe_vnode_count: VnodeCount::set(vnode_count).to_protobuf(),
2180 nodes: stream_node,
2181 }
2182 }
2183
2184 pub(super) fn get_fragment(&self, fragment_id: GlobalFragmentId) -> EitherFragment {
2187 if self.existing_fragments.contains_key(&fragment_id) {
2188 EitherFragment::Existing
2189 } else {
2190 EitherFragment::Building(
2191 self.building_graph
2192 .fragments
2193 .get(&fragment_id)
2194 .unwrap()
2195 .clone(),
2196 )
2197 }
2198 }
2199
2200 pub(super) fn get_downstreams(
2203 &self,
2204 fragment_id: GlobalFragmentId,
2205 ) -> impl Iterator<Item = (GlobalFragmentId, &StreamFragmentEdge)> {
2206 self.building_graph
2207 .get_downstreams(fragment_id)
2208 .iter()
2209 .chain(
2210 self.extra_downstreams
2211 .get(&fragment_id)
2212 .into_iter()
2213 .flatten(),
2214 )
2215 .map(|(&id, edge)| (id, edge))
2216 }
2217
2218 pub(super) fn get_upstreams(
2221 &self,
2222 fragment_id: GlobalFragmentId,
2223 ) -> impl Iterator<Item = (GlobalFragmentId, &StreamFragmentEdge)> {
2224 self.building_graph
2225 .get_upstreams(fragment_id)
2226 .iter()
2227 .chain(self.extra_upstreams.get(&fragment_id).into_iter().flatten())
2228 .map(|(&id, edge)| (id, edge))
2229 }
2230
2231 pub(super) fn building_fragments(&self) -> &HashMap<GlobalFragmentId, BuildingFragment> {
2233 &self.building_graph.fragments
2234 }
2235
2236 pub(super) fn building_fragments_mut(
2238 &mut self,
2239 ) -> &mut HashMap<GlobalFragmentId, BuildingFragment> {
2240 &mut self.building_graph.fragments
2241 }
2242
2243 pub(super) fn max_parallelism(&self) -> usize {
2245 self.building_graph.max_parallelism()
2246 }
2247}
2248
2249#[cfg(test)]
2250mod tests {
2251 use risingwave_common::catalog::{ColumnDesc, ColumnId};
2252 use risingwave_common::types::DataType;
2253 use risingwave_pb::catalog::SinkType as PbSinkType;
2254 use risingwave_pb::meta::table_fragments::fragment::PbFragmentDistributionType;
2255 use risingwave_pb::plan_common::StorageTableDesc;
2256 use risingwave_pb::stream_plan::{
2257 BatchPlanNode, MergeNode, ProjectNode, SinkDesc, SinkLogStoreType, SinkNode, StreamNode,
2258 StreamScanNode, StreamScanType,
2259 };
2260
2261 use super::*;
2262
2263 fn make_column(name: &str, id: i32, data_type: DataType) -> ColumnCatalog {
2264 ColumnCatalog::visible(ColumnDesc::named(name, ColumnId::new(id), data_type))
2265 }
2266
2267 fn make_field(table_name: &str, column: &ColumnCatalog) -> risingwave_pb::plan_common::Field {
2268 Field::new(
2269 format!("{}.{}", table_name, column.column_desc.name),
2270 column.data_type().clone(),
2271 )
2272 .to_prost()
2273 }
2274
2275 fn make_input_ref(index: u32, data_type: &DataType) -> PbExprNode {
2276 PbExprNode {
2277 function_type: expr_node::Type::Unspecified as i32,
2278 return_type: Some(data_type.to_protobuf()),
2279 rex_node: Some(expr_node::RexNode::InputRef(index)),
2280 }
2281 }
2282
2283 fn make_stream_scan_node(
2284 table_name: &str,
2285 table_id: u32,
2286 columns: &[ColumnCatalog],
2287 ) -> StreamNode {
2288 let merge_node = StreamNode {
2289 node_body: Some(NodeBody::Merge(Box::new(MergeNode {
2290 upstream_fragment_id: 0.into(),
2291 ..Default::default()
2292 }))),
2293 fields: columns
2294 .iter()
2295 .map(|col| make_field(table_name, col))
2296 .collect(),
2297 ..Default::default()
2298 };
2299 let batch_plan_node = StreamNode {
2300 node_body: Some(NodeBody::BatchPlan(Box::new(BatchPlanNode {
2301 ..Default::default()
2302 }))),
2303 ..Default::default()
2304 };
2305 let stream_scan_node = StreamScanNode {
2306 table_id: table_id.into(),
2307 upstream_column_ids: columns.iter().map(|c| c.column_id().get_id()).collect(),
2308 output_indices: (0..columns.len()).map(|i| i as u32).collect(),
2309 stream_scan_type: StreamScanType::ArrangementBackfill as i32,
2310 table_desc: Some(StorageTableDesc {
2311 table_id: table_id.into(),
2312 columns: columns
2313 .iter()
2314 .map(|col| col.column_desc.to_protobuf())
2315 .collect(),
2316 value_indices: (0..columns.len()).map(|i| i as u32).collect(),
2317 versioned: true,
2318 ..Default::default()
2319 }),
2320 ..Default::default()
2321 };
2322 StreamNode {
2323 node_body: Some(NodeBody::StreamScan(Box::new(stream_scan_node))),
2324 fields: columns
2325 .iter()
2326 .map(|col| make_field(table_name, col))
2327 .collect(),
2328 input: vec![merge_node, batch_plan_node],
2329 ..Default::default()
2330 }
2331 }
2332
2333 fn make_project_node(
2334 table_name: &str,
2335 columns: &[ColumnCatalog],
2336 input: StreamNode,
2337 ) -> StreamNode {
2338 let select_list = columns
2339 .iter()
2340 .enumerate()
2341 .map(|(i, col)| make_input_ref(i as u32, col.data_type()))
2342 .collect();
2343 StreamNode {
2344 node_body: Some(NodeBody::Project(Box::new(ProjectNode {
2345 select_list,
2346 ..Default::default()
2347 }))),
2348 fields: columns
2349 .iter()
2350 .map(|col| make_field(table_name, col))
2351 .collect(),
2352 input: vec![input],
2353 ..Default::default()
2354 }
2355 }
2356
2357 #[tokio::test]
2358 async fn test_rewrite_refresh_schema_sink_fragment_with_project() {
2359 let env = MetaSrvEnv::for_test().await;
2360 let id_gen_manager = env.id_gen_manager().as_ref();
2361
2362 let table_name = "t";
2363 let columns = vec![
2364 make_column("a", 1, DataType::Int64),
2365 make_column("b", 2, DataType::Int64),
2366 ];
2367 let new_column = make_column("c", 3, DataType::Varchar);
2368
2369 let mut upstream_columns = columns.clone();
2370 upstream_columns.push(new_column.clone());
2371 let upstream_table = PbTable {
2372 name: table_name.to_owned(),
2373 columns: upstream_columns
2374 .iter()
2375 .map(|col| col.to_protobuf())
2376 .collect(),
2377 ..Default::default()
2378 };
2379
2380 let sink = PbSink {
2381 columns: columns.iter().map(|col| col.to_protobuf()).collect(),
2382 sink_type: PbSinkType::AppendOnly as i32,
2383 ..Default::default()
2384 };
2385
2386 let sink_desc = SinkDesc {
2387 sink_type: PbSinkType::AppendOnly as i32,
2388 column_catalogs: sink.columns.clone(),
2389 ..Default::default()
2390 };
2391
2392 let stream_scan_node = make_stream_scan_node(table_name, 1, &columns);
2393 let project_node = make_project_node(table_name, &columns, stream_scan_node);
2394
2395 let log_store_table = PbTable {
2396 columns: columns
2397 .iter()
2398 .cloned()
2399 .map(|mut col| {
2400 col.column_desc.name = format!("{}_{}", table_name, col.column_desc.name);
2401 col.to_protobuf()
2402 })
2403 .collect(),
2404 value_indices: (0..columns.len()).map(|i| i as i32).collect(),
2405 ..Default::default()
2406 };
2407
2408 let original_fragment = Fragment {
2409 fragment_id: 1.into(),
2410 fragment_type_mask: FragmentTypeMask::default(),
2411 distribution_type: PbFragmentDistributionType::Single,
2412 state_table_ids: vec![],
2413 maybe_vnode_count: None,
2414 nodes: StreamNode {
2415 node_body: Some(NodeBody::Sink(Box::new(SinkNode {
2416 sink_desc: Some(sink_desc),
2417 table: Some(log_store_table),
2418 ..Default::default()
2419 }))),
2420 fields: columns
2421 .iter()
2422 .map(|col| make_field(table_name, col))
2423 .collect(),
2424 input: vec![project_node],
2425 ..Default::default()
2426 },
2427 };
2428
2429 let (new_fragment, _, _) = rewrite_refresh_schema_sink_fragment(
2430 &original_fragment,
2431 &sink,
2432 std::slice::from_ref(&new_column),
2433 &[],
2434 &upstream_table,
2435 7.into(),
2436 id_gen_manager,
2437 )
2438 .unwrap();
2439
2440 let sink_node = &new_fragment.nodes;
2441 let [project_node] = sink_node.input.as_slice() else {
2442 panic!("Sink has more than 1 input: {:?}", sink_node.input);
2443 };
2444 let PbNodeBody::Project(project_body) = project_node.node_body.as_ref().unwrap() else {
2445 panic!(
2446 "expect PbNodeBody::Project but got: {:?}",
2447 project_node.node_body
2448 );
2449 };
2450 assert_eq!(project_body.select_list.len(), columns.len() + 1);
2451 let last_expr = project_body.select_list.last().unwrap();
2452 assert!(
2453 matches!(last_expr.rex_node, Some(expr_node::RexNode::InputRef(idx)) if idx == columns.len() as u32)
2454 );
2455 assert_eq!(project_node.fields.len(), columns.len() + 1);
2456
2457 let [stream_scan_node] = project_node.input.as_slice() else {
2458 panic!("Project has more than 1 input: {:?}", project_node.input);
2459 };
2460 let PbNodeBody::StreamScan(scan) = stream_scan_node.node_body.as_ref().unwrap() else {
2461 panic!(
2462 "expect PbNodeBody::StreamScan but got: {:?}",
2463 stream_scan_node.node_body
2464 );
2465 };
2466 assert_eq!(
2467 scan.upstream_column_ids.last().copied(),
2468 Some(new_column.column_id().get_id())
2469 );
2470 assert_eq!(
2471 scan.output_indices.last().copied(),
2472 Some(columns.len() as u32)
2473 );
2474 assert_eq!(
2475 stream_scan_node.fields.last().unwrap().name,
2476 format!("{}.{}", table_name, new_column.column_desc.name)
2477 );
2478 }
2479
2480 #[tokio::test]
2481 async fn test_rewrite_refresh_schema_sink_fragment_drop_column_with_project() {
2482 let env = MetaSrvEnv::for_test().await;
2483 let id_gen_manager = env.id_gen_manager().as_ref();
2484
2485 let table_name = "t";
2486 let columns = vec![
2487 make_column("a", 1, DataType::Int64),
2488 make_column("b", 2, DataType::Int64),
2489 make_column("tmp", 3, DataType::Varchar),
2490 ];
2491 let removed_column = columns.last().unwrap().clone();
2492 let upstream_columns = columns[..2].to_vec();
2493
2494 let upstream_table = PbTable {
2495 name: table_name.to_owned(),
2496 columns: upstream_columns
2497 .iter()
2498 .map(|col| col.to_protobuf())
2499 .collect(),
2500 ..Default::default()
2501 };
2502
2503 let sink = PbSink {
2504 columns: columns.iter().map(|col| col.to_protobuf()).collect(),
2505 sink_type: PbSinkType::AppendOnly as i32,
2506 ..Default::default()
2507 };
2508
2509 let sink_desc = SinkDesc {
2510 sink_type: PbSinkType::AppendOnly as i32,
2511 column_catalogs: sink.columns.clone(),
2512 ..Default::default()
2513 };
2514
2515 let stream_scan_node = make_stream_scan_node(table_name, 1, &columns);
2516 let project_node = make_project_node(table_name, &columns, stream_scan_node);
2517
2518 let log_store_table = PbTable {
2519 columns: columns
2520 .iter()
2521 .cloned()
2522 .map(|mut col| {
2523 col.column_desc.name = format!("{}_{}", table_name, col.column_desc.name);
2524 col.to_protobuf()
2525 })
2526 .collect(),
2527 value_indices: (0..columns.len()).map(|i| i as i32).collect(),
2528 ..Default::default()
2529 };
2530
2531 let original_fragment = Fragment {
2532 fragment_id: 1.into(),
2533 fragment_type_mask: FragmentTypeMask::default(),
2534 distribution_type: PbFragmentDistributionType::Single,
2535 state_table_ids: vec![],
2536 maybe_vnode_count: None,
2537 nodes: StreamNode {
2538 node_body: Some(NodeBody::Sink(Box::new(SinkNode {
2539 sink_desc: Some(sink_desc),
2540 table: Some(log_store_table),
2541 log_store_type: SinkLogStoreType::KvLogStore as i32,
2542 ..Default::default()
2543 }))),
2544 fields: columns
2545 .iter()
2546 .map(|col| make_field(table_name, col))
2547 .collect(),
2548 input: vec![project_node],
2549 ..Default::default()
2550 },
2551 };
2552
2553 let (new_fragment, new_schema, new_log_store_table) = rewrite_refresh_schema_sink_fragment(
2554 &original_fragment,
2555 &sink,
2556 &[],
2557 std::slice::from_ref(&removed_column),
2558 &upstream_table,
2559 7.into(),
2560 id_gen_manager,
2561 )
2562 .unwrap();
2563
2564 assert_eq!(new_schema.len(), 2);
2565 assert!(
2566 new_schema.iter().all(|col| {
2567 col.column_desc.as_ref().map(|desc| desc.name.as_str()) != Some("tmp")
2568 })
2569 );
2570
2571 let sink_node = &new_fragment.nodes;
2572 let [project_node] = sink_node.input.as_slice() else {
2573 panic!("Sink has more than 1 input: {:?}", sink_node.input);
2574 };
2575 let PbNodeBody::Project(project_body) = project_node.node_body.as_ref().unwrap() else {
2576 panic!(
2577 "expect PbNodeBody::Project but got: {:?}",
2578 project_node.node_body
2579 );
2580 };
2581 assert_eq!(project_body.select_list.len(), 2);
2582 assert!(project_node.fields.iter().all(|f| !f.name.contains("tmp")));
2583
2584 let [stream_scan_node] = project_node.input.as_slice() else {
2585 panic!("Project has more than 1 input: {:?}", project_node.input);
2586 };
2587 let PbNodeBody::StreamScan(scan) = stream_scan_node.node_body.as_ref().unwrap() else {
2588 panic!(
2589 "expect PbNodeBody::StreamScan but got: {:?}",
2590 stream_scan_node.node_body
2591 );
2592 };
2593 assert!(
2594 !scan
2595 .upstream_column_ids
2596 .iter()
2597 .any(|&id| id == removed_column.column_id().get_id())
2598 );
2599 assert!(
2600 stream_scan_node
2601 .fields
2602 .iter()
2603 .all(|f| !f.name.contains("tmp"))
2604 );
2605
2606 let new_log_store_table = new_log_store_table.expect("log store table should be updated");
2607 assert!(
2608 new_log_store_table.columns.iter().all(|col| !col
2609 .column_desc
2610 .as_ref()
2611 .unwrap()
2612 .name
2613 .contains("tmp"))
2614 );
2615 assert_eq!(
2616 new_log_store_table.value_indices,
2617 (0..new_log_store_table.columns.len() as i32).collect::<Vec<_>>()
2618 );
2619 }
2620}