1use std::collections::{BTreeMap, HashMap, HashSet};
16use std::num::NonZeroUsize;
17use std::ops::{Deref, DerefMut};
18use std::sync::LazyLock;
19
20use anyhow::{Context, anyhow};
21use enum_as_inner::EnumAsInner;
22use itertools::Itertools;
23use risingwave_common::bail;
24use risingwave_common::catalog::{
25 CDC_SOURCE_COLUMN_NUM, ColumnCatalog, ColumnId, Field, FragmentTypeFlag, FragmentTypeMask,
26 TableId, generate_internal_table_name_with_type,
27};
28use risingwave_common::hash::VnodeCount;
29use risingwave_common::id::JobId;
30use risingwave_common::util::iter_util::ZipEqFast;
31use risingwave_common::util::stream_graph_visitor::{
32 self, visit_stream_node_cont, visit_stream_node_cont_mut,
33};
34use risingwave_connector::sink::catalog::SinkType;
35use risingwave_meta_model::streaming_job::BackfillOrders;
36use risingwave_pb::catalog::{PbSink, PbTable, Table};
37use risingwave_pb::ddl_service::TableJobType;
38use risingwave_pb::expr::{ExprNode as PbExprNode, expr_node};
39use risingwave_pb::id::{RelationId, StreamNodeLocalOperatorId};
40use risingwave_pb::plan_common::{PbColumnCatalog, PbColumnDesc};
41use risingwave_pb::stream_plan::dispatch_output_mapping::TypePair;
42use risingwave_pb::stream_plan::stream_fragment_graph::{
43 Parallelism, StreamFragment, StreamFragmentEdge as StreamFragmentEdgeProto,
44};
45use risingwave_pb::stream_plan::stream_node::{NodeBody, PbNodeBody};
46use risingwave_pb::stream_plan::{
47 BackfillOrder, DispatchOutputMapping, DispatchStrategy, DispatcherType, PbStreamNode,
48 PbStreamScanType, StreamFragmentGraph as StreamFragmentGraphProto, StreamNode, StreamScanNode,
49 StreamScanType,
50};
51
52use crate::barrier::SnapshotBackfillInfo;
53use crate::controller::id::IdGeneratorManager;
54use crate::manager::{MetaSrvEnv, StreamingJob, StreamingJobType};
55use crate::model::{Fragment, FragmentDownstreamRelation, FragmentId};
56use crate::stream::stream_graph::id::{GlobalFragmentId, GlobalFragmentIdGen, GlobalTableIdGen};
57use crate::stream::stream_graph::schedule::Distribution;
58use crate::{MetaError, MetaResult};
59
60#[derive(Debug, Clone)]
63pub(super) struct BuildingFragment {
64 inner: StreamFragment,
66
67 job_id: Option<JobId>,
69
70 upstream_job_columns: HashMap<JobId, Vec<PbColumnDesc>>,
75}
76
77impl BuildingFragment {
78 fn new(
81 id: GlobalFragmentId,
82 fragment: StreamFragment,
83 job: &StreamingJob,
84 table_id_gen: GlobalTableIdGen,
85 ) -> Self {
86 let mut fragment = StreamFragment {
87 fragment_id: id.as_global_id(),
88 ..fragment
89 };
90
91 Self::fill_internal_tables(&mut fragment, job, table_id_gen);
93
94 let job_id = Self::fill_job(&mut fragment, job).then(|| job.id());
95 let upstream_job_columns =
96 Self::extract_upstream_columns_except_cross_db_backfill(&fragment);
97
98 Self {
99 inner: fragment,
100 job_id,
101 upstream_job_columns,
102 }
103 }
104
105 fn extract_internal_tables(&self) -> Vec<Table> {
107 let mut fragment = self.inner.clone();
108 let mut tables = Vec::new();
109 stream_graph_visitor::visit_internal_tables(&mut fragment, |table, _| {
110 tables.push(table.clone());
111 });
112 tables
113 }
114
115 fn fill_internal_tables(
117 fragment: &mut StreamFragment,
118 job: &StreamingJob,
119 table_id_gen: GlobalTableIdGen,
120 ) {
121 let fragment_id = fragment.fragment_id;
122 stream_graph_visitor::visit_internal_tables(fragment, |table, table_type_name| {
123 table.id = table_id_gen
124 .to_global_id(table.id.as_raw_id())
125 .as_global_id();
126 table.schema_id = job.schema_id();
127 table.database_id = job.database_id();
128 table.name = generate_internal_table_name_with_type(
129 &job.name(),
130 fragment_id,
131 table.id,
132 table_type_name,
133 );
134 table.fragment_id = fragment_id;
135 table.owner = job.owner();
136 table.job_id = Some(job.id());
137 });
138 }
139
140 fn fill_job(fragment: &mut StreamFragment, job: &StreamingJob) -> bool {
142 let job_id = job.id();
143 let fragment_id = fragment.fragment_id;
144 let mut has_job = false;
145
146 stream_graph_visitor::visit_fragment_mut(fragment, |node_body| match node_body {
147 NodeBody::Materialize(materialize_node) => {
148 materialize_node.table_id = job_id.as_mv_table_id();
149
150 let table = materialize_node.table.insert(job.table().unwrap().clone());
152 table.fragment_id = fragment_id; if cfg!(not(debug_assertions)) {
155 table.definition = job.name();
156 }
157
158 has_job = true;
159 }
160 NodeBody::Sink(sink_node) => {
161 sink_node.sink_desc.as_mut().unwrap().id = job_id.as_sink_id();
162
163 has_job = true;
164 }
165 NodeBody::Dml(dml_node) => {
166 dml_node.table_id = job_id.as_mv_table_id();
167 dml_node.table_version_id = job.table_version_id().unwrap();
168 }
169 NodeBody::StreamFsFetch(fs_fetch_node) => {
170 if let StreamingJob::Table(table_source, _, _) = job
171 && let Some(node_inner) = fs_fetch_node.node_inner.as_mut()
172 && let Some(source) = table_source
173 {
174 node_inner.source_id = source.id;
175 if let Some(id) = source.optional_associated_table_id {
176 node_inner.associated_table_id = Some(id.into());
177 }
178 }
179 }
180 NodeBody::Source(source_node) => {
181 match job {
182 StreamingJob::Table(source, _table, _table_job_type) => {
185 if let Some(source_inner) = source_node.source_inner.as_mut()
186 && let Some(source) = source
187 {
188 debug_assert_ne!(source.id, job_id.as_raw_id());
189 source_inner.source_id = source.id;
190 if let Some(id) = source.optional_associated_table_id {
191 source_inner.associated_table_id = Some(id.into());
192 }
193 }
194 }
195 StreamingJob::Source(source) => {
196 has_job = true;
197 if let Some(source_inner) = source_node.source_inner.as_mut() {
198 debug_assert_eq!(source.id, job_id.as_raw_id());
199 source_inner.source_id = source.id;
200 if let Some(id) = source.optional_associated_table_id {
201 source_inner.associated_table_id = Some(id.into());
202 }
203 }
204 }
205 _ => {}
207 }
208 }
209 NodeBody::StreamCdcScan(node) => {
210 if let Some(table_desc) = node.cdc_table_desc.as_mut() {
211 table_desc.table_id = job_id.as_mv_table_id();
212 }
213 }
214 NodeBody::VectorIndexWrite(node) => {
215 let table = node.table.as_mut().unwrap();
216 table.id = job_id.as_mv_table_id();
217 table.database_id = job.database_id();
218 table.schema_id = job.schema_id();
219 table.fragment_id = fragment_id;
220 #[cfg(not(debug_assertions))]
221 {
222 table.definition = job.name();
223 }
224
225 has_job = true;
226 }
227 _ => {}
228 });
229
230 has_job
231 }
232
233 fn extract_upstream_columns_except_cross_db_backfill(
235 fragment: &StreamFragment,
236 ) -> HashMap<JobId, Vec<PbColumnDesc>> {
237 let mut table_columns = HashMap::new();
238
239 stream_graph_visitor::visit_fragment(fragment, |node_body| {
240 let (table_id, column_ids) = match node_body {
241 NodeBody::StreamScan(stream_scan) => {
242 if stream_scan.get_stream_scan_type().unwrap()
243 == StreamScanType::CrossDbSnapshotBackfill
244 {
245 return;
246 }
247 (
248 stream_scan.table_id.as_job_id(),
249 stream_scan.upstream_columns(),
250 )
251 }
252 NodeBody::CdcFilter(cdc_filter) => (
253 cdc_filter.upstream_source_id.as_share_source_job_id(),
254 vec![],
255 ),
256 NodeBody::SourceBackfill(backfill) => (
257 backfill.upstream_source_id.as_share_source_job_id(),
258 backfill.column_descs(),
260 ),
261 _ => return,
262 };
263 table_columns
264 .try_insert(table_id, column_ids)
265 .expect("currently there should be no two same upstream tables in a fragment");
266 });
267
268 table_columns
269 }
270
271 pub fn has_shuffled_backfill(&self) -> bool {
272 let stream_node = match self.inner.node.as_ref() {
273 Some(node) => node,
274 _ => return false,
275 };
276 let mut has_shuffled_backfill = false;
277 let has_shuffled_backfill_mut_ref = &mut has_shuffled_backfill;
278 visit_stream_node_cont(stream_node, |node| {
279 let is_shuffled_backfill = if let Some(node) = &node.node_body
280 && let Some(node) = node.as_stream_scan()
281 {
282 node.stream_scan_type == StreamScanType::ArrangementBackfill as i32
283 || node.stream_scan_type == StreamScanType::SnapshotBackfill as i32
284 } else {
285 false
286 };
287 if is_shuffled_backfill {
288 *has_shuffled_backfill_mut_ref = true;
289 false
290 } else {
291 true
292 }
293 });
294 has_shuffled_backfill
295 }
296}
297
298impl Deref for BuildingFragment {
299 type Target = StreamFragment;
300
301 fn deref(&self) -> &Self::Target {
302 &self.inner
303 }
304}
305
306impl DerefMut for BuildingFragment {
307 fn deref_mut(&mut self) -> &mut Self::Target {
308 &mut self.inner
309 }
310}
311
312#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, EnumAsInner)]
315pub(super) enum EdgeId {
316 Internal {
318 link_id: u64,
321 },
322
323 UpstreamExternal {
326 upstream_job_id: JobId,
328 downstream_fragment_id: GlobalFragmentId,
330 },
331
332 DownstreamExternal(DownstreamExternalEdgeId),
335}
336
337#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
338pub(super) struct DownstreamExternalEdgeId {
339 pub(super) original_upstream_fragment_id: GlobalFragmentId,
341 pub(super) downstream_fragment_id: GlobalFragmentId,
343}
344
345#[derive(Debug, Clone)]
349pub(super) struct StreamFragmentEdge {
350 pub id: EdgeId,
352
353 pub dispatch_strategy: DispatchStrategy,
355}
356
357impl StreamFragmentEdge {
358 fn from_protobuf(edge: &StreamFragmentEdgeProto) -> Self {
359 Self {
360 id: EdgeId::Internal {
363 link_id: edge.link_id,
364 },
365 dispatch_strategy: edge.get_dispatch_strategy().unwrap().clone(),
366 }
367 }
368}
369
370fn clone_fragment(fragment: &Fragment, id_generator_manager: &IdGeneratorManager) -> Fragment {
371 let fragment_id = GlobalFragmentIdGen::new(id_generator_manager, 1)
372 .to_global_id(0)
373 .as_global_id();
374 Fragment {
375 fragment_id,
376 fragment_type_mask: fragment.fragment_type_mask,
377 distribution_type: fragment.distribution_type,
378 state_table_ids: fragment.state_table_ids.clone(),
379 maybe_vnode_count: fragment.maybe_vnode_count,
380 nodes: fragment.nodes.clone(),
381 }
382}
383
384pub fn check_sink_fragments_support_refresh_schema(
385 fragments: &BTreeMap<FragmentId, Fragment>,
386) -> MetaResult<()> {
387 if fragments.len() != 1 {
388 return Err(anyhow!(
389 "sink with auto schema change should have only 1 fragment, but got {:?}",
390 fragments.len()
391 )
392 .into());
393 }
394 let (_, fragment) = fragments.first_key_value().expect("non-empty");
395 let sink_node = &fragment.nodes;
396 let PbNodeBody::Sink(_) = sink_node.node_body.as_ref().unwrap() else {
397 return Err(anyhow!("expect PbNodeBody::Sink but got: {:?}", sink_node.node_body).into());
398 };
399 let [stream_input_node] = sink_node.input.as_slice() else {
400 panic!("Sink has more than 1 input: {:?}", sink_node.input);
401 };
402 let stream_scan_node = match stream_input_node.node_body.as_ref().unwrap() {
403 PbNodeBody::StreamScan(_) => stream_input_node,
404 PbNodeBody::Project(_) => {
405 let [stream_scan_node] = stream_input_node.input.as_slice() else {
406 return Err(anyhow!(
407 "Project node must have exactly 1 input for auto schema change, but got {:?}",
408 stream_input_node.input.len()
409 )
410 .into());
411 };
412 stream_scan_node
413 }
414 _ => {
415 return Err(anyhow!(
416 "expect PbNodeBody::StreamScan or PbNodeBody::Project but got: {:?}",
417 stream_input_node.node_body
418 )
419 .into());
420 }
421 };
422 let PbNodeBody::StreamScan(scan) = stream_scan_node.node_body.as_ref().unwrap() else {
423 return Err(anyhow!(
424 "expect PbNodeBody::StreamScan but got: {:?}",
425 stream_scan_node.node_body
426 )
427 .into());
428 };
429 let stream_scan_type = PbStreamScanType::try_from(scan.stream_scan_type).unwrap();
430 if stream_scan_type != PbStreamScanType::ArrangementBackfill {
431 return Err(anyhow!(
432 "unsupported stream_scan_type for auto refresh schema: {:?}",
433 stream_scan_type
434 )
435 .into());
436 }
437 let [merge_node, _batch_plan_node] = stream_scan_node.input.as_slice() else {
438 panic!(
439 "the number of StreamScan inputs is not 2: {:?}",
440 stream_scan_node.input
441 );
442 };
443 let NodeBody::Merge(_) = merge_node.node_body.as_ref().unwrap() else {
444 return Err(anyhow!(
445 "expect PbNodeBody::Merge but got: {:?}",
446 merge_node.node_body
447 )
448 .into());
449 };
450 Ok(())
451}
452
453struct ScanRewriteResult {
455 old_output_index_to_new_output_index: HashMap<u32, u32>,
456 new_output_index_by_column_id: HashMap<ColumnId, u32>,
457 output_fields: Vec<risingwave_pb::plan_common::Field>,
458}
459
460fn extend_sink_columns(
462 sink_columns: &mut Vec<PbColumnCatalog>,
463 new_columns: &[ColumnCatalog],
464 get_column_name: impl Fn(&String) -> String,
465) {
466 let next_column_id = sink_columns
467 .iter()
468 .map(|col| col.column_desc.as_ref().unwrap().column_id + 1)
469 .max()
470 .unwrap_or(1);
471 sink_columns.extend(new_columns.iter().enumerate().map(|(i, col)| {
472 let mut col = col.to_protobuf();
473 let column_desc = col.column_desc.as_mut().unwrap();
474 column_desc.column_id = next_column_id + (i as i32);
475 column_desc.name = get_column_name(&column_desc.name);
476 col
477 }));
478}
479
480fn build_new_sink_columns(
482 sink: &PbSink,
483 removed_column_ids: &HashSet<ColumnId>,
484 newly_added_columns: &[ColumnCatalog],
485) -> Vec<PbColumnCatalog> {
486 let mut columns: Vec<PbColumnCatalog> = sink
487 .columns
488 .iter()
489 .filter(|col| {
490 let column_id = ColumnId::new(col.column_desc.as_ref().unwrap().column_id as _);
491 !removed_column_ids.contains(&column_id)
492 })
493 .cloned()
494 .collect();
495 extend_sink_columns(&mut columns, newly_added_columns, |name| name.clone());
496 columns
497}
498
499fn rewrite_log_store_table(
501 log_store_table: &mut PbTable,
502 removed_log_store_column_names: &HashSet<String>,
503 newly_added_columns: &[ColumnCatalog],
504 upstream_table_name: &str,
505) {
506 log_store_table.columns.retain(|col| {
507 !removed_log_store_column_names.contains(&col.column_desc.as_ref().unwrap().name)
508 });
509 extend_sink_columns(&mut log_store_table.columns, newly_added_columns, |name| {
510 format!("{}_{}", upstream_table_name, name)
511 });
512}
513
514fn rewrite_stream_scan_and_merge(
516 stream_scan_node: &mut StreamNode,
517 removed_column_ids: &HashSet<ColumnId>,
518 newly_added_columns: &[ColumnCatalog],
519 upstream_table: &PbTable,
520 upstream_table_fragment_id: FragmentId,
521) -> MetaResult<ScanRewriteResult> {
522 let PbNodeBody::StreamScan(scan) = stream_scan_node.node_body.as_mut().unwrap() else {
523 return Err(anyhow!(
524 "expect PbNodeBody::StreamScan but got: {:?}",
525 stream_scan_node.node_body
526 )
527 .into());
528 };
529 let [merge_node, _batch_plan_node] = stream_scan_node.input.as_mut_slice() else {
530 panic!(
531 "the number of StreamScan inputs is not 2: {:?}",
532 stream_scan_node.input
533 );
534 };
535 let NodeBody::Merge(merge) = merge_node.node_body.as_mut().unwrap() else {
536 return Err(anyhow!(
537 "expect PbNodeBody::Merge but got: {:?}",
538 merge_node.node_body
539 )
540 .into());
541 };
542
543 let stream_scan_type = PbStreamScanType::try_from(scan.stream_scan_type).unwrap();
544 if stream_scan_type != PbStreamScanType::ArrangementBackfill {
545 return Err(anyhow!(
546 "unsupported stream_scan_type for auto refresh schema: {:?}",
547 stream_scan_type
548 )
549 .into());
550 }
551
552 let upstream_columns_by_id: HashMap<i32, PbColumnDesc> = upstream_table
553 .columns
554 .iter()
555 .map(|col| {
556 let desc = col.column_desc.as_ref().unwrap().clone();
557 (desc.column_id, desc)
558 })
559 .collect();
560
561 let old_upstream_column_ids = scan.upstream_column_ids.clone();
562 let old_output_indices = scan.output_indices.clone();
563 let mut old_upstream_index_to_new_upstream_index = HashMap::new();
564 let mut new_upstream_column_ids = Vec::new();
565 for (old_idx, &column_id) in old_upstream_column_ids.iter().enumerate() {
566 if !removed_column_ids.contains(&ColumnId::new(column_id as _)) {
567 let new_idx = new_upstream_column_ids.len() as u32;
568 old_upstream_index_to_new_upstream_index.insert(old_idx as u32, new_idx);
569 new_upstream_column_ids.push(column_id);
570 }
571 }
572 let mut new_output_indices = Vec::new();
573 for old_output_index in &old_output_indices {
574 if let Some(new_index) = old_upstream_index_to_new_upstream_index.get(old_output_index) {
575 new_output_indices.push(*new_index);
576 }
577 }
578 for col in newly_added_columns {
579 let new_index = new_upstream_column_ids.len() as u32;
580 new_upstream_column_ids.push(col.column_id().get_id());
581 new_output_indices.push(new_index);
582 }
583
584 let new_output_column_ids: Vec<i32> = new_output_indices
585 .iter()
586 .map(|&idx| new_upstream_column_ids[idx as usize])
587 .collect();
588 let mut new_output_index_by_column_id = HashMap::new();
589 for (pos, &column_id) in new_output_column_ids.iter().enumerate() {
590 new_output_index_by_column_id.insert(ColumnId::new(column_id as _), pos as u32);
591 }
592 let mut old_output_index_to_new_output_index = HashMap::new();
593 for (old_pos, old_output_index) in old_output_indices.iter().enumerate() {
594 let column_id = old_upstream_column_ids[*old_output_index as usize];
595 if let Some(new_pos) = new_output_index_by_column_id.get(&ColumnId::new(column_id as _)) {
596 old_output_index_to_new_output_index.insert(old_pos as u32, *new_pos);
597 }
598 }
599
600 scan.arrangement_table = Some(upstream_table.clone());
601 scan.upstream_column_ids = new_upstream_column_ids;
602 scan.output_indices = new_output_indices;
603 let table_desc = scan.table_desc.as_mut().unwrap();
604 table_desc.columns = scan
605 .upstream_column_ids
606 .iter()
607 .map(|column_id| {
608 upstream_columns_by_id
609 .get(column_id)
610 .unwrap_or_else(|| panic!("upstream column id not found: {}", column_id))
611 .clone()
612 })
613 .collect();
614
615 stream_scan_node.fields = new_output_column_ids
616 .iter()
617 .map(|column_id| {
618 let col_desc = upstream_columns_by_id
619 .get(column_id)
620 .unwrap_or_else(|| panic!("upstream column id not found: {}", column_id));
621 Field::new(
622 format!("{}.{}", upstream_table.name, col_desc.name),
623 col_desc.column_type.as_ref().unwrap().into(),
624 )
625 .to_prost()
626 })
627 .collect();
628 stream_scan_node.identity = {
630 let columns = stream_scan_node
631 .fields
632 .iter()
633 .map(|col| &col.name)
634 .join(", ");
635 format!("StreamTableScan {{ table: t, columns: [{columns}] }}")
636 };
637
638 merge_node.fields = scan
640 .upstream_column_ids
641 .iter()
642 .map(|&column_id| {
643 let col_desc = upstream_columns_by_id
644 .get(&column_id)
645 .unwrap_or_else(|| panic!("upstream column id not found: {}", column_id));
646 Field::new(
647 col_desc.name.clone(),
648 col_desc.column_type.as_ref().unwrap().into(),
649 )
650 .to_prost()
651 })
652 .collect();
653 merge.upstream_fragment_id = upstream_table_fragment_id;
654
655 Ok(ScanRewriteResult {
656 old_output_index_to_new_output_index,
657 new_output_index_by_column_id,
658 output_fields: stream_scan_node.fields.clone(),
659 })
660}
661
662fn rewrite_project_node(
664 project_node: &mut StreamNode,
665 scan_rewrite: &ScanRewriteResult,
666 newly_added_columns: &[ColumnCatalog],
667 removed_column_ids: &HashSet<ColumnId>,
668 upstream_table_name: &str,
669) -> MetaResult<()> {
670 let PbNodeBody::Project(project_node_body) = project_node.node_body.as_mut().unwrap() else {
671 return Err(anyhow!(
672 "expect PbNodeBody::Project but got: {:?}",
673 project_node.node_body
674 )
675 .into());
676 };
677 let has_non_input_ref = project_node_body
678 .select_list
679 .iter()
680 .any(|expr| !matches!(expr.rex_node, Some(expr_node::RexNode::InputRef(_))));
681 if has_non_input_ref && !removed_column_ids.is_empty() {
682 return Err(anyhow!(
683 "auto schema change with drop column only supports Project with InputRef"
684 )
685 .into());
686 }
687
688 let mut new_select_list = Vec::with_capacity(project_node_body.select_list.len());
689 let mut new_project_fields = Vec::with_capacity(project_node.fields.len());
690 for (index, expr) in project_node_body.select_list.iter().enumerate() {
691 let mut new_expr = expr.clone();
692 if let Some(expr_node::RexNode::InputRef(old_index)) = new_expr.rex_node {
693 let Some(&new_index) = scan_rewrite
694 .old_output_index_to_new_output_index
695 .get(&old_index)
696 else {
697 continue;
698 };
699 new_expr.rex_node = Some(expr_node::RexNode::InputRef(new_index));
700 } else if !removed_column_ids.is_empty() {
701 return Err(anyhow!(
702 "auto schema change with drop column only supports Project with InputRef"
703 )
704 .into());
705 }
706 new_select_list.push(new_expr);
707 new_project_fields.push(project_node.fields[index].clone());
708 }
709
710 for col in newly_added_columns {
711 let Some(&new_index) = scan_rewrite
712 .new_output_index_by_column_id
713 .get(&col.column_id())
714 else {
715 return Err(anyhow!("new column id not found in scan output").into());
716 };
717 new_select_list.push(PbExprNode {
718 function_type: expr_node::Type::Unspecified as i32,
719 return_type: Some(col.data_type().to_protobuf()),
720 rex_node: Some(expr_node::RexNode::InputRef(new_index)),
721 });
722 new_project_fields.push(
723 Field::new(
724 format!("{}.{}", upstream_table_name, col.column_desc.name),
725 col.data_type().clone(),
726 )
727 .to_prost(),
728 );
729 }
730
731 project_node_body.select_list = new_select_list;
732 project_node.fields = new_project_fields;
733 Ok(())
734}
735
736pub fn rewrite_refresh_schema_sink_fragment(
737 original_sink_fragment: &Fragment,
738 sink: &PbSink,
739 newly_added_columns: &[ColumnCatalog],
740 removed_columns: &[ColumnCatalog],
741 upstream_table: &PbTable,
742 upstream_table_fragment_id: FragmentId,
743 id_generator_manager: &IdGeneratorManager,
744) -> MetaResult<(Fragment, Vec<PbColumnCatalog>, Option<PbTable>)> {
745 let removed_column_ids: HashSet<_> =
746 removed_columns.iter().map(|col| col.column_id()).collect();
747 let removed_log_store_column_names: HashSet<_> = removed_columns
748 .iter()
749 .map(|col| format!("{}_{}", upstream_table.name, col.column_desc.name))
750 .collect();
751 let new_sink_columns = build_new_sink_columns(sink, &removed_column_ids, newly_added_columns);
752
753 let mut new_sink_fragment = clone_fragment(original_sink_fragment, id_generator_manager);
754 let sink_node = &mut new_sink_fragment.nodes;
755 let PbNodeBody::Sink(sink_node_body) = sink_node.node_body.as_mut().unwrap() else {
756 return Err(anyhow!("expect PbNodeBody::Sink but got: {:?}", sink_node.node_body).into());
757 };
758 let [stream_input_node] = sink_node.input.as_mut_slice() else {
759 panic!("Sink has more than 1 input: {:?}", sink_node.input);
760 };
761 let stream_input_body = stream_input_node.node_body.as_ref().unwrap();
762 let stream_input_is_project = matches!(stream_input_body, PbNodeBody::Project(_));
763 let stream_input_is_scan = matches!(stream_input_body, PbNodeBody::StreamScan(_));
764 if !stream_input_is_project && !stream_input_is_scan {
765 return Err(anyhow!(
766 "expect PbNodeBody::StreamScan or PbNodeBody::Project but got: {:?}",
767 stream_input_body
768 )
769 .into());
770 }
771
772 sink_node.identity = {
775 let sink_type = SinkType::from_proto(sink.sink_type());
776 let sink_type_str = sink_type.type_str();
777 let column_names = new_sink_columns
778 .iter()
779 .map(|col| {
780 ColumnCatalog::from(col.clone())
781 .name_with_hidden()
782 .to_string()
783 })
784 .join(", ");
785 let downstream_pk = if !sink_type.is_append_only() {
786 let downstream_pk = sink
787 .downstream_pk
788 .iter()
789 .map(|i| &sink.columns[*i as usize].column_desc.as_ref().unwrap().name)
790 .collect_vec();
791 format!(", downstream_pk: {downstream_pk:?}")
792 } else {
793 "".to_owned()
794 };
795 format!("StreamSink {{ type: {sink_type_str}, columns: [{column_names}]{downstream_pk} }}")
796 };
797 let new_log_store_table = if let Some(log_store_table) = &mut sink_node_body.table {
798 rewrite_log_store_table(
799 log_store_table,
800 &removed_log_store_column_names,
801 newly_added_columns,
802 &upstream_table.name,
803 );
804 Some(log_store_table.clone())
805 } else {
806 None
807 };
808 sink_node_body.sink_desc.as_mut().unwrap().column_catalogs = new_sink_columns.clone();
809
810 let stream_scan_node = if stream_input_is_project {
811 let [stream_scan_node] = stream_input_node.input.as_mut_slice() else {
812 return Err(anyhow!(
813 "Project node must have exactly 1 input for auto schema change, but got {:?}",
814 stream_input_node.input.len()
815 )
816 .into());
817 };
818 stream_scan_node
819 } else {
820 stream_input_node
821 };
822 let scan_rewrite = rewrite_stream_scan_and_merge(
823 stream_scan_node,
824 &removed_column_ids,
825 newly_added_columns,
826 upstream_table,
827 upstream_table_fragment_id,
828 )?;
829
830 if stream_input_is_project {
831 let [project_node] = sink_node.input.as_mut_slice() else {
832 panic!("Sink has more than 1 input: {:?}", sink_node.input);
833 };
834 rewrite_project_node(
835 project_node,
836 &scan_rewrite,
837 newly_added_columns,
838 &removed_column_ids,
839 &upstream_table.name,
840 )?;
841 sink_node.fields = project_node.fields.clone();
842 } else {
843 sink_node.fields = scan_rewrite.output_fields;
844 }
845 Ok((new_sink_fragment, new_sink_columns, new_log_store_table))
846}
847
848#[derive(Clone, Debug, Default)]
853pub struct FragmentBackfillOrder<const EXTENDED: bool> {
854 inner: HashMap<FragmentId, Vec<FragmentId>>,
855}
856
857impl<const EXTENDED: bool> Deref for FragmentBackfillOrder<EXTENDED> {
858 type Target = HashMap<FragmentId, Vec<FragmentId>>;
859
860 fn deref(&self) -> &Self::Target {
861 &self.inner
862 }
863}
864
865impl UserDefinedFragmentBackfillOrder {
866 pub fn new(inner: HashMap<FragmentId, Vec<FragmentId>>) -> Self {
867 Self { inner }
868 }
869
870 pub fn merge(orders: impl Iterator<Item = Self>) -> Self {
871 Self {
872 inner: orders.flat_map(|order| order.inner).collect(),
873 }
874 }
875
876 pub fn to_meta_model(&self) -> BackfillOrders {
877 self.inner.clone().into()
878 }
879}
880
881pub type UserDefinedFragmentBackfillOrder = FragmentBackfillOrder<false>;
882pub type ExtendedFragmentBackfillOrder = FragmentBackfillOrder<true>;
883
884#[derive(Default, Debug)]
891pub struct StreamFragmentGraph {
892 pub(super) fragments: HashMap<GlobalFragmentId, BuildingFragment>,
894
895 pub(super) downstreams:
897 HashMap<GlobalFragmentId, HashMap<GlobalFragmentId, StreamFragmentEdge>>,
898
899 pub(super) upstreams: HashMap<GlobalFragmentId, HashMap<GlobalFragmentId, StreamFragmentEdge>>,
901
902 dependent_table_ids: HashSet<TableId>,
904
905 specified_parallelism: Option<NonZeroUsize>,
908 specified_backfill_parallelism: Option<NonZeroUsize>,
911
912 max_parallelism: usize,
922
923 backfill_order: BackfillOrder,
925}
926
927impl StreamFragmentGraph {
928 pub fn new(
931 env: &MetaSrvEnv,
932 proto: StreamFragmentGraphProto,
933 job: &StreamingJob,
934 ) -> MetaResult<Self> {
935 let fragment_id_gen =
936 GlobalFragmentIdGen::new(env.id_gen_manager(), proto.fragments.len() as u64);
937 let table_id_gen = GlobalTableIdGen::new(env.id_gen_manager(), proto.table_ids_cnt as u64);
941
942 let fragments: HashMap<_, _> = proto
944 .fragments
945 .into_iter()
946 .map(|(id, fragment)| {
947 let id = fragment_id_gen.to_global_id(id.as_raw_id());
948 let fragment = BuildingFragment::new(id, fragment, job, table_id_gen);
949 (id, fragment)
950 })
951 .collect();
952
953 assert_eq!(
954 fragments
955 .values()
956 .map(|f| f.extract_internal_tables().len() as u32)
957 .sum::<u32>(),
958 proto.table_ids_cnt
959 );
960
961 let mut downstreams = HashMap::new();
963 let mut upstreams = HashMap::new();
964
965 for edge in proto.edges {
966 let upstream_id = fragment_id_gen.to_global_id(edge.upstream_id.as_raw_id());
967 let downstream_id = fragment_id_gen.to_global_id(edge.downstream_id.as_raw_id());
968 let edge = StreamFragmentEdge::from_protobuf(&edge);
969
970 upstreams
971 .entry(downstream_id)
972 .or_insert_with(HashMap::new)
973 .try_insert(upstream_id, edge.clone())
974 .unwrap();
975 downstreams
976 .entry(upstream_id)
977 .or_insert_with(HashMap::new)
978 .try_insert(downstream_id, edge)
979 .unwrap();
980 }
981
982 let dependent_table_ids = proto.dependent_table_ids.iter().copied().collect();
985
986 let specified_parallelism = if let Some(Parallelism { parallelism }) = proto.parallelism {
987 Some(NonZeroUsize::new(parallelism as usize).context("parallelism should not be 0")?)
988 } else {
989 None
990 };
991 let specified_backfill_parallelism =
992 if let Some(Parallelism { parallelism }) = proto.backfill_parallelism {
993 Some(
994 NonZeroUsize::new(parallelism as usize)
995 .context("backfill parallelism should not be 0")?,
996 )
997 } else {
998 None
999 };
1000
1001 let max_parallelism = proto.max_parallelism as usize;
1002 let backfill_order = proto.backfill_order.unwrap_or(BackfillOrder {
1003 order: Default::default(),
1004 });
1005
1006 Ok(Self {
1007 fragments,
1008 downstreams,
1009 upstreams,
1010 dependent_table_ids,
1011 specified_parallelism,
1012 specified_backfill_parallelism,
1013 max_parallelism,
1014 backfill_order,
1015 })
1016 }
1017
1018 pub fn incomplete_internal_tables(&self) -> BTreeMap<TableId, Table> {
1024 let mut tables = BTreeMap::new();
1025 for fragment in self.fragments.values() {
1026 for table in fragment.extract_internal_tables() {
1027 let table_id = table.id;
1028 tables
1029 .try_insert(table_id, table)
1030 .unwrap_or_else(|_| panic!("duplicated table id `{}`", table_id));
1031 }
1032 }
1033 tables
1034 }
1035
1036 pub fn refill_internal_table_ids(&mut self, table_id_map: HashMap<TableId, TableId>) {
1039 for fragment in self.fragments.values_mut() {
1040 stream_graph_visitor::visit_internal_tables(
1041 &mut fragment.inner,
1042 |table, _table_type_name| {
1043 let target = table_id_map.get(&table.id).cloned().unwrap();
1044 table.id = target;
1045 },
1046 );
1047 }
1048 }
1049
1050 pub fn fit_internal_tables_trivial(
1053 &mut self,
1054 mut old_internal_tables: Vec<Table>,
1055 ) -> MetaResult<()> {
1056 let mut new_internal_table_ids = Vec::new();
1057 for fragment in self.fragments.values() {
1058 for table in &fragment.extract_internal_tables() {
1059 new_internal_table_ids.push(table.id);
1060 }
1061 }
1062
1063 if new_internal_table_ids.len() != old_internal_tables.len() {
1064 bail!(
1065 "Different number of internal tables. New: {}, Old: {}",
1066 new_internal_table_ids.len(),
1067 old_internal_tables.len()
1068 );
1069 }
1070 old_internal_tables.sort_by(|a, b| a.id.cmp(&b.id));
1071 new_internal_table_ids.sort();
1072
1073 let internal_table_id_map = new_internal_table_ids
1074 .into_iter()
1075 .zip_eq_fast(old_internal_tables.into_iter())
1076 .collect::<HashMap<_, _>>();
1077
1078 for fragment in self.fragments.values_mut() {
1081 stream_graph_visitor::visit_internal_tables(
1082 &mut fragment.inner,
1083 |table, _table_type_name| {
1084 let target = internal_table_id_map.get(&table.id).cloned().unwrap();
1086 *table = target;
1087 },
1088 );
1089 }
1090
1091 Ok(())
1092 }
1093
1094 pub fn fit_internal_table_ids_with_mapping(&mut self, mut matches: HashMap<TableId, Table>) {
1096 for fragment in self.fragments.values_mut() {
1097 stream_graph_visitor::visit_internal_tables(
1098 &mut fragment.inner,
1099 |table, _table_type_name| {
1100 let target = matches.remove(&table.id).unwrap_or_else(|| {
1101 panic!("no matching table for table {}({})", table.id, table.name)
1102 });
1103 table.id = target.id;
1104 table.maybe_vnode_count = target.maybe_vnode_count;
1105 },
1106 );
1107 }
1108 }
1109
1110 pub fn fit_snapshot_backfill_epochs(
1111 &mut self,
1112 mut snapshot_backfill_epochs: HashMap<StreamNodeLocalOperatorId, u64>,
1113 ) {
1114 for fragment in self.fragments.values_mut() {
1115 visit_stream_node_cont_mut(fragment.node.as_mut().unwrap(), |node| {
1116 if let PbNodeBody::StreamScan(scan) = node.node_body.as_mut().unwrap()
1117 && let StreamScanType::SnapshotBackfill
1118 | StreamScanType::CrossDbSnapshotBackfill = scan.stream_scan_type()
1119 {
1120 let Some(epoch) = snapshot_backfill_epochs.remove(&node.operator_id) else {
1121 panic!("no snapshot epoch found for node {:?}", node)
1122 };
1123 scan.snapshot_backfill_epoch = Some(epoch);
1124 }
1125 true
1126 })
1127 }
1128 }
1129
1130 pub fn table_fragment_id(&self) -> FragmentId {
1132 self.fragments
1133 .values()
1134 .filter(|b| b.job_id.is_some())
1135 .map(|b| b.fragment_id)
1136 .exactly_one()
1137 .expect("require exactly 1 materialize/sink/cdc source node when creating the streaming job")
1138 }
1139
1140 pub fn dml_fragment_id(&self) -> Option<FragmentId> {
1142 self.fragments
1143 .values()
1144 .filter(|b| {
1145 FragmentTypeMask::from(b.fragment_type_mask).contains(FragmentTypeFlag::Dml)
1146 })
1147 .map(|b| b.fragment_id)
1148 .at_most_one()
1149 .expect("require at most 1 dml node when creating the streaming job")
1150 }
1151
1152 pub fn dependent_table_ids(&self) -> &HashSet<TableId> {
1154 &self.dependent_table_ids
1155 }
1156
1157 pub fn specified_parallelism(&self) -> Option<NonZeroUsize> {
1159 self.specified_parallelism
1160 }
1161
1162 pub fn specified_backfill_parallelism(&self) -> Option<NonZeroUsize> {
1164 self.specified_backfill_parallelism
1165 }
1166
1167 pub fn max_parallelism(&self) -> usize {
1169 self.max_parallelism
1170 }
1171
1172 fn get_downstreams(
1174 &self,
1175 fragment_id: GlobalFragmentId,
1176 ) -> &HashMap<GlobalFragmentId, StreamFragmentEdge> {
1177 self.downstreams.get(&fragment_id).unwrap_or(&EMPTY_HASHMAP)
1178 }
1179
1180 fn get_upstreams(
1182 &self,
1183 fragment_id: GlobalFragmentId,
1184 ) -> &HashMap<GlobalFragmentId, StreamFragmentEdge> {
1185 self.upstreams.get(&fragment_id).unwrap_or(&EMPTY_HASHMAP)
1186 }
1187
1188 pub fn collect_snapshot_backfill_info(
1189 &self,
1190 ) -> MetaResult<(Option<SnapshotBackfillInfo>, SnapshotBackfillInfo)> {
1191 Self::collect_snapshot_backfill_info_impl(self.fragments.values().map(|fragment| {
1192 (
1193 fragment.node.as_ref().unwrap(),
1194 fragment.fragment_type_mask.into(),
1195 )
1196 }))
1197 }
1198
1199 pub fn collect_snapshot_backfill_info_impl(
1201 fragments: impl IntoIterator<Item = (&PbStreamNode, FragmentTypeMask)>,
1202 ) -> MetaResult<(Option<SnapshotBackfillInfo>, SnapshotBackfillInfo)> {
1203 let mut prev_stream_scan: Option<(Option<SnapshotBackfillInfo>, StreamScanNode)> = None;
1204 let mut cross_db_info = SnapshotBackfillInfo {
1205 upstream_mv_table_id_to_backfill_epoch: Default::default(),
1206 };
1207 let mut result = Ok(());
1208 for (node, fragment_type_mask) in fragments {
1209 visit_stream_node_cont(node, |node| {
1210 if let Some(NodeBody::StreamScan(stream_scan)) = node.node_body.as_ref() {
1211 let stream_scan_type = StreamScanType::try_from(stream_scan.stream_scan_type)
1212 .expect("invalid stream_scan_type");
1213 let is_snapshot_backfill = match stream_scan_type {
1214 StreamScanType::SnapshotBackfill => {
1215 assert!(
1216 fragment_type_mask
1217 .contains(FragmentTypeFlag::SnapshotBackfillStreamScan)
1218 );
1219 true
1220 }
1221 StreamScanType::CrossDbSnapshotBackfill => {
1222 assert!(
1223 fragment_type_mask
1224 .contains(FragmentTypeFlag::CrossDbSnapshotBackfillStreamScan)
1225 );
1226 cross_db_info
1227 .upstream_mv_table_id_to_backfill_epoch
1228 .insert(stream_scan.table_id, stream_scan.snapshot_backfill_epoch);
1229
1230 return true;
1231 }
1232 _ => false,
1233 };
1234
1235 match &mut prev_stream_scan {
1236 Some((prev_snapshot_backfill_info, prev_stream_scan)) => {
1237 match (prev_snapshot_backfill_info, is_snapshot_backfill) {
1238 (Some(prev_snapshot_backfill_info), true) => {
1239 prev_snapshot_backfill_info
1240 .upstream_mv_table_id_to_backfill_epoch
1241 .insert(
1242 stream_scan.table_id,
1243 stream_scan.snapshot_backfill_epoch,
1244 );
1245 true
1246 }
1247 (None, false) => true,
1248 (_, _) => {
1249 result = Err(anyhow!("must be either all snapshot_backfill or no snapshot_backfill. Curr: {stream_scan:?} Prev: {prev_stream_scan:?}").into());
1250 false
1251 }
1252 }
1253 }
1254 None => {
1255 prev_stream_scan = Some((
1256 if is_snapshot_backfill {
1257 Some(SnapshotBackfillInfo {
1258 upstream_mv_table_id_to_backfill_epoch: HashMap::from_iter(
1259 [(
1260 stream_scan.table_id,
1261 stream_scan.snapshot_backfill_epoch,
1262 )],
1263 ),
1264 })
1265 } else {
1266 None
1267 },
1268 *stream_scan.clone(),
1269 ));
1270 true
1271 }
1272 }
1273 } else {
1274 true
1275 }
1276 })
1277 }
1278 result.map(|_| {
1279 (
1280 prev_stream_scan
1281 .map(|(snapshot_backfill_info, _)| snapshot_backfill_info)
1282 .unwrap_or(None),
1283 cross_db_info,
1284 )
1285 })
1286 }
1287
1288 pub fn collect_backfill_mapping(
1290 fragments: impl Iterator<Item = (FragmentId, FragmentTypeMask, &PbStreamNode)>,
1291 ) -> HashMap<RelationId, Vec<FragmentId>> {
1292 let mut mapping = HashMap::new();
1293 for (fragment_id, fragment_type_mask, node) in fragments {
1294 let has_some_scan = fragment_type_mask
1295 .contains_any([FragmentTypeFlag::StreamScan, FragmentTypeFlag::SourceScan]);
1296 if has_some_scan {
1297 visit_stream_node_cont(node, |node| {
1298 match node.node_body.as_ref() {
1299 Some(NodeBody::StreamScan(stream_scan)) => {
1300 let table_id = stream_scan.table_id;
1301 let fragments: &mut Vec<_> =
1302 mapping.entry(table_id.as_relation_id()).or_default();
1303 fragments.push(fragment_id);
1304 false
1306 }
1307 Some(NodeBody::SourceBackfill(source_backfill)) => {
1308 let source_id = source_backfill.upstream_source_id;
1309 let fragments: &mut Vec<_> =
1310 mapping.entry(source_id.as_relation_id()).or_default();
1311 fragments.push(fragment_id);
1312 false
1314 }
1315 _ => true,
1316 }
1317 })
1318 }
1319 }
1320 mapping
1321 }
1322
1323 pub fn create_fragment_backfill_ordering(&self) -> UserDefinedFragmentBackfillOrder {
1327 let mapping =
1328 Self::collect_backfill_mapping(self.fragments.iter().map(|(fragment_id, fragment)| {
1329 (
1330 fragment_id.as_global_id(),
1331 fragment.fragment_type_mask.into(),
1332 fragment.node.as_ref().expect("should exist node"),
1333 )
1334 }));
1335 let mut fragment_ordering: HashMap<FragmentId, Vec<FragmentId>> = HashMap::new();
1336
1337 for (rel_id, downstream_rel_ids) in &self.backfill_order.order {
1339 let fragment_ids = mapping.get(rel_id).unwrap();
1340 for fragment_id in fragment_ids {
1341 let downstream_fragment_ids = downstream_rel_ids
1342 .data
1343 .iter()
1344 .flat_map(|&downstream_rel_id| mapping.get(&downstream_rel_id).unwrap().iter())
1345 .copied()
1346 .collect();
1347 fragment_ordering.insert(*fragment_id, downstream_fragment_ids);
1348 }
1349 }
1350
1351 UserDefinedFragmentBackfillOrder {
1352 inner: fragment_ordering,
1353 }
1354 }
1355
1356 pub fn extend_fragment_backfill_ordering_with_locality_backfill<
1357 'a,
1358 FI: Iterator<Item = (FragmentId, FragmentTypeMask, &'a PbStreamNode)> + 'a,
1359 >(
1360 fragment_ordering: UserDefinedFragmentBackfillOrder,
1361 fragment_downstreams: &FragmentDownstreamRelation,
1362 get_fragments: impl Fn() -> FI,
1363 ) -> ExtendedFragmentBackfillOrder {
1364 let mut fragment_ordering = fragment_ordering.inner;
1365 let mapping = Self::collect_backfill_mapping(get_fragments());
1366 if fragment_ordering.is_empty() {
1369 for value in mapping.values() {
1370 for &fragment_id in value {
1371 fragment_ordering.entry(fragment_id).or_default();
1372 }
1373 }
1374 }
1375
1376 let locality_provider_dependencies = Self::find_locality_provider_dependencies(
1378 get_fragments().map(|(fragment_id, _, node)| (fragment_id, node)),
1379 fragment_downstreams,
1380 );
1381
1382 let backfill_fragments: HashSet<FragmentId> = mapping.values().flatten().copied().collect();
1383
1384 let all_locality_provider_fragments: HashSet<FragmentId> =
1387 locality_provider_dependencies.keys().copied().collect();
1388 let downstream_locality_provider_fragments: HashSet<FragmentId> =
1389 locality_provider_dependencies
1390 .values()
1391 .flatten()
1392 .copied()
1393 .collect();
1394 let locality_provider_root_fragments: Vec<FragmentId> = all_locality_provider_fragments
1395 .difference(&downstream_locality_provider_fragments)
1396 .copied()
1397 .collect();
1398
1399 for &backfill_fragment_id in &backfill_fragments {
1402 fragment_ordering
1403 .entry(backfill_fragment_id)
1404 .or_default()
1405 .extend(locality_provider_root_fragments.iter().copied());
1406 }
1407
1408 for (fragment_id, downstream_fragments) in locality_provider_dependencies {
1410 fragment_ordering
1411 .entry(fragment_id)
1412 .or_default()
1413 .extend(downstream_fragments);
1414 }
1415
1416 for downstream in fragment_ordering.values_mut() {
1420 let mut seen = HashSet::new();
1421 downstream.retain(|id| seen.insert(*id));
1422 }
1423
1424 ExtendedFragmentBackfillOrder {
1425 inner: fragment_ordering,
1426 }
1427 }
1428
1429 pub fn find_locality_provider_fragment_state_table_mapping(
1430 &self,
1431 ) -> HashMap<FragmentId, Vec<TableId>> {
1432 let mut mapping: HashMap<FragmentId, Vec<TableId>> = HashMap::new();
1433
1434 for (fragment_id, fragment) in &self.fragments {
1435 let fragment_id = fragment_id.as_global_id();
1436
1437 if let Some(node) = fragment.node.as_ref() {
1439 let mut state_table_ids = Vec::new();
1440
1441 visit_stream_node_cont(node, |stream_node| {
1442 if let Some(NodeBody::LocalityProvider(locality_provider)) =
1443 stream_node.node_body.as_ref()
1444 {
1445 let state_table_id = locality_provider
1447 .state_table
1448 .as_ref()
1449 .expect("must have state table")
1450 .id;
1451 state_table_ids.push(state_table_id);
1452 false } else {
1454 true }
1456 });
1457
1458 if !state_table_ids.is_empty() {
1459 mapping.insert(fragment_id, state_table_ids);
1460 }
1461 }
1462 }
1463
1464 mapping
1465 }
1466
1467 pub fn find_locality_provider_dependencies<'a>(
1475 fragments_nodes: impl Iterator<Item = (FragmentId, &'a PbStreamNode)>,
1476 fragment_downstreams: &FragmentDownstreamRelation,
1477 ) -> HashMap<FragmentId, Vec<FragmentId>> {
1478 let mut locality_provider_fragments = HashSet::new();
1479 let mut dependencies: HashMap<FragmentId, Vec<FragmentId>> = HashMap::new();
1480
1481 for (fragment_id, node) in fragments_nodes {
1483 let has_locality_provider = Self::fragment_has_locality_provider(node);
1484
1485 if has_locality_provider {
1486 locality_provider_fragments.insert(fragment_id);
1487 dependencies.entry(fragment_id).or_default();
1488 }
1489 }
1490
1491 for &provider_fragment_id in &locality_provider_fragments {
1495 let mut visited = HashSet::new();
1497 let mut downstream_locality_providers = Vec::new();
1498
1499 Self::collect_downstream_locality_providers(
1500 provider_fragment_id,
1501 &locality_provider_fragments,
1502 fragment_downstreams,
1503 &mut visited,
1504 &mut downstream_locality_providers,
1505 );
1506
1507 dependencies
1509 .entry(provider_fragment_id)
1510 .or_default()
1511 .extend(downstream_locality_providers);
1512 }
1513
1514 dependencies
1515 }
1516
1517 fn fragment_has_locality_provider(node: &PbStreamNode) -> bool {
1518 let mut has_locality_provider = false;
1519
1520 {
1521 visit_stream_node_cont(node, |stream_node| {
1522 if let Some(NodeBody::LocalityProvider(_)) = stream_node.node_body.as_ref() {
1523 has_locality_provider = true;
1524 false } else {
1526 true }
1528 });
1529 }
1530
1531 has_locality_provider
1532 }
1533
1534 fn collect_downstream_locality_providers(
1536 current_fragment_id: FragmentId,
1537 locality_provider_fragments: &HashSet<FragmentId>,
1538 fragment_downstreams: &FragmentDownstreamRelation,
1539 visited: &mut HashSet<FragmentId>,
1540 downstream_providers: &mut Vec<FragmentId>,
1541 ) {
1542 if visited.contains(¤t_fragment_id) {
1543 return;
1544 }
1545 visited.insert(current_fragment_id);
1546
1547 for downstream_fragment_id in fragment_downstreams
1549 .get(¤t_fragment_id)
1550 .into_iter()
1551 .flat_map(|downstreams| {
1552 downstreams
1553 .iter()
1554 .map(|downstream| downstream.downstream_fragment_id)
1555 })
1556 {
1557 if locality_provider_fragments.contains(&downstream_fragment_id) {
1559 downstream_providers.push(downstream_fragment_id);
1560 }
1561
1562 Self::collect_downstream_locality_providers(
1564 downstream_fragment_id,
1565 locality_provider_fragments,
1566 fragment_downstreams,
1567 visited,
1568 downstream_providers,
1569 );
1570 }
1571 }
1572}
1573
1574pub fn fill_snapshot_backfill_epoch(
1577 node: &mut StreamNode,
1578 snapshot_backfill_info: Option<&SnapshotBackfillInfo>,
1579 cross_db_snapshot_backfill_info: &SnapshotBackfillInfo,
1580) -> MetaResult<bool> {
1581 let mut result = Ok(());
1582 let mut applied = false;
1583 visit_stream_node_cont_mut(node, |node| {
1584 if let Some(NodeBody::StreamScan(stream_scan)) = node.node_body.as_mut()
1585 && (stream_scan.stream_scan_type == StreamScanType::SnapshotBackfill as i32
1586 || stream_scan.stream_scan_type == StreamScanType::CrossDbSnapshotBackfill as i32)
1587 {
1588 result = try {
1589 let table_id = stream_scan.table_id;
1590 let snapshot_epoch = cross_db_snapshot_backfill_info
1591 .upstream_mv_table_id_to_backfill_epoch
1592 .get(&table_id)
1593 .or_else(|| {
1594 snapshot_backfill_info.and_then(|snapshot_backfill_info| {
1595 snapshot_backfill_info
1596 .upstream_mv_table_id_to_backfill_epoch
1597 .get(&table_id)
1598 })
1599 })
1600 .ok_or_else(|| anyhow!("upstream table id not covered: {}", table_id))?
1601 .ok_or_else(|| anyhow!("upstream table id not set: {}", table_id))?;
1602 if let Some(prev_snapshot_epoch) =
1603 stream_scan.snapshot_backfill_epoch.replace(snapshot_epoch)
1604 {
1605 Err(anyhow!(
1606 "snapshot backfill epoch set again: {} {} {}",
1607 table_id,
1608 prev_snapshot_epoch,
1609 snapshot_epoch
1610 ))?;
1611 }
1612 applied = true;
1613 };
1614 result.is_ok()
1615 } else {
1616 true
1617 }
1618 });
1619 result.map(|_| applied)
1620}
1621
1622static EMPTY_HASHMAP: LazyLock<HashMap<GlobalFragmentId, StreamFragmentEdge>> =
1623 LazyLock::new(HashMap::new);
1624
1625#[derive(Debug, Clone, EnumAsInner)]
1628pub(super) enum EitherFragment {
1629 Building(BuildingFragment),
1631
1632 Existing,
1634}
1635
1636#[derive(Debug)]
1645pub struct CompleteStreamFragmentGraph {
1646 building_graph: StreamFragmentGraph,
1648
1649 existing_fragments: HashMap<GlobalFragmentId, Fragment>,
1651
1652 extra_downstreams: HashMap<GlobalFragmentId, HashMap<GlobalFragmentId, StreamFragmentEdge>>,
1654
1655 extra_upstreams: HashMap<GlobalFragmentId, HashMap<GlobalFragmentId, StreamFragmentEdge>>,
1657}
1658
1659pub struct FragmentGraphUpstreamContext {
1660 pub upstream_root_fragments: HashMap<JobId, Fragment>,
1663}
1664
1665pub struct FragmentGraphDownstreamContext {
1666 pub original_root_fragment_id: FragmentId,
1667 pub downstream_fragments: Vec<(DispatcherType, Fragment)>,
1668}
1669
1670impl CompleteStreamFragmentGraph {
1671 #[cfg(test)]
1674 pub fn for_test(graph: StreamFragmentGraph) -> Self {
1675 Self {
1676 building_graph: graph,
1677 existing_fragments: Default::default(),
1678 extra_downstreams: Default::default(),
1679 extra_upstreams: Default::default(),
1680 }
1681 }
1682
1683 pub fn with_upstreams(
1687 graph: StreamFragmentGraph,
1688 upstream_context: FragmentGraphUpstreamContext,
1689 job_type: StreamingJobType,
1690 ) -> MetaResult<Self> {
1691 Self::build_helper(graph, Some(upstream_context), None, job_type)
1692 }
1693
1694 pub fn with_downstreams(
1697 graph: StreamFragmentGraph,
1698 downstream_context: FragmentGraphDownstreamContext,
1699 job_type: StreamingJobType,
1700 ) -> MetaResult<Self> {
1701 Self::build_helper(graph, None, Some(downstream_context), job_type)
1702 }
1703
1704 pub fn with_upstreams_and_downstreams(
1706 graph: StreamFragmentGraph,
1707 upstream_context: FragmentGraphUpstreamContext,
1708 downstream_context: FragmentGraphDownstreamContext,
1709 job_type: StreamingJobType,
1710 ) -> MetaResult<Self> {
1711 Self::build_helper(
1712 graph,
1713 Some(upstream_context),
1714 Some(downstream_context),
1715 job_type,
1716 )
1717 }
1718
1719 fn build_helper(
1721 mut graph: StreamFragmentGraph,
1722 upstream_ctx: Option<FragmentGraphUpstreamContext>,
1723 downstream_ctx: Option<FragmentGraphDownstreamContext>,
1724 job_type: StreamingJobType,
1725 ) -> MetaResult<Self> {
1726 let mut extra_downstreams = HashMap::new();
1727 let mut extra_upstreams = HashMap::new();
1728 let mut existing_fragments = HashMap::new();
1729
1730 if let Some(FragmentGraphUpstreamContext {
1731 upstream_root_fragments,
1732 }) = upstream_ctx
1733 {
1734 for (&id, fragment) in &mut graph.fragments {
1735 let uses_shuffled_backfill = fragment.has_shuffled_backfill();
1736
1737 for (&upstream_job_id, required_columns) in &fragment.upstream_job_columns {
1738 let upstream_fragment = upstream_root_fragments
1739 .get(&upstream_job_id)
1740 .context("upstream fragment not found")?;
1741 let upstream_root_fragment_id =
1742 GlobalFragmentId::new(upstream_fragment.fragment_id);
1743
1744 let edge = match job_type {
1745 StreamingJobType::Table(TableJobType::SharedCdcSource) => {
1746 assert_ne!(
1749 (fragment.fragment_type_mask & FragmentTypeFlag::CdcFilter as u32),
1750 0
1751 );
1752
1753 tracing::debug!(
1754 ?upstream_root_fragment_id,
1755 ?required_columns,
1756 identity = ?fragment.inner.get_node().unwrap().get_identity(),
1757 current_frag_id=?id,
1758 "CdcFilter with upstream source fragment"
1759 );
1760
1761 StreamFragmentEdge {
1762 id: EdgeId::UpstreamExternal {
1763 upstream_job_id,
1764 downstream_fragment_id: id,
1765 },
1766 dispatch_strategy: DispatchStrategy {
1769 r#type: DispatcherType::NoShuffle as _,
1770 dist_key_indices: vec![], output_mapping: DispatchOutputMapping::identical(
1772 CDC_SOURCE_COLUMN_NUM as _,
1773 )
1774 .into(),
1775 },
1776 }
1777 }
1778
1779 StreamingJobType::MaterializedView
1781 | StreamingJobType::Sink
1782 | StreamingJobType::Index => {
1783 if upstream_fragment
1786 .fragment_type_mask
1787 .contains(FragmentTypeFlag::Mview)
1788 {
1789 let (dist_key_indices, output_mapping) = {
1791 let mview_node = upstream_fragment
1792 .nodes
1793 .get_node_body()
1794 .unwrap()
1795 .as_materialize()
1796 .unwrap();
1797 let all_columns = mview_node.column_descs();
1798 let dist_key_indices = mview_node.dist_key_indices();
1799 let output_mapping = gen_output_mapping(
1800 required_columns,
1801 &all_columns,
1802 )
1803 .context(
1804 "BUG: column not found in the upstream materialized view",
1805 )?;
1806 (dist_key_indices, output_mapping)
1807 };
1808 let dispatch_strategy = mv_on_mv_dispatch_strategy(
1809 uses_shuffled_backfill,
1810 dist_key_indices,
1811 output_mapping,
1812 );
1813
1814 StreamFragmentEdge {
1815 id: EdgeId::UpstreamExternal {
1816 upstream_job_id,
1817 downstream_fragment_id: id,
1818 },
1819 dispatch_strategy,
1820 }
1821 }
1822 else if upstream_fragment
1825 .fragment_type_mask
1826 .contains(FragmentTypeFlag::Source)
1827 {
1828 let output_mapping = {
1829 let source_node = upstream_fragment
1830 .nodes
1831 .get_node_body()
1832 .unwrap()
1833 .as_source()
1834 .unwrap();
1835
1836 let all_columns = source_node.column_descs().unwrap();
1837 gen_output_mapping(required_columns, &all_columns).context(
1838 "BUG: column not found in the upstream source node",
1839 )?
1840 };
1841
1842 StreamFragmentEdge {
1843 id: EdgeId::UpstreamExternal {
1844 upstream_job_id,
1845 downstream_fragment_id: id,
1846 },
1847 dispatch_strategy: DispatchStrategy {
1850 r#type: DispatcherType::NoShuffle as _,
1851 dist_key_indices: vec![], output_mapping: Some(output_mapping),
1853 },
1854 }
1855 } else {
1856 bail!(
1857 "the upstream fragment should be a MView or Source, got fragment type: {:b}",
1858 upstream_fragment.fragment_type_mask
1859 )
1860 }
1861 }
1862 StreamingJobType::Source | StreamingJobType::Table(_) => {
1863 bail!(
1864 "the streaming job shouldn't have an upstream fragment, job_type: {:?}",
1865 job_type
1866 )
1867 }
1868 };
1869
1870 extra_downstreams
1872 .entry(upstream_root_fragment_id)
1873 .or_insert_with(HashMap::new)
1874 .try_insert(id, edge.clone())
1875 .unwrap();
1876 extra_upstreams
1877 .entry(id)
1878 .or_insert_with(HashMap::new)
1879 .try_insert(upstream_root_fragment_id, edge)
1880 .unwrap();
1881 }
1882 }
1883
1884 existing_fragments.extend(
1885 upstream_root_fragments
1886 .into_values()
1887 .map(|f| (GlobalFragmentId::new(f.fragment_id), f)),
1888 );
1889 }
1890
1891 if let Some(FragmentGraphDownstreamContext {
1892 original_root_fragment_id,
1893 downstream_fragments,
1894 }) = downstream_ctx
1895 {
1896 let original_table_fragment_id = GlobalFragmentId::new(original_root_fragment_id);
1897 let table_fragment_id = GlobalFragmentId::new(graph.table_fragment_id());
1898
1899 for (dispatcher_type, fragment) in &downstream_fragments {
1902 let id = GlobalFragmentId::new(fragment.fragment_id);
1903
1904 let output_columns = {
1906 let mut res = None;
1907
1908 stream_graph_visitor::visit_stream_node_body(&fragment.nodes, |node_body| {
1909 let columns = match node_body {
1910 NodeBody::StreamScan(stream_scan) => stream_scan.upstream_columns(),
1911 NodeBody::SourceBackfill(source_backfill) => {
1912 source_backfill.column_descs()
1914 }
1915 _ => return,
1916 };
1917 res = Some(columns);
1918 });
1919
1920 res.context("failed to locate downstream scan")?
1921 };
1922
1923 let table_fragment = graph.fragments.get(&table_fragment_id).unwrap();
1924 let nodes = table_fragment.node.as_ref().unwrap();
1925
1926 let (dist_key_indices, output_mapping) = match job_type {
1927 StreamingJobType::Table(_) | StreamingJobType::MaterializedView => {
1928 let mview_node = nodes.get_node_body().unwrap().as_materialize().unwrap();
1929 let all_columns = mview_node.column_descs();
1930 let dist_key_indices = mview_node.dist_key_indices();
1931 let output_mapping = gen_output_mapping(&output_columns, &all_columns)
1932 .ok_or_else(|| {
1933 MetaError::invalid_parameter(
1934 "unable to drop the column due to \
1935 being referenced by downstream materialized views or sinks",
1936 )
1937 })?;
1938 (dist_key_indices, output_mapping)
1939 }
1940
1941 StreamingJobType::Source => {
1942 let source_node = nodes.get_node_body().unwrap().as_source().unwrap();
1943 let all_columns = source_node.column_descs().unwrap();
1944 let output_mapping = gen_output_mapping(&output_columns, &all_columns)
1945 .ok_or_else(|| {
1946 MetaError::invalid_parameter(
1947 "unable to drop the column due to \
1948 being referenced by downstream materialized views or sinks",
1949 )
1950 })?;
1951 assert_eq!(*dispatcher_type, DispatcherType::NoShuffle);
1952 (
1953 vec![], output_mapping,
1955 )
1956 }
1957
1958 _ => bail!("unsupported job type for replacement: {job_type:?}"),
1959 };
1960
1961 let edge = StreamFragmentEdge {
1962 id: EdgeId::DownstreamExternal(DownstreamExternalEdgeId {
1963 original_upstream_fragment_id: original_table_fragment_id,
1964 downstream_fragment_id: id,
1965 }),
1966 dispatch_strategy: DispatchStrategy {
1967 r#type: *dispatcher_type as i32,
1968 output_mapping: Some(output_mapping),
1969 dist_key_indices,
1970 },
1971 };
1972
1973 extra_downstreams
1974 .entry(table_fragment_id)
1975 .or_insert_with(HashMap::new)
1976 .try_insert(id, edge.clone())
1977 .unwrap();
1978 extra_upstreams
1979 .entry(id)
1980 .or_insert_with(HashMap::new)
1981 .try_insert(table_fragment_id, edge)
1982 .unwrap();
1983 }
1984
1985 existing_fragments.extend(
1986 downstream_fragments
1987 .into_iter()
1988 .map(|(_, f)| (GlobalFragmentId::new(f.fragment_id), f)),
1989 );
1990 }
1991
1992 Ok(Self {
1993 building_graph: graph,
1994 existing_fragments,
1995 extra_downstreams,
1996 extra_upstreams,
1997 })
1998 }
1999}
2000
2001fn gen_output_mapping(
2003 required_columns: &[PbColumnDesc],
2004 upstream_columns: &[PbColumnDesc],
2005) -> Option<DispatchOutputMapping> {
2006 let len = required_columns.len();
2007 let mut indices = vec![0; len];
2008 let mut types = None;
2009
2010 for (i, r) in required_columns.iter().enumerate() {
2011 let (ui, u) = upstream_columns
2012 .iter()
2013 .find_position(|&u| u.column_id == r.column_id)?;
2014 indices[i] = ui as u32;
2015
2016 if u.column_type != r.column_type {
2019 types.get_or_insert_with(|| vec![TypePair::default(); len])[i] = TypePair {
2020 upstream: u.column_type.clone(),
2021 downstream: r.column_type.clone(),
2022 };
2023 }
2024 }
2025
2026 let types = types.unwrap_or(Vec::new());
2028
2029 Some(DispatchOutputMapping { indices, types })
2030}
2031
2032fn mv_on_mv_dispatch_strategy(
2033 uses_shuffled_backfill: bool,
2034 dist_key_indices: Vec<u32>,
2035 output_mapping: DispatchOutputMapping,
2036) -> DispatchStrategy {
2037 if uses_shuffled_backfill {
2038 if !dist_key_indices.is_empty() {
2039 DispatchStrategy {
2040 r#type: DispatcherType::Hash as _,
2041 dist_key_indices,
2042 output_mapping: Some(output_mapping),
2043 }
2044 } else {
2045 DispatchStrategy {
2046 r#type: DispatcherType::Simple as _,
2047 dist_key_indices: vec![], output_mapping: Some(output_mapping),
2049 }
2050 }
2051 } else {
2052 DispatchStrategy {
2053 r#type: DispatcherType::NoShuffle as _,
2054 dist_key_indices: vec![], output_mapping: Some(output_mapping),
2056 }
2057 }
2058}
2059
2060impl CompleteStreamFragmentGraph {
2061 pub(super) fn all_fragment_ids(&self) -> impl Iterator<Item = GlobalFragmentId> + '_ {
2064 self.building_graph
2065 .fragments
2066 .keys()
2067 .chain(self.existing_fragments.keys())
2068 .copied()
2069 }
2070
2071 pub(super) fn all_edges(
2073 &self,
2074 ) -> impl Iterator<Item = (GlobalFragmentId, GlobalFragmentId, &StreamFragmentEdge)> + '_ {
2075 self.building_graph
2076 .downstreams
2077 .iter()
2078 .chain(self.extra_downstreams.iter())
2079 .flat_map(|(&from, tos)| tos.iter().map(move |(&to, edge)| (from, to, edge)))
2080 }
2081
2082 pub(super) fn existing_distribution(&self) -> HashMap<GlobalFragmentId, Distribution> {
2084 self.existing_fragments
2085 .iter()
2086 .map(|(&id, f)| (id, Distribution::from_fragment(f)))
2087 .collect()
2088 }
2089
2090 pub(super) fn topo_order(&self) -> MetaResult<Vec<GlobalFragmentId>> {
2097 let mut topo = Vec::new();
2098 let mut downstream_cnts = HashMap::new();
2099
2100 for fragment_id in self.all_fragment_ids() {
2102 let downstream_cnt = self.get_downstreams(fragment_id).count();
2104 if downstream_cnt == 0 {
2105 topo.push(fragment_id);
2106 } else {
2107 downstream_cnts.insert(fragment_id, downstream_cnt);
2108 }
2109 }
2110
2111 let mut i = 0;
2112 while let Some(&fragment_id) = topo.get(i) {
2113 i += 1;
2114 for (upstream_job_id, _) in self.get_upstreams(fragment_id) {
2116 let downstream_cnt = downstream_cnts.get_mut(&upstream_job_id).unwrap();
2117 *downstream_cnt -= 1;
2118 if *downstream_cnt == 0 {
2119 downstream_cnts.remove(&upstream_job_id);
2120 topo.push(upstream_job_id);
2121 }
2122 }
2123 }
2124
2125 if !downstream_cnts.is_empty() {
2126 bail!("graph is not a DAG");
2128 }
2129
2130 Ok(topo)
2131 }
2132
2133 pub(super) fn seal_fragment(
2136 &self,
2137 id: GlobalFragmentId,
2138 distribution: Distribution,
2139 stream_node: StreamNode,
2140 ) -> Fragment {
2141 let building_fragment = self.get_fragment(id).into_building().unwrap();
2142 let internal_tables = building_fragment.extract_internal_tables();
2143 let BuildingFragment {
2144 inner,
2145 job_id,
2146 upstream_job_columns: _,
2147 } = building_fragment;
2148
2149 let distribution_type = distribution.to_distribution_type();
2150 let vnode_count = distribution.vnode_count();
2151
2152 let materialized_fragment_id =
2153 if FragmentTypeMask::from(inner.fragment_type_mask).contains(FragmentTypeFlag::Mview) {
2154 job_id.map(JobId::as_mv_table_id)
2155 } else {
2156 None
2157 };
2158
2159 let vector_index_fragment_id =
2160 if inner.fragment_type_mask & FragmentTypeFlag::VectorIndexWrite as u32 != 0 {
2161 job_id.map(JobId::as_mv_table_id)
2162 } else {
2163 None
2164 };
2165
2166 let state_table_ids = internal_tables
2167 .iter()
2168 .map(|t| t.id)
2169 .chain(materialized_fragment_id)
2170 .chain(vector_index_fragment_id)
2171 .collect();
2172
2173 Fragment {
2174 fragment_id: inner.fragment_id,
2175 fragment_type_mask: inner.fragment_type_mask.into(),
2176 distribution_type,
2177 state_table_ids,
2178 maybe_vnode_count: VnodeCount::set(vnode_count).to_protobuf(),
2179 nodes: stream_node,
2180 }
2181 }
2182
2183 pub(super) fn get_fragment(&self, fragment_id: GlobalFragmentId) -> EitherFragment {
2186 if self.existing_fragments.contains_key(&fragment_id) {
2187 EitherFragment::Existing
2188 } else {
2189 EitherFragment::Building(
2190 self.building_graph
2191 .fragments
2192 .get(&fragment_id)
2193 .unwrap()
2194 .clone(),
2195 )
2196 }
2197 }
2198
2199 pub(super) fn get_downstreams(
2202 &self,
2203 fragment_id: GlobalFragmentId,
2204 ) -> impl Iterator<Item = (GlobalFragmentId, &StreamFragmentEdge)> {
2205 self.building_graph
2206 .get_downstreams(fragment_id)
2207 .iter()
2208 .chain(
2209 self.extra_downstreams
2210 .get(&fragment_id)
2211 .into_iter()
2212 .flatten(),
2213 )
2214 .map(|(&id, edge)| (id, edge))
2215 }
2216
2217 pub(super) fn get_upstreams(
2220 &self,
2221 fragment_id: GlobalFragmentId,
2222 ) -> impl Iterator<Item = (GlobalFragmentId, &StreamFragmentEdge)> {
2223 self.building_graph
2224 .get_upstreams(fragment_id)
2225 .iter()
2226 .chain(self.extra_upstreams.get(&fragment_id).into_iter().flatten())
2227 .map(|(&id, edge)| (id, edge))
2228 }
2229
2230 pub(super) fn building_fragments(&self) -> &HashMap<GlobalFragmentId, BuildingFragment> {
2232 &self.building_graph.fragments
2233 }
2234
2235 pub(super) fn building_fragments_mut(
2237 &mut self,
2238 ) -> &mut HashMap<GlobalFragmentId, BuildingFragment> {
2239 &mut self.building_graph.fragments
2240 }
2241
2242 pub(super) fn max_parallelism(&self) -> usize {
2244 self.building_graph.max_parallelism()
2245 }
2246}
2247
2248#[cfg(test)]
2249mod tests {
2250 use risingwave_common::catalog::{ColumnDesc, ColumnId};
2251 use risingwave_common::types::DataType;
2252 use risingwave_pb::catalog::SinkType as PbSinkType;
2253 use risingwave_pb::meta::table_fragments::fragment::PbFragmentDistributionType;
2254 use risingwave_pb::plan_common::StorageTableDesc;
2255 use risingwave_pb::stream_plan::{
2256 BatchPlanNode, MergeNode, ProjectNode, SinkDesc, SinkLogStoreType, SinkNode, StreamNode,
2257 StreamScanNode, StreamScanType,
2258 };
2259
2260 use super::*;
2261
2262 fn make_column(name: &str, id: i32, data_type: DataType) -> ColumnCatalog {
2263 ColumnCatalog::visible(ColumnDesc::named(name, ColumnId::new(id), data_type))
2264 }
2265
2266 fn make_field(table_name: &str, column: &ColumnCatalog) -> risingwave_pb::plan_common::Field {
2267 Field::new(
2268 format!("{}.{}", table_name, column.column_desc.name),
2269 column.data_type().clone(),
2270 )
2271 .to_prost()
2272 }
2273
2274 fn make_input_ref(index: u32, data_type: &DataType) -> PbExprNode {
2275 PbExprNode {
2276 function_type: expr_node::Type::Unspecified as i32,
2277 return_type: Some(data_type.to_protobuf()),
2278 rex_node: Some(expr_node::RexNode::InputRef(index)),
2279 }
2280 }
2281
2282 fn make_stream_scan_node(
2283 table_name: &str,
2284 table_id: u32,
2285 columns: &[ColumnCatalog],
2286 ) -> StreamNode {
2287 let merge_node = StreamNode {
2288 node_body: Some(NodeBody::Merge(Box::new(MergeNode {
2289 upstream_fragment_id: 0.into(),
2290 ..Default::default()
2291 }))),
2292 fields: columns
2293 .iter()
2294 .map(|col| make_field(table_name, col))
2295 .collect(),
2296 ..Default::default()
2297 };
2298 let batch_plan_node = StreamNode {
2299 node_body: Some(NodeBody::BatchPlan(Box::new(BatchPlanNode {
2300 ..Default::default()
2301 }))),
2302 ..Default::default()
2303 };
2304 let stream_scan_node = StreamScanNode {
2305 table_id: table_id.into(),
2306 upstream_column_ids: columns.iter().map(|c| c.column_id().get_id()).collect(),
2307 output_indices: (0..columns.len()).map(|i| i as u32).collect(),
2308 stream_scan_type: StreamScanType::ArrangementBackfill as i32,
2309 table_desc: Some(StorageTableDesc {
2310 table_id: table_id.into(),
2311 columns: columns
2312 .iter()
2313 .map(|col| col.column_desc.to_protobuf())
2314 .collect(),
2315 value_indices: (0..columns.len()).map(|i| i as u32).collect(),
2316 versioned: true,
2317 ..Default::default()
2318 }),
2319 ..Default::default()
2320 };
2321 StreamNode {
2322 node_body: Some(NodeBody::StreamScan(Box::new(stream_scan_node))),
2323 fields: columns
2324 .iter()
2325 .map(|col| make_field(table_name, col))
2326 .collect(),
2327 input: vec![merge_node, batch_plan_node],
2328 ..Default::default()
2329 }
2330 }
2331
2332 fn make_project_node(
2333 table_name: &str,
2334 columns: &[ColumnCatalog],
2335 input: StreamNode,
2336 ) -> StreamNode {
2337 let select_list = columns
2338 .iter()
2339 .enumerate()
2340 .map(|(i, col)| make_input_ref(i as u32, col.data_type()))
2341 .collect();
2342 StreamNode {
2343 node_body: Some(NodeBody::Project(Box::new(ProjectNode {
2344 select_list,
2345 ..Default::default()
2346 }))),
2347 fields: columns
2348 .iter()
2349 .map(|col| make_field(table_name, col))
2350 .collect(),
2351 input: vec![input],
2352 ..Default::default()
2353 }
2354 }
2355
2356 #[tokio::test]
2357 async fn test_rewrite_refresh_schema_sink_fragment_with_project() {
2358 let env = MetaSrvEnv::for_test().await;
2359 let id_gen_manager = env.id_gen_manager().as_ref();
2360
2361 let table_name = "t";
2362 let columns = vec![
2363 make_column("a", 1, DataType::Int64),
2364 make_column("b", 2, DataType::Int64),
2365 ];
2366 let new_column = make_column("c", 3, DataType::Varchar);
2367
2368 let mut upstream_columns = columns.clone();
2369 upstream_columns.push(new_column.clone());
2370 let upstream_table = PbTable {
2371 name: table_name.to_owned(),
2372 columns: upstream_columns
2373 .iter()
2374 .map(|col| col.to_protobuf())
2375 .collect(),
2376 ..Default::default()
2377 };
2378
2379 let sink = PbSink {
2380 columns: columns.iter().map(|col| col.to_protobuf()).collect(),
2381 sink_type: PbSinkType::AppendOnly as i32,
2382 ..Default::default()
2383 };
2384
2385 let sink_desc = SinkDesc {
2386 sink_type: PbSinkType::AppendOnly as i32,
2387 column_catalogs: sink.columns.clone(),
2388 ..Default::default()
2389 };
2390
2391 let stream_scan_node = make_stream_scan_node(table_name, 1, &columns);
2392 let project_node = make_project_node(table_name, &columns, stream_scan_node);
2393
2394 let log_store_table = PbTable {
2395 columns: columns
2396 .iter()
2397 .cloned()
2398 .map(|mut col| {
2399 col.column_desc.name = format!("{}_{}", table_name, col.column_desc.name);
2400 col.to_protobuf()
2401 })
2402 .collect(),
2403 value_indices: (0..columns.len()).map(|i| i as i32).collect(),
2404 ..Default::default()
2405 };
2406
2407 let original_fragment = Fragment {
2408 fragment_id: 1.into(),
2409 fragment_type_mask: FragmentTypeMask::default(),
2410 distribution_type: PbFragmentDistributionType::Single,
2411 state_table_ids: vec![],
2412 maybe_vnode_count: None,
2413 nodes: StreamNode {
2414 node_body: Some(NodeBody::Sink(Box::new(SinkNode {
2415 sink_desc: Some(sink_desc),
2416 table: Some(log_store_table),
2417 ..Default::default()
2418 }))),
2419 fields: columns
2420 .iter()
2421 .map(|col| make_field(table_name, col))
2422 .collect(),
2423 input: vec![project_node],
2424 ..Default::default()
2425 },
2426 };
2427
2428 let (new_fragment, _, _) = rewrite_refresh_schema_sink_fragment(
2429 &original_fragment,
2430 &sink,
2431 std::slice::from_ref(&new_column),
2432 &[],
2433 &upstream_table,
2434 7.into(),
2435 id_gen_manager,
2436 )
2437 .unwrap();
2438
2439 let sink_node = &new_fragment.nodes;
2440 let [project_node] = sink_node.input.as_slice() else {
2441 panic!("Sink has more than 1 input: {:?}", sink_node.input);
2442 };
2443 let PbNodeBody::Project(project_body) = project_node.node_body.as_ref().unwrap() else {
2444 panic!(
2445 "expect PbNodeBody::Project but got: {:?}",
2446 project_node.node_body
2447 );
2448 };
2449 assert_eq!(project_body.select_list.len(), columns.len() + 1);
2450 let last_expr = project_body.select_list.last().unwrap();
2451 assert!(
2452 matches!(last_expr.rex_node, Some(expr_node::RexNode::InputRef(idx)) if idx == columns.len() as u32)
2453 );
2454 assert_eq!(project_node.fields.len(), columns.len() + 1);
2455
2456 let [stream_scan_node] = project_node.input.as_slice() else {
2457 panic!("Project has more than 1 input: {:?}", project_node.input);
2458 };
2459 let PbNodeBody::StreamScan(scan) = stream_scan_node.node_body.as_ref().unwrap() else {
2460 panic!(
2461 "expect PbNodeBody::StreamScan but got: {:?}",
2462 stream_scan_node.node_body
2463 );
2464 };
2465 assert_eq!(
2466 scan.upstream_column_ids.last().copied(),
2467 Some(new_column.column_id().get_id())
2468 );
2469 assert_eq!(
2470 scan.output_indices.last().copied(),
2471 Some(columns.len() as u32)
2472 );
2473 assert_eq!(
2474 stream_scan_node.fields.last().unwrap().name,
2475 format!("{}.{}", table_name, new_column.column_desc.name)
2476 );
2477 }
2478
2479 #[tokio::test]
2480 async fn test_rewrite_refresh_schema_sink_fragment_drop_column_with_project() {
2481 let env = MetaSrvEnv::for_test().await;
2482 let id_gen_manager = env.id_gen_manager().as_ref();
2483
2484 let table_name = "t";
2485 let columns = vec![
2486 make_column("a", 1, DataType::Int64),
2487 make_column("b", 2, DataType::Int64),
2488 make_column("tmp", 3, DataType::Varchar),
2489 ];
2490 let removed_column = columns.last().unwrap().clone();
2491 let upstream_columns = columns[..2].to_vec();
2492
2493 let upstream_table = PbTable {
2494 name: table_name.to_owned(),
2495 columns: upstream_columns
2496 .iter()
2497 .map(|col| col.to_protobuf())
2498 .collect(),
2499 ..Default::default()
2500 };
2501
2502 let sink = PbSink {
2503 columns: columns.iter().map(|col| col.to_protobuf()).collect(),
2504 sink_type: PbSinkType::AppendOnly as i32,
2505 ..Default::default()
2506 };
2507
2508 let sink_desc = SinkDesc {
2509 sink_type: PbSinkType::AppendOnly as i32,
2510 column_catalogs: sink.columns.clone(),
2511 ..Default::default()
2512 };
2513
2514 let stream_scan_node = make_stream_scan_node(table_name, 1, &columns);
2515 let project_node = make_project_node(table_name, &columns, stream_scan_node);
2516
2517 let log_store_table = PbTable {
2518 columns: columns
2519 .iter()
2520 .cloned()
2521 .map(|mut col| {
2522 col.column_desc.name = format!("{}_{}", table_name, col.column_desc.name);
2523 col.to_protobuf()
2524 })
2525 .collect(),
2526 value_indices: (0..columns.len()).map(|i| i as i32).collect(),
2527 ..Default::default()
2528 };
2529
2530 let original_fragment = Fragment {
2531 fragment_id: 1.into(),
2532 fragment_type_mask: FragmentTypeMask::default(),
2533 distribution_type: PbFragmentDistributionType::Single,
2534 state_table_ids: vec![],
2535 maybe_vnode_count: None,
2536 nodes: StreamNode {
2537 node_body: Some(NodeBody::Sink(Box::new(SinkNode {
2538 sink_desc: Some(sink_desc),
2539 table: Some(log_store_table),
2540 log_store_type: SinkLogStoreType::KvLogStore as i32,
2541 ..Default::default()
2542 }))),
2543 fields: columns
2544 .iter()
2545 .map(|col| make_field(table_name, col))
2546 .collect(),
2547 input: vec![project_node],
2548 ..Default::default()
2549 },
2550 };
2551
2552 let (new_fragment, new_schema, new_log_store_table) = rewrite_refresh_schema_sink_fragment(
2553 &original_fragment,
2554 &sink,
2555 &[],
2556 std::slice::from_ref(&removed_column),
2557 &upstream_table,
2558 7.into(),
2559 id_gen_manager,
2560 )
2561 .unwrap();
2562
2563 assert_eq!(new_schema.len(), 2);
2564 assert!(
2565 new_schema.iter().all(|col| {
2566 col.column_desc.as_ref().map(|desc| desc.name.as_str()) != Some("tmp")
2567 })
2568 );
2569
2570 let sink_node = &new_fragment.nodes;
2571 let [project_node] = sink_node.input.as_slice() else {
2572 panic!("Sink has more than 1 input: {:?}", sink_node.input);
2573 };
2574 let PbNodeBody::Project(project_body) = project_node.node_body.as_ref().unwrap() else {
2575 panic!(
2576 "expect PbNodeBody::Project but got: {:?}",
2577 project_node.node_body
2578 );
2579 };
2580 assert_eq!(project_body.select_list.len(), 2);
2581 assert!(project_node.fields.iter().all(|f| !f.name.contains("tmp")));
2582
2583 let [stream_scan_node] = project_node.input.as_slice() else {
2584 panic!("Project has more than 1 input: {:?}", project_node.input);
2585 };
2586 let PbNodeBody::StreamScan(scan) = stream_scan_node.node_body.as_ref().unwrap() else {
2587 panic!(
2588 "expect PbNodeBody::StreamScan but got: {:?}",
2589 stream_scan_node.node_body
2590 );
2591 };
2592 assert!(
2593 !scan
2594 .upstream_column_ids
2595 .iter()
2596 .any(|&id| id == removed_column.column_id().get_id())
2597 );
2598 assert!(
2599 stream_scan_node
2600 .fields
2601 .iter()
2602 .all(|f| !f.name.contains("tmp"))
2603 );
2604
2605 let new_log_store_table = new_log_store_table.expect("log store table should be updated");
2606 assert_eq!(
2607 new_log_store_table.value_indices,
2608 (0..columns.len()).map(|i| i as i32).collect::<Vec<_>>()
2609 );
2610 }
2611}