risingwave_frontend/scheduler/
plan_fragmenter.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::min;
16use std::collections::{HashMap, HashSet};
17use std::fmt::{Debug, Formatter};
18use std::num::NonZeroU64;
19use std::sync::Arc;
20
21use anyhow::anyhow;
22use async_recursion::async_recursion;
23use enum_as_inner::EnumAsInner;
24use futures::TryStreamExt;
25use iceberg::expr::Predicate as IcebergPredicate;
26use itertools::Itertools;
27use petgraph::{Directed, Graph};
28use pgwire::pg_server::SessionId;
29use risingwave_batch::error::BatchError;
30use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeSelector;
31use risingwave_common::bitmap::{Bitmap, BitmapBuilder};
32use risingwave_common::catalog::{Schema, TableDesc};
33use risingwave_common::hash::table_distribution::TableDistribution;
34use risingwave_common::hash::{WorkerSlotId, WorkerSlotMapping};
35use risingwave_common::util::scan_range::ScanRange;
36use risingwave_connector::source::filesystem::opendal_source::opendal_enumerator::OpendalEnumerator;
37use risingwave_connector::source::filesystem::opendal_source::{
38    OpendalAzblob, OpendalGcs, OpendalS3,
39};
40use risingwave_connector::source::iceberg::IcebergSplitEnumerator;
41use risingwave_connector::source::kafka::KafkaSplitEnumerator;
42use risingwave_connector::source::reader::reader::build_opendal_fs_list_for_batch;
43use risingwave_connector::source::{
44    ConnectorProperties, SourceEnumeratorContext, SplitEnumerator, SplitImpl,
45};
46use risingwave_pb::batch_plan::iceberg_scan_node::IcebergScanType;
47use risingwave_pb::batch_plan::plan_node::NodeBody;
48use risingwave_pb::batch_plan::{ExchangeInfo, ScanRange as ScanRangeProto};
49use risingwave_pb::plan_common::Field as PbField;
50use risingwave_sqlparser::ast::AsOf;
51use serde::Serialize;
52use serde::ser::SerializeStruct;
53use uuid::Uuid;
54
55use super::SchedulerError;
56use crate::catalog::TableId;
57use crate::catalog::catalog_service::CatalogReader;
58use crate::error::RwError;
59use crate::optimizer::PlanRef;
60use crate::optimizer::plan_node::generic::{GenericPlanRef, PhysicalPlanRef};
61use crate::optimizer::plan_node::utils::to_iceberg_time_travel_as_of;
62use crate::optimizer::plan_node::{
63    BatchIcebergScan, BatchKafkaScan, BatchSource, PlanNodeId, PlanNodeType,
64};
65use crate::optimizer::property::Distribution;
66use crate::scheduler::SchedulerResult;
67
68#[derive(Clone, Debug, Hash, Eq, PartialEq)]
69pub struct QueryId {
70    pub id: String,
71}
72
73impl std::fmt::Display for QueryId {
74    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
75        write!(f, "QueryId:{}", self.id)
76    }
77}
78
79pub type StageId = u32;
80
81// Root stage always has only one task.
82pub const ROOT_TASK_ID: u64 = 0;
83// Root task has only one output.
84pub const ROOT_TASK_OUTPUT_ID: u64 = 0;
85pub type TaskId = u64;
86
87/// Generated by [`BatchPlanFragmenter`] and used in query execution graph.
88#[derive(Clone, Debug)]
89pub struct ExecutionPlanNode {
90    pub plan_node_id: PlanNodeId,
91    pub plan_node_type: PlanNodeType,
92    pub node: NodeBody,
93    pub schema: Vec<PbField>,
94
95    pub children: Vec<Arc<ExecutionPlanNode>>,
96
97    /// The stage id of the source of `BatchExchange`.
98    /// Used to find `ExchangeSource` from scheduler when creating `PlanNode`.
99    ///
100    /// `None` when this node is not `BatchExchange`.
101    pub source_stage_id: Option<StageId>,
102}
103
104impl Serialize for ExecutionPlanNode {
105    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
106    where
107        S: serde::Serializer,
108    {
109        let mut state = serializer.serialize_struct("QueryStage", 5)?;
110        state.serialize_field("plan_node_id", &self.plan_node_id)?;
111        state.serialize_field("plan_node_type", &self.plan_node_type)?;
112        state.serialize_field("schema", &self.schema)?;
113        state.serialize_field("children", &self.children)?;
114        state.serialize_field("source_stage_id", &self.source_stage_id)?;
115        state.end()
116    }
117}
118
119impl TryFrom<PlanRef> for ExecutionPlanNode {
120    type Error = SchedulerError;
121
122    fn try_from(plan_node: PlanRef) -> Result<Self, Self::Error> {
123        Ok(Self {
124            plan_node_id: plan_node.plan_base().id(),
125            plan_node_type: plan_node.node_type(),
126            node: plan_node.try_to_batch_prost_body()?,
127            children: vec![],
128            schema: plan_node.schema().to_prost(),
129            source_stage_id: None,
130        })
131    }
132}
133
134impl ExecutionPlanNode {
135    pub fn node_type(&self) -> PlanNodeType {
136        self.plan_node_type
137    }
138}
139
140/// `BatchPlanFragmenter` splits a query plan into fragments.
141pub struct BatchPlanFragmenter {
142    query_id: QueryId,
143    next_stage_id: StageId,
144    worker_node_manager: WorkerNodeSelector,
145    catalog_reader: CatalogReader,
146
147    batch_parallelism: usize,
148    timezone: String,
149
150    stage_graph_builder: Option<StageGraphBuilder>,
151    stage_graph: Option<StageGraph>,
152}
153
154impl Default for QueryId {
155    fn default() -> Self {
156        Self {
157            id: Uuid::new_v4().to_string(),
158        }
159    }
160}
161
162impl BatchPlanFragmenter {
163    pub fn new(
164        worker_node_manager: WorkerNodeSelector,
165        catalog_reader: CatalogReader,
166        batch_parallelism: Option<NonZeroU64>,
167        timezone: String,
168        batch_node: PlanRef,
169    ) -> SchedulerResult<Self> {
170        // if batch_parallelism is None, it means no limit, we will use the available nodes count as
171        // parallelism.
172        // if batch_parallelism is Some(num), we will use the min(num, the available
173        // nodes count) as parallelism.
174        let batch_parallelism = if let Some(num) = batch_parallelism {
175            // can be 0 if no available serving worker
176            min(
177                num.get() as usize,
178                worker_node_manager.schedule_unit_count(),
179            )
180        } else {
181            // can be 0 if no available serving worker
182            worker_node_manager.schedule_unit_count()
183        };
184
185        let mut plan_fragmenter = Self {
186            query_id: Default::default(),
187            next_stage_id: 0,
188            worker_node_manager,
189            catalog_reader,
190            batch_parallelism,
191            timezone,
192            stage_graph_builder: Some(StageGraphBuilder::new(batch_parallelism)),
193            stage_graph: None,
194        };
195        plan_fragmenter.split_into_stage(batch_node)?;
196        Ok(plan_fragmenter)
197    }
198
199    /// Split the plan node into each stages, based on exchange node.
200    fn split_into_stage(&mut self, batch_node: PlanRef) -> SchedulerResult<()> {
201        let root_stage = self.new_stage(
202            batch_node,
203            Some(Distribution::Single.to_prost(
204                1,
205                &self.catalog_reader,
206                &self.worker_node_manager,
207            )?),
208        )?;
209        self.stage_graph = Some(
210            self.stage_graph_builder
211                .take()
212                .unwrap()
213                .build(root_stage.id),
214        );
215        Ok(())
216    }
217}
218
219/// The fragmented query generated by [`BatchPlanFragmenter`].
220#[derive(Debug)]
221#[cfg_attr(test, derive(Clone))]
222pub struct Query {
223    /// Query id should always be unique.
224    pub query_id: QueryId,
225    pub stage_graph: StageGraph,
226}
227
228impl Query {
229    pub fn leaf_stages(&self) -> Vec<StageId> {
230        let mut ret_leaf_stages = Vec::new();
231        for stage_id in self.stage_graph.stages.keys() {
232            if self
233                .stage_graph
234                .get_child_stages_unchecked(stage_id)
235                .is_empty()
236            {
237                ret_leaf_stages.push(*stage_id);
238            }
239        }
240        ret_leaf_stages
241    }
242
243    pub fn get_parents(&self, stage_id: &StageId) -> &HashSet<StageId> {
244        self.stage_graph.parent_edges.get(stage_id).unwrap()
245    }
246
247    pub fn root_stage_id(&self) -> StageId {
248        self.stage_graph.root_stage_id
249    }
250
251    pub fn query_id(&self) -> &QueryId {
252        &self.query_id
253    }
254
255    pub fn stages_with_table_scan(&self) -> HashSet<StageId> {
256        self.stage_graph
257            .stages
258            .iter()
259            .filter_map(|(stage_id, stage_query)| {
260                if stage_query.has_table_scan() {
261                    Some(*stage_id)
262                } else {
263                    None
264                }
265            })
266            .collect()
267    }
268
269    pub fn has_lookup_join_stage(&self) -> bool {
270        self.stage_graph
271            .stages
272            .iter()
273            .any(|(_stage_id, stage_query)| stage_query.has_lookup_join())
274    }
275}
276
277#[derive(Debug, Clone)]
278pub enum SourceFetchParameters {
279    IcebergSpecificInfo(IcebergSpecificInfo),
280    KafkaTimebound {
281        lower: Option<i64>,
282        upper: Option<i64>,
283    },
284    Empty,
285}
286
287#[derive(Debug, Clone)]
288pub struct SourceFetchInfo {
289    pub schema: Schema,
290    /// These are user-configured connector properties.
291    /// e.g. host, username, etc...
292    pub connector: ConnectorProperties,
293    /// These parameters are internally derived by the plan node.
294    /// e.g. predicate pushdown for iceberg, timebound for kafka.
295    pub fetch_parameters: SourceFetchParameters,
296    pub as_of: Option<AsOf>,
297}
298
299#[derive(Debug, Clone)]
300pub struct IcebergSpecificInfo {
301    pub iceberg_scan_type: IcebergScanType,
302    pub predicate: IcebergPredicate,
303}
304
305#[derive(Clone, Debug)]
306pub enum SourceScanInfo {
307    /// Split Info
308    Incomplete(SourceFetchInfo),
309    Complete(Vec<SplitImpl>),
310}
311
312impl SourceScanInfo {
313    pub fn new(fetch_info: SourceFetchInfo) -> Self {
314        Self::Incomplete(fetch_info)
315    }
316
317    pub async fn complete(
318        self,
319        batch_parallelism: usize,
320        timezone: String,
321    ) -> SchedulerResult<Self> {
322        let fetch_info = match self {
323            SourceScanInfo::Incomplete(fetch_info) => fetch_info,
324            SourceScanInfo::Complete(_) => {
325                unreachable!("Never call complete when SourceScanInfo is already complete")
326            }
327        };
328        match (fetch_info.connector, fetch_info.fetch_parameters) {
329            (
330                ConnectorProperties::Kafka(prop),
331                SourceFetchParameters::KafkaTimebound { lower, upper },
332            ) => {
333                let mut kafka_enumerator =
334                    KafkaSplitEnumerator::new(*prop, SourceEnumeratorContext::dummy().into())
335                        .await?;
336                let split_info = kafka_enumerator
337                    .list_splits_batch(lower, upper)
338                    .await?
339                    .into_iter()
340                    .map(SplitImpl::Kafka)
341                    .collect_vec();
342
343                Ok(SourceScanInfo::Complete(split_info))
344            }
345            (ConnectorProperties::OpendalS3(prop), SourceFetchParameters::Empty) => {
346                let lister: OpendalEnumerator<OpendalS3> =
347                    OpendalEnumerator::new_s3_source(prop.s3_properties, prop.assume_role)?;
348                let stream = build_opendal_fs_list_for_batch(lister);
349
350                let batch_res: Vec<_> = stream.try_collect().await?;
351                let res = batch_res
352                    .into_iter()
353                    .map(SplitImpl::OpendalS3)
354                    .collect_vec();
355
356                Ok(SourceScanInfo::Complete(res))
357            }
358            (ConnectorProperties::Gcs(prop), SourceFetchParameters::Empty) => {
359                let lister: OpendalEnumerator<OpendalGcs> =
360                    OpendalEnumerator::new_gcs_source(*prop)?;
361                let stream = build_opendal_fs_list_for_batch(lister);
362                let batch_res: Vec<_> = stream.try_collect().await?;
363                let res = batch_res.into_iter().map(SplitImpl::Gcs).collect_vec();
364
365                Ok(SourceScanInfo::Complete(res))
366            }
367            (ConnectorProperties::Azblob(prop), SourceFetchParameters::Empty) => {
368                let lister: OpendalEnumerator<OpendalAzblob> =
369                    OpendalEnumerator::new_azblob_source(*prop)?;
370                let stream = build_opendal_fs_list_for_batch(lister);
371                let batch_res: Vec<_> = stream.try_collect().await?;
372                let res = batch_res.into_iter().map(SplitImpl::Azblob).collect_vec();
373
374                Ok(SourceScanInfo::Complete(res))
375            }
376            (
377                ConnectorProperties::Iceberg(prop),
378                SourceFetchParameters::IcebergSpecificInfo(iceberg_specific_info),
379            ) => {
380                let iceberg_enumerator =
381                    IcebergSplitEnumerator::new(*prop, SourceEnumeratorContext::dummy().into())
382                        .await?;
383
384                let time_travel_info = to_iceberg_time_travel_as_of(&fetch_info.as_of, &timezone)?;
385
386                let split_info = iceberg_enumerator
387                    .list_splits_batch(
388                        fetch_info.schema,
389                        time_travel_info,
390                        batch_parallelism,
391                        iceberg_specific_info.iceberg_scan_type,
392                        iceberg_specific_info.predicate,
393                    )
394                    .await?
395                    .into_iter()
396                    .map(SplitImpl::Iceberg)
397                    .collect_vec();
398
399                Ok(SourceScanInfo::Complete(split_info))
400            }
401            _ => Err(SchedulerError::Internal(anyhow!(
402                "Unsupported to query directly from this source"
403            ))),
404        }
405    }
406
407    pub fn split_info(&self) -> SchedulerResult<&Vec<SplitImpl>> {
408        match self {
409            Self::Incomplete(_) => Err(SchedulerError::Internal(anyhow!(
410                "Should not get split info from incomplete source scan info"
411            ))),
412            Self::Complete(split_info) => Ok(split_info),
413        }
414    }
415}
416
417#[derive(Clone, Debug)]
418pub struct TableScanInfo {
419    /// The name of the table to scan.
420    name: String,
421
422    /// Indicates the table partitions to be read by scan tasks. Unnecessary partitions are already
423    /// pruned.
424    ///
425    /// For singleton table, this field is still `Some` and only contains a single partition with
426    /// full vnode bitmap, since we need to know where to schedule the singleton scan task.
427    ///
428    /// `None` iff the table is a system table.
429    partitions: Option<HashMap<WorkerSlotId, TablePartitionInfo>>,
430}
431
432impl TableScanInfo {
433    /// For normal tables, `partitions` should always be `Some`.
434    pub fn new(name: String, partitions: HashMap<WorkerSlotId, TablePartitionInfo>) -> Self {
435        Self {
436            name,
437            partitions: Some(partitions),
438        }
439    }
440
441    /// For system table, there's no partition info.
442    pub fn system_table(name: String) -> Self {
443        Self {
444            name,
445            partitions: None,
446        }
447    }
448
449    pub fn name(&self) -> &str {
450        self.name.as_ref()
451    }
452
453    pub fn partitions(&self) -> Option<&HashMap<WorkerSlotId, TablePartitionInfo>> {
454        self.partitions.as_ref()
455    }
456}
457
458#[derive(Clone, Debug)]
459pub struct TablePartitionInfo {
460    pub vnode_bitmap: Bitmap,
461    pub scan_ranges: Vec<ScanRangeProto>,
462}
463
464#[derive(Clone, Debug, EnumAsInner)]
465pub enum PartitionInfo {
466    Table(TablePartitionInfo),
467    Source(Vec<SplitImpl>),
468    File(Vec<String>),
469}
470
471#[derive(Clone, Debug)]
472pub struct FileScanInfo {
473    pub file_location: Vec<String>,
474}
475
476/// Fragment part of `Query`.
477#[derive(Clone)]
478pub struct QueryStage {
479    pub query_id: QueryId,
480    pub id: StageId,
481    pub root: Arc<ExecutionPlanNode>,
482    pub exchange_info: Option<ExchangeInfo>,
483    pub parallelism: Option<u32>,
484    /// Indicates whether this stage contains a table scan node and the table's information if so.
485    pub table_scan_info: Option<TableScanInfo>,
486    pub source_info: Option<SourceScanInfo>,
487    pub file_scan_info: Option<FileScanInfo>,
488    pub has_lookup_join: bool,
489    pub dml_table_id: Option<TableId>,
490    pub session_id: SessionId,
491    pub batch_enable_distributed_dml: bool,
492
493    /// Used to generate exchange information when complete source scan information.
494    children_exchange_distribution: Option<HashMap<StageId, Distribution>>,
495}
496
497impl QueryStage {
498    /// If true, this stage contains table scan executor that creates
499    /// Hummock iterators to read data from table. The iterator is initialized during
500    /// the executor building process on the batch execution engine.
501    pub fn has_table_scan(&self) -> bool {
502        self.table_scan_info.is_some()
503    }
504
505    /// If true, this stage contains lookup join executor.
506    /// We need to delay epoch unpin util the end of the query.
507    pub fn has_lookup_join(&self) -> bool {
508        self.has_lookup_join
509    }
510
511    pub fn clone_with_exchange_info(
512        &self,
513        exchange_info: Option<ExchangeInfo>,
514        parallelism: Option<u32>,
515    ) -> Self {
516        if let Some(exchange_info) = exchange_info {
517            return Self {
518                query_id: self.query_id.clone(),
519                id: self.id,
520                root: self.root.clone(),
521                exchange_info: Some(exchange_info),
522                parallelism,
523                table_scan_info: self.table_scan_info.clone(),
524                source_info: self.source_info.clone(),
525                file_scan_info: self.file_scan_info.clone(),
526                has_lookup_join: self.has_lookup_join,
527                dml_table_id: self.dml_table_id,
528                session_id: self.session_id,
529                batch_enable_distributed_dml: self.batch_enable_distributed_dml,
530                children_exchange_distribution: self.children_exchange_distribution.clone(),
531            };
532        }
533        self.clone()
534    }
535
536    pub fn clone_with_exchange_info_and_complete_source_info(
537        &self,
538        exchange_info: Option<ExchangeInfo>,
539        source_info: SourceScanInfo,
540        task_parallelism: u32,
541    ) -> Self {
542        assert!(matches!(source_info, SourceScanInfo::Complete(_)));
543        let exchange_info = if let Some(exchange_info) = exchange_info {
544            Some(exchange_info)
545        } else {
546            self.exchange_info.clone()
547        };
548        Self {
549            query_id: self.query_id.clone(),
550            id: self.id,
551            root: self.root.clone(),
552            exchange_info,
553            parallelism: Some(task_parallelism),
554            table_scan_info: self.table_scan_info.clone(),
555            source_info: Some(source_info),
556            file_scan_info: self.file_scan_info.clone(),
557            has_lookup_join: self.has_lookup_join,
558            dml_table_id: self.dml_table_id,
559            session_id: self.session_id,
560            batch_enable_distributed_dml: self.batch_enable_distributed_dml,
561            children_exchange_distribution: None,
562        }
563    }
564}
565
566impl Debug for QueryStage {
567    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
568        f.debug_struct("QueryStage")
569            .field("id", &self.id)
570            .field("parallelism", &self.parallelism)
571            .field("exchange_info", &self.exchange_info)
572            .field("has_table_scan", &self.has_table_scan())
573            .finish()
574    }
575}
576
577impl Serialize for QueryStage {
578    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
579    where
580        S: serde::Serializer,
581    {
582        let mut state = serializer.serialize_struct("QueryStage", 3)?;
583        state.serialize_field("root", &self.root)?;
584        state.serialize_field("parallelism", &self.parallelism)?;
585        state.serialize_field("exchange_info", &self.exchange_info)?;
586        state.end()
587    }
588}
589
590pub type QueryStageRef = Arc<QueryStage>;
591
592struct QueryStageBuilder {
593    query_id: QueryId,
594    id: StageId,
595    root: Option<Arc<ExecutionPlanNode>>,
596    parallelism: Option<u32>,
597    exchange_info: Option<ExchangeInfo>,
598
599    children_stages: Vec<QueryStageRef>,
600    /// See also [`QueryStage::table_scan_info`].
601    table_scan_info: Option<TableScanInfo>,
602    source_info: Option<SourceScanInfo>,
603    file_scan_file: Option<FileScanInfo>,
604    has_lookup_join: bool,
605    dml_table_id: Option<TableId>,
606    session_id: SessionId,
607    batch_enable_distributed_dml: bool,
608
609    children_exchange_distribution: HashMap<StageId, Distribution>,
610}
611
612impl QueryStageBuilder {
613    #[allow(clippy::too_many_arguments)]
614    fn new(
615        id: StageId,
616        query_id: QueryId,
617        parallelism: Option<u32>,
618        exchange_info: Option<ExchangeInfo>,
619        table_scan_info: Option<TableScanInfo>,
620        source_info: Option<SourceScanInfo>,
621        file_scan_file: Option<FileScanInfo>,
622        has_lookup_join: bool,
623        dml_table_id: Option<TableId>,
624        session_id: SessionId,
625        batch_enable_distributed_dml: bool,
626    ) -> Self {
627        Self {
628            query_id,
629            id,
630            root: None,
631            parallelism,
632            exchange_info,
633            children_stages: vec![],
634            table_scan_info,
635            source_info,
636            file_scan_file,
637            has_lookup_join,
638            dml_table_id,
639            session_id,
640            batch_enable_distributed_dml,
641            children_exchange_distribution: HashMap::new(),
642        }
643    }
644
645    fn finish(self, stage_graph_builder: &mut StageGraphBuilder) -> QueryStageRef {
646        let children_exchange_distribution = if self.parallelism.is_none() {
647            Some(self.children_exchange_distribution)
648        } else {
649            None
650        };
651        let stage = Arc::new(QueryStage {
652            query_id: self.query_id,
653            id: self.id,
654            root: self.root.unwrap(),
655            exchange_info: self.exchange_info,
656            parallelism: self.parallelism,
657            table_scan_info: self.table_scan_info,
658            source_info: self.source_info,
659            file_scan_info: self.file_scan_file,
660            has_lookup_join: self.has_lookup_join,
661            dml_table_id: self.dml_table_id,
662            session_id: self.session_id,
663            batch_enable_distributed_dml: self.batch_enable_distributed_dml,
664            children_exchange_distribution,
665        });
666
667        stage_graph_builder.add_node(stage.clone());
668        for child_stage in self.children_stages {
669            stage_graph_builder.link_to_child(self.id, child_stage.id);
670        }
671        stage
672    }
673}
674
675/// Maintains how each stage are connected.
676#[derive(Debug, Serialize)]
677#[cfg_attr(test, derive(Clone))]
678pub struct StageGraph {
679    pub root_stage_id: StageId,
680    pub stages: HashMap<StageId, QueryStageRef>,
681    /// Traverse from top to down. Used in split plan into stages.
682    child_edges: HashMap<StageId, HashSet<StageId>>,
683    /// Traverse from down to top. Used in schedule each stage.
684    parent_edges: HashMap<StageId, HashSet<StageId>>,
685
686    batch_parallelism: usize,
687}
688
689impl StageGraph {
690    pub fn get_child_stages_unchecked(&self, stage_id: &StageId) -> &HashSet<StageId> {
691        self.child_edges.get(stage_id).unwrap()
692    }
693
694    pub fn get_child_stages(&self, stage_id: &StageId) -> Option<&HashSet<StageId>> {
695        self.child_edges.get(stage_id)
696    }
697
698    /// Returns stage ids in topology order, s.t. child stage always appears before its parent.
699    pub fn stage_ids_by_topo_order(&self) -> impl Iterator<Item = StageId> {
700        let mut stack = Vec::with_capacity(self.stages.len());
701        stack.push(self.root_stage_id);
702        let mut ret = Vec::with_capacity(self.stages.len());
703        let mut existing = HashSet::with_capacity(self.stages.len());
704
705        while let Some(s) = stack.pop() {
706            if !existing.contains(&s) {
707                ret.push(s);
708                existing.insert(s);
709                stack.extend(&self.child_edges[&s]);
710            }
711        }
712
713        ret.into_iter().rev()
714    }
715
716    async fn complete(
717        self,
718        catalog_reader: &CatalogReader,
719        worker_node_manager: &WorkerNodeSelector,
720        timezone: String,
721    ) -> SchedulerResult<StageGraph> {
722        let mut complete_stages = HashMap::new();
723        self.complete_stage(
724            self.stages.get(&self.root_stage_id).unwrap().clone(),
725            None,
726            &mut complete_stages,
727            catalog_reader,
728            worker_node_manager,
729            timezone,
730        )
731        .await?;
732        Ok(StageGraph {
733            root_stage_id: self.root_stage_id,
734            stages: complete_stages,
735            child_edges: self.child_edges,
736            parent_edges: self.parent_edges,
737            batch_parallelism: self.batch_parallelism,
738        })
739    }
740
741    #[async_recursion]
742    async fn complete_stage(
743        &self,
744        stage: QueryStageRef,
745        exchange_info: Option<ExchangeInfo>,
746        complete_stages: &mut HashMap<StageId, QueryStageRef>,
747        catalog_reader: &CatalogReader,
748        worker_node_manager: &WorkerNodeSelector,
749        timezone: String,
750    ) -> SchedulerResult<()> {
751        let parallelism = if stage.parallelism.is_some() {
752            // If the stage has parallelism, it means it's a complete stage.
753            complete_stages.insert(
754                stage.id,
755                Arc::new(stage.clone_with_exchange_info(exchange_info, stage.parallelism)),
756            );
757            None
758        } else if matches!(stage.source_info, Some(SourceScanInfo::Incomplete(_))) {
759            let complete_source_info = stage
760                .source_info
761                .as_ref()
762                .unwrap()
763                .clone()
764                .complete(self.batch_parallelism, timezone.to_owned())
765                .await?;
766
767            // For batch reading file source, the number of files involved is typically large.
768            // In order to avoid generating a task for each file, the parallelism of tasks is limited here.
769            // The minimum `task_parallelism` is 1. Additionally, `task_parallelism`
770            // must be greater than the number of files to read. Therefore, we first take the
771            // minimum of the number of files and (self.batch_parallelism / 2). If the number of
772            // files is 0, we set task_parallelism to 1.
773
774            let task_parallelism = match &stage.source_info {
775                Some(SourceScanInfo::Incomplete(source_fetch_info)) => {
776                    match source_fetch_info.connector {
777                        ConnectorProperties::Gcs(_)
778                        | ConnectorProperties::OpendalS3(_)
779                        | ConnectorProperties::Azblob(_) => (min(
780                            complete_source_info.split_info().unwrap().len() as u32,
781                            (self.batch_parallelism / 2) as u32,
782                        ))
783                        .max(1),
784                        _ => complete_source_info.split_info().unwrap().len() as u32,
785                    }
786                }
787                _ => unreachable!(),
788            };
789            // For file source batch read, all the files  to be read are divide into several parts to prevent the task from taking up too many resources.
790            // todo(wcy-fdu): Currently it will be divided into half of batch_parallelism groups, and this will be changed to configurable later.
791            let complete_stage = Arc::new(stage.clone_with_exchange_info_and_complete_source_info(
792                exchange_info,
793                complete_source_info,
794                task_parallelism,
795            ));
796            let parallelism = complete_stage.parallelism;
797            complete_stages.insert(stage.id, complete_stage);
798            parallelism
799        } else {
800            assert!(stage.file_scan_info.is_some());
801            let parallelism = min(
802                self.batch_parallelism / 2,
803                stage.file_scan_info.as_ref().unwrap().file_location.len(),
804            );
805            complete_stages.insert(
806                stage.id,
807                Arc::new(stage.clone_with_exchange_info(exchange_info, Some(parallelism as u32))),
808            );
809            None
810        };
811
812        for child_stage_id in self.child_edges.get(&stage.id).unwrap_or(&HashSet::new()) {
813            let exchange_info = if let Some(parallelism) = parallelism {
814                let exchange_distribution = stage
815                    .children_exchange_distribution
816                    .as_ref()
817                    .unwrap()
818                    .get(child_stage_id)
819                    .expect("Exchange distribution is not consistent with the stage graph");
820                Some(exchange_distribution.to_prost(
821                    parallelism,
822                    catalog_reader,
823                    worker_node_manager,
824                )?)
825            } else {
826                None
827            };
828            self.complete_stage(
829                self.stages.get(child_stage_id).unwrap().clone(),
830                exchange_info,
831                complete_stages,
832                catalog_reader,
833                worker_node_manager,
834                timezone.to_owned(),
835            )
836            .await?;
837        }
838
839        Ok(())
840    }
841
842    /// Converts the `StageGraph` into a `petgraph::graph::Graph<String, String>`.
843    pub fn to_petgraph(&self) -> Graph<String, String, Directed> {
844        let mut graph = Graph::<String, String, Directed>::new();
845
846        let mut node_indices = HashMap::new();
847
848        // Add all stages as nodes
849        for (&stage_id, stage_ref) in self.stages.iter().sorted_by_key(|(id, _)| **id) {
850            let node_label = format!("Stage {}: {:?}", stage_id, stage_ref);
851            let node_index = graph.add_node(node_label);
852            node_indices.insert(stage_id, node_index);
853        }
854
855        // Add edges between stages based on child_edges
856        for (&parent_id, children) in &self.child_edges {
857            if let Some(&parent_index) = node_indices.get(&parent_id) {
858                for &child_id in children {
859                    if let Some(&child_index) = node_indices.get(&child_id) {
860                        // Add an edge from parent to child
861                        graph.add_edge(parent_index, child_index, "".to_owned());
862                    }
863                }
864            }
865        }
866
867        graph
868    }
869}
870
871struct StageGraphBuilder {
872    stages: HashMap<StageId, QueryStageRef>,
873    child_edges: HashMap<StageId, HashSet<StageId>>,
874    parent_edges: HashMap<StageId, HashSet<StageId>>,
875    batch_parallelism: usize,
876}
877
878impl StageGraphBuilder {
879    pub fn new(batch_parallelism: usize) -> Self {
880        Self {
881            stages: HashMap::new(),
882            child_edges: HashMap::new(),
883            parent_edges: HashMap::new(),
884            batch_parallelism,
885        }
886    }
887
888    pub fn build(self, root_stage_id: StageId) -> StageGraph {
889        StageGraph {
890            root_stage_id,
891            stages: self.stages,
892            child_edges: self.child_edges,
893            parent_edges: self.parent_edges,
894            batch_parallelism: self.batch_parallelism,
895        }
896    }
897
898    /// Link parent stage and child stage. Maintain the mappings of parent -> child and child ->
899    /// parent.
900    pub fn link_to_child(&mut self, parent_id: StageId, child_id: StageId) {
901        self.child_edges
902            .get_mut(&parent_id)
903            .unwrap()
904            .insert(child_id);
905        self.parent_edges
906            .get_mut(&child_id)
907            .unwrap()
908            .insert(parent_id);
909    }
910
911    pub fn add_node(&mut self, stage: QueryStageRef) {
912        // Insert here so that left/root stages also has linkage.
913        self.child_edges.insert(stage.id, HashSet::new());
914        self.parent_edges.insert(stage.id, HashSet::new());
915        self.stages.insert(stage.id, stage);
916    }
917}
918
919impl BatchPlanFragmenter {
920    /// After split, the `stage_graph` in the framenter may has the stage with incomplete source
921    /// info, we need to fetch the source info to complete the stage in this function.
922    /// Why separate this two step(`split()` and `generate_complete_query()`)?
923    /// The step of fetching source info is a async operation so that we can't do it in the split
924    /// step.
925    pub async fn generate_complete_query(self) -> SchedulerResult<Query> {
926        let stage_graph = self.stage_graph.unwrap();
927        let new_stage_graph = stage_graph
928            .complete(
929                &self.catalog_reader,
930                &self.worker_node_manager,
931                self.timezone.to_owned(),
932            )
933            .await?;
934        Ok(Query {
935            query_id: self.query_id,
936            stage_graph: new_stage_graph,
937        })
938    }
939
940    fn new_stage(
941        &mut self,
942        root: PlanRef,
943        exchange_info: Option<ExchangeInfo>,
944    ) -> SchedulerResult<QueryStageRef> {
945        let next_stage_id = self.next_stage_id;
946        self.next_stage_id += 1;
947
948        let mut table_scan_info = self.collect_stage_table_scan(root.clone())?;
949        // For current implementation, we can guarantee that each stage has only one table
950        // scan(except System table) or one source.
951        let source_info = if table_scan_info.is_none() {
952            Self::collect_stage_source(root.clone())?
953        } else {
954            None
955        };
956
957        let file_scan_info = if table_scan_info.is_none() && source_info.is_none() {
958            Self::collect_stage_file_scan(root.clone())?
959        } else {
960            None
961        };
962
963        let mut has_lookup_join = false;
964        let parallelism = match root.distribution() {
965            Distribution::Single => {
966                if let Some(info) = &mut table_scan_info {
967                    if let Some(partitions) = &mut info.partitions {
968                        if partitions.len() != 1 {
969                            // This is rare case, but it's possible on the internal state of the
970                            // Source operator.
971                            tracing::warn!(
972                                "The stage has single distribution, but contains a scan of table `{}` with {} partitions. A single random worker will be assigned",
973                                info.name,
974                                partitions.len()
975                            );
976
977                            *partitions = partitions
978                                .drain()
979                                .take(1)
980                                .update(|(_, info)| {
981                                    info.vnode_bitmap = Bitmap::ones(info.vnode_bitmap.len());
982                                })
983                                .collect();
984                        }
985                    } else {
986                        // System table
987                    }
988                } else if source_info.is_some() {
989                    return Err(SchedulerError::Internal(anyhow!(
990                        "The stage has single distribution, but contains a source operator"
991                    )));
992                }
993                1
994            }
995            _ => {
996                if let Some(table_scan_info) = &table_scan_info {
997                    table_scan_info
998                        .partitions
999                        .as_ref()
1000                        .map(|m| m.len())
1001                        .unwrap_or(1)
1002                } else if let Some(lookup_join_parallelism) =
1003                    self.collect_stage_lookup_join_parallelism(root.clone())?
1004                {
1005                    has_lookup_join = true;
1006                    lookup_join_parallelism
1007                } else if source_info.is_some() {
1008                    0
1009                } else if file_scan_info.is_some() {
1010                    1
1011                } else {
1012                    self.batch_parallelism
1013                }
1014            }
1015        };
1016        if source_info.is_none() && file_scan_info.is_none() && parallelism == 0 {
1017            return Err(BatchError::EmptyWorkerNodes.into());
1018        }
1019        let parallelism = if parallelism == 0 {
1020            None
1021        } else {
1022            Some(parallelism as u32)
1023        };
1024        let dml_table_id = Self::collect_dml_table_id(&root);
1025        let mut builder = QueryStageBuilder::new(
1026            next_stage_id,
1027            self.query_id.clone(),
1028            parallelism,
1029            exchange_info,
1030            table_scan_info,
1031            source_info,
1032            file_scan_info,
1033            has_lookup_join,
1034            dml_table_id,
1035            root.ctx().session_ctx().session_id(),
1036            root.ctx()
1037                .session_ctx()
1038                .config()
1039                .batch_enable_distributed_dml(),
1040        );
1041
1042        self.visit_node(root, &mut builder, None)?;
1043
1044        Ok(builder.finish(self.stage_graph_builder.as_mut().unwrap()))
1045    }
1046
1047    fn visit_node(
1048        &mut self,
1049        node: PlanRef,
1050        builder: &mut QueryStageBuilder,
1051        parent_exec_node: Option<&mut ExecutionPlanNode>,
1052    ) -> SchedulerResult<()> {
1053        match node.node_type() {
1054            PlanNodeType::BatchExchange => {
1055                self.visit_exchange(node.clone(), builder, parent_exec_node)?;
1056            }
1057            _ => {
1058                let mut execution_plan_node = ExecutionPlanNode::try_from(node.clone())?;
1059
1060                for child in node.inputs() {
1061                    self.visit_node(child, builder, Some(&mut execution_plan_node))?;
1062                }
1063
1064                if let Some(parent) = parent_exec_node {
1065                    parent.children.push(Arc::new(execution_plan_node));
1066                } else {
1067                    builder.root = Some(Arc::new(execution_plan_node));
1068                }
1069            }
1070        }
1071        Ok(())
1072    }
1073
1074    fn visit_exchange(
1075        &mut self,
1076        node: PlanRef,
1077        builder: &mut QueryStageBuilder,
1078        parent_exec_node: Option<&mut ExecutionPlanNode>,
1079    ) -> SchedulerResult<()> {
1080        let mut execution_plan_node = ExecutionPlanNode::try_from(node.clone())?;
1081        let child_exchange_info = if let Some(parallelism) = builder.parallelism {
1082            Some(node.distribution().to_prost(
1083                parallelism,
1084                &self.catalog_reader,
1085                &self.worker_node_manager,
1086            )?)
1087        } else {
1088            None
1089        };
1090        let child_stage = self.new_stage(node.inputs()[0].clone(), child_exchange_info)?;
1091        execution_plan_node.source_stage_id = Some(child_stage.id);
1092        if builder.parallelism.is_none() {
1093            builder
1094                .children_exchange_distribution
1095                .insert(child_stage.id, node.distribution().clone());
1096        }
1097
1098        if let Some(parent) = parent_exec_node {
1099            parent.children.push(Arc::new(execution_plan_node));
1100        } else {
1101            builder.root = Some(Arc::new(execution_plan_node));
1102        }
1103
1104        builder.children_stages.push(child_stage);
1105        Ok(())
1106    }
1107
1108    /// Check whether this stage contains a source node.
1109    /// If so, use  `SplitEnumeratorImpl` to get the split info from exteneral source.
1110    ///
1111    /// For current implementation, we can guarantee that each stage has only one source.
1112    fn collect_stage_source(node: PlanRef) -> SchedulerResult<Option<SourceScanInfo>> {
1113        if node.node_type() == PlanNodeType::BatchExchange {
1114            // Do not visit next stage.
1115            return Ok(None);
1116        }
1117
1118        if let Some(batch_kafka_node) = node.as_batch_kafka_scan() {
1119            let batch_kafka_scan: &BatchKafkaScan = batch_kafka_node;
1120            let source_catalog = batch_kafka_scan.source_catalog();
1121            if let Some(source_catalog) = source_catalog {
1122                let property =
1123                    ConnectorProperties::extract(source_catalog.with_properties.clone(), false)?;
1124                let timestamp_bound = batch_kafka_scan.kafka_timestamp_range_value();
1125                return Ok(Some(SourceScanInfo::new(SourceFetchInfo {
1126                    schema: batch_kafka_scan.base.schema().clone(),
1127                    connector: property,
1128                    fetch_parameters: SourceFetchParameters::KafkaTimebound {
1129                        lower: timestamp_bound.0,
1130                        upper: timestamp_bound.1,
1131                    },
1132                    as_of: None,
1133                })));
1134            }
1135        } else if let Some(batch_iceberg_scan) = node.as_batch_iceberg_scan() {
1136            let batch_iceberg_scan: &BatchIcebergScan = batch_iceberg_scan;
1137            let source_catalog = batch_iceberg_scan.source_catalog();
1138            if let Some(source_catalog) = source_catalog {
1139                let property =
1140                    ConnectorProperties::extract(source_catalog.with_properties.clone(), false)?;
1141                let as_of = batch_iceberg_scan.as_of();
1142                return Ok(Some(SourceScanInfo::new(SourceFetchInfo {
1143                    schema: batch_iceberg_scan.base.schema().clone(),
1144                    connector: property,
1145                    fetch_parameters: SourceFetchParameters::IcebergSpecificInfo(
1146                        IcebergSpecificInfo {
1147                            predicate: batch_iceberg_scan.predicate.clone(),
1148                            iceberg_scan_type: batch_iceberg_scan.iceberg_scan_type(),
1149                        },
1150                    ),
1151                    as_of,
1152                })));
1153            }
1154        } else if let Some(source_node) = node.as_batch_source() {
1155            // TODO: use specific batch operator instead of batch source.
1156            let source_node: &BatchSource = source_node;
1157            let source_catalog = source_node.source_catalog();
1158            if let Some(source_catalog) = source_catalog {
1159                let property =
1160                    ConnectorProperties::extract(source_catalog.with_properties.clone(), false)?;
1161                let as_of = source_node.as_of();
1162                return Ok(Some(SourceScanInfo::new(SourceFetchInfo {
1163                    schema: source_node.base.schema().clone(),
1164                    connector: property,
1165                    fetch_parameters: SourceFetchParameters::Empty,
1166                    as_of,
1167                })));
1168            }
1169        }
1170
1171        node.inputs()
1172            .into_iter()
1173            .find_map(|n| Self::collect_stage_source(n).transpose())
1174            .transpose()
1175    }
1176
1177    fn collect_stage_file_scan(node: PlanRef) -> SchedulerResult<Option<FileScanInfo>> {
1178        if node.node_type() == PlanNodeType::BatchExchange {
1179            // Do not visit next stage.
1180            return Ok(None);
1181        }
1182
1183        if let Some(batch_file_scan) = node.as_batch_file_scan() {
1184            return Ok(Some(FileScanInfo {
1185                file_location: batch_file_scan.core.file_location().clone(),
1186            }));
1187        }
1188
1189        node.inputs()
1190            .into_iter()
1191            .find_map(|n| Self::collect_stage_file_scan(n).transpose())
1192            .transpose()
1193    }
1194
1195    /// Check whether this stage contains a table scan node and the table's information if so.
1196    ///
1197    /// If there are multiple scan nodes in this stage, they must have the same distribution, but
1198    /// maybe different vnodes partition. We just use the same partition for all the scan nodes.
1199    fn collect_stage_table_scan(&self, node: PlanRef) -> SchedulerResult<Option<TableScanInfo>> {
1200        let build_table_scan_info = |name, table_desc: &TableDesc, scan_range| {
1201            let table_catalog = self
1202                .catalog_reader
1203                .read_guard()
1204                .get_any_table_by_id(&table_desc.table_id)
1205                .cloned()
1206                .map_err(RwError::from)?;
1207            let vnode_mapping = self
1208                .worker_node_manager
1209                .fragment_mapping(table_catalog.fragment_id)?;
1210            let partitions = derive_partitions(scan_range, table_desc, &vnode_mapping)?;
1211            let info = TableScanInfo::new(name, partitions);
1212            Ok(Some(info))
1213        };
1214        if node.node_type() == PlanNodeType::BatchExchange {
1215            // Do not visit next stage.
1216            return Ok(None);
1217        }
1218        if let Some(scan_node) = node.as_batch_sys_seq_scan() {
1219            let name = scan_node.core().table_name.to_owned();
1220            Ok(Some(TableScanInfo::system_table(name)))
1221        } else if let Some(scan_node) = node.as_batch_log_seq_scan() {
1222            build_table_scan_info(
1223                scan_node.core().table_name.to_owned(),
1224                &scan_node.core().table_desc,
1225                &[],
1226            )
1227        } else if let Some(scan_node) = node.as_batch_seq_scan() {
1228            build_table_scan_info(
1229                scan_node.core().table_name.to_owned(),
1230                &scan_node.core().table_desc,
1231                scan_node.scan_ranges(),
1232            )
1233        } else {
1234            node.inputs()
1235                .into_iter()
1236                .find_map(|n| self.collect_stage_table_scan(n).transpose())
1237                .transpose()
1238        }
1239    }
1240
1241    /// Returns the dml table id if any.
1242    fn collect_dml_table_id(node: &PlanRef) -> Option<TableId> {
1243        if node.node_type() == PlanNodeType::BatchExchange {
1244            return None;
1245        }
1246        if let Some(insert) = node.as_batch_insert() {
1247            Some(insert.core.table_id)
1248        } else if let Some(update) = node.as_batch_update() {
1249            Some(update.core.table_id)
1250        } else if let Some(delete) = node.as_batch_delete() {
1251            Some(delete.core.table_id)
1252        } else {
1253            node.inputs()
1254                .into_iter()
1255                .find_map(|n| Self::collect_dml_table_id(&n))
1256        }
1257    }
1258
1259    fn collect_stage_lookup_join_parallelism(
1260        &self,
1261        node: PlanRef,
1262    ) -> SchedulerResult<Option<usize>> {
1263        if node.node_type() == PlanNodeType::BatchExchange {
1264            // Do not visit next stage.
1265            return Ok(None);
1266        }
1267        if let Some(lookup_join) = node.as_batch_lookup_join() {
1268            let table_desc = lookup_join.right_table_desc();
1269            let table_catalog = self
1270                .catalog_reader
1271                .read_guard()
1272                .get_any_table_by_id(&table_desc.table_id)
1273                .cloned()
1274                .map_err(RwError::from)?;
1275            let vnode_mapping = self
1276                .worker_node_manager
1277                .fragment_mapping(table_catalog.fragment_id)?;
1278            let parallelism = vnode_mapping.iter().sorted().dedup().count();
1279            Ok(Some(parallelism))
1280        } else {
1281            node.inputs()
1282                .into_iter()
1283                .find_map(|n| self.collect_stage_lookup_join_parallelism(n).transpose())
1284                .transpose()
1285        }
1286    }
1287}
1288
1289/// Try to derive the partition to read from the scan range.
1290/// It can be derived if the value of the distribution key is already known.
1291fn derive_partitions(
1292    scan_ranges: &[ScanRange],
1293    table_desc: &TableDesc,
1294    vnode_mapping: &WorkerSlotMapping,
1295) -> SchedulerResult<HashMap<WorkerSlotId, TablePartitionInfo>> {
1296    let vnode_mapping = if table_desc.vnode_count != vnode_mapping.len() {
1297        // The vnode count mismatch occurs only in special cases where a hash-distributed fragment
1298        // contains singleton internal tables. e.g., the state table of `Source` executors.
1299        // In this case, we reduce the vnode mapping to a single vnode as only `SINGLETON_VNODE` is used.
1300        assert!(
1301            table_desc.vnode_count == 1,
1302            "fragment vnode count {} does not match table vnode count {}",
1303            vnode_mapping.len(),
1304            table_desc.vnode_count,
1305        );
1306        &WorkerSlotMapping::new_single(vnode_mapping.iter().next().unwrap())
1307    } else {
1308        vnode_mapping
1309    };
1310    let vnode_count = vnode_mapping.len();
1311
1312    let mut partitions: HashMap<WorkerSlotId, (BitmapBuilder, Vec<_>)> = HashMap::new();
1313
1314    if scan_ranges.is_empty() {
1315        return Ok(vnode_mapping
1316            .to_bitmaps()
1317            .into_iter()
1318            .map(|(k, vnode_bitmap)| {
1319                (
1320                    k,
1321                    TablePartitionInfo {
1322                        vnode_bitmap,
1323                        scan_ranges: vec![],
1324                    },
1325                )
1326            })
1327            .collect());
1328    }
1329
1330    let table_distribution = TableDistribution::new_from_storage_table_desc(
1331        Some(Bitmap::ones(vnode_count).into()),
1332        &table_desc.try_to_protobuf()?,
1333    );
1334
1335    for scan_range in scan_ranges {
1336        let vnode = scan_range.try_compute_vnode(&table_distribution);
1337        match vnode {
1338            None => {
1339                // put this scan_range to all partitions
1340                vnode_mapping.to_bitmaps().into_iter().for_each(
1341                    |(worker_slot_id, vnode_bitmap)| {
1342                        let (bitmap, scan_ranges) = partitions
1343                            .entry(worker_slot_id)
1344                            .or_insert_with(|| (BitmapBuilder::zeroed(vnode_count), vec![]));
1345                        vnode_bitmap
1346                            .iter()
1347                            .enumerate()
1348                            .for_each(|(vnode, b)| bitmap.set(vnode, b));
1349                        scan_ranges.push(scan_range.to_protobuf());
1350                    },
1351                );
1352            }
1353            // scan a single partition
1354            Some(vnode) => {
1355                let worker_slot_id = vnode_mapping[vnode];
1356                let (bitmap, scan_ranges) = partitions
1357                    .entry(worker_slot_id)
1358                    .or_insert_with(|| (BitmapBuilder::zeroed(vnode_count), vec![]));
1359                bitmap.set(vnode.to_index(), true);
1360                scan_ranges.push(scan_range.to_protobuf());
1361            }
1362        }
1363    }
1364
1365    Ok(partitions
1366        .into_iter()
1367        .map(|(k, (bitmap, scan_ranges))| {
1368            (
1369                k,
1370                TablePartitionInfo {
1371                    vnode_bitmap: bitmap.finish(),
1372                    scan_ranges,
1373                },
1374            )
1375        })
1376        .collect())
1377}
1378
1379#[cfg(test)]
1380mod tests {
1381    use std::collections::{HashMap, HashSet};
1382
1383    use risingwave_pb::batch_plan::plan_node::NodeBody;
1384
1385    use crate::optimizer::plan_node::PlanNodeType;
1386    use crate::scheduler::plan_fragmenter::StageId;
1387
1388    #[tokio::test]
1389    async fn test_fragmenter() {
1390        let query = crate::scheduler::distributed::tests::create_query().await;
1391
1392        assert_eq!(query.stage_graph.root_stage_id, 0);
1393        assert_eq!(query.stage_graph.stages.len(), 4);
1394
1395        // Check the mappings of child edges.
1396        assert_eq!(query.stage_graph.child_edges[&0], [1].into());
1397        assert_eq!(query.stage_graph.child_edges[&1], [2, 3].into());
1398        assert_eq!(query.stage_graph.child_edges[&2], HashSet::new());
1399        assert_eq!(query.stage_graph.child_edges[&3], HashSet::new());
1400
1401        // Check the mappings of parent edges.
1402        assert_eq!(query.stage_graph.parent_edges[&0], HashSet::new());
1403        assert_eq!(query.stage_graph.parent_edges[&1], [0].into());
1404        assert_eq!(query.stage_graph.parent_edges[&2], [1].into());
1405        assert_eq!(query.stage_graph.parent_edges[&3], [1].into());
1406
1407        // Verify topology order
1408        {
1409            let stage_id_to_pos: HashMap<StageId, usize> = query
1410                .stage_graph
1411                .stage_ids_by_topo_order()
1412                .enumerate()
1413                .map(|(pos, stage_id)| (stage_id, pos))
1414                .collect();
1415
1416            for stage_id in query.stage_graph.stages.keys() {
1417                let stage_pos = stage_id_to_pos[stage_id];
1418                for child_stage_id in &query.stage_graph.child_edges[stage_id] {
1419                    let child_pos = stage_id_to_pos[child_stage_id];
1420                    assert!(stage_pos > child_pos);
1421                }
1422            }
1423        }
1424
1425        // Check plan node in each stages.
1426        let root_exchange = query.stage_graph.stages.get(&0).unwrap();
1427        assert_eq!(root_exchange.root.node_type(), PlanNodeType::BatchExchange);
1428        assert_eq!(root_exchange.root.source_stage_id, Some(1));
1429        assert!(matches!(root_exchange.root.node, NodeBody::Exchange(_)));
1430        assert_eq!(root_exchange.parallelism, Some(1));
1431        assert!(!root_exchange.has_table_scan());
1432
1433        let join_node = query.stage_graph.stages.get(&1).unwrap();
1434        assert_eq!(join_node.root.node_type(), PlanNodeType::BatchHashJoin);
1435        assert_eq!(join_node.parallelism, Some(24));
1436
1437        assert!(matches!(join_node.root.node, NodeBody::HashJoin(_)));
1438        assert_eq!(join_node.root.source_stage_id, None);
1439        assert_eq!(2, join_node.root.children.len());
1440
1441        assert!(matches!(
1442            join_node.root.children[0].node,
1443            NodeBody::Exchange(_)
1444        ));
1445        assert_eq!(join_node.root.children[0].source_stage_id, Some(2));
1446        assert_eq!(0, join_node.root.children[0].children.len());
1447
1448        assert!(matches!(
1449            join_node.root.children[1].node,
1450            NodeBody::Exchange(_)
1451        ));
1452        assert_eq!(join_node.root.children[1].source_stage_id, Some(3));
1453        assert_eq!(0, join_node.root.children[1].children.len());
1454        assert!(!join_node.has_table_scan());
1455
1456        let scan_node1 = query.stage_graph.stages.get(&2).unwrap();
1457        assert_eq!(scan_node1.root.node_type(), PlanNodeType::BatchSeqScan);
1458        assert_eq!(scan_node1.root.source_stage_id, None);
1459        assert_eq!(0, scan_node1.root.children.len());
1460        assert!(scan_node1.has_table_scan());
1461
1462        let scan_node2 = query.stage_graph.stages.get(&3).unwrap();
1463        assert_eq!(scan_node2.root.node_type(), PlanNodeType::BatchFilter);
1464        assert_eq!(scan_node2.root.source_stage_id, None);
1465        assert_eq!(1, scan_node2.root.children.len());
1466        assert!(scan_node2.has_table_scan());
1467    }
1468}