Skip to main content

risingwave_frontend/scheduler/
plan_fragmenter.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::min;
16use std::collections::{HashMap, HashSet};
17use std::fmt::{Debug, Display, Formatter};
18use std::num::NonZeroU64;
19
20use anyhow::anyhow;
21use async_recursion::async_recursion;
22use enum_as_inner::EnumAsInner;
23use futures::TryStreamExt;
24use itertools::Itertools;
25use petgraph::{Directed, Graph};
26use pgwire::pg_server::SessionId;
27use risingwave_batch::error::BatchError;
28use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeSelector;
29use risingwave_common::bitmap::{Bitmap, BitmapBuilder};
30use risingwave_common::catalog::Schema;
31use risingwave_common::hash::table_distribution::TableDistribution;
32use risingwave_common::hash::{WorkerSlotId, WorkerSlotMapping};
33use risingwave_common::util::scan_range::ScanRange;
34use risingwave_connector::source::filesystem::opendal_source::opendal_enumerator::OpendalEnumerator;
35use risingwave_connector::source::filesystem::opendal_source::{
36    BatchPosixFsEnumerator, OpendalAzblob, OpendalGcs, OpendalS3,
37};
38use risingwave_connector::source::iceberg::{
39    IcebergFileScanTask, IcebergSplit, IcebergSplitEnumerator,
40};
41use risingwave_connector::source::kafka::KafkaSplitEnumerator;
42use risingwave_connector::source::prelude::DatagenSplitEnumerator;
43use risingwave_connector::source::reader::reader::build_opendal_fs_list_for_batch;
44use risingwave_connector::source::{
45    ConnectorProperties, SourceEnumeratorContext, SplitEnumerator, SplitImpl,
46};
47use risingwave_pb::batch_plan::plan_node::NodeBody;
48use risingwave_pb::batch_plan::{ExchangeInfo, ScanRange as ScanRangeProto};
49use risingwave_pb::plan_common::Field as PbField;
50use serde::ser::SerializeStruct;
51use serde::{Serialize, Serializer};
52use uuid::Uuid;
53
54use super::SchedulerError;
55use crate::TableCatalog;
56use crate::catalog::TableId;
57use crate::catalog::catalog_service::CatalogReader;
58use crate::optimizer::plan_node::generic::{GenericPlanRef, PhysicalPlanRef};
59use crate::optimizer::plan_node::{
60    BatchIcebergScan, BatchKafkaScan, BatchPlanNodeType, BatchPlanRef as PlanRef, BatchSource,
61    PlanNodeId,
62};
63use crate::optimizer::property::Distribution;
64use crate::scheduler::SchedulerResult;
65
66#[derive(Clone, Debug, Hash, Eq, PartialEq)]
67pub struct QueryId {
68    pub id: String,
69}
70
71impl std::fmt::Display for QueryId {
72    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
73        write!(f, "QueryId:{}", self.id)
74    }
75}
76
77#[derive(Copy, Clone, Hash, PartialEq, Eq, Ord, PartialOrd)]
78pub struct StageId(u32);
79
80impl Display for StageId {
81    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
82        write!(f, "{}", self.0)
83    }
84}
85
86impl Debug for StageId {
87    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
88        write!(f, "{:?}", self.0)
89    }
90}
91
92impl From<StageId> for u32 {
93    fn from(value: StageId) -> Self {
94        value.0
95    }
96}
97
98impl From<u32> for StageId {
99    fn from(value: u32) -> Self {
100        StageId(value)
101    }
102}
103
104impl Serialize for StageId {
105    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
106    where
107        S: Serializer,
108    {
109        self.0.serialize(serializer)
110    }
111}
112
113impl StageId {
114    pub fn inc(&mut self) {
115        self.0 += 1;
116    }
117}
118
119// Root stage always has only one task.
120pub const ROOT_TASK_ID: u64 = 0;
121// Root task has only one output.
122pub const ROOT_TASK_OUTPUT_ID: u64 = 0;
123pub type TaskId = u64;
124
125/// Generated by [`BatchPlanFragmenter`] and used in query execution graph.
126#[derive(Debug)]
127#[cfg_attr(test, derive(Clone))]
128pub struct ExecutionPlanNode {
129    pub plan_node_id: PlanNodeId,
130    pub plan_node_type: BatchPlanNodeType,
131    pub node: NodeBody,
132    pub schema: Vec<PbField>,
133
134    pub children: Vec<ExecutionPlanNode>,
135
136    /// The stage id of the source of `BatchExchange`.
137    /// Used to find `ExchangeSource` from scheduler when creating `PlanNode`.
138    ///
139    /// `None` when this node is not `BatchExchange`.
140    pub source_stage_id: Option<StageId>,
141}
142
143impl Serialize for ExecutionPlanNode {
144    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
145    where
146        S: serde::Serializer,
147    {
148        let mut state = serializer.serialize_struct("QueryStage", 5)?;
149        state.serialize_field("plan_node_id", &self.plan_node_id)?;
150        state.serialize_field("plan_node_type", &self.plan_node_type)?;
151        state.serialize_field("schema", &self.schema)?;
152        state.serialize_field("children", &self.children)?;
153        state.serialize_field("source_stage_id", &self.source_stage_id)?;
154        state.end()
155    }
156}
157
158impl TryFrom<PlanRef> for ExecutionPlanNode {
159    type Error = SchedulerError;
160
161    fn try_from(plan_node: PlanRef) -> Result<Self, Self::Error> {
162        Ok(Self {
163            plan_node_id: plan_node.plan_base().id(),
164            plan_node_type: plan_node.node_type(),
165            node: plan_node.try_to_batch_prost_body()?,
166            children: vec![],
167            schema: plan_node.schema().to_prost(),
168            source_stage_id: None,
169        })
170    }
171}
172
173impl ExecutionPlanNode {
174    pub fn node_type(&self) -> BatchPlanNodeType {
175        self.plan_node_type
176    }
177}
178
179/// `BatchPlanFragmenter` splits a query plan into fragments.
180pub struct BatchPlanFragmenter {
181    query_id: QueryId,
182    next_stage_id: StageId,
183    worker_node_manager: WorkerNodeSelector,
184    catalog_reader: CatalogReader,
185
186    batch_parallelism: usize,
187
188    stage_graph_builder: Option<StageGraphBuilder>,
189    stage_graph: Option<StageGraph>,
190}
191
192impl Default for QueryId {
193    fn default() -> Self {
194        Self {
195            id: Uuid::new_v4().to_string(),
196        }
197    }
198}
199
200impl BatchPlanFragmenter {
201    pub fn new(
202        worker_node_manager: WorkerNodeSelector,
203        catalog_reader: CatalogReader,
204        batch_parallelism: Option<NonZeroU64>,
205        batch_node: PlanRef,
206    ) -> SchedulerResult<Self> {
207        // if batch_parallelism is None, it means no limit, we will use the available nodes count as
208        // parallelism.
209        // if batch_parallelism is Some(num), we will use the min(num, the available
210        // nodes count) as parallelism.
211        let batch_parallelism = if let Some(num) = batch_parallelism {
212            // can be 0 if no available serving worker
213            min(
214                num.get() as usize,
215                worker_node_manager.schedule_unit_count(),
216            )
217        } else {
218            // can be 0 if no available serving worker
219            worker_node_manager.schedule_unit_count()
220        };
221
222        let mut plan_fragmenter = Self {
223            query_id: Default::default(),
224            next_stage_id: 0.into(),
225            worker_node_manager,
226            catalog_reader,
227            batch_parallelism,
228            stage_graph_builder: Some(StageGraphBuilder::new(batch_parallelism)),
229            stage_graph: None,
230        };
231        plan_fragmenter.split_into_stage(batch_node)?;
232        Ok(plan_fragmenter)
233    }
234
235    /// Split the plan node into each stages, based on exchange node.
236    fn split_into_stage(&mut self, batch_node: PlanRef) -> SchedulerResult<()> {
237        let root_stage_id = self.new_stage(
238            batch_node,
239            Some(Distribution::Single.to_prost(
240                1,
241                &self.catalog_reader,
242                &self.worker_node_manager,
243            )?),
244        )?;
245        self.stage_graph = Some(
246            self.stage_graph_builder
247                .take()
248                .unwrap()
249                .build(root_stage_id),
250        );
251        Ok(())
252    }
253}
254
255/// The fragmented query generated by [`BatchPlanFragmenter`].
256#[derive(Debug)]
257#[cfg_attr(test, derive(Clone))]
258pub struct Query {
259    /// Query id should always be unique.
260    pub query_id: QueryId,
261    pub stage_graph: StageGraph,
262}
263
264impl Query {
265    pub fn leaf_stages(&self) -> Vec<StageId> {
266        let mut ret_leaf_stages = Vec::new();
267        for stage_id in self.stage_graph.stages.keys() {
268            if self
269                .stage_graph
270                .get_child_stages_unchecked(stage_id)
271                .is_empty()
272            {
273                ret_leaf_stages.push(*stage_id);
274            }
275        }
276        ret_leaf_stages
277    }
278
279    pub fn get_parents(&self, stage_id: &StageId) -> &HashSet<StageId> {
280        self.stage_graph.parent_edges.get(stage_id).unwrap()
281    }
282
283    pub fn root_stage_id(&self) -> StageId {
284        self.stage_graph.root_stage_id
285    }
286
287    pub fn query_id(&self) -> &QueryId {
288        &self.query_id
289    }
290
291    pub fn stages_with_table_scan(&self) -> HashSet<StageId> {
292        self.stage_graph
293            .stages
294            .iter()
295            .filter_map(|(stage_id, stage_query)| {
296                if stage_query.has_table_scan() {
297                    Some(*stage_id)
298                } else {
299                    None
300                }
301            })
302            .collect()
303    }
304
305    pub fn has_lookup_join_stage(&self) -> bool {
306        self.stage_graph
307            .stages
308            .iter()
309            .any(|(_stage_id, stage_query)| stage_query.has_lookup_join())
310    }
311
312    pub fn stage(&self, stage_id: StageId) -> &QueryStage {
313        &self.stage_graph.stages[&stage_id]
314    }
315}
316
317#[derive(Debug, Clone)]
318pub enum SourceFetchParameters {
319    KafkaTimebound {
320        lower: Option<i64>,
321        upper: Option<i64>,
322    },
323    Empty,
324}
325
326#[derive(Debug, Clone)]
327pub enum UnpartitionedData {
328    Iceberg {
329        task: IcebergFileScanTask,
330        limit: Option<u64>,
331    },
332}
333
334#[derive(Debug, Clone)]
335pub struct SourceFetchInfo {
336    pub schema: Schema,
337    /// These are user-configured connector properties.
338    /// e.g. host, username, etc...
339    pub connector: ConnectorProperties,
340    /// These parameters are internally derived by the plan node.
341    /// e.g. predicate pushdown for iceberg, timebound for kafka.
342    pub fetch_parameters: SourceFetchParameters,
343}
344
345#[derive(Clone)]
346pub enum SourceScanInfo {
347    /// Split Info
348    Incomplete(SourceFetchInfo),
349    Unpartitioned(UnpartitionedData),
350    Complete(Vec<SplitImpl>),
351}
352
353impl SourceScanInfo {
354    pub fn new(fetch_info: SourceFetchInfo) -> Self {
355        Self::Incomplete(fetch_info)
356    }
357
358    pub async fn complete(self, batch_parallelism: usize) -> SchedulerResult<Self> {
359        match self {
360            SourceScanInfo::Incomplete(fetch_info) => fetch_info.complete(batch_parallelism).await,
361            SourceScanInfo::Unpartitioned(data) => data.complete(batch_parallelism),
362            SourceScanInfo::Complete(_) => {
363                unreachable!("Never call complete when SourceScanInfo is already complete")
364            }
365        }
366    }
367
368    pub fn split_info(&self) -> SchedulerResult<&Vec<SplitImpl>> {
369        match self {
370            Self::Incomplete(_) => Err(SchedulerError::Internal(anyhow!(
371                "Should not get split info from incomplete source scan info"
372            ))),
373            Self::Unpartitioned(_) => Err(SchedulerError::Internal(anyhow!(
374                "Should not get split info from unpartitioned source scan info"
375            ))),
376            Self::Complete(split_info) => Ok(split_info),
377        }
378    }
379}
380
381impl UnpartitionedData {
382    fn complete(self, batch_parallelism: usize) -> SchedulerResult<SourceScanInfo> {
383        macro_rules! iceberg_tasks {
384            ($tasks:expr, $variant:ident, $limit:expr) => {
385                if let Some(limit) = $limit {
386                    vec![SplitImpl::Iceberg(IcebergSplit {
387                        split_id: 0,
388                        task: IcebergFileScanTask::$variant($tasks),
389                        limit: Some(limit),
390                    })]
391                } else {
392                    IcebergSplitEnumerator::split_n_vecs($tasks, batch_parallelism)
393                        .into_iter()
394                        .enumerate()
395                        .map(|(id, tasks)| {
396                            SplitImpl::Iceberg(IcebergSplit {
397                                split_id: id.try_into().unwrap(),
398                                task: IcebergFileScanTask::$variant(tasks),
399                                limit: None,
400                            })
401                        })
402                        .collect()
403                }
404            };
405        }
406
407        let splits = match self {
408            UnpartitionedData::Iceberg { task, limit } => match task {
409                IcebergFileScanTask::Data(tasks) => iceberg_tasks!(tasks, Data, limit),
410                IcebergFileScanTask::EqualityDelete(tasks) => {
411                    iceberg_tasks!(tasks, EqualityDelete, None)
412                }
413                IcebergFileScanTask::PositionDelete(tasks) => {
414                    iceberg_tasks!(tasks, PositionDelete, None)
415                }
416            },
417        };
418        Ok(SourceScanInfo::Complete(splits))
419    }
420}
421
422impl SourceFetchInfo {
423    async fn complete(self, _batch_parallelism: usize) -> SchedulerResult<SourceScanInfo> {
424        match (self.connector, self.fetch_parameters) {
425            (
426                ConnectorProperties::Kafka(prop),
427                SourceFetchParameters::KafkaTimebound { lower, upper },
428            ) => {
429                let mut kafka_enumerator =
430                    KafkaSplitEnumerator::new(*prop, SourceEnumeratorContext::dummy().into())
431                        .await?;
432                let split_info = kafka_enumerator
433                    .list_splits_batch(lower, upper)
434                    .await?
435                    .into_iter()
436                    .map(SplitImpl::Kafka)
437                    .collect_vec();
438
439                Ok(SourceScanInfo::Complete(split_info))
440            }
441            (ConnectorProperties::Datagen(prop), SourceFetchParameters::Empty) => {
442                let mut datagen_enumerator =
443                    DatagenSplitEnumerator::new(*prop, SourceEnumeratorContext::dummy().into())
444                        .await?;
445                let split_info = datagen_enumerator.list_splits().await?;
446                let res = split_info.into_iter().map(SplitImpl::Datagen).collect_vec();
447
448                Ok(SourceScanInfo::Complete(res))
449            }
450            (ConnectorProperties::OpendalS3(prop), SourceFetchParameters::Empty) => {
451                let lister: OpendalEnumerator<OpendalS3> = OpendalEnumerator::new_s3_source(
452                    &prop.s3_properties,
453                    prop.assume_role,
454                    prop.fs_common.compression_format,
455                )?;
456                let stream = build_opendal_fs_list_for_batch(lister);
457
458                let batch_res: Vec<_> = stream.try_collect().await?;
459                let res = batch_res
460                    .into_iter()
461                    .map(SplitImpl::OpendalS3)
462                    .collect_vec();
463
464                Ok(SourceScanInfo::Complete(res))
465            }
466            (ConnectorProperties::Gcs(prop), SourceFetchParameters::Empty) => {
467                let lister: OpendalEnumerator<OpendalGcs> =
468                    OpendalEnumerator::new_gcs_source(*prop)?;
469                let stream = build_opendal_fs_list_for_batch(lister);
470                let batch_res: Vec<_> = stream.try_collect().await?;
471                let res = batch_res.into_iter().map(SplitImpl::Gcs).collect_vec();
472
473                Ok(SourceScanInfo::Complete(res))
474            }
475            (ConnectorProperties::Azblob(prop), SourceFetchParameters::Empty) => {
476                let lister: OpendalEnumerator<OpendalAzblob> =
477                    OpendalEnumerator::new_azblob_source(*prop)?;
478                let stream = build_opendal_fs_list_for_batch(lister);
479                let batch_res: Vec<_> = stream.try_collect().await?;
480                let res = batch_res.into_iter().map(SplitImpl::Azblob).collect_vec();
481
482                Ok(SourceScanInfo::Complete(res))
483            }
484            (ConnectorProperties::BatchPosixFs(prop), SourceFetchParameters::Empty) => {
485                use risingwave_connector::source::SplitEnumerator;
486                let mut enumerator = BatchPosixFsEnumerator::new(
487                    *prop,
488                    risingwave_connector::source::SourceEnumeratorContext::dummy().into(),
489                )
490                .await?;
491                let splits = enumerator.list_splits().await?;
492                let res = splits
493                    .into_iter()
494                    .map(SplitImpl::BatchPosixFs)
495                    .collect_vec();
496
497                Ok(SourceScanInfo::Complete(res))
498            }
499            (connector, _) => Err(SchedulerError::Internal(anyhow!(
500                "Unsupported to query directly from this {} source, \
501                 please create a table or streaming job from it",
502                connector.kind()
503            ))),
504        }
505    }
506}
507
508#[derive(Clone, Debug)]
509pub struct TableScanInfo {
510    /// The name of the table to scan.
511    name: String,
512
513    /// Indicates the table partitions to be read by scan tasks. Unnecessary partitions are already
514    /// pruned.
515    ///
516    /// For singleton table, this field is still `Some` and only contains a single partition with
517    /// full vnode bitmap, since we need to know where to schedule the singleton scan task.
518    ///
519    /// `None` iff the table is a system table.
520    partitions: Option<HashMap<WorkerSlotId, TablePartitionInfo>>,
521}
522
523impl TableScanInfo {
524    /// For normal tables, `partitions` should always be `Some`.
525    pub fn new(name: String, partitions: HashMap<WorkerSlotId, TablePartitionInfo>) -> Self {
526        Self {
527            name,
528            partitions: Some(partitions),
529        }
530    }
531
532    /// For system table, there's no partition info.
533    pub fn system_table(name: String) -> Self {
534        Self {
535            name,
536            partitions: None,
537        }
538    }
539
540    pub fn name(&self) -> &str {
541        self.name.as_ref()
542    }
543
544    pub fn partitions(&self) -> Option<&HashMap<WorkerSlotId, TablePartitionInfo>> {
545        self.partitions.as_ref()
546    }
547}
548
549#[derive(Clone, Debug)]
550pub struct TablePartitionInfo {
551    pub vnode_bitmap: Bitmap,
552    pub scan_ranges: Vec<ScanRangeProto>,
553}
554
555#[derive(Clone, Debug, EnumAsInner)]
556pub enum PartitionInfo {
557    Table(TablePartitionInfo),
558    Source(Vec<SplitImpl>),
559    File(Vec<String>),
560}
561
562#[derive(Clone, Debug)]
563pub struct FileScanInfo {
564    pub file_location: Vec<String>,
565}
566
567/// Fragment part of `Query`.
568#[cfg_attr(test, derive(Clone))]
569pub struct QueryStage {
570    pub id: StageId,
571    pub root: ExecutionPlanNode,
572    pub exchange_info: Option<ExchangeInfo>,
573    pub parallelism: Option<u32>,
574    /// Indicates whether this stage contains a table scan node and the table's information if so.
575    pub table_scan_info: Option<TableScanInfo>,
576    pub source_info: Option<SourceScanInfo>,
577    pub file_scan_info: Option<FileScanInfo>,
578    pub has_lookup_join: bool,
579    pub dml_table_id: Option<TableId>,
580    pub session_id: SessionId,
581    pub batch_enable_distributed_dml: bool,
582
583    /// Used to generate exchange information when complete source scan information.
584    children_exchange_distribution: Option<HashMap<StageId, Distribution>>,
585}
586
587impl QueryStage {
588    /// If true, this stage contains table scan executor that creates
589    /// Hummock iterators to read data from table. The iterator is initialized during
590    /// the executor building process on the batch execution engine.
591    pub fn has_table_scan(&self) -> bool {
592        self.table_scan_info.is_some()
593    }
594
595    /// If true, this stage contains lookup join executor.
596    /// We need to delay epoch unpin util the end of the query.
597    pub fn has_lookup_join(&self) -> bool {
598        self.has_lookup_join
599    }
600
601    pub fn with_exchange_info(
602        self,
603        exchange_info: Option<ExchangeInfo>,
604        parallelism: Option<u32>,
605    ) -> Self {
606        if let Some(exchange_info) = exchange_info {
607            Self {
608                id: self.id,
609                root: self.root,
610                exchange_info: Some(exchange_info),
611                parallelism,
612                table_scan_info: self.table_scan_info,
613                source_info: self.source_info,
614                file_scan_info: self.file_scan_info,
615                has_lookup_join: self.has_lookup_join,
616                dml_table_id: self.dml_table_id,
617                session_id: self.session_id,
618                batch_enable_distributed_dml: self.batch_enable_distributed_dml,
619                children_exchange_distribution: self.children_exchange_distribution,
620            }
621        } else {
622            self
623        }
624    }
625
626    pub fn with_exchange_info_and_complete_source_info(
627        self,
628        exchange_info: Option<ExchangeInfo>,
629        source_info: SourceScanInfo,
630        task_parallelism: u32,
631    ) -> Self {
632        assert!(matches!(source_info, SourceScanInfo::Complete(_)));
633        let exchange_info = if let Some(exchange_info) = exchange_info {
634            Some(exchange_info)
635        } else {
636            self.exchange_info
637        };
638        Self {
639            id: self.id,
640            root: self.root,
641            exchange_info,
642            parallelism: Some(task_parallelism),
643            table_scan_info: self.table_scan_info,
644            source_info: Some(source_info),
645            file_scan_info: self.file_scan_info,
646            has_lookup_join: self.has_lookup_join,
647            dml_table_id: self.dml_table_id,
648            session_id: self.session_id,
649            batch_enable_distributed_dml: self.batch_enable_distributed_dml,
650            children_exchange_distribution: None,
651        }
652    }
653}
654
655impl Debug for QueryStage {
656    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
657        f.debug_struct("QueryStage")
658            .field("id", &self.id)
659            .field("parallelism", &self.parallelism)
660            .field("exchange_info", &self.exchange_info)
661            .field("has_table_scan", &self.has_table_scan())
662            .finish()
663    }
664}
665
666impl Serialize for QueryStage {
667    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
668    where
669        S: serde::Serializer,
670    {
671        let mut state = serializer.serialize_struct("QueryStage", 3)?;
672        state.serialize_field("root", &self.root)?;
673        state.serialize_field("parallelism", &self.parallelism)?;
674        state.serialize_field("exchange_info", &self.exchange_info)?;
675        state.end()
676    }
677}
678
679struct QueryStageBuilder {
680    id: StageId,
681    root: Option<ExecutionPlanNode>,
682    parallelism: Option<u32>,
683    exchange_info: Option<ExchangeInfo>,
684
685    children_stages: Vec<StageId>,
686    /// See also [`QueryStage::table_scan_info`].
687    table_scan_info: Option<TableScanInfo>,
688    source_info: Option<SourceScanInfo>,
689    file_scan_file: Option<FileScanInfo>,
690    has_lookup_join: bool,
691    dml_table_id: Option<TableId>,
692    session_id: SessionId,
693    batch_enable_distributed_dml: bool,
694
695    children_exchange_distribution: HashMap<StageId, Distribution>,
696}
697
698impl QueryStageBuilder {
699    fn new(
700        id: StageId,
701        parallelism: Option<u32>,
702        exchange_info: Option<ExchangeInfo>,
703        table_scan_info: Option<TableScanInfo>,
704        source_info: Option<SourceScanInfo>,
705        file_scan_file: Option<FileScanInfo>,
706        has_lookup_join: bool,
707        dml_table_id: Option<TableId>,
708        session_id: SessionId,
709        batch_enable_distributed_dml: bool,
710    ) -> Self {
711        Self {
712            id,
713            root: None,
714            parallelism,
715            exchange_info,
716            children_stages: vec![],
717            table_scan_info,
718            source_info,
719            file_scan_file,
720            has_lookup_join,
721            dml_table_id,
722            session_id,
723            batch_enable_distributed_dml,
724            children_exchange_distribution: HashMap::new(),
725        }
726    }
727
728    fn finish(self, stage_graph_builder: &mut StageGraphBuilder) -> StageId {
729        let children_exchange_distribution = if self.parallelism.is_none() {
730            Some(self.children_exchange_distribution)
731        } else {
732            None
733        };
734        let stage = QueryStage {
735            id: self.id,
736            root: self.root.unwrap(),
737            exchange_info: self.exchange_info,
738            parallelism: self.parallelism,
739            table_scan_info: self.table_scan_info,
740            source_info: self.source_info,
741            file_scan_info: self.file_scan_file,
742            has_lookup_join: self.has_lookup_join,
743            dml_table_id: self.dml_table_id,
744            session_id: self.session_id,
745            batch_enable_distributed_dml: self.batch_enable_distributed_dml,
746            children_exchange_distribution,
747        };
748
749        let stage_id = stage.id;
750        stage_graph_builder.add_node(stage);
751        for child_stage_id in self.children_stages {
752            stage_graph_builder.link_to_child(self.id, child_stage_id);
753        }
754        stage_id
755    }
756}
757
758/// Maintains how each stage are connected.
759#[derive(Debug, Serialize)]
760#[cfg_attr(test, derive(Clone))]
761pub struct StageGraph {
762    pub root_stage_id: StageId,
763    pub stages: HashMap<StageId, QueryStage>,
764    /// Traverse from top to down. Used in split plan into stages.
765    child_edges: HashMap<StageId, HashSet<StageId>>,
766    /// Traverse from down to top. Used in schedule each stage.
767    parent_edges: HashMap<StageId, HashSet<StageId>>,
768
769    batch_parallelism: usize,
770}
771
772enum StageCompleteInfo {
773    ExchangeInfo((Option<ExchangeInfo>, Option<u32>)),
774    ExchangeWithSourceInfo((Option<ExchangeInfo>, SourceScanInfo, u32)),
775}
776
777impl StageGraph {
778    pub fn get_child_stages_unchecked(&self, stage_id: &StageId) -> &HashSet<StageId> {
779        self.child_edges.get(stage_id).unwrap()
780    }
781
782    pub fn get_child_stages(&self, stage_id: &StageId) -> Option<&HashSet<StageId>> {
783        self.child_edges.get(stage_id)
784    }
785
786    /// Returns stage ids in topology order, s.t. child stage always appears before its parent.
787    pub fn stage_ids_by_topo_order(&self) -> impl Iterator<Item = StageId> {
788        let mut stack = Vec::with_capacity(self.stages.len());
789        stack.push(self.root_stage_id);
790        let mut ret = Vec::with_capacity(self.stages.len());
791        let mut existing = HashSet::with_capacity(self.stages.len());
792
793        while let Some(s) = stack.pop() {
794            if !existing.contains(&s) {
795                ret.push(s);
796                existing.insert(s);
797                stack.extend(&self.child_edges[&s]);
798            }
799        }
800
801        ret.into_iter().rev()
802    }
803
804    async fn complete(
805        self,
806        catalog_reader: &CatalogReader,
807        worker_node_manager: &WorkerNodeSelector,
808    ) -> SchedulerResult<StageGraph> {
809        let mut complete_stages = HashMap::new();
810        self.complete_stage(
811            self.root_stage_id,
812            None,
813            &mut complete_stages,
814            catalog_reader,
815            worker_node_manager,
816        )
817        .await?;
818        let mut stages = self.stages;
819        Ok(StageGraph {
820            root_stage_id: self.root_stage_id,
821            stages: complete_stages
822                .into_iter()
823                .map(|(stage_id, info)| {
824                    let stage = stages.remove(&stage_id).expect("should exist");
825                    let stage = match info {
826                        StageCompleteInfo::ExchangeInfo((exchange_info, parallelism)) => {
827                            stage.with_exchange_info(exchange_info, parallelism)
828                        }
829                        StageCompleteInfo::ExchangeWithSourceInfo((
830                            exchange_info,
831                            source_info,
832                            parallelism,
833                        )) => stage.with_exchange_info_and_complete_source_info(
834                            exchange_info,
835                            source_info,
836                            parallelism,
837                        ),
838                    };
839                    (stage_id, stage)
840                })
841                .collect(),
842            child_edges: self.child_edges,
843            parent_edges: self.parent_edges,
844            batch_parallelism: self.batch_parallelism,
845        })
846    }
847
848    #[async_recursion]
849    async fn complete_stage(
850        &self,
851        stage_id: StageId,
852        exchange_info: Option<ExchangeInfo>,
853        complete_stages: &mut HashMap<StageId, StageCompleteInfo>,
854        catalog_reader: &CatalogReader,
855        worker_node_manager: &WorkerNodeSelector,
856    ) -> SchedulerResult<()> {
857        let stage = &self.stages[&stage_id];
858        let parallelism = if stage.parallelism.is_some() {
859            // If the stage has parallelism, it means it's a complete stage.
860            complete_stages.insert(
861                stage.id,
862                StageCompleteInfo::ExchangeInfo((exchange_info, stage.parallelism)),
863            );
864            None
865        } else if matches!(
866            stage.source_info,
867            Some(SourceScanInfo::Incomplete(_)) | Some(SourceScanInfo::Unpartitioned(_))
868        ) {
869            let complete_source_info = stage
870                .source_info
871                .as_ref()
872                .unwrap()
873                .clone()
874                .complete(self.batch_parallelism)
875                .await?;
876
877            // For batch reading file source, the number of files involved is typically large.
878            // In order to avoid generating a task for each file, the parallelism of tasks is limited here.
879            // The minimum `task_parallelism` is 1. Additionally, `task_parallelism`
880            // must be greater than the number of files to read. Therefore, we first take the
881            // minimum of the number of files and (self.batch_parallelism / 2). If the number of
882            // files is 0, we set task_parallelism to 1.
883
884            let task_parallelism = match &stage.source_info {
885                Some(SourceScanInfo::Incomplete(source_fetch_info)) => {
886                    match source_fetch_info.connector {
887                        ConnectorProperties::Gcs(_)
888                        | ConnectorProperties::OpendalS3(_)
889                        | ConnectorProperties::Azblob(_) => (min(
890                            complete_source_info.split_info().unwrap().len() as u32,
891                            (self.batch_parallelism / 2) as u32,
892                        ))
893                        .max(1),
894                        _ => complete_source_info.split_info().unwrap().len() as u32,
895                    }
896                }
897                _ => complete_source_info.split_info().unwrap().len() as u32,
898            };
899            // For file source batch read, all the files  to be read are divide into several parts to prevent the task from taking up too many resources.
900            // todo(wcy-fdu): Currently it will be divided into half of batch_parallelism groups, and this will be changed to configurable later.
901            let complete_stage_info = StageCompleteInfo::ExchangeWithSourceInfo((
902                exchange_info,
903                complete_source_info,
904                task_parallelism,
905            ));
906            complete_stages.insert(stage.id, complete_stage_info);
907            Some(task_parallelism)
908        } else {
909            assert!(stage.file_scan_info.is_some());
910            let parallelism = min(
911                self.batch_parallelism / 2,
912                stage.file_scan_info.as_ref().unwrap().file_location.len(),
913            );
914            complete_stages.insert(
915                stage.id,
916                StageCompleteInfo::ExchangeInfo((exchange_info, Some(parallelism as u32))),
917            );
918            None
919        };
920
921        for child_stage_id in self
922            .child_edges
923            .get(&stage.id)
924            .map(|edges| edges.iter())
925            .into_iter()
926            .flatten()
927        {
928            let exchange_info = if let Some(parallelism) = parallelism {
929                let exchange_distribution = stage
930                    .children_exchange_distribution
931                    .as_ref()
932                    .unwrap()
933                    .get(child_stage_id)
934                    .expect("Exchange distribution is not consistent with the stage graph");
935                Some(exchange_distribution.to_prost(
936                    parallelism,
937                    catalog_reader,
938                    worker_node_manager,
939                )?)
940            } else {
941                None
942            };
943            self.complete_stage(
944                *child_stage_id,
945                exchange_info,
946                complete_stages,
947                catalog_reader,
948                worker_node_manager,
949            )
950            .await?;
951        }
952
953        Ok(())
954    }
955
956    /// Converts the `StageGraph` into a `petgraph::graph::Graph<String, String>`.
957    pub fn to_petgraph(&self) -> Graph<String, String, Directed> {
958        let mut graph = Graph::<String, String, Directed>::new();
959
960        let mut node_indices = HashMap::new();
961
962        // Add all stages as nodes
963        for (&stage_id, stage_ref) in self.stages.iter().sorted_by_key(|(id, _)| **id) {
964            let node_label = format!("Stage {}: {:?}", stage_id, stage_ref);
965            let node_index = graph.add_node(node_label);
966            node_indices.insert(stage_id, node_index);
967        }
968
969        // Add edges between stages based on child_edges
970        for (&parent_id, children) in &self.child_edges {
971            if let Some(&parent_index) = node_indices.get(&parent_id) {
972                for &child_id in children {
973                    if let Some(&child_index) = node_indices.get(&child_id) {
974                        // Add an edge from parent to child
975                        graph.add_edge(parent_index, child_index, "".to_owned());
976                    }
977                }
978            }
979        }
980
981        graph
982    }
983}
984
985struct StageGraphBuilder {
986    stages: HashMap<StageId, QueryStage>,
987    child_edges: HashMap<StageId, HashSet<StageId>>,
988    parent_edges: HashMap<StageId, HashSet<StageId>>,
989    batch_parallelism: usize,
990}
991
992impl StageGraphBuilder {
993    pub fn new(batch_parallelism: usize) -> Self {
994        Self {
995            stages: HashMap::new(),
996            child_edges: HashMap::new(),
997            parent_edges: HashMap::new(),
998            batch_parallelism,
999        }
1000    }
1001
1002    pub fn build(self, root_stage_id: StageId) -> StageGraph {
1003        StageGraph {
1004            root_stage_id,
1005            stages: self.stages,
1006            child_edges: self.child_edges,
1007            parent_edges: self.parent_edges,
1008            batch_parallelism: self.batch_parallelism,
1009        }
1010    }
1011
1012    /// Link parent stage and child stage. Maintain the mappings of parent -> child and child ->
1013    /// parent.
1014    pub fn link_to_child(&mut self, parent_id: StageId, child_id: StageId) {
1015        self.child_edges
1016            .get_mut(&parent_id)
1017            .unwrap()
1018            .insert(child_id);
1019        self.parent_edges
1020            .get_mut(&child_id)
1021            .unwrap()
1022            .insert(parent_id);
1023    }
1024
1025    pub fn add_node(&mut self, stage: QueryStage) {
1026        // Insert here so that left/root stages also has linkage.
1027        self.child_edges.insert(stage.id, HashSet::new());
1028        self.parent_edges.insert(stage.id, HashSet::new());
1029        self.stages.insert(stage.id, stage);
1030    }
1031}
1032
1033impl BatchPlanFragmenter {
1034    /// After split, the `stage_graph` in the framenter may has the stage with incomplete source
1035    /// info, we need to fetch the source info to complete the stage in this function.
1036    /// Why separate this two step(`split()` and `generate_complete_query()`)?
1037    /// The step of fetching source info is a async operation so that we can't do it in the split
1038    /// step.
1039    pub async fn generate_complete_query(self) -> SchedulerResult<Query> {
1040        let stage_graph = self.stage_graph.unwrap();
1041        let new_stage_graph = stage_graph
1042            .complete(&self.catalog_reader, &self.worker_node_manager)
1043            .await?;
1044        Ok(Query {
1045            query_id: self.query_id,
1046            stage_graph: new_stage_graph,
1047        })
1048    }
1049
1050    fn new_stage(
1051        &mut self,
1052        root: PlanRef,
1053        exchange_info: Option<ExchangeInfo>,
1054    ) -> SchedulerResult<StageId> {
1055        let next_stage_id = self.next_stage_id;
1056        self.next_stage_id.inc();
1057
1058        let mut table_scan_info = None;
1059        let mut source_info = None;
1060        let mut file_scan_info = None;
1061
1062        // For current implementation, we can guarantee that each stage has only one table
1063        // scan(except System table) or one source.
1064        if let Some(info) = self.collect_stage_table_scan(root.clone())? {
1065            table_scan_info = Some(info);
1066        } else if let Some(info) = Self::collect_stage_source(root.clone())? {
1067            source_info = Some(info);
1068        } else if let Some(info) = Self::collect_stage_file_scan(root.clone())? {
1069            file_scan_info = Some(info);
1070        }
1071
1072        let mut has_lookup_join = false;
1073        let parallelism = match root.distribution() {
1074            Distribution::Single => {
1075                if let Some(info) = &mut table_scan_info {
1076                    if let Some(partitions) = &mut info.partitions {
1077                        if partitions.len() != 1 {
1078                            // This is rare case, but it's possible on the internal state of the
1079                            // Source operator.
1080                            tracing::warn!(
1081                                "The stage has single distribution, but contains a scan of table `{}` with {} partitions. A single random worker will be assigned",
1082                                info.name,
1083                                partitions.len()
1084                            );
1085
1086                            *partitions = partitions
1087                                .drain()
1088                                .take(1)
1089                                .update(|(_, info)| {
1090                                    info.vnode_bitmap = Bitmap::ones(info.vnode_bitmap.len());
1091                                })
1092                                .collect();
1093                        }
1094                    } else {
1095                        // System table
1096                    }
1097                } else if source_info.is_some() {
1098                    return Err(SchedulerError::Internal(anyhow!(
1099                        "The stage has single distribution, but contains a source operator"
1100                    )));
1101                }
1102                1
1103            }
1104            _ => {
1105                if let Some(table_scan_info) = &table_scan_info {
1106                    table_scan_info
1107                        .partitions
1108                        .as_ref()
1109                        .map(|m| m.len())
1110                        .unwrap_or(1)
1111                } else if let Some(lookup_join_parallelism) =
1112                    self.collect_stage_lookup_join_parallelism(root.clone())?
1113                {
1114                    has_lookup_join = true;
1115                    lookup_join_parallelism
1116                } else if source_info.is_some() {
1117                    0
1118                } else if file_scan_info.is_some() {
1119                    1
1120                } else {
1121                    self.batch_parallelism
1122                }
1123            }
1124        };
1125        if source_info.is_none() && file_scan_info.is_none() && parallelism == 0 {
1126            return Err(BatchError::EmptyWorkerNodes.into());
1127        }
1128        let parallelism = if parallelism == 0 {
1129            None
1130        } else {
1131            Some(parallelism as u32)
1132        };
1133        let dml_table_id = Self::collect_dml_table_id(&root);
1134        let mut builder = QueryStageBuilder::new(
1135            next_stage_id,
1136            parallelism,
1137            exchange_info,
1138            table_scan_info,
1139            source_info,
1140            file_scan_info,
1141            has_lookup_join,
1142            dml_table_id,
1143            root.ctx().session_ctx().session_id(),
1144            root.ctx()
1145                .session_ctx()
1146                .config()
1147                .batch_enable_distributed_dml(),
1148        );
1149
1150        self.visit_node(root, &mut builder, None)?;
1151
1152        Ok(builder.finish(self.stage_graph_builder.as_mut().unwrap()))
1153    }
1154
1155    fn visit_node(
1156        &mut self,
1157        node: PlanRef,
1158        builder: &mut QueryStageBuilder,
1159        parent_exec_node: Option<&mut ExecutionPlanNode>,
1160    ) -> SchedulerResult<()> {
1161        match node.node_type() {
1162            BatchPlanNodeType::BatchExchange => {
1163                self.visit_exchange(node, builder, parent_exec_node)?;
1164            }
1165            _ => {
1166                let mut execution_plan_node = ExecutionPlanNode::try_from(node.clone())?;
1167
1168                for child in node.inputs() {
1169                    self.visit_node(child, builder, Some(&mut execution_plan_node))?;
1170                }
1171
1172                if let Some(parent) = parent_exec_node {
1173                    parent.children.push(execution_plan_node);
1174                } else {
1175                    builder.root = Some(execution_plan_node);
1176                }
1177            }
1178        }
1179        Ok(())
1180    }
1181
1182    fn visit_exchange(
1183        &mut self,
1184        node: PlanRef,
1185        builder: &mut QueryStageBuilder,
1186        parent_exec_node: Option<&mut ExecutionPlanNode>,
1187    ) -> SchedulerResult<()> {
1188        let mut execution_plan_node = ExecutionPlanNode::try_from(node.clone())?;
1189        let child_exchange_info = if let Some(parallelism) = builder.parallelism {
1190            Some(node.distribution().to_prost(
1191                parallelism,
1192                &self.catalog_reader,
1193                &self.worker_node_manager,
1194            )?)
1195        } else {
1196            None
1197        };
1198        let child_stage_id = self.new_stage(node.inputs()[0].clone(), child_exchange_info)?;
1199        execution_plan_node.source_stage_id = Some(child_stage_id);
1200        if builder.parallelism.is_none() {
1201            builder
1202                .children_exchange_distribution
1203                .insert(child_stage_id, node.distribution().clone());
1204        }
1205
1206        if let Some(parent) = parent_exec_node {
1207            parent.children.push(execution_plan_node);
1208        } else {
1209            builder.root = Some(execution_plan_node);
1210        }
1211
1212        builder.children_stages.push(child_stage_id);
1213        Ok(())
1214    }
1215
1216    /// Check whether this stage contains a source node.
1217    /// If so, use  `SplitEnumeratorImpl` to get the split info from exteneral source.
1218    ///
1219    /// For current implementation, we can guarantee that each stage has only one source.
1220    fn collect_stage_source(node: PlanRef) -> SchedulerResult<Option<SourceScanInfo>> {
1221        if node.node_type() == BatchPlanNodeType::BatchExchange {
1222            // Do not visit next stage.
1223            return Ok(None);
1224        }
1225
1226        if let Some(batch_kafka_node) = node.as_batch_kafka_scan() {
1227            let batch_kafka_scan: &BatchKafkaScan = batch_kafka_node;
1228            let source_catalog = batch_kafka_scan.source_catalog();
1229            if let Some(source_catalog) = source_catalog {
1230                let property =
1231                    ConnectorProperties::extract(source_catalog.with_properties.clone(), false)?;
1232                let timestamp_bound = batch_kafka_scan.kafka_timestamp_range_value();
1233                return Ok(Some(SourceScanInfo::new(SourceFetchInfo {
1234                    schema: batch_kafka_scan.base.schema().clone(),
1235                    connector: property,
1236                    fetch_parameters: SourceFetchParameters::KafkaTimebound {
1237                        lower: timestamp_bound.0,
1238                        upper: timestamp_bound.1,
1239                    },
1240                })));
1241            }
1242        } else if let Some(batch_iceberg_scan) = node.as_batch_iceberg_scan() {
1243            let batch_iceberg_scan: &BatchIcebergScan = batch_iceberg_scan;
1244            let task = batch_iceberg_scan.task.clone();
1245            let limit = batch_iceberg_scan.limit();
1246            return Ok(Some(SourceScanInfo::Unpartitioned(
1247                UnpartitionedData::Iceberg { task, limit },
1248            )));
1249        } else if let Some(source_node) = node.as_batch_source() {
1250            // TODO: use specific batch operator instead of batch source.
1251            let source_node: &BatchSource = source_node;
1252            let source_catalog = source_node.source_catalog();
1253            if let Some(source_catalog) = source_catalog {
1254                let property =
1255                    ConnectorProperties::extract(source_catalog.with_properties.clone(), false)?;
1256                return Ok(Some(SourceScanInfo::new(SourceFetchInfo {
1257                    schema: source_node.base.schema().clone(),
1258                    connector: property,
1259                    fetch_parameters: SourceFetchParameters::Empty,
1260                })));
1261            }
1262        }
1263
1264        node.inputs()
1265            .into_iter()
1266            .find_map(|n| Self::collect_stage_source(n).transpose())
1267            .transpose()
1268    }
1269
1270    fn collect_stage_file_scan(node: PlanRef) -> SchedulerResult<Option<FileScanInfo>> {
1271        if node.node_type() == BatchPlanNodeType::BatchExchange {
1272            // Do not visit next stage.
1273            return Ok(None);
1274        }
1275
1276        if let Some(batch_file_scan) = node.as_batch_file_scan() {
1277            return Ok(Some(FileScanInfo {
1278                file_location: batch_file_scan.core.file_location(),
1279            }));
1280        }
1281
1282        node.inputs()
1283            .into_iter()
1284            .find_map(|n| Self::collect_stage_file_scan(n).transpose())
1285            .transpose()
1286    }
1287
1288    /// Check whether this stage contains a table scan node and the table's information if so.
1289    ///
1290    /// If there are multiple scan nodes in this stage, they must have the same distribution, but
1291    /// maybe different vnodes partition. We just use the same partition for all the scan nodes.
1292    fn collect_stage_table_scan(&self, node: PlanRef) -> SchedulerResult<Option<TableScanInfo>> {
1293        let build_table_scan_info = |name, table_catalog: &TableCatalog, scan_range| {
1294            let vnode_mapping = self
1295                .worker_node_manager
1296                .fragment_mapping(table_catalog.fragment_id)?;
1297            let partitions = derive_partitions(scan_range, table_catalog, &vnode_mapping)?;
1298            let info = TableScanInfo::new(name, partitions);
1299            Ok(Some(info))
1300        };
1301        if node.node_type() == BatchPlanNodeType::BatchExchange {
1302            // Do not visit next stage.
1303            return Ok(None);
1304        }
1305        if let Some(scan_node) = node.as_batch_sys_seq_scan() {
1306            let name = scan_node.core().table.name.clone();
1307            Ok(Some(TableScanInfo::system_table(name)))
1308        } else if let Some(scan_node) = node.as_batch_log_seq_scan() {
1309            build_table_scan_info(
1310                scan_node.core().table_name.clone(),
1311                &scan_node.core().table,
1312                &[],
1313            )
1314        } else if let Some(scan_node) = node.as_batch_seq_scan() {
1315            build_table_scan_info(
1316                scan_node.core().table_name().to_owned(),
1317                &scan_node.core().table_catalog,
1318                scan_node.scan_ranges(),
1319            )
1320        } else {
1321            node.inputs()
1322                .into_iter()
1323                .find_map(|n| self.collect_stage_table_scan(n).transpose())
1324                .transpose()
1325        }
1326    }
1327
1328    /// Returns the dml table id if any.
1329    fn collect_dml_table_id(node: &PlanRef) -> Option<TableId> {
1330        if node.node_type() == BatchPlanNodeType::BatchExchange {
1331            return None;
1332        }
1333        if let Some(insert) = node.as_batch_insert() {
1334            Some(insert.core.table_id)
1335        } else if let Some(update) = node.as_batch_update() {
1336            Some(update.core.table_id)
1337        } else if let Some(delete) = node.as_batch_delete() {
1338            Some(delete.core.table_id)
1339        } else {
1340            node.inputs()
1341                .into_iter()
1342                .find_map(|n| Self::collect_dml_table_id(&n))
1343        }
1344    }
1345
1346    fn collect_stage_lookup_join_parallelism(
1347        &self,
1348        node: PlanRef,
1349    ) -> SchedulerResult<Option<usize>> {
1350        if node.node_type() == BatchPlanNodeType::BatchExchange {
1351            // Do not visit next stage.
1352            return Ok(None);
1353        }
1354        if let Some(lookup_join) = node.as_batch_lookup_join() {
1355            let table_catalog = lookup_join.right_table();
1356            let vnode_mapping = self
1357                .worker_node_manager
1358                .fragment_mapping(table_catalog.fragment_id)?;
1359            let parallelism = vnode_mapping.iter().sorted().dedup().count();
1360            Ok(Some(parallelism))
1361        } else {
1362            node.inputs()
1363                .into_iter()
1364                .find_map(|n| self.collect_stage_lookup_join_parallelism(n).transpose())
1365                .transpose()
1366        }
1367    }
1368}
1369
1370/// Try to derive the partition to read from the scan range.
1371/// It can be derived if the value of the distribution key is already known.
1372fn derive_partitions(
1373    scan_ranges: &[ScanRange],
1374    table_catalog: &TableCatalog,
1375    vnode_mapping: &WorkerSlotMapping,
1376) -> SchedulerResult<HashMap<WorkerSlotId, TablePartitionInfo>> {
1377    let vnode_mapping = if table_catalog.vnode_count.value() != vnode_mapping.len() {
1378        // The vnode count mismatch occurs only in special cases where a hash-distributed fragment
1379        // contains singleton internal tables. e.g., the state table of `Source` executors.
1380        // In this case, we reduce the vnode mapping to a single vnode as only `SINGLETON_VNODE` is used.
1381        assert_eq!(
1382            table_catalog.vnode_count.value(),
1383            1,
1384            "fragment vnode count {} does not match table vnode count {}",
1385            vnode_mapping.len(),
1386            table_catalog.vnode_count.value(),
1387        );
1388        &WorkerSlotMapping::new_single(vnode_mapping.iter().next().unwrap())
1389    } else {
1390        vnode_mapping
1391    };
1392    let vnode_count = vnode_mapping.len();
1393
1394    let mut partitions: HashMap<WorkerSlotId, (BitmapBuilder, Vec<_>)> = HashMap::new();
1395
1396    if scan_ranges.is_empty() {
1397        return Ok(vnode_mapping
1398            .to_bitmaps()
1399            .into_iter()
1400            .map(|(k, vnode_bitmap)| {
1401                (
1402                    k,
1403                    TablePartitionInfo {
1404                        vnode_bitmap,
1405                        scan_ranges: vec![],
1406                    },
1407                )
1408            })
1409            .collect());
1410    }
1411
1412    let table_distribution = TableDistribution::new_from_storage_table_desc(
1413        Some(Bitmap::ones(vnode_count).into()),
1414        &table_catalog.table_desc().try_to_protobuf()?,
1415    );
1416
1417    for scan_range in scan_ranges {
1418        let vnode = scan_range.try_compute_vnode(&table_distribution);
1419        match vnode {
1420            None => {
1421                // put this scan_range to all partitions
1422                vnode_mapping.to_bitmaps().into_iter().for_each(
1423                    |(worker_slot_id, vnode_bitmap)| {
1424                        let (bitmap, scan_ranges) = partitions
1425                            .entry(worker_slot_id)
1426                            .or_insert_with(|| (BitmapBuilder::zeroed(vnode_count), vec![]));
1427                        vnode_bitmap
1428                            .iter()
1429                            .enumerate()
1430                            .for_each(|(vnode, b)| bitmap.set(vnode, b));
1431                        scan_ranges.push(scan_range.to_protobuf());
1432                    },
1433                );
1434            }
1435            // scan a single partition
1436            Some(vnode) => {
1437                let worker_slot_id = vnode_mapping[vnode];
1438                let (bitmap, scan_ranges) = partitions
1439                    .entry(worker_slot_id)
1440                    .or_insert_with(|| (BitmapBuilder::zeroed(vnode_count), vec![]));
1441                bitmap.set(vnode.to_index(), true);
1442                scan_ranges.push(scan_range.to_protobuf());
1443            }
1444        }
1445    }
1446
1447    Ok(partitions
1448        .into_iter()
1449        .map(|(k, (bitmap, scan_ranges))| {
1450            (
1451                k,
1452                TablePartitionInfo {
1453                    vnode_bitmap: bitmap.finish(),
1454                    scan_ranges,
1455                },
1456            )
1457        })
1458        .collect())
1459}
1460
1461#[cfg(test)]
1462mod tests {
1463    use std::collections::{HashMap, HashSet};
1464
1465    use risingwave_pb::batch_plan::plan_node::NodeBody;
1466
1467    use crate::optimizer::plan_node::BatchPlanNodeType;
1468    use crate::scheduler::plan_fragmenter::StageId;
1469
1470    #[tokio::test]
1471    async fn test_fragmenter() {
1472        let query = crate::scheduler::distributed::tests::create_query().await;
1473
1474        assert_eq!(query.stage_graph.root_stage_id, 0.into());
1475        assert_eq!(query.stage_graph.stages.len(), 4);
1476
1477        // Check the mappings of child edges.
1478        assert_eq!(
1479            query.stage_graph.child_edges[&0.into()],
1480            HashSet::from_iter([1.into()])
1481        );
1482        assert_eq!(
1483            query.stage_graph.child_edges[&1.into()],
1484            HashSet::from_iter([2.into(), 3.into()])
1485        );
1486        assert_eq!(query.stage_graph.child_edges[&2.into()], HashSet::new());
1487        assert_eq!(query.stage_graph.child_edges[&3.into()], HashSet::new());
1488
1489        // Check the mappings of parent edges.
1490        assert_eq!(query.stage_graph.parent_edges[&0.into()], HashSet::new());
1491        assert_eq!(
1492            query.stage_graph.parent_edges[&1.into()],
1493            HashSet::from_iter([0.into()])
1494        );
1495        assert_eq!(
1496            query.stage_graph.parent_edges[&2.into()],
1497            HashSet::from_iter([1.into()])
1498        );
1499        assert_eq!(
1500            query.stage_graph.parent_edges[&3.into()],
1501            HashSet::from_iter([1.into()])
1502        );
1503
1504        // Verify topology order
1505        {
1506            let stage_id_to_pos: HashMap<StageId, usize> = query
1507                .stage_graph
1508                .stage_ids_by_topo_order()
1509                .enumerate()
1510                .map(|(pos, stage_id)| (stage_id, pos))
1511                .collect();
1512
1513            for stage_id in query.stage_graph.stages.keys() {
1514                let stage_pos = stage_id_to_pos[stage_id];
1515                for child_stage_id in &query.stage_graph.child_edges[stage_id] {
1516                    let child_pos = stage_id_to_pos[child_stage_id];
1517                    assert!(stage_pos > child_pos);
1518                }
1519            }
1520        }
1521
1522        // Check plan node in each stages.
1523        let root_exchange = query.stage_graph.stages.get(&0.into()).unwrap();
1524        assert_eq!(
1525            root_exchange.root.node_type(),
1526            BatchPlanNodeType::BatchExchange
1527        );
1528        assert_eq!(root_exchange.root.source_stage_id, Some(1.into()));
1529        assert!(matches!(root_exchange.root.node, NodeBody::Exchange(_)));
1530        assert_eq!(root_exchange.parallelism, Some(1));
1531        assert!(!root_exchange.has_table_scan());
1532
1533        let join_node = query.stage_graph.stages.get(&1.into()).unwrap();
1534        assert_eq!(join_node.root.node_type(), BatchPlanNodeType::BatchHashJoin);
1535        assert_eq!(join_node.parallelism, Some(24));
1536
1537        assert!(matches!(join_node.root.node, NodeBody::HashJoin(_)));
1538        assert_eq!(join_node.root.source_stage_id, None);
1539        assert_eq!(2, join_node.root.children.len());
1540
1541        assert!(matches!(
1542            join_node.root.children[0].node,
1543            NodeBody::Exchange(_)
1544        ));
1545        assert_eq!(join_node.root.children[0].source_stage_id, Some(2.into()));
1546        assert_eq!(0, join_node.root.children[0].children.len());
1547
1548        assert!(matches!(
1549            join_node.root.children[1].node,
1550            NodeBody::Exchange(_)
1551        ));
1552        assert_eq!(join_node.root.children[1].source_stage_id, Some(3.into()));
1553        assert_eq!(0, join_node.root.children[1].children.len());
1554        assert!(!join_node.has_table_scan());
1555
1556        let scan_node1 = query.stage_graph.stages.get(&2.into()).unwrap();
1557        assert_eq!(scan_node1.root.node_type(), BatchPlanNodeType::BatchSeqScan);
1558        assert_eq!(scan_node1.root.source_stage_id, None);
1559        assert_eq!(0, scan_node1.root.children.len());
1560        assert!(scan_node1.has_table_scan());
1561
1562        let scan_node2 = query.stage_graph.stages.get(&3.into()).unwrap();
1563        assert_eq!(scan_node2.root.node_type(), BatchPlanNodeType::BatchFilter);
1564        assert_eq!(scan_node2.root.source_stage_id, None);
1565        assert_eq!(1, scan_node2.root.children.len());
1566        assert!(scan_node2.has_table_scan());
1567    }
1568}