1use std::cmp::min;
16use std::collections::{HashMap, HashSet};
17use std::fmt::{Debug, Display, Formatter};
18use std::num::NonZeroU64;
19
20use anyhow::anyhow;
21use async_recursion::async_recursion;
22use enum_as_inner::EnumAsInner;
23use futures::TryStreamExt;
24use itertools::Itertools;
25use petgraph::{Directed, Graph};
26use pgwire::pg_server::SessionId;
27use risingwave_batch::error::BatchError;
28use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeSelector;
29use risingwave_common::bitmap::{Bitmap, BitmapBuilder};
30use risingwave_common::catalog::Schema;
31use risingwave_common::hash::table_distribution::TableDistribution;
32use risingwave_common::hash::{WorkerSlotId, WorkerSlotMapping};
33use risingwave_common::util::scan_range::ScanRange;
34use risingwave_connector::source::filesystem::opendal_source::opendal_enumerator::OpendalEnumerator;
35use risingwave_connector::source::filesystem::opendal_source::{
36 BatchPosixFsEnumerator, OpendalAzblob, OpendalGcs, OpendalS3,
37};
38use risingwave_connector::source::iceberg::{
39 IcebergFileScanTask, IcebergSplit, IcebergSplitEnumerator,
40};
41use risingwave_connector::source::kafka::KafkaSplitEnumerator;
42use risingwave_connector::source::prelude::DatagenSplitEnumerator;
43use risingwave_connector::source::reader::reader::build_opendal_fs_list_for_batch;
44use risingwave_connector::source::{
45 ConnectorProperties, SourceEnumeratorContext, SplitEnumerator, SplitImpl,
46};
47use risingwave_pb::batch_plan::plan_node::NodeBody;
48use risingwave_pb::batch_plan::{ExchangeInfo, ScanRange as ScanRangeProto};
49use risingwave_pb::plan_common::Field as PbField;
50use serde::ser::SerializeStruct;
51use serde::{Serialize, Serializer};
52use uuid::Uuid;
53
54use super::SchedulerError;
55use crate::TableCatalog;
56use crate::catalog::TableId;
57use crate::catalog::catalog_service::CatalogReader;
58use crate::optimizer::plan_node::generic::{GenericPlanRef, PhysicalPlanRef};
59use crate::optimizer::plan_node::{
60 BatchIcebergScan, BatchKafkaScan, BatchPlanNodeType, BatchPlanRef as PlanRef, BatchSource,
61 PlanNodeId,
62};
63use crate::optimizer::property::Distribution;
64use crate::scheduler::SchedulerResult;
65
66#[derive(Clone, Debug, Hash, Eq, PartialEq)]
67pub struct QueryId {
68 pub id: String,
69}
70
71impl std::fmt::Display for QueryId {
72 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
73 write!(f, "QueryId:{}", self.id)
74 }
75}
76
77#[derive(Copy, Clone, Hash, PartialEq, Eq, Ord, PartialOrd)]
78pub struct StageId(u32);
79
80impl Display for StageId {
81 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
82 write!(f, "{}", self.0)
83 }
84}
85
86impl Debug for StageId {
87 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
88 write!(f, "{:?}", self.0)
89 }
90}
91
92impl From<StageId> for u32 {
93 fn from(value: StageId) -> Self {
94 value.0
95 }
96}
97
98impl From<u32> for StageId {
99 fn from(value: u32) -> Self {
100 StageId(value)
101 }
102}
103
104impl Serialize for StageId {
105 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
106 where
107 S: Serializer,
108 {
109 self.0.serialize(serializer)
110 }
111}
112
113impl StageId {
114 pub fn inc(&mut self) {
115 self.0 += 1;
116 }
117}
118
119pub const ROOT_TASK_ID: u64 = 0;
121pub const ROOT_TASK_OUTPUT_ID: u64 = 0;
123pub type TaskId = u64;
124
125#[derive(Debug)]
127#[cfg_attr(test, derive(Clone))]
128pub struct ExecutionPlanNode {
129 pub plan_node_id: PlanNodeId,
130 pub plan_node_type: BatchPlanNodeType,
131 pub node: NodeBody,
132 pub schema: Vec<PbField>,
133
134 pub children: Vec<ExecutionPlanNode>,
135
136 pub source_stage_id: Option<StageId>,
141}
142
143impl Serialize for ExecutionPlanNode {
144 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
145 where
146 S: serde::Serializer,
147 {
148 let mut state = serializer.serialize_struct("QueryStage", 5)?;
149 state.serialize_field("plan_node_id", &self.plan_node_id)?;
150 state.serialize_field("plan_node_type", &self.plan_node_type)?;
151 state.serialize_field("schema", &self.schema)?;
152 state.serialize_field("children", &self.children)?;
153 state.serialize_field("source_stage_id", &self.source_stage_id)?;
154 state.end()
155 }
156}
157
158impl TryFrom<PlanRef> for ExecutionPlanNode {
159 type Error = SchedulerError;
160
161 fn try_from(plan_node: PlanRef) -> Result<Self, Self::Error> {
162 Ok(Self {
163 plan_node_id: plan_node.plan_base().id(),
164 plan_node_type: plan_node.node_type(),
165 node: plan_node.try_to_batch_prost_body()?,
166 children: vec![],
167 schema: plan_node.schema().to_prost(),
168 source_stage_id: None,
169 })
170 }
171}
172
173impl ExecutionPlanNode {
174 pub fn node_type(&self) -> BatchPlanNodeType {
175 self.plan_node_type
176 }
177}
178
179pub struct BatchPlanFragmenter {
181 query_id: QueryId,
182 next_stage_id: StageId,
183 worker_node_manager: WorkerNodeSelector,
184 catalog_reader: CatalogReader,
185
186 batch_parallelism: usize,
187
188 stage_graph_builder: Option<StageGraphBuilder>,
189 stage_graph: Option<StageGraph>,
190}
191
192impl Default for QueryId {
193 fn default() -> Self {
194 Self {
195 id: Uuid::new_v4().to_string(),
196 }
197 }
198}
199
200impl BatchPlanFragmenter {
201 pub fn new(
202 worker_node_manager: WorkerNodeSelector,
203 catalog_reader: CatalogReader,
204 batch_parallelism: Option<NonZeroU64>,
205 batch_node: PlanRef,
206 ) -> SchedulerResult<Self> {
207 let batch_parallelism = if let Some(num) = batch_parallelism {
212 min(
214 num.get() as usize,
215 worker_node_manager.schedule_unit_count(),
216 )
217 } else {
218 worker_node_manager.schedule_unit_count()
220 };
221
222 let mut plan_fragmenter = Self {
223 query_id: Default::default(),
224 next_stage_id: 0.into(),
225 worker_node_manager,
226 catalog_reader,
227 batch_parallelism,
228 stage_graph_builder: Some(StageGraphBuilder::new(batch_parallelism)),
229 stage_graph: None,
230 };
231 plan_fragmenter.split_into_stage(batch_node)?;
232 Ok(plan_fragmenter)
233 }
234
235 fn split_into_stage(&mut self, batch_node: PlanRef) -> SchedulerResult<()> {
237 let root_stage_id = self.new_stage(
238 batch_node,
239 Some(Distribution::Single.to_prost(
240 1,
241 &self.catalog_reader,
242 &self.worker_node_manager,
243 )?),
244 )?;
245 self.stage_graph = Some(
246 self.stage_graph_builder
247 .take()
248 .unwrap()
249 .build(root_stage_id),
250 );
251 Ok(())
252 }
253}
254
255#[derive(Debug)]
257#[cfg_attr(test, derive(Clone))]
258pub struct Query {
259 pub query_id: QueryId,
261 pub stage_graph: StageGraph,
262}
263
264impl Query {
265 pub fn leaf_stages(&self) -> Vec<StageId> {
266 let mut ret_leaf_stages = Vec::new();
267 for stage_id in self.stage_graph.stages.keys() {
268 if self
269 .stage_graph
270 .get_child_stages_unchecked(stage_id)
271 .is_empty()
272 {
273 ret_leaf_stages.push(*stage_id);
274 }
275 }
276 ret_leaf_stages
277 }
278
279 pub fn get_parents(&self, stage_id: &StageId) -> &HashSet<StageId> {
280 self.stage_graph.parent_edges.get(stage_id).unwrap()
281 }
282
283 pub fn root_stage_id(&self) -> StageId {
284 self.stage_graph.root_stage_id
285 }
286
287 pub fn query_id(&self) -> &QueryId {
288 &self.query_id
289 }
290
291 pub fn stages_with_table_scan(&self) -> HashSet<StageId> {
292 self.stage_graph
293 .stages
294 .iter()
295 .filter_map(|(stage_id, stage_query)| {
296 if stage_query.has_table_scan() {
297 Some(*stage_id)
298 } else {
299 None
300 }
301 })
302 .collect()
303 }
304
305 pub fn has_lookup_join_stage(&self) -> bool {
306 self.stage_graph
307 .stages
308 .iter()
309 .any(|(_stage_id, stage_query)| stage_query.has_lookup_join())
310 }
311
312 pub fn stage(&self, stage_id: StageId) -> &QueryStage {
313 &self.stage_graph.stages[&stage_id]
314 }
315}
316
317#[derive(Debug, Clone)]
318pub enum SourceFetchParameters {
319 KafkaTimebound {
320 lower: Option<i64>,
321 upper: Option<i64>,
322 },
323 Empty,
324}
325
326#[derive(Debug, Clone)]
327pub enum UnpartitionedData {
328 Iceberg(IcebergFileScanTask),
329}
330
331#[derive(Debug, Clone)]
332pub struct SourceFetchInfo {
333 pub schema: Schema,
334 pub connector: ConnectorProperties,
337 pub fetch_parameters: SourceFetchParameters,
340}
341
342#[derive(Clone)]
343pub enum SourceScanInfo {
344 Incomplete(SourceFetchInfo),
346 Unpartitioned(UnpartitionedData),
347 Complete(Vec<SplitImpl>),
348}
349
350impl SourceScanInfo {
351 pub fn new(fetch_info: SourceFetchInfo) -> Self {
352 Self::Incomplete(fetch_info)
353 }
354
355 pub async fn complete(self, batch_parallelism: usize) -> SchedulerResult<Self> {
356 match self {
357 SourceScanInfo::Incomplete(fetch_info) => fetch_info.complete(batch_parallelism).await,
358 SourceScanInfo::Unpartitioned(data) => data.complete(batch_parallelism),
359 SourceScanInfo::Complete(_) => {
360 unreachable!("Never call complete when SourceScanInfo is already complete")
361 }
362 }
363 }
364
365 pub fn split_info(&self) -> SchedulerResult<&Vec<SplitImpl>> {
366 match self {
367 Self::Incomplete(_) => Err(SchedulerError::Internal(anyhow!(
368 "Should not get split info from incomplete source scan info"
369 ))),
370 Self::Unpartitioned(_) => Err(SchedulerError::Internal(anyhow!(
371 "Should not get split info from unpartitioned source scan info"
372 ))),
373 Self::Complete(split_info) => Ok(split_info),
374 }
375 }
376}
377
378impl UnpartitionedData {
379 fn complete(self, batch_parallelism: usize) -> SchedulerResult<SourceScanInfo> {
380 macro_rules! split_iceberg_tasks {
381 ($tasks:expr, $variant:ident) => {
382 IcebergSplitEnumerator::split_n_vecs($tasks, batch_parallelism)
383 .into_iter()
384 .enumerate()
385 .map(|(id, tasks)| {
386 SplitImpl::Iceberg(IcebergSplit {
387 split_id: id.try_into().unwrap(),
388 task: IcebergFileScanTask::$variant(tasks),
389 })
390 })
391 .collect()
392 };
393 }
394
395 let splits = match self {
396 UnpartitionedData::Iceberg(task) => match task {
397 IcebergFileScanTask::Data(tasks) => split_iceberg_tasks!(tasks, Data),
398 IcebergFileScanTask::EqualityDelete(tasks) => {
399 split_iceberg_tasks!(tasks, EqualityDelete)
400 }
401 IcebergFileScanTask::PositionDelete(tasks) => {
402 split_iceberg_tasks!(tasks, PositionDelete)
403 }
404 },
405 };
406 Ok(SourceScanInfo::Complete(splits))
407 }
408}
409
410impl SourceFetchInfo {
411 async fn complete(self, _batch_parallelism: usize) -> SchedulerResult<SourceScanInfo> {
412 match (self.connector, self.fetch_parameters) {
413 (
414 ConnectorProperties::Kafka(prop),
415 SourceFetchParameters::KafkaTimebound { lower, upper },
416 ) => {
417 let mut kafka_enumerator =
418 KafkaSplitEnumerator::new(*prop, SourceEnumeratorContext::dummy().into())
419 .await?;
420 let split_info = kafka_enumerator
421 .list_splits_batch(lower, upper)
422 .await?
423 .into_iter()
424 .map(SplitImpl::Kafka)
425 .collect_vec();
426
427 Ok(SourceScanInfo::Complete(split_info))
428 }
429 (ConnectorProperties::Datagen(prop), SourceFetchParameters::Empty) => {
430 let mut datagen_enumerator =
431 DatagenSplitEnumerator::new(*prop, SourceEnumeratorContext::dummy().into())
432 .await?;
433 let split_info = datagen_enumerator.list_splits().await?;
434 let res = split_info.into_iter().map(SplitImpl::Datagen).collect_vec();
435
436 Ok(SourceScanInfo::Complete(res))
437 }
438 (ConnectorProperties::OpendalS3(prop), SourceFetchParameters::Empty) => {
439 let lister: OpendalEnumerator<OpendalS3> = OpendalEnumerator::new_s3_source(
440 &prop.s3_properties,
441 prop.assume_role,
442 prop.fs_common.compression_format,
443 )?;
444 let stream = build_opendal_fs_list_for_batch(lister);
445
446 let batch_res: Vec<_> = stream.try_collect().await?;
447 let res = batch_res
448 .into_iter()
449 .map(SplitImpl::OpendalS3)
450 .collect_vec();
451
452 Ok(SourceScanInfo::Complete(res))
453 }
454 (ConnectorProperties::Gcs(prop), SourceFetchParameters::Empty) => {
455 let lister: OpendalEnumerator<OpendalGcs> =
456 OpendalEnumerator::new_gcs_source(*prop)?;
457 let stream = build_opendal_fs_list_for_batch(lister);
458 let batch_res: Vec<_> = stream.try_collect().await?;
459 let res = batch_res.into_iter().map(SplitImpl::Gcs).collect_vec();
460
461 Ok(SourceScanInfo::Complete(res))
462 }
463 (ConnectorProperties::Azblob(prop), SourceFetchParameters::Empty) => {
464 let lister: OpendalEnumerator<OpendalAzblob> =
465 OpendalEnumerator::new_azblob_source(*prop)?;
466 let stream = build_opendal_fs_list_for_batch(lister);
467 let batch_res: Vec<_> = stream.try_collect().await?;
468 let res = batch_res.into_iter().map(SplitImpl::Azblob).collect_vec();
469
470 Ok(SourceScanInfo::Complete(res))
471 }
472 (ConnectorProperties::BatchPosixFs(prop), SourceFetchParameters::Empty) => {
473 use risingwave_connector::source::SplitEnumerator;
474 let mut enumerator = BatchPosixFsEnumerator::new(
475 *prop,
476 risingwave_connector::source::SourceEnumeratorContext::dummy().into(),
477 )
478 .await?;
479 let splits = enumerator.list_splits().await?;
480 let res = splits
481 .into_iter()
482 .map(SplitImpl::BatchPosixFs)
483 .collect_vec();
484
485 Ok(SourceScanInfo::Complete(res))
486 }
487 (connector, _) => Err(SchedulerError::Internal(anyhow!(
488 "Unsupported to query directly from this {} source, \
489 please create a table or streaming job from it",
490 connector.kind()
491 ))),
492 }
493 }
494}
495
496#[derive(Clone, Debug)]
497pub struct TableScanInfo {
498 name: String,
500
501 partitions: Option<HashMap<WorkerSlotId, TablePartitionInfo>>,
509}
510
511impl TableScanInfo {
512 pub fn new(name: String, partitions: HashMap<WorkerSlotId, TablePartitionInfo>) -> Self {
514 Self {
515 name,
516 partitions: Some(partitions),
517 }
518 }
519
520 pub fn system_table(name: String) -> Self {
522 Self {
523 name,
524 partitions: None,
525 }
526 }
527
528 pub fn name(&self) -> &str {
529 self.name.as_ref()
530 }
531
532 pub fn partitions(&self) -> Option<&HashMap<WorkerSlotId, TablePartitionInfo>> {
533 self.partitions.as_ref()
534 }
535}
536
537#[derive(Clone, Debug)]
538pub struct TablePartitionInfo {
539 pub vnode_bitmap: Bitmap,
540 pub scan_ranges: Vec<ScanRangeProto>,
541}
542
543#[derive(Clone, Debug, EnumAsInner)]
544pub enum PartitionInfo {
545 Table(TablePartitionInfo),
546 Source(Vec<SplitImpl>),
547 File(Vec<String>),
548}
549
550#[derive(Clone, Debug)]
551pub struct FileScanInfo {
552 pub file_location: Vec<String>,
553}
554
555#[cfg_attr(test, derive(Clone))]
557pub struct QueryStage {
558 pub id: StageId,
559 pub root: ExecutionPlanNode,
560 pub exchange_info: Option<ExchangeInfo>,
561 pub parallelism: Option<u32>,
562 pub table_scan_info: Option<TableScanInfo>,
564 pub source_info: Option<SourceScanInfo>,
565 pub file_scan_info: Option<FileScanInfo>,
566 pub has_lookup_join: bool,
567 pub dml_table_id: Option<TableId>,
568 pub session_id: SessionId,
569 pub batch_enable_distributed_dml: bool,
570
571 children_exchange_distribution: Option<HashMap<StageId, Distribution>>,
573}
574
575impl QueryStage {
576 pub fn has_table_scan(&self) -> bool {
580 self.table_scan_info.is_some()
581 }
582
583 pub fn has_lookup_join(&self) -> bool {
586 self.has_lookup_join
587 }
588
589 pub fn with_exchange_info(
590 self,
591 exchange_info: Option<ExchangeInfo>,
592 parallelism: Option<u32>,
593 ) -> Self {
594 if let Some(exchange_info) = exchange_info {
595 Self {
596 id: self.id,
597 root: self.root,
598 exchange_info: Some(exchange_info),
599 parallelism,
600 table_scan_info: self.table_scan_info,
601 source_info: self.source_info,
602 file_scan_info: self.file_scan_info,
603 has_lookup_join: self.has_lookup_join,
604 dml_table_id: self.dml_table_id,
605 session_id: self.session_id,
606 batch_enable_distributed_dml: self.batch_enable_distributed_dml,
607 children_exchange_distribution: self.children_exchange_distribution,
608 }
609 } else {
610 self
611 }
612 }
613
614 pub fn with_exchange_info_and_complete_source_info(
615 self,
616 exchange_info: Option<ExchangeInfo>,
617 source_info: SourceScanInfo,
618 task_parallelism: u32,
619 ) -> Self {
620 assert!(matches!(source_info, SourceScanInfo::Complete(_)));
621 let exchange_info = if let Some(exchange_info) = exchange_info {
622 Some(exchange_info)
623 } else {
624 self.exchange_info
625 };
626 Self {
627 id: self.id,
628 root: self.root,
629 exchange_info,
630 parallelism: Some(task_parallelism),
631 table_scan_info: self.table_scan_info,
632 source_info: Some(source_info),
633 file_scan_info: self.file_scan_info,
634 has_lookup_join: self.has_lookup_join,
635 dml_table_id: self.dml_table_id,
636 session_id: self.session_id,
637 batch_enable_distributed_dml: self.batch_enable_distributed_dml,
638 children_exchange_distribution: None,
639 }
640 }
641}
642
643impl Debug for QueryStage {
644 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
645 f.debug_struct("QueryStage")
646 .field("id", &self.id)
647 .field("parallelism", &self.parallelism)
648 .field("exchange_info", &self.exchange_info)
649 .field("has_table_scan", &self.has_table_scan())
650 .finish()
651 }
652}
653
654impl Serialize for QueryStage {
655 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
656 where
657 S: serde::Serializer,
658 {
659 let mut state = serializer.serialize_struct("QueryStage", 3)?;
660 state.serialize_field("root", &self.root)?;
661 state.serialize_field("parallelism", &self.parallelism)?;
662 state.serialize_field("exchange_info", &self.exchange_info)?;
663 state.end()
664 }
665}
666
667struct QueryStageBuilder {
668 id: StageId,
669 root: Option<ExecutionPlanNode>,
670 parallelism: Option<u32>,
671 exchange_info: Option<ExchangeInfo>,
672
673 children_stages: Vec<StageId>,
674 table_scan_info: Option<TableScanInfo>,
676 source_info: Option<SourceScanInfo>,
677 file_scan_file: Option<FileScanInfo>,
678 has_lookup_join: bool,
679 dml_table_id: Option<TableId>,
680 session_id: SessionId,
681 batch_enable_distributed_dml: bool,
682
683 children_exchange_distribution: HashMap<StageId, Distribution>,
684}
685
686impl QueryStageBuilder {
687 #[allow(clippy::too_many_arguments)]
688 fn new(
689 id: StageId,
690 parallelism: Option<u32>,
691 exchange_info: Option<ExchangeInfo>,
692 table_scan_info: Option<TableScanInfo>,
693 source_info: Option<SourceScanInfo>,
694 file_scan_file: Option<FileScanInfo>,
695 has_lookup_join: bool,
696 dml_table_id: Option<TableId>,
697 session_id: SessionId,
698 batch_enable_distributed_dml: bool,
699 ) -> Self {
700 Self {
701 id,
702 root: None,
703 parallelism,
704 exchange_info,
705 children_stages: vec![],
706 table_scan_info,
707 source_info,
708 file_scan_file,
709 has_lookup_join,
710 dml_table_id,
711 session_id,
712 batch_enable_distributed_dml,
713 children_exchange_distribution: HashMap::new(),
714 }
715 }
716
717 fn finish(self, stage_graph_builder: &mut StageGraphBuilder) -> StageId {
718 let children_exchange_distribution = if self.parallelism.is_none() {
719 Some(self.children_exchange_distribution)
720 } else {
721 None
722 };
723 let stage = QueryStage {
724 id: self.id,
725 root: self.root.unwrap(),
726 exchange_info: self.exchange_info,
727 parallelism: self.parallelism,
728 table_scan_info: self.table_scan_info,
729 source_info: self.source_info,
730 file_scan_info: self.file_scan_file,
731 has_lookup_join: self.has_lookup_join,
732 dml_table_id: self.dml_table_id,
733 session_id: self.session_id,
734 batch_enable_distributed_dml: self.batch_enable_distributed_dml,
735 children_exchange_distribution,
736 };
737
738 let stage_id = stage.id;
739 stage_graph_builder.add_node(stage);
740 for child_stage_id in self.children_stages {
741 stage_graph_builder.link_to_child(self.id, child_stage_id);
742 }
743 stage_id
744 }
745}
746
747#[derive(Debug, Serialize)]
749#[cfg_attr(test, derive(Clone))]
750pub struct StageGraph {
751 pub root_stage_id: StageId,
752 pub stages: HashMap<StageId, QueryStage>,
753 child_edges: HashMap<StageId, HashSet<StageId>>,
755 parent_edges: HashMap<StageId, HashSet<StageId>>,
757
758 batch_parallelism: usize,
759}
760
761enum StageCompleteInfo {
762 ExchangeInfo((Option<ExchangeInfo>, Option<u32>)),
763 ExchangeWithSourceInfo((Option<ExchangeInfo>, SourceScanInfo, u32)),
764}
765
766impl StageGraph {
767 pub fn get_child_stages_unchecked(&self, stage_id: &StageId) -> &HashSet<StageId> {
768 self.child_edges.get(stage_id).unwrap()
769 }
770
771 pub fn get_child_stages(&self, stage_id: &StageId) -> Option<&HashSet<StageId>> {
772 self.child_edges.get(stage_id)
773 }
774
775 pub fn stage_ids_by_topo_order(&self) -> impl Iterator<Item = StageId> {
777 let mut stack = Vec::with_capacity(self.stages.len());
778 stack.push(self.root_stage_id);
779 let mut ret = Vec::with_capacity(self.stages.len());
780 let mut existing = HashSet::with_capacity(self.stages.len());
781
782 while let Some(s) = stack.pop() {
783 if !existing.contains(&s) {
784 ret.push(s);
785 existing.insert(s);
786 stack.extend(&self.child_edges[&s]);
787 }
788 }
789
790 ret.into_iter().rev()
791 }
792
793 async fn complete(
794 self,
795 catalog_reader: &CatalogReader,
796 worker_node_manager: &WorkerNodeSelector,
797 ) -> SchedulerResult<StageGraph> {
798 let mut complete_stages = HashMap::new();
799 self.complete_stage(
800 self.root_stage_id,
801 None,
802 &mut complete_stages,
803 catalog_reader,
804 worker_node_manager,
805 )
806 .await?;
807 let mut stages = self.stages;
808 Ok(StageGraph {
809 root_stage_id: self.root_stage_id,
810 stages: complete_stages
811 .into_iter()
812 .map(|(stage_id, info)| {
813 let stage = stages.remove(&stage_id).expect("should exist");
814 let stage = match info {
815 StageCompleteInfo::ExchangeInfo((exchange_info, parallelism)) => {
816 stage.with_exchange_info(exchange_info, parallelism)
817 }
818 StageCompleteInfo::ExchangeWithSourceInfo((
819 exchange_info,
820 source_info,
821 parallelism,
822 )) => stage.with_exchange_info_and_complete_source_info(
823 exchange_info,
824 source_info,
825 parallelism,
826 ),
827 };
828 (stage_id, stage)
829 })
830 .collect(),
831 child_edges: self.child_edges,
832 parent_edges: self.parent_edges,
833 batch_parallelism: self.batch_parallelism,
834 })
835 }
836
837 #[async_recursion]
838 async fn complete_stage(
839 &self,
840 stage_id: StageId,
841 exchange_info: Option<ExchangeInfo>,
842 complete_stages: &mut HashMap<StageId, StageCompleteInfo>,
843 catalog_reader: &CatalogReader,
844 worker_node_manager: &WorkerNodeSelector,
845 ) -> SchedulerResult<()> {
846 let stage = &self.stages[&stage_id];
847 let parallelism = if stage.parallelism.is_some() {
848 complete_stages.insert(
850 stage.id,
851 StageCompleteInfo::ExchangeInfo((exchange_info, stage.parallelism)),
852 );
853 None
854 } else if matches!(
855 stage.source_info,
856 Some(SourceScanInfo::Incomplete(_)) | Some(SourceScanInfo::Unpartitioned(_))
857 ) {
858 let complete_source_info = stage
859 .source_info
860 .as_ref()
861 .unwrap()
862 .clone()
863 .complete(self.batch_parallelism)
864 .await?;
865
866 let task_parallelism = match &stage.source_info {
874 Some(SourceScanInfo::Incomplete(source_fetch_info)) => {
875 match source_fetch_info.connector {
876 ConnectorProperties::Gcs(_)
877 | ConnectorProperties::OpendalS3(_)
878 | ConnectorProperties::Azblob(_) => (min(
879 complete_source_info.split_info().unwrap().len() as u32,
880 (self.batch_parallelism / 2) as u32,
881 ))
882 .max(1),
883 _ => complete_source_info.split_info().unwrap().len() as u32,
884 }
885 }
886 _ => complete_source_info.split_info().unwrap().len() as u32,
887 };
888 let complete_stage_info = StageCompleteInfo::ExchangeWithSourceInfo((
891 exchange_info,
892 complete_source_info,
893 task_parallelism,
894 ));
895 complete_stages.insert(stage.id, complete_stage_info);
896 Some(task_parallelism)
897 } else {
898 assert!(stage.file_scan_info.is_some());
899 let parallelism = min(
900 self.batch_parallelism / 2,
901 stage.file_scan_info.as_ref().unwrap().file_location.len(),
902 );
903 complete_stages.insert(
904 stage.id,
905 StageCompleteInfo::ExchangeInfo((exchange_info, Some(parallelism as u32))),
906 );
907 None
908 };
909
910 for child_stage_id in self
911 .child_edges
912 .get(&stage.id)
913 .map(|edges| edges.iter())
914 .into_iter()
915 .flatten()
916 {
917 let exchange_info = if let Some(parallelism) = parallelism {
918 let exchange_distribution = stage
919 .children_exchange_distribution
920 .as_ref()
921 .unwrap()
922 .get(child_stage_id)
923 .expect("Exchange distribution is not consistent with the stage graph");
924 Some(exchange_distribution.to_prost(
925 parallelism,
926 catalog_reader,
927 worker_node_manager,
928 )?)
929 } else {
930 None
931 };
932 self.complete_stage(
933 *child_stage_id,
934 exchange_info,
935 complete_stages,
936 catalog_reader,
937 worker_node_manager,
938 )
939 .await?;
940 }
941
942 Ok(())
943 }
944
945 pub fn to_petgraph(&self) -> Graph<String, String, Directed> {
947 let mut graph = Graph::<String, String, Directed>::new();
948
949 let mut node_indices = HashMap::new();
950
951 for (&stage_id, stage_ref) in self.stages.iter().sorted_by_key(|(id, _)| **id) {
953 let node_label = format!("Stage {}: {:?}", stage_id, stage_ref);
954 let node_index = graph.add_node(node_label);
955 node_indices.insert(stage_id, node_index);
956 }
957
958 for (&parent_id, children) in &self.child_edges {
960 if let Some(&parent_index) = node_indices.get(&parent_id) {
961 for &child_id in children {
962 if let Some(&child_index) = node_indices.get(&child_id) {
963 graph.add_edge(parent_index, child_index, "".to_owned());
965 }
966 }
967 }
968 }
969
970 graph
971 }
972}
973
974struct StageGraphBuilder {
975 stages: HashMap<StageId, QueryStage>,
976 child_edges: HashMap<StageId, HashSet<StageId>>,
977 parent_edges: HashMap<StageId, HashSet<StageId>>,
978 batch_parallelism: usize,
979}
980
981impl StageGraphBuilder {
982 pub fn new(batch_parallelism: usize) -> Self {
983 Self {
984 stages: HashMap::new(),
985 child_edges: HashMap::new(),
986 parent_edges: HashMap::new(),
987 batch_parallelism,
988 }
989 }
990
991 pub fn build(self, root_stage_id: StageId) -> StageGraph {
992 StageGraph {
993 root_stage_id,
994 stages: self.stages,
995 child_edges: self.child_edges,
996 parent_edges: self.parent_edges,
997 batch_parallelism: self.batch_parallelism,
998 }
999 }
1000
1001 pub fn link_to_child(&mut self, parent_id: StageId, child_id: StageId) {
1004 self.child_edges
1005 .get_mut(&parent_id)
1006 .unwrap()
1007 .insert(child_id);
1008 self.parent_edges
1009 .get_mut(&child_id)
1010 .unwrap()
1011 .insert(parent_id);
1012 }
1013
1014 pub fn add_node(&mut self, stage: QueryStage) {
1015 self.child_edges.insert(stage.id, HashSet::new());
1017 self.parent_edges.insert(stage.id, HashSet::new());
1018 self.stages.insert(stage.id, stage);
1019 }
1020}
1021
1022impl BatchPlanFragmenter {
1023 pub async fn generate_complete_query(self) -> SchedulerResult<Query> {
1029 let stage_graph = self.stage_graph.unwrap();
1030 let new_stage_graph = stage_graph
1031 .complete(&self.catalog_reader, &self.worker_node_manager)
1032 .await?;
1033 Ok(Query {
1034 query_id: self.query_id,
1035 stage_graph: new_stage_graph,
1036 })
1037 }
1038
1039 fn new_stage(
1040 &mut self,
1041 root: PlanRef,
1042 exchange_info: Option<ExchangeInfo>,
1043 ) -> SchedulerResult<StageId> {
1044 let next_stage_id = self.next_stage_id;
1045 self.next_stage_id.inc();
1046
1047 let mut table_scan_info = None;
1048 let mut source_info = None;
1049 let mut file_scan_info = None;
1050
1051 if let Some(info) = self.collect_stage_table_scan(root.clone())? {
1054 table_scan_info = Some(info);
1055 } else if let Some(info) = Self::collect_stage_source(root.clone())? {
1056 source_info = Some(info);
1057 } else if let Some(info) = Self::collect_stage_file_scan(root.clone())? {
1058 file_scan_info = Some(info);
1059 }
1060
1061 let mut has_lookup_join = false;
1062 let parallelism = match root.distribution() {
1063 Distribution::Single => {
1064 if let Some(info) = &mut table_scan_info {
1065 if let Some(partitions) = &mut info.partitions {
1066 if partitions.len() != 1 {
1067 tracing::warn!(
1070 "The stage has single distribution, but contains a scan of table `{}` with {} partitions. A single random worker will be assigned",
1071 info.name,
1072 partitions.len()
1073 );
1074
1075 *partitions = partitions
1076 .drain()
1077 .take(1)
1078 .update(|(_, info)| {
1079 info.vnode_bitmap = Bitmap::ones(info.vnode_bitmap.len());
1080 })
1081 .collect();
1082 }
1083 } else {
1084 }
1086 } else if source_info.is_some() {
1087 return Err(SchedulerError::Internal(anyhow!(
1088 "The stage has single distribution, but contains a source operator"
1089 )));
1090 }
1091 1
1092 }
1093 _ => {
1094 if let Some(table_scan_info) = &table_scan_info {
1095 table_scan_info
1096 .partitions
1097 .as_ref()
1098 .map(|m| m.len())
1099 .unwrap_or(1)
1100 } else if let Some(lookup_join_parallelism) =
1101 self.collect_stage_lookup_join_parallelism(root.clone())?
1102 {
1103 has_lookup_join = true;
1104 lookup_join_parallelism
1105 } else if source_info.is_some() {
1106 0
1107 } else if file_scan_info.is_some() {
1108 1
1109 } else {
1110 self.batch_parallelism
1111 }
1112 }
1113 };
1114 if source_info.is_none() && file_scan_info.is_none() && parallelism == 0 {
1115 return Err(BatchError::EmptyWorkerNodes.into());
1116 }
1117 let parallelism = if parallelism == 0 {
1118 None
1119 } else {
1120 Some(parallelism as u32)
1121 };
1122 let dml_table_id = Self::collect_dml_table_id(&root);
1123 let mut builder = QueryStageBuilder::new(
1124 next_stage_id,
1125 parallelism,
1126 exchange_info,
1127 table_scan_info,
1128 source_info,
1129 file_scan_info,
1130 has_lookup_join,
1131 dml_table_id,
1132 root.ctx().session_ctx().session_id(),
1133 root.ctx()
1134 .session_ctx()
1135 .config()
1136 .batch_enable_distributed_dml(),
1137 );
1138
1139 self.visit_node(root, &mut builder, None)?;
1140
1141 Ok(builder.finish(self.stage_graph_builder.as_mut().unwrap()))
1142 }
1143
1144 fn visit_node(
1145 &mut self,
1146 node: PlanRef,
1147 builder: &mut QueryStageBuilder,
1148 parent_exec_node: Option<&mut ExecutionPlanNode>,
1149 ) -> SchedulerResult<()> {
1150 match node.node_type() {
1151 BatchPlanNodeType::BatchExchange => {
1152 self.visit_exchange(node, builder, parent_exec_node)?;
1153 }
1154 _ => {
1155 let mut execution_plan_node = ExecutionPlanNode::try_from(node.clone())?;
1156
1157 for child in node.inputs() {
1158 self.visit_node(child, builder, Some(&mut execution_plan_node))?;
1159 }
1160
1161 if let Some(parent) = parent_exec_node {
1162 parent.children.push(execution_plan_node);
1163 } else {
1164 builder.root = Some(execution_plan_node);
1165 }
1166 }
1167 }
1168 Ok(())
1169 }
1170
1171 fn visit_exchange(
1172 &mut self,
1173 node: PlanRef,
1174 builder: &mut QueryStageBuilder,
1175 parent_exec_node: Option<&mut ExecutionPlanNode>,
1176 ) -> SchedulerResult<()> {
1177 let mut execution_plan_node = ExecutionPlanNode::try_from(node.clone())?;
1178 let child_exchange_info = if let Some(parallelism) = builder.parallelism {
1179 Some(node.distribution().to_prost(
1180 parallelism,
1181 &self.catalog_reader,
1182 &self.worker_node_manager,
1183 )?)
1184 } else {
1185 None
1186 };
1187 let child_stage_id = self.new_stage(node.inputs()[0].clone(), child_exchange_info)?;
1188 execution_plan_node.source_stage_id = Some(child_stage_id);
1189 if builder.parallelism.is_none() {
1190 builder
1191 .children_exchange_distribution
1192 .insert(child_stage_id, node.distribution().clone());
1193 }
1194
1195 if let Some(parent) = parent_exec_node {
1196 parent.children.push(execution_plan_node);
1197 } else {
1198 builder.root = Some(execution_plan_node);
1199 }
1200
1201 builder.children_stages.push(child_stage_id);
1202 Ok(())
1203 }
1204
1205 fn collect_stage_source(node: PlanRef) -> SchedulerResult<Option<SourceScanInfo>> {
1210 if node.node_type() == BatchPlanNodeType::BatchExchange {
1211 return Ok(None);
1213 }
1214
1215 if let Some(batch_kafka_node) = node.as_batch_kafka_scan() {
1216 let batch_kafka_scan: &BatchKafkaScan = batch_kafka_node;
1217 let source_catalog = batch_kafka_scan.source_catalog();
1218 if let Some(source_catalog) = source_catalog {
1219 let property =
1220 ConnectorProperties::extract(source_catalog.with_properties.clone(), false)?;
1221 let timestamp_bound = batch_kafka_scan.kafka_timestamp_range_value();
1222 return Ok(Some(SourceScanInfo::new(SourceFetchInfo {
1223 schema: batch_kafka_scan.base.schema().clone(),
1224 connector: property,
1225 fetch_parameters: SourceFetchParameters::KafkaTimebound {
1226 lower: timestamp_bound.0,
1227 upper: timestamp_bound.1,
1228 },
1229 })));
1230 }
1231 } else if let Some(batch_iceberg_scan) = node.as_batch_iceberg_scan() {
1232 let batch_iceberg_scan: &BatchIcebergScan = batch_iceberg_scan;
1233 let task = batch_iceberg_scan.task.clone();
1234 return Ok(Some(SourceScanInfo::Unpartitioned(
1235 UnpartitionedData::Iceberg(task),
1236 )));
1237 } else if let Some(source_node) = node.as_batch_source() {
1238 let source_node: &BatchSource = source_node;
1240 let source_catalog = source_node.source_catalog();
1241 if let Some(source_catalog) = source_catalog {
1242 let property =
1243 ConnectorProperties::extract(source_catalog.with_properties.clone(), false)?;
1244 return Ok(Some(SourceScanInfo::new(SourceFetchInfo {
1245 schema: source_node.base.schema().clone(),
1246 connector: property,
1247 fetch_parameters: SourceFetchParameters::Empty,
1248 })));
1249 }
1250 }
1251
1252 node.inputs()
1253 .into_iter()
1254 .find_map(|n| Self::collect_stage_source(n).transpose())
1255 .transpose()
1256 }
1257
1258 fn collect_stage_file_scan(node: PlanRef) -> SchedulerResult<Option<FileScanInfo>> {
1259 if node.node_type() == BatchPlanNodeType::BatchExchange {
1260 return Ok(None);
1262 }
1263
1264 if let Some(batch_file_scan) = node.as_batch_file_scan() {
1265 return Ok(Some(FileScanInfo {
1266 file_location: batch_file_scan.core.file_location(),
1267 }));
1268 }
1269
1270 node.inputs()
1271 .into_iter()
1272 .find_map(|n| Self::collect_stage_file_scan(n).transpose())
1273 .transpose()
1274 }
1275
1276 fn collect_stage_table_scan(&self, node: PlanRef) -> SchedulerResult<Option<TableScanInfo>> {
1281 let build_table_scan_info = |name, table_catalog: &TableCatalog, scan_range| {
1282 let vnode_mapping = self
1283 .worker_node_manager
1284 .fragment_mapping(table_catalog.fragment_id)?;
1285 let partitions = derive_partitions(scan_range, table_catalog, &vnode_mapping)?;
1286 let info = TableScanInfo::new(name, partitions);
1287 Ok(Some(info))
1288 };
1289 if node.node_type() == BatchPlanNodeType::BatchExchange {
1290 return Ok(None);
1292 }
1293 if let Some(scan_node) = node.as_batch_sys_seq_scan() {
1294 let name = scan_node.core().table.name.clone();
1295 Ok(Some(TableScanInfo::system_table(name)))
1296 } else if let Some(scan_node) = node.as_batch_log_seq_scan() {
1297 build_table_scan_info(
1298 scan_node.core().table_name.clone(),
1299 &scan_node.core().table,
1300 &[],
1301 )
1302 } else if let Some(scan_node) = node.as_batch_seq_scan() {
1303 build_table_scan_info(
1304 scan_node.core().table_name().to_owned(),
1305 &scan_node.core().table_catalog,
1306 scan_node.scan_ranges(),
1307 )
1308 } else {
1309 node.inputs()
1310 .into_iter()
1311 .find_map(|n| self.collect_stage_table_scan(n).transpose())
1312 .transpose()
1313 }
1314 }
1315
1316 fn collect_dml_table_id(node: &PlanRef) -> Option<TableId> {
1318 if node.node_type() == BatchPlanNodeType::BatchExchange {
1319 return None;
1320 }
1321 if let Some(insert) = node.as_batch_insert() {
1322 Some(insert.core.table_id)
1323 } else if let Some(update) = node.as_batch_update() {
1324 Some(update.core.table_id)
1325 } else if let Some(delete) = node.as_batch_delete() {
1326 Some(delete.core.table_id)
1327 } else {
1328 node.inputs()
1329 .into_iter()
1330 .find_map(|n| Self::collect_dml_table_id(&n))
1331 }
1332 }
1333
1334 fn collect_stage_lookup_join_parallelism(
1335 &self,
1336 node: PlanRef,
1337 ) -> SchedulerResult<Option<usize>> {
1338 if node.node_type() == BatchPlanNodeType::BatchExchange {
1339 return Ok(None);
1341 }
1342 if let Some(lookup_join) = node.as_batch_lookup_join() {
1343 let table_catalog = lookup_join.right_table();
1344 let vnode_mapping = self
1345 .worker_node_manager
1346 .fragment_mapping(table_catalog.fragment_id)?;
1347 let parallelism = vnode_mapping.iter().sorted().dedup().count();
1348 Ok(Some(parallelism))
1349 } else {
1350 node.inputs()
1351 .into_iter()
1352 .find_map(|n| self.collect_stage_lookup_join_parallelism(n).transpose())
1353 .transpose()
1354 }
1355 }
1356}
1357
1358fn derive_partitions(
1361 scan_ranges: &[ScanRange],
1362 table_catalog: &TableCatalog,
1363 vnode_mapping: &WorkerSlotMapping,
1364) -> SchedulerResult<HashMap<WorkerSlotId, TablePartitionInfo>> {
1365 let vnode_mapping = if table_catalog.vnode_count.value() != vnode_mapping.len() {
1366 assert_eq!(
1370 table_catalog.vnode_count.value(),
1371 1,
1372 "fragment vnode count {} does not match table vnode count {}",
1373 vnode_mapping.len(),
1374 table_catalog.vnode_count.value(),
1375 );
1376 &WorkerSlotMapping::new_single(vnode_mapping.iter().next().unwrap())
1377 } else {
1378 vnode_mapping
1379 };
1380 let vnode_count = vnode_mapping.len();
1381
1382 let mut partitions: HashMap<WorkerSlotId, (BitmapBuilder, Vec<_>)> = HashMap::new();
1383
1384 if scan_ranges.is_empty() {
1385 return Ok(vnode_mapping
1386 .to_bitmaps()
1387 .into_iter()
1388 .map(|(k, vnode_bitmap)| {
1389 (
1390 k,
1391 TablePartitionInfo {
1392 vnode_bitmap,
1393 scan_ranges: vec![],
1394 },
1395 )
1396 })
1397 .collect());
1398 }
1399
1400 let table_distribution = TableDistribution::new_from_storage_table_desc(
1401 Some(Bitmap::ones(vnode_count).into()),
1402 &table_catalog.table_desc().try_to_protobuf()?,
1403 );
1404
1405 for scan_range in scan_ranges {
1406 let vnode = scan_range.try_compute_vnode(&table_distribution);
1407 match vnode {
1408 None => {
1409 vnode_mapping.to_bitmaps().into_iter().for_each(
1411 |(worker_slot_id, vnode_bitmap)| {
1412 let (bitmap, scan_ranges) = partitions
1413 .entry(worker_slot_id)
1414 .or_insert_with(|| (BitmapBuilder::zeroed(vnode_count), vec![]));
1415 vnode_bitmap
1416 .iter()
1417 .enumerate()
1418 .for_each(|(vnode, b)| bitmap.set(vnode, b));
1419 scan_ranges.push(scan_range.to_protobuf());
1420 },
1421 );
1422 }
1423 Some(vnode) => {
1425 let worker_slot_id = vnode_mapping[vnode];
1426 let (bitmap, scan_ranges) = partitions
1427 .entry(worker_slot_id)
1428 .or_insert_with(|| (BitmapBuilder::zeroed(vnode_count), vec![]));
1429 bitmap.set(vnode.to_index(), true);
1430 scan_ranges.push(scan_range.to_protobuf());
1431 }
1432 }
1433 }
1434
1435 Ok(partitions
1436 .into_iter()
1437 .map(|(k, (bitmap, scan_ranges))| {
1438 (
1439 k,
1440 TablePartitionInfo {
1441 vnode_bitmap: bitmap.finish(),
1442 scan_ranges,
1443 },
1444 )
1445 })
1446 .collect())
1447}
1448
1449#[cfg(test)]
1450mod tests {
1451 use std::collections::{HashMap, HashSet};
1452
1453 use risingwave_pb::batch_plan::plan_node::NodeBody;
1454
1455 use crate::optimizer::plan_node::BatchPlanNodeType;
1456 use crate::scheduler::plan_fragmenter::StageId;
1457
1458 #[tokio::test]
1459 async fn test_fragmenter() {
1460 let query = crate::scheduler::distributed::tests::create_query().await;
1461
1462 assert_eq!(query.stage_graph.root_stage_id, 0.into());
1463 assert_eq!(query.stage_graph.stages.len(), 4);
1464
1465 assert_eq!(
1467 query.stage_graph.child_edges[&0.into()],
1468 HashSet::from_iter([1.into()])
1469 );
1470 assert_eq!(
1471 query.stage_graph.child_edges[&1.into()],
1472 HashSet::from_iter([2.into(), 3.into()])
1473 );
1474 assert_eq!(query.stage_graph.child_edges[&2.into()], HashSet::new());
1475 assert_eq!(query.stage_graph.child_edges[&3.into()], HashSet::new());
1476
1477 assert_eq!(query.stage_graph.parent_edges[&0.into()], HashSet::new());
1479 assert_eq!(
1480 query.stage_graph.parent_edges[&1.into()],
1481 HashSet::from_iter([0.into()])
1482 );
1483 assert_eq!(
1484 query.stage_graph.parent_edges[&2.into()],
1485 HashSet::from_iter([1.into()])
1486 );
1487 assert_eq!(
1488 query.stage_graph.parent_edges[&3.into()],
1489 HashSet::from_iter([1.into()])
1490 );
1491
1492 {
1494 let stage_id_to_pos: HashMap<StageId, usize> = query
1495 .stage_graph
1496 .stage_ids_by_topo_order()
1497 .enumerate()
1498 .map(|(pos, stage_id)| (stage_id, pos))
1499 .collect();
1500
1501 for stage_id in query.stage_graph.stages.keys() {
1502 let stage_pos = stage_id_to_pos[stage_id];
1503 for child_stage_id in &query.stage_graph.child_edges[stage_id] {
1504 let child_pos = stage_id_to_pos[child_stage_id];
1505 assert!(stage_pos > child_pos);
1506 }
1507 }
1508 }
1509
1510 let root_exchange = query.stage_graph.stages.get(&0.into()).unwrap();
1512 assert_eq!(
1513 root_exchange.root.node_type(),
1514 BatchPlanNodeType::BatchExchange
1515 );
1516 assert_eq!(root_exchange.root.source_stage_id, Some(1.into()));
1517 assert!(matches!(root_exchange.root.node, NodeBody::Exchange(_)));
1518 assert_eq!(root_exchange.parallelism, Some(1));
1519 assert!(!root_exchange.has_table_scan());
1520
1521 let join_node = query.stage_graph.stages.get(&1.into()).unwrap();
1522 assert_eq!(join_node.root.node_type(), BatchPlanNodeType::BatchHashJoin);
1523 assert_eq!(join_node.parallelism, Some(24));
1524
1525 assert!(matches!(join_node.root.node, NodeBody::HashJoin(_)));
1526 assert_eq!(join_node.root.source_stage_id, None);
1527 assert_eq!(2, join_node.root.children.len());
1528
1529 assert!(matches!(
1530 join_node.root.children[0].node,
1531 NodeBody::Exchange(_)
1532 ));
1533 assert_eq!(join_node.root.children[0].source_stage_id, Some(2.into()));
1534 assert_eq!(0, join_node.root.children[0].children.len());
1535
1536 assert!(matches!(
1537 join_node.root.children[1].node,
1538 NodeBody::Exchange(_)
1539 ));
1540 assert_eq!(join_node.root.children[1].source_stage_id, Some(3.into()));
1541 assert_eq!(0, join_node.root.children[1].children.len());
1542 assert!(!join_node.has_table_scan());
1543
1544 let scan_node1 = query.stage_graph.stages.get(&2.into()).unwrap();
1545 assert_eq!(scan_node1.root.node_type(), BatchPlanNodeType::BatchSeqScan);
1546 assert_eq!(scan_node1.root.source_stage_id, None);
1547 assert_eq!(0, scan_node1.root.children.len());
1548 assert!(scan_node1.has_table_scan());
1549
1550 let scan_node2 = query.stage_graph.stages.get(&3.into()).unwrap();
1551 assert_eq!(scan_node2.root.node_type(), BatchPlanNodeType::BatchFilter);
1552 assert_eq!(scan_node2.root.source_stage_id, None);
1553 assert_eq!(1, scan_node2.root.children.len());
1554 assert!(scan_node2.has_table_scan());
1555 }
1556}