risingwave_frontend/scheduler/
plan_fragmenter.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::min;
16use std::collections::{HashMap, HashSet};
17use std::fmt::{Debug, Display, Formatter};
18use std::num::NonZeroU64;
19
20use anyhow::anyhow;
21use async_recursion::async_recursion;
22use enum_as_inner::EnumAsInner;
23use futures::TryStreamExt;
24use iceberg::expr::Predicate as IcebergPredicate;
25use itertools::Itertools;
26use petgraph::{Directed, Graph};
27use pgwire::pg_server::SessionId;
28use risingwave_batch::error::BatchError;
29use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeSelector;
30use risingwave_common::bitmap::{Bitmap, BitmapBuilder};
31use risingwave_common::catalog::Schema;
32use risingwave_common::hash::table_distribution::TableDistribution;
33use risingwave_common::hash::{WorkerSlotId, WorkerSlotMapping};
34use risingwave_common::util::scan_range::ScanRange;
35use risingwave_connector::source::filesystem::opendal_source::opendal_enumerator::OpendalEnumerator;
36use risingwave_connector::source::filesystem::opendal_source::{
37    BatchPosixFsEnumerator, OpendalAzblob, OpendalGcs, OpendalS3,
38};
39use risingwave_connector::source::iceberg::IcebergSplitEnumerator;
40use risingwave_connector::source::kafka::KafkaSplitEnumerator;
41use risingwave_connector::source::prelude::DatagenSplitEnumerator;
42use risingwave_connector::source::reader::reader::build_opendal_fs_list_for_batch;
43use risingwave_connector::source::{
44    ConnectorProperties, SourceEnumeratorContext, SplitEnumerator, SplitImpl,
45};
46use risingwave_pb::batch_plan::iceberg_scan_node::IcebergScanType;
47use risingwave_pb::batch_plan::plan_node::NodeBody;
48use risingwave_pb::batch_plan::{ExchangeInfo, ScanRange as ScanRangeProto};
49use risingwave_pb::plan_common::Field as PbField;
50use serde::ser::SerializeStruct;
51use serde::{Serialize, Serializer};
52use uuid::Uuid;
53
54use super::SchedulerError;
55use crate::TableCatalog;
56use crate::catalog::TableId;
57use crate::catalog::catalog_service::CatalogReader;
58use crate::optimizer::plan_node::generic::{GenericPlanRef, PhysicalPlanRef};
59use crate::optimizer::plan_node::{
60    BatchIcebergScan, BatchKafkaScan, BatchPlanNodeType, BatchPlanRef as PlanRef, BatchSource,
61    PlanNodeId,
62};
63use crate::optimizer::property::Distribution;
64use crate::scheduler::SchedulerResult;
65
66#[derive(Clone, Debug, Hash, Eq, PartialEq)]
67pub struct QueryId {
68    pub id: String,
69}
70
71impl std::fmt::Display for QueryId {
72    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
73        write!(f, "QueryId:{}", self.id)
74    }
75}
76
77#[derive(Copy, Clone, Hash, PartialEq, Eq, Ord, PartialOrd)]
78pub struct StageId(u32);
79
80impl Display for StageId {
81    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
82        write!(f, "{}", self.0)
83    }
84}
85
86impl Debug for StageId {
87    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
88        write!(f, "{:?}", self.0)
89    }
90}
91
92impl From<StageId> for u32 {
93    fn from(value: StageId) -> Self {
94        value.0
95    }
96}
97
98impl From<u32> for StageId {
99    fn from(value: u32) -> Self {
100        StageId(value)
101    }
102}
103
104impl Serialize for StageId {
105    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
106    where
107        S: Serializer,
108    {
109        self.0.serialize(serializer)
110    }
111}
112
113impl StageId {
114    pub fn inc(&mut self) {
115        self.0 += 1;
116    }
117}
118
119// Root stage always has only one task.
120pub const ROOT_TASK_ID: u64 = 0;
121// Root task has only one output.
122pub const ROOT_TASK_OUTPUT_ID: u64 = 0;
123pub type TaskId = u64;
124
125/// Generated by [`BatchPlanFragmenter`] and used in query execution graph.
126#[derive(Debug)]
127#[cfg_attr(test, derive(Clone))]
128pub struct ExecutionPlanNode {
129    pub plan_node_id: PlanNodeId,
130    pub plan_node_type: BatchPlanNodeType,
131    pub node: NodeBody,
132    pub schema: Vec<PbField>,
133
134    pub children: Vec<ExecutionPlanNode>,
135
136    /// The stage id of the source of `BatchExchange`.
137    /// Used to find `ExchangeSource` from scheduler when creating `PlanNode`.
138    ///
139    /// `None` when this node is not `BatchExchange`.
140    pub source_stage_id: Option<StageId>,
141}
142
143impl Serialize for ExecutionPlanNode {
144    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
145    where
146        S: serde::Serializer,
147    {
148        let mut state = serializer.serialize_struct("QueryStage", 5)?;
149        state.serialize_field("plan_node_id", &self.plan_node_id)?;
150        state.serialize_field("plan_node_type", &self.plan_node_type)?;
151        state.serialize_field("schema", &self.schema)?;
152        state.serialize_field("children", &self.children)?;
153        state.serialize_field("source_stage_id", &self.source_stage_id)?;
154        state.end()
155    }
156}
157
158impl TryFrom<PlanRef> for ExecutionPlanNode {
159    type Error = SchedulerError;
160
161    fn try_from(plan_node: PlanRef) -> Result<Self, Self::Error> {
162        Ok(Self {
163            plan_node_id: plan_node.plan_base().id(),
164            plan_node_type: plan_node.node_type(),
165            node: plan_node.try_to_batch_prost_body()?,
166            children: vec![],
167            schema: plan_node.schema().to_prost(),
168            source_stage_id: None,
169        })
170    }
171}
172
173impl ExecutionPlanNode {
174    pub fn node_type(&self) -> BatchPlanNodeType {
175        self.plan_node_type
176    }
177}
178
179/// `BatchPlanFragmenter` splits a query plan into fragments.
180pub struct BatchPlanFragmenter {
181    query_id: QueryId,
182    next_stage_id: StageId,
183    worker_node_manager: WorkerNodeSelector,
184    catalog_reader: CatalogReader,
185
186    batch_parallelism: usize,
187
188    stage_graph_builder: Option<StageGraphBuilder>,
189    stage_graph: Option<StageGraph>,
190}
191
192impl Default for QueryId {
193    fn default() -> Self {
194        Self {
195            id: Uuid::new_v4().to_string(),
196        }
197    }
198}
199
200impl BatchPlanFragmenter {
201    pub fn new(
202        worker_node_manager: WorkerNodeSelector,
203        catalog_reader: CatalogReader,
204        batch_parallelism: Option<NonZeroU64>,
205        batch_node: PlanRef,
206    ) -> SchedulerResult<Self> {
207        // if batch_parallelism is None, it means no limit, we will use the available nodes count as
208        // parallelism.
209        // if batch_parallelism is Some(num), we will use the min(num, the available
210        // nodes count) as parallelism.
211        let batch_parallelism = if let Some(num) = batch_parallelism {
212            // can be 0 if no available serving worker
213            min(
214                num.get() as usize,
215                worker_node_manager.schedule_unit_count(),
216            )
217        } else {
218            // can be 0 if no available serving worker
219            worker_node_manager.schedule_unit_count()
220        };
221
222        let mut plan_fragmenter = Self {
223            query_id: Default::default(),
224            next_stage_id: 0.into(),
225            worker_node_manager,
226            catalog_reader,
227            batch_parallelism,
228            stage_graph_builder: Some(StageGraphBuilder::new(batch_parallelism)),
229            stage_graph: None,
230        };
231        plan_fragmenter.split_into_stage(batch_node)?;
232        Ok(plan_fragmenter)
233    }
234
235    /// Split the plan node into each stages, based on exchange node.
236    fn split_into_stage(&mut self, batch_node: PlanRef) -> SchedulerResult<()> {
237        let root_stage_id = self.new_stage(
238            batch_node,
239            Some(Distribution::Single.to_prost(
240                1,
241                &self.catalog_reader,
242                &self.worker_node_manager,
243            )?),
244        )?;
245        self.stage_graph = Some(
246            self.stage_graph_builder
247                .take()
248                .unwrap()
249                .build(root_stage_id),
250        );
251        Ok(())
252    }
253}
254
255/// The fragmented query generated by [`BatchPlanFragmenter`].
256#[derive(Debug)]
257#[cfg_attr(test, derive(Clone))]
258pub struct Query {
259    /// Query id should always be unique.
260    pub query_id: QueryId,
261    pub stage_graph: StageGraph,
262}
263
264impl Query {
265    pub fn leaf_stages(&self) -> Vec<StageId> {
266        let mut ret_leaf_stages = Vec::new();
267        for stage_id in self.stage_graph.stages.keys() {
268            if self
269                .stage_graph
270                .get_child_stages_unchecked(stage_id)
271                .is_empty()
272            {
273                ret_leaf_stages.push(*stage_id);
274            }
275        }
276        ret_leaf_stages
277    }
278
279    pub fn get_parents(&self, stage_id: &StageId) -> &HashSet<StageId> {
280        self.stage_graph.parent_edges.get(stage_id).unwrap()
281    }
282
283    pub fn root_stage_id(&self) -> StageId {
284        self.stage_graph.root_stage_id
285    }
286
287    pub fn query_id(&self) -> &QueryId {
288        &self.query_id
289    }
290
291    pub fn stages_with_table_scan(&self) -> HashSet<StageId> {
292        self.stage_graph
293            .stages
294            .iter()
295            .filter_map(|(stage_id, stage_query)| {
296                if stage_query.has_table_scan() {
297                    Some(*stage_id)
298                } else {
299                    None
300                }
301            })
302            .collect()
303    }
304
305    pub fn has_lookup_join_stage(&self) -> bool {
306        self.stage_graph
307            .stages
308            .iter()
309            .any(|(_stage_id, stage_query)| stage_query.has_lookup_join())
310    }
311
312    pub fn stage(&self, stage_id: StageId) -> &QueryStage {
313        &self.stage_graph.stages[&stage_id]
314    }
315}
316
317#[derive(Debug, Clone)]
318pub enum SourceFetchParameters {
319    IcebergSpecificInfo(IcebergSpecificInfo),
320    KafkaTimebound {
321        lower: Option<i64>,
322        upper: Option<i64>,
323    },
324    Empty,
325}
326
327#[derive(Debug, Clone)]
328pub struct SourceFetchInfo {
329    pub schema: Schema,
330    /// These are user-configured connector properties.
331    /// e.g. host, username, etc...
332    pub connector: ConnectorProperties,
333    /// These parameters are internally derived by the plan node.
334    /// e.g. predicate pushdown for iceberg, timebound for kafka.
335    pub fetch_parameters: SourceFetchParameters,
336}
337
338#[derive(Debug, Clone)]
339pub struct IcebergSpecificInfo {
340    pub iceberg_scan_type: IcebergScanType,
341    pub predicate: IcebergPredicate,
342    pub snapshot_id: Option<i64>,
343}
344
345#[derive(Clone, Debug)]
346pub enum SourceScanInfo {
347    /// Split Info
348    Incomplete(SourceFetchInfo),
349    Complete(Vec<SplitImpl>),
350}
351
352impl SourceScanInfo {
353    pub fn new(fetch_info: SourceFetchInfo) -> Self {
354        Self::Incomplete(fetch_info)
355    }
356
357    pub async fn complete(self, batch_parallelism: usize) -> SchedulerResult<Self> {
358        let fetch_info = match self {
359            SourceScanInfo::Incomplete(fetch_info) => fetch_info,
360            SourceScanInfo::Complete(_) => {
361                unreachable!("Never call complete when SourceScanInfo is already complete")
362            }
363        };
364        match (fetch_info.connector, fetch_info.fetch_parameters) {
365            (
366                ConnectorProperties::Kafka(prop),
367                SourceFetchParameters::KafkaTimebound { lower, upper },
368            ) => {
369                let mut kafka_enumerator =
370                    KafkaSplitEnumerator::new(*prop, SourceEnumeratorContext::dummy().into())
371                        .await?;
372                let split_info = kafka_enumerator
373                    .list_splits_batch(lower, upper)
374                    .await?
375                    .into_iter()
376                    .map(SplitImpl::Kafka)
377                    .collect_vec();
378
379                Ok(SourceScanInfo::Complete(split_info))
380            }
381            (ConnectorProperties::Datagen(prop), SourceFetchParameters::Empty) => {
382                let mut datagen_enumerator =
383                    DatagenSplitEnumerator::new(*prop, SourceEnumeratorContext::dummy().into())
384                        .await?;
385                let split_info = datagen_enumerator.list_splits().await?;
386                let res = split_info.into_iter().map(SplitImpl::Datagen).collect_vec();
387
388                Ok(SourceScanInfo::Complete(res))
389            }
390            (ConnectorProperties::OpendalS3(prop), SourceFetchParameters::Empty) => {
391                let lister: OpendalEnumerator<OpendalS3> = OpendalEnumerator::new_s3_source(
392                    &prop.s3_properties,
393                    prop.assume_role,
394                    prop.fs_common.compression_format,
395                )?;
396                let stream = build_opendal_fs_list_for_batch(lister);
397
398                let batch_res: Vec<_> = stream.try_collect().await?;
399                let res = batch_res
400                    .into_iter()
401                    .map(SplitImpl::OpendalS3)
402                    .collect_vec();
403
404                Ok(SourceScanInfo::Complete(res))
405            }
406            (ConnectorProperties::Gcs(prop), SourceFetchParameters::Empty) => {
407                let lister: OpendalEnumerator<OpendalGcs> =
408                    OpendalEnumerator::new_gcs_source(*prop)?;
409                let stream = build_opendal_fs_list_for_batch(lister);
410                let batch_res: Vec<_> = stream.try_collect().await?;
411                let res = batch_res.into_iter().map(SplitImpl::Gcs).collect_vec();
412
413                Ok(SourceScanInfo::Complete(res))
414            }
415            (ConnectorProperties::Azblob(prop), SourceFetchParameters::Empty) => {
416                let lister: OpendalEnumerator<OpendalAzblob> =
417                    OpendalEnumerator::new_azblob_source(*prop)?;
418                let stream = build_opendal_fs_list_for_batch(lister);
419                let batch_res: Vec<_> = stream.try_collect().await?;
420                let res = batch_res.into_iter().map(SplitImpl::Azblob).collect_vec();
421
422                Ok(SourceScanInfo::Complete(res))
423            }
424            (ConnectorProperties::BatchPosixFs(prop), SourceFetchParameters::Empty) => {
425                use risingwave_connector::source::SplitEnumerator;
426                let mut enumerator = BatchPosixFsEnumerator::new(
427                    *prop,
428                    risingwave_connector::source::SourceEnumeratorContext::dummy().into(),
429                )
430                .await?;
431                let splits = enumerator.list_splits().await?;
432                let res = splits
433                    .into_iter()
434                    .map(SplitImpl::BatchPosixFs)
435                    .collect_vec();
436
437                Ok(SourceScanInfo::Complete(res))
438            }
439            (
440                ConnectorProperties::Iceberg(prop),
441                SourceFetchParameters::IcebergSpecificInfo(iceberg_specific_info),
442            ) => {
443                let iceberg_enumerator =
444                    IcebergSplitEnumerator::new(*prop, SourceEnumeratorContext::dummy().into())
445                        .await?;
446
447                let split_info = iceberg_enumerator
448                    .list_splits_batch(
449                        fetch_info.schema,
450                        iceberg_specific_info.snapshot_id,
451                        batch_parallelism,
452                        iceberg_specific_info.iceberg_scan_type,
453                        iceberg_specific_info.predicate,
454                    )
455                    .await?
456                    .into_iter()
457                    .map(SplitImpl::Iceberg)
458                    .collect_vec();
459
460                Ok(SourceScanInfo::Complete(split_info))
461            }
462            (connector, _) => Err(SchedulerError::Internal(anyhow!(
463                "Unsupported to query directly from this {} source, \
464                 please create a table or streaming job from it",
465                connector.kind()
466            ))),
467        }
468    }
469
470    pub fn split_info(&self) -> SchedulerResult<&Vec<SplitImpl>> {
471        match self {
472            Self::Incomplete(_) => Err(SchedulerError::Internal(anyhow!(
473                "Should not get split info from incomplete source scan info"
474            ))),
475            Self::Complete(split_info) => Ok(split_info),
476        }
477    }
478}
479
480#[derive(Clone, Debug)]
481pub struct TableScanInfo {
482    /// The name of the table to scan.
483    name: String,
484
485    /// Indicates the table partitions to be read by scan tasks. Unnecessary partitions are already
486    /// pruned.
487    ///
488    /// For singleton table, this field is still `Some` and only contains a single partition with
489    /// full vnode bitmap, since we need to know where to schedule the singleton scan task.
490    ///
491    /// `None` iff the table is a system table.
492    partitions: Option<HashMap<WorkerSlotId, TablePartitionInfo>>,
493}
494
495impl TableScanInfo {
496    /// For normal tables, `partitions` should always be `Some`.
497    pub fn new(name: String, partitions: HashMap<WorkerSlotId, TablePartitionInfo>) -> Self {
498        Self {
499            name,
500            partitions: Some(partitions),
501        }
502    }
503
504    /// For system table, there's no partition info.
505    pub fn system_table(name: String) -> Self {
506        Self {
507            name,
508            partitions: None,
509        }
510    }
511
512    pub fn name(&self) -> &str {
513        self.name.as_ref()
514    }
515
516    pub fn partitions(&self) -> Option<&HashMap<WorkerSlotId, TablePartitionInfo>> {
517        self.partitions.as_ref()
518    }
519}
520
521#[derive(Clone, Debug)]
522pub struct TablePartitionInfo {
523    pub vnode_bitmap: Bitmap,
524    pub scan_ranges: Vec<ScanRangeProto>,
525}
526
527#[derive(Clone, Debug, EnumAsInner)]
528pub enum PartitionInfo {
529    Table(TablePartitionInfo),
530    Source(Vec<SplitImpl>),
531    File(Vec<String>),
532}
533
534#[derive(Clone, Debug)]
535pub struct FileScanInfo {
536    pub file_location: Vec<String>,
537}
538
539/// Fragment part of `Query`.
540#[cfg_attr(test, derive(Clone))]
541pub struct QueryStage {
542    pub id: StageId,
543    pub root: ExecutionPlanNode,
544    pub exchange_info: Option<ExchangeInfo>,
545    pub parallelism: Option<u32>,
546    /// Indicates whether this stage contains a table scan node and the table's information if so.
547    pub table_scan_info: Option<TableScanInfo>,
548    pub source_info: Option<SourceScanInfo>,
549    pub file_scan_info: Option<FileScanInfo>,
550    pub has_lookup_join: bool,
551    pub dml_table_id: Option<TableId>,
552    pub session_id: SessionId,
553    pub batch_enable_distributed_dml: bool,
554
555    /// Used to generate exchange information when complete source scan information.
556    children_exchange_distribution: Option<HashMap<StageId, Distribution>>,
557}
558
559impl QueryStage {
560    /// If true, this stage contains table scan executor that creates
561    /// Hummock iterators to read data from table. The iterator is initialized during
562    /// the executor building process on the batch execution engine.
563    pub fn has_table_scan(&self) -> bool {
564        self.table_scan_info.is_some()
565    }
566
567    /// If true, this stage contains lookup join executor.
568    /// We need to delay epoch unpin util the end of the query.
569    pub fn has_lookup_join(&self) -> bool {
570        self.has_lookup_join
571    }
572
573    pub fn with_exchange_info(
574        self,
575        exchange_info: Option<ExchangeInfo>,
576        parallelism: Option<u32>,
577    ) -> Self {
578        if let Some(exchange_info) = exchange_info {
579            Self {
580                id: self.id,
581                root: self.root,
582                exchange_info: Some(exchange_info),
583                parallelism,
584                table_scan_info: self.table_scan_info,
585                source_info: self.source_info,
586                file_scan_info: self.file_scan_info,
587                has_lookup_join: self.has_lookup_join,
588                dml_table_id: self.dml_table_id,
589                session_id: self.session_id,
590                batch_enable_distributed_dml: self.batch_enable_distributed_dml,
591                children_exchange_distribution: self.children_exchange_distribution,
592            }
593        } else {
594            self
595        }
596    }
597
598    pub fn with_exchange_info_and_complete_source_info(
599        self,
600        exchange_info: Option<ExchangeInfo>,
601        source_info: SourceScanInfo,
602        task_parallelism: u32,
603    ) -> Self {
604        assert!(matches!(source_info, SourceScanInfo::Complete(_)));
605        let exchange_info = if let Some(exchange_info) = exchange_info {
606            Some(exchange_info)
607        } else {
608            self.exchange_info
609        };
610        Self {
611            id: self.id,
612            root: self.root,
613            exchange_info,
614            parallelism: Some(task_parallelism),
615            table_scan_info: self.table_scan_info,
616            source_info: Some(source_info),
617            file_scan_info: self.file_scan_info,
618            has_lookup_join: self.has_lookup_join,
619            dml_table_id: self.dml_table_id,
620            session_id: self.session_id,
621            batch_enable_distributed_dml: self.batch_enable_distributed_dml,
622            children_exchange_distribution: None,
623        }
624    }
625}
626
627impl Debug for QueryStage {
628    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
629        f.debug_struct("QueryStage")
630            .field("id", &self.id)
631            .field("parallelism", &self.parallelism)
632            .field("exchange_info", &self.exchange_info)
633            .field("has_table_scan", &self.has_table_scan())
634            .finish()
635    }
636}
637
638impl Serialize for QueryStage {
639    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
640    where
641        S: serde::Serializer,
642    {
643        let mut state = serializer.serialize_struct("QueryStage", 3)?;
644        state.serialize_field("root", &self.root)?;
645        state.serialize_field("parallelism", &self.parallelism)?;
646        state.serialize_field("exchange_info", &self.exchange_info)?;
647        state.end()
648    }
649}
650
651struct QueryStageBuilder {
652    id: StageId,
653    root: Option<ExecutionPlanNode>,
654    parallelism: Option<u32>,
655    exchange_info: Option<ExchangeInfo>,
656
657    children_stages: Vec<StageId>,
658    /// See also [`QueryStage::table_scan_info`].
659    table_scan_info: Option<TableScanInfo>,
660    source_info: Option<SourceScanInfo>,
661    file_scan_file: Option<FileScanInfo>,
662    has_lookup_join: bool,
663    dml_table_id: Option<TableId>,
664    session_id: SessionId,
665    batch_enable_distributed_dml: bool,
666
667    children_exchange_distribution: HashMap<StageId, Distribution>,
668}
669
670impl QueryStageBuilder {
671    #[allow(clippy::too_many_arguments)]
672    fn new(
673        id: StageId,
674        parallelism: Option<u32>,
675        exchange_info: Option<ExchangeInfo>,
676        table_scan_info: Option<TableScanInfo>,
677        source_info: Option<SourceScanInfo>,
678        file_scan_file: Option<FileScanInfo>,
679        has_lookup_join: bool,
680        dml_table_id: Option<TableId>,
681        session_id: SessionId,
682        batch_enable_distributed_dml: bool,
683    ) -> Self {
684        Self {
685            id,
686            root: None,
687            parallelism,
688            exchange_info,
689            children_stages: vec![],
690            table_scan_info,
691            source_info,
692            file_scan_file,
693            has_lookup_join,
694            dml_table_id,
695            session_id,
696            batch_enable_distributed_dml,
697            children_exchange_distribution: HashMap::new(),
698        }
699    }
700
701    fn finish(self, stage_graph_builder: &mut StageGraphBuilder) -> StageId {
702        let children_exchange_distribution = if self.parallelism.is_none() {
703            Some(self.children_exchange_distribution)
704        } else {
705            None
706        };
707        let stage = QueryStage {
708            id: self.id,
709            root: self.root.unwrap(),
710            exchange_info: self.exchange_info,
711            parallelism: self.parallelism,
712            table_scan_info: self.table_scan_info,
713            source_info: self.source_info,
714            file_scan_info: self.file_scan_file,
715            has_lookup_join: self.has_lookup_join,
716            dml_table_id: self.dml_table_id,
717            session_id: self.session_id,
718            batch_enable_distributed_dml: self.batch_enable_distributed_dml,
719            children_exchange_distribution,
720        };
721
722        let stage_id = stage.id;
723        stage_graph_builder.add_node(stage);
724        for child_stage_id in self.children_stages {
725            stage_graph_builder.link_to_child(self.id, child_stage_id);
726        }
727        stage_id
728    }
729}
730
731/// Maintains how each stage are connected.
732#[derive(Debug, Serialize)]
733#[cfg_attr(test, derive(Clone))]
734pub struct StageGraph {
735    pub root_stage_id: StageId,
736    pub stages: HashMap<StageId, QueryStage>,
737    /// Traverse from top to down. Used in split plan into stages.
738    child_edges: HashMap<StageId, HashSet<StageId>>,
739    /// Traverse from down to top. Used in schedule each stage.
740    parent_edges: HashMap<StageId, HashSet<StageId>>,
741
742    batch_parallelism: usize,
743}
744
745enum StageCompleteInfo {
746    ExchangeInfo((Option<ExchangeInfo>, Option<u32>)),
747    ExchangeWithSourceInfo((Option<ExchangeInfo>, SourceScanInfo, u32)),
748}
749
750impl StageGraph {
751    pub fn get_child_stages_unchecked(&self, stage_id: &StageId) -> &HashSet<StageId> {
752        self.child_edges.get(stage_id).unwrap()
753    }
754
755    pub fn get_child_stages(&self, stage_id: &StageId) -> Option<&HashSet<StageId>> {
756        self.child_edges.get(stage_id)
757    }
758
759    /// Returns stage ids in topology order, s.t. child stage always appears before its parent.
760    pub fn stage_ids_by_topo_order(&self) -> impl Iterator<Item = StageId> {
761        let mut stack = Vec::with_capacity(self.stages.len());
762        stack.push(self.root_stage_id);
763        let mut ret = Vec::with_capacity(self.stages.len());
764        let mut existing = HashSet::with_capacity(self.stages.len());
765
766        while let Some(s) = stack.pop() {
767            if !existing.contains(&s) {
768                ret.push(s);
769                existing.insert(s);
770                stack.extend(&self.child_edges[&s]);
771            }
772        }
773
774        ret.into_iter().rev()
775    }
776
777    async fn complete(
778        self,
779        catalog_reader: &CatalogReader,
780        worker_node_manager: &WorkerNodeSelector,
781    ) -> SchedulerResult<StageGraph> {
782        let mut complete_stages = HashMap::new();
783        self.complete_stage(
784            self.root_stage_id,
785            None,
786            &mut complete_stages,
787            catalog_reader,
788            worker_node_manager,
789        )
790        .await?;
791        let mut stages = self.stages;
792        Ok(StageGraph {
793            root_stage_id: self.root_stage_id,
794            stages: complete_stages
795                .into_iter()
796                .map(|(stage_id, info)| {
797                    let stage = stages.remove(&stage_id).expect("should exist");
798                    let stage = match info {
799                        StageCompleteInfo::ExchangeInfo((exchange_info, parallelism)) => {
800                            stage.with_exchange_info(exchange_info, parallelism)
801                        }
802                        StageCompleteInfo::ExchangeWithSourceInfo((
803                            exchange_info,
804                            source_info,
805                            parallelism,
806                        )) => stage.with_exchange_info_and_complete_source_info(
807                            exchange_info,
808                            source_info,
809                            parallelism,
810                        ),
811                    };
812                    (stage_id, stage)
813                })
814                .collect(),
815            child_edges: self.child_edges,
816            parent_edges: self.parent_edges,
817            batch_parallelism: self.batch_parallelism,
818        })
819    }
820
821    #[async_recursion]
822    async fn complete_stage(
823        &self,
824        stage_id: StageId,
825        exchange_info: Option<ExchangeInfo>,
826        complete_stages: &mut HashMap<StageId, StageCompleteInfo>,
827        catalog_reader: &CatalogReader,
828        worker_node_manager: &WorkerNodeSelector,
829    ) -> SchedulerResult<()> {
830        let stage = &self.stages[&stage_id];
831        let parallelism = if stage.parallelism.is_some() {
832            // If the stage has parallelism, it means it's a complete stage.
833            complete_stages.insert(
834                stage.id,
835                StageCompleteInfo::ExchangeInfo((exchange_info, stage.parallelism)),
836            );
837            None
838        } else if matches!(stage.source_info, Some(SourceScanInfo::Incomplete(_))) {
839            let complete_source_info = stage
840                .source_info
841                .as_ref()
842                .unwrap()
843                .clone()
844                .complete(self.batch_parallelism)
845                .await?;
846
847            // For batch reading file source, the number of files involved is typically large.
848            // In order to avoid generating a task for each file, the parallelism of tasks is limited here.
849            // The minimum `task_parallelism` is 1. Additionally, `task_parallelism`
850            // must be greater than the number of files to read. Therefore, we first take the
851            // minimum of the number of files and (self.batch_parallelism / 2). If the number of
852            // files is 0, we set task_parallelism to 1.
853
854            let task_parallelism = match &stage.source_info {
855                Some(SourceScanInfo::Incomplete(source_fetch_info)) => {
856                    match source_fetch_info.connector {
857                        ConnectorProperties::Gcs(_)
858                        | ConnectorProperties::OpendalS3(_)
859                        | ConnectorProperties::Azblob(_) => (min(
860                            complete_source_info.split_info().unwrap().len() as u32,
861                            (self.batch_parallelism / 2) as u32,
862                        ))
863                        .max(1),
864                        _ => complete_source_info.split_info().unwrap().len() as u32,
865                    }
866                }
867                _ => unreachable!(),
868            };
869            // For file source batch read, all the files  to be read are divide into several parts to prevent the task from taking up too many resources.
870            // todo(wcy-fdu): Currently it will be divided into half of batch_parallelism groups, and this will be changed to configurable later.
871            let complete_stage_info = StageCompleteInfo::ExchangeWithSourceInfo((
872                exchange_info,
873                complete_source_info,
874                task_parallelism,
875            ));
876            complete_stages.insert(stage.id, complete_stage_info);
877            Some(task_parallelism)
878        } else {
879            assert!(stage.file_scan_info.is_some());
880            let parallelism = min(
881                self.batch_parallelism / 2,
882                stage.file_scan_info.as_ref().unwrap().file_location.len(),
883            );
884            complete_stages.insert(
885                stage.id,
886                StageCompleteInfo::ExchangeInfo((exchange_info, Some(parallelism as u32))),
887            );
888            None
889        };
890
891        for child_stage_id in self
892            .child_edges
893            .get(&stage.id)
894            .map(|edges| edges.iter())
895            .into_iter()
896            .flatten()
897        {
898            let exchange_info = if let Some(parallelism) = parallelism {
899                let exchange_distribution = stage
900                    .children_exchange_distribution
901                    .as_ref()
902                    .unwrap()
903                    .get(child_stage_id)
904                    .expect("Exchange distribution is not consistent with the stage graph");
905                Some(exchange_distribution.to_prost(
906                    parallelism,
907                    catalog_reader,
908                    worker_node_manager,
909                )?)
910            } else {
911                None
912            };
913            self.complete_stage(
914                *child_stage_id,
915                exchange_info,
916                complete_stages,
917                catalog_reader,
918                worker_node_manager,
919            )
920            .await?;
921        }
922
923        Ok(())
924    }
925
926    /// Converts the `StageGraph` into a `petgraph::graph::Graph<String, String>`.
927    pub fn to_petgraph(&self) -> Graph<String, String, Directed> {
928        let mut graph = Graph::<String, String, Directed>::new();
929
930        let mut node_indices = HashMap::new();
931
932        // Add all stages as nodes
933        for (&stage_id, stage_ref) in self.stages.iter().sorted_by_key(|(id, _)| **id) {
934            let node_label = format!("Stage {}: {:?}", stage_id, stage_ref);
935            let node_index = graph.add_node(node_label);
936            node_indices.insert(stage_id, node_index);
937        }
938
939        // Add edges between stages based on child_edges
940        for (&parent_id, children) in &self.child_edges {
941            if let Some(&parent_index) = node_indices.get(&parent_id) {
942                for &child_id in children {
943                    if let Some(&child_index) = node_indices.get(&child_id) {
944                        // Add an edge from parent to child
945                        graph.add_edge(parent_index, child_index, "".to_owned());
946                    }
947                }
948            }
949        }
950
951        graph
952    }
953}
954
955struct StageGraphBuilder {
956    stages: HashMap<StageId, QueryStage>,
957    child_edges: HashMap<StageId, HashSet<StageId>>,
958    parent_edges: HashMap<StageId, HashSet<StageId>>,
959    batch_parallelism: usize,
960}
961
962impl StageGraphBuilder {
963    pub fn new(batch_parallelism: usize) -> Self {
964        Self {
965            stages: HashMap::new(),
966            child_edges: HashMap::new(),
967            parent_edges: HashMap::new(),
968            batch_parallelism,
969        }
970    }
971
972    pub fn build(self, root_stage_id: StageId) -> StageGraph {
973        StageGraph {
974            root_stage_id,
975            stages: self.stages,
976            child_edges: self.child_edges,
977            parent_edges: self.parent_edges,
978            batch_parallelism: self.batch_parallelism,
979        }
980    }
981
982    /// Link parent stage and child stage. Maintain the mappings of parent -> child and child ->
983    /// parent.
984    pub fn link_to_child(&mut self, parent_id: StageId, child_id: StageId) {
985        self.child_edges
986            .get_mut(&parent_id)
987            .unwrap()
988            .insert(child_id);
989        self.parent_edges
990            .get_mut(&child_id)
991            .unwrap()
992            .insert(parent_id);
993    }
994
995    pub fn add_node(&mut self, stage: QueryStage) {
996        // Insert here so that left/root stages also has linkage.
997        self.child_edges.insert(stage.id, HashSet::new());
998        self.parent_edges.insert(stage.id, HashSet::new());
999        self.stages.insert(stage.id, stage);
1000    }
1001}
1002
1003impl BatchPlanFragmenter {
1004    /// After split, the `stage_graph` in the framenter may has the stage with incomplete source
1005    /// info, we need to fetch the source info to complete the stage in this function.
1006    /// Why separate this two step(`split()` and `generate_complete_query()`)?
1007    /// The step of fetching source info is a async operation so that we can't do it in the split
1008    /// step.
1009    pub async fn generate_complete_query(self) -> SchedulerResult<Query> {
1010        let stage_graph = self.stage_graph.unwrap();
1011        let new_stage_graph = stage_graph
1012            .complete(&self.catalog_reader, &self.worker_node_manager)
1013            .await?;
1014        Ok(Query {
1015            query_id: self.query_id,
1016            stage_graph: new_stage_graph,
1017        })
1018    }
1019
1020    fn new_stage(
1021        &mut self,
1022        root: PlanRef,
1023        exchange_info: Option<ExchangeInfo>,
1024    ) -> SchedulerResult<StageId> {
1025        let next_stage_id = self.next_stage_id;
1026        self.next_stage_id.inc();
1027
1028        let mut table_scan_info = None;
1029        let mut source_info = None;
1030        let mut file_scan_info = None;
1031
1032        // For current implementation, we can guarantee that each stage has only one table
1033        // scan(except System table) or one source.
1034        if let Some(info) = self.collect_stage_table_scan(root.clone())? {
1035            table_scan_info = Some(info);
1036        } else if let Some(info) = Self::collect_stage_source(root.clone())? {
1037            source_info = Some(info);
1038        } else if let Some(info) = Self::collect_stage_file_scan(root.clone())? {
1039            file_scan_info = Some(info);
1040        }
1041
1042        let mut has_lookup_join = false;
1043        let parallelism = match root.distribution() {
1044            Distribution::Single => {
1045                if let Some(info) = &mut table_scan_info {
1046                    if let Some(partitions) = &mut info.partitions {
1047                        if partitions.len() != 1 {
1048                            // This is rare case, but it's possible on the internal state of the
1049                            // Source operator.
1050                            tracing::warn!(
1051                                "The stage has single distribution, but contains a scan of table `{}` with {} partitions. A single random worker will be assigned",
1052                                info.name,
1053                                partitions.len()
1054                            );
1055
1056                            *partitions = partitions
1057                                .drain()
1058                                .take(1)
1059                                .update(|(_, info)| {
1060                                    info.vnode_bitmap = Bitmap::ones(info.vnode_bitmap.len());
1061                                })
1062                                .collect();
1063                        }
1064                    } else {
1065                        // System table
1066                    }
1067                } else if source_info.is_some() {
1068                    return Err(SchedulerError::Internal(anyhow!(
1069                        "The stage has single distribution, but contains a source operator"
1070                    )));
1071                }
1072                1
1073            }
1074            _ => {
1075                if let Some(table_scan_info) = &table_scan_info {
1076                    table_scan_info
1077                        .partitions
1078                        .as_ref()
1079                        .map(|m| m.len())
1080                        .unwrap_or(1)
1081                } else if let Some(lookup_join_parallelism) =
1082                    self.collect_stage_lookup_join_parallelism(root.clone())?
1083                {
1084                    has_lookup_join = true;
1085                    lookup_join_parallelism
1086                } else if source_info.is_some() {
1087                    0
1088                } else if file_scan_info.is_some() {
1089                    1
1090                } else {
1091                    self.batch_parallelism
1092                }
1093            }
1094        };
1095        if source_info.is_none() && file_scan_info.is_none() && parallelism == 0 {
1096            return Err(BatchError::EmptyWorkerNodes.into());
1097        }
1098        let parallelism = if parallelism == 0 {
1099            None
1100        } else {
1101            Some(parallelism as u32)
1102        };
1103        let dml_table_id = Self::collect_dml_table_id(&root);
1104        let mut builder = QueryStageBuilder::new(
1105            next_stage_id,
1106            parallelism,
1107            exchange_info,
1108            table_scan_info,
1109            source_info,
1110            file_scan_info,
1111            has_lookup_join,
1112            dml_table_id,
1113            root.ctx().session_ctx().session_id(),
1114            root.ctx()
1115                .session_ctx()
1116                .config()
1117                .batch_enable_distributed_dml(),
1118        );
1119
1120        self.visit_node(root, &mut builder, None)?;
1121
1122        Ok(builder.finish(self.stage_graph_builder.as_mut().unwrap()))
1123    }
1124
1125    fn visit_node(
1126        &mut self,
1127        node: PlanRef,
1128        builder: &mut QueryStageBuilder,
1129        parent_exec_node: Option<&mut ExecutionPlanNode>,
1130    ) -> SchedulerResult<()> {
1131        match node.node_type() {
1132            BatchPlanNodeType::BatchExchange => {
1133                self.visit_exchange(node, builder, parent_exec_node)?;
1134            }
1135            _ => {
1136                let mut execution_plan_node = ExecutionPlanNode::try_from(node.clone())?;
1137
1138                for child in node.inputs() {
1139                    self.visit_node(child, builder, Some(&mut execution_plan_node))?;
1140                }
1141
1142                if let Some(parent) = parent_exec_node {
1143                    parent.children.push(execution_plan_node);
1144                } else {
1145                    builder.root = Some(execution_plan_node);
1146                }
1147            }
1148        }
1149        Ok(())
1150    }
1151
1152    fn visit_exchange(
1153        &mut self,
1154        node: PlanRef,
1155        builder: &mut QueryStageBuilder,
1156        parent_exec_node: Option<&mut ExecutionPlanNode>,
1157    ) -> SchedulerResult<()> {
1158        let mut execution_plan_node = ExecutionPlanNode::try_from(node.clone())?;
1159        let child_exchange_info = if let Some(parallelism) = builder.parallelism {
1160            Some(node.distribution().to_prost(
1161                parallelism,
1162                &self.catalog_reader,
1163                &self.worker_node_manager,
1164            )?)
1165        } else {
1166            None
1167        };
1168        let child_stage_id = self.new_stage(node.inputs()[0].clone(), child_exchange_info)?;
1169        execution_plan_node.source_stage_id = Some(child_stage_id);
1170        if builder.parallelism.is_none() {
1171            builder
1172                .children_exchange_distribution
1173                .insert(child_stage_id, node.distribution().clone());
1174        }
1175
1176        if let Some(parent) = parent_exec_node {
1177            parent.children.push(execution_plan_node);
1178        } else {
1179            builder.root = Some(execution_plan_node);
1180        }
1181
1182        builder.children_stages.push(child_stage_id);
1183        Ok(())
1184    }
1185
1186    /// Check whether this stage contains a source node.
1187    /// If so, use  `SplitEnumeratorImpl` to get the split info from exteneral source.
1188    ///
1189    /// For current implementation, we can guarantee that each stage has only one source.
1190    fn collect_stage_source(node: PlanRef) -> SchedulerResult<Option<SourceScanInfo>> {
1191        if node.node_type() == BatchPlanNodeType::BatchExchange {
1192            // Do not visit next stage.
1193            return Ok(None);
1194        }
1195
1196        if let Some(batch_kafka_node) = node.as_batch_kafka_scan() {
1197            let batch_kafka_scan: &BatchKafkaScan = batch_kafka_node;
1198            let source_catalog = batch_kafka_scan.source_catalog();
1199            if let Some(source_catalog) = source_catalog {
1200                let property =
1201                    ConnectorProperties::extract(source_catalog.with_properties.clone(), false)?;
1202                let timestamp_bound = batch_kafka_scan.kafka_timestamp_range_value();
1203                return Ok(Some(SourceScanInfo::new(SourceFetchInfo {
1204                    schema: batch_kafka_scan.base.schema().clone(),
1205                    connector: property,
1206                    fetch_parameters: SourceFetchParameters::KafkaTimebound {
1207                        lower: timestamp_bound.0,
1208                        upper: timestamp_bound.1,
1209                    },
1210                })));
1211            }
1212        } else if let Some(batch_iceberg_scan) = node.as_batch_iceberg_scan() {
1213            let batch_iceberg_scan: &BatchIcebergScan = batch_iceberg_scan;
1214            let source_catalog = batch_iceberg_scan.source_catalog();
1215            if let Some(source_catalog) = source_catalog {
1216                let property =
1217                    ConnectorProperties::extract(source_catalog.with_properties.clone(), false)?;
1218                return Ok(Some(SourceScanInfo::new(SourceFetchInfo {
1219                    schema: batch_iceberg_scan.base.schema().clone(),
1220                    connector: property,
1221                    fetch_parameters: SourceFetchParameters::IcebergSpecificInfo(
1222                        IcebergSpecificInfo {
1223                            predicate: batch_iceberg_scan.predicate.clone(),
1224                            iceberg_scan_type: batch_iceberg_scan.iceberg_scan_type(),
1225                            snapshot_id: batch_iceberg_scan.snapshot_id(),
1226                        },
1227                    ),
1228                })));
1229            }
1230        } else if let Some(source_node) = node.as_batch_source() {
1231            // TODO: use specific batch operator instead of batch source.
1232            let source_node: &BatchSource = source_node;
1233            let source_catalog = source_node.source_catalog();
1234            if let Some(source_catalog) = source_catalog {
1235                let property =
1236                    ConnectorProperties::extract(source_catalog.with_properties.clone(), false)?;
1237                return Ok(Some(SourceScanInfo::new(SourceFetchInfo {
1238                    schema: source_node.base.schema().clone(),
1239                    connector: property,
1240                    fetch_parameters: SourceFetchParameters::Empty,
1241                })));
1242            }
1243        }
1244
1245        node.inputs()
1246            .into_iter()
1247            .find_map(|n| Self::collect_stage_source(n).transpose())
1248            .transpose()
1249    }
1250
1251    fn collect_stage_file_scan(node: PlanRef) -> SchedulerResult<Option<FileScanInfo>> {
1252        if node.node_type() == BatchPlanNodeType::BatchExchange {
1253            // Do not visit next stage.
1254            return Ok(None);
1255        }
1256
1257        if let Some(batch_file_scan) = node.as_batch_file_scan() {
1258            return Ok(Some(FileScanInfo {
1259                file_location: batch_file_scan.core.file_location(),
1260            }));
1261        }
1262
1263        node.inputs()
1264            .into_iter()
1265            .find_map(|n| Self::collect_stage_file_scan(n).transpose())
1266            .transpose()
1267    }
1268
1269    /// Check whether this stage contains a table scan node and the table's information if so.
1270    ///
1271    /// If there are multiple scan nodes in this stage, they must have the same distribution, but
1272    /// maybe different vnodes partition. We just use the same partition for all the scan nodes.
1273    fn collect_stage_table_scan(&self, node: PlanRef) -> SchedulerResult<Option<TableScanInfo>> {
1274        let build_table_scan_info = |name, table_catalog: &TableCatalog, scan_range| {
1275            let vnode_mapping = self
1276                .worker_node_manager
1277                .fragment_mapping(table_catalog.fragment_id)?;
1278            let partitions = derive_partitions(scan_range, table_catalog, &vnode_mapping)?;
1279            let info = TableScanInfo::new(name, partitions);
1280            Ok(Some(info))
1281        };
1282        if node.node_type() == BatchPlanNodeType::BatchExchange {
1283            // Do not visit next stage.
1284            return Ok(None);
1285        }
1286        if let Some(scan_node) = node.as_batch_sys_seq_scan() {
1287            let name = scan_node.core().table.name.clone();
1288            Ok(Some(TableScanInfo::system_table(name)))
1289        } else if let Some(scan_node) = node.as_batch_log_seq_scan() {
1290            build_table_scan_info(
1291                scan_node.core().table_name.clone(),
1292                &scan_node.core().table,
1293                &[],
1294            )
1295        } else if let Some(scan_node) = node.as_batch_seq_scan() {
1296            build_table_scan_info(
1297                scan_node.core().table_name().to_owned(),
1298                &scan_node.core().table_catalog,
1299                scan_node.scan_ranges(),
1300            )
1301        } else {
1302            node.inputs()
1303                .into_iter()
1304                .find_map(|n| self.collect_stage_table_scan(n).transpose())
1305                .transpose()
1306        }
1307    }
1308
1309    /// Returns the dml table id if any.
1310    fn collect_dml_table_id(node: &PlanRef) -> Option<TableId> {
1311        if node.node_type() == BatchPlanNodeType::BatchExchange {
1312            return None;
1313        }
1314        if let Some(insert) = node.as_batch_insert() {
1315            Some(insert.core.table_id)
1316        } else if let Some(update) = node.as_batch_update() {
1317            Some(update.core.table_id)
1318        } else if let Some(delete) = node.as_batch_delete() {
1319            Some(delete.core.table_id)
1320        } else {
1321            node.inputs()
1322                .into_iter()
1323                .find_map(|n| Self::collect_dml_table_id(&n))
1324        }
1325    }
1326
1327    fn collect_stage_lookup_join_parallelism(
1328        &self,
1329        node: PlanRef,
1330    ) -> SchedulerResult<Option<usize>> {
1331        if node.node_type() == BatchPlanNodeType::BatchExchange {
1332            // Do not visit next stage.
1333            return Ok(None);
1334        }
1335        if let Some(lookup_join) = node.as_batch_lookup_join() {
1336            let table_catalog = lookup_join.right_table();
1337            let vnode_mapping = self
1338                .worker_node_manager
1339                .fragment_mapping(table_catalog.fragment_id)?;
1340            let parallelism = vnode_mapping.iter().sorted().dedup().count();
1341            Ok(Some(parallelism))
1342        } else {
1343            node.inputs()
1344                .into_iter()
1345                .find_map(|n| self.collect_stage_lookup_join_parallelism(n).transpose())
1346                .transpose()
1347        }
1348    }
1349}
1350
1351/// Try to derive the partition to read from the scan range.
1352/// It can be derived if the value of the distribution key is already known.
1353fn derive_partitions(
1354    scan_ranges: &[ScanRange],
1355    table_catalog: &TableCatalog,
1356    vnode_mapping: &WorkerSlotMapping,
1357) -> SchedulerResult<HashMap<WorkerSlotId, TablePartitionInfo>> {
1358    let vnode_mapping = if table_catalog.vnode_count.value() != vnode_mapping.len() {
1359        // The vnode count mismatch occurs only in special cases where a hash-distributed fragment
1360        // contains singleton internal tables. e.g., the state table of `Source` executors.
1361        // In this case, we reduce the vnode mapping to a single vnode as only `SINGLETON_VNODE` is used.
1362        assert_eq!(
1363            table_catalog.vnode_count.value(),
1364            1,
1365            "fragment vnode count {} does not match table vnode count {}",
1366            vnode_mapping.len(),
1367            table_catalog.vnode_count.value(),
1368        );
1369        &WorkerSlotMapping::new_single(vnode_mapping.iter().next().unwrap())
1370    } else {
1371        vnode_mapping
1372    };
1373    let vnode_count = vnode_mapping.len();
1374
1375    let mut partitions: HashMap<WorkerSlotId, (BitmapBuilder, Vec<_>)> = HashMap::new();
1376
1377    if scan_ranges.is_empty() {
1378        return Ok(vnode_mapping
1379            .to_bitmaps()
1380            .into_iter()
1381            .map(|(k, vnode_bitmap)| {
1382                (
1383                    k,
1384                    TablePartitionInfo {
1385                        vnode_bitmap,
1386                        scan_ranges: vec![],
1387                    },
1388                )
1389            })
1390            .collect());
1391    }
1392
1393    let table_distribution = TableDistribution::new_from_storage_table_desc(
1394        Some(Bitmap::ones(vnode_count).into()),
1395        &table_catalog.table_desc().try_to_protobuf()?,
1396    );
1397
1398    for scan_range in scan_ranges {
1399        let vnode = scan_range.try_compute_vnode(&table_distribution);
1400        match vnode {
1401            None => {
1402                // put this scan_range to all partitions
1403                vnode_mapping.to_bitmaps().into_iter().for_each(
1404                    |(worker_slot_id, vnode_bitmap)| {
1405                        let (bitmap, scan_ranges) = partitions
1406                            .entry(worker_slot_id)
1407                            .or_insert_with(|| (BitmapBuilder::zeroed(vnode_count), vec![]));
1408                        vnode_bitmap
1409                            .iter()
1410                            .enumerate()
1411                            .for_each(|(vnode, b)| bitmap.set(vnode, b));
1412                        scan_ranges.push(scan_range.to_protobuf());
1413                    },
1414                );
1415            }
1416            // scan a single partition
1417            Some(vnode) => {
1418                let worker_slot_id = vnode_mapping[vnode];
1419                let (bitmap, scan_ranges) = partitions
1420                    .entry(worker_slot_id)
1421                    .or_insert_with(|| (BitmapBuilder::zeroed(vnode_count), vec![]));
1422                bitmap.set(vnode.to_index(), true);
1423                scan_ranges.push(scan_range.to_protobuf());
1424            }
1425        }
1426    }
1427
1428    Ok(partitions
1429        .into_iter()
1430        .map(|(k, (bitmap, scan_ranges))| {
1431            (
1432                k,
1433                TablePartitionInfo {
1434                    vnode_bitmap: bitmap.finish(),
1435                    scan_ranges,
1436                },
1437            )
1438        })
1439        .collect())
1440}
1441
1442#[cfg(test)]
1443mod tests {
1444    use std::collections::{HashMap, HashSet};
1445
1446    use risingwave_pb::batch_plan::plan_node::NodeBody;
1447
1448    use crate::optimizer::plan_node::BatchPlanNodeType;
1449    use crate::scheduler::plan_fragmenter::StageId;
1450
1451    #[tokio::test]
1452    async fn test_fragmenter() {
1453        let query = crate::scheduler::distributed::tests::create_query().await;
1454
1455        assert_eq!(query.stage_graph.root_stage_id, 0.into());
1456        assert_eq!(query.stage_graph.stages.len(), 4);
1457
1458        // Check the mappings of child edges.
1459        assert_eq!(
1460            query.stage_graph.child_edges[&0.into()],
1461            HashSet::from_iter([1.into()])
1462        );
1463        assert_eq!(
1464            query.stage_graph.child_edges[&1.into()],
1465            HashSet::from_iter([2.into(), 3.into()])
1466        );
1467        assert_eq!(query.stage_graph.child_edges[&2.into()], HashSet::new());
1468        assert_eq!(query.stage_graph.child_edges[&3.into()], HashSet::new());
1469
1470        // Check the mappings of parent edges.
1471        assert_eq!(query.stage_graph.parent_edges[&0.into()], HashSet::new());
1472        assert_eq!(
1473            query.stage_graph.parent_edges[&1.into()],
1474            HashSet::from_iter([0.into()])
1475        );
1476        assert_eq!(
1477            query.stage_graph.parent_edges[&2.into()],
1478            HashSet::from_iter([1.into()])
1479        );
1480        assert_eq!(
1481            query.stage_graph.parent_edges[&3.into()],
1482            HashSet::from_iter([1.into()])
1483        );
1484
1485        // Verify topology order
1486        {
1487            let stage_id_to_pos: HashMap<StageId, usize> = query
1488                .stage_graph
1489                .stage_ids_by_topo_order()
1490                .enumerate()
1491                .map(|(pos, stage_id)| (stage_id, pos))
1492                .collect();
1493
1494            for stage_id in query.stage_graph.stages.keys() {
1495                let stage_pos = stage_id_to_pos[stage_id];
1496                for child_stage_id in &query.stage_graph.child_edges[stage_id] {
1497                    let child_pos = stage_id_to_pos[child_stage_id];
1498                    assert!(stage_pos > child_pos);
1499                }
1500            }
1501        }
1502
1503        // Check plan node in each stages.
1504        let root_exchange = query.stage_graph.stages.get(&0.into()).unwrap();
1505        assert_eq!(
1506            root_exchange.root.node_type(),
1507            BatchPlanNodeType::BatchExchange
1508        );
1509        assert_eq!(root_exchange.root.source_stage_id, Some(1.into()));
1510        assert!(matches!(root_exchange.root.node, NodeBody::Exchange(_)));
1511        assert_eq!(root_exchange.parallelism, Some(1));
1512        assert!(!root_exchange.has_table_scan());
1513
1514        let join_node = query.stage_graph.stages.get(&1.into()).unwrap();
1515        assert_eq!(join_node.root.node_type(), BatchPlanNodeType::BatchHashJoin);
1516        assert_eq!(join_node.parallelism, Some(24));
1517
1518        assert!(matches!(join_node.root.node, NodeBody::HashJoin(_)));
1519        assert_eq!(join_node.root.source_stage_id, None);
1520        assert_eq!(2, join_node.root.children.len());
1521
1522        assert!(matches!(
1523            join_node.root.children[0].node,
1524            NodeBody::Exchange(_)
1525        ));
1526        assert_eq!(join_node.root.children[0].source_stage_id, Some(2.into()));
1527        assert_eq!(0, join_node.root.children[0].children.len());
1528
1529        assert!(matches!(
1530            join_node.root.children[1].node,
1531            NodeBody::Exchange(_)
1532        ));
1533        assert_eq!(join_node.root.children[1].source_stage_id, Some(3.into()));
1534        assert_eq!(0, join_node.root.children[1].children.len());
1535        assert!(!join_node.has_table_scan());
1536
1537        let scan_node1 = query.stage_graph.stages.get(&2.into()).unwrap();
1538        assert_eq!(scan_node1.root.node_type(), BatchPlanNodeType::BatchSeqScan);
1539        assert_eq!(scan_node1.root.source_stage_id, None);
1540        assert_eq!(0, scan_node1.root.children.len());
1541        assert!(scan_node1.has_table_scan());
1542
1543        let scan_node2 = query.stage_graph.stages.get(&3.into()).unwrap();
1544        assert_eq!(scan_node2.root.node_type(), BatchPlanNodeType::BatchFilter);
1545        assert_eq!(scan_node2.root.source_stage_id, None);
1546        assert_eq!(1, scan_node2.root.children.len());
1547        assert!(scan_node2.has_table_scan());
1548    }
1549}