risingwave_frontend/scheduler/
plan_fragmenter.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::min;
16use std::collections::{HashMap, HashSet};
17use std::fmt::{Debug, Display, Formatter};
18use std::num::NonZeroU64;
19
20use anyhow::anyhow;
21use async_recursion::async_recursion;
22use enum_as_inner::EnumAsInner;
23use futures::TryStreamExt;
24use itertools::Itertools;
25use petgraph::{Directed, Graph};
26use pgwire::pg_server::SessionId;
27use risingwave_batch::error::BatchError;
28use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeSelector;
29use risingwave_common::bitmap::{Bitmap, BitmapBuilder};
30use risingwave_common::catalog::Schema;
31use risingwave_common::hash::table_distribution::TableDistribution;
32use risingwave_common::hash::{WorkerSlotId, WorkerSlotMapping};
33use risingwave_common::util::scan_range::ScanRange;
34use risingwave_connector::source::filesystem::opendal_source::opendal_enumerator::OpendalEnumerator;
35use risingwave_connector::source::filesystem::opendal_source::{
36    BatchPosixFsEnumerator, OpendalAzblob, OpendalGcs, OpendalS3,
37};
38use risingwave_connector::source::iceberg::{
39    IcebergFileScanTask, IcebergSplit, IcebergSplitEnumerator,
40};
41use risingwave_connector::source::kafka::KafkaSplitEnumerator;
42use risingwave_connector::source::prelude::DatagenSplitEnumerator;
43use risingwave_connector::source::reader::reader::build_opendal_fs_list_for_batch;
44use risingwave_connector::source::{
45    ConnectorProperties, SourceEnumeratorContext, SplitEnumerator, SplitImpl,
46};
47use risingwave_pb::batch_plan::plan_node::NodeBody;
48use risingwave_pb::batch_plan::{ExchangeInfo, ScanRange as ScanRangeProto};
49use risingwave_pb::plan_common::Field as PbField;
50use serde::ser::SerializeStruct;
51use serde::{Serialize, Serializer};
52use uuid::Uuid;
53
54use super::SchedulerError;
55use crate::TableCatalog;
56use crate::catalog::TableId;
57use crate::catalog::catalog_service::CatalogReader;
58use crate::optimizer::plan_node::generic::{GenericPlanRef, PhysicalPlanRef};
59use crate::optimizer::plan_node::{
60    BatchIcebergScan, BatchKafkaScan, BatchPlanNodeType, BatchPlanRef as PlanRef, BatchSource,
61    PlanNodeId,
62};
63use crate::optimizer::property::Distribution;
64use crate::scheduler::SchedulerResult;
65
66#[derive(Clone, Debug, Hash, Eq, PartialEq)]
67pub struct QueryId {
68    pub id: String,
69}
70
71impl std::fmt::Display for QueryId {
72    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
73        write!(f, "QueryId:{}", self.id)
74    }
75}
76
77#[derive(Copy, Clone, Hash, PartialEq, Eq, Ord, PartialOrd)]
78pub struct StageId(u32);
79
80impl Display for StageId {
81    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
82        write!(f, "{}", self.0)
83    }
84}
85
86impl Debug for StageId {
87    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
88        write!(f, "{:?}", self.0)
89    }
90}
91
92impl From<StageId> for u32 {
93    fn from(value: StageId) -> Self {
94        value.0
95    }
96}
97
98impl From<u32> for StageId {
99    fn from(value: u32) -> Self {
100        StageId(value)
101    }
102}
103
104impl Serialize for StageId {
105    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
106    where
107        S: Serializer,
108    {
109        self.0.serialize(serializer)
110    }
111}
112
113impl StageId {
114    pub fn inc(&mut self) {
115        self.0 += 1;
116    }
117}
118
119// Root stage always has only one task.
120pub const ROOT_TASK_ID: u64 = 0;
121// Root task has only one output.
122pub const ROOT_TASK_OUTPUT_ID: u64 = 0;
123pub type TaskId = u64;
124
125/// Generated by [`BatchPlanFragmenter`] and used in query execution graph.
126#[derive(Debug)]
127#[cfg_attr(test, derive(Clone))]
128pub struct ExecutionPlanNode {
129    pub plan_node_id: PlanNodeId,
130    pub plan_node_type: BatchPlanNodeType,
131    pub node: NodeBody,
132    pub schema: Vec<PbField>,
133
134    pub children: Vec<ExecutionPlanNode>,
135
136    /// The stage id of the source of `BatchExchange`.
137    /// Used to find `ExchangeSource` from scheduler when creating `PlanNode`.
138    ///
139    /// `None` when this node is not `BatchExchange`.
140    pub source_stage_id: Option<StageId>,
141}
142
143impl Serialize for ExecutionPlanNode {
144    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
145    where
146        S: serde::Serializer,
147    {
148        let mut state = serializer.serialize_struct("QueryStage", 5)?;
149        state.serialize_field("plan_node_id", &self.plan_node_id)?;
150        state.serialize_field("plan_node_type", &self.plan_node_type)?;
151        state.serialize_field("schema", &self.schema)?;
152        state.serialize_field("children", &self.children)?;
153        state.serialize_field("source_stage_id", &self.source_stage_id)?;
154        state.end()
155    }
156}
157
158impl TryFrom<PlanRef> for ExecutionPlanNode {
159    type Error = SchedulerError;
160
161    fn try_from(plan_node: PlanRef) -> Result<Self, Self::Error> {
162        Ok(Self {
163            plan_node_id: plan_node.plan_base().id(),
164            plan_node_type: plan_node.node_type(),
165            node: plan_node.try_to_batch_prost_body()?,
166            children: vec![],
167            schema: plan_node.schema().to_prost(),
168            source_stage_id: None,
169        })
170    }
171}
172
173impl ExecutionPlanNode {
174    pub fn node_type(&self) -> BatchPlanNodeType {
175        self.plan_node_type
176    }
177}
178
179/// `BatchPlanFragmenter` splits a query plan into fragments.
180pub struct BatchPlanFragmenter {
181    query_id: QueryId,
182    next_stage_id: StageId,
183    worker_node_manager: WorkerNodeSelector,
184    catalog_reader: CatalogReader,
185
186    batch_parallelism: usize,
187
188    stage_graph_builder: Option<StageGraphBuilder>,
189    stage_graph: Option<StageGraph>,
190}
191
192impl Default for QueryId {
193    fn default() -> Self {
194        Self {
195            id: Uuid::new_v4().to_string(),
196        }
197    }
198}
199
200impl BatchPlanFragmenter {
201    pub fn new(
202        worker_node_manager: WorkerNodeSelector,
203        catalog_reader: CatalogReader,
204        batch_parallelism: Option<NonZeroU64>,
205        batch_node: PlanRef,
206    ) -> SchedulerResult<Self> {
207        // if batch_parallelism is None, it means no limit, we will use the available nodes count as
208        // parallelism.
209        // if batch_parallelism is Some(num), we will use the min(num, the available
210        // nodes count) as parallelism.
211        let batch_parallelism = if let Some(num) = batch_parallelism {
212            // can be 0 if no available serving worker
213            min(
214                num.get() as usize,
215                worker_node_manager.schedule_unit_count(),
216            )
217        } else {
218            // can be 0 if no available serving worker
219            worker_node_manager.schedule_unit_count()
220        };
221
222        let mut plan_fragmenter = Self {
223            query_id: Default::default(),
224            next_stage_id: 0.into(),
225            worker_node_manager,
226            catalog_reader,
227            batch_parallelism,
228            stage_graph_builder: Some(StageGraphBuilder::new(batch_parallelism)),
229            stage_graph: None,
230        };
231        plan_fragmenter.split_into_stage(batch_node)?;
232        Ok(plan_fragmenter)
233    }
234
235    /// Split the plan node into each stages, based on exchange node.
236    fn split_into_stage(&mut self, batch_node: PlanRef) -> SchedulerResult<()> {
237        let root_stage_id = self.new_stage(
238            batch_node,
239            Some(Distribution::Single.to_prost(
240                1,
241                &self.catalog_reader,
242                &self.worker_node_manager,
243            )?),
244        )?;
245        self.stage_graph = Some(
246            self.stage_graph_builder
247                .take()
248                .unwrap()
249                .build(root_stage_id),
250        );
251        Ok(())
252    }
253}
254
255/// The fragmented query generated by [`BatchPlanFragmenter`].
256#[derive(Debug)]
257#[cfg_attr(test, derive(Clone))]
258pub struct Query {
259    /// Query id should always be unique.
260    pub query_id: QueryId,
261    pub stage_graph: StageGraph,
262}
263
264impl Query {
265    pub fn leaf_stages(&self) -> Vec<StageId> {
266        let mut ret_leaf_stages = Vec::new();
267        for stage_id in self.stage_graph.stages.keys() {
268            if self
269                .stage_graph
270                .get_child_stages_unchecked(stage_id)
271                .is_empty()
272            {
273                ret_leaf_stages.push(*stage_id);
274            }
275        }
276        ret_leaf_stages
277    }
278
279    pub fn get_parents(&self, stage_id: &StageId) -> &HashSet<StageId> {
280        self.stage_graph.parent_edges.get(stage_id).unwrap()
281    }
282
283    pub fn root_stage_id(&self) -> StageId {
284        self.stage_graph.root_stage_id
285    }
286
287    pub fn query_id(&self) -> &QueryId {
288        &self.query_id
289    }
290
291    pub fn stages_with_table_scan(&self) -> HashSet<StageId> {
292        self.stage_graph
293            .stages
294            .iter()
295            .filter_map(|(stage_id, stage_query)| {
296                if stage_query.has_table_scan() {
297                    Some(*stage_id)
298                } else {
299                    None
300                }
301            })
302            .collect()
303    }
304
305    pub fn has_lookup_join_stage(&self) -> bool {
306        self.stage_graph
307            .stages
308            .iter()
309            .any(|(_stage_id, stage_query)| stage_query.has_lookup_join())
310    }
311
312    pub fn stage(&self, stage_id: StageId) -> &QueryStage {
313        &self.stage_graph.stages[&stage_id]
314    }
315}
316
317#[derive(Debug, Clone)]
318pub enum SourceFetchParameters {
319    KafkaTimebound {
320        lower: Option<i64>,
321        upper: Option<i64>,
322    },
323    Empty,
324}
325
326#[derive(Debug, Clone)]
327pub enum UnpartitionedData {
328    Iceberg(IcebergFileScanTask),
329}
330
331#[derive(Debug, Clone)]
332pub struct SourceFetchInfo {
333    pub schema: Schema,
334    /// These are user-configured connector properties.
335    /// e.g. host, username, etc...
336    pub connector: ConnectorProperties,
337    /// These parameters are internally derived by the plan node.
338    /// e.g. predicate pushdown for iceberg, timebound for kafka.
339    pub fetch_parameters: SourceFetchParameters,
340}
341
342#[derive(Clone)]
343pub enum SourceScanInfo {
344    /// Split Info
345    Incomplete(SourceFetchInfo),
346    Unpartitioned(UnpartitionedData),
347    Complete(Vec<SplitImpl>),
348}
349
350impl SourceScanInfo {
351    pub fn new(fetch_info: SourceFetchInfo) -> Self {
352        Self::Incomplete(fetch_info)
353    }
354
355    pub async fn complete(self, batch_parallelism: usize) -> SchedulerResult<Self> {
356        match self {
357            SourceScanInfo::Incomplete(fetch_info) => fetch_info.complete(batch_parallelism).await,
358            SourceScanInfo::Unpartitioned(data) => data.complete(batch_parallelism),
359            SourceScanInfo::Complete(_) => {
360                unreachable!("Never call complete when SourceScanInfo is already complete")
361            }
362        }
363    }
364
365    pub fn split_info(&self) -> SchedulerResult<&Vec<SplitImpl>> {
366        match self {
367            Self::Incomplete(_) => Err(SchedulerError::Internal(anyhow!(
368                "Should not get split info from incomplete source scan info"
369            ))),
370            Self::Unpartitioned(_) => Err(SchedulerError::Internal(anyhow!(
371                "Should not get split info from unpartitioned source scan info"
372            ))),
373            Self::Complete(split_info) => Ok(split_info),
374        }
375    }
376}
377
378impl UnpartitionedData {
379    fn complete(self, batch_parallelism: usize) -> SchedulerResult<SourceScanInfo> {
380        macro_rules! split_iceberg_tasks {
381            ($tasks:expr, $variant:ident) => {
382                IcebergSplitEnumerator::split_n_vecs($tasks, batch_parallelism)
383                    .into_iter()
384                    .enumerate()
385                    .map(|(id, tasks)| {
386                        SplitImpl::Iceberg(IcebergSplit {
387                            split_id: id.try_into().unwrap(),
388                            task: IcebergFileScanTask::$variant(tasks),
389                        })
390                    })
391                    .collect()
392            };
393        }
394
395        let splits = match self {
396            UnpartitionedData::Iceberg(task) => match task {
397                IcebergFileScanTask::Data(tasks) => split_iceberg_tasks!(tasks, Data),
398                IcebergFileScanTask::EqualityDelete(tasks) => {
399                    split_iceberg_tasks!(tasks, EqualityDelete)
400                }
401                IcebergFileScanTask::PositionDelete(tasks) => {
402                    split_iceberg_tasks!(tasks, PositionDelete)
403                }
404            },
405        };
406        Ok(SourceScanInfo::Complete(splits))
407    }
408}
409
410impl SourceFetchInfo {
411    async fn complete(self, _batch_parallelism: usize) -> SchedulerResult<SourceScanInfo> {
412        match (self.connector, self.fetch_parameters) {
413            (
414                ConnectorProperties::Kafka(prop),
415                SourceFetchParameters::KafkaTimebound { lower, upper },
416            ) => {
417                let mut kafka_enumerator =
418                    KafkaSplitEnumerator::new(*prop, SourceEnumeratorContext::dummy().into())
419                        .await?;
420                let split_info = kafka_enumerator
421                    .list_splits_batch(lower, upper)
422                    .await?
423                    .into_iter()
424                    .map(SplitImpl::Kafka)
425                    .collect_vec();
426
427                Ok(SourceScanInfo::Complete(split_info))
428            }
429            (ConnectorProperties::Datagen(prop), SourceFetchParameters::Empty) => {
430                let mut datagen_enumerator =
431                    DatagenSplitEnumerator::new(*prop, SourceEnumeratorContext::dummy().into())
432                        .await?;
433                let split_info = datagen_enumerator.list_splits().await?;
434                let res = split_info.into_iter().map(SplitImpl::Datagen).collect_vec();
435
436                Ok(SourceScanInfo::Complete(res))
437            }
438            (ConnectorProperties::OpendalS3(prop), SourceFetchParameters::Empty) => {
439                let lister: OpendalEnumerator<OpendalS3> = OpendalEnumerator::new_s3_source(
440                    &prop.s3_properties,
441                    prop.assume_role,
442                    prop.fs_common.compression_format,
443                )?;
444                let stream = build_opendal_fs_list_for_batch(lister);
445
446                let batch_res: Vec<_> = stream.try_collect().await?;
447                let res = batch_res
448                    .into_iter()
449                    .map(SplitImpl::OpendalS3)
450                    .collect_vec();
451
452                Ok(SourceScanInfo::Complete(res))
453            }
454            (ConnectorProperties::Gcs(prop), SourceFetchParameters::Empty) => {
455                let lister: OpendalEnumerator<OpendalGcs> =
456                    OpendalEnumerator::new_gcs_source(*prop)?;
457                let stream = build_opendal_fs_list_for_batch(lister);
458                let batch_res: Vec<_> = stream.try_collect().await?;
459                let res = batch_res.into_iter().map(SplitImpl::Gcs).collect_vec();
460
461                Ok(SourceScanInfo::Complete(res))
462            }
463            (ConnectorProperties::Azblob(prop), SourceFetchParameters::Empty) => {
464                let lister: OpendalEnumerator<OpendalAzblob> =
465                    OpendalEnumerator::new_azblob_source(*prop)?;
466                let stream = build_opendal_fs_list_for_batch(lister);
467                let batch_res: Vec<_> = stream.try_collect().await?;
468                let res = batch_res.into_iter().map(SplitImpl::Azblob).collect_vec();
469
470                Ok(SourceScanInfo::Complete(res))
471            }
472            (ConnectorProperties::BatchPosixFs(prop), SourceFetchParameters::Empty) => {
473                use risingwave_connector::source::SplitEnumerator;
474                let mut enumerator = BatchPosixFsEnumerator::new(
475                    *prop,
476                    risingwave_connector::source::SourceEnumeratorContext::dummy().into(),
477                )
478                .await?;
479                let splits = enumerator.list_splits().await?;
480                let res = splits
481                    .into_iter()
482                    .map(SplitImpl::BatchPosixFs)
483                    .collect_vec();
484
485                Ok(SourceScanInfo::Complete(res))
486            }
487            (connector, _) => Err(SchedulerError::Internal(anyhow!(
488                "Unsupported to query directly from this {} source, \
489                 please create a table or streaming job from it",
490                connector.kind()
491            ))),
492        }
493    }
494}
495
496#[derive(Clone, Debug)]
497pub struct TableScanInfo {
498    /// The name of the table to scan.
499    name: String,
500
501    /// Indicates the table partitions to be read by scan tasks. Unnecessary partitions are already
502    /// pruned.
503    ///
504    /// For singleton table, this field is still `Some` and only contains a single partition with
505    /// full vnode bitmap, since we need to know where to schedule the singleton scan task.
506    ///
507    /// `None` iff the table is a system table.
508    partitions: Option<HashMap<WorkerSlotId, TablePartitionInfo>>,
509}
510
511impl TableScanInfo {
512    /// For normal tables, `partitions` should always be `Some`.
513    pub fn new(name: String, partitions: HashMap<WorkerSlotId, TablePartitionInfo>) -> Self {
514        Self {
515            name,
516            partitions: Some(partitions),
517        }
518    }
519
520    /// For system table, there's no partition info.
521    pub fn system_table(name: String) -> Self {
522        Self {
523            name,
524            partitions: None,
525        }
526    }
527
528    pub fn name(&self) -> &str {
529        self.name.as_ref()
530    }
531
532    pub fn partitions(&self) -> Option<&HashMap<WorkerSlotId, TablePartitionInfo>> {
533        self.partitions.as_ref()
534    }
535}
536
537#[derive(Clone, Debug)]
538pub struct TablePartitionInfo {
539    pub vnode_bitmap: Bitmap,
540    pub scan_ranges: Vec<ScanRangeProto>,
541}
542
543#[derive(Clone, Debug, EnumAsInner)]
544pub enum PartitionInfo {
545    Table(TablePartitionInfo),
546    Source(Vec<SplitImpl>),
547    File(Vec<String>),
548}
549
550#[derive(Clone, Debug)]
551pub struct FileScanInfo {
552    pub file_location: Vec<String>,
553}
554
555/// Fragment part of `Query`.
556#[cfg_attr(test, derive(Clone))]
557pub struct QueryStage {
558    pub id: StageId,
559    pub root: ExecutionPlanNode,
560    pub exchange_info: Option<ExchangeInfo>,
561    pub parallelism: Option<u32>,
562    /// Indicates whether this stage contains a table scan node and the table's information if so.
563    pub table_scan_info: Option<TableScanInfo>,
564    pub source_info: Option<SourceScanInfo>,
565    pub file_scan_info: Option<FileScanInfo>,
566    pub has_lookup_join: bool,
567    pub dml_table_id: Option<TableId>,
568    pub session_id: SessionId,
569    pub batch_enable_distributed_dml: bool,
570
571    /// Used to generate exchange information when complete source scan information.
572    children_exchange_distribution: Option<HashMap<StageId, Distribution>>,
573}
574
575impl QueryStage {
576    /// If true, this stage contains table scan executor that creates
577    /// Hummock iterators to read data from table. The iterator is initialized during
578    /// the executor building process on the batch execution engine.
579    pub fn has_table_scan(&self) -> bool {
580        self.table_scan_info.is_some()
581    }
582
583    /// If true, this stage contains lookup join executor.
584    /// We need to delay epoch unpin util the end of the query.
585    pub fn has_lookup_join(&self) -> bool {
586        self.has_lookup_join
587    }
588
589    pub fn with_exchange_info(
590        self,
591        exchange_info: Option<ExchangeInfo>,
592        parallelism: Option<u32>,
593    ) -> Self {
594        if let Some(exchange_info) = exchange_info {
595            Self {
596                id: self.id,
597                root: self.root,
598                exchange_info: Some(exchange_info),
599                parallelism,
600                table_scan_info: self.table_scan_info,
601                source_info: self.source_info,
602                file_scan_info: self.file_scan_info,
603                has_lookup_join: self.has_lookup_join,
604                dml_table_id: self.dml_table_id,
605                session_id: self.session_id,
606                batch_enable_distributed_dml: self.batch_enable_distributed_dml,
607                children_exchange_distribution: self.children_exchange_distribution,
608            }
609        } else {
610            self
611        }
612    }
613
614    pub fn with_exchange_info_and_complete_source_info(
615        self,
616        exchange_info: Option<ExchangeInfo>,
617        source_info: SourceScanInfo,
618        task_parallelism: u32,
619    ) -> Self {
620        assert!(matches!(source_info, SourceScanInfo::Complete(_)));
621        let exchange_info = if let Some(exchange_info) = exchange_info {
622            Some(exchange_info)
623        } else {
624            self.exchange_info
625        };
626        Self {
627            id: self.id,
628            root: self.root,
629            exchange_info,
630            parallelism: Some(task_parallelism),
631            table_scan_info: self.table_scan_info,
632            source_info: Some(source_info),
633            file_scan_info: self.file_scan_info,
634            has_lookup_join: self.has_lookup_join,
635            dml_table_id: self.dml_table_id,
636            session_id: self.session_id,
637            batch_enable_distributed_dml: self.batch_enable_distributed_dml,
638            children_exchange_distribution: None,
639        }
640    }
641}
642
643impl Debug for QueryStage {
644    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
645        f.debug_struct("QueryStage")
646            .field("id", &self.id)
647            .field("parallelism", &self.parallelism)
648            .field("exchange_info", &self.exchange_info)
649            .field("has_table_scan", &self.has_table_scan())
650            .finish()
651    }
652}
653
654impl Serialize for QueryStage {
655    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
656    where
657        S: serde::Serializer,
658    {
659        let mut state = serializer.serialize_struct("QueryStage", 3)?;
660        state.serialize_field("root", &self.root)?;
661        state.serialize_field("parallelism", &self.parallelism)?;
662        state.serialize_field("exchange_info", &self.exchange_info)?;
663        state.end()
664    }
665}
666
667struct QueryStageBuilder {
668    id: StageId,
669    root: Option<ExecutionPlanNode>,
670    parallelism: Option<u32>,
671    exchange_info: Option<ExchangeInfo>,
672
673    children_stages: Vec<StageId>,
674    /// See also [`QueryStage::table_scan_info`].
675    table_scan_info: Option<TableScanInfo>,
676    source_info: Option<SourceScanInfo>,
677    file_scan_file: Option<FileScanInfo>,
678    has_lookup_join: bool,
679    dml_table_id: Option<TableId>,
680    session_id: SessionId,
681    batch_enable_distributed_dml: bool,
682
683    children_exchange_distribution: HashMap<StageId, Distribution>,
684}
685
686impl QueryStageBuilder {
687    #[allow(clippy::too_many_arguments)]
688    fn new(
689        id: StageId,
690        parallelism: Option<u32>,
691        exchange_info: Option<ExchangeInfo>,
692        table_scan_info: Option<TableScanInfo>,
693        source_info: Option<SourceScanInfo>,
694        file_scan_file: Option<FileScanInfo>,
695        has_lookup_join: bool,
696        dml_table_id: Option<TableId>,
697        session_id: SessionId,
698        batch_enable_distributed_dml: bool,
699    ) -> Self {
700        Self {
701            id,
702            root: None,
703            parallelism,
704            exchange_info,
705            children_stages: vec![],
706            table_scan_info,
707            source_info,
708            file_scan_file,
709            has_lookup_join,
710            dml_table_id,
711            session_id,
712            batch_enable_distributed_dml,
713            children_exchange_distribution: HashMap::new(),
714        }
715    }
716
717    fn finish(self, stage_graph_builder: &mut StageGraphBuilder) -> StageId {
718        let children_exchange_distribution = if self.parallelism.is_none() {
719            Some(self.children_exchange_distribution)
720        } else {
721            None
722        };
723        let stage = QueryStage {
724            id: self.id,
725            root: self.root.unwrap(),
726            exchange_info: self.exchange_info,
727            parallelism: self.parallelism,
728            table_scan_info: self.table_scan_info,
729            source_info: self.source_info,
730            file_scan_info: self.file_scan_file,
731            has_lookup_join: self.has_lookup_join,
732            dml_table_id: self.dml_table_id,
733            session_id: self.session_id,
734            batch_enable_distributed_dml: self.batch_enable_distributed_dml,
735            children_exchange_distribution,
736        };
737
738        let stage_id = stage.id;
739        stage_graph_builder.add_node(stage);
740        for child_stage_id in self.children_stages {
741            stage_graph_builder.link_to_child(self.id, child_stage_id);
742        }
743        stage_id
744    }
745}
746
747/// Maintains how each stage are connected.
748#[derive(Debug, Serialize)]
749#[cfg_attr(test, derive(Clone))]
750pub struct StageGraph {
751    pub root_stage_id: StageId,
752    pub stages: HashMap<StageId, QueryStage>,
753    /// Traverse from top to down. Used in split plan into stages.
754    child_edges: HashMap<StageId, HashSet<StageId>>,
755    /// Traverse from down to top. Used in schedule each stage.
756    parent_edges: HashMap<StageId, HashSet<StageId>>,
757
758    batch_parallelism: usize,
759}
760
761enum StageCompleteInfo {
762    ExchangeInfo((Option<ExchangeInfo>, Option<u32>)),
763    ExchangeWithSourceInfo((Option<ExchangeInfo>, SourceScanInfo, u32)),
764}
765
766impl StageGraph {
767    pub fn get_child_stages_unchecked(&self, stage_id: &StageId) -> &HashSet<StageId> {
768        self.child_edges.get(stage_id).unwrap()
769    }
770
771    pub fn get_child_stages(&self, stage_id: &StageId) -> Option<&HashSet<StageId>> {
772        self.child_edges.get(stage_id)
773    }
774
775    /// Returns stage ids in topology order, s.t. child stage always appears before its parent.
776    pub fn stage_ids_by_topo_order(&self) -> impl Iterator<Item = StageId> {
777        let mut stack = Vec::with_capacity(self.stages.len());
778        stack.push(self.root_stage_id);
779        let mut ret = Vec::with_capacity(self.stages.len());
780        let mut existing = HashSet::with_capacity(self.stages.len());
781
782        while let Some(s) = stack.pop() {
783            if !existing.contains(&s) {
784                ret.push(s);
785                existing.insert(s);
786                stack.extend(&self.child_edges[&s]);
787            }
788        }
789
790        ret.into_iter().rev()
791    }
792
793    async fn complete(
794        self,
795        catalog_reader: &CatalogReader,
796        worker_node_manager: &WorkerNodeSelector,
797    ) -> SchedulerResult<StageGraph> {
798        let mut complete_stages = HashMap::new();
799        self.complete_stage(
800            self.root_stage_id,
801            None,
802            &mut complete_stages,
803            catalog_reader,
804            worker_node_manager,
805        )
806        .await?;
807        let mut stages = self.stages;
808        Ok(StageGraph {
809            root_stage_id: self.root_stage_id,
810            stages: complete_stages
811                .into_iter()
812                .map(|(stage_id, info)| {
813                    let stage = stages.remove(&stage_id).expect("should exist");
814                    let stage = match info {
815                        StageCompleteInfo::ExchangeInfo((exchange_info, parallelism)) => {
816                            stage.with_exchange_info(exchange_info, parallelism)
817                        }
818                        StageCompleteInfo::ExchangeWithSourceInfo((
819                            exchange_info,
820                            source_info,
821                            parallelism,
822                        )) => stage.with_exchange_info_and_complete_source_info(
823                            exchange_info,
824                            source_info,
825                            parallelism,
826                        ),
827                    };
828                    (stage_id, stage)
829                })
830                .collect(),
831            child_edges: self.child_edges,
832            parent_edges: self.parent_edges,
833            batch_parallelism: self.batch_parallelism,
834        })
835    }
836
837    #[async_recursion]
838    async fn complete_stage(
839        &self,
840        stage_id: StageId,
841        exchange_info: Option<ExchangeInfo>,
842        complete_stages: &mut HashMap<StageId, StageCompleteInfo>,
843        catalog_reader: &CatalogReader,
844        worker_node_manager: &WorkerNodeSelector,
845    ) -> SchedulerResult<()> {
846        let stage = &self.stages[&stage_id];
847        let parallelism = if stage.parallelism.is_some() {
848            // If the stage has parallelism, it means it's a complete stage.
849            complete_stages.insert(
850                stage.id,
851                StageCompleteInfo::ExchangeInfo((exchange_info, stage.parallelism)),
852            );
853            None
854        } else if matches!(
855            stage.source_info,
856            Some(SourceScanInfo::Incomplete(_)) | Some(SourceScanInfo::Unpartitioned(_))
857        ) {
858            let complete_source_info = stage
859                .source_info
860                .as_ref()
861                .unwrap()
862                .clone()
863                .complete(self.batch_parallelism)
864                .await?;
865
866            // For batch reading file source, the number of files involved is typically large.
867            // In order to avoid generating a task for each file, the parallelism of tasks is limited here.
868            // The minimum `task_parallelism` is 1. Additionally, `task_parallelism`
869            // must be greater than the number of files to read. Therefore, we first take the
870            // minimum of the number of files and (self.batch_parallelism / 2). If the number of
871            // files is 0, we set task_parallelism to 1.
872
873            let task_parallelism = match &stage.source_info {
874                Some(SourceScanInfo::Incomplete(source_fetch_info)) => {
875                    match source_fetch_info.connector {
876                        ConnectorProperties::Gcs(_)
877                        | ConnectorProperties::OpendalS3(_)
878                        | ConnectorProperties::Azblob(_) => (min(
879                            complete_source_info.split_info().unwrap().len() as u32,
880                            (self.batch_parallelism / 2) as u32,
881                        ))
882                        .max(1),
883                        _ => complete_source_info.split_info().unwrap().len() as u32,
884                    }
885                }
886                _ => complete_source_info.split_info().unwrap().len() as u32,
887            };
888            // For file source batch read, all the files  to be read are divide into several parts to prevent the task from taking up too many resources.
889            // todo(wcy-fdu): Currently it will be divided into half of batch_parallelism groups, and this will be changed to configurable later.
890            let complete_stage_info = StageCompleteInfo::ExchangeWithSourceInfo((
891                exchange_info,
892                complete_source_info,
893                task_parallelism,
894            ));
895            complete_stages.insert(stage.id, complete_stage_info);
896            Some(task_parallelism)
897        } else {
898            assert!(stage.file_scan_info.is_some());
899            let parallelism = min(
900                self.batch_parallelism / 2,
901                stage.file_scan_info.as_ref().unwrap().file_location.len(),
902            );
903            complete_stages.insert(
904                stage.id,
905                StageCompleteInfo::ExchangeInfo((exchange_info, Some(parallelism as u32))),
906            );
907            None
908        };
909
910        for child_stage_id in self
911            .child_edges
912            .get(&stage.id)
913            .map(|edges| edges.iter())
914            .into_iter()
915            .flatten()
916        {
917            let exchange_info = if let Some(parallelism) = parallelism {
918                let exchange_distribution = stage
919                    .children_exchange_distribution
920                    .as_ref()
921                    .unwrap()
922                    .get(child_stage_id)
923                    .expect("Exchange distribution is not consistent with the stage graph");
924                Some(exchange_distribution.to_prost(
925                    parallelism,
926                    catalog_reader,
927                    worker_node_manager,
928                )?)
929            } else {
930                None
931            };
932            self.complete_stage(
933                *child_stage_id,
934                exchange_info,
935                complete_stages,
936                catalog_reader,
937                worker_node_manager,
938            )
939            .await?;
940        }
941
942        Ok(())
943    }
944
945    /// Converts the `StageGraph` into a `petgraph::graph::Graph<String, String>`.
946    pub fn to_petgraph(&self) -> Graph<String, String, Directed> {
947        let mut graph = Graph::<String, String, Directed>::new();
948
949        let mut node_indices = HashMap::new();
950
951        // Add all stages as nodes
952        for (&stage_id, stage_ref) in self.stages.iter().sorted_by_key(|(id, _)| **id) {
953            let node_label = format!("Stage {}: {:?}", stage_id, stage_ref);
954            let node_index = graph.add_node(node_label);
955            node_indices.insert(stage_id, node_index);
956        }
957
958        // Add edges between stages based on child_edges
959        for (&parent_id, children) in &self.child_edges {
960            if let Some(&parent_index) = node_indices.get(&parent_id) {
961                for &child_id in children {
962                    if let Some(&child_index) = node_indices.get(&child_id) {
963                        // Add an edge from parent to child
964                        graph.add_edge(parent_index, child_index, "".to_owned());
965                    }
966                }
967            }
968        }
969
970        graph
971    }
972}
973
974struct StageGraphBuilder {
975    stages: HashMap<StageId, QueryStage>,
976    child_edges: HashMap<StageId, HashSet<StageId>>,
977    parent_edges: HashMap<StageId, HashSet<StageId>>,
978    batch_parallelism: usize,
979}
980
981impl StageGraphBuilder {
982    pub fn new(batch_parallelism: usize) -> Self {
983        Self {
984            stages: HashMap::new(),
985            child_edges: HashMap::new(),
986            parent_edges: HashMap::new(),
987            batch_parallelism,
988        }
989    }
990
991    pub fn build(self, root_stage_id: StageId) -> StageGraph {
992        StageGraph {
993            root_stage_id,
994            stages: self.stages,
995            child_edges: self.child_edges,
996            parent_edges: self.parent_edges,
997            batch_parallelism: self.batch_parallelism,
998        }
999    }
1000
1001    /// Link parent stage and child stage. Maintain the mappings of parent -> child and child ->
1002    /// parent.
1003    pub fn link_to_child(&mut self, parent_id: StageId, child_id: StageId) {
1004        self.child_edges
1005            .get_mut(&parent_id)
1006            .unwrap()
1007            .insert(child_id);
1008        self.parent_edges
1009            .get_mut(&child_id)
1010            .unwrap()
1011            .insert(parent_id);
1012    }
1013
1014    pub fn add_node(&mut self, stage: QueryStage) {
1015        // Insert here so that left/root stages also has linkage.
1016        self.child_edges.insert(stage.id, HashSet::new());
1017        self.parent_edges.insert(stage.id, HashSet::new());
1018        self.stages.insert(stage.id, stage);
1019    }
1020}
1021
1022impl BatchPlanFragmenter {
1023    /// After split, the `stage_graph` in the framenter may has the stage with incomplete source
1024    /// info, we need to fetch the source info to complete the stage in this function.
1025    /// Why separate this two step(`split()` and `generate_complete_query()`)?
1026    /// The step of fetching source info is a async operation so that we can't do it in the split
1027    /// step.
1028    pub async fn generate_complete_query(self) -> SchedulerResult<Query> {
1029        let stage_graph = self.stage_graph.unwrap();
1030        let new_stage_graph = stage_graph
1031            .complete(&self.catalog_reader, &self.worker_node_manager)
1032            .await?;
1033        Ok(Query {
1034            query_id: self.query_id,
1035            stage_graph: new_stage_graph,
1036        })
1037    }
1038
1039    fn new_stage(
1040        &mut self,
1041        root: PlanRef,
1042        exchange_info: Option<ExchangeInfo>,
1043    ) -> SchedulerResult<StageId> {
1044        let next_stage_id = self.next_stage_id;
1045        self.next_stage_id.inc();
1046
1047        let mut table_scan_info = None;
1048        let mut source_info = None;
1049        let mut file_scan_info = None;
1050
1051        // For current implementation, we can guarantee that each stage has only one table
1052        // scan(except System table) or one source.
1053        if let Some(info) = self.collect_stage_table_scan(root.clone())? {
1054            table_scan_info = Some(info);
1055        } else if let Some(info) = Self::collect_stage_source(root.clone())? {
1056            source_info = Some(info);
1057        } else if let Some(info) = Self::collect_stage_file_scan(root.clone())? {
1058            file_scan_info = Some(info);
1059        }
1060
1061        let mut has_lookup_join = false;
1062        let parallelism = match root.distribution() {
1063            Distribution::Single => {
1064                if let Some(info) = &mut table_scan_info {
1065                    if let Some(partitions) = &mut info.partitions {
1066                        if partitions.len() != 1 {
1067                            // This is rare case, but it's possible on the internal state of the
1068                            // Source operator.
1069                            tracing::warn!(
1070                                "The stage has single distribution, but contains a scan of table `{}` with {} partitions. A single random worker will be assigned",
1071                                info.name,
1072                                partitions.len()
1073                            );
1074
1075                            *partitions = partitions
1076                                .drain()
1077                                .take(1)
1078                                .update(|(_, info)| {
1079                                    info.vnode_bitmap = Bitmap::ones(info.vnode_bitmap.len());
1080                                })
1081                                .collect();
1082                        }
1083                    } else {
1084                        // System table
1085                    }
1086                } else if source_info.is_some() {
1087                    return Err(SchedulerError::Internal(anyhow!(
1088                        "The stage has single distribution, but contains a source operator"
1089                    )));
1090                }
1091                1
1092            }
1093            _ => {
1094                if let Some(table_scan_info) = &table_scan_info {
1095                    table_scan_info
1096                        .partitions
1097                        .as_ref()
1098                        .map(|m| m.len())
1099                        .unwrap_or(1)
1100                } else if let Some(lookup_join_parallelism) =
1101                    self.collect_stage_lookup_join_parallelism(root.clone())?
1102                {
1103                    has_lookup_join = true;
1104                    lookup_join_parallelism
1105                } else if source_info.is_some() {
1106                    0
1107                } else if file_scan_info.is_some() {
1108                    1
1109                } else {
1110                    self.batch_parallelism
1111                }
1112            }
1113        };
1114        if source_info.is_none() && file_scan_info.is_none() && parallelism == 0 {
1115            return Err(BatchError::EmptyWorkerNodes.into());
1116        }
1117        let parallelism = if parallelism == 0 {
1118            None
1119        } else {
1120            Some(parallelism as u32)
1121        };
1122        let dml_table_id = Self::collect_dml_table_id(&root);
1123        let mut builder = QueryStageBuilder::new(
1124            next_stage_id,
1125            parallelism,
1126            exchange_info,
1127            table_scan_info,
1128            source_info,
1129            file_scan_info,
1130            has_lookup_join,
1131            dml_table_id,
1132            root.ctx().session_ctx().session_id(),
1133            root.ctx()
1134                .session_ctx()
1135                .config()
1136                .batch_enable_distributed_dml(),
1137        );
1138
1139        self.visit_node(root, &mut builder, None)?;
1140
1141        Ok(builder.finish(self.stage_graph_builder.as_mut().unwrap()))
1142    }
1143
1144    fn visit_node(
1145        &mut self,
1146        node: PlanRef,
1147        builder: &mut QueryStageBuilder,
1148        parent_exec_node: Option<&mut ExecutionPlanNode>,
1149    ) -> SchedulerResult<()> {
1150        match node.node_type() {
1151            BatchPlanNodeType::BatchExchange => {
1152                self.visit_exchange(node, builder, parent_exec_node)?;
1153            }
1154            _ => {
1155                let mut execution_plan_node = ExecutionPlanNode::try_from(node.clone())?;
1156
1157                for child in node.inputs() {
1158                    self.visit_node(child, builder, Some(&mut execution_plan_node))?;
1159                }
1160
1161                if let Some(parent) = parent_exec_node {
1162                    parent.children.push(execution_plan_node);
1163                } else {
1164                    builder.root = Some(execution_plan_node);
1165                }
1166            }
1167        }
1168        Ok(())
1169    }
1170
1171    fn visit_exchange(
1172        &mut self,
1173        node: PlanRef,
1174        builder: &mut QueryStageBuilder,
1175        parent_exec_node: Option<&mut ExecutionPlanNode>,
1176    ) -> SchedulerResult<()> {
1177        let mut execution_plan_node = ExecutionPlanNode::try_from(node.clone())?;
1178        let child_exchange_info = if let Some(parallelism) = builder.parallelism {
1179            Some(node.distribution().to_prost(
1180                parallelism,
1181                &self.catalog_reader,
1182                &self.worker_node_manager,
1183            )?)
1184        } else {
1185            None
1186        };
1187        let child_stage_id = self.new_stage(node.inputs()[0].clone(), child_exchange_info)?;
1188        execution_plan_node.source_stage_id = Some(child_stage_id);
1189        if builder.parallelism.is_none() {
1190            builder
1191                .children_exchange_distribution
1192                .insert(child_stage_id, node.distribution().clone());
1193        }
1194
1195        if let Some(parent) = parent_exec_node {
1196            parent.children.push(execution_plan_node);
1197        } else {
1198            builder.root = Some(execution_plan_node);
1199        }
1200
1201        builder.children_stages.push(child_stage_id);
1202        Ok(())
1203    }
1204
1205    /// Check whether this stage contains a source node.
1206    /// If so, use  `SplitEnumeratorImpl` to get the split info from exteneral source.
1207    ///
1208    /// For current implementation, we can guarantee that each stage has only one source.
1209    fn collect_stage_source(node: PlanRef) -> SchedulerResult<Option<SourceScanInfo>> {
1210        if node.node_type() == BatchPlanNodeType::BatchExchange {
1211            // Do not visit next stage.
1212            return Ok(None);
1213        }
1214
1215        if let Some(batch_kafka_node) = node.as_batch_kafka_scan() {
1216            let batch_kafka_scan: &BatchKafkaScan = batch_kafka_node;
1217            let source_catalog = batch_kafka_scan.source_catalog();
1218            if let Some(source_catalog) = source_catalog {
1219                let property =
1220                    ConnectorProperties::extract(source_catalog.with_properties.clone(), false)?;
1221                let timestamp_bound = batch_kafka_scan.kafka_timestamp_range_value();
1222                return Ok(Some(SourceScanInfo::new(SourceFetchInfo {
1223                    schema: batch_kafka_scan.base.schema().clone(),
1224                    connector: property,
1225                    fetch_parameters: SourceFetchParameters::KafkaTimebound {
1226                        lower: timestamp_bound.0,
1227                        upper: timestamp_bound.1,
1228                    },
1229                })));
1230            }
1231        } else if let Some(batch_iceberg_scan) = node.as_batch_iceberg_scan() {
1232            let batch_iceberg_scan: &BatchIcebergScan = batch_iceberg_scan;
1233            let task = batch_iceberg_scan.task.clone();
1234            return Ok(Some(SourceScanInfo::Unpartitioned(
1235                UnpartitionedData::Iceberg(task),
1236            )));
1237        } else if let Some(source_node) = node.as_batch_source() {
1238            // TODO: use specific batch operator instead of batch source.
1239            let source_node: &BatchSource = source_node;
1240            let source_catalog = source_node.source_catalog();
1241            if let Some(source_catalog) = source_catalog {
1242                let property =
1243                    ConnectorProperties::extract(source_catalog.with_properties.clone(), false)?;
1244                return Ok(Some(SourceScanInfo::new(SourceFetchInfo {
1245                    schema: source_node.base.schema().clone(),
1246                    connector: property,
1247                    fetch_parameters: SourceFetchParameters::Empty,
1248                })));
1249            }
1250        }
1251
1252        node.inputs()
1253            .into_iter()
1254            .find_map(|n| Self::collect_stage_source(n).transpose())
1255            .transpose()
1256    }
1257
1258    fn collect_stage_file_scan(node: PlanRef) -> SchedulerResult<Option<FileScanInfo>> {
1259        if node.node_type() == BatchPlanNodeType::BatchExchange {
1260            // Do not visit next stage.
1261            return Ok(None);
1262        }
1263
1264        if let Some(batch_file_scan) = node.as_batch_file_scan() {
1265            return Ok(Some(FileScanInfo {
1266                file_location: batch_file_scan.core.file_location(),
1267            }));
1268        }
1269
1270        node.inputs()
1271            .into_iter()
1272            .find_map(|n| Self::collect_stage_file_scan(n).transpose())
1273            .transpose()
1274    }
1275
1276    /// Check whether this stage contains a table scan node and the table's information if so.
1277    ///
1278    /// If there are multiple scan nodes in this stage, they must have the same distribution, but
1279    /// maybe different vnodes partition. We just use the same partition for all the scan nodes.
1280    fn collect_stage_table_scan(&self, node: PlanRef) -> SchedulerResult<Option<TableScanInfo>> {
1281        let build_table_scan_info = |name, table_catalog: &TableCatalog, scan_range| {
1282            let vnode_mapping = self
1283                .worker_node_manager
1284                .fragment_mapping(table_catalog.fragment_id)?;
1285            let partitions = derive_partitions(scan_range, table_catalog, &vnode_mapping)?;
1286            let info = TableScanInfo::new(name, partitions);
1287            Ok(Some(info))
1288        };
1289        if node.node_type() == BatchPlanNodeType::BatchExchange {
1290            // Do not visit next stage.
1291            return Ok(None);
1292        }
1293        if let Some(scan_node) = node.as_batch_sys_seq_scan() {
1294            let name = scan_node.core().table.name.clone();
1295            Ok(Some(TableScanInfo::system_table(name)))
1296        } else if let Some(scan_node) = node.as_batch_log_seq_scan() {
1297            build_table_scan_info(
1298                scan_node.core().table_name.clone(),
1299                &scan_node.core().table,
1300                &[],
1301            )
1302        } else if let Some(scan_node) = node.as_batch_seq_scan() {
1303            build_table_scan_info(
1304                scan_node.core().table_name().to_owned(),
1305                &scan_node.core().table_catalog,
1306                scan_node.scan_ranges(),
1307            )
1308        } else {
1309            node.inputs()
1310                .into_iter()
1311                .find_map(|n| self.collect_stage_table_scan(n).transpose())
1312                .transpose()
1313        }
1314    }
1315
1316    /// Returns the dml table id if any.
1317    fn collect_dml_table_id(node: &PlanRef) -> Option<TableId> {
1318        if node.node_type() == BatchPlanNodeType::BatchExchange {
1319            return None;
1320        }
1321        if let Some(insert) = node.as_batch_insert() {
1322            Some(insert.core.table_id)
1323        } else if let Some(update) = node.as_batch_update() {
1324            Some(update.core.table_id)
1325        } else if let Some(delete) = node.as_batch_delete() {
1326            Some(delete.core.table_id)
1327        } else {
1328            node.inputs()
1329                .into_iter()
1330                .find_map(|n| Self::collect_dml_table_id(&n))
1331        }
1332    }
1333
1334    fn collect_stage_lookup_join_parallelism(
1335        &self,
1336        node: PlanRef,
1337    ) -> SchedulerResult<Option<usize>> {
1338        if node.node_type() == BatchPlanNodeType::BatchExchange {
1339            // Do not visit next stage.
1340            return Ok(None);
1341        }
1342        if let Some(lookup_join) = node.as_batch_lookup_join() {
1343            let table_catalog = lookup_join.right_table();
1344            let vnode_mapping = self
1345                .worker_node_manager
1346                .fragment_mapping(table_catalog.fragment_id)?;
1347            let parallelism = vnode_mapping.iter().sorted().dedup().count();
1348            Ok(Some(parallelism))
1349        } else {
1350            node.inputs()
1351                .into_iter()
1352                .find_map(|n| self.collect_stage_lookup_join_parallelism(n).transpose())
1353                .transpose()
1354        }
1355    }
1356}
1357
1358/// Try to derive the partition to read from the scan range.
1359/// It can be derived if the value of the distribution key is already known.
1360fn derive_partitions(
1361    scan_ranges: &[ScanRange],
1362    table_catalog: &TableCatalog,
1363    vnode_mapping: &WorkerSlotMapping,
1364) -> SchedulerResult<HashMap<WorkerSlotId, TablePartitionInfo>> {
1365    let vnode_mapping = if table_catalog.vnode_count.value() != vnode_mapping.len() {
1366        // The vnode count mismatch occurs only in special cases where a hash-distributed fragment
1367        // contains singleton internal tables. e.g., the state table of `Source` executors.
1368        // In this case, we reduce the vnode mapping to a single vnode as only `SINGLETON_VNODE` is used.
1369        assert_eq!(
1370            table_catalog.vnode_count.value(),
1371            1,
1372            "fragment vnode count {} does not match table vnode count {}",
1373            vnode_mapping.len(),
1374            table_catalog.vnode_count.value(),
1375        );
1376        &WorkerSlotMapping::new_single(vnode_mapping.iter().next().unwrap())
1377    } else {
1378        vnode_mapping
1379    };
1380    let vnode_count = vnode_mapping.len();
1381
1382    let mut partitions: HashMap<WorkerSlotId, (BitmapBuilder, Vec<_>)> = HashMap::new();
1383
1384    if scan_ranges.is_empty() {
1385        return Ok(vnode_mapping
1386            .to_bitmaps()
1387            .into_iter()
1388            .map(|(k, vnode_bitmap)| {
1389                (
1390                    k,
1391                    TablePartitionInfo {
1392                        vnode_bitmap,
1393                        scan_ranges: vec![],
1394                    },
1395                )
1396            })
1397            .collect());
1398    }
1399
1400    let table_distribution = TableDistribution::new_from_storage_table_desc(
1401        Some(Bitmap::ones(vnode_count).into()),
1402        &table_catalog.table_desc().try_to_protobuf()?,
1403    );
1404
1405    for scan_range in scan_ranges {
1406        let vnode = scan_range.try_compute_vnode(&table_distribution);
1407        match vnode {
1408            None => {
1409                // put this scan_range to all partitions
1410                vnode_mapping.to_bitmaps().into_iter().for_each(
1411                    |(worker_slot_id, vnode_bitmap)| {
1412                        let (bitmap, scan_ranges) = partitions
1413                            .entry(worker_slot_id)
1414                            .or_insert_with(|| (BitmapBuilder::zeroed(vnode_count), vec![]));
1415                        vnode_bitmap
1416                            .iter()
1417                            .enumerate()
1418                            .for_each(|(vnode, b)| bitmap.set(vnode, b));
1419                        scan_ranges.push(scan_range.to_protobuf());
1420                    },
1421                );
1422            }
1423            // scan a single partition
1424            Some(vnode) => {
1425                let worker_slot_id = vnode_mapping[vnode];
1426                let (bitmap, scan_ranges) = partitions
1427                    .entry(worker_slot_id)
1428                    .or_insert_with(|| (BitmapBuilder::zeroed(vnode_count), vec![]));
1429                bitmap.set(vnode.to_index(), true);
1430                scan_ranges.push(scan_range.to_protobuf());
1431            }
1432        }
1433    }
1434
1435    Ok(partitions
1436        .into_iter()
1437        .map(|(k, (bitmap, scan_ranges))| {
1438            (
1439                k,
1440                TablePartitionInfo {
1441                    vnode_bitmap: bitmap.finish(),
1442                    scan_ranges,
1443                },
1444            )
1445        })
1446        .collect())
1447}
1448
1449#[cfg(test)]
1450mod tests {
1451    use std::collections::{HashMap, HashSet};
1452
1453    use risingwave_pb::batch_plan::plan_node::NodeBody;
1454
1455    use crate::optimizer::plan_node::BatchPlanNodeType;
1456    use crate::scheduler::plan_fragmenter::StageId;
1457
1458    #[tokio::test]
1459    async fn test_fragmenter() {
1460        let query = crate::scheduler::distributed::tests::create_query().await;
1461
1462        assert_eq!(query.stage_graph.root_stage_id, 0.into());
1463        assert_eq!(query.stage_graph.stages.len(), 4);
1464
1465        // Check the mappings of child edges.
1466        assert_eq!(
1467            query.stage_graph.child_edges[&0.into()],
1468            HashSet::from_iter([1.into()])
1469        );
1470        assert_eq!(
1471            query.stage_graph.child_edges[&1.into()],
1472            HashSet::from_iter([2.into(), 3.into()])
1473        );
1474        assert_eq!(query.stage_graph.child_edges[&2.into()], HashSet::new());
1475        assert_eq!(query.stage_graph.child_edges[&3.into()], HashSet::new());
1476
1477        // Check the mappings of parent edges.
1478        assert_eq!(query.stage_graph.parent_edges[&0.into()], HashSet::new());
1479        assert_eq!(
1480            query.stage_graph.parent_edges[&1.into()],
1481            HashSet::from_iter([0.into()])
1482        );
1483        assert_eq!(
1484            query.stage_graph.parent_edges[&2.into()],
1485            HashSet::from_iter([1.into()])
1486        );
1487        assert_eq!(
1488            query.stage_graph.parent_edges[&3.into()],
1489            HashSet::from_iter([1.into()])
1490        );
1491
1492        // Verify topology order
1493        {
1494            let stage_id_to_pos: HashMap<StageId, usize> = query
1495                .stage_graph
1496                .stage_ids_by_topo_order()
1497                .enumerate()
1498                .map(|(pos, stage_id)| (stage_id, pos))
1499                .collect();
1500
1501            for stage_id in query.stage_graph.stages.keys() {
1502                let stage_pos = stage_id_to_pos[stage_id];
1503                for child_stage_id in &query.stage_graph.child_edges[stage_id] {
1504                    let child_pos = stage_id_to_pos[child_stage_id];
1505                    assert!(stage_pos > child_pos);
1506                }
1507            }
1508        }
1509
1510        // Check plan node in each stages.
1511        let root_exchange = query.stage_graph.stages.get(&0.into()).unwrap();
1512        assert_eq!(
1513            root_exchange.root.node_type(),
1514            BatchPlanNodeType::BatchExchange
1515        );
1516        assert_eq!(root_exchange.root.source_stage_id, Some(1.into()));
1517        assert!(matches!(root_exchange.root.node, NodeBody::Exchange(_)));
1518        assert_eq!(root_exchange.parallelism, Some(1));
1519        assert!(!root_exchange.has_table_scan());
1520
1521        let join_node = query.stage_graph.stages.get(&1.into()).unwrap();
1522        assert_eq!(join_node.root.node_type(), BatchPlanNodeType::BatchHashJoin);
1523        assert_eq!(join_node.parallelism, Some(24));
1524
1525        assert!(matches!(join_node.root.node, NodeBody::HashJoin(_)));
1526        assert_eq!(join_node.root.source_stage_id, None);
1527        assert_eq!(2, join_node.root.children.len());
1528
1529        assert!(matches!(
1530            join_node.root.children[0].node,
1531            NodeBody::Exchange(_)
1532        ));
1533        assert_eq!(join_node.root.children[0].source_stage_id, Some(2.into()));
1534        assert_eq!(0, join_node.root.children[0].children.len());
1535
1536        assert!(matches!(
1537            join_node.root.children[1].node,
1538            NodeBody::Exchange(_)
1539        ));
1540        assert_eq!(join_node.root.children[1].source_stage_id, Some(3.into()));
1541        assert_eq!(0, join_node.root.children[1].children.len());
1542        assert!(!join_node.has_table_scan());
1543
1544        let scan_node1 = query.stage_graph.stages.get(&2.into()).unwrap();
1545        assert_eq!(scan_node1.root.node_type(), BatchPlanNodeType::BatchSeqScan);
1546        assert_eq!(scan_node1.root.source_stage_id, None);
1547        assert_eq!(0, scan_node1.root.children.len());
1548        assert!(scan_node1.has_table_scan());
1549
1550        let scan_node2 = query.stage_graph.stages.get(&3.into()).unwrap();
1551        assert_eq!(scan_node2.root.node_type(), BatchPlanNodeType::BatchFilter);
1552        assert_eq!(scan_node2.root.source_stage_id, None);
1553        assert_eq!(1, scan_node2.root.children.len());
1554        assert!(scan_node2.has_table_scan());
1555    }
1556}