1use std::cmp::min;
16use std::collections::{HashMap, HashSet};
17use std::fmt::{Debug, Display, Formatter};
18use std::num::NonZeroU64;
19
20use anyhow::anyhow;
21use async_recursion::async_recursion;
22use enum_as_inner::EnumAsInner;
23use futures::TryStreamExt;
24use iceberg::expr::Predicate as IcebergPredicate;
25use itertools::Itertools;
26use petgraph::{Directed, Graph};
27use pgwire::pg_server::SessionId;
28use risingwave_batch::error::BatchError;
29use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeSelector;
30use risingwave_common::bitmap::{Bitmap, BitmapBuilder};
31use risingwave_common::catalog::Schema;
32use risingwave_common::hash::table_distribution::TableDistribution;
33use risingwave_common::hash::{WorkerSlotId, WorkerSlotMapping};
34use risingwave_common::util::scan_range::ScanRange;
35use risingwave_connector::source::filesystem::opendal_source::opendal_enumerator::OpendalEnumerator;
36use risingwave_connector::source::filesystem::opendal_source::{
37 BatchPosixFsEnumerator, OpendalAzblob, OpendalGcs, OpendalS3,
38};
39use risingwave_connector::source::iceberg::IcebergSplitEnumerator;
40use risingwave_connector::source::kafka::KafkaSplitEnumerator;
41use risingwave_connector::source::prelude::DatagenSplitEnumerator;
42use risingwave_connector::source::reader::reader::build_opendal_fs_list_for_batch;
43use risingwave_connector::source::{
44 ConnectorProperties, SourceEnumeratorContext, SplitEnumerator, SplitImpl,
45};
46use risingwave_pb::batch_plan::iceberg_scan_node::IcebergScanType;
47use risingwave_pb::batch_plan::plan_node::NodeBody;
48use risingwave_pb::batch_plan::{ExchangeInfo, ScanRange as ScanRangeProto};
49use risingwave_pb::plan_common::Field as PbField;
50use serde::ser::SerializeStruct;
51use serde::{Serialize, Serializer};
52use uuid::Uuid;
53
54use super::SchedulerError;
55use crate::TableCatalog;
56use crate::catalog::TableId;
57use crate::catalog::catalog_service::CatalogReader;
58use crate::optimizer::plan_node::generic::{GenericPlanRef, PhysicalPlanRef};
59use crate::optimizer::plan_node::{
60 BatchIcebergScan, BatchKafkaScan, BatchPlanNodeType, BatchPlanRef as PlanRef, BatchSource,
61 PlanNodeId,
62};
63use crate::optimizer::property::Distribution;
64use crate::scheduler::SchedulerResult;
65
66#[derive(Clone, Debug, Hash, Eq, PartialEq)]
67pub struct QueryId {
68 pub id: String,
69}
70
71impl std::fmt::Display for QueryId {
72 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
73 write!(f, "QueryId:{}", self.id)
74 }
75}
76
77#[derive(Copy, Clone, Hash, PartialEq, Eq, Ord, PartialOrd)]
78pub struct StageId(u32);
79
80impl Display for StageId {
81 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
82 write!(f, "{}", self.0)
83 }
84}
85
86impl Debug for StageId {
87 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
88 write!(f, "{:?}", self.0)
89 }
90}
91
92impl From<StageId> for u32 {
93 fn from(value: StageId) -> Self {
94 value.0
95 }
96}
97
98impl From<u32> for StageId {
99 fn from(value: u32) -> Self {
100 StageId(value)
101 }
102}
103
104impl Serialize for StageId {
105 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
106 where
107 S: Serializer,
108 {
109 self.0.serialize(serializer)
110 }
111}
112
113impl StageId {
114 pub fn inc(&mut self) {
115 self.0 += 1;
116 }
117}
118
119pub const ROOT_TASK_ID: u64 = 0;
121pub const ROOT_TASK_OUTPUT_ID: u64 = 0;
123pub type TaskId = u64;
124
125#[derive(Debug)]
127#[cfg_attr(test, derive(Clone))]
128pub struct ExecutionPlanNode {
129 pub plan_node_id: PlanNodeId,
130 pub plan_node_type: BatchPlanNodeType,
131 pub node: NodeBody,
132 pub schema: Vec<PbField>,
133
134 pub children: Vec<ExecutionPlanNode>,
135
136 pub source_stage_id: Option<StageId>,
141}
142
143impl Serialize for ExecutionPlanNode {
144 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
145 where
146 S: serde::Serializer,
147 {
148 let mut state = serializer.serialize_struct("QueryStage", 5)?;
149 state.serialize_field("plan_node_id", &self.plan_node_id)?;
150 state.serialize_field("plan_node_type", &self.plan_node_type)?;
151 state.serialize_field("schema", &self.schema)?;
152 state.serialize_field("children", &self.children)?;
153 state.serialize_field("source_stage_id", &self.source_stage_id)?;
154 state.end()
155 }
156}
157
158impl TryFrom<PlanRef> for ExecutionPlanNode {
159 type Error = SchedulerError;
160
161 fn try_from(plan_node: PlanRef) -> Result<Self, Self::Error> {
162 Ok(Self {
163 plan_node_id: plan_node.plan_base().id(),
164 plan_node_type: plan_node.node_type(),
165 node: plan_node.try_to_batch_prost_body()?,
166 children: vec![],
167 schema: plan_node.schema().to_prost(),
168 source_stage_id: None,
169 })
170 }
171}
172
173impl ExecutionPlanNode {
174 pub fn node_type(&self) -> BatchPlanNodeType {
175 self.plan_node_type
176 }
177}
178
179pub struct BatchPlanFragmenter {
181 query_id: QueryId,
182 next_stage_id: StageId,
183 worker_node_manager: WorkerNodeSelector,
184 catalog_reader: CatalogReader,
185
186 batch_parallelism: usize,
187
188 stage_graph_builder: Option<StageGraphBuilder>,
189 stage_graph: Option<StageGraph>,
190}
191
192impl Default for QueryId {
193 fn default() -> Self {
194 Self {
195 id: Uuid::new_v4().to_string(),
196 }
197 }
198}
199
200impl BatchPlanFragmenter {
201 pub fn new(
202 worker_node_manager: WorkerNodeSelector,
203 catalog_reader: CatalogReader,
204 batch_parallelism: Option<NonZeroU64>,
205 batch_node: PlanRef,
206 ) -> SchedulerResult<Self> {
207 let batch_parallelism = if let Some(num) = batch_parallelism {
212 min(
214 num.get() as usize,
215 worker_node_manager.schedule_unit_count(),
216 )
217 } else {
218 worker_node_manager.schedule_unit_count()
220 };
221
222 let mut plan_fragmenter = Self {
223 query_id: Default::default(),
224 next_stage_id: 0.into(),
225 worker_node_manager,
226 catalog_reader,
227 batch_parallelism,
228 stage_graph_builder: Some(StageGraphBuilder::new(batch_parallelism)),
229 stage_graph: None,
230 };
231 plan_fragmenter.split_into_stage(batch_node)?;
232 Ok(plan_fragmenter)
233 }
234
235 fn split_into_stage(&mut self, batch_node: PlanRef) -> SchedulerResult<()> {
237 let root_stage_id = self.new_stage(
238 batch_node,
239 Some(Distribution::Single.to_prost(
240 1,
241 &self.catalog_reader,
242 &self.worker_node_manager,
243 )?),
244 )?;
245 self.stage_graph = Some(
246 self.stage_graph_builder
247 .take()
248 .unwrap()
249 .build(root_stage_id),
250 );
251 Ok(())
252 }
253}
254
255#[derive(Debug)]
257#[cfg_attr(test, derive(Clone))]
258pub struct Query {
259 pub query_id: QueryId,
261 pub stage_graph: StageGraph,
262}
263
264impl Query {
265 pub fn leaf_stages(&self) -> Vec<StageId> {
266 let mut ret_leaf_stages = Vec::new();
267 for stage_id in self.stage_graph.stages.keys() {
268 if self
269 .stage_graph
270 .get_child_stages_unchecked(stage_id)
271 .is_empty()
272 {
273 ret_leaf_stages.push(*stage_id);
274 }
275 }
276 ret_leaf_stages
277 }
278
279 pub fn get_parents(&self, stage_id: &StageId) -> &HashSet<StageId> {
280 self.stage_graph.parent_edges.get(stage_id).unwrap()
281 }
282
283 pub fn root_stage_id(&self) -> StageId {
284 self.stage_graph.root_stage_id
285 }
286
287 pub fn query_id(&self) -> &QueryId {
288 &self.query_id
289 }
290
291 pub fn stages_with_table_scan(&self) -> HashSet<StageId> {
292 self.stage_graph
293 .stages
294 .iter()
295 .filter_map(|(stage_id, stage_query)| {
296 if stage_query.has_table_scan() {
297 Some(*stage_id)
298 } else {
299 None
300 }
301 })
302 .collect()
303 }
304
305 pub fn has_lookup_join_stage(&self) -> bool {
306 self.stage_graph
307 .stages
308 .iter()
309 .any(|(_stage_id, stage_query)| stage_query.has_lookup_join())
310 }
311
312 pub fn stage(&self, stage_id: StageId) -> &QueryStage {
313 &self.stage_graph.stages[&stage_id]
314 }
315}
316
317#[derive(Debug, Clone)]
318pub enum SourceFetchParameters {
319 IcebergSpecificInfo(IcebergSpecificInfo),
320 KafkaTimebound {
321 lower: Option<i64>,
322 upper: Option<i64>,
323 },
324 Empty,
325}
326
327#[derive(Debug, Clone)]
328pub struct SourceFetchInfo {
329 pub schema: Schema,
330 pub connector: ConnectorProperties,
333 pub fetch_parameters: SourceFetchParameters,
336}
337
338#[derive(Debug, Clone)]
339pub struct IcebergSpecificInfo {
340 pub iceberg_scan_type: IcebergScanType,
341 pub predicate: IcebergPredicate,
342 pub snapshot_id: Option<i64>,
343}
344
345#[derive(Clone, Debug)]
346pub enum SourceScanInfo {
347 Incomplete(SourceFetchInfo),
349 Complete(Vec<SplitImpl>),
350}
351
352impl SourceScanInfo {
353 pub fn new(fetch_info: SourceFetchInfo) -> Self {
354 Self::Incomplete(fetch_info)
355 }
356
357 pub async fn complete(self, batch_parallelism: usize) -> SchedulerResult<Self> {
358 let fetch_info = match self {
359 SourceScanInfo::Incomplete(fetch_info) => fetch_info,
360 SourceScanInfo::Complete(_) => {
361 unreachable!("Never call complete when SourceScanInfo is already complete")
362 }
363 };
364 match (fetch_info.connector, fetch_info.fetch_parameters) {
365 (
366 ConnectorProperties::Kafka(prop),
367 SourceFetchParameters::KafkaTimebound { lower, upper },
368 ) => {
369 let mut kafka_enumerator =
370 KafkaSplitEnumerator::new(*prop, SourceEnumeratorContext::dummy().into())
371 .await?;
372 let split_info = kafka_enumerator
373 .list_splits_batch(lower, upper)
374 .await?
375 .into_iter()
376 .map(SplitImpl::Kafka)
377 .collect_vec();
378
379 Ok(SourceScanInfo::Complete(split_info))
380 }
381 (ConnectorProperties::Datagen(prop), SourceFetchParameters::Empty) => {
382 let mut datagen_enumerator =
383 DatagenSplitEnumerator::new(*prop, SourceEnumeratorContext::dummy().into())
384 .await?;
385 let split_info = datagen_enumerator.list_splits().await?;
386 let res = split_info.into_iter().map(SplitImpl::Datagen).collect_vec();
387
388 Ok(SourceScanInfo::Complete(res))
389 }
390 (ConnectorProperties::OpendalS3(prop), SourceFetchParameters::Empty) => {
391 let lister: OpendalEnumerator<OpendalS3> = OpendalEnumerator::new_s3_source(
392 &prop.s3_properties,
393 prop.assume_role,
394 prop.fs_common.compression_format,
395 )?;
396 let stream = build_opendal_fs_list_for_batch(lister);
397
398 let batch_res: Vec<_> = stream.try_collect().await?;
399 let res = batch_res
400 .into_iter()
401 .map(SplitImpl::OpendalS3)
402 .collect_vec();
403
404 Ok(SourceScanInfo::Complete(res))
405 }
406 (ConnectorProperties::Gcs(prop), SourceFetchParameters::Empty) => {
407 let lister: OpendalEnumerator<OpendalGcs> =
408 OpendalEnumerator::new_gcs_source(*prop)?;
409 let stream = build_opendal_fs_list_for_batch(lister);
410 let batch_res: Vec<_> = stream.try_collect().await?;
411 let res = batch_res.into_iter().map(SplitImpl::Gcs).collect_vec();
412
413 Ok(SourceScanInfo::Complete(res))
414 }
415 (ConnectorProperties::Azblob(prop), SourceFetchParameters::Empty) => {
416 let lister: OpendalEnumerator<OpendalAzblob> =
417 OpendalEnumerator::new_azblob_source(*prop)?;
418 let stream = build_opendal_fs_list_for_batch(lister);
419 let batch_res: Vec<_> = stream.try_collect().await?;
420 let res = batch_res.into_iter().map(SplitImpl::Azblob).collect_vec();
421
422 Ok(SourceScanInfo::Complete(res))
423 }
424 (ConnectorProperties::BatchPosixFs(prop), SourceFetchParameters::Empty) => {
425 use risingwave_connector::source::SplitEnumerator;
426 let mut enumerator = BatchPosixFsEnumerator::new(
427 *prop,
428 risingwave_connector::source::SourceEnumeratorContext::dummy().into(),
429 )
430 .await?;
431 let splits = enumerator.list_splits().await?;
432 let res = splits
433 .into_iter()
434 .map(SplitImpl::BatchPosixFs)
435 .collect_vec();
436
437 Ok(SourceScanInfo::Complete(res))
438 }
439 (
440 ConnectorProperties::Iceberg(prop),
441 SourceFetchParameters::IcebergSpecificInfo(iceberg_specific_info),
442 ) => {
443 let iceberg_enumerator =
444 IcebergSplitEnumerator::new(*prop, SourceEnumeratorContext::dummy().into())
445 .await?;
446
447 let split_info = iceberg_enumerator
448 .list_splits_batch(
449 fetch_info.schema,
450 iceberg_specific_info.snapshot_id,
451 batch_parallelism,
452 iceberg_specific_info.iceberg_scan_type,
453 iceberg_specific_info.predicate,
454 )
455 .await?
456 .into_iter()
457 .map(SplitImpl::Iceberg)
458 .collect_vec();
459
460 Ok(SourceScanInfo::Complete(split_info))
461 }
462 (connector, _) => Err(SchedulerError::Internal(anyhow!(
463 "Unsupported to query directly from this {} source, \
464 please create a table or streaming job from it",
465 connector.kind()
466 ))),
467 }
468 }
469
470 pub fn split_info(&self) -> SchedulerResult<&Vec<SplitImpl>> {
471 match self {
472 Self::Incomplete(_) => Err(SchedulerError::Internal(anyhow!(
473 "Should not get split info from incomplete source scan info"
474 ))),
475 Self::Complete(split_info) => Ok(split_info),
476 }
477 }
478}
479
480#[derive(Clone, Debug)]
481pub struct TableScanInfo {
482 name: String,
484
485 partitions: Option<HashMap<WorkerSlotId, TablePartitionInfo>>,
493}
494
495impl TableScanInfo {
496 pub fn new(name: String, partitions: HashMap<WorkerSlotId, TablePartitionInfo>) -> Self {
498 Self {
499 name,
500 partitions: Some(partitions),
501 }
502 }
503
504 pub fn system_table(name: String) -> Self {
506 Self {
507 name,
508 partitions: None,
509 }
510 }
511
512 pub fn name(&self) -> &str {
513 self.name.as_ref()
514 }
515
516 pub fn partitions(&self) -> Option<&HashMap<WorkerSlotId, TablePartitionInfo>> {
517 self.partitions.as_ref()
518 }
519}
520
521#[derive(Clone, Debug)]
522pub struct TablePartitionInfo {
523 pub vnode_bitmap: Bitmap,
524 pub scan_ranges: Vec<ScanRangeProto>,
525}
526
527#[derive(Clone, Debug, EnumAsInner)]
528pub enum PartitionInfo {
529 Table(TablePartitionInfo),
530 Source(Vec<SplitImpl>),
531 File(Vec<String>),
532}
533
534#[derive(Clone, Debug)]
535pub struct FileScanInfo {
536 pub file_location: Vec<String>,
537}
538
539#[cfg_attr(test, derive(Clone))]
541pub struct QueryStage {
542 pub id: StageId,
543 pub root: ExecutionPlanNode,
544 pub exchange_info: Option<ExchangeInfo>,
545 pub parallelism: Option<u32>,
546 pub table_scan_info: Option<TableScanInfo>,
548 pub source_info: Option<SourceScanInfo>,
549 pub file_scan_info: Option<FileScanInfo>,
550 pub has_lookup_join: bool,
551 pub dml_table_id: Option<TableId>,
552 pub session_id: SessionId,
553 pub batch_enable_distributed_dml: bool,
554
555 children_exchange_distribution: Option<HashMap<StageId, Distribution>>,
557}
558
559impl QueryStage {
560 pub fn has_table_scan(&self) -> bool {
564 self.table_scan_info.is_some()
565 }
566
567 pub fn has_lookup_join(&self) -> bool {
570 self.has_lookup_join
571 }
572
573 pub fn with_exchange_info(
574 self,
575 exchange_info: Option<ExchangeInfo>,
576 parallelism: Option<u32>,
577 ) -> Self {
578 if let Some(exchange_info) = exchange_info {
579 Self {
580 id: self.id,
581 root: self.root,
582 exchange_info: Some(exchange_info),
583 parallelism,
584 table_scan_info: self.table_scan_info,
585 source_info: self.source_info,
586 file_scan_info: self.file_scan_info,
587 has_lookup_join: self.has_lookup_join,
588 dml_table_id: self.dml_table_id,
589 session_id: self.session_id,
590 batch_enable_distributed_dml: self.batch_enable_distributed_dml,
591 children_exchange_distribution: self.children_exchange_distribution,
592 }
593 } else {
594 self
595 }
596 }
597
598 pub fn with_exchange_info_and_complete_source_info(
599 self,
600 exchange_info: Option<ExchangeInfo>,
601 source_info: SourceScanInfo,
602 task_parallelism: u32,
603 ) -> Self {
604 assert!(matches!(source_info, SourceScanInfo::Complete(_)));
605 let exchange_info = if let Some(exchange_info) = exchange_info {
606 Some(exchange_info)
607 } else {
608 self.exchange_info
609 };
610 Self {
611 id: self.id,
612 root: self.root,
613 exchange_info,
614 parallelism: Some(task_parallelism),
615 table_scan_info: self.table_scan_info,
616 source_info: Some(source_info),
617 file_scan_info: self.file_scan_info,
618 has_lookup_join: self.has_lookup_join,
619 dml_table_id: self.dml_table_id,
620 session_id: self.session_id,
621 batch_enable_distributed_dml: self.batch_enable_distributed_dml,
622 children_exchange_distribution: None,
623 }
624 }
625}
626
627impl Debug for QueryStage {
628 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
629 f.debug_struct("QueryStage")
630 .field("id", &self.id)
631 .field("parallelism", &self.parallelism)
632 .field("exchange_info", &self.exchange_info)
633 .field("has_table_scan", &self.has_table_scan())
634 .finish()
635 }
636}
637
638impl Serialize for QueryStage {
639 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
640 where
641 S: serde::Serializer,
642 {
643 let mut state = serializer.serialize_struct("QueryStage", 3)?;
644 state.serialize_field("root", &self.root)?;
645 state.serialize_field("parallelism", &self.parallelism)?;
646 state.serialize_field("exchange_info", &self.exchange_info)?;
647 state.end()
648 }
649}
650
651struct QueryStageBuilder {
652 id: StageId,
653 root: Option<ExecutionPlanNode>,
654 parallelism: Option<u32>,
655 exchange_info: Option<ExchangeInfo>,
656
657 children_stages: Vec<StageId>,
658 table_scan_info: Option<TableScanInfo>,
660 source_info: Option<SourceScanInfo>,
661 file_scan_file: Option<FileScanInfo>,
662 has_lookup_join: bool,
663 dml_table_id: Option<TableId>,
664 session_id: SessionId,
665 batch_enable_distributed_dml: bool,
666
667 children_exchange_distribution: HashMap<StageId, Distribution>,
668}
669
670impl QueryStageBuilder {
671 #[allow(clippy::too_many_arguments)]
672 fn new(
673 id: StageId,
674 parallelism: Option<u32>,
675 exchange_info: Option<ExchangeInfo>,
676 table_scan_info: Option<TableScanInfo>,
677 source_info: Option<SourceScanInfo>,
678 file_scan_file: Option<FileScanInfo>,
679 has_lookup_join: bool,
680 dml_table_id: Option<TableId>,
681 session_id: SessionId,
682 batch_enable_distributed_dml: bool,
683 ) -> Self {
684 Self {
685 id,
686 root: None,
687 parallelism,
688 exchange_info,
689 children_stages: vec![],
690 table_scan_info,
691 source_info,
692 file_scan_file,
693 has_lookup_join,
694 dml_table_id,
695 session_id,
696 batch_enable_distributed_dml,
697 children_exchange_distribution: HashMap::new(),
698 }
699 }
700
701 fn finish(self, stage_graph_builder: &mut StageGraphBuilder) -> StageId {
702 let children_exchange_distribution = if self.parallelism.is_none() {
703 Some(self.children_exchange_distribution)
704 } else {
705 None
706 };
707 let stage = QueryStage {
708 id: self.id,
709 root: self.root.unwrap(),
710 exchange_info: self.exchange_info,
711 parallelism: self.parallelism,
712 table_scan_info: self.table_scan_info,
713 source_info: self.source_info,
714 file_scan_info: self.file_scan_file,
715 has_lookup_join: self.has_lookup_join,
716 dml_table_id: self.dml_table_id,
717 session_id: self.session_id,
718 batch_enable_distributed_dml: self.batch_enable_distributed_dml,
719 children_exchange_distribution,
720 };
721
722 let stage_id = stage.id;
723 stage_graph_builder.add_node(stage);
724 for child_stage_id in self.children_stages {
725 stage_graph_builder.link_to_child(self.id, child_stage_id);
726 }
727 stage_id
728 }
729}
730
731#[derive(Debug, Serialize)]
733#[cfg_attr(test, derive(Clone))]
734pub struct StageGraph {
735 pub root_stage_id: StageId,
736 pub stages: HashMap<StageId, QueryStage>,
737 child_edges: HashMap<StageId, HashSet<StageId>>,
739 parent_edges: HashMap<StageId, HashSet<StageId>>,
741
742 batch_parallelism: usize,
743}
744
745enum StageCompleteInfo {
746 ExchangeInfo((Option<ExchangeInfo>, Option<u32>)),
747 ExchangeWithSourceInfo((Option<ExchangeInfo>, SourceScanInfo, u32)),
748}
749
750impl StageGraph {
751 pub fn get_child_stages_unchecked(&self, stage_id: &StageId) -> &HashSet<StageId> {
752 self.child_edges.get(stage_id).unwrap()
753 }
754
755 pub fn get_child_stages(&self, stage_id: &StageId) -> Option<&HashSet<StageId>> {
756 self.child_edges.get(stage_id)
757 }
758
759 pub fn stage_ids_by_topo_order(&self) -> impl Iterator<Item = StageId> {
761 let mut stack = Vec::with_capacity(self.stages.len());
762 stack.push(self.root_stage_id);
763 let mut ret = Vec::with_capacity(self.stages.len());
764 let mut existing = HashSet::with_capacity(self.stages.len());
765
766 while let Some(s) = stack.pop() {
767 if !existing.contains(&s) {
768 ret.push(s);
769 existing.insert(s);
770 stack.extend(&self.child_edges[&s]);
771 }
772 }
773
774 ret.into_iter().rev()
775 }
776
777 async fn complete(
778 self,
779 catalog_reader: &CatalogReader,
780 worker_node_manager: &WorkerNodeSelector,
781 ) -> SchedulerResult<StageGraph> {
782 let mut complete_stages = HashMap::new();
783 self.complete_stage(
784 self.root_stage_id,
785 None,
786 &mut complete_stages,
787 catalog_reader,
788 worker_node_manager,
789 )
790 .await?;
791 let mut stages = self.stages;
792 Ok(StageGraph {
793 root_stage_id: self.root_stage_id,
794 stages: complete_stages
795 .into_iter()
796 .map(|(stage_id, info)| {
797 let stage = stages.remove(&stage_id).expect("should exist");
798 let stage = match info {
799 StageCompleteInfo::ExchangeInfo((exchange_info, parallelism)) => {
800 stage.with_exchange_info(exchange_info, parallelism)
801 }
802 StageCompleteInfo::ExchangeWithSourceInfo((
803 exchange_info,
804 source_info,
805 parallelism,
806 )) => stage.with_exchange_info_and_complete_source_info(
807 exchange_info,
808 source_info,
809 parallelism,
810 ),
811 };
812 (stage_id, stage)
813 })
814 .collect(),
815 child_edges: self.child_edges,
816 parent_edges: self.parent_edges,
817 batch_parallelism: self.batch_parallelism,
818 })
819 }
820
821 #[async_recursion]
822 async fn complete_stage(
823 &self,
824 stage_id: StageId,
825 exchange_info: Option<ExchangeInfo>,
826 complete_stages: &mut HashMap<StageId, StageCompleteInfo>,
827 catalog_reader: &CatalogReader,
828 worker_node_manager: &WorkerNodeSelector,
829 ) -> SchedulerResult<()> {
830 let stage = &self.stages[&stage_id];
831 let parallelism = if stage.parallelism.is_some() {
832 complete_stages.insert(
834 stage.id,
835 StageCompleteInfo::ExchangeInfo((exchange_info, stage.parallelism)),
836 );
837 None
838 } else if matches!(stage.source_info, Some(SourceScanInfo::Incomplete(_))) {
839 let complete_source_info = stage
840 .source_info
841 .as_ref()
842 .unwrap()
843 .clone()
844 .complete(self.batch_parallelism)
845 .await?;
846
847 let task_parallelism = match &stage.source_info {
855 Some(SourceScanInfo::Incomplete(source_fetch_info)) => {
856 match source_fetch_info.connector {
857 ConnectorProperties::Gcs(_)
858 | ConnectorProperties::OpendalS3(_)
859 | ConnectorProperties::Azblob(_) => (min(
860 complete_source_info.split_info().unwrap().len() as u32,
861 (self.batch_parallelism / 2) as u32,
862 ))
863 .max(1),
864 _ => complete_source_info.split_info().unwrap().len() as u32,
865 }
866 }
867 _ => unreachable!(),
868 };
869 let complete_stage_info = StageCompleteInfo::ExchangeWithSourceInfo((
872 exchange_info,
873 complete_source_info,
874 task_parallelism,
875 ));
876 complete_stages.insert(stage.id, complete_stage_info);
877 Some(task_parallelism)
878 } else {
879 assert!(stage.file_scan_info.is_some());
880 let parallelism = min(
881 self.batch_parallelism / 2,
882 stage.file_scan_info.as_ref().unwrap().file_location.len(),
883 );
884 complete_stages.insert(
885 stage.id,
886 StageCompleteInfo::ExchangeInfo((exchange_info, Some(parallelism as u32))),
887 );
888 None
889 };
890
891 for child_stage_id in self
892 .child_edges
893 .get(&stage.id)
894 .map(|edges| edges.iter())
895 .into_iter()
896 .flatten()
897 {
898 let exchange_info = if let Some(parallelism) = parallelism {
899 let exchange_distribution = stage
900 .children_exchange_distribution
901 .as_ref()
902 .unwrap()
903 .get(child_stage_id)
904 .expect("Exchange distribution is not consistent with the stage graph");
905 Some(exchange_distribution.to_prost(
906 parallelism,
907 catalog_reader,
908 worker_node_manager,
909 )?)
910 } else {
911 None
912 };
913 self.complete_stage(
914 *child_stage_id,
915 exchange_info,
916 complete_stages,
917 catalog_reader,
918 worker_node_manager,
919 )
920 .await?;
921 }
922
923 Ok(())
924 }
925
926 pub fn to_petgraph(&self) -> Graph<String, String, Directed> {
928 let mut graph = Graph::<String, String, Directed>::new();
929
930 let mut node_indices = HashMap::new();
931
932 for (&stage_id, stage_ref) in self.stages.iter().sorted_by_key(|(id, _)| **id) {
934 let node_label = format!("Stage {}: {:?}", stage_id, stage_ref);
935 let node_index = graph.add_node(node_label);
936 node_indices.insert(stage_id, node_index);
937 }
938
939 for (&parent_id, children) in &self.child_edges {
941 if let Some(&parent_index) = node_indices.get(&parent_id) {
942 for &child_id in children {
943 if let Some(&child_index) = node_indices.get(&child_id) {
944 graph.add_edge(parent_index, child_index, "".to_owned());
946 }
947 }
948 }
949 }
950
951 graph
952 }
953}
954
955struct StageGraphBuilder {
956 stages: HashMap<StageId, QueryStage>,
957 child_edges: HashMap<StageId, HashSet<StageId>>,
958 parent_edges: HashMap<StageId, HashSet<StageId>>,
959 batch_parallelism: usize,
960}
961
962impl StageGraphBuilder {
963 pub fn new(batch_parallelism: usize) -> Self {
964 Self {
965 stages: HashMap::new(),
966 child_edges: HashMap::new(),
967 parent_edges: HashMap::new(),
968 batch_parallelism,
969 }
970 }
971
972 pub fn build(self, root_stage_id: StageId) -> StageGraph {
973 StageGraph {
974 root_stage_id,
975 stages: self.stages,
976 child_edges: self.child_edges,
977 parent_edges: self.parent_edges,
978 batch_parallelism: self.batch_parallelism,
979 }
980 }
981
982 pub fn link_to_child(&mut self, parent_id: StageId, child_id: StageId) {
985 self.child_edges
986 .get_mut(&parent_id)
987 .unwrap()
988 .insert(child_id);
989 self.parent_edges
990 .get_mut(&child_id)
991 .unwrap()
992 .insert(parent_id);
993 }
994
995 pub fn add_node(&mut self, stage: QueryStage) {
996 self.child_edges.insert(stage.id, HashSet::new());
998 self.parent_edges.insert(stage.id, HashSet::new());
999 self.stages.insert(stage.id, stage);
1000 }
1001}
1002
1003impl BatchPlanFragmenter {
1004 pub async fn generate_complete_query(self) -> SchedulerResult<Query> {
1010 let stage_graph = self.stage_graph.unwrap();
1011 let new_stage_graph = stage_graph
1012 .complete(&self.catalog_reader, &self.worker_node_manager)
1013 .await?;
1014 Ok(Query {
1015 query_id: self.query_id,
1016 stage_graph: new_stage_graph,
1017 })
1018 }
1019
1020 fn new_stage(
1021 &mut self,
1022 root: PlanRef,
1023 exchange_info: Option<ExchangeInfo>,
1024 ) -> SchedulerResult<StageId> {
1025 let next_stage_id = self.next_stage_id;
1026 self.next_stage_id.inc();
1027
1028 let mut table_scan_info = None;
1029 let mut source_info = None;
1030 let mut file_scan_info = None;
1031
1032 if let Some(info) = self.collect_stage_table_scan(root.clone())? {
1035 table_scan_info = Some(info);
1036 } else if let Some(info) = Self::collect_stage_source(root.clone())? {
1037 source_info = Some(info);
1038 } else if let Some(info) = Self::collect_stage_file_scan(root.clone())? {
1039 file_scan_info = Some(info);
1040 }
1041
1042 let mut has_lookup_join = false;
1043 let parallelism = match root.distribution() {
1044 Distribution::Single => {
1045 if let Some(info) = &mut table_scan_info {
1046 if let Some(partitions) = &mut info.partitions {
1047 if partitions.len() != 1 {
1048 tracing::warn!(
1051 "The stage has single distribution, but contains a scan of table `{}` with {} partitions. A single random worker will be assigned",
1052 info.name,
1053 partitions.len()
1054 );
1055
1056 *partitions = partitions
1057 .drain()
1058 .take(1)
1059 .update(|(_, info)| {
1060 info.vnode_bitmap = Bitmap::ones(info.vnode_bitmap.len());
1061 })
1062 .collect();
1063 }
1064 } else {
1065 }
1067 } else if source_info.is_some() {
1068 return Err(SchedulerError::Internal(anyhow!(
1069 "The stage has single distribution, but contains a source operator"
1070 )));
1071 }
1072 1
1073 }
1074 _ => {
1075 if let Some(table_scan_info) = &table_scan_info {
1076 table_scan_info
1077 .partitions
1078 .as_ref()
1079 .map(|m| m.len())
1080 .unwrap_or(1)
1081 } else if let Some(lookup_join_parallelism) =
1082 self.collect_stage_lookup_join_parallelism(root.clone())?
1083 {
1084 has_lookup_join = true;
1085 lookup_join_parallelism
1086 } else if source_info.is_some() {
1087 0
1088 } else if file_scan_info.is_some() {
1089 1
1090 } else {
1091 self.batch_parallelism
1092 }
1093 }
1094 };
1095 if source_info.is_none() && file_scan_info.is_none() && parallelism == 0 {
1096 return Err(BatchError::EmptyWorkerNodes.into());
1097 }
1098 let parallelism = if parallelism == 0 {
1099 None
1100 } else {
1101 Some(parallelism as u32)
1102 };
1103 let dml_table_id = Self::collect_dml_table_id(&root);
1104 let mut builder = QueryStageBuilder::new(
1105 next_stage_id,
1106 parallelism,
1107 exchange_info,
1108 table_scan_info,
1109 source_info,
1110 file_scan_info,
1111 has_lookup_join,
1112 dml_table_id,
1113 root.ctx().session_ctx().session_id(),
1114 root.ctx()
1115 .session_ctx()
1116 .config()
1117 .batch_enable_distributed_dml(),
1118 );
1119
1120 self.visit_node(root, &mut builder, None)?;
1121
1122 Ok(builder.finish(self.stage_graph_builder.as_mut().unwrap()))
1123 }
1124
1125 fn visit_node(
1126 &mut self,
1127 node: PlanRef,
1128 builder: &mut QueryStageBuilder,
1129 parent_exec_node: Option<&mut ExecutionPlanNode>,
1130 ) -> SchedulerResult<()> {
1131 match node.node_type() {
1132 BatchPlanNodeType::BatchExchange => {
1133 self.visit_exchange(node, builder, parent_exec_node)?;
1134 }
1135 _ => {
1136 let mut execution_plan_node = ExecutionPlanNode::try_from(node.clone())?;
1137
1138 for child in node.inputs() {
1139 self.visit_node(child, builder, Some(&mut execution_plan_node))?;
1140 }
1141
1142 if let Some(parent) = parent_exec_node {
1143 parent.children.push(execution_plan_node);
1144 } else {
1145 builder.root = Some(execution_plan_node);
1146 }
1147 }
1148 }
1149 Ok(())
1150 }
1151
1152 fn visit_exchange(
1153 &mut self,
1154 node: PlanRef,
1155 builder: &mut QueryStageBuilder,
1156 parent_exec_node: Option<&mut ExecutionPlanNode>,
1157 ) -> SchedulerResult<()> {
1158 let mut execution_plan_node = ExecutionPlanNode::try_from(node.clone())?;
1159 let child_exchange_info = if let Some(parallelism) = builder.parallelism {
1160 Some(node.distribution().to_prost(
1161 parallelism,
1162 &self.catalog_reader,
1163 &self.worker_node_manager,
1164 )?)
1165 } else {
1166 None
1167 };
1168 let child_stage_id = self.new_stage(node.inputs()[0].clone(), child_exchange_info)?;
1169 execution_plan_node.source_stage_id = Some(child_stage_id);
1170 if builder.parallelism.is_none() {
1171 builder
1172 .children_exchange_distribution
1173 .insert(child_stage_id, node.distribution().clone());
1174 }
1175
1176 if let Some(parent) = parent_exec_node {
1177 parent.children.push(execution_plan_node);
1178 } else {
1179 builder.root = Some(execution_plan_node);
1180 }
1181
1182 builder.children_stages.push(child_stage_id);
1183 Ok(())
1184 }
1185
1186 fn collect_stage_source(node: PlanRef) -> SchedulerResult<Option<SourceScanInfo>> {
1191 if node.node_type() == BatchPlanNodeType::BatchExchange {
1192 return Ok(None);
1194 }
1195
1196 if let Some(batch_kafka_node) = node.as_batch_kafka_scan() {
1197 let batch_kafka_scan: &BatchKafkaScan = batch_kafka_node;
1198 let source_catalog = batch_kafka_scan.source_catalog();
1199 if let Some(source_catalog) = source_catalog {
1200 let property =
1201 ConnectorProperties::extract(source_catalog.with_properties.clone(), false)?;
1202 let timestamp_bound = batch_kafka_scan.kafka_timestamp_range_value();
1203 return Ok(Some(SourceScanInfo::new(SourceFetchInfo {
1204 schema: batch_kafka_scan.base.schema().clone(),
1205 connector: property,
1206 fetch_parameters: SourceFetchParameters::KafkaTimebound {
1207 lower: timestamp_bound.0,
1208 upper: timestamp_bound.1,
1209 },
1210 })));
1211 }
1212 } else if let Some(batch_iceberg_scan) = node.as_batch_iceberg_scan() {
1213 let batch_iceberg_scan: &BatchIcebergScan = batch_iceberg_scan;
1214 let source_catalog = batch_iceberg_scan.source_catalog();
1215 if let Some(source_catalog) = source_catalog {
1216 let property =
1217 ConnectorProperties::extract(source_catalog.with_properties.clone(), false)?;
1218 return Ok(Some(SourceScanInfo::new(SourceFetchInfo {
1219 schema: batch_iceberg_scan.base.schema().clone(),
1220 connector: property,
1221 fetch_parameters: SourceFetchParameters::IcebergSpecificInfo(
1222 IcebergSpecificInfo {
1223 predicate: batch_iceberg_scan.predicate.clone(),
1224 iceberg_scan_type: batch_iceberg_scan.iceberg_scan_type(),
1225 snapshot_id: batch_iceberg_scan.snapshot_id(),
1226 },
1227 ),
1228 })));
1229 }
1230 } else if let Some(source_node) = node.as_batch_source() {
1231 let source_node: &BatchSource = source_node;
1233 let source_catalog = source_node.source_catalog();
1234 if let Some(source_catalog) = source_catalog {
1235 let property =
1236 ConnectorProperties::extract(source_catalog.with_properties.clone(), false)?;
1237 return Ok(Some(SourceScanInfo::new(SourceFetchInfo {
1238 schema: source_node.base.schema().clone(),
1239 connector: property,
1240 fetch_parameters: SourceFetchParameters::Empty,
1241 })));
1242 }
1243 }
1244
1245 node.inputs()
1246 .into_iter()
1247 .find_map(|n| Self::collect_stage_source(n).transpose())
1248 .transpose()
1249 }
1250
1251 fn collect_stage_file_scan(node: PlanRef) -> SchedulerResult<Option<FileScanInfo>> {
1252 if node.node_type() == BatchPlanNodeType::BatchExchange {
1253 return Ok(None);
1255 }
1256
1257 if let Some(batch_file_scan) = node.as_batch_file_scan() {
1258 return Ok(Some(FileScanInfo {
1259 file_location: batch_file_scan.core.file_location(),
1260 }));
1261 }
1262
1263 node.inputs()
1264 .into_iter()
1265 .find_map(|n| Self::collect_stage_file_scan(n).transpose())
1266 .transpose()
1267 }
1268
1269 fn collect_stage_table_scan(&self, node: PlanRef) -> SchedulerResult<Option<TableScanInfo>> {
1274 let build_table_scan_info = |name, table_catalog: &TableCatalog, scan_range| {
1275 let vnode_mapping = self
1276 .worker_node_manager
1277 .fragment_mapping(table_catalog.fragment_id)?;
1278 let partitions = derive_partitions(scan_range, table_catalog, &vnode_mapping)?;
1279 let info = TableScanInfo::new(name, partitions);
1280 Ok(Some(info))
1281 };
1282 if node.node_type() == BatchPlanNodeType::BatchExchange {
1283 return Ok(None);
1285 }
1286 if let Some(scan_node) = node.as_batch_sys_seq_scan() {
1287 let name = scan_node.core().table.name.clone();
1288 Ok(Some(TableScanInfo::system_table(name)))
1289 } else if let Some(scan_node) = node.as_batch_log_seq_scan() {
1290 build_table_scan_info(
1291 scan_node.core().table_name.clone(),
1292 &scan_node.core().table,
1293 &[],
1294 )
1295 } else if let Some(scan_node) = node.as_batch_seq_scan() {
1296 build_table_scan_info(
1297 scan_node.core().table_name().to_owned(),
1298 &scan_node.core().table_catalog,
1299 scan_node.scan_ranges(),
1300 )
1301 } else {
1302 node.inputs()
1303 .into_iter()
1304 .find_map(|n| self.collect_stage_table_scan(n).transpose())
1305 .transpose()
1306 }
1307 }
1308
1309 fn collect_dml_table_id(node: &PlanRef) -> Option<TableId> {
1311 if node.node_type() == BatchPlanNodeType::BatchExchange {
1312 return None;
1313 }
1314 if let Some(insert) = node.as_batch_insert() {
1315 Some(insert.core.table_id)
1316 } else if let Some(update) = node.as_batch_update() {
1317 Some(update.core.table_id)
1318 } else if let Some(delete) = node.as_batch_delete() {
1319 Some(delete.core.table_id)
1320 } else {
1321 node.inputs()
1322 .into_iter()
1323 .find_map(|n| Self::collect_dml_table_id(&n))
1324 }
1325 }
1326
1327 fn collect_stage_lookup_join_parallelism(
1328 &self,
1329 node: PlanRef,
1330 ) -> SchedulerResult<Option<usize>> {
1331 if node.node_type() == BatchPlanNodeType::BatchExchange {
1332 return Ok(None);
1334 }
1335 if let Some(lookup_join) = node.as_batch_lookup_join() {
1336 let table_catalog = lookup_join.right_table();
1337 let vnode_mapping = self
1338 .worker_node_manager
1339 .fragment_mapping(table_catalog.fragment_id)?;
1340 let parallelism = vnode_mapping.iter().sorted().dedup().count();
1341 Ok(Some(parallelism))
1342 } else {
1343 node.inputs()
1344 .into_iter()
1345 .find_map(|n| self.collect_stage_lookup_join_parallelism(n).transpose())
1346 .transpose()
1347 }
1348 }
1349}
1350
1351fn derive_partitions(
1354 scan_ranges: &[ScanRange],
1355 table_catalog: &TableCatalog,
1356 vnode_mapping: &WorkerSlotMapping,
1357) -> SchedulerResult<HashMap<WorkerSlotId, TablePartitionInfo>> {
1358 let vnode_mapping = if table_catalog.vnode_count.value() != vnode_mapping.len() {
1359 assert_eq!(
1363 table_catalog.vnode_count.value(),
1364 1,
1365 "fragment vnode count {} does not match table vnode count {}",
1366 vnode_mapping.len(),
1367 table_catalog.vnode_count.value(),
1368 );
1369 &WorkerSlotMapping::new_single(vnode_mapping.iter().next().unwrap())
1370 } else {
1371 vnode_mapping
1372 };
1373 let vnode_count = vnode_mapping.len();
1374
1375 let mut partitions: HashMap<WorkerSlotId, (BitmapBuilder, Vec<_>)> = HashMap::new();
1376
1377 if scan_ranges.is_empty() {
1378 return Ok(vnode_mapping
1379 .to_bitmaps()
1380 .into_iter()
1381 .map(|(k, vnode_bitmap)| {
1382 (
1383 k,
1384 TablePartitionInfo {
1385 vnode_bitmap,
1386 scan_ranges: vec![],
1387 },
1388 )
1389 })
1390 .collect());
1391 }
1392
1393 let table_distribution = TableDistribution::new_from_storage_table_desc(
1394 Some(Bitmap::ones(vnode_count).into()),
1395 &table_catalog.table_desc().try_to_protobuf()?,
1396 );
1397
1398 for scan_range in scan_ranges {
1399 let vnode = scan_range.try_compute_vnode(&table_distribution);
1400 match vnode {
1401 None => {
1402 vnode_mapping.to_bitmaps().into_iter().for_each(
1404 |(worker_slot_id, vnode_bitmap)| {
1405 let (bitmap, scan_ranges) = partitions
1406 .entry(worker_slot_id)
1407 .or_insert_with(|| (BitmapBuilder::zeroed(vnode_count), vec![]));
1408 vnode_bitmap
1409 .iter()
1410 .enumerate()
1411 .for_each(|(vnode, b)| bitmap.set(vnode, b));
1412 scan_ranges.push(scan_range.to_protobuf());
1413 },
1414 );
1415 }
1416 Some(vnode) => {
1418 let worker_slot_id = vnode_mapping[vnode];
1419 let (bitmap, scan_ranges) = partitions
1420 .entry(worker_slot_id)
1421 .or_insert_with(|| (BitmapBuilder::zeroed(vnode_count), vec![]));
1422 bitmap.set(vnode.to_index(), true);
1423 scan_ranges.push(scan_range.to_protobuf());
1424 }
1425 }
1426 }
1427
1428 Ok(partitions
1429 .into_iter()
1430 .map(|(k, (bitmap, scan_ranges))| {
1431 (
1432 k,
1433 TablePartitionInfo {
1434 vnode_bitmap: bitmap.finish(),
1435 scan_ranges,
1436 },
1437 )
1438 })
1439 .collect())
1440}
1441
1442#[cfg(test)]
1443mod tests {
1444 use std::collections::{HashMap, HashSet};
1445
1446 use risingwave_pb::batch_plan::plan_node::NodeBody;
1447
1448 use crate::optimizer::plan_node::BatchPlanNodeType;
1449 use crate::scheduler::plan_fragmenter::StageId;
1450
1451 #[tokio::test]
1452 async fn test_fragmenter() {
1453 let query = crate::scheduler::distributed::tests::create_query().await;
1454
1455 assert_eq!(query.stage_graph.root_stage_id, 0.into());
1456 assert_eq!(query.stage_graph.stages.len(), 4);
1457
1458 assert_eq!(
1460 query.stage_graph.child_edges[&0.into()],
1461 HashSet::from_iter([1.into()])
1462 );
1463 assert_eq!(
1464 query.stage_graph.child_edges[&1.into()],
1465 HashSet::from_iter([2.into(), 3.into()])
1466 );
1467 assert_eq!(query.stage_graph.child_edges[&2.into()], HashSet::new());
1468 assert_eq!(query.stage_graph.child_edges[&3.into()], HashSet::new());
1469
1470 assert_eq!(query.stage_graph.parent_edges[&0.into()], HashSet::new());
1472 assert_eq!(
1473 query.stage_graph.parent_edges[&1.into()],
1474 HashSet::from_iter([0.into()])
1475 );
1476 assert_eq!(
1477 query.stage_graph.parent_edges[&2.into()],
1478 HashSet::from_iter([1.into()])
1479 );
1480 assert_eq!(
1481 query.stage_graph.parent_edges[&3.into()],
1482 HashSet::from_iter([1.into()])
1483 );
1484
1485 {
1487 let stage_id_to_pos: HashMap<StageId, usize> = query
1488 .stage_graph
1489 .stage_ids_by_topo_order()
1490 .enumerate()
1491 .map(|(pos, stage_id)| (stage_id, pos))
1492 .collect();
1493
1494 for stage_id in query.stage_graph.stages.keys() {
1495 let stage_pos = stage_id_to_pos[stage_id];
1496 for child_stage_id in &query.stage_graph.child_edges[stage_id] {
1497 let child_pos = stage_id_to_pos[child_stage_id];
1498 assert!(stage_pos > child_pos);
1499 }
1500 }
1501 }
1502
1503 let root_exchange = query.stage_graph.stages.get(&0.into()).unwrap();
1505 assert_eq!(
1506 root_exchange.root.node_type(),
1507 BatchPlanNodeType::BatchExchange
1508 );
1509 assert_eq!(root_exchange.root.source_stage_id, Some(1.into()));
1510 assert!(matches!(root_exchange.root.node, NodeBody::Exchange(_)));
1511 assert_eq!(root_exchange.parallelism, Some(1));
1512 assert!(!root_exchange.has_table_scan());
1513
1514 let join_node = query.stage_graph.stages.get(&1.into()).unwrap();
1515 assert_eq!(join_node.root.node_type(), BatchPlanNodeType::BatchHashJoin);
1516 assert_eq!(join_node.parallelism, Some(24));
1517
1518 assert!(matches!(join_node.root.node, NodeBody::HashJoin(_)));
1519 assert_eq!(join_node.root.source_stage_id, None);
1520 assert_eq!(2, join_node.root.children.len());
1521
1522 assert!(matches!(
1523 join_node.root.children[0].node,
1524 NodeBody::Exchange(_)
1525 ));
1526 assert_eq!(join_node.root.children[0].source_stage_id, Some(2.into()));
1527 assert_eq!(0, join_node.root.children[0].children.len());
1528
1529 assert!(matches!(
1530 join_node.root.children[1].node,
1531 NodeBody::Exchange(_)
1532 ));
1533 assert_eq!(join_node.root.children[1].source_stage_id, Some(3.into()));
1534 assert_eq!(0, join_node.root.children[1].children.len());
1535 assert!(!join_node.has_table_scan());
1536
1537 let scan_node1 = query.stage_graph.stages.get(&2.into()).unwrap();
1538 assert_eq!(scan_node1.root.node_type(), BatchPlanNodeType::BatchSeqScan);
1539 assert_eq!(scan_node1.root.source_stage_id, None);
1540 assert_eq!(0, scan_node1.root.children.len());
1541 assert!(scan_node1.has_table_scan());
1542
1543 let scan_node2 = query.stage_graph.stages.get(&3.into()).unwrap();
1544 assert_eq!(scan_node2.root.node_type(), BatchPlanNodeType::BatchFilter);
1545 assert_eq!(scan_node2.root.source_stage_id, None);
1546 assert_eq!(1, scan_node2.root.children.len());
1547 assert!(scan_node2.has_table_scan());
1548 }
1549}