risingwave_frontend/scheduler/
plan_fragmenter.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::min;
16use std::collections::{HashMap, HashSet};
17use std::fmt::{Debug, Formatter};
18use std::num::NonZeroU64;
19use std::sync::Arc;
20
21use anyhow::anyhow;
22use async_recursion::async_recursion;
23use enum_as_inner::EnumAsInner;
24use futures::TryStreamExt;
25use iceberg::expr::Predicate as IcebergPredicate;
26use itertools::Itertools;
27use petgraph::{Directed, Graph};
28use pgwire::pg_server::SessionId;
29use risingwave_batch::error::BatchError;
30use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeSelector;
31use risingwave_common::bitmap::{Bitmap, BitmapBuilder};
32use risingwave_common::catalog::Schema;
33use risingwave_common::hash::table_distribution::TableDistribution;
34use risingwave_common::hash::{WorkerSlotId, WorkerSlotMapping};
35use risingwave_common::util::scan_range::ScanRange;
36use risingwave_connector::source::filesystem::opendal_source::opendal_enumerator::OpendalEnumerator;
37use risingwave_connector::source::filesystem::opendal_source::{
38    BatchPosixFsEnumerator, OpendalAzblob, OpendalGcs, OpendalS3,
39};
40use risingwave_connector::source::iceberg::IcebergSplitEnumerator;
41use risingwave_connector::source::kafka::KafkaSplitEnumerator;
42use risingwave_connector::source::prelude::DatagenSplitEnumerator;
43use risingwave_connector::source::reader::reader::build_opendal_fs_list_for_batch;
44use risingwave_connector::source::{
45    ConnectorProperties, SourceEnumeratorContext, SplitEnumerator, SplitImpl,
46};
47use risingwave_pb::batch_plan::iceberg_scan_node::IcebergScanType;
48use risingwave_pb::batch_plan::plan_node::NodeBody;
49use risingwave_pb::batch_plan::{ExchangeInfo, ScanRange as ScanRangeProto};
50use risingwave_pb::plan_common::Field as PbField;
51use risingwave_sqlparser::ast::AsOf;
52use serde::Serialize;
53use serde::ser::SerializeStruct;
54use uuid::Uuid;
55
56use super::SchedulerError;
57use crate::TableCatalog;
58use crate::catalog::TableId;
59use crate::catalog::catalog_service::CatalogReader;
60use crate::optimizer::plan_node::generic::{GenericPlanRef, PhysicalPlanRef};
61use crate::optimizer::plan_node::utils::to_iceberg_time_travel_as_of;
62use crate::optimizer::plan_node::{
63    BatchIcebergScan, BatchKafkaScan, BatchPlanNodeType, BatchPlanRef as PlanRef, BatchSource,
64    PlanNodeId,
65};
66use crate::optimizer::property::Distribution;
67use crate::scheduler::SchedulerResult;
68
69#[derive(Clone, Debug, Hash, Eq, PartialEq)]
70pub struct QueryId {
71    pub id: String,
72}
73
74impl std::fmt::Display for QueryId {
75    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
76        write!(f, "QueryId:{}", self.id)
77    }
78}
79
80pub type StageId = u32;
81
82// Root stage always has only one task.
83pub const ROOT_TASK_ID: u64 = 0;
84// Root task has only one output.
85pub const ROOT_TASK_OUTPUT_ID: u64 = 0;
86pub type TaskId = u64;
87
88/// Generated by [`BatchPlanFragmenter`] and used in query execution graph.
89#[derive(Clone, Debug)]
90pub struct ExecutionPlanNode {
91    pub plan_node_id: PlanNodeId,
92    pub plan_node_type: BatchPlanNodeType,
93    pub node: NodeBody,
94    pub schema: Vec<PbField>,
95
96    pub children: Vec<Arc<ExecutionPlanNode>>,
97
98    /// The stage id of the source of `BatchExchange`.
99    /// Used to find `ExchangeSource` from scheduler when creating `PlanNode`.
100    ///
101    /// `None` when this node is not `BatchExchange`.
102    pub source_stage_id: Option<StageId>,
103}
104
105impl Serialize for ExecutionPlanNode {
106    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
107    where
108        S: serde::Serializer,
109    {
110        let mut state = serializer.serialize_struct("QueryStage", 5)?;
111        state.serialize_field("plan_node_id", &self.plan_node_id)?;
112        state.serialize_field("plan_node_type", &self.plan_node_type)?;
113        state.serialize_field("schema", &self.schema)?;
114        state.serialize_field("children", &self.children)?;
115        state.serialize_field("source_stage_id", &self.source_stage_id)?;
116        state.end()
117    }
118}
119
120impl TryFrom<PlanRef> for ExecutionPlanNode {
121    type Error = SchedulerError;
122
123    fn try_from(plan_node: PlanRef) -> Result<Self, Self::Error> {
124        Ok(Self {
125            plan_node_id: plan_node.plan_base().id(),
126            plan_node_type: plan_node.node_type(),
127            node: plan_node.try_to_batch_prost_body()?,
128            children: vec![],
129            schema: plan_node.schema().to_prost(),
130            source_stage_id: None,
131        })
132    }
133}
134
135impl ExecutionPlanNode {
136    pub fn node_type(&self) -> BatchPlanNodeType {
137        self.plan_node_type
138    }
139}
140
141/// `BatchPlanFragmenter` splits a query plan into fragments.
142pub struct BatchPlanFragmenter {
143    query_id: QueryId,
144    next_stage_id: StageId,
145    worker_node_manager: WorkerNodeSelector,
146    catalog_reader: CatalogReader,
147
148    batch_parallelism: usize,
149    timezone: String,
150
151    stage_graph_builder: Option<StageGraphBuilder>,
152    stage_graph: Option<StageGraph>,
153}
154
155impl Default for QueryId {
156    fn default() -> Self {
157        Self {
158            id: Uuid::new_v4().to_string(),
159        }
160    }
161}
162
163impl BatchPlanFragmenter {
164    pub fn new(
165        worker_node_manager: WorkerNodeSelector,
166        catalog_reader: CatalogReader,
167        batch_parallelism: Option<NonZeroU64>,
168        timezone: String,
169        batch_node: PlanRef,
170    ) -> SchedulerResult<Self> {
171        // if batch_parallelism is None, it means no limit, we will use the available nodes count as
172        // parallelism.
173        // if batch_parallelism is Some(num), we will use the min(num, the available
174        // nodes count) as parallelism.
175        let batch_parallelism = if let Some(num) = batch_parallelism {
176            // can be 0 if no available serving worker
177            min(
178                num.get() as usize,
179                worker_node_manager.schedule_unit_count(),
180            )
181        } else {
182            // can be 0 if no available serving worker
183            worker_node_manager.schedule_unit_count()
184        };
185
186        let mut plan_fragmenter = Self {
187            query_id: Default::default(),
188            next_stage_id: 0,
189            worker_node_manager,
190            catalog_reader,
191            batch_parallelism,
192            timezone,
193            stage_graph_builder: Some(StageGraphBuilder::new(batch_parallelism)),
194            stage_graph: None,
195        };
196        plan_fragmenter.split_into_stage(batch_node)?;
197        Ok(plan_fragmenter)
198    }
199
200    /// Split the plan node into each stages, based on exchange node.
201    fn split_into_stage(&mut self, batch_node: PlanRef) -> SchedulerResult<()> {
202        let root_stage = self.new_stage(
203            batch_node,
204            Some(Distribution::Single.to_prost(
205                1,
206                &self.catalog_reader,
207                &self.worker_node_manager,
208            )?),
209        )?;
210        self.stage_graph = Some(
211            self.stage_graph_builder
212                .take()
213                .unwrap()
214                .build(root_stage.id),
215        );
216        Ok(())
217    }
218}
219
220/// The fragmented query generated by [`BatchPlanFragmenter`].
221#[derive(Debug)]
222#[cfg_attr(test, derive(Clone))]
223pub struct Query {
224    /// Query id should always be unique.
225    pub query_id: QueryId,
226    pub stage_graph: StageGraph,
227}
228
229impl Query {
230    pub fn leaf_stages(&self) -> Vec<StageId> {
231        let mut ret_leaf_stages = Vec::new();
232        for stage_id in self.stage_graph.stages.keys() {
233            if self
234                .stage_graph
235                .get_child_stages_unchecked(stage_id)
236                .is_empty()
237            {
238                ret_leaf_stages.push(*stage_id);
239            }
240        }
241        ret_leaf_stages
242    }
243
244    pub fn get_parents(&self, stage_id: &StageId) -> &HashSet<StageId> {
245        self.stage_graph.parent_edges.get(stage_id).unwrap()
246    }
247
248    pub fn root_stage_id(&self) -> StageId {
249        self.stage_graph.root_stage_id
250    }
251
252    pub fn query_id(&self) -> &QueryId {
253        &self.query_id
254    }
255
256    pub fn stages_with_table_scan(&self) -> HashSet<StageId> {
257        self.stage_graph
258            .stages
259            .iter()
260            .filter_map(|(stage_id, stage_query)| {
261                if stage_query.has_table_scan() {
262                    Some(*stage_id)
263                } else {
264                    None
265                }
266            })
267            .collect()
268    }
269
270    pub fn has_lookup_join_stage(&self) -> bool {
271        self.stage_graph
272            .stages
273            .iter()
274            .any(|(_stage_id, stage_query)| stage_query.has_lookup_join())
275    }
276}
277
278#[derive(Debug, Clone)]
279pub enum SourceFetchParameters {
280    IcebergSpecificInfo(IcebergSpecificInfo),
281    KafkaTimebound {
282        lower: Option<i64>,
283        upper: Option<i64>,
284    },
285    Empty,
286}
287
288#[derive(Debug, Clone)]
289pub struct SourceFetchInfo {
290    pub schema: Schema,
291    /// These are user-configured connector properties.
292    /// e.g. host, username, etc...
293    pub connector: ConnectorProperties,
294    /// These parameters are internally derived by the plan node.
295    /// e.g. predicate pushdown for iceberg, timebound for kafka.
296    pub fetch_parameters: SourceFetchParameters,
297    pub as_of: Option<AsOf>,
298}
299
300#[derive(Debug, Clone)]
301pub struct IcebergSpecificInfo {
302    pub iceberg_scan_type: IcebergScanType,
303    pub predicate: IcebergPredicate,
304}
305
306#[derive(Clone, Debug)]
307pub enum SourceScanInfo {
308    /// Split Info
309    Incomplete(SourceFetchInfo),
310    Complete(Vec<SplitImpl>),
311}
312
313impl SourceScanInfo {
314    pub fn new(fetch_info: SourceFetchInfo) -> Self {
315        Self::Incomplete(fetch_info)
316    }
317
318    pub async fn complete(
319        self,
320        batch_parallelism: usize,
321        timezone: String,
322    ) -> SchedulerResult<Self> {
323        let fetch_info = match self {
324            SourceScanInfo::Incomplete(fetch_info) => fetch_info,
325            SourceScanInfo::Complete(_) => {
326                unreachable!("Never call complete when SourceScanInfo is already complete")
327            }
328        };
329        match (fetch_info.connector, fetch_info.fetch_parameters) {
330            (
331                ConnectorProperties::Kafka(prop),
332                SourceFetchParameters::KafkaTimebound { lower, upper },
333            ) => {
334                let mut kafka_enumerator =
335                    KafkaSplitEnumerator::new(*prop, SourceEnumeratorContext::dummy().into())
336                        .await?;
337                let split_info = kafka_enumerator
338                    .list_splits_batch(lower, upper)
339                    .await?
340                    .into_iter()
341                    .map(SplitImpl::Kafka)
342                    .collect_vec();
343
344                Ok(SourceScanInfo::Complete(split_info))
345            }
346            (ConnectorProperties::Datagen(prop), SourceFetchParameters::Empty) => {
347                let mut datagen_enumerator =
348                    DatagenSplitEnumerator::new(*prop, SourceEnumeratorContext::dummy().into())
349                        .await?;
350                let split_info = datagen_enumerator.list_splits().await?;
351                let res = split_info.into_iter().map(SplitImpl::Datagen).collect_vec();
352
353                Ok(SourceScanInfo::Complete(res))
354            }
355            (ConnectorProperties::OpendalS3(prop), SourceFetchParameters::Empty) => {
356                let lister: OpendalEnumerator<OpendalS3> = OpendalEnumerator::new_s3_source(
357                    &prop.s3_properties,
358                    prop.assume_role,
359                    prop.fs_common.compression_format,
360                )?;
361                let stream = build_opendal_fs_list_for_batch(lister);
362
363                let batch_res: Vec<_> = stream.try_collect().await?;
364                let res = batch_res
365                    .into_iter()
366                    .map(SplitImpl::OpendalS3)
367                    .collect_vec();
368
369                Ok(SourceScanInfo::Complete(res))
370            }
371            (ConnectorProperties::Gcs(prop), SourceFetchParameters::Empty) => {
372                let lister: OpendalEnumerator<OpendalGcs> =
373                    OpendalEnumerator::new_gcs_source(*prop)?;
374                let stream = build_opendal_fs_list_for_batch(lister);
375                let batch_res: Vec<_> = stream.try_collect().await?;
376                let res = batch_res.into_iter().map(SplitImpl::Gcs).collect_vec();
377
378                Ok(SourceScanInfo::Complete(res))
379            }
380            (ConnectorProperties::Azblob(prop), SourceFetchParameters::Empty) => {
381                let lister: OpendalEnumerator<OpendalAzblob> =
382                    OpendalEnumerator::new_azblob_source(*prop)?;
383                let stream = build_opendal_fs_list_for_batch(lister);
384                let batch_res: Vec<_> = stream.try_collect().await?;
385                let res = batch_res.into_iter().map(SplitImpl::Azblob).collect_vec();
386
387                Ok(SourceScanInfo::Complete(res))
388            }
389            (ConnectorProperties::BatchPosixFs(prop), SourceFetchParameters::Empty) => {
390                use risingwave_connector::source::SplitEnumerator;
391                let mut enumerator = BatchPosixFsEnumerator::new(
392                    *prop,
393                    risingwave_connector::source::SourceEnumeratorContext::dummy().into(),
394                )
395                .await?;
396                let splits = enumerator.list_splits().await?;
397                let res = splits
398                    .into_iter()
399                    .map(SplitImpl::BatchPosixFs)
400                    .collect_vec();
401
402                Ok(SourceScanInfo::Complete(res))
403            }
404            (
405                ConnectorProperties::Iceberg(prop),
406                SourceFetchParameters::IcebergSpecificInfo(iceberg_specific_info),
407            ) => {
408                let iceberg_enumerator =
409                    IcebergSplitEnumerator::new(*prop, SourceEnumeratorContext::dummy().into())
410                        .await?;
411
412                let time_travel_info = to_iceberg_time_travel_as_of(&fetch_info.as_of, &timezone)?;
413
414                let split_info = iceberg_enumerator
415                    .list_splits_batch(
416                        fetch_info.schema,
417                        time_travel_info,
418                        batch_parallelism,
419                        iceberg_specific_info.iceberg_scan_type,
420                        iceberg_specific_info.predicate,
421                    )
422                    .await?
423                    .into_iter()
424                    .map(SplitImpl::Iceberg)
425                    .collect_vec();
426
427                Ok(SourceScanInfo::Complete(split_info))
428            }
429            (connector, _) => Err(SchedulerError::Internal(anyhow!(
430                "Unsupported to query directly from this {} source, \
431                 please create a table or streaming job from it",
432                connector.kind()
433            ))),
434        }
435    }
436
437    pub fn split_info(&self) -> SchedulerResult<&Vec<SplitImpl>> {
438        match self {
439            Self::Incomplete(_) => Err(SchedulerError::Internal(anyhow!(
440                "Should not get split info from incomplete source scan info"
441            ))),
442            Self::Complete(split_info) => Ok(split_info),
443        }
444    }
445}
446
447#[derive(Clone, Debug)]
448pub struct TableScanInfo {
449    /// The name of the table to scan.
450    name: String,
451
452    /// Indicates the table partitions to be read by scan tasks. Unnecessary partitions are already
453    /// pruned.
454    ///
455    /// For singleton table, this field is still `Some` and only contains a single partition with
456    /// full vnode bitmap, since we need to know where to schedule the singleton scan task.
457    ///
458    /// `None` iff the table is a system table.
459    partitions: Option<HashMap<WorkerSlotId, TablePartitionInfo>>,
460}
461
462impl TableScanInfo {
463    /// For normal tables, `partitions` should always be `Some`.
464    pub fn new(name: String, partitions: HashMap<WorkerSlotId, TablePartitionInfo>) -> Self {
465        Self {
466            name,
467            partitions: Some(partitions),
468        }
469    }
470
471    /// For system table, there's no partition info.
472    pub fn system_table(name: String) -> Self {
473        Self {
474            name,
475            partitions: None,
476        }
477    }
478
479    pub fn name(&self) -> &str {
480        self.name.as_ref()
481    }
482
483    pub fn partitions(&self) -> Option<&HashMap<WorkerSlotId, TablePartitionInfo>> {
484        self.partitions.as_ref()
485    }
486}
487
488#[derive(Clone, Debug)]
489pub struct TablePartitionInfo {
490    pub vnode_bitmap: Bitmap,
491    pub scan_ranges: Vec<ScanRangeProto>,
492}
493
494#[derive(Clone, Debug, EnumAsInner)]
495pub enum PartitionInfo {
496    Table(TablePartitionInfo),
497    Source(Vec<SplitImpl>),
498    File(Vec<String>),
499}
500
501#[derive(Clone, Debug)]
502pub struct FileScanInfo {
503    pub file_location: Vec<String>,
504}
505
506/// Fragment part of `Query`.
507#[derive(Clone)]
508pub struct QueryStage {
509    pub query_id: QueryId,
510    pub id: StageId,
511    pub root: Arc<ExecutionPlanNode>,
512    pub exchange_info: Option<ExchangeInfo>,
513    pub parallelism: Option<u32>,
514    /// Indicates whether this stage contains a table scan node and the table's information if so.
515    pub table_scan_info: Option<TableScanInfo>,
516    pub source_info: Option<SourceScanInfo>,
517    pub file_scan_info: Option<FileScanInfo>,
518    pub has_lookup_join: bool,
519    pub has_vector_search: bool,
520    pub dml_table_id: Option<TableId>,
521    pub session_id: SessionId,
522    pub batch_enable_distributed_dml: bool,
523
524    /// Used to generate exchange information when complete source scan information.
525    children_exchange_distribution: Option<HashMap<StageId, Distribution>>,
526}
527
528impl QueryStage {
529    /// If true, this stage contains table scan executor that creates
530    /// Hummock iterators to read data from table. The iterator is initialized during
531    /// the executor building process on the batch execution engine.
532    pub fn has_table_scan(&self) -> bool {
533        self.table_scan_info.is_some()
534    }
535
536    /// If true, this stage contains lookup join executor.
537    /// We need to delay epoch unpin util the end of the query.
538    pub fn has_lookup_join(&self) -> bool {
539        self.has_lookup_join
540    }
541
542    pub fn clone_with_exchange_info(
543        &self,
544        exchange_info: Option<ExchangeInfo>,
545        parallelism: Option<u32>,
546    ) -> Self {
547        if let Some(exchange_info) = exchange_info {
548            return Self {
549                query_id: self.query_id.clone(),
550                id: self.id,
551                root: self.root.clone(),
552                exchange_info: Some(exchange_info),
553                parallelism,
554                table_scan_info: self.table_scan_info.clone(),
555                source_info: self.source_info.clone(),
556                file_scan_info: self.file_scan_info.clone(),
557                has_lookup_join: self.has_lookup_join,
558                has_vector_search: self.has_vector_search,
559                dml_table_id: self.dml_table_id,
560                session_id: self.session_id,
561                batch_enable_distributed_dml: self.batch_enable_distributed_dml,
562                children_exchange_distribution: self.children_exchange_distribution.clone(),
563            };
564        }
565        self.clone()
566    }
567
568    pub fn clone_with_exchange_info_and_complete_source_info(
569        &self,
570        exchange_info: Option<ExchangeInfo>,
571        source_info: SourceScanInfo,
572        task_parallelism: u32,
573    ) -> Self {
574        assert!(matches!(source_info, SourceScanInfo::Complete(_)));
575        let exchange_info = if let Some(exchange_info) = exchange_info {
576            Some(exchange_info)
577        } else {
578            self.exchange_info.clone()
579        };
580        Self {
581            query_id: self.query_id.clone(),
582            id: self.id,
583            root: self.root.clone(),
584            exchange_info,
585            parallelism: Some(task_parallelism),
586            table_scan_info: self.table_scan_info.clone(),
587            source_info: Some(source_info),
588            file_scan_info: self.file_scan_info.clone(),
589            has_lookup_join: self.has_lookup_join,
590            has_vector_search: self.has_vector_search,
591            dml_table_id: self.dml_table_id,
592            session_id: self.session_id,
593            batch_enable_distributed_dml: self.batch_enable_distributed_dml,
594            children_exchange_distribution: None,
595        }
596    }
597}
598
599impl Debug for QueryStage {
600    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
601        f.debug_struct("QueryStage")
602            .field("id", &self.id)
603            .field("parallelism", &self.parallelism)
604            .field("exchange_info", &self.exchange_info)
605            .field("has_table_scan", &self.has_table_scan())
606            .finish()
607    }
608}
609
610impl Serialize for QueryStage {
611    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
612    where
613        S: serde::Serializer,
614    {
615        let mut state = serializer.serialize_struct("QueryStage", 3)?;
616        state.serialize_field("root", &self.root)?;
617        state.serialize_field("parallelism", &self.parallelism)?;
618        state.serialize_field("exchange_info", &self.exchange_info)?;
619        state.end()
620    }
621}
622
623pub type QueryStageRef = Arc<QueryStage>;
624
625struct QueryStageBuilder {
626    query_id: QueryId,
627    id: StageId,
628    root: Option<Arc<ExecutionPlanNode>>,
629    parallelism: Option<u32>,
630    exchange_info: Option<ExchangeInfo>,
631
632    children_stages: Vec<QueryStageRef>,
633    /// See also [`QueryStage::table_scan_info`].
634    table_scan_info: Option<TableScanInfo>,
635    source_info: Option<SourceScanInfo>,
636    file_scan_file: Option<FileScanInfo>,
637    has_lookup_join: bool,
638    has_vector_search: bool,
639    dml_table_id: Option<TableId>,
640    session_id: SessionId,
641    batch_enable_distributed_dml: bool,
642
643    children_exchange_distribution: HashMap<StageId, Distribution>,
644}
645
646impl QueryStageBuilder {
647    #[allow(clippy::too_many_arguments)]
648    fn new(
649        id: StageId,
650        query_id: QueryId,
651        parallelism: Option<u32>,
652        exchange_info: Option<ExchangeInfo>,
653        table_scan_info: Option<TableScanInfo>,
654        source_info: Option<SourceScanInfo>,
655        file_scan_file: Option<FileScanInfo>,
656        has_lookup_join: bool,
657        has_vector_search: bool,
658        dml_table_id: Option<TableId>,
659        session_id: SessionId,
660        batch_enable_distributed_dml: bool,
661    ) -> Self {
662        Self {
663            query_id,
664            id,
665            root: None,
666            parallelism,
667            exchange_info,
668            children_stages: vec![],
669            table_scan_info,
670            source_info,
671            file_scan_file,
672            has_lookup_join,
673            has_vector_search,
674            dml_table_id,
675            session_id,
676            batch_enable_distributed_dml,
677            children_exchange_distribution: HashMap::new(),
678        }
679    }
680
681    fn finish(self, stage_graph_builder: &mut StageGraphBuilder) -> QueryStageRef {
682        let children_exchange_distribution = if self.parallelism.is_none() {
683            Some(self.children_exchange_distribution)
684        } else {
685            None
686        };
687        let stage = Arc::new(QueryStage {
688            query_id: self.query_id,
689            id: self.id,
690            root: self.root.unwrap(),
691            exchange_info: self.exchange_info,
692            parallelism: self.parallelism,
693            table_scan_info: self.table_scan_info,
694            source_info: self.source_info,
695            file_scan_info: self.file_scan_file,
696            has_lookup_join: self.has_lookup_join,
697            has_vector_search: self.has_vector_search,
698            dml_table_id: self.dml_table_id,
699            session_id: self.session_id,
700            batch_enable_distributed_dml: self.batch_enable_distributed_dml,
701            children_exchange_distribution,
702        });
703
704        stage_graph_builder.add_node(stage.clone());
705        for child_stage in self.children_stages {
706            stage_graph_builder.link_to_child(self.id, child_stage.id);
707        }
708        stage
709    }
710}
711
712/// Maintains how each stage are connected.
713#[derive(Debug, Serialize)]
714#[cfg_attr(test, derive(Clone))]
715pub struct StageGraph {
716    pub root_stage_id: StageId,
717    pub stages: HashMap<StageId, QueryStageRef>,
718    /// Traverse from top to down. Used in split plan into stages.
719    child_edges: HashMap<StageId, HashSet<StageId>>,
720    /// Traverse from down to top. Used in schedule each stage.
721    parent_edges: HashMap<StageId, HashSet<StageId>>,
722
723    batch_parallelism: usize,
724}
725
726impl StageGraph {
727    pub fn get_child_stages_unchecked(&self, stage_id: &StageId) -> &HashSet<StageId> {
728        self.child_edges.get(stage_id).unwrap()
729    }
730
731    pub fn get_child_stages(&self, stage_id: &StageId) -> Option<&HashSet<StageId>> {
732        self.child_edges.get(stage_id)
733    }
734
735    /// Returns stage ids in topology order, s.t. child stage always appears before its parent.
736    pub fn stage_ids_by_topo_order(&self) -> impl Iterator<Item = StageId> {
737        let mut stack = Vec::with_capacity(self.stages.len());
738        stack.push(self.root_stage_id);
739        let mut ret = Vec::with_capacity(self.stages.len());
740        let mut existing = HashSet::with_capacity(self.stages.len());
741
742        while let Some(s) = stack.pop() {
743            if !existing.contains(&s) {
744                ret.push(s);
745                existing.insert(s);
746                stack.extend(&self.child_edges[&s]);
747            }
748        }
749
750        ret.into_iter().rev()
751    }
752
753    async fn complete(
754        self,
755        catalog_reader: &CatalogReader,
756        worker_node_manager: &WorkerNodeSelector,
757        timezone: String,
758    ) -> SchedulerResult<StageGraph> {
759        let mut complete_stages = HashMap::new();
760        self.complete_stage(
761            self.stages.get(&self.root_stage_id).unwrap().clone(),
762            None,
763            &mut complete_stages,
764            catalog_reader,
765            worker_node_manager,
766            timezone,
767        )
768        .await?;
769        Ok(StageGraph {
770            root_stage_id: self.root_stage_id,
771            stages: complete_stages,
772            child_edges: self.child_edges,
773            parent_edges: self.parent_edges,
774            batch_parallelism: self.batch_parallelism,
775        })
776    }
777
778    #[async_recursion]
779    async fn complete_stage(
780        &self,
781        stage: QueryStageRef,
782        exchange_info: Option<ExchangeInfo>,
783        complete_stages: &mut HashMap<StageId, QueryStageRef>,
784        catalog_reader: &CatalogReader,
785        worker_node_manager: &WorkerNodeSelector,
786        timezone: String,
787    ) -> SchedulerResult<()> {
788        let parallelism = if stage.parallelism.is_some() {
789            // If the stage has parallelism, it means it's a complete stage.
790            complete_stages.insert(
791                stage.id,
792                Arc::new(stage.clone_with_exchange_info(exchange_info, stage.parallelism)),
793            );
794            None
795        } else if matches!(stage.source_info, Some(SourceScanInfo::Incomplete(_))) {
796            let complete_source_info = stage
797                .source_info
798                .as_ref()
799                .unwrap()
800                .clone()
801                .complete(self.batch_parallelism, timezone.to_owned())
802                .await?;
803
804            // For batch reading file source, the number of files involved is typically large.
805            // In order to avoid generating a task for each file, the parallelism of tasks is limited here.
806            // The minimum `task_parallelism` is 1. Additionally, `task_parallelism`
807            // must be greater than the number of files to read. Therefore, we first take the
808            // minimum of the number of files and (self.batch_parallelism / 2). If the number of
809            // files is 0, we set task_parallelism to 1.
810
811            let task_parallelism = match &stage.source_info {
812                Some(SourceScanInfo::Incomplete(source_fetch_info)) => {
813                    match source_fetch_info.connector {
814                        ConnectorProperties::Gcs(_)
815                        | ConnectorProperties::OpendalS3(_)
816                        | ConnectorProperties::Azblob(_) => (min(
817                            complete_source_info.split_info().unwrap().len() as u32,
818                            (self.batch_parallelism / 2) as u32,
819                        ))
820                        .max(1),
821                        _ => complete_source_info.split_info().unwrap().len() as u32,
822                    }
823                }
824                _ => unreachable!(),
825            };
826            // For file source batch read, all the files  to be read are divide into several parts to prevent the task from taking up too many resources.
827            // todo(wcy-fdu): Currently it will be divided into half of batch_parallelism groups, and this will be changed to configurable later.
828            let complete_stage = Arc::new(stage.clone_with_exchange_info_and_complete_source_info(
829                exchange_info,
830                complete_source_info,
831                task_parallelism,
832            ));
833            let parallelism = complete_stage.parallelism;
834            complete_stages.insert(stage.id, complete_stage);
835            parallelism
836        } else {
837            assert!(stage.file_scan_info.is_some());
838            let parallelism = min(
839                self.batch_parallelism / 2,
840                stage.file_scan_info.as_ref().unwrap().file_location.len(),
841            );
842            complete_stages.insert(
843                stage.id,
844                Arc::new(stage.clone_with_exchange_info(exchange_info, Some(parallelism as u32))),
845            );
846            None
847        };
848
849        for child_stage_id in self.child_edges.get(&stage.id).unwrap_or(&HashSet::new()) {
850            let exchange_info = if let Some(parallelism) = parallelism {
851                let exchange_distribution = stage
852                    .children_exchange_distribution
853                    .as_ref()
854                    .unwrap()
855                    .get(child_stage_id)
856                    .expect("Exchange distribution is not consistent with the stage graph");
857                Some(exchange_distribution.to_prost(
858                    parallelism,
859                    catalog_reader,
860                    worker_node_manager,
861                )?)
862            } else {
863                None
864            };
865            self.complete_stage(
866                self.stages.get(child_stage_id).unwrap().clone(),
867                exchange_info,
868                complete_stages,
869                catalog_reader,
870                worker_node_manager,
871                timezone.to_owned(),
872            )
873            .await?;
874        }
875
876        Ok(())
877    }
878
879    /// Converts the `StageGraph` into a `petgraph::graph::Graph<String, String>`.
880    pub fn to_petgraph(&self) -> Graph<String, String, Directed> {
881        let mut graph = Graph::<String, String, Directed>::new();
882
883        let mut node_indices = HashMap::new();
884
885        // Add all stages as nodes
886        for (&stage_id, stage_ref) in self.stages.iter().sorted_by_key(|(id, _)| **id) {
887            let node_label = format!("Stage {}: {:?}", stage_id, stage_ref);
888            let node_index = graph.add_node(node_label);
889            node_indices.insert(stage_id, node_index);
890        }
891
892        // Add edges between stages based on child_edges
893        for (&parent_id, children) in &self.child_edges {
894            if let Some(&parent_index) = node_indices.get(&parent_id) {
895                for &child_id in children {
896                    if let Some(&child_index) = node_indices.get(&child_id) {
897                        // Add an edge from parent to child
898                        graph.add_edge(parent_index, child_index, "".to_owned());
899                    }
900                }
901            }
902        }
903
904        graph
905    }
906}
907
908struct StageGraphBuilder {
909    stages: HashMap<StageId, QueryStageRef>,
910    child_edges: HashMap<StageId, HashSet<StageId>>,
911    parent_edges: HashMap<StageId, HashSet<StageId>>,
912    batch_parallelism: usize,
913}
914
915impl StageGraphBuilder {
916    pub fn new(batch_parallelism: usize) -> Self {
917        Self {
918            stages: HashMap::new(),
919            child_edges: HashMap::new(),
920            parent_edges: HashMap::new(),
921            batch_parallelism,
922        }
923    }
924
925    pub fn build(self, root_stage_id: StageId) -> StageGraph {
926        StageGraph {
927            root_stage_id,
928            stages: self.stages,
929            child_edges: self.child_edges,
930            parent_edges: self.parent_edges,
931            batch_parallelism: self.batch_parallelism,
932        }
933    }
934
935    /// Link parent stage and child stage. Maintain the mappings of parent -> child and child ->
936    /// parent.
937    pub fn link_to_child(&mut self, parent_id: StageId, child_id: StageId) {
938        self.child_edges
939            .get_mut(&parent_id)
940            .unwrap()
941            .insert(child_id);
942        self.parent_edges
943            .get_mut(&child_id)
944            .unwrap()
945            .insert(parent_id);
946    }
947
948    pub fn add_node(&mut self, stage: QueryStageRef) {
949        // Insert here so that left/root stages also has linkage.
950        self.child_edges.insert(stage.id, HashSet::new());
951        self.parent_edges.insert(stage.id, HashSet::new());
952        self.stages.insert(stage.id, stage);
953    }
954}
955
956impl BatchPlanFragmenter {
957    /// After split, the `stage_graph` in the framenter may has the stage with incomplete source
958    /// info, we need to fetch the source info to complete the stage in this function.
959    /// Why separate this two step(`split()` and `generate_complete_query()`)?
960    /// The step of fetching source info is a async operation so that we can't do it in the split
961    /// step.
962    pub async fn generate_complete_query(self) -> SchedulerResult<Query> {
963        let stage_graph = self.stage_graph.unwrap();
964        let new_stage_graph = stage_graph
965            .complete(
966                &self.catalog_reader,
967                &self.worker_node_manager,
968                self.timezone.to_owned(),
969            )
970            .await?;
971        Ok(Query {
972            query_id: self.query_id,
973            stage_graph: new_stage_graph,
974        })
975    }
976
977    fn new_stage(
978        &mut self,
979        root: PlanRef,
980        exchange_info: Option<ExchangeInfo>,
981    ) -> SchedulerResult<QueryStageRef> {
982        let next_stage_id = self.next_stage_id;
983        self.next_stage_id += 1;
984
985        let mut table_scan_info = None;
986        let mut source_info = None;
987        let mut file_scan_info = None;
988        let mut has_vector_search = false;
989
990        // For current implementation, we can guarantee that each stage has only one table
991        // scan(except System table) or one source.
992        if let Some(info) = self.collect_stage_table_scan(root.clone())? {
993            table_scan_info = Some(info);
994        } else if let Some(info) = Self::collect_stage_source(root.clone())? {
995            source_info = Some(info);
996        } else if let Some(info) = Self::collect_stage_file_scan(root.clone())? {
997            file_scan_info = Some(info);
998        } else if Self::collect_stage_vector_search(root.clone()) {
999            has_vector_search = true;
1000        }
1001
1002        let mut has_lookup_join = false;
1003        let parallelism = match root.distribution() {
1004            Distribution::Single => {
1005                if let Some(info) = &mut table_scan_info {
1006                    if let Some(partitions) = &mut info.partitions {
1007                        if partitions.len() != 1 {
1008                            // This is rare case, but it's possible on the internal state of the
1009                            // Source operator.
1010                            tracing::warn!(
1011                                "The stage has single distribution, but contains a scan of table `{}` with {} partitions. A single random worker will be assigned",
1012                                info.name,
1013                                partitions.len()
1014                            );
1015
1016                            *partitions = partitions
1017                                .drain()
1018                                .take(1)
1019                                .update(|(_, info)| {
1020                                    info.vnode_bitmap = Bitmap::ones(info.vnode_bitmap.len());
1021                                })
1022                                .collect();
1023                        }
1024                    } else {
1025                        // System table
1026                    }
1027                } else if source_info.is_some() {
1028                    return Err(SchedulerError::Internal(anyhow!(
1029                        "The stage has single distribution, but contains a source operator"
1030                    )));
1031                }
1032                1
1033            }
1034            _ => {
1035                if let Some(table_scan_info) = &table_scan_info {
1036                    table_scan_info
1037                        .partitions
1038                        .as_ref()
1039                        .map(|m| m.len())
1040                        .unwrap_or(1)
1041                } else if let Some(lookup_join_parallelism) =
1042                    self.collect_stage_lookup_join_parallelism(root.clone())?
1043                {
1044                    has_lookup_join = true;
1045                    lookup_join_parallelism
1046                } else if source_info.is_some() {
1047                    0
1048                } else if file_scan_info.is_some() || has_vector_search {
1049                    1
1050                } else {
1051                    self.batch_parallelism
1052                }
1053            }
1054        };
1055        if source_info.is_none() && file_scan_info.is_none() && parallelism == 0 {
1056            return Err(BatchError::EmptyWorkerNodes.into());
1057        }
1058        let parallelism = if parallelism == 0 {
1059            None
1060        } else {
1061            Some(parallelism as u32)
1062        };
1063        let dml_table_id = Self::collect_dml_table_id(&root);
1064        let mut builder = QueryStageBuilder::new(
1065            next_stage_id,
1066            self.query_id.clone(),
1067            parallelism,
1068            exchange_info,
1069            table_scan_info,
1070            source_info,
1071            file_scan_info,
1072            has_lookup_join,
1073            has_vector_search,
1074            dml_table_id,
1075            root.ctx().session_ctx().session_id(),
1076            root.ctx()
1077                .session_ctx()
1078                .config()
1079                .batch_enable_distributed_dml(),
1080        );
1081
1082        self.visit_node(root, &mut builder, None)?;
1083
1084        Ok(builder.finish(self.stage_graph_builder.as_mut().unwrap()))
1085    }
1086
1087    fn visit_node(
1088        &mut self,
1089        node: PlanRef,
1090        builder: &mut QueryStageBuilder,
1091        parent_exec_node: Option<&mut ExecutionPlanNode>,
1092    ) -> SchedulerResult<()> {
1093        match node.node_type() {
1094            BatchPlanNodeType::BatchExchange => {
1095                self.visit_exchange(node.clone(), builder, parent_exec_node)?;
1096            }
1097            _ => {
1098                let mut execution_plan_node = ExecutionPlanNode::try_from(node.clone())?;
1099
1100                for child in node.inputs() {
1101                    self.visit_node(child, builder, Some(&mut execution_plan_node))?;
1102                }
1103
1104                if let Some(parent) = parent_exec_node {
1105                    parent.children.push(Arc::new(execution_plan_node));
1106                } else {
1107                    builder.root = Some(Arc::new(execution_plan_node));
1108                }
1109            }
1110        }
1111        Ok(())
1112    }
1113
1114    fn visit_exchange(
1115        &mut self,
1116        node: PlanRef,
1117        builder: &mut QueryStageBuilder,
1118        parent_exec_node: Option<&mut ExecutionPlanNode>,
1119    ) -> SchedulerResult<()> {
1120        let mut execution_plan_node = ExecutionPlanNode::try_from(node.clone())?;
1121        let child_exchange_info = if let Some(parallelism) = builder.parallelism {
1122            Some(node.distribution().to_prost(
1123                parallelism,
1124                &self.catalog_reader,
1125                &self.worker_node_manager,
1126            )?)
1127        } else {
1128            None
1129        };
1130        let child_stage = self.new_stage(node.inputs()[0].clone(), child_exchange_info)?;
1131        execution_plan_node.source_stage_id = Some(child_stage.id);
1132        if builder.parallelism.is_none() {
1133            builder
1134                .children_exchange_distribution
1135                .insert(child_stage.id, node.distribution().clone());
1136        }
1137
1138        if let Some(parent) = parent_exec_node {
1139            parent.children.push(Arc::new(execution_plan_node));
1140        } else {
1141            builder.root = Some(Arc::new(execution_plan_node));
1142        }
1143
1144        builder.children_stages.push(child_stage);
1145        Ok(())
1146    }
1147
1148    /// Check whether this stage contains a source node.
1149    /// If so, use  `SplitEnumeratorImpl` to get the split info from exteneral source.
1150    ///
1151    /// For current implementation, we can guarantee that each stage has only one source.
1152    fn collect_stage_source(node: PlanRef) -> SchedulerResult<Option<SourceScanInfo>> {
1153        if node.node_type() == BatchPlanNodeType::BatchExchange {
1154            // Do not visit next stage.
1155            return Ok(None);
1156        }
1157
1158        if let Some(batch_kafka_node) = node.as_batch_kafka_scan() {
1159            let batch_kafka_scan: &BatchKafkaScan = batch_kafka_node;
1160            let source_catalog = batch_kafka_scan.source_catalog();
1161            if let Some(source_catalog) = source_catalog {
1162                let property =
1163                    ConnectorProperties::extract(source_catalog.with_properties.clone(), false)?;
1164                let timestamp_bound = batch_kafka_scan.kafka_timestamp_range_value();
1165                return Ok(Some(SourceScanInfo::new(SourceFetchInfo {
1166                    schema: batch_kafka_scan.base.schema().clone(),
1167                    connector: property,
1168                    fetch_parameters: SourceFetchParameters::KafkaTimebound {
1169                        lower: timestamp_bound.0,
1170                        upper: timestamp_bound.1,
1171                    },
1172                    as_of: None,
1173                })));
1174            }
1175        } else if let Some(batch_iceberg_scan) = node.as_batch_iceberg_scan() {
1176            let batch_iceberg_scan: &BatchIcebergScan = batch_iceberg_scan;
1177            let source_catalog = batch_iceberg_scan.source_catalog();
1178            if let Some(source_catalog) = source_catalog {
1179                let property =
1180                    ConnectorProperties::extract(source_catalog.with_properties.clone(), false)?;
1181                let as_of = batch_iceberg_scan.as_of();
1182                return Ok(Some(SourceScanInfo::new(SourceFetchInfo {
1183                    schema: batch_iceberg_scan.base.schema().clone(),
1184                    connector: property,
1185                    fetch_parameters: SourceFetchParameters::IcebergSpecificInfo(
1186                        IcebergSpecificInfo {
1187                            predicate: batch_iceberg_scan.predicate.clone(),
1188                            iceberg_scan_type: batch_iceberg_scan.iceberg_scan_type(),
1189                        },
1190                    ),
1191                    as_of,
1192                })));
1193            }
1194        } else if let Some(source_node) = node.as_batch_source() {
1195            // TODO: use specific batch operator instead of batch source.
1196            let source_node: &BatchSource = source_node;
1197            let source_catalog = source_node.source_catalog();
1198            if let Some(source_catalog) = source_catalog {
1199                let property =
1200                    ConnectorProperties::extract(source_catalog.with_properties.clone(), false)?;
1201                let as_of = source_node.as_of();
1202                return Ok(Some(SourceScanInfo::new(SourceFetchInfo {
1203                    schema: source_node.base.schema().clone(),
1204                    connector: property,
1205                    fetch_parameters: SourceFetchParameters::Empty,
1206                    as_of,
1207                })));
1208            }
1209        }
1210
1211        node.inputs()
1212            .into_iter()
1213            .find_map(|n| Self::collect_stage_source(n).transpose())
1214            .transpose()
1215    }
1216
1217    fn collect_stage_vector_search(node: PlanRef) -> bool {
1218        if node.node_type() == BatchPlanNodeType::BatchExchange {
1219            // Do not visit next stage.
1220            return false;
1221        }
1222
1223        if node.as_batch_vector_search().is_some() {
1224            return true;
1225        }
1226
1227        node.inputs()
1228            .into_iter()
1229            .any(Self::collect_stage_vector_search)
1230    }
1231
1232    fn collect_stage_file_scan(node: PlanRef) -> SchedulerResult<Option<FileScanInfo>> {
1233        if node.node_type() == BatchPlanNodeType::BatchExchange {
1234            // Do not visit next stage.
1235            return Ok(None);
1236        }
1237
1238        if let Some(batch_file_scan) = node.as_batch_file_scan() {
1239            return Ok(Some(FileScanInfo {
1240                file_location: batch_file_scan.core.file_location().clone(),
1241            }));
1242        }
1243
1244        node.inputs()
1245            .into_iter()
1246            .find_map(|n| Self::collect_stage_file_scan(n).transpose())
1247            .transpose()
1248    }
1249
1250    /// Check whether this stage contains a table scan node and the table's information if so.
1251    ///
1252    /// If there are multiple scan nodes in this stage, they must have the same distribution, but
1253    /// maybe different vnodes partition. We just use the same partition for all the scan nodes.
1254    fn collect_stage_table_scan(&self, node: PlanRef) -> SchedulerResult<Option<TableScanInfo>> {
1255        let build_table_scan_info = |name, table_catalog: &TableCatalog, scan_range| {
1256            let vnode_mapping = self
1257                .worker_node_manager
1258                .fragment_mapping(table_catalog.fragment_id)?;
1259            let partitions = derive_partitions(scan_range, table_catalog, &vnode_mapping)?;
1260            let info = TableScanInfo::new(name, partitions);
1261            Ok(Some(info))
1262        };
1263        if node.node_type() == BatchPlanNodeType::BatchExchange {
1264            // Do not visit next stage.
1265            return Ok(None);
1266        }
1267        if let Some(scan_node) = node.as_batch_sys_seq_scan() {
1268            let name = scan_node.core().table.name.clone();
1269            Ok(Some(TableScanInfo::system_table(name)))
1270        } else if let Some(scan_node) = node.as_batch_log_seq_scan() {
1271            build_table_scan_info(
1272                scan_node.core().table_name.to_owned(),
1273                &scan_node.core().table,
1274                &[],
1275            )
1276        } else if let Some(scan_node) = node.as_batch_seq_scan() {
1277            build_table_scan_info(
1278                scan_node.core().table_name().to_owned(),
1279                &scan_node.core().table_catalog,
1280                scan_node.scan_ranges(),
1281            )
1282        } else {
1283            node.inputs()
1284                .into_iter()
1285                .find_map(|n| self.collect_stage_table_scan(n).transpose())
1286                .transpose()
1287        }
1288    }
1289
1290    /// Returns the dml table id if any.
1291    fn collect_dml_table_id(node: &PlanRef) -> Option<TableId> {
1292        if node.node_type() == BatchPlanNodeType::BatchExchange {
1293            return None;
1294        }
1295        if let Some(insert) = node.as_batch_insert() {
1296            Some(insert.core.table_id)
1297        } else if let Some(update) = node.as_batch_update() {
1298            Some(update.core.table_id)
1299        } else if let Some(delete) = node.as_batch_delete() {
1300            Some(delete.core.table_id)
1301        } else {
1302            node.inputs()
1303                .into_iter()
1304                .find_map(|n| Self::collect_dml_table_id(&n))
1305        }
1306    }
1307
1308    fn collect_stage_lookup_join_parallelism(
1309        &self,
1310        node: PlanRef,
1311    ) -> SchedulerResult<Option<usize>> {
1312        if node.node_type() == BatchPlanNodeType::BatchExchange {
1313            // Do not visit next stage.
1314            return Ok(None);
1315        }
1316        if let Some(lookup_join) = node.as_batch_lookup_join() {
1317            let table_catalog = lookup_join.right_table();
1318            let vnode_mapping = self
1319                .worker_node_manager
1320                .fragment_mapping(table_catalog.fragment_id)?;
1321            let parallelism = vnode_mapping.iter().sorted().dedup().count();
1322            Ok(Some(parallelism))
1323        } else {
1324            node.inputs()
1325                .into_iter()
1326                .find_map(|n| self.collect_stage_lookup_join_parallelism(n).transpose())
1327                .transpose()
1328        }
1329    }
1330}
1331
1332/// Try to derive the partition to read from the scan range.
1333/// It can be derived if the value of the distribution key is already known.
1334fn derive_partitions(
1335    scan_ranges: &[ScanRange],
1336    table_catalog: &TableCatalog,
1337    vnode_mapping: &WorkerSlotMapping,
1338) -> SchedulerResult<HashMap<WorkerSlotId, TablePartitionInfo>> {
1339    let vnode_mapping = if table_catalog.vnode_count.value() != vnode_mapping.len() {
1340        // The vnode count mismatch occurs only in special cases where a hash-distributed fragment
1341        // contains singleton internal tables. e.g., the state table of `Source` executors.
1342        // In this case, we reduce the vnode mapping to a single vnode as only `SINGLETON_VNODE` is used.
1343        assert_eq!(
1344            table_catalog.vnode_count.value(),
1345            1,
1346            "fragment vnode count {} does not match table vnode count {}",
1347            vnode_mapping.len(),
1348            table_catalog.vnode_count.value(),
1349        );
1350        &WorkerSlotMapping::new_single(vnode_mapping.iter().next().unwrap())
1351    } else {
1352        vnode_mapping
1353    };
1354    let vnode_count = vnode_mapping.len();
1355
1356    let mut partitions: HashMap<WorkerSlotId, (BitmapBuilder, Vec<_>)> = HashMap::new();
1357
1358    if scan_ranges.is_empty() {
1359        return Ok(vnode_mapping
1360            .to_bitmaps()
1361            .into_iter()
1362            .map(|(k, vnode_bitmap)| {
1363                (
1364                    k,
1365                    TablePartitionInfo {
1366                        vnode_bitmap,
1367                        scan_ranges: vec![],
1368                    },
1369                )
1370            })
1371            .collect());
1372    }
1373
1374    let table_distribution = TableDistribution::new_from_storage_table_desc(
1375        Some(Bitmap::ones(vnode_count).into()),
1376        &table_catalog.table_desc().try_to_protobuf()?,
1377    );
1378
1379    for scan_range in scan_ranges {
1380        let vnode = scan_range.try_compute_vnode(&table_distribution);
1381        match vnode {
1382            None => {
1383                // put this scan_range to all partitions
1384                vnode_mapping.to_bitmaps().into_iter().for_each(
1385                    |(worker_slot_id, vnode_bitmap)| {
1386                        let (bitmap, scan_ranges) = partitions
1387                            .entry(worker_slot_id)
1388                            .or_insert_with(|| (BitmapBuilder::zeroed(vnode_count), vec![]));
1389                        vnode_bitmap
1390                            .iter()
1391                            .enumerate()
1392                            .for_each(|(vnode, b)| bitmap.set(vnode, b));
1393                        scan_ranges.push(scan_range.to_protobuf());
1394                    },
1395                );
1396            }
1397            // scan a single partition
1398            Some(vnode) => {
1399                let worker_slot_id = vnode_mapping[vnode];
1400                let (bitmap, scan_ranges) = partitions
1401                    .entry(worker_slot_id)
1402                    .or_insert_with(|| (BitmapBuilder::zeroed(vnode_count), vec![]));
1403                bitmap.set(vnode.to_index(), true);
1404                scan_ranges.push(scan_range.to_protobuf());
1405            }
1406        }
1407    }
1408
1409    Ok(partitions
1410        .into_iter()
1411        .map(|(k, (bitmap, scan_ranges))| {
1412            (
1413                k,
1414                TablePartitionInfo {
1415                    vnode_bitmap: bitmap.finish(),
1416                    scan_ranges,
1417                },
1418            )
1419        })
1420        .collect())
1421}
1422
1423#[cfg(test)]
1424mod tests {
1425    use std::collections::{HashMap, HashSet};
1426
1427    use risingwave_pb::batch_plan::plan_node::NodeBody;
1428
1429    use crate::optimizer::plan_node::BatchPlanNodeType;
1430    use crate::scheduler::plan_fragmenter::StageId;
1431
1432    #[tokio::test]
1433    async fn test_fragmenter() {
1434        let query = crate::scheduler::distributed::tests::create_query().await;
1435
1436        assert_eq!(query.stage_graph.root_stage_id, 0);
1437        assert_eq!(query.stage_graph.stages.len(), 4);
1438
1439        // Check the mappings of child edges.
1440        assert_eq!(query.stage_graph.child_edges[&0], [1].into());
1441        assert_eq!(query.stage_graph.child_edges[&1], [2, 3].into());
1442        assert_eq!(query.stage_graph.child_edges[&2], HashSet::new());
1443        assert_eq!(query.stage_graph.child_edges[&3], HashSet::new());
1444
1445        // Check the mappings of parent edges.
1446        assert_eq!(query.stage_graph.parent_edges[&0], HashSet::new());
1447        assert_eq!(query.stage_graph.parent_edges[&1], [0].into());
1448        assert_eq!(query.stage_graph.parent_edges[&2], [1].into());
1449        assert_eq!(query.stage_graph.parent_edges[&3], [1].into());
1450
1451        // Verify topology order
1452        {
1453            let stage_id_to_pos: HashMap<StageId, usize> = query
1454                .stage_graph
1455                .stage_ids_by_topo_order()
1456                .enumerate()
1457                .map(|(pos, stage_id)| (stage_id, pos))
1458                .collect();
1459
1460            for stage_id in query.stage_graph.stages.keys() {
1461                let stage_pos = stage_id_to_pos[stage_id];
1462                for child_stage_id in &query.stage_graph.child_edges[stage_id] {
1463                    let child_pos = stage_id_to_pos[child_stage_id];
1464                    assert!(stage_pos > child_pos);
1465                }
1466            }
1467        }
1468
1469        // Check plan node in each stages.
1470        let root_exchange = query.stage_graph.stages.get(&0).unwrap();
1471        assert_eq!(
1472            root_exchange.root.node_type(),
1473            BatchPlanNodeType::BatchExchange
1474        );
1475        assert_eq!(root_exchange.root.source_stage_id, Some(1));
1476        assert!(matches!(root_exchange.root.node, NodeBody::Exchange(_)));
1477        assert_eq!(root_exchange.parallelism, Some(1));
1478        assert!(!root_exchange.has_table_scan());
1479
1480        let join_node = query.stage_graph.stages.get(&1).unwrap();
1481        assert_eq!(join_node.root.node_type(), BatchPlanNodeType::BatchHashJoin);
1482        assert_eq!(join_node.parallelism, Some(24));
1483
1484        assert!(matches!(join_node.root.node, NodeBody::HashJoin(_)));
1485        assert_eq!(join_node.root.source_stage_id, None);
1486        assert_eq!(2, join_node.root.children.len());
1487
1488        assert!(matches!(
1489            join_node.root.children[0].node,
1490            NodeBody::Exchange(_)
1491        ));
1492        assert_eq!(join_node.root.children[0].source_stage_id, Some(2));
1493        assert_eq!(0, join_node.root.children[0].children.len());
1494
1495        assert!(matches!(
1496            join_node.root.children[1].node,
1497            NodeBody::Exchange(_)
1498        ));
1499        assert_eq!(join_node.root.children[1].source_stage_id, Some(3));
1500        assert_eq!(0, join_node.root.children[1].children.len());
1501        assert!(!join_node.has_table_scan());
1502
1503        let scan_node1 = query.stage_graph.stages.get(&2).unwrap();
1504        assert_eq!(scan_node1.root.node_type(), BatchPlanNodeType::BatchSeqScan);
1505        assert_eq!(scan_node1.root.source_stage_id, None);
1506        assert_eq!(0, scan_node1.root.children.len());
1507        assert!(scan_node1.has_table_scan());
1508
1509        let scan_node2 = query.stage_graph.stages.get(&3).unwrap();
1510        assert_eq!(scan_node2.root.node_type(), BatchPlanNodeType::BatchFilter);
1511        assert_eq!(scan_node2.root.source_stage_id, None);
1512        assert_eq!(1, scan_node2.root.children.len());
1513        assert!(scan_node2.has_table_scan());
1514    }
1515}