1use std::collections::{BTreeMap, HashMap, HashSet};
16use std::num::NonZeroUsize;
17use std::ops::{Deref, DerefMut};
18use std::sync::LazyLock;
19
20use anyhow::{Context, anyhow};
21use enum_as_inner::EnumAsInner;
22use itertools::Itertools;
23use risingwave_common::bail;
24use risingwave_common::catalog::{
25 CDC_SOURCE_COLUMN_NUM, ColumnCatalog, ColumnId, Field, FragmentTypeFlag, FragmentTypeMask,
26 TableId, generate_internal_table_name_with_type,
27};
28use risingwave_common::hash::VnodeCount;
29use risingwave_common::id::JobId;
30use risingwave_common::util::iter_util::ZipEqFast;
31use risingwave_common::util::stream_graph_visitor::{
32 self, visit_stream_node_cont, visit_stream_node_cont_mut,
33};
34use risingwave_connector::sink::catalog::SinkType;
35use risingwave_meta_model::streaming_job::BackfillOrders;
36use risingwave_pb::catalog::{PbSink, PbTable, Table};
37use risingwave_pb::ddl_service::TableJobType;
38use risingwave_pb::expr::{ExprNode as PbExprNode, expr_node};
39use risingwave_pb::id::{RelationId, StreamNodeLocalOperatorId};
40use risingwave_pb::plan_common::{PbColumnCatalog, PbColumnDesc};
41use risingwave_pb::stream_plan::dispatch_output_mapping::TypePair;
42use risingwave_pb::stream_plan::stream_fragment_graph::{
43 Parallelism, StreamFragment, StreamFragmentEdge as StreamFragmentEdgeProto,
44};
45use risingwave_pb::stream_plan::stream_node::{NodeBody, PbNodeBody};
46use risingwave_pb::stream_plan::{
47 BackfillOrder, DispatchOutputMapping, DispatchStrategy, DispatcherType, PbStreamNode,
48 PbStreamScanType, StreamFragmentGraph as StreamFragmentGraphProto, StreamNode, StreamScanNode,
49 StreamScanType,
50};
51
52use crate::barrier::SnapshotBackfillInfo;
53use crate::controller::id::IdGeneratorManager;
54use crate::manager::{MetaSrvEnv, StreamingJob, StreamingJobType};
55use crate::model::{Fragment, FragmentDownstreamRelation, FragmentId};
56use crate::stream::stream_graph::id::{GlobalFragmentId, GlobalFragmentIdGen, GlobalTableIdGen};
57use crate::stream::stream_graph::schedule::Distribution;
58use crate::{MetaError, MetaResult};
59
60#[derive(Debug, Clone)]
63pub(super) struct BuildingFragment {
64 inner: StreamFragment,
66
67 job_id: Option<JobId>,
69
70 upstream_job_columns: HashMap<JobId, Vec<PbColumnDesc>>,
75}
76
77impl BuildingFragment {
78 fn new(
81 id: GlobalFragmentId,
82 fragment: StreamFragment,
83 job: &StreamingJob,
84 table_id_gen: GlobalTableIdGen,
85 ) -> Self {
86 let mut fragment = StreamFragment {
87 fragment_id: id.as_global_id(),
88 ..fragment
89 };
90
91 Self::fill_internal_tables(&mut fragment, job, table_id_gen);
93
94 let job_id = Self::fill_job(&mut fragment, job).then(|| job.id());
95 let upstream_job_columns =
96 Self::extract_upstream_columns_except_cross_db_backfill(&fragment);
97
98 Self {
99 inner: fragment,
100 job_id,
101 upstream_job_columns,
102 }
103 }
104
105 fn extract_internal_tables(&self) -> Vec<Table> {
107 let mut fragment = self.inner.clone();
108 let mut tables = Vec::new();
109 stream_graph_visitor::visit_internal_tables(&mut fragment, |table, _| {
110 tables.push(table.clone());
111 });
112 tables
113 }
114
115 fn fill_internal_tables(
117 fragment: &mut StreamFragment,
118 job: &StreamingJob,
119 table_id_gen: GlobalTableIdGen,
120 ) {
121 let fragment_id = fragment.fragment_id;
122 stream_graph_visitor::visit_internal_tables(fragment, |table, table_type_name| {
123 table.id = table_id_gen
124 .to_global_id(table.id.as_raw_id())
125 .as_global_id();
126 table.schema_id = job.schema_id();
127 table.database_id = job.database_id();
128 table.name = generate_internal_table_name_with_type(
129 &job.name(),
130 fragment_id,
131 table.id,
132 table_type_name,
133 );
134 table.fragment_id = fragment_id;
135 table.owner = job.owner();
136 table.job_id = Some(job.id());
137 });
138 }
139
140 fn fill_job(fragment: &mut StreamFragment, job: &StreamingJob) -> bool {
142 let job_id = job.id();
143 let fragment_id = fragment.fragment_id;
144 let mut has_job = false;
145
146 stream_graph_visitor::visit_fragment_mut(fragment, |node_body| match node_body {
147 NodeBody::Materialize(materialize_node) => {
148 materialize_node.table_id = job_id.as_mv_table_id();
149
150 let table = materialize_node.table.insert(job.table().unwrap().clone());
152 table.fragment_id = fragment_id; if cfg!(not(debug_assertions)) {
155 table.definition = job.name();
156 }
157
158 has_job = true;
159 }
160 NodeBody::Sink(sink_node) => {
161 sink_node.sink_desc.as_mut().unwrap().id = job_id.as_sink_id();
162
163 has_job = true;
164 }
165 NodeBody::IcebergWithPkIndexWriter(writer_node) => {
166 writer_node.sink_desc.as_mut().unwrap().id = job_id.as_sink_id();
167
168 has_job = true;
169 }
170 NodeBody::IcebergWithPkIndexDvMerger(merger_node) => {
171 merger_node.sink_desc.as_mut().unwrap().id = job_id.as_sink_id();
172
173 has_job = true;
174 }
175 NodeBody::Dml(dml_node) => {
176 dml_node.table_id = job_id.as_mv_table_id();
177 dml_node.table_version_id = job.table_version_id().unwrap();
178 }
179 NodeBody::StreamFsFetch(fs_fetch_node) => {
180 if let StreamingJob::Table(table_source, _, _) = job
181 && let Some(node_inner) = fs_fetch_node.node_inner.as_mut()
182 && let Some(source) = table_source
183 {
184 node_inner.source_id = source.id;
185 if let Some(id) = source.optional_associated_table_id {
186 node_inner.associated_table_id = Some(id.into());
187 }
188 }
189 }
190 NodeBody::Source(source_node) => {
191 match job {
192 StreamingJob::Table(source, _table, _table_job_type) => {
195 if let Some(source_inner) = source_node.source_inner.as_mut()
196 && let Some(source) = source
197 {
198 debug_assert_ne!(source.id, job_id.as_raw_id());
199 source_inner.source_id = source.id;
200 if let Some(id) = source.optional_associated_table_id {
201 source_inner.associated_table_id = Some(id.into());
202 }
203 }
204 }
205 StreamingJob::Source(source) => {
206 has_job = true;
207 if let Some(source_inner) = source_node.source_inner.as_mut() {
208 debug_assert_eq!(source.id, job_id.as_raw_id());
209 source_inner.source_id = source.id;
210 if let Some(id) = source.optional_associated_table_id {
211 source_inner.associated_table_id = Some(id.into());
212 }
213 }
214 }
215 _ => {}
217 }
218 }
219 NodeBody::StreamCdcScan(node) => {
220 if let Some(table_desc) = node.cdc_table_desc.as_mut() {
221 table_desc.table_id = job_id.as_mv_table_id();
222 }
223 }
224 NodeBody::VectorIndexWrite(node) => {
225 let table = node.table.as_mut().unwrap();
226 table.id = job_id.as_mv_table_id();
227 table.database_id = job.database_id();
228 table.schema_id = job.schema_id();
229 table.fragment_id = fragment_id;
230 #[cfg(not(debug_assertions))]
231 {
232 table.definition = job.name();
233 }
234
235 has_job = true;
236 }
237 _ => {}
238 });
239
240 has_job
241 }
242
243 fn extract_upstream_columns_except_cross_db_backfill(
245 fragment: &StreamFragment,
246 ) -> HashMap<JobId, Vec<PbColumnDesc>> {
247 let mut table_columns = HashMap::new();
248
249 stream_graph_visitor::visit_fragment(fragment, |node_body| {
250 let (table_id, column_ids) = match node_body {
251 NodeBody::StreamScan(stream_scan) => {
252 if stream_scan.get_stream_scan_type().unwrap()
253 == StreamScanType::CrossDbSnapshotBackfill
254 {
255 return;
256 }
257 (
258 stream_scan.table_id.as_job_id(),
259 stream_scan.upstream_columns(),
260 )
261 }
262 NodeBody::CdcFilter(cdc_filter) => (
263 cdc_filter.upstream_source_id.as_share_source_job_id(),
264 vec![],
265 ),
266 NodeBody::SourceBackfill(backfill) => (
267 backfill.upstream_source_id.as_share_source_job_id(),
268 backfill.column_descs(),
270 ),
271 _ => return,
272 };
273 table_columns
274 .try_insert(table_id, column_ids)
275 .expect("currently there should be no two same upstream tables in a fragment");
276 });
277
278 table_columns
279 }
280
281 pub fn has_shuffled_backfill(&self) -> bool {
282 let stream_node = match self.inner.node.as_ref() {
283 Some(node) => node,
284 _ => return false,
285 };
286 let mut has_shuffled_backfill = false;
287 let has_shuffled_backfill_mut_ref = &mut has_shuffled_backfill;
288 visit_stream_node_cont(stream_node, |node| {
289 let is_shuffled_backfill = if let Some(node) = &node.node_body
290 && let Some(node) = node.as_stream_scan()
291 {
292 node.stream_scan_type == StreamScanType::ArrangementBackfill as i32
293 || node.stream_scan_type == StreamScanType::SnapshotBackfill as i32
294 } else {
295 false
296 };
297 if is_shuffled_backfill {
298 *has_shuffled_backfill_mut_ref = true;
299 false
300 } else {
301 true
302 }
303 });
304 has_shuffled_backfill
305 }
306}
307
308impl Deref for BuildingFragment {
309 type Target = StreamFragment;
310
311 fn deref(&self) -> &Self::Target {
312 &self.inner
313 }
314}
315
316impl DerefMut for BuildingFragment {
317 fn deref_mut(&mut self) -> &mut Self::Target {
318 &mut self.inner
319 }
320}
321
322#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, EnumAsInner)]
325pub(super) enum EdgeId {
326 Internal {
328 link_id: u64,
331 },
332
333 UpstreamExternal {
336 upstream_job_id: JobId,
338 downstream_fragment_id: GlobalFragmentId,
340 },
341
342 DownstreamExternal(DownstreamExternalEdgeId),
345}
346
347#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
348pub(super) struct DownstreamExternalEdgeId {
349 pub(super) original_upstream_fragment_id: GlobalFragmentId,
351 pub(super) downstream_fragment_id: GlobalFragmentId,
353}
354
355#[derive(Debug, Clone)]
359pub(super) struct StreamFragmentEdge {
360 pub id: EdgeId,
362
363 pub dispatch_strategy: DispatchStrategy,
365}
366
367impl StreamFragmentEdge {
368 fn from_protobuf(edge: &StreamFragmentEdgeProto) -> Self {
369 Self {
370 id: EdgeId::Internal {
373 link_id: edge.link_id,
374 },
375 dispatch_strategy: edge.get_dispatch_strategy().unwrap().clone(),
376 }
377 }
378}
379
380fn clone_fragment(fragment: &Fragment, id_generator_manager: &IdGeneratorManager) -> Fragment {
381 let fragment_id = GlobalFragmentIdGen::new(id_generator_manager, 1)
382 .to_global_id(0)
383 .as_global_id();
384 Fragment {
385 fragment_id,
386 fragment_type_mask: fragment.fragment_type_mask,
387 distribution_type: fragment.distribution_type,
388 state_table_ids: fragment.state_table_ids.clone(),
389 maybe_vnode_count: fragment.maybe_vnode_count,
390 nodes: fragment.nodes.clone(),
391 }
392}
393
394pub fn check_sink_fragments_support_refresh_schema(
395 fragments: &BTreeMap<FragmentId, Fragment>,
396) -> MetaResult<()> {
397 if fragments.len() != 1 {
398 return Err(anyhow!(
399 "sink with auto schema change should have only 1 fragment, but got {:?}",
400 fragments.len()
401 )
402 .into());
403 }
404 let (_, fragment) = fragments.first_key_value().expect("non-empty");
405 let sink_node = &fragment.nodes;
406 let PbNodeBody::Sink(_) = sink_node.node_body.as_ref().unwrap() else {
407 return Err(anyhow!("expect PbNodeBody::Sink but got: {:?}", sink_node.node_body).into());
408 };
409 let [stream_input_node] = sink_node.input.as_slice() else {
410 panic!("Sink has more than 1 input: {:?}", sink_node.input);
411 };
412 let stream_scan_node = match stream_input_node.node_body.as_ref().unwrap() {
413 PbNodeBody::StreamScan(_) => stream_input_node,
414 PbNodeBody::Project(_) => {
415 let [stream_scan_node] = stream_input_node.input.as_slice() else {
416 return Err(anyhow!(
417 "Project node must have exactly 1 input for auto schema change, but got {:?}",
418 stream_input_node.input.len()
419 )
420 .into());
421 };
422 stream_scan_node
423 }
424 _ => {
425 return Err(anyhow!(
426 "expect PbNodeBody::StreamScan or PbNodeBody::Project but got: {:?}",
427 stream_input_node.node_body
428 )
429 .into());
430 }
431 };
432 let PbNodeBody::StreamScan(scan) = stream_scan_node.node_body.as_ref().unwrap() else {
433 return Err(anyhow!(
434 "expect PbNodeBody::StreamScan but got: {:?}",
435 stream_scan_node.node_body
436 )
437 .into());
438 };
439 let stream_scan_type = PbStreamScanType::try_from(scan.stream_scan_type).unwrap();
440 if stream_scan_type != PbStreamScanType::ArrangementBackfill {
441 return Err(anyhow!(
442 "unsupported stream_scan_type for auto refresh schema: {:?}",
443 stream_scan_type
444 )
445 .into());
446 }
447 let [merge_node, _batch_plan_node] = stream_scan_node.input.as_slice() else {
448 panic!(
449 "the number of StreamScan inputs is not 2: {:?}",
450 stream_scan_node.input
451 );
452 };
453 let NodeBody::Merge(_) = merge_node.node_body.as_ref().unwrap() else {
454 return Err(anyhow!(
455 "expect PbNodeBody::Merge but got: {:?}",
456 merge_node.node_body
457 )
458 .into());
459 };
460 Ok(())
461}
462
463struct ScanRewriteResult {
465 old_output_index_to_new_output_index: HashMap<u32, u32>,
466 new_output_index_by_column_id: HashMap<ColumnId, u32>,
467 output_fields: Vec<risingwave_pb::plan_common::Field>,
468}
469
470fn extend_sink_columns(
472 sink_columns: &mut Vec<PbColumnCatalog>,
473 new_columns: &[ColumnCatalog],
474 get_column_name: impl Fn(&String) -> String,
475) {
476 let next_column_id = sink_columns
477 .iter()
478 .map(|col| col.column_desc.as_ref().unwrap().column_id + 1)
479 .max()
480 .unwrap_or(1);
481 sink_columns.extend(new_columns.iter().enumerate().map(|(i, col)| {
482 let mut col = col.to_protobuf();
483 let column_desc = col.column_desc.as_mut().unwrap();
484 column_desc.column_id = next_column_id + (i as i32);
485 column_desc.name = get_column_name(&column_desc.name);
486 col
487 }));
488}
489
490fn build_new_sink_columns(
492 sink: &PbSink,
493 removed_column_names: &HashSet<String>,
494 newly_added_columns: &[ColumnCatalog],
495) -> Vec<PbColumnCatalog> {
496 let mut columns: Vec<PbColumnCatalog> = sink
497 .columns
498 .iter()
499 .filter(|col| {
500 let column_name = &col.column_desc.as_ref().unwrap().name;
501 !removed_column_names.contains(column_name)
502 })
503 .cloned()
504 .collect();
505 extend_sink_columns(&mut columns, newly_added_columns, |name| name.clone());
506 columns
507}
508
509fn rewrite_log_store_table(
511 log_store_table: &mut PbTable,
512 removed_log_store_column_names: &HashSet<String>,
513 newly_added_columns: &[ColumnCatalog],
514 upstream_table_name: &str,
515) {
516 log_store_table.columns.retain(|col| {
517 !removed_log_store_column_names.contains(&col.column_desc.as_ref().unwrap().name)
518 });
519 extend_sink_columns(&mut log_store_table.columns, newly_added_columns, |name| {
520 format!("{}_{}", upstream_table_name, name)
521 });
522 log_store_table.value_indices = (0..log_store_table.columns.len() as i32).collect();
523}
524
525fn rewrite_stream_scan_and_merge(
527 stream_scan_node: &mut StreamNode,
528 removed_column_ids: &HashSet<ColumnId>,
529 newly_added_columns: &[ColumnCatalog],
530 upstream_table: &PbTable,
531 upstream_table_fragment_id: FragmentId,
532) -> MetaResult<ScanRewriteResult> {
533 let PbNodeBody::StreamScan(scan) = stream_scan_node.node_body.as_mut().unwrap() else {
534 return Err(anyhow!(
535 "expect PbNodeBody::StreamScan but got: {:?}",
536 stream_scan_node.node_body
537 )
538 .into());
539 };
540 let [merge_node, _batch_plan_node] = stream_scan_node.input.as_mut_slice() else {
541 panic!(
542 "the number of StreamScan inputs is not 2: {:?}",
543 stream_scan_node.input
544 );
545 };
546 let NodeBody::Merge(merge) = merge_node.node_body.as_mut().unwrap() else {
547 return Err(anyhow!(
548 "expect PbNodeBody::Merge but got: {:?}",
549 merge_node.node_body
550 )
551 .into());
552 };
553
554 let stream_scan_type = PbStreamScanType::try_from(scan.stream_scan_type).unwrap();
555 if stream_scan_type != PbStreamScanType::ArrangementBackfill {
556 return Err(anyhow!(
557 "unsupported stream_scan_type for auto refresh schema: {:?}",
558 stream_scan_type
559 )
560 .into());
561 }
562
563 let upstream_columns_by_id: HashMap<i32, PbColumnDesc> = upstream_table
564 .columns
565 .iter()
566 .map(|col| {
567 let desc = col.column_desc.as_ref().unwrap().clone();
568 (desc.column_id, desc)
569 })
570 .collect();
571
572 let old_upstream_column_ids = scan.upstream_column_ids.clone();
573 let old_output_indices = scan.output_indices.clone();
574 let mut old_upstream_index_to_new_upstream_index = HashMap::new();
575 let mut new_upstream_column_ids = Vec::new();
576 for (old_idx, &column_id) in old_upstream_column_ids.iter().enumerate() {
577 if !removed_column_ids.contains(&ColumnId::new(column_id as _)) {
578 let new_idx = new_upstream_column_ids.len() as u32;
579 old_upstream_index_to_new_upstream_index.insert(old_idx as u32, new_idx);
580 new_upstream_column_ids.push(column_id);
581 }
582 }
583 let mut new_output_indices = Vec::new();
584 for old_output_index in &old_output_indices {
585 if let Some(new_index) = old_upstream_index_to_new_upstream_index.get(old_output_index) {
586 new_output_indices.push(*new_index);
587 }
588 }
589 for col in newly_added_columns {
590 let new_index = new_upstream_column_ids.len() as u32;
591 new_upstream_column_ids.push(col.column_id().get_id());
592 new_output_indices.push(new_index);
593 }
594
595 let new_output_column_ids: Vec<i32> = new_output_indices
596 .iter()
597 .map(|&idx| new_upstream_column_ids[idx as usize])
598 .collect();
599 let mut new_output_index_by_column_id = HashMap::new();
600 for (pos, &column_id) in new_output_column_ids.iter().enumerate() {
601 new_output_index_by_column_id.insert(ColumnId::new(column_id as _), pos as u32);
602 }
603 let mut old_output_index_to_new_output_index = HashMap::new();
604 for (old_pos, old_output_index) in old_output_indices.iter().enumerate() {
605 let column_id = old_upstream_column_ids[*old_output_index as usize];
606 if let Some(new_pos) = new_output_index_by_column_id.get(&ColumnId::new(column_id as _)) {
607 old_output_index_to_new_output_index.insert(old_pos as u32, *new_pos);
608 }
609 }
610
611 scan.arrangement_table = Some(upstream_table.clone());
612 scan.upstream_column_ids = new_upstream_column_ids;
613 scan.output_indices = new_output_indices;
614 let table_desc = scan.table_desc.as_mut().unwrap();
615 table_desc.columns = scan
616 .upstream_column_ids
617 .iter()
618 .map(|column_id| {
619 upstream_columns_by_id
620 .get(column_id)
621 .unwrap_or_else(|| panic!("upstream column id not found: {}", column_id))
622 .clone()
623 })
624 .collect();
625
626 stream_scan_node.fields = new_output_column_ids
627 .iter()
628 .map(|column_id| {
629 let col_desc = upstream_columns_by_id
630 .get(column_id)
631 .unwrap_or_else(|| panic!("upstream column id not found: {}", column_id));
632 Field::new(
633 format!("{}.{}", upstream_table.name, col_desc.name),
634 col_desc.column_type.as_ref().unwrap().into(),
635 )
636 .to_prost()
637 })
638 .collect();
639 stream_scan_node.identity = {
641 let columns = stream_scan_node
642 .fields
643 .iter()
644 .map(|col| &col.name)
645 .join(", ");
646 format!("StreamTableScan {{ table: t, columns: [{columns}] }}")
647 };
648
649 merge_node.fields = scan
651 .upstream_column_ids
652 .iter()
653 .map(|&column_id| {
654 let col_desc = upstream_columns_by_id
655 .get(&column_id)
656 .unwrap_or_else(|| panic!("upstream column id not found: {}", column_id));
657 Field::new(
658 col_desc.name.clone(),
659 col_desc.column_type.as_ref().unwrap().into(),
660 )
661 .to_prost()
662 })
663 .collect();
664 merge.upstream_fragment_id = upstream_table_fragment_id;
665
666 Ok(ScanRewriteResult {
667 old_output_index_to_new_output_index,
668 new_output_index_by_column_id,
669 output_fields: stream_scan_node.fields.clone(),
670 })
671}
672
673fn rewrite_project_node(
675 project_node: &mut StreamNode,
676 scan_rewrite: &ScanRewriteResult,
677 newly_added_columns: &[ColumnCatalog],
678 removed_column_ids: &HashSet<ColumnId>,
679 upstream_table_name: &str,
680) -> MetaResult<()> {
681 let PbNodeBody::Project(project_node_body) = project_node.node_body.as_mut().unwrap() else {
682 return Err(anyhow!(
683 "expect PbNodeBody::Project but got: {:?}",
684 project_node.node_body
685 )
686 .into());
687 };
688 let has_non_input_ref = project_node_body
689 .select_list
690 .iter()
691 .any(|expr| !matches!(expr.rex_node, Some(expr_node::RexNode::InputRef(_))));
692 if has_non_input_ref && !removed_column_ids.is_empty() {
693 return Err(anyhow!(
694 "auto schema change with drop column only supports Project with InputRef"
695 )
696 .into());
697 }
698
699 let mut new_select_list = Vec::with_capacity(project_node_body.select_list.len());
700 let mut new_project_fields = Vec::with_capacity(project_node.fields.len());
701 for (index, expr) in project_node_body.select_list.iter().enumerate() {
702 let mut new_expr = expr.clone();
703 if let Some(expr_node::RexNode::InputRef(old_index)) = new_expr.rex_node {
704 let Some(&new_index) = scan_rewrite
705 .old_output_index_to_new_output_index
706 .get(&old_index)
707 else {
708 continue;
709 };
710 new_expr.rex_node = Some(expr_node::RexNode::InputRef(new_index));
711 } else if !removed_column_ids.is_empty() {
712 return Err(anyhow!(
713 "auto schema change with drop column only supports Project with InputRef"
714 )
715 .into());
716 }
717 new_select_list.push(new_expr);
718 new_project_fields.push(project_node.fields[index].clone());
719 }
720
721 for col in newly_added_columns {
722 let Some(&new_index) = scan_rewrite
723 .new_output_index_by_column_id
724 .get(&col.column_id())
725 else {
726 return Err(anyhow!("new column id not found in scan output").into());
727 };
728 new_select_list.push(PbExprNode {
729 function_type: expr_node::Type::Unspecified as i32,
730 return_type: Some(col.data_type().to_protobuf()),
731 rex_node: Some(expr_node::RexNode::InputRef(new_index)),
732 });
733 new_project_fields.push(
734 Field::new(
735 format!("{}.{}", upstream_table_name, col.column_desc.name),
736 col.data_type().clone(),
737 )
738 .to_prost(),
739 );
740 }
741
742 project_node_body.select_list = new_select_list;
743 project_node.fields = new_project_fields;
744 Ok(())
745}
746
747pub fn rewrite_refresh_schema_sink_fragment(
748 original_sink_fragment: &Fragment,
749 sink: &PbSink,
750 newly_added_columns: &[ColumnCatalog],
751 removed_columns: &[ColumnCatalog],
752 upstream_table: &PbTable,
753 upstream_table_fragment_id: FragmentId,
754 id_generator_manager: &IdGeneratorManager,
755) -> MetaResult<(Fragment, Vec<PbColumnCatalog>, Option<PbTable>)> {
756 let removed_column_ids: HashSet<_> =
757 removed_columns.iter().map(|col| col.column_id()).collect();
758 let removed_log_store_column_names: HashSet<_> = removed_columns
759 .iter()
760 .map(|col| format!("{}_{}", upstream_table.name, col.column_desc.name))
761 .collect();
762 let removed_sink_column_names: HashSet<_> = removed_columns
763 .iter()
764 .map(|col| col.column_desc.name.clone())
765 .collect();
766 let new_sink_columns =
767 build_new_sink_columns(sink, &removed_sink_column_names, newly_added_columns);
768
769 let mut new_sink_fragment = clone_fragment(original_sink_fragment, id_generator_manager);
770 let sink_node = &mut new_sink_fragment.nodes;
771 let PbNodeBody::Sink(sink_node_body) = sink_node.node_body.as_mut().unwrap() else {
772 return Err(anyhow!("expect PbNodeBody::Sink but got: {:?}", sink_node.node_body).into());
773 };
774 let [stream_input_node] = sink_node.input.as_mut_slice() else {
775 panic!("Sink has more than 1 input: {:?}", sink_node.input);
776 };
777 let stream_input_body = stream_input_node.node_body.as_ref().unwrap();
778 let stream_input_is_project = matches!(stream_input_body, PbNodeBody::Project(_));
779 let stream_input_is_scan = matches!(stream_input_body, PbNodeBody::StreamScan(_));
780 if !stream_input_is_project && !stream_input_is_scan {
781 return Err(anyhow!(
782 "expect PbNodeBody::StreamScan or PbNodeBody::Project but got: {:?}",
783 stream_input_body
784 )
785 .into());
786 }
787
788 sink_node.identity = {
791 let sink_type = SinkType::from_proto(sink.sink_type());
792 let sink_type_str = sink_type.type_str();
793 let column_names = new_sink_columns
794 .iter()
795 .map(|col| {
796 ColumnCatalog::from(col.clone())
797 .name_with_hidden()
798 .to_string()
799 })
800 .join(", ");
801 let downstream_pk = if !sink_type.is_append_only() {
802 let downstream_pk = sink
803 .downstream_pk
804 .iter()
805 .map(|i| &sink.columns[*i as usize].column_desc.as_ref().unwrap().name)
806 .collect_vec();
807 format!(", downstream_pk: {downstream_pk:?}")
808 } else {
809 "".to_owned()
810 };
811 format!("StreamSink {{ type: {sink_type_str}, columns: [{column_names}]{downstream_pk} }}")
812 };
813 let new_log_store_table = if let Some(log_store_table) = &mut sink_node_body.table {
814 rewrite_log_store_table(
815 log_store_table,
816 &removed_log_store_column_names,
817 newly_added_columns,
818 &upstream_table.name,
819 );
820 Some(log_store_table.clone())
821 } else {
822 None
823 };
824 sink_node_body.sink_desc.as_mut().unwrap().column_catalogs = new_sink_columns.clone();
825
826 let stream_scan_node = if stream_input_is_project {
827 let [stream_scan_node] = stream_input_node.input.as_mut_slice() else {
828 return Err(anyhow!(
829 "Project node must have exactly 1 input for auto schema change, but got {:?}",
830 stream_input_node.input.len()
831 )
832 .into());
833 };
834 stream_scan_node
835 } else {
836 stream_input_node
837 };
838 let scan_rewrite = rewrite_stream_scan_and_merge(
839 stream_scan_node,
840 &removed_column_ids,
841 newly_added_columns,
842 upstream_table,
843 upstream_table_fragment_id,
844 )?;
845
846 if stream_input_is_project {
847 let [project_node] = sink_node.input.as_mut_slice() else {
848 panic!("Sink has more than 1 input: {:?}", sink_node.input);
849 };
850 rewrite_project_node(
851 project_node,
852 &scan_rewrite,
853 newly_added_columns,
854 &removed_column_ids,
855 &upstream_table.name,
856 )?;
857 sink_node.fields = project_node.fields.clone();
858 } else {
859 sink_node.fields = scan_rewrite.output_fields;
860 }
861 Ok((new_sink_fragment, new_sink_columns, new_log_store_table))
862}
863
864#[derive(Clone, Debug, Default)]
869pub struct FragmentBackfillOrder<const EXTENDED: bool> {
870 inner: HashMap<FragmentId, Vec<FragmentId>>,
871}
872
873impl<const EXTENDED: bool> Deref for FragmentBackfillOrder<EXTENDED> {
874 type Target = HashMap<FragmentId, Vec<FragmentId>>;
875
876 fn deref(&self) -> &Self::Target {
877 &self.inner
878 }
879}
880
881impl UserDefinedFragmentBackfillOrder {
882 pub fn new(inner: HashMap<FragmentId, Vec<FragmentId>>) -> Self {
883 Self { inner }
884 }
885
886 pub fn merge(orders: impl Iterator<Item = Self>) -> Self {
887 Self {
888 inner: orders.flat_map(|order| order.inner).collect(),
889 }
890 }
891
892 pub fn to_meta_model(&self) -> BackfillOrders {
893 self.inner.clone().into()
894 }
895}
896
897pub type UserDefinedFragmentBackfillOrder = FragmentBackfillOrder<false>;
898pub type ExtendedFragmentBackfillOrder = FragmentBackfillOrder<true>;
899
900#[derive(Default, Debug)]
907pub struct StreamFragmentGraph {
908 pub(super) fragments: HashMap<GlobalFragmentId, BuildingFragment>,
910
911 pub(super) downstreams:
913 HashMap<GlobalFragmentId, HashMap<GlobalFragmentId, StreamFragmentEdge>>,
914
915 pub(super) upstreams: HashMap<GlobalFragmentId, HashMap<GlobalFragmentId, StreamFragmentEdge>>,
917
918 dependent_table_ids: HashSet<TableId>,
920
921 specified_parallelism: Option<NonZeroUsize>,
924 specified_backfill_parallelism: Option<NonZeroUsize>,
927
928 max_parallelism: usize,
938
939 backfill_order: BackfillOrder,
941}
942
943impl StreamFragmentGraph {
944 pub fn new(
947 env: &MetaSrvEnv,
948 proto: StreamFragmentGraphProto,
949 job: &StreamingJob,
950 ) -> MetaResult<Self> {
951 let fragment_id_gen =
952 GlobalFragmentIdGen::new(env.id_gen_manager(), proto.fragments.len() as u64);
953 let table_id_gen = GlobalTableIdGen::new(env.id_gen_manager(), proto.table_ids_cnt as u64);
957
958 let fragments: HashMap<_, _> = proto
960 .fragments
961 .into_iter()
962 .map(|(id, fragment)| {
963 let id = fragment_id_gen.to_global_id(id.as_raw_id());
964 let fragment = BuildingFragment::new(id, fragment, job, table_id_gen);
965 (id, fragment)
966 })
967 .collect();
968
969 assert_eq!(
970 fragments
971 .values()
972 .map(|f| f.extract_internal_tables().len() as u32)
973 .sum::<u32>(),
974 proto.table_ids_cnt
975 );
976
977 let mut downstreams = HashMap::new();
979 let mut upstreams = HashMap::new();
980
981 for edge in proto.edges {
982 let upstream_id = fragment_id_gen.to_global_id(edge.upstream_id.as_raw_id());
983 let downstream_id = fragment_id_gen.to_global_id(edge.downstream_id.as_raw_id());
984 let edge = StreamFragmentEdge::from_protobuf(&edge);
985
986 upstreams
987 .entry(downstream_id)
988 .or_insert_with(HashMap::new)
989 .try_insert(upstream_id, edge.clone())
990 .unwrap();
991 downstreams
992 .entry(upstream_id)
993 .or_insert_with(HashMap::new)
994 .try_insert(downstream_id, edge)
995 .unwrap();
996 }
997
998 let dependent_table_ids = proto.dependent_table_ids.iter().copied().collect();
1001
1002 let specified_parallelism = if let Some(Parallelism { parallelism }) = proto.parallelism {
1003 Some(NonZeroUsize::new(parallelism as usize).context("parallelism should not be 0")?)
1004 } else {
1005 None
1006 };
1007 let specified_backfill_parallelism =
1008 if let Some(Parallelism { parallelism }) = proto.backfill_parallelism {
1009 Some(
1010 NonZeroUsize::new(parallelism as usize)
1011 .context("backfill parallelism should not be 0")?,
1012 )
1013 } else {
1014 None
1015 };
1016
1017 let max_parallelism = proto.max_parallelism as usize;
1018 let backfill_order = proto.backfill_order.unwrap_or(BackfillOrder {
1019 order: Default::default(),
1020 });
1021
1022 Ok(Self {
1023 fragments,
1024 downstreams,
1025 upstreams,
1026 dependent_table_ids,
1027 specified_parallelism,
1028 specified_backfill_parallelism,
1029 max_parallelism,
1030 backfill_order,
1031 })
1032 }
1033
1034 pub fn incomplete_internal_tables(&self) -> BTreeMap<TableId, Table> {
1040 let mut tables = BTreeMap::new();
1041 for fragment in self.fragments.values() {
1042 for table in fragment.extract_internal_tables() {
1043 let table_id = table.id;
1044 tables
1045 .try_insert(table_id, table)
1046 .unwrap_or_else(|_| panic!("duplicated table id `{}`", table_id));
1047 }
1048 }
1049 tables
1050 }
1051
1052 pub fn refill_internal_table_ids(&mut self, table_id_map: HashMap<TableId, TableId>) {
1055 for fragment in self.fragments.values_mut() {
1056 stream_graph_visitor::visit_internal_tables(
1057 &mut fragment.inner,
1058 |table, _table_type_name| {
1059 let target = table_id_map.get(&table.id).cloned().unwrap();
1060 table.id = target;
1061 },
1062 );
1063 }
1064 }
1065
1066 pub fn fit_internal_tables_trivial(
1069 &mut self,
1070 mut old_internal_tables: Vec<Table>,
1071 ) -> MetaResult<()> {
1072 let mut new_internal_table_ids = Vec::new();
1073 for fragment in self.fragments.values() {
1074 for table in &fragment.extract_internal_tables() {
1075 new_internal_table_ids.push(table.id);
1076 }
1077 }
1078
1079 if new_internal_table_ids.len() != old_internal_tables.len() {
1080 bail!(
1081 "Different number of internal tables. New: {}, Old: {}",
1082 new_internal_table_ids.len(),
1083 old_internal_tables.len()
1084 );
1085 }
1086 old_internal_tables.sort_by(|a, b| a.id.cmp(&b.id));
1087 new_internal_table_ids.sort();
1088
1089 let internal_table_id_map = new_internal_table_ids
1090 .into_iter()
1091 .zip_eq_fast(old_internal_tables.into_iter())
1092 .collect::<HashMap<_, _>>();
1093
1094 for fragment in self.fragments.values_mut() {
1097 stream_graph_visitor::visit_internal_tables(
1098 &mut fragment.inner,
1099 |table, _table_type_name| {
1100 let target = internal_table_id_map.get(&table.id).cloned().unwrap();
1102 *table = target;
1103 },
1104 );
1105 }
1106
1107 Ok(())
1108 }
1109
1110 pub fn fit_internal_table_ids_with_mapping(&mut self, mut matches: HashMap<TableId, Table>) {
1112 for fragment in self.fragments.values_mut() {
1113 stream_graph_visitor::visit_internal_tables(
1114 &mut fragment.inner,
1115 |table, _table_type_name| {
1116 let target = matches.remove(&table.id).unwrap_or_else(|| {
1117 panic!("no matching table for table {}({})", table.id, table.name)
1118 });
1119 table.id = target.id;
1120 table.maybe_vnode_count = target.maybe_vnode_count;
1121 },
1122 );
1123 }
1124 }
1125
1126 pub fn fit_snapshot_backfill_epochs(
1127 &mut self,
1128 mut snapshot_backfill_epochs: HashMap<StreamNodeLocalOperatorId, u64>,
1129 ) {
1130 for fragment in self.fragments.values_mut() {
1131 visit_stream_node_cont_mut(fragment.node.as_mut().unwrap(), |node| {
1132 if let PbNodeBody::StreamScan(scan) = node.node_body.as_mut().unwrap()
1133 && let StreamScanType::SnapshotBackfill
1134 | StreamScanType::CrossDbSnapshotBackfill = scan.stream_scan_type()
1135 {
1136 let Some(epoch) = snapshot_backfill_epochs.remove(&node.operator_id) else {
1137 panic!("no snapshot epoch found for node {:?}", node)
1138 };
1139 scan.snapshot_backfill_epoch = Some(epoch);
1140 }
1141 true
1142 })
1143 }
1144 }
1145
1146 pub fn table_fragment_id(&self) -> FragmentId {
1148 self.fragments
1149 .values()
1150 .filter(|b| b.job_id.is_some())
1151 .map(|b| b.fragment_id)
1152 .exactly_one()
1153 .expect("require exactly 1 materialize/sink/cdc source node when creating the streaming job")
1154 }
1155
1156 pub fn dml_fragment_id(&self) -> Option<FragmentId> {
1158 self.fragments
1159 .values()
1160 .filter(|b| {
1161 FragmentTypeMask::from(b.fragment_type_mask).contains(FragmentTypeFlag::Dml)
1162 })
1163 .map(|b| b.fragment_id)
1164 .at_most_one()
1165 .expect("require at most 1 dml node when creating the streaming job")
1166 }
1167
1168 pub fn dependent_table_ids(&self) -> &HashSet<TableId> {
1170 &self.dependent_table_ids
1171 }
1172
1173 pub fn specified_parallelism(&self) -> Option<NonZeroUsize> {
1175 self.specified_parallelism
1176 }
1177
1178 pub fn specified_backfill_parallelism(&self) -> Option<NonZeroUsize> {
1180 self.specified_backfill_parallelism
1181 }
1182
1183 pub fn max_parallelism(&self) -> usize {
1185 self.max_parallelism
1186 }
1187
1188 fn get_downstreams(
1190 &self,
1191 fragment_id: GlobalFragmentId,
1192 ) -> &HashMap<GlobalFragmentId, StreamFragmentEdge> {
1193 self.downstreams.get(&fragment_id).unwrap_or(&EMPTY_HASHMAP)
1194 }
1195
1196 fn get_upstreams(
1198 &self,
1199 fragment_id: GlobalFragmentId,
1200 ) -> &HashMap<GlobalFragmentId, StreamFragmentEdge> {
1201 self.upstreams.get(&fragment_id).unwrap_or(&EMPTY_HASHMAP)
1202 }
1203
1204 pub fn collect_snapshot_backfill_info(
1205 &self,
1206 ) -> MetaResult<(Option<SnapshotBackfillInfo>, SnapshotBackfillInfo)> {
1207 Self::collect_snapshot_backfill_info_impl(self.fragments.values().map(|fragment| {
1208 (
1209 fragment.node.as_ref().unwrap(),
1210 fragment.fragment_type_mask.into(),
1211 )
1212 }))
1213 }
1214
1215 pub fn collect_snapshot_backfill_info_impl(
1217 fragments: impl IntoIterator<Item = (&PbStreamNode, FragmentTypeMask)>,
1218 ) -> MetaResult<(Option<SnapshotBackfillInfo>, SnapshotBackfillInfo)> {
1219 let mut prev_stream_scan: Option<(Option<SnapshotBackfillInfo>, StreamScanNode)> = None;
1220 let mut cross_db_info = SnapshotBackfillInfo {
1221 upstream_mv_table_id_to_backfill_epoch: Default::default(),
1222 };
1223 let mut result = Ok(());
1224 for (node, fragment_type_mask) in fragments {
1225 visit_stream_node_cont(node, |node| {
1226 if let Some(NodeBody::StreamScan(stream_scan)) = node.node_body.as_ref() {
1227 let stream_scan_type = StreamScanType::try_from(stream_scan.stream_scan_type)
1228 .expect("invalid stream_scan_type");
1229 let is_snapshot_backfill = match stream_scan_type {
1230 StreamScanType::SnapshotBackfill => {
1231 assert!(
1232 fragment_type_mask
1233 .contains(FragmentTypeFlag::SnapshotBackfillStreamScan)
1234 );
1235 true
1236 }
1237 StreamScanType::CrossDbSnapshotBackfill => {
1238 assert!(
1239 fragment_type_mask
1240 .contains(FragmentTypeFlag::CrossDbSnapshotBackfillStreamScan)
1241 );
1242 cross_db_info
1243 .upstream_mv_table_id_to_backfill_epoch
1244 .insert(stream_scan.table_id, stream_scan.snapshot_backfill_epoch);
1245
1246 return true;
1247 }
1248 _ => false,
1249 };
1250
1251 match &mut prev_stream_scan {
1252 Some((prev_snapshot_backfill_info, prev_stream_scan)) => {
1253 match (prev_snapshot_backfill_info, is_snapshot_backfill) {
1254 (Some(prev_snapshot_backfill_info), true) => {
1255 prev_snapshot_backfill_info
1256 .upstream_mv_table_id_to_backfill_epoch
1257 .insert(
1258 stream_scan.table_id,
1259 stream_scan.snapshot_backfill_epoch,
1260 );
1261 true
1262 }
1263 (None, false) => true,
1264 (_, _) => {
1265 result = Err(anyhow!("must be either all snapshot_backfill or no snapshot_backfill. Curr: {stream_scan:?} Prev: {prev_stream_scan:?}").into());
1266 false
1267 }
1268 }
1269 }
1270 None => {
1271 prev_stream_scan = Some((
1272 if is_snapshot_backfill {
1273 Some(SnapshotBackfillInfo {
1274 upstream_mv_table_id_to_backfill_epoch: HashMap::from_iter(
1275 [(
1276 stream_scan.table_id,
1277 stream_scan.snapshot_backfill_epoch,
1278 )],
1279 ),
1280 })
1281 } else {
1282 None
1283 },
1284 *stream_scan.clone(),
1285 ));
1286 true
1287 }
1288 }
1289 } else {
1290 true
1291 }
1292 })
1293 }
1294 result.map(|_| {
1295 (
1296 prev_stream_scan
1297 .map(|(snapshot_backfill_info, _)| snapshot_backfill_info)
1298 .unwrap_or(None),
1299 cross_db_info,
1300 )
1301 })
1302 }
1303
1304 pub fn collect_backfill_mapping(
1306 fragments: impl Iterator<Item = (FragmentId, FragmentTypeMask, &PbStreamNode)>,
1307 ) -> HashMap<RelationId, Vec<FragmentId>> {
1308 let mut mapping = HashMap::new();
1309 for (fragment_id, fragment_type_mask, node) in fragments {
1310 let has_some_scan = fragment_type_mask
1311 .contains_any([FragmentTypeFlag::StreamScan, FragmentTypeFlag::SourceScan]);
1312 if has_some_scan {
1313 visit_stream_node_cont(node, |node| {
1314 match node.node_body.as_ref() {
1315 Some(NodeBody::StreamScan(stream_scan)) => {
1316 let table_id = stream_scan.table_id;
1317 let fragments: &mut Vec<_> =
1318 mapping.entry(table_id.as_relation_id()).or_default();
1319 fragments.push(fragment_id);
1320 false
1322 }
1323 Some(NodeBody::SourceBackfill(source_backfill)) => {
1324 let source_id = source_backfill.upstream_source_id;
1325 let fragments: &mut Vec<_> =
1326 mapping.entry(source_id.as_relation_id()).or_default();
1327 fragments.push(fragment_id);
1328 false
1330 }
1331 _ => true,
1332 }
1333 })
1334 }
1335 }
1336 mapping
1337 }
1338
1339 pub fn create_fragment_backfill_ordering(&self) -> UserDefinedFragmentBackfillOrder {
1343 let mapping =
1344 Self::collect_backfill_mapping(self.fragments.iter().map(|(fragment_id, fragment)| {
1345 (
1346 fragment_id.as_global_id(),
1347 fragment.fragment_type_mask.into(),
1348 fragment.node.as_ref().expect("should exist node"),
1349 )
1350 }));
1351 let mut fragment_ordering: HashMap<FragmentId, Vec<FragmentId>> = HashMap::new();
1352
1353 for (rel_id, downstream_rel_ids) in &self.backfill_order.order {
1355 let fragment_ids = mapping.get(rel_id).unwrap();
1356 for fragment_id in fragment_ids {
1357 let downstream_fragment_ids = downstream_rel_ids
1358 .data
1359 .iter()
1360 .flat_map(|&downstream_rel_id| mapping.get(&downstream_rel_id).unwrap().iter())
1361 .copied()
1362 .collect();
1363 fragment_ordering.insert(*fragment_id, downstream_fragment_ids);
1364 }
1365 }
1366
1367 UserDefinedFragmentBackfillOrder {
1368 inner: fragment_ordering,
1369 }
1370 }
1371
1372 pub fn extend_fragment_backfill_ordering_with_locality_backfill<
1373 'a,
1374 FI: Iterator<Item = (FragmentId, FragmentTypeMask, &'a PbStreamNode)> + 'a,
1375 >(
1376 fragment_ordering: UserDefinedFragmentBackfillOrder,
1377 fragment_downstreams: &FragmentDownstreamRelation,
1378 get_fragments: impl Fn() -> FI,
1379 ) -> ExtendedFragmentBackfillOrder {
1380 let mut fragment_ordering = fragment_ordering.inner;
1381 let mapping = Self::collect_backfill_mapping(get_fragments());
1382 if fragment_ordering.is_empty() {
1385 for value in mapping.values() {
1386 for &fragment_id in value {
1387 fragment_ordering.entry(fragment_id).or_default();
1388 }
1389 }
1390 }
1391
1392 let locality_provider_dependencies = Self::find_locality_provider_dependencies(
1394 get_fragments().map(|(fragment_id, _, node)| (fragment_id, node)),
1395 fragment_downstreams,
1396 );
1397
1398 let backfill_fragments: HashSet<FragmentId> = mapping.values().flatten().copied().collect();
1399
1400 let all_locality_provider_fragments: HashSet<FragmentId> =
1403 locality_provider_dependencies.keys().copied().collect();
1404 let downstream_locality_provider_fragments: HashSet<FragmentId> =
1405 locality_provider_dependencies
1406 .values()
1407 .flatten()
1408 .copied()
1409 .collect();
1410 let locality_provider_root_fragments: Vec<FragmentId> = all_locality_provider_fragments
1411 .difference(&downstream_locality_provider_fragments)
1412 .copied()
1413 .collect();
1414
1415 for &backfill_fragment_id in &backfill_fragments {
1418 fragment_ordering
1419 .entry(backfill_fragment_id)
1420 .or_default()
1421 .extend(locality_provider_root_fragments.iter().copied());
1422 }
1423
1424 for (fragment_id, downstream_fragments) in locality_provider_dependencies {
1426 fragment_ordering
1427 .entry(fragment_id)
1428 .or_default()
1429 .extend(downstream_fragments);
1430 }
1431
1432 for downstream in fragment_ordering.values_mut() {
1436 let mut seen = HashSet::new();
1437 downstream.retain(|id| seen.insert(*id));
1438 }
1439
1440 ExtendedFragmentBackfillOrder {
1441 inner: fragment_ordering,
1442 }
1443 }
1444
1445 pub fn find_locality_provider_fragment_state_table_mapping(
1446 &self,
1447 ) -> HashMap<FragmentId, Vec<TableId>> {
1448 let mut mapping: HashMap<FragmentId, Vec<TableId>> = HashMap::new();
1449
1450 for (fragment_id, fragment) in &self.fragments {
1451 let fragment_id = fragment_id.as_global_id();
1452
1453 if let Some(node) = fragment.node.as_ref() {
1455 let mut state_table_ids = Vec::new();
1456
1457 visit_stream_node_cont(node, |stream_node| {
1458 if let Some(NodeBody::LocalityProvider(locality_provider)) =
1459 stream_node.node_body.as_ref()
1460 {
1461 let state_table_id = locality_provider
1463 .state_table
1464 .as_ref()
1465 .expect("must have state table")
1466 .id;
1467 state_table_ids.push(state_table_id);
1468 false } else {
1470 true }
1472 });
1473
1474 if !state_table_ids.is_empty() {
1475 mapping.insert(fragment_id, state_table_ids);
1476 }
1477 }
1478 }
1479
1480 mapping
1481 }
1482
1483 pub fn find_locality_provider_dependencies<'a>(
1491 fragments_nodes: impl Iterator<Item = (FragmentId, &'a PbStreamNode)>,
1492 fragment_downstreams: &FragmentDownstreamRelation,
1493 ) -> HashMap<FragmentId, Vec<FragmentId>> {
1494 let mut locality_provider_fragments = HashSet::new();
1495 let mut dependencies: HashMap<FragmentId, Vec<FragmentId>> = HashMap::new();
1496
1497 for (fragment_id, node) in fragments_nodes {
1499 let has_locality_provider = Self::fragment_has_locality_provider(node);
1500
1501 if has_locality_provider {
1502 locality_provider_fragments.insert(fragment_id);
1503 dependencies.entry(fragment_id).or_default();
1504 }
1505 }
1506
1507 for &provider_fragment_id in &locality_provider_fragments {
1511 let mut visited = HashSet::new();
1513 let mut downstream_locality_providers = Vec::new();
1514
1515 Self::collect_downstream_locality_providers(
1516 provider_fragment_id,
1517 &locality_provider_fragments,
1518 fragment_downstreams,
1519 &mut visited,
1520 &mut downstream_locality_providers,
1521 );
1522
1523 dependencies
1525 .entry(provider_fragment_id)
1526 .or_default()
1527 .extend(downstream_locality_providers);
1528 }
1529
1530 dependencies
1531 }
1532
1533 fn fragment_has_locality_provider(node: &PbStreamNode) -> bool {
1534 let mut has_locality_provider = false;
1535
1536 {
1537 visit_stream_node_cont(node, |stream_node| {
1538 if let Some(NodeBody::LocalityProvider(_)) = stream_node.node_body.as_ref() {
1539 has_locality_provider = true;
1540 false } else {
1542 true }
1544 });
1545 }
1546
1547 has_locality_provider
1548 }
1549
1550 fn collect_downstream_locality_providers(
1552 current_fragment_id: FragmentId,
1553 locality_provider_fragments: &HashSet<FragmentId>,
1554 fragment_downstreams: &FragmentDownstreamRelation,
1555 visited: &mut HashSet<FragmentId>,
1556 downstream_providers: &mut Vec<FragmentId>,
1557 ) {
1558 if visited.contains(¤t_fragment_id) {
1559 return;
1560 }
1561 visited.insert(current_fragment_id);
1562
1563 for downstream_fragment_id in fragment_downstreams
1565 .get(¤t_fragment_id)
1566 .into_iter()
1567 .flat_map(|downstreams| {
1568 downstreams
1569 .iter()
1570 .map(|downstream| downstream.downstream_fragment_id)
1571 })
1572 {
1573 if locality_provider_fragments.contains(&downstream_fragment_id) {
1575 downstream_providers.push(downstream_fragment_id);
1576 }
1577
1578 Self::collect_downstream_locality_providers(
1580 downstream_fragment_id,
1581 locality_provider_fragments,
1582 fragment_downstreams,
1583 visited,
1584 downstream_providers,
1585 );
1586 }
1587 }
1588}
1589
1590pub fn fill_snapshot_backfill_epoch(
1593 node: &mut StreamNode,
1594 snapshot_backfill_info: Option<&SnapshotBackfillInfo>,
1595 cross_db_snapshot_backfill_info: &SnapshotBackfillInfo,
1596) -> MetaResult<bool> {
1597 let mut result = Ok(());
1598 let mut applied = false;
1599 visit_stream_node_cont_mut(node, |node| {
1600 if let Some(NodeBody::StreamScan(stream_scan)) = node.node_body.as_mut()
1601 && (stream_scan.stream_scan_type == StreamScanType::SnapshotBackfill as i32
1602 || stream_scan.stream_scan_type == StreamScanType::CrossDbSnapshotBackfill as i32)
1603 {
1604 result = try {
1605 let table_id = stream_scan.table_id;
1606 let snapshot_epoch = cross_db_snapshot_backfill_info
1607 .upstream_mv_table_id_to_backfill_epoch
1608 .get(&table_id)
1609 .or_else(|| {
1610 snapshot_backfill_info.and_then(|snapshot_backfill_info| {
1611 snapshot_backfill_info
1612 .upstream_mv_table_id_to_backfill_epoch
1613 .get(&table_id)
1614 })
1615 })
1616 .ok_or_else(|| anyhow!("upstream table id not covered: {}", table_id))?
1617 .ok_or_else(|| anyhow!("upstream table id not set: {}", table_id))?;
1618 if let Some(prev_snapshot_epoch) =
1619 stream_scan.snapshot_backfill_epoch.replace(snapshot_epoch)
1620 {
1621 Err(anyhow!(
1622 "snapshot backfill epoch set again: {} {} {}",
1623 table_id,
1624 prev_snapshot_epoch,
1625 snapshot_epoch
1626 ))?;
1627 }
1628 applied = true;
1629 };
1630 result.is_ok()
1631 } else {
1632 true
1633 }
1634 });
1635 result.map(|_| applied)
1636}
1637
1638static EMPTY_HASHMAP: LazyLock<HashMap<GlobalFragmentId, StreamFragmentEdge>> =
1639 LazyLock::new(HashMap::new);
1640
1641#[derive(Debug, Clone, EnumAsInner)]
1644pub(super) enum EitherFragment {
1645 Building(BuildingFragment),
1647
1648 Existing,
1650}
1651
1652#[derive(Debug)]
1661pub struct CompleteStreamFragmentGraph {
1662 building_graph: StreamFragmentGraph,
1664
1665 existing_fragments: HashMap<GlobalFragmentId, Fragment>,
1667
1668 extra_downstreams: HashMap<GlobalFragmentId, HashMap<GlobalFragmentId, StreamFragmentEdge>>,
1670
1671 extra_upstreams: HashMap<GlobalFragmentId, HashMap<GlobalFragmentId, StreamFragmentEdge>>,
1673}
1674
1675pub struct FragmentGraphUpstreamContext {
1676 pub upstream_root_fragments: HashMap<JobId, Fragment>,
1679}
1680
1681pub struct FragmentGraphDownstreamContext {
1682 pub original_root_fragment_id: FragmentId,
1683 pub downstream_fragments: Vec<(DispatcherType, Fragment)>,
1684}
1685
1686impl CompleteStreamFragmentGraph {
1687 #[cfg(test)]
1690 pub fn for_test(graph: StreamFragmentGraph) -> Self {
1691 Self {
1692 building_graph: graph,
1693 existing_fragments: Default::default(),
1694 extra_downstreams: Default::default(),
1695 extra_upstreams: Default::default(),
1696 }
1697 }
1698
1699 pub fn with_upstreams(
1703 graph: StreamFragmentGraph,
1704 upstream_context: FragmentGraphUpstreamContext,
1705 job_type: StreamingJobType,
1706 ) -> MetaResult<Self> {
1707 Self::build_helper(graph, Some(upstream_context), None, job_type)
1708 }
1709
1710 pub fn with_downstreams(
1713 graph: StreamFragmentGraph,
1714 downstream_context: FragmentGraphDownstreamContext,
1715 job_type: StreamingJobType,
1716 ) -> MetaResult<Self> {
1717 Self::build_helper(graph, None, Some(downstream_context), job_type)
1718 }
1719
1720 pub fn with_upstreams_and_downstreams(
1722 graph: StreamFragmentGraph,
1723 upstream_context: FragmentGraphUpstreamContext,
1724 downstream_context: FragmentGraphDownstreamContext,
1725 job_type: StreamingJobType,
1726 ) -> MetaResult<Self> {
1727 Self::build_helper(
1728 graph,
1729 Some(upstream_context),
1730 Some(downstream_context),
1731 job_type,
1732 )
1733 }
1734
1735 fn build_helper(
1737 mut graph: StreamFragmentGraph,
1738 upstream_ctx: Option<FragmentGraphUpstreamContext>,
1739 downstream_ctx: Option<FragmentGraphDownstreamContext>,
1740 job_type: StreamingJobType,
1741 ) -> MetaResult<Self> {
1742 let mut extra_downstreams = HashMap::new();
1743 let mut extra_upstreams = HashMap::new();
1744 let mut existing_fragments = HashMap::new();
1745
1746 if let Some(FragmentGraphUpstreamContext {
1747 upstream_root_fragments,
1748 }) = upstream_ctx
1749 {
1750 for (&id, fragment) in &mut graph.fragments {
1751 let uses_shuffled_backfill = fragment.has_shuffled_backfill();
1752
1753 for (&upstream_job_id, required_columns) in &fragment.upstream_job_columns {
1754 let upstream_fragment = upstream_root_fragments
1755 .get(&upstream_job_id)
1756 .context("upstream fragment not found")?;
1757 let upstream_root_fragment_id =
1758 GlobalFragmentId::new(upstream_fragment.fragment_id);
1759
1760 let edge = match job_type {
1761 StreamingJobType::Table(TableJobType::SharedCdcSource) => {
1762 assert_ne!(
1765 (fragment.fragment_type_mask & FragmentTypeFlag::CdcFilter as u32),
1766 0
1767 );
1768
1769 tracing::debug!(
1770 ?upstream_root_fragment_id,
1771 ?required_columns,
1772 identity = ?fragment.inner.get_node().unwrap().get_identity(),
1773 current_frag_id=?id,
1774 "CdcFilter with upstream source fragment"
1775 );
1776
1777 StreamFragmentEdge {
1778 id: EdgeId::UpstreamExternal {
1779 upstream_job_id,
1780 downstream_fragment_id: id,
1781 },
1782 dispatch_strategy: DispatchStrategy {
1785 r#type: DispatcherType::NoShuffle as _,
1786 dist_key_indices: vec![], output_mapping: DispatchOutputMapping::identical(
1788 CDC_SOURCE_COLUMN_NUM as _,
1789 )
1790 .into(),
1791 },
1792 }
1793 }
1794
1795 StreamingJobType::MaterializedView
1797 | StreamingJobType::Sink
1798 | StreamingJobType::Index => {
1799 if upstream_fragment
1802 .fragment_type_mask
1803 .contains(FragmentTypeFlag::Mview)
1804 {
1805 let (dist_key_indices, output_mapping) = {
1807 let mview_node = upstream_fragment
1808 .nodes
1809 .get_node_body()
1810 .unwrap()
1811 .as_materialize()
1812 .unwrap();
1813 let all_columns = mview_node.column_descs();
1814 let dist_key_indices = mview_node.dist_key_indices();
1815 let output_mapping = gen_output_mapping(
1816 required_columns,
1817 &all_columns,
1818 )
1819 .context(
1820 "BUG: column not found in the upstream materialized view",
1821 )?;
1822 (dist_key_indices, output_mapping)
1823 };
1824 let dispatch_strategy = mv_on_mv_dispatch_strategy(
1825 uses_shuffled_backfill,
1826 dist_key_indices,
1827 output_mapping,
1828 );
1829
1830 StreamFragmentEdge {
1831 id: EdgeId::UpstreamExternal {
1832 upstream_job_id,
1833 downstream_fragment_id: id,
1834 },
1835 dispatch_strategy,
1836 }
1837 }
1838 else if upstream_fragment
1841 .fragment_type_mask
1842 .contains(FragmentTypeFlag::Source)
1843 {
1844 let output_mapping = {
1845 let source_node = upstream_fragment
1846 .nodes
1847 .get_node_body()
1848 .unwrap()
1849 .as_source()
1850 .unwrap();
1851
1852 let all_columns = source_node.column_descs().unwrap();
1853 gen_output_mapping(required_columns, &all_columns).context(
1854 "BUG: column not found in the upstream source node",
1855 )?
1856 };
1857
1858 StreamFragmentEdge {
1859 id: EdgeId::UpstreamExternal {
1860 upstream_job_id,
1861 downstream_fragment_id: id,
1862 },
1863 dispatch_strategy: DispatchStrategy {
1866 r#type: DispatcherType::NoShuffle as _,
1867 dist_key_indices: vec![], output_mapping: Some(output_mapping),
1869 },
1870 }
1871 } else {
1872 bail!(
1873 "the upstream fragment should be a MView or Source, got fragment type: {:b}",
1874 upstream_fragment.fragment_type_mask
1875 )
1876 }
1877 }
1878 StreamingJobType::Source | StreamingJobType::Table(_) => {
1879 bail!(
1880 "the streaming job shouldn't have an upstream fragment, job_type: {:?}",
1881 job_type
1882 )
1883 }
1884 };
1885
1886 extra_downstreams
1888 .entry(upstream_root_fragment_id)
1889 .or_insert_with(HashMap::new)
1890 .try_insert(id, edge.clone())
1891 .unwrap();
1892 extra_upstreams
1893 .entry(id)
1894 .or_insert_with(HashMap::new)
1895 .try_insert(upstream_root_fragment_id, edge)
1896 .unwrap();
1897 }
1898 }
1899
1900 existing_fragments.extend(
1901 upstream_root_fragments
1902 .into_values()
1903 .map(|f| (GlobalFragmentId::new(f.fragment_id), f)),
1904 );
1905 }
1906
1907 if let Some(FragmentGraphDownstreamContext {
1908 original_root_fragment_id,
1909 downstream_fragments,
1910 }) = downstream_ctx
1911 {
1912 let original_table_fragment_id = GlobalFragmentId::new(original_root_fragment_id);
1913 let table_fragment_id = GlobalFragmentId::new(graph.table_fragment_id());
1914
1915 for (dispatcher_type, fragment) in &downstream_fragments {
1918 let id = GlobalFragmentId::new(fragment.fragment_id);
1919
1920 let output_columns = {
1922 let mut res = None;
1923
1924 stream_graph_visitor::visit_stream_node_body(&fragment.nodes, |node_body| {
1925 let columns = match node_body {
1926 NodeBody::StreamScan(stream_scan) => stream_scan.upstream_columns(),
1927 NodeBody::SourceBackfill(source_backfill) => {
1928 source_backfill.column_descs()
1930 }
1931 _ => return,
1932 };
1933 res = Some(columns);
1934 });
1935
1936 res.context("failed to locate downstream scan")?
1937 };
1938
1939 let table_fragment = graph.fragments.get(&table_fragment_id).unwrap();
1940 let nodes = table_fragment.node.as_ref().unwrap();
1941
1942 let (dist_key_indices, output_mapping) = match job_type {
1943 StreamingJobType::Table(_) | StreamingJobType::MaterializedView => {
1944 let mview_node = nodes.get_node_body().unwrap().as_materialize().unwrap();
1945 let all_columns = mview_node.column_descs();
1946 let dist_key_indices = mview_node.dist_key_indices();
1947 let output_mapping = gen_output_mapping(&output_columns, &all_columns)
1948 .ok_or_else(|| {
1949 MetaError::invalid_parameter(
1950 "unable to drop the column due to \
1951 being referenced by downstream materialized views or sinks",
1952 )
1953 })?;
1954 (dist_key_indices, output_mapping)
1955 }
1956
1957 StreamingJobType::Source => {
1958 let source_node = nodes.get_node_body().unwrap().as_source().unwrap();
1959 let all_columns = source_node.column_descs().unwrap();
1960 let output_mapping = gen_output_mapping(&output_columns, &all_columns)
1961 .ok_or_else(|| {
1962 MetaError::invalid_parameter(
1963 "unable to drop the column due to \
1964 being referenced by downstream materialized views or sinks",
1965 )
1966 })?;
1967 assert_eq!(*dispatcher_type, DispatcherType::NoShuffle);
1968 (
1969 vec![], output_mapping,
1971 )
1972 }
1973
1974 _ => bail!("unsupported job type for replacement: {job_type:?}"),
1975 };
1976
1977 let edge = StreamFragmentEdge {
1978 id: EdgeId::DownstreamExternal(DownstreamExternalEdgeId {
1979 original_upstream_fragment_id: original_table_fragment_id,
1980 downstream_fragment_id: id,
1981 }),
1982 dispatch_strategy: DispatchStrategy {
1983 r#type: *dispatcher_type as i32,
1984 output_mapping: Some(output_mapping),
1985 dist_key_indices,
1986 },
1987 };
1988
1989 extra_downstreams
1990 .entry(table_fragment_id)
1991 .or_insert_with(HashMap::new)
1992 .try_insert(id, edge.clone())
1993 .unwrap();
1994 extra_upstreams
1995 .entry(id)
1996 .or_insert_with(HashMap::new)
1997 .try_insert(table_fragment_id, edge)
1998 .unwrap();
1999 }
2000
2001 existing_fragments.extend(
2002 downstream_fragments
2003 .into_iter()
2004 .map(|(_, f)| (GlobalFragmentId::new(f.fragment_id), f)),
2005 );
2006 }
2007
2008 Ok(Self {
2009 building_graph: graph,
2010 existing_fragments,
2011 extra_downstreams,
2012 extra_upstreams,
2013 })
2014 }
2015}
2016
2017fn gen_output_mapping(
2019 required_columns: &[PbColumnDesc],
2020 upstream_columns: &[PbColumnDesc],
2021) -> Option<DispatchOutputMapping> {
2022 let len = required_columns.len();
2023 let mut indices = vec![0; len];
2024 let mut types = None;
2025
2026 for (i, r) in required_columns.iter().enumerate() {
2027 let (ui, u) = upstream_columns
2028 .iter()
2029 .find_position(|&u| u.column_id == r.column_id)?;
2030 indices[i] = ui as u32;
2031
2032 if u.column_type != r.column_type {
2035 types.get_or_insert_with(|| vec![TypePair::default(); len])[i] = TypePair {
2036 upstream: u.column_type.clone(),
2037 downstream: r.column_type.clone(),
2038 };
2039 }
2040 }
2041
2042 let types = types.unwrap_or(Vec::new());
2044
2045 Some(DispatchOutputMapping { indices, types })
2046}
2047
2048fn mv_on_mv_dispatch_strategy(
2049 uses_shuffled_backfill: bool,
2050 dist_key_indices: Vec<u32>,
2051 output_mapping: DispatchOutputMapping,
2052) -> DispatchStrategy {
2053 if uses_shuffled_backfill {
2054 if !dist_key_indices.is_empty() {
2055 DispatchStrategy {
2056 r#type: DispatcherType::Hash as _,
2057 dist_key_indices,
2058 output_mapping: Some(output_mapping),
2059 }
2060 } else {
2061 DispatchStrategy {
2062 r#type: DispatcherType::Simple as _,
2063 dist_key_indices: vec![], output_mapping: Some(output_mapping),
2065 }
2066 }
2067 } else {
2068 DispatchStrategy {
2069 r#type: DispatcherType::NoShuffle as _,
2070 dist_key_indices: vec![], output_mapping: Some(output_mapping),
2072 }
2073 }
2074}
2075
2076impl CompleteStreamFragmentGraph {
2077 pub(super) fn all_fragment_ids(&self) -> impl Iterator<Item = GlobalFragmentId> + '_ {
2080 self.building_graph
2081 .fragments
2082 .keys()
2083 .chain(self.existing_fragments.keys())
2084 .copied()
2085 }
2086
2087 pub(super) fn all_edges(
2089 &self,
2090 ) -> impl Iterator<Item = (GlobalFragmentId, GlobalFragmentId, &StreamFragmentEdge)> + '_ {
2091 self.building_graph
2092 .downstreams
2093 .iter()
2094 .chain(self.extra_downstreams.iter())
2095 .flat_map(|(&from, tos)| tos.iter().map(move |(&to, edge)| (from, to, edge)))
2096 }
2097
2098 pub(super) fn existing_distribution(&self) -> HashMap<GlobalFragmentId, Distribution> {
2100 self.existing_fragments
2101 .iter()
2102 .map(|(&id, f)| (id, Distribution::from_fragment(f)))
2103 .collect()
2104 }
2105
2106 pub(super) fn topo_order(&self) -> MetaResult<Vec<GlobalFragmentId>> {
2113 let mut topo = Vec::new();
2114 let mut downstream_cnts = HashMap::new();
2115
2116 for fragment_id in self.all_fragment_ids() {
2118 let downstream_cnt = self.get_downstreams(fragment_id).count();
2120 if downstream_cnt == 0 {
2121 topo.push(fragment_id);
2122 } else {
2123 downstream_cnts.insert(fragment_id, downstream_cnt);
2124 }
2125 }
2126
2127 let mut i = 0;
2128 while let Some(&fragment_id) = topo.get(i) {
2129 i += 1;
2130 for (upstream_job_id, _) in self.get_upstreams(fragment_id) {
2132 let downstream_cnt = downstream_cnts.get_mut(&upstream_job_id).unwrap();
2133 *downstream_cnt -= 1;
2134 if *downstream_cnt == 0 {
2135 downstream_cnts.remove(&upstream_job_id);
2136 topo.push(upstream_job_id);
2137 }
2138 }
2139 }
2140
2141 if !downstream_cnts.is_empty() {
2142 bail!("graph is not a DAG");
2144 }
2145
2146 Ok(topo)
2147 }
2148
2149 pub(super) fn seal_fragment(
2152 &self,
2153 id: GlobalFragmentId,
2154 distribution: Distribution,
2155 stream_node: StreamNode,
2156 ) -> Fragment {
2157 let building_fragment = self.get_fragment(id).into_building().unwrap();
2158 let internal_tables = building_fragment.extract_internal_tables();
2159 let BuildingFragment {
2160 inner,
2161 job_id,
2162 upstream_job_columns: _,
2163 } = building_fragment;
2164
2165 let distribution_type = distribution.to_distribution_type();
2166 let vnode_count = distribution.vnode_count();
2167
2168 let materialized_fragment_id =
2169 if FragmentTypeMask::from(inner.fragment_type_mask).contains(FragmentTypeFlag::Mview) {
2170 job_id.map(JobId::as_mv_table_id)
2171 } else {
2172 None
2173 };
2174
2175 let vector_index_fragment_id =
2176 if inner.fragment_type_mask & FragmentTypeFlag::VectorIndexWrite as u32 != 0 {
2177 job_id.map(JobId::as_mv_table_id)
2178 } else {
2179 None
2180 };
2181
2182 let state_table_ids = internal_tables
2183 .iter()
2184 .map(|t| t.id)
2185 .chain(materialized_fragment_id)
2186 .chain(vector_index_fragment_id)
2187 .collect();
2188
2189 Fragment {
2190 fragment_id: inner.fragment_id,
2191 fragment_type_mask: inner.fragment_type_mask.into(),
2192 distribution_type,
2193 state_table_ids,
2194 maybe_vnode_count: VnodeCount::set(vnode_count).to_protobuf(),
2195 nodes: stream_node,
2196 }
2197 }
2198
2199 pub(super) fn get_fragment(&self, fragment_id: GlobalFragmentId) -> EitherFragment {
2202 if self.existing_fragments.contains_key(&fragment_id) {
2203 EitherFragment::Existing
2204 } else {
2205 EitherFragment::Building(
2206 self.building_graph
2207 .fragments
2208 .get(&fragment_id)
2209 .unwrap()
2210 .clone(),
2211 )
2212 }
2213 }
2214
2215 pub(super) fn get_downstreams(
2218 &self,
2219 fragment_id: GlobalFragmentId,
2220 ) -> impl Iterator<Item = (GlobalFragmentId, &StreamFragmentEdge)> {
2221 self.building_graph
2222 .get_downstreams(fragment_id)
2223 .iter()
2224 .chain(
2225 self.extra_downstreams
2226 .get(&fragment_id)
2227 .into_iter()
2228 .flatten(),
2229 )
2230 .map(|(&id, edge)| (id, edge))
2231 }
2232
2233 pub(super) fn get_upstreams(
2236 &self,
2237 fragment_id: GlobalFragmentId,
2238 ) -> impl Iterator<Item = (GlobalFragmentId, &StreamFragmentEdge)> {
2239 self.building_graph
2240 .get_upstreams(fragment_id)
2241 .iter()
2242 .chain(self.extra_upstreams.get(&fragment_id).into_iter().flatten())
2243 .map(|(&id, edge)| (id, edge))
2244 }
2245
2246 pub(super) fn building_fragments(&self) -> &HashMap<GlobalFragmentId, BuildingFragment> {
2248 &self.building_graph.fragments
2249 }
2250
2251 pub(super) fn building_fragments_mut(
2253 &mut self,
2254 ) -> &mut HashMap<GlobalFragmentId, BuildingFragment> {
2255 &mut self.building_graph.fragments
2256 }
2257
2258 pub(super) fn max_parallelism(&self) -> usize {
2260 self.building_graph.max_parallelism()
2261 }
2262}
2263
2264#[cfg(test)]
2265mod tests {
2266 use risingwave_common::catalog::{ColumnDesc, ColumnId};
2267 use risingwave_common::types::DataType;
2268 use risingwave_pb::catalog::SinkType as PbSinkType;
2269 use risingwave_pb::meta::table_fragments::fragment::PbFragmentDistributionType;
2270 use risingwave_pb::plan_common::StorageTableDesc;
2271 use risingwave_pb::stream_plan::{
2272 BatchPlanNode, MergeNode, ProjectNode, SinkDesc, SinkLogStoreType, SinkNode, StreamNode,
2273 StreamScanNode, StreamScanType,
2274 };
2275
2276 use super::*;
2277
2278 fn make_column(name: &str, id: i32, data_type: DataType) -> ColumnCatalog {
2279 ColumnCatalog::visible(ColumnDesc::named(name, ColumnId::new(id), data_type))
2280 }
2281
2282 fn make_field(table_name: &str, column: &ColumnCatalog) -> risingwave_pb::plan_common::Field {
2283 Field::new(
2284 format!("{}.{}", table_name, column.column_desc.name),
2285 column.data_type().clone(),
2286 )
2287 .to_prost()
2288 }
2289
2290 fn make_input_ref(index: u32, data_type: &DataType) -> PbExprNode {
2291 PbExprNode {
2292 function_type: expr_node::Type::Unspecified as i32,
2293 return_type: Some(data_type.to_protobuf()),
2294 rex_node: Some(expr_node::RexNode::InputRef(index)),
2295 }
2296 }
2297
2298 fn make_stream_scan_node(
2299 table_name: &str,
2300 table_id: u32,
2301 columns: &[ColumnCatalog],
2302 ) -> StreamNode {
2303 let merge_node = StreamNode {
2304 node_body: Some(NodeBody::Merge(Box::new(MergeNode {
2305 upstream_fragment_id: 0.into(),
2306 ..Default::default()
2307 }))),
2308 fields: columns
2309 .iter()
2310 .map(|col| make_field(table_name, col))
2311 .collect(),
2312 ..Default::default()
2313 };
2314 let batch_plan_node = StreamNode {
2315 node_body: Some(NodeBody::BatchPlan(Box::new(BatchPlanNode {
2316 ..Default::default()
2317 }))),
2318 ..Default::default()
2319 };
2320 let stream_scan_node = StreamScanNode {
2321 table_id: table_id.into(),
2322 upstream_column_ids: columns.iter().map(|c| c.column_id().get_id()).collect(),
2323 output_indices: (0..columns.len()).map(|i| i as u32).collect(),
2324 stream_scan_type: StreamScanType::ArrangementBackfill as i32,
2325 table_desc: Some(StorageTableDesc {
2326 table_id: table_id.into(),
2327 columns: columns
2328 .iter()
2329 .map(|col| col.column_desc.to_protobuf())
2330 .collect(),
2331 value_indices: (0..columns.len()).map(|i| i as u32).collect(),
2332 versioned: true,
2333 ..Default::default()
2334 }),
2335 ..Default::default()
2336 };
2337 StreamNode {
2338 node_body: Some(NodeBody::StreamScan(Box::new(stream_scan_node))),
2339 fields: columns
2340 .iter()
2341 .map(|col| make_field(table_name, col))
2342 .collect(),
2343 input: vec![merge_node, batch_plan_node],
2344 ..Default::default()
2345 }
2346 }
2347
2348 fn make_project_node(
2349 table_name: &str,
2350 columns: &[ColumnCatalog],
2351 input: StreamNode,
2352 ) -> StreamNode {
2353 let select_list = columns
2354 .iter()
2355 .enumerate()
2356 .map(|(i, col)| make_input_ref(i as u32, col.data_type()))
2357 .collect();
2358 StreamNode {
2359 node_body: Some(NodeBody::Project(Box::new(ProjectNode {
2360 select_list,
2361 ..Default::default()
2362 }))),
2363 fields: columns
2364 .iter()
2365 .map(|col| make_field(table_name, col))
2366 .collect(),
2367 input: vec![input],
2368 ..Default::default()
2369 }
2370 }
2371
2372 #[tokio::test]
2373 async fn test_rewrite_refresh_schema_sink_fragment_with_project() {
2374 let env = MetaSrvEnv::for_test().await;
2375 let id_gen_manager = env.id_gen_manager().as_ref();
2376
2377 let table_name = "t";
2378 let columns = vec![
2379 make_column("a", 1, DataType::Int64),
2380 make_column("b", 2, DataType::Int64),
2381 ];
2382 let new_column = make_column("c", 3, DataType::Varchar);
2383
2384 let mut upstream_columns = columns.clone();
2385 upstream_columns.push(new_column.clone());
2386 let upstream_table = PbTable {
2387 name: table_name.to_owned(),
2388 columns: upstream_columns
2389 .iter()
2390 .map(|col| col.to_protobuf())
2391 .collect(),
2392 ..Default::default()
2393 };
2394
2395 let sink = PbSink {
2396 columns: columns.iter().map(|col| col.to_protobuf()).collect(),
2397 sink_type: PbSinkType::AppendOnly as i32,
2398 ..Default::default()
2399 };
2400
2401 let sink_desc = SinkDesc {
2402 sink_type: PbSinkType::AppendOnly as i32,
2403 column_catalogs: sink.columns.clone(),
2404 ..Default::default()
2405 };
2406
2407 let stream_scan_node = make_stream_scan_node(table_name, 1, &columns);
2408 let project_node = make_project_node(table_name, &columns, stream_scan_node);
2409
2410 let log_store_table = PbTable {
2411 columns: columns
2412 .iter()
2413 .cloned()
2414 .map(|mut col| {
2415 col.column_desc.name = format!("{}_{}", table_name, col.column_desc.name);
2416 col.to_protobuf()
2417 })
2418 .collect(),
2419 value_indices: (0..columns.len()).map(|i| i as i32).collect(),
2420 ..Default::default()
2421 };
2422
2423 let original_fragment = Fragment {
2424 fragment_id: 1.into(),
2425 fragment_type_mask: FragmentTypeMask::default(),
2426 distribution_type: PbFragmentDistributionType::Single,
2427 state_table_ids: vec![],
2428 maybe_vnode_count: None,
2429 nodes: StreamNode {
2430 node_body: Some(NodeBody::Sink(Box::new(SinkNode {
2431 sink_desc: Some(sink_desc),
2432 table: Some(log_store_table),
2433 ..Default::default()
2434 }))),
2435 fields: columns
2436 .iter()
2437 .map(|col| make_field(table_name, col))
2438 .collect(),
2439 input: vec![project_node],
2440 ..Default::default()
2441 },
2442 };
2443
2444 let (new_fragment, _, _) = rewrite_refresh_schema_sink_fragment(
2445 &original_fragment,
2446 &sink,
2447 std::slice::from_ref(&new_column),
2448 &[],
2449 &upstream_table,
2450 7.into(),
2451 id_gen_manager,
2452 )
2453 .unwrap();
2454
2455 let sink_node = &new_fragment.nodes;
2456 let [project_node] = sink_node.input.as_slice() else {
2457 panic!("Sink has more than 1 input: {:?}", sink_node.input);
2458 };
2459 let PbNodeBody::Project(project_body) = project_node.node_body.as_ref().unwrap() else {
2460 panic!(
2461 "expect PbNodeBody::Project but got: {:?}",
2462 project_node.node_body
2463 );
2464 };
2465 assert_eq!(project_body.select_list.len(), columns.len() + 1);
2466 let last_expr = project_body.select_list.last().unwrap();
2467 assert!(
2468 matches!(last_expr.rex_node, Some(expr_node::RexNode::InputRef(idx)) if idx == columns.len() as u32)
2469 );
2470 assert_eq!(project_node.fields.len(), columns.len() + 1);
2471
2472 let [stream_scan_node] = project_node.input.as_slice() else {
2473 panic!("Project has more than 1 input: {:?}", project_node.input);
2474 };
2475 let PbNodeBody::StreamScan(scan) = stream_scan_node.node_body.as_ref().unwrap() else {
2476 panic!(
2477 "expect PbNodeBody::StreamScan but got: {:?}",
2478 stream_scan_node.node_body
2479 );
2480 };
2481 assert_eq!(
2482 scan.upstream_column_ids.last().copied(),
2483 Some(new_column.column_id().get_id())
2484 );
2485 assert_eq!(
2486 scan.output_indices.last().copied(),
2487 Some(columns.len() as u32)
2488 );
2489 assert_eq!(
2490 stream_scan_node.fields.last().unwrap().name,
2491 format!("{}.{}", table_name, new_column.column_desc.name)
2492 );
2493 }
2494
2495 #[tokio::test]
2496 async fn test_rewrite_refresh_schema_sink_fragment_drop_column_with_project() {
2497 let env = MetaSrvEnv::for_test().await;
2498 let id_gen_manager = env.id_gen_manager().as_ref();
2499
2500 let table_name = "t";
2501 let columns = vec![
2502 make_column("a", 1, DataType::Int64),
2503 make_column("b", 2, DataType::Int64),
2504 make_column("tmp", 3, DataType::Varchar),
2505 ];
2506 let removed_column = columns.last().unwrap().clone();
2507 let upstream_columns = columns[..2].to_vec();
2508
2509 let upstream_table = PbTable {
2510 name: table_name.to_owned(),
2511 columns: upstream_columns
2512 .iter()
2513 .map(|col| col.to_protobuf())
2514 .collect(),
2515 ..Default::default()
2516 };
2517
2518 let sink = PbSink {
2519 columns: columns.iter().map(|col| col.to_protobuf()).collect(),
2520 sink_type: PbSinkType::AppendOnly as i32,
2521 ..Default::default()
2522 };
2523
2524 let sink_desc = SinkDesc {
2525 sink_type: PbSinkType::AppendOnly as i32,
2526 column_catalogs: sink.columns.clone(),
2527 ..Default::default()
2528 };
2529
2530 let stream_scan_node = make_stream_scan_node(table_name, 1, &columns);
2531 let project_node = make_project_node(table_name, &columns, stream_scan_node);
2532
2533 let log_store_table = PbTable {
2534 columns: columns
2535 .iter()
2536 .cloned()
2537 .map(|mut col| {
2538 col.column_desc.name = format!("{}_{}", table_name, col.column_desc.name);
2539 col.to_protobuf()
2540 })
2541 .collect(),
2542 value_indices: (0..columns.len()).map(|i| i as i32).collect(),
2543 ..Default::default()
2544 };
2545
2546 let original_fragment = Fragment {
2547 fragment_id: 1.into(),
2548 fragment_type_mask: FragmentTypeMask::default(),
2549 distribution_type: PbFragmentDistributionType::Single,
2550 state_table_ids: vec![],
2551 maybe_vnode_count: None,
2552 nodes: StreamNode {
2553 node_body: Some(NodeBody::Sink(Box::new(SinkNode {
2554 sink_desc: Some(sink_desc),
2555 table: Some(log_store_table),
2556 log_store_type: SinkLogStoreType::KvLogStore as i32,
2557 ..Default::default()
2558 }))),
2559 fields: columns
2560 .iter()
2561 .map(|col| make_field(table_name, col))
2562 .collect(),
2563 input: vec![project_node],
2564 ..Default::default()
2565 },
2566 };
2567
2568 let (new_fragment, new_schema, new_log_store_table) = rewrite_refresh_schema_sink_fragment(
2569 &original_fragment,
2570 &sink,
2571 &[],
2572 std::slice::from_ref(&removed_column),
2573 &upstream_table,
2574 7.into(),
2575 id_gen_manager,
2576 )
2577 .unwrap();
2578
2579 assert_eq!(new_schema.len(), 2);
2580 assert!(
2581 new_schema.iter().all(|col| {
2582 col.column_desc.as_ref().map(|desc| desc.name.as_str()) != Some("tmp")
2583 })
2584 );
2585
2586 let sink_node = &new_fragment.nodes;
2587 let [project_node] = sink_node.input.as_slice() else {
2588 panic!("Sink has more than 1 input: {:?}", sink_node.input);
2589 };
2590 let PbNodeBody::Project(project_body) = project_node.node_body.as_ref().unwrap() else {
2591 panic!(
2592 "expect PbNodeBody::Project but got: {:?}",
2593 project_node.node_body
2594 );
2595 };
2596 assert_eq!(project_body.select_list.len(), 2);
2597 assert!(project_node.fields.iter().all(|f| !f.name.contains("tmp")));
2598
2599 let [stream_scan_node] = project_node.input.as_slice() else {
2600 panic!("Project has more than 1 input: {:?}", project_node.input);
2601 };
2602 let PbNodeBody::StreamScan(scan) = stream_scan_node.node_body.as_ref().unwrap() else {
2603 panic!(
2604 "expect PbNodeBody::StreamScan but got: {:?}",
2605 stream_scan_node.node_body
2606 );
2607 };
2608 assert!(
2609 !scan
2610 .upstream_column_ids
2611 .iter()
2612 .any(|&id| id == removed_column.column_id().get_id())
2613 );
2614 assert!(
2615 stream_scan_node
2616 .fields
2617 .iter()
2618 .all(|f| !f.name.contains("tmp"))
2619 );
2620
2621 let new_log_store_table = new_log_store_table.expect("log store table should be updated");
2622 assert!(
2623 new_log_store_table.columns.iter().all(|col| !col
2624 .column_desc
2625 .as_ref()
2626 .unwrap()
2627 .name
2628 .contains("tmp"))
2629 );
2630 assert_eq!(
2631 new_log_store_table.value_indices,
2632 (0..new_log_store_table.columns.len() as i32).collect::<Vec<_>>()
2633 );
2634 }
2635}