risingwave_frontend/scheduler/
plan_fragmenter.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::min;
16use std::collections::{HashMap, HashSet};
17use std::fmt::{Debug, Display, Formatter};
18use std::num::NonZeroU64;
19
20use anyhow::anyhow;
21use async_recursion::async_recursion;
22use enum_as_inner::EnumAsInner;
23use futures::TryStreamExt;
24use iceberg::expr::Predicate as IcebergPredicate;
25use itertools::Itertools;
26use petgraph::{Directed, Graph};
27use pgwire::pg_server::SessionId;
28use risingwave_batch::error::BatchError;
29use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeSelector;
30use risingwave_common::bitmap::{Bitmap, BitmapBuilder};
31use risingwave_common::catalog::Schema;
32use risingwave_common::hash::table_distribution::TableDistribution;
33use risingwave_common::hash::{WorkerSlotId, WorkerSlotMapping};
34use risingwave_common::util::scan_range::ScanRange;
35use risingwave_connector::source::filesystem::opendal_source::opendal_enumerator::OpendalEnumerator;
36use risingwave_connector::source::filesystem::opendal_source::{
37    BatchPosixFsEnumerator, OpendalAzblob, OpendalGcs, OpendalS3,
38};
39use risingwave_connector::source::iceberg::IcebergSplitEnumerator;
40use risingwave_connector::source::kafka::KafkaSplitEnumerator;
41use risingwave_connector::source::prelude::DatagenSplitEnumerator;
42use risingwave_connector::source::reader::reader::build_opendal_fs_list_for_batch;
43use risingwave_connector::source::{
44    ConnectorProperties, SourceEnumeratorContext, SplitEnumerator, SplitImpl,
45};
46use risingwave_pb::batch_plan::iceberg_scan_node::IcebergScanType;
47use risingwave_pb::batch_plan::plan_node::NodeBody;
48use risingwave_pb::batch_plan::{ExchangeInfo, ScanRange as ScanRangeProto};
49use risingwave_pb::plan_common::Field as PbField;
50use risingwave_sqlparser::ast::AsOf;
51use serde::ser::SerializeStruct;
52use serde::{Serialize, Serializer};
53use uuid::Uuid;
54
55use super::SchedulerError;
56use crate::TableCatalog;
57use crate::catalog::TableId;
58use crate::catalog::catalog_service::CatalogReader;
59use crate::optimizer::plan_node::generic::{GenericPlanRef, PhysicalPlanRef};
60use crate::optimizer::plan_node::utils::to_iceberg_time_travel_as_of;
61use crate::optimizer::plan_node::{
62    BatchIcebergScan, BatchKafkaScan, BatchPlanNodeType, BatchPlanRef as PlanRef, BatchSource,
63    PlanNodeId,
64};
65use crate::optimizer::property::Distribution;
66use crate::scheduler::SchedulerResult;
67
68#[derive(Clone, Debug, Hash, Eq, PartialEq)]
69pub struct QueryId {
70    pub id: String,
71}
72
73impl std::fmt::Display for QueryId {
74    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
75        write!(f, "QueryId:{}", self.id)
76    }
77}
78
79#[derive(Copy, Clone, Hash, PartialEq, Eq, Ord, PartialOrd)]
80pub struct StageId(u32);
81
82impl Display for StageId {
83    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
84        write!(f, "{}", self.0)
85    }
86}
87
88impl Debug for StageId {
89    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
90        write!(f, "{:?}", self.0)
91    }
92}
93
94impl From<StageId> for u32 {
95    fn from(value: StageId) -> Self {
96        value.0
97    }
98}
99
100impl From<u32> for StageId {
101    fn from(value: u32) -> Self {
102        StageId(value)
103    }
104}
105
106impl Serialize for StageId {
107    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
108    where
109        S: Serializer,
110    {
111        self.0.serialize(serializer)
112    }
113}
114
115impl StageId {
116    pub fn inc(&mut self) {
117        self.0 += 1;
118    }
119}
120
121// Root stage always has only one task.
122pub const ROOT_TASK_ID: u64 = 0;
123// Root task has only one output.
124pub const ROOT_TASK_OUTPUT_ID: u64 = 0;
125pub type TaskId = u64;
126
127/// Generated by [`BatchPlanFragmenter`] and used in query execution graph.
128#[derive(Debug)]
129#[cfg_attr(test, derive(Clone))]
130pub struct ExecutionPlanNode {
131    pub plan_node_id: PlanNodeId,
132    pub plan_node_type: BatchPlanNodeType,
133    pub node: NodeBody,
134    pub schema: Vec<PbField>,
135
136    pub children: Vec<ExecutionPlanNode>,
137
138    /// The stage id of the source of `BatchExchange`.
139    /// Used to find `ExchangeSource` from scheduler when creating `PlanNode`.
140    ///
141    /// `None` when this node is not `BatchExchange`.
142    pub source_stage_id: Option<StageId>,
143}
144
145impl Serialize for ExecutionPlanNode {
146    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
147    where
148        S: serde::Serializer,
149    {
150        let mut state = serializer.serialize_struct("QueryStage", 5)?;
151        state.serialize_field("plan_node_id", &self.plan_node_id)?;
152        state.serialize_field("plan_node_type", &self.plan_node_type)?;
153        state.serialize_field("schema", &self.schema)?;
154        state.serialize_field("children", &self.children)?;
155        state.serialize_field("source_stage_id", &self.source_stage_id)?;
156        state.end()
157    }
158}
159
160impl TryFrom<PlanRef> for ExecutionPlanNode {
161    type Error = SchedulerError;
162
163    fn try_from(plan_node: PlanRef) -> Result<Self, Self::Error> {
164        Ok(Self {
165            plan_node_id: plan_node.plan_base().id(),
166            plan_node_type: plan_node.node_type(),
167            node: plan_node.try_to_batch_prost_body()?,
168            children: vec![],
169            schema: plan_node.schema().to_prost(),
170            source_stage_id: None,
171        })
172    }
173}
174
175impl ExecutionPlanNode {
176    pub fn node_type(&self) -> BatchPlanNodeType {
177        self.plan_node_type
178    }
179}
180
181/// `BatchPlanFragmenter` splits a query plan into fragments.
182pub struct BatchPlanFragmenter {
183    query_id: QueryId,
184    next_stage_id: StageId,
185    worker_node_manager: WorkerNodeSelector,
186    catalog_reader: CatalogReader,
187
188    batch_parallelism: usize,
189    timezone: String,
190
191    stage_graph_builder: Option<StageGraphBuilder>,
192    stage_graph: Option<StageGraph>,
193}
194
195impl Default for QueryId {
196    fn default() -> Self {
197        Self {
198            id: Uuid::new_v4().to_string(),
199        }
200    }
201}
202
203impl BatchPlanFragmenter {
204    pub fn new(
205        worker_node_manager: WorkerNodeSelector,
206        catalog_reader: CatalogReader,
207        batch_parallelism: Option<NonZeroU64>,
208        timezone: String,
209        batch_node: PlanRef,
210    ) -> SchedulerResult<Self> {
211        // if batch_parallelism is None, it means no limit, we will use the available nodes count as
212        // parallelism.
213        // if batch_parallelism is Some(num), we will use the min(num, the available
214        // nodes count) as parallelism.
215        let batch_parallelism = if let Some(num) = batch_parallelism {
216            // can be 0 if no available serving worker
217            min(
218                num.get() as usize,
219                worker_node_manager.schedule_unit_count(),
220            )
221        } else {
222            // can be 0 if no available serving worker
223            worker_node_manager.schedule_unit_count()
224        };
225
226        let mut plan_fragmenter = Self {
227            query_id: Default::default(),
228            next_stage_id: 0.into(),
229            worker_node_manager,
230            catalog_reader,
231            batch_parallelism,
232            timezone,
233            stage_graph_builder: Some(StageGraphBuilder::new(batch_parallelism)),
234            stage_graph: None,
235        };
236        plan_fragmenter.split_into_stage(batch_node)?;
237        Ok(plan_fragmenter)
238    }
239
240    /// Split the plan node into each stages, based on exchange node.
241    fn split_into_stage(&mut self, batch_node: PlanRef) -> SchedulerResult<()> {
242        let root_stage_id = self.new_stage(
243            batch_node,
244            Some(Distribution::Single.to_prost(
245                1,
246                &self.catalog_reader,
247                &self.worker_node_manager,
248            )?),
249        )?;
250        self.stage_graph = Some(
251            self.stage_graph_builder
252                .take()
253                .unwrap()
254                .build(root_stage_id),
255        );
256        Ok(())
257    }
258}
259
260/// The fragmented query generated by [`BatchPlanFragmenter`].
261#[derive(Debug)]
262#[cfg_attr(test, derive(Clone))]
263pub struct Query {
264    /// Query id should always be unique.
265    pub query_id: QueryId,
266    pub stage_graph: StageGraph,
267}
268
269impl Query {
270    pub fn leaf_stages(&self) -> Vec<StageId> {
271        let mut ret_leaf_stages = Vec::new();
272        for stage_id in self.stage_graph.stages.keys() {
273            if self
274                .stage_graph
275                .get_child_stages_unchecked(stage_id)
276                .is_empty()
277            {
278                ret_leaf_stages.push(*stage_id);
279            }
280        }
281        ret_leaf_stages
282    }
283
284    pub fn get_parents(&self, stage_id: &StageId) -> &HashSet<StageId> {
285        self.stage_graph.parent_edges.get(stage_id).unwrap()
286    }
287
288    pub fn root_stage_id(&self) -> StageId {
289        self.stage_graph.root_stage_id
290    }
291
292    pub fn query_id(&self) -> &QueryId {
293        &self.query_id
294    }
295
296    pub fn stages_with_table_scan(&self) -> HashSet<StageId> {
297        self.stage_graph
298            .stages
299            .iter()
300            .filter_map(|(stage_id, stage_query)| {
301                if stage_query.has_table_scan() {
302                    Some(*stage_id)
303                } else {
304                    None
305                }
306            })
307            .collect()
308    }
309
310    pub fn has_lookup_join_stage(&self) -> bool {
311        self.stage_graph
312            .stages
313            .iter()
314            .any(|(_stage_id, stage_query)| stage_query.has_lookup_join())
315    }
316
317    pub fn stage(&self, stage_id: StageId) -> &QueryStage {
318        &self.stage_graph.stages[&stage_id]
319    }
320}
321
322#[derive(Debug, Clone)]
323pub enum SourceFetchParameters {
324    IcebergSpecificInfo(IcebergSpecificInfo),
325    KafkaTimebound {
326        lower: Option<i64>,
327        upper: Option<i64>,
328    },
329    Empty,
330}
331
332#[derive(Debug, Clone)]
333pub struct SourceFetchInfo {
334    pub schema: Schema,
335    /// These are user-configured connector properties.
336    /// e.g. host, username, etc...
337    pub connector: ConnectorProperties,
338    /// These parameters are internally derived by the plan node.
339    /// e.g. predicate pushdown for iceberg, timebound for kafka.
340    pub fetch_parameters: SourceFetchParameters,
341    pub as_of: Option<AsOf>,
342}
343
344#[derive(Debug, Clone)]
345pub struct IcebergSpecificInfo {
346    pub iceberg_scan_type: IcebergScanType,
347    pub predicate: IcebergPredicate,
348}
349
350#[derive(Clone, Debug)]
351pub enum SourceScanInfo {
352    /// Split Info
353    Incomplete(SourceFetchInfo),
354    Complete(Vec<SplitImpl>),
355}
356
357impl SourceScanInfo {
358    pub fn new(fetch_info: SourceFetchInfo) -> Self {
359        Self::Incomplete(fetch_info)
360    }
361
362    pub async fn complete(
363        self,
364        batch_parallelism: usize,
365        timezone: String,
366    ) -> SchedulerResult<Self> {
367        let fetch_info = match self {
368            SourceScanInfo::Incomplete(fetch_info) => fetch_info,
369            SourceScanInfo::Complete(_) => {
370                unreachable!("Never call complete when SourceScanInfo is already complete")
371            }
372        };
373        match (fetch_info.connector, fetch_info.fetch_parameters) {
374            (
375                ConnectorProperties::Kafka(prop),
376                SourceFetchParameters::KafkaTimebound { lower, upper },
377            ) => {
378                let mut kafka_enumerator =
379                    KafkaSplitEnumerator::new(*prop, SourceEnumeratorContext::dummy().into())
380                        .await?;
381                let split_info = kafka_enumerator
382                    .list_splits_batch(lower, upper)
383                    .await?
384                    .into_iter()
385                    .map(SplitImpl::Kafka)
386                    .collect_vec();
387
388                Ok(SourceScanInfo::Complete(split_info))
389            }
390            (ConnectorProperties::Datagen(prop), SourceFetchParameters::Empty) => {
391                let mut datagen_enumerator =
392                    DatagenSplitEnumerator::new(*prop, SourceEnumeratorContext::dummy().into())
393                        .await?;
394                let split_info = datagen_enumerator.list_splits().await?;
395                let res = split_info.into_iter().map(SplitImpl::Datagen).collect_vec();
396
397                Ok(SourceScanInfo::Complete(res))
398            }
399            (ConnectorProperties::OpendalS3(prop), SourceFetchParameters::Empty) => {
400                let lister: OpendalEnumerator<OpendalS3> = OpendalEnumerator::new_s3_source(
401                    &prop.s3_properties,
402                    prop.assume_role,
403                    prop.fs_common.compression_format,
404                )?;
405                let stream = build_opendal_fs_list_for_batch(lister);
406
407                let batch_res: Vec<_> = stream.try_collect().await?;
408                let res = batch_res
409                    .into_iter()
410                    .map(SplitImpl::OpendalS3)
411                    .collect_vec();
412
413                Ok(SourceScanInfo::Complete(res))
414            }
415            (ConnectorProperties::Gcs(prop), SourceFetchParameters::Empty) => {
416                let lister: OpendalEnumerator<OpendalGcs> =
417                    OpendalEnumerator::new_gcs_source(*prop)?;
418                let stream = build_opendal_fs_list_for_batch(lister);
419                let batch_res: Vec<_> = stream.try_collect().await?;
420                let res = batch_res.into_iter().map(SplitImpl::Gcs).collect_vec();
421
422                Ok(SourceScanInfo::Complete(res))
423            }
424            (ConnectorProperties::Azblob(prop), SourceFetchParameters::Empty) => {
425                let lister: OpendalEnumerator<OpendalAzblob> =
426                    OpendalEnumerator::new_azblob_source(*prop)?;
427                let stream = build_opendal_fs_list_for_batch(lister);
428                let batch_res: Vec<_> = stream.try_collect().await?;
429                let res = batch_res.into_iter().map(SplitImpl::Azblob).collect_vec();
430
431                Ok(SourceScanInfo::Complete(res))
432            }
433            (ConnectorProperties::BatchPosixFs(prop), SourceFetchParameters::Empty) => {
434                use risingwave_connector::source::SplitEnumerator;
435                let mut enumerator = BatchPosixFsEnumerator::new(
436                    *prop,
437                    risingwave_connector::source::SourceEnumeratorContext::dummy().into(),
438                )
439                .await?;
440                let splits = enumerator.list_splits().await?;
441                let res = splits
442                    .into_iter()
443                    .map(SplitImpl::BatchPosixFs)
444                    .collect_vec();
445
446                Ok(SourceScanInfo::Complete(res))
447            }
448            (
449                ConnectorProperties::Iceberg(prop),
450                SourceFetchParameters::IcebergSpecificInfo(iceberg_specific_info),
451            ) => {
452                let iceberg_enumerator =
453                    IcebergSplitEnumerator::new(*prop, SourceEnumeratorContext::dummy().into())
454                        .await?;
455
456                let time_travel_info = to_iceberg_time_travel_as_of(&fetch_info.as_of, &timezone)?;
457
458                let split_info = iceberg_enumerator
459                    .list_splits_batch(
460                        fetch_info.schema,
461                        time_travel_info,
462                        batch_parallelism,
463                        iceberg_specific_info.iceberg_scan_type,
464                        iceberg_specific_info.predicate,
465                    )
466                    .await?
467                    .into_iter()
468                    .map(SplitImpl::Iceberg)
469                    .collect_vec();
470
471                Ok(SourceScanInfo::Complete(split_info))
472            }
473            (connector, _) => Err(SchedulerError::Internal(anyhow!(
474                "Unsupported to query directly from this {} source, \
475                 please create a table or streaming job from it",
476                connector.kind()
477            ))),
478        }
479    }
480
481    pub fn split_info(&self) -> SchedulerResult<&Vec<SplitImpl>> {
482        match self {
483            Self::Incomplete(_) => Err(SchedulerError::Internal(anyhow!(
484                "Should not get split info from incomplete source scan info"
485            ))),
486            Self::Complete(split_info) => Ok(split_info),
487        }
488    }
489}
490
491#[derive(Clone, Debug)]
492pub struct TableScanInfo {
493    /// The name of the table to scan.
494    name: String,
495
496    /// Indicates the table partitions to be read by scan tasks. Unnecessary partitions are already
497    /// pruned.
498    ///
499    /// For singleton table, this field is still `Some` and only contains a single partition with
500    /// full vnode bitmap, since we need to know where to schedule the singleton scan task.
501    ///
502    /// `None` iff the table is a system table.
503    partitions: Option<HashMap<WorkerSlotId, TablePartitionInfo>>,
504}
505
506impl TableScanInfo {
507    /// For normal tables, `partitions` should always be `Some`.
508    pub fn new(name: String, partitions: HashMap<WorkerSlotId, TablePartitionInfo>) -> Self {
509        Self {
510            name,
511            partitions: Some(partitions),
512        }
513    }
514
515    /// For system table, there's no partition info.
516    pub fn system_table(name: String) -> Self {
517        Self {
518            name,
519            partitions: None,
520        }
521    }
522
523    pub fn name(&self) -> &str {
524        self.name.as_ref()
525    }
526
527    pub fn partitions(&self) -> Option<&HashMap<WorkerSlotId, TablePartitionInfo>> {
528        self.partitions.as_ref()
529    }
530}
531
532#[derive(Clone, Debug)]
533pub struct TablePartitionInfo {
534    pub vnode_bitmap: Bitmap,
535    pub scan_ranges: Vec<ScanRangeProto>,
536}
537
538#[derive(Clone, Debug, EnumAsInner)]
539pub enum PartitionInfo {
540    Table(TablePartitionInfo),
541    Source(Vec<SplitImpl>),
542    File(Vec<String>),
543}
544
545#[derive(Clone, Debug)]
546pub struct FileScanInfo {
547    pub file_location: Vec<String>,
548}
549
550/// Fragment part of `Query`.
551#[cfg_attr(test, derive(Clone))]
552pub struct QueryStage {
553    pub id: StageId,
554    pub root: ExecutionPlanNode,
555    pub exchange_info: Option<ExchangeInfo>,
556    pub parallelism: Option<u32>,
557    /// Indicates whether this stage contains a table scan node and the table's information if so.
558    pub table_scan_info: Option<TableScanInfo>,
559    pub source_info: Option<SourceScanInfo>,
560    pub file_scan_info: Option<FileScanInfo>,
561    pub has_lookup_join: bool,
562    pub dml_table_id: Option<TableId>,
563    pub session_id: SessionId,
564    pub batch_enable_distributed_dml: bool,
565
566    /// Used to generate exchange information when complete source scan information.
567    children_exchange_distribution: Option<HashMap<StageId, Distribution>>,
568}
569
570impl QueryStage {
571    /// If true, this stage contains table scan executor that creates
572    /// Hummock iterators to read data from table. The iterator is initialized during
573    /// the executor building process on the batch execution engine.
574    pub fn has_table_scan(&self) -> bool {
575        self.table_scan_info.is_some()
576    }
577
578    /// If true, this stage contains lookup join executor.
579    /// We need to delay epoch unpin util the end of the query.
580    pub fn has_lookup_join(&self) -> bool {
581        self.has_lookup_join
582    }
583
584    pub fn with_exchange_info(
585        self,
586        exchange_info: Option<ExchangeInfo>,
587        parallelism: Option<u32>,
588    ) -> Self {
589        if let Some(exchange_info) = exchange_info {
590            Self {
591                id: self.id,
592                root: self.root,
593                exchange_info: Some(exchange_info),
594                parallelism,
595                table_scan_info: self.table_scan_info,
596                source_info: self.source_info,
597                file_scan_info: self.file_scan_info,
598                has_lookup_join: self.has_lookup_join,
599                dml_table_id: self.dml_table_id,
600                session_id: self.session_id,
601                batch_enable_distributed_dml: self.batch_enable_distributed_dml,
602                children_exchange_distribution: self.children_exchange_distribution,
603            }
604        } else {
605            self
606        }
607    }
608
609    pub fn with_exchange_info_and_complete_source_info(
610        self,
611        exchange_info: Option<ExchangeInfo>,
612        source_info: SourceScanInfo,
613        task_parallelism: u32,
614    ) -> Self {
615        assert!(matches!(source_info, SourceScanInfo::Complete(_)));
616        let exchange_info = if let Some(exchange_info) = exchange_info {
617            Some(exchange_info)
618        } else {
619            self.exchange_info
620        };
621        Self {
622            id: self.id,
623            root: self.root,
624            exchange_info,
625            parallelism: Some(task_parallelism),
626            table_scan_info: self.table_scan_info,
627            source_info: Some(source_info),
628            file_scan_info: self.file_scan_info,
629            has_lookup_join: self.has_lookup_join,
630            dml_table_id: self.dml_table_id,
631            session_id: self.session_id,
632            batch_enable_distributed_dml: self.batch_enable_distributed_dml,
633            children_exchange_distribution: None,
634        }
635    }
636}
637
638impl Debug for QueryStage {
639    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
640        f.debug_struct("QueryStage")
641            .field("id", &self.id)
642            .field("parallelism", &self.parallelism)
643            .field("exchange_info", &self.exchange_info)
644            .field("has_table_scan", &self.has_table_scan())
645            .finish()
646    }
647}
648
649impl Serialize for QueryStage {
650    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
651    where
652        S: serde::Serializer,
653    {
654        let mut state = serializer.serialize_struct("QueryStage", 3)?;
655        state.serialize_field("root", &self.root)?;
656        state.serialize_field("parallelism", &self.parallelism)?;
657        state.serialize_field("exchange_info", &self.exchange_info)?;
658        state.end()
659    }
660}
661
662struct QueryStageBuilder {
663    id: StageId,
664    root: Option<ExecutionPlanNode>,
665    parallelism: Option<u32>,
666    exchange_info: Option<ExchangeInfo>,
667
668    children_stages: Vec<StageId>,
669    /// See also [`QueryStage::table_scan_info`].
670    table_scan_info: Option<TableScanInfo>,
671    source_info: Option<SourceScanInfo>,
672    file_scan_file: Option<FileScanInfo>,
673    has_lookup_join: bool,
674    dml_table_id: Option<TableId>,
675    session_id: SessionId,
676    batch_enable_distributed_dml: bool,
677
678    children_exchange_distribution: HashMap<StageId, Distribution>,
679}
680
681impl QueryStageBuilder {
682    #[allow(clippy::too_many_arguments)]
683    fn new(
684        id: StageId,
685        parallelism: Option<u32>,
686        exchange_info: Option<ExchangeInfo>,
687        table_scan_info: Option<TableScanInfo>,
688        source_info: Option<SourceScanInfo>,
689        file_scan_file: Option<FileScanInfo>,
690        has_lookup_join: bool,
691        dml_table_id: Option<TableId>,
692        session_id: SessionId,
693        batch_enable_distributed_dml: bool,
694    ) -> Self {
695        Self {
696            id,
697            root: None,
698            parallelism,
699            exchange_info,
700            children_stages: vec![],
701            table_scan_info,
702            source_info,
703            file_scan_file,
704            has_lookup_join,
705            dml_table_id,
706            session_id,
707            batch_enable_distributed_dml,
708            children_exchange_distribution: HashMap::new(),
709        }
710    }
711
712    fn finish(self, stage_graph_builder: &mut StageGraphBuilder) -> StageId {
713        let children_exchange_distribution = if self.parallelism.is_none() {
714            Some(self.children_exchange_distribution)
715        } else {
716            None
717        };
718        let stage = QueryStage {
719            id: self.id,
720            root: self.root.unwrap(),
721            exchange_info: self.exchange_info,
722            parallelism: self.parallelism,
723            table_scan_info: self.table_scan_info,
724            source_info: self.source_info,
725            file_scan_info: self.file_scan_file,
726            has_lookup_join: self.has_lookup_join,
727            dml_table_id: self.dml_table_id,
728            session_id: self.session_id,
729            batch_enable_distributed_dml: self.batch_enable_distributed_dml,
730            children_exchange_distribution,
731        };
732
733        let stage_id = stage.id;
734        stage_graph_builder.add_node(stage);
735        for child_stage_id in self.children_stages {
736            stage_graph_builder.link_to_child(self.id, child_stage_id);
737        }
738        stage_id
739    }
740}
741
742/// Maintains how each stage are connected.
743#[derive(Debug, Serialize)]
744#[cfg_attr(test, derive(Clone))]
745pub struct StageGraph {
746    pub root_stage_id: StageId,
747    pub stages: HashMap<StageId, QueryStage>,
748    /// Traverse from top to down. Used in split plan into stages.
749    child_edges: HashMap<StageId, HashSet<StageId>>,
750    /// Traverse from down to top. Used in schedule each stage.
751    parent_edges: HashMap<StageId, HashSet<StageId>>,
752
753    batch_parallelism: usize,
754}
755
756enum StageCompleteInfo {
757    ExchangeInfo((Option<ExchangeInfo>, Option<u32>)),
758    ExchangeWithSourceInfo((Option<ExchangeInfo>, SourceScanInfo, u32)),
759}
760
761impl StageGraph {
762    pub fn get_child_stages_unchecked(&self, stage_id: &StageId) -> &HashSet<StageId> {
763        self.child_edges.get(stage_id).unwrap()
764    }
765
766    pub fn get_child_stages(&self, stage_id: &StageId) -> Option<&HashSet<StageId>> {
767        self.child_edges.get(stage_id)
768    }
769
770    /// Returns stage ids in topology order, s.t. child stage always appears before its parent.
771    pub fn stage_ids_by_topo_order(&self) -> impl Iterator<Item = StageId> {
772        let mut stack = Vec::with_capacity(self.stages.len());
773        stack.push(self.root_stage_id);
774        let mut ret = Vec::with_capacity(self.stages.len());
775        let mut existing = HashSet::with_capacity(self.stages.len());
776
777        while let Some(s) = stack.pop() {
778            if !existing.contains(&s) {
779                ret.push(s);
780                existing.insert(s);
781                stack.extend(&self.child_edges[&s]);
782            }
783        }
784
785        ret.into_iter().rev()
786    }
787
788    async fn complete(
789        self,
790        catalog_reader: &CatalogReader,
791        worker_node_manager: &WorkerNodeSelector,
792        timezone: String,
793    ) -> SchedulerResult<StageGraph> {
794        let mut complete_stages = HashMap::new();
795        self.complete_stage(
796            self.root_stage_id,
797            None,
798            &mut complete_stages,
799            catalog_reader,
800            worker_node_manager,
801            timezone,
802        )
803        .await?;
804        let mut stages = self.stages;
805        Ok(StageGraph {
806            root_stage_id: self.root_stage_id,
807            stages: complete_stages
808                .into_iter()
809                .map(|(stage_id, info)| {
810                    let stage = stages.remove(&stage_id).expect("should exist");
811                    let stage = match info {
812                        StageCompleteInfo::ExchangeInfo((exchange_info, parallelism)) => {
813                            stage.with_exchange_info(exchange_info, parallelism)
814                        }
815                        StageCompleteInfo::ExchangeWithSourceInfo((
816                            exchange_info,
817                            source_info,
818                            parallelism,
819                        )) => stage.with_exchange_info_and_complete_source_info(
820                            exchange_info,
821                            source_info,
822                            parallelism,
823                        ),
824                    };
825                    (stage_id, stage)
826                })
827                .collect(),
828            child_edges: self.child_edges,
829            parent_edges: self.parent_edges,
830            batch_parallelism: self.batch_parallelism,
831        })
832    }
833
834    #[async_recursion]
835    async fn complete_stage(
836        &self,
837        stage_id: StageId,
838        exchange_info: Option<ExchangeInfo>,
839        complete_stages: &mut HashMap<StageId, StageCompleteInfo>,
840        catalog_reader: &CatalogReader,
841        worker_node_manager: &WorkerNodeSelector,
842        timezone: String,
843    ) -> SchedulerResult<()> {
844        let stage = &self.stages[&stage_id];
845        let parallelism = if stage.parallelism.is_some() {
846            // If the stage has parallelism, it means it's a complete stage.
847            complete_stages.insert(
848                stage.id,
849                StageCompleteInfo::ExchangeInfo((exchange_info, stage.parallelism)),
850            );
851            None
852        } else if matches!(stage.source_info, Some(SourceScanInfo::Incomplete(_))) {
853            let complete_source_info = stage
854                .source_info
855                .as_ref()
856                .unwrap()
857                .clone()
858                .complete(self.batch_parallelism, timezone.clone())
859                .await?;
860
861            // For batch reading file source, the number of files involved is typically large.
862            // In order to avoid generating a task for each file, the parallelism of tasks is limited here.
863            // The minimum `task_parallelism` is 1. Additionally, `task_parallelism`
864            // must be greater than the number of files to read. Therefore, we first take the
865            // minimum of the number of files and (self.batch_parallelism / 2). If the number of
866            // files is 0, we set task_parallelism to 1.
867
868            let task_parallelism = match &stage.source_info {
869                Some(SourceScanInfo::Incomplete(source_fetch_info)) => {
870                    match source_fetch_info.connector {
871                        ConnectorProperties::Gcs(_)
872                        | ConnectorProperties::OpendalS3(_)
873                        | ConnectorProperties::Azblob(_) => (min(
874                            complete_source_info.split_info().unwrap().len() as u32,
875                            (self.batch_parallelism / 2) as u32,
876                        ))
877                        .max(1),
878                        _ => complete_source_info.split_info().unwrap().len() as u32,
879                    }
880                }
881                _ => unreachable!(),
882            };
883            // For file source batch read, all the files  to be read are divide into several parts to prevent the task from taking up too many resources.
884            // todo(wcy-fdu): Currently it will be divided into half of batch_parallelism groups, and this will be changed to configurable later.
885            let complete_stage_info = StageCompleteInfo::ExchangeWithSourceInfo((
886                exchange_info,
887                complete_source_info,
888                task_parallelism,
889            ));
890            complete_stages.insert(stage.id, complete_stage_info);
891            Some(task_parallelism)
892        } else {
893            assert!(stage.file_scan_info.is_some());
894            let parallelism = min(
895                self.batch_parallelism / 2,
896                stage.file_scan_info.as_ref().unwrap().file_location.len(),
897            );
898            complete_stages.insert(
899                stage.id,
900                StageCompleteInfo::ExchangeInfo((exchange_info, Some(parallelism as u32))),
901            );
902            None
903        };
904
905        for child_stage_id in self
906            .child_edges
907            .get(&stage.id)
908            .map(|edges| edges.iter())
909            .into_iter()
910            .flatten()
911        {
912            let exchange_info = if let Some(parallelism) = parallelism {
913                let exchange_distribution = stage
914                    .children_exchange_distribution
915                    .as_ref()
916                    .unwrap()
917                    .get(child_stage_id)
918                    .expect("Exchange distribution is not consistent with the stage graph");
919                Some(exchange_distribution.to_prost(
920                    parallelism,
921                    catalog_reader,
922                    worker_node_manager,
923                )?)
924            } else {
925                None
926            };
927            self.complete_stage(
928                *child_stage_id,
929                exchange_info,
930                complete_stages,
931                catalog_reader,
932                worker_node_manager,
933                timezone.clone(),
934            )
935            .await?;
936        }
937
938        Ok(())
939    }
940
941    /// Converts the `StageGraph` into a `petgraph::graph::Graph<String, String>`.
942    pub fn to_petgraph(&self) -> Graph<String, String, Directed> {
943        let mut graph = Graph::<String, String, Directed>::new();
944
945        let mut node_indices = HashMap::new();
946
947        // Add all stages as nodes
948        for (&stage_id, stage_ref) in self.stages.iter().sorted_by_key(|(id, _)| **id) {
949            let node_label = format!("Stage {}: {:?}", stage_id, stage_ref);
950            let node_index = graph.add_node(node_label);
951            node_indices.insert(stage_id, node_index);
952        }
953
954        // Add edges between stages based on child_edges
955        for (&parent_id, children) in &self.child_edges {
956            if let Some(&parent_index) = node_indices.get(&parent_id) {
957                for &child_id in children {
958                    if let Some(&child_index) = node_indices.get(&child_id) {
959                        // Add an edge from parent to child
960                        graph.add_edge(parent_index, child_index, "".to_owned());
961                    }
962                }
963            }
964        }
965
966        graph
967    }
968}
969
970struct StageGraphBuilder {
971    stages: HashMap<StageId, QueryStage>,
972    child_edges: HashMap<StageId, HashSet<StageId>>,
973    parent_edges: HashMap<StageId, HashSet<StageId>>,
974    batch_parallelism: usize,
975}
976
977impl StageGraphBuilder {
978    pub fn new(batch_parallelism: usize) -> Self {
979        Self {
980            stages: HashMap::new(),
981            child_edges: HashMap::new(),
982            parent_edges: HashMap::new(),
983            batch_parallelism,
984        }
985    }
986
987    pub fn build(self, root_stage_id: StageId) -> StageGraph {
988        StageGraph {
989            root_stage_id,
990            stages: self.stages,
991            child_edges: self.child_edges,
992            parent_edges: self.parent_edges,
993            batch_parallelism: self.batch_parallelism,
994        }
995    }
996
997    /// Link parent stage and child stage. Maintain the mappings of parent -> child and child ->
998    /// parent.
999    pub fn link_to_child(&mut self, parent_id: StageId, child_id: StageId) {
1000        self.child_edges
1001            .get_mut(&parent_id)
1002            .unwrap()
1003            .insert(child_id);
1004        self.parent_edges
1005            .get_mut(&child_id)
1006            .unwrap()
1007            .insert(parent_id);
1008    }
1009
1010    pub fn add_node(&mut self, stage: QueryStage) {
1011        // Insert here so that left/root stages also has linkage.
1012        self.child_edges.insert(stage.id, HashSet::new());
1013        self.parent_edges.insert(stage.id, HashSet::new());
1014        self.stages.insert(stage.id, stage);
1015    }
1016}
1017
1018impl BatchPlanFragmenter {
1019    /// After split, the `stage_graph` in the framenter may has the stage with incomplete source
1020    /// info, we need to fetch the source info to complete the stage in this function.
1021    /// Why separate this two step(`split()` and `generate_complete_query()`)?
1022    /// The step of fetching source info is a async operation so that we can't do it in the split
1023    /// step.
1024    pub async fn generate_complete_query(self) -> SchedulerResult<Query> {
1025        let stage_graph = self.stage_graph.unwrap();
1026        let new_stage_graph = stage_graph
1027            .complete(
1028                &self.catalog_reader,
1029                &self.worker_node_manager,
1030                self.timezone.clone(),
1031            )
1032            .await?;
1033        Ok(Query {
1034            query_id: self.query_id,
1035            stage_graph: new_stage_graph,
1036        })
1037    }
1038
1039    fn new_stage(
1040        &mut self,
1041        root: PlanRef,
1042        exchange_info: Option<ExchangeInfo>,
1043    ) -> SchedulerResult<StageId> {
1044        let next_stage_id = self.next_stage_id;
1045        self.next_stage_id.inc();
1046
1047        let mut table_scan_info = None;
1048        let mut source_info = None;
1049        let mut file_scan_info = None;
1050
1051        // For current implementation, we can guarantee that each stage has only one table
1052        // scan(except System table) or one source.
1053        if let Some(info) = self.collect_stage_table_scan(root.clone())? {
1054            table_scan_info = Some(info);
1055        } else if let Some(info) = Self::collect_stage_source(root.clone())? {
1056            source_info = Some(info);
1057        } else if let Some(info) = Self::collect_stage_file_scan(root.clone())? {
1058            file_scan_info = Some(info);
1059        }
1060
1061        let mut has_lookup_join = false;
1062        let parallelism = match root.distribution() {
1063            Distribution::Single => {
1064                if let Some(info) = &mut table_scan_info {
1065                    if let Some(partitions) = &mut info.partitions {
1066                        if partitions.len() != 1 {
1067                            // This is rare case, but it's possible on the internal state of the
1068                            // Source operator.
1069                            tracing::warn!(
1070                                "The stage has single distribution, but contains a scan of table `{}` with {} partitions. A single random worker will be assigned",
1071                                info.name,
1072                                partitions.len()
1073                            );
1074
1075                            *partitions = partitions
1076                                .drain()
1077                                .take(1)
1078                                .update(|(_, info)| {
1079                                    info.vnode_bitmap = Bitmap::ones(info.vnode_bitmap.len());
1080                                })
1081                                .collect();
1082                        }
1083                    } else {
1084                        // System table
1085                    }
1086                } else if source_info.is_some() {
1087                    return Err(SchedulerError::Internal(anyhow!(
1088                        "The stage has single distribution, but contains a source operator"
1089                    )));
1090                }
1091                1
1092            }
1093            _ => {
1094                if let Some(table_scan_info) = &table_scan_info {
1095                    table_scan_info
1096                        .partitions
1097                        .as_ref()
1098                        .map(|m| m.len())
1099                        .unwrap_or(1)
1100                } else if let Some(lookup_join_parallelism) =
1101                    self.collect_stage_lookup_join_parallelism(root.clone())?
1102                {
1103                    has_lookup_join = true;
1104                    lookup_join_parallelism
1105                } else if source_info.is_some() {
1106                    0
1107                } else if file_scan_info.is_some() {
1108                    1
1109                } else {
1110                    self.batch_parallelism
1111                }
1112            }
1113        };
1114        if source_info.is_none() && file_scan_info.is_none() && parallelism == 0 {
1115            return Err(BatchError::EmptyWorkerNodes.into());
1116        }
1117        let parallelism = if parallelism == 0 {
1118            None
1119        } else {
1120            Some(parallelism as u32)
1121        };
1122        let dml_table_id = Self::collect_dml_table_id(&root);
1123        let mut builder = QueryStageBuilder::new(
1124            next_stage_id,
1125            parallelism,
1126            exchange_info,
1127            table_scan_info,
1128            source_info,
1129            file_scan_info,
1130            has_lookup_join,
1131            dml_table_id,
1132            root.ctx().session_ctx().session_id(),
1133            root.ctx()
1134                .session_ctx()
1135                .config()
1136                .batch_enable_distributed_dml(),
1137        );
1138
1139        self.visit_node(root, &mut builder, None)?;
1140
1141        Ok(builder.finish(self.stage_graph_builder.as_mut().unwrap()))
1142    }
1143
1144    fn visit_node(
1145        &mut self,
1146        node: PlanRef,
1147        builder: &mut QueryStageBuilder,
1148        parent_exec_node: Option<&mut ExecutionPlanNode>,
1149    ) -> SchedulerResult<()> {
1150        match node.node_type() {
1151            BatchPlanNodeType::BatchExchange => {
1152                self.visit_exchange(node, builder, parent_exec_node)?;
1153            }
1154            _ => {
1155                let mut execution_plan_node = ExecutionPlanNode::try_from(node.clone())?;
1156
1157                for child in node.inputs() {
1158                    self.visit_node(child, builder, Some(&mut execution_plan_node))?;
1159                }
1160
1161                if let Some(parent) = parent_exec_node {
1162                    parent.children.push(execution_plan_node);
1163                } else {
1164                    builder.root = Some(execution_plan_node);
1165                }
1166            }
1167        }
1168        Ok(())
1169    }
1170
1171    fn visit_exchange(
1172        &mut self,
1173        node: PlanRef,
1174        builder: &mut QueryStageBuilder,
1175        parent_exec_node: Option<&mut ExecutionPlanNode>,
1176    ) -> SchedulerResult<()> {
1177        let mut execution_plan_node = ExecutionPlanNode::try_from(node.clone())?;
1178        let child_exchange_info = if let Some(parallelism) = builder.parallelism {
1179            Some(node.distribution().to_prost(
1180                parallelism,
1181                &self.catalog_reader,
1182                &self.worker_node_manager,
1183            )?)
1184        } else {
1185            None
1186        };
1187        let child_stage_id = self.new_stage(node.inputs()[0].clone(), child_exchange_info)?;
1188        execution_plan_node.source_stage_id = Some(child_stage_id);
1189        if builder.parallelism.is_none() {
1190            builder
1191                .children_exchange_distribution
1192                .insert(child_stage_id, node.distribution().clone());
1193        }
1194
1195        if let Some(parent) = parent_exec_node {
1196            parent.children.push(execution_plan_node);
1197        } else {
1198            builder.root = Some(execution_plan_node);
1199        }
1200
1201        builder.children_stages.push(child_stage_id);
1202        Ok(())
1203    }
1204
1205    /// Check whether this stage contains a source node.
1206    /// If so, use  `SplitEnumeratorImpl` to get the split info from exteneral source.
1207    ///
1208    /// For current implementation, we can guarantee that each stage has only one source.
1209    fn collect_stage_source(node: PlanRef) -> SchedulerResult<Option<SourceScanInfo>> {
1210        if node.node_type() == BatchPlanNodeType::BatchExchange {
1211            // Do not visit next stage.
1212            return Ok(None);
1213        }
1214
1215        if let Some(batch_kafka_node) = node.as_batch_kafka_scan() {
1216            let batch_kafka_scan: &BatchKafkaScan = batch_kafka_node;
1217            let source_catalog = batch_kafka_scan.source_catalog();
1218            if let Some(source_catalog) = source_catalog {
1219                let property =
1220                    ConnectorProperties::extract(source_catalog.with_properties.clone(), false)?;
1221                let timestamp_bound = batch_kafka_scan.kafka_timestamp_range_value();
1222                return Ok(Some(SourceScanInfo::new(SourceFetchInfo {
1223                    schema: batch_kafka_scan.base.schema().clone(),
1224                    connector: property,
1225                    fetch_parameters: SourceFetchParameters::KafkaTimebound {
1226                        lower: timestamp_bound.0,
1227                        upper: timestamp_bound.1,
1228                    },
1229                    as_of: None,
1230                })));
1231            }
1232        } else if let Some(batch_iceberg_scan) = node.as_batch_iceberg_scan() {
1233            let batch_iceberg_scan: &BatchIcebergScan = batch_iceberg_scan;
1234            let source_catalog = batch_iceberg_scan.source_catalog();
1235            if let Some(source_catalog) = source_catalog {
1236                let property =
1237                    ConnectorProperties::extract(source_catalog.with_properties.clone(), false)?;
1238                let as_of = batch_iceberg_scan.as_of();
1239                return Ok(Some(SourceScanInfo::new(SourceFetchInfo {
1240                    schema: batch_iceberg_scan.base.schema().clone(),
1241                    connector: property,
1242                    fetch_parameters: SourceFetchParameters::IcebergSpecificInfo(
1243                        IcebergSpecificInfo {
1244                            predicate: batch_iceberg_scan.predicate.clone(),
1245                            iceberg_scan_type: batch_iceberg_scan.iceberg_scan_type(),
1246                        },
1247                    ),
1248                    as_of,
1249                })));
1250            }
1251        } else if let Some(source_node) = node.as_batch_source() {
1252            // TODO: use specific batch operator instead of batch source.
1253            let source_node: &BatchSource = source_node;
1254            let source_catalog = source_node.source_catalog();
1255            if let Some(source_catalog) = source_catalog {
1256                let property =
1257                    ConnectorProperties::extract(source_catalog.with_properties.clone(), false)?;
1258                let as_of = source_node.as_of();
1259                return Ok(Some(SourceScanInfo::new(SourceFetchInfo {
1260                    schema: source_node.base.schema().clone(),
1261                    connector: property,
1262                    fetch_parameters: SourceFetchParameters::Empty,
1263                    as_of,
1264                })));
1265            }
1266        }
1267
1268        node.inputs()
1269            .into_iter()
1270            .find_map(|n| Self::collect_stage_source(n).transpose())
1271            .transpose()
1272    }
1273
1274    fn collect_stage_file_scan(node: PlanRef) -> SchedulerResult<Option<FileScanInfo>> {
1275        if node.node_type() == BatchPlanNodeType::BatchExchange {
1276            // Do not visit next stage.
1277            return Ok(None);
1278        }
1279
1280        if let Some(batch_file_scan) = node.as_batch_file_scan() {
1281            return Ok(Some(FileScanInfo {
1282                file_location: batch_file_scan.core.file_location(),
1283            }));
1284        }
1285
1286        node.inputs()
1287            .into_iter()
1288            .find_map(|n| Self::collect_stage_file_scan(n).transpose())
1289            .transpose()
1290    }
1291
1292    /// Check whether this stage contains a table scan node and the table's information if so.
1293    ///
1294    /// If there are multiple scan nodes in this stage, they must have the same distribution, but
1295    /// maybe different vnodes partition. We just use the same partition for all the scan nodes.
1296    fn collect_stage_table_scan(&self, node: PlanRef) -> SchedulerResult<Option<TableScanInfo>> {
1297        let build_table_scan_info = |name, table_catalog: &TableCatalog, scan_range| {
1298            let vnode_mapping = self
1299                .worker_node_manager
1300                .fragment_mapping(table_catalog.fragment_id)?;
1301            let partitions = derive_partitions(scan_range, table_catalog, &vnode_mapping)?;
1302            let info = TableScanInfo::new(name, partitions);
1303            Ok(Some(info))
1304        };
1305        if node.node_type() == BatchPlanNodeType::BatchExchange {
1306            // Do not visit next stage.
1307            return Ok(None);
1308        }
1309        if let Some(scan_node) = node.as_batch_sys_seq_scan() {
1310            let name = scan_node.core().table.name.clone();
1311            Ok(Some(TableScanInfo::system_table(name)))
1312        } else if let Some(scan_node) = node.as_batch_log_seq_scan() {
1313            build_table_scan_info(
1314                scan_node.core().table_name.clone(),
1315                &scan_node.core().table,
1316                &[],
1317            )
1318        } else if let Some(scan_node) = node.as_batch_seq_scan() {
1319            build_table_scan_info(
1320                scan_node.core().table_name().to_owned(),
1321                &scan_node.core().table_catalog,
1322                scan_node.scan_ranges(),
1323            )
1324        } else {
1325            node.inputs()
1326                .into_iter()
1327                .find_map(|n| self.collect_stage_table_scan(n).transpose())
1328                .transpose()
1329        }
1330    }
1331
1332    /// Returns the dml table id if any.
1333    fn collect_dml_table_id(node: &PlanRef) -> Option<TableId> {
1334        if node.node_type() == BatchPlanNodeType::BatchExchange {
1335            return None;
1336        }
1337        if let Some(insert) = node.as_batch_insert() {
1338            Some(insert.core.table_id)
1339        } else if let Some(update) = node.as_batch_update() {
1340            Some(update.core.table_id)
1341        } else if let Some(delete) = node.as_batch_delete() {
1342            Some(delete.core.table_id)
1343        } else {
1344            node.inputs()
1345                .into_iter()
1346                .find_map(|n| Self::collect_dml_table_id(&n))
1347        }
1348    }
1349
1350    fn collect_stage_lookup_join_parallelism(
1351        &self,
1352        node: PlanRef,
1353    ) -> SchedulerResult<Option<usize>> {
1354        if node.node_type() == BatchPlanNodeType::BatchExchange {
1355            // Do not visit next stage.
1356            return Ok(None);
1357        }
1358        if let Some(lookup_join) = node.as_batch_lookup_join() {
1359            let table_catalog = lookup_join.right_table();
1360            let vnode_mapping = self
1361                .worker_node_manager
1362                .fragment_mapping(table_catalog.fragment_id)?;
1363            let parallelism = vnode_mapping.iter().sorted().dedup().count();
1364            Ok(Some(parallelism))
1365        } else {
1366            node.inputs()
1367                .into_iter()
1368                .find_map(|n| self.collect_stage_lookup_join_parallelism(n).transpose())
1369                .transpose()
1370        }
1371    }
1372}
1373
1374/// Try to derive the partition to read from the scan range.
1375/// It can be derived if the value of the distribution key is already known.
1376fn derive_partitions(
1377    scan_ranges: &[ScanRange],
1378    table_catalog: &TableCatalog,
1379    vnode_mapping: &WorkerSlotMapping,
1380) -> SchedulerResult<HashMap<WorkerSlotId, TablePartitionInfo>> {
1381    let vnode_mapping = if table_catalog.vnode_count.value() != vnode_mapping.len() {
1382        // The vnode count mismatch occurs only in special cases where a hash-distributed fragment
1383        // contains singleton internal tables. e.g., the state table of `Source` executors.
1384        // In this case, we reduce the vnode mapping to a single vnode as only `SINGLETON_VNODE` is used.
1385        assert_eq!(
1386            table_catalog.vnode_count.value(),
1387            1,
1388            "fragment vnode count {} does not match table vnode count {}",
1389            vnode_mapping.len(),
1390            table_catalog.vnode_count.value(),
1391        );
1392        &WorkerSlotMapping::new_single(vnode_mapping.iter().next().unwrap())
1393    } else {
1394        vnode_mapping
1395    };
1396    let vnode_count = vnode_mapping.len();
1397
1398    let mut partitions: HashMap<WorkerSlotId, (BitmapBuilder, Vec<_>)> = HashMap::new();
1399
1400    if scan_ranges.is_empty() {
1401        return Ok(vnode_mapping
1402            .to_bitmaps()
1403            .into_iter()
1404            .map(|(k, vnode_bitmap)| {
1405                (
1406                    k,
1407                    TablePartitionInfo {
1408                        vnode_bitmap,
1409                        scan_ranges: vec![],
1410                    },
1411                )
1412            })
1413            .collect());
1414    }
1415
1416    let table_distribution = TableDistribution::new_from_storage_table_desc(
1417        Some(Bitmap::ones(vnode_count).into()),
1418        &table_catalog.table_desc().try_to_protobuf()?,
1419    );
1420
1421    for scan_range in scan_ranges {
1422        let vnode = scan_range.try_compute_vnode(&table_distribution);
1423        match vnode {
1424            None => {
1425                // put this scan_range to all partitions
1426                vnode_mapping.to_bitmaps().into_iter().for_each(
1427                    |(worker_slot_id, vnode_bitmap)| {
1428                        let (bitmap, scan_ranges) = partitions
1429                            .entry(worker_slot_id)
1430                            .or_insert_with(|| (BitmapBuilder::zeroed(vnode_count), vec![]));
1431                        vnode_bitmap
1432                            .iter()
1433                            .enumerate()
1434                            .for_each(|(vnode, b)| bitmap.set(vnode, b));
1435                        scan_ranges.push(scan_range.to_protobuf());
1436                    },
1437                );
1438            }
1439            // scan a single partition
1440            Some(vnode) => {
1441                let worker_slot_id = vnode_mapping[vnode];
1442                let (bitmap, scan_ranges) = partitions
1443                    .entry(worker_slot_id)
1444                    .or_insert_with(|| (BitmapBuilder::zeroed(vnode_count), vec![]));
1445                bitmap.set(vnode.to_index(), true);
1446                scan_ranges.push(scan_range.to_protobuf());
1447            }
1448        }
1449    }
1450
1451    Ok(partitions
1452        .into_iter()
1453        .map(|(k, (bitmap, scan_ranges))| {
1454            (
1455                k,
1456                TablePartitionInfo {
1457                    vnode_bitmap: bitmap.finish(),
1458                    scan_ranges,
1459                },
1460            )
1461        })
1462        .collect())
1463}
1464
1465#[cfg(test)]
1466mod tests {
1467    use std::collections::{HashMap, HashSet};
1468
1469    use risingwave_pb::batch_plan::plan_node::NodeBody;
1470
1471    use crate::optimizer::plan_node::BatchPlanNodeType;
1472    use crate::scheduler::plan_fragmenter::StageId;
1473
1474    #[tokio::test]
1475    async fn test_fragmenter() {
1476        let query = crate::scheduler::distributed::tests::create_query().await;
1477
1478        assert_eq!(query.stage_graph.root_stage_id, 0.into());
1479        assert_eq!(query.stage_graph.stages.len(), 4);
1480
1481        // Check the mappings of child edges.
1482        assert_eq!(
1483            query.stage_graph.child_edges[&0.into()],
1484            HashSet::from_iter([1.into()])
1485        );
1486        assert_eq!(
1487            query.stage_graph.child_edges[&1.into()],
1488            HashSet::from_iter([2.into(), 3.into()])
1489        );
1490        assert_eq!(query.stage_graph.child_edges[&2.into()], HashSet::new());
1491        assert_eq!(query.stage_graph.child_edges[&3.into()], HashSet::new());
1492
1493        // Check the mappings of parent edges.
1494        assert_eq!(query.stage_graph.parent_edges[&0.into()], HashSet::new());
1495        assert_eq!(
1496            query.stage_graph.parent_edges[&1.into()],
1497            HashSet::from_iter([0.into()])
1498        );
1499        assert_eq!(
1500            query.stage_graph.parent_edges[&2.into()],
1501            HashSet::from_iter([1.into()])
1502        );
1503        assert_eq!(
1504            query.stage_graph.parent_edges[&3.into()],
1505            HashSet::from_iter([1.into()])
1506        );
1507
1508        // Verify topology order
1509        {
1510            let stage_id_to_pos: HashMap<StageId, usize> = query
1511                .stage_graph
1512                .stage_ids_by_topo_order()
1513                .enumerate()
1514                .map(|(pos, stage_id)| (stage_id, pos))
1515                .collect();
1516
1517            for stage_id in query.stage_graph.stages.keys() {
1518                let stage_pos = stage_id_to_pos[stage_id];
1519                for child_stage_id in &query.stage_graph.child_edges[stage_id] {
1520                    let child_pos = stage_id_to_pos[child_stage_id];
1521                    assert!(stage_pos > child_pos);
1522                }
1523            }
1524        }
1525
1526        // Check plan node in each stages.
1527        let root_exchange = query.stage_graph.stages.get(&0.into()).unwrap();
1528        assert_eq!(
1529            root_exchange.root.node_type(),
1530            BatchPlanNodeType::BatchExchange
1531        );
1532        assert_eq!(root_exchange.root.source_stage_id, Some(1.into()));
1533        assert!(matches!(root_exchange.root.node, NodeBody::Exchange(_)));
1534        assert_eq!(root_exchange.parallelism, Some(1));
1535        assert!(!root_exchange.has_table_scan());
1536
1537        let join_node = query.stage_graph.stages.get(&1.into()).unwrap();
1538        assert_eq!(join_node.root.node_type(), BatchPlanNodeType::BatchHashJoin);
1539        assert_eq!(join_node.parallelism, Some(24));
1540
1541        assert!(matches!(join_node.root.node, NodeBody::HashJoin(_)));
1542        assert_eq!(join_node.root.source_stage_id, None);
1543        assert_eq!(2, join_node.root.children.len());
1544
1545        assert!(matches!(
1546            join_node.root.children[0].node,
1547            NodeBody::Exchange(_)
1548        ));
1549        assert_eq!(join_node.root.children[0].source_stage_id, Some(2.into()));
1550        assert_eq!(0, join_node.root.children[0].children.len());
1551
1552        assert!(matches!(
1553            join_node.root.children[1].node,
1554            NodeBody::Exchange(_)
1555        ));
1556        assert_eq!(join_node.root.children[1].source_stage_id, Some(3.into()));
1557        assert_eq!(0, join_node.root.children[1].children.len());
1558        assert!(!join_node.has_table_scan());
1559
1560        let scan_node1 = query.stage_graph.stages.get(&2.into()).unwrap();
1561        assert_eq!(scan_node1.root.node_type(), BatchPlanNodeType::BatchSeqScan);
1562        assert_eq!(scan_node1.root.source_stage_id, None);
1563        assert_eq!(0, scan_node1.root.children.len());
1564        assert!(scan_node1.has_table_scan());
1565
1566        let scan_node2 = query.stage_graph.stages.get(&3.into()).unwrap();
1567        assert_eq!(scan_node2.root.node_type(), BatchPlanNodeType::BatchFilter);
1568        assert_eq!(scan_node2.root.source_stage_id, None);
1569        assert_eq!(1, scan_node2.root.children.len());
1570        assert!(scan_node2.has_table_scan());
1571    }
1572}