1use std::cmp::min;
16use std::collections::{HashMap, HashSet};
17use std::fmt::{Debug, Formatter};
18use std::num::NonZeroU64;
19use std::sync::Arc;
20
21use anyhow::anyhow;
22use async_recursion::async_recursion;
23use enum_as_inner::EnumAsInner;
24use futures::TryStreamExt;
25use iceberg::expr::Predicate as IcebergPredicate;
26use itertools::Itertools;
27use petgraph::{Directed, Graph};
28use pgwire::pg_server::SessionId;
29use risingwave_batch::error::BatchError;
30use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeSelector;
31use risingwave_common::bitmap::{Bitmap, BitmapBuilder};
32use risingwave_common::catalog::{Schema, TableDesc};
33use risingwave_common::hash::table_distribution::TableDistribution;
34use risingwave_common::hash::{WorkerSlotId, WorkerSlotMapping};
35use risingwave_common::util::scan_range::ScanRange;
36use risingwave_connector::source::filesystem::opendal_source::opendal_enumerator::OpendalEnumerator;
37use risingwave_connector::source::filesystem::opendal_source::{
38 OpendalAzblob, OpendalGcs, OpendalS3,
39};
40use risingwave_connector::source::iceberg::IcebergSplitEnumerator;
41use risingwave_connector::source::kafka::KafkaSplitEnumerator;
42use risingwave_connector::source::reader::reader::build_opendal_fs_list_for_batch;
43use risingwave_connector::source::{
44 ConnectorProperties, SourceEnumeratorContext, SplitEnumerator, SplitImpl,
45};
46use risingwave_pb::batch_plan::iceberg_scan_node::IcebergScanType;
47use risingwave_pb::batch_plan::plan_node::NodeBody;
48use risingwave_pb::batch_plan::{ExchangeInfo, ScanRange as ScanRangeProto};
49use risingwave_pb::plan_common::Field as PbField;
50use risingwave_sqlparser::ast::AsOf;
51use serde::Serialize;
52use serde::ser::SerializeStruct;
53use uuid::Uuid;
54
55use super::SchedulerError;
56use crate::catalog::TableId;
57use crate::catalog::catalog_service::CatalogReader;
58use crate::error::RwError;
59use crate::optimizer::PlanRef;
60use crate::optimizer::plan_node::generic::{GenericPlanRef, PhysicalPlanRef};
61use crate::optimizer::plan_node::utils::to_iceberg_time_travel_as_of;
62use crate::optimizer::plan_node::{
63 BatchIcebergScan, BatchKafkaScan, BatchSource, PlanNodeId, PlanNodeType,
64};
65use crate::optimizer::property::Distribution;
66use crate::scheduler::SchedulerResult;
67
68#[derive(Clone, Debug, Hash, Eq, PartialEq)]
69pub struct QueryId {
70 pub id: String,
71}
72
73impl std::fmt::Display for QueryId {
74 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
75 write!(f, "QueryId:{}", self.id)
76 }
77}
78
79pub type StageId = u32;
80
81pub const ROOT_TASK_ID: u64 = 0;
83pub const ROOT_TASK_OUTPUT_ID: u64 = 0;
85pub type TaskId = u64;
86
87#[derive(Clone, Debug)]
89pub struct ExecutionPlanNode {
90 pub plan_node_id: PlanNodeId,
91 pub plan_node_type: PlanNodeType,
92 pub node: NodeBody,
93 pub schema: Vec<PbField>,
94
95 pub children: Vec<Arc<ExecutionPlanNode>>,
96
97 pub source_stage_id: Option<StageId>,
102}
103
104impl Serialize for ExecutionPlanNode {
105 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
106 where
107 S: serde::Serializer,
108 {
109 let mut state = serializer.serialize_struct("QueryStage", 5)?;
110 state.serialize_field("plan_node_id", &self.plan_node_id)?;
111 state.serialize_field("plan_node_type", &self.plan_node_type)?;
112 state.serialize_field("schema", &self.schema)?;
113 state.serialize_field("children", &self.children)?;
114 state.serialize_field("source_stage_id", &self.source_stage_id)?;
115 state.end()
116 }
117}
118
119impl TryFrom<PlanRef> for ExecutionPlanNode {
120 type Error = SchedulerError;
121
122 fn try_from(plan_node: PlanRef) -> Result<Self, Self::Error> {
123 Ok(Self {
124 plan_node_id: plan_node.plan_base().id(),
125 plan_node_type: plan_node.node_type(),
126 node: plan_node.try_to_batch_prost_body()?,
127 children: vec![],
128 schema: plan_node.schema().to_prost(),
129 source_stage_id: None,
130 })
131 }
132}
133
134impl ExecutionPlanNode {
135 pub fn node_type(&self) -> PlanNodeType {
136 self.plan_node_type
137 }
138}
139
140pub struct BatchPlanFragmenter {
142 query_id: QueryId,
143 next_stage_id: StageId,
144 worker_node_manager: WorkerNodeSelector,
145 catalog_reader: CatalogReader,
146
147 batch_parallelism: usize,
148 timezone: String,
149
150 stage_graph_builder: Option<StageGraphBuilder>,
151 stage_graph: Option<StageGraph>,
152}
153
154impl Default for QueryId {
155 fn default() -> Self {
156 Self {
157 id: Uuid::new_v4().to_string(),
158 }
159 }
160}
161
162impl BatchPlanFragmenter {
163 pub fn new(
164 worker_node_manager: WorkerNodeSelector,
165 catalog_reader: CatalogReader,
166 batch_parallelism: Option<NonZeroU64>,
167 timezone: String,
168 batch_node: PlanRef,
169 ) -> SchedulerResult<Self> {
170 let batch_parallelism = if let Some(num) = batch_parallelism {
175 min(
177 num.get() as usize,
178 worker_node_manager.schedule_unit_count(),
179 )
180 } else {
181 worker_node_manager.schedule_unit_count()
183 };
184
185 let mut plan_fragmenter = Self {
186 query_id: Default::default(),
187 next_stage_id: 0,
188 worker_node_manager,
189 catalog_reader,
190 batch_parallelism,
191 timezone,
192 stage_graph_builder: Some(StageGraphBuilder::new(batch_parallelism)),
193 stage_graph: None,
194 };
195 plan_fragmenter.split_into_stage(batch_node)?;
196 Ok(plan_fragmenter)
197 }
198
199 fn split_into_stage(&mut self, batch_node: PlanRef) -> SchedulerResult<()> {
201 let root_stage = self.new_stage(
202 batch_node,
203 Some(Distribution::Single.to_prost(
204 1,
205 &self.catalog_reader,
206 &self.worker_node_manager,
207 )?),
208 )?;
209 self.stage_graph = Some(
210 self.stage_graph_builder
211 .take()
212 .unwrap()
213 .build(root_stage.id),
214 );
215 Ok(())
216 }
217}
218
219#[derive(Debug)]
221#[cfg_attr(test, derive(Clone))]
222pub struct Query {
223 pub query_id: QueryId,
225 pub stage_graph: StageGraph,
226}
227
228impl Query {
229 pub fn leaf_stages(&self) -> Vec<StageId> {
230 let mut ret_leaf_stages = Vec::new();
231 for stage_id in self.stage_graph.stages.keys() {
232 if self
233 .stage_graph
234 .get_child_stages_unchecked(stage_id)
235 .is_empty()
236 {
237 ret_leaf_stages.push(*stage_id);
238 }
239 }
240 ret_leaf_stages
241 }
242
243 pub fn get_parents(&self, stage_id: &StageId) -> &HashSet<StageId> {
244 self.stage_graph.parent_edges.get(stage_id).unwrap()
245 }
246
247 pub fn root_stage_id(&self) -> StageId {
248 self.stage_graph.root_stage_id
249 }
250
251 pub fn query_id(&self) -> &QueryId {
252 &self.query_id
253 }
254
255 pub fn stages_with_table_scan(&self) -> HashSet<StageId> {
256 self.stage_graph
257 .stages
258 .iter()
259 .filter_map(|(stage_id, stage_query)| {
260 if stage_query.has_table_scan() {
261 Some(*stage_id)
262 } else {
263 None
264 }
265 })
266 .collect()
267 }
268
269 pub fn has_lookup_join_stage(&self) -> bool {
270 self.stage_graph
271 .stages
272 .iter()
273 .any(|(_stage_id, stage_query)| stage_query.has_lookup_join())
274 }
275}
276
277#[derive(Debug, Clone)]
278pub enum SourceFetchParameters {
279 IcebergSpecificInfo(IcebergSpecificInfo),
280 KafkaTimebound {
281 lower: Option<i64>,
282 upper: Option<i64>,
283 },
284 Empty,
285}
286
287#[derive(Debug, Clone)]
288pub struct SourceFetchInfo {
289 pub schema: Schema,
290 pub connector: ConnectorProperties,
293 pub fetch_parameters: SourceFetchParameters,
296 pub as_of: Option<AsOf>,
297}
298
299#[derive(Debug, Clone)]
300pub struct IcebergSpecificInfo {
301 pub iceberg_scan_type: IcebergScanType,
302 pub predicate: IcebergPredicate,
303}
304
305#[derive(Clone, Debug)]
306pub enum SourceScanInfo {
307 Incomplete(SourceFetchInfo),
309 Complete(Vec<SplitImpl>),
310}
311
312impl SourceScanInfo {
313 pub fn new(fetch_info: SourceFetchInfo) -> Self {
314 Self::Incomplete(fetch_info)
315 }
316
317 pub async fn complete(
318 self,
319 batch_parallelism: usize,
320 timezone: String,
321 ) -> SchedulerResult<Self> {
322 let fetch_info = match self {
323 SourceScanInfo::Incomplete(fetch_info) => fetch_info,
324 SourceScanInfo::Complete(_) => {
325 unreachable!("Never call complete when SourceScanInfo is already complete")
326 }
327 };
328 match (fetch_info.connector, fetch_info.fetch_parameters) {
329 (
330 ConnectorProperties::Kafka(prop),
331 SourceFetchParameters::KafkaTimebound { lower, upper },
332 ) => {
333 let mut kafka_enumerator =
334 KafkaSplitEnumerator::new(*prop, SourceEnumeratorContext::dummy().into())
335 .await?;
336 let split_info = kafka_enumerator
337 .list_splits_batch(lower, upper)
338 .await?
339 .into_iter()
340 .map(SplitImpl::Kafka)
341 .collect_vec();
342
343 Ok(SourceScanInfo::Complete(split_info))
344 }
345 (ConnectorProperties::OpendalS3(prop), SourceFetchParameters::Empty) => {
346 let lister: OpendalEnumerator<OpendalS3> =
347 OpendalEnumerator::new_s3_source(prop.s3_properties, prop.assume_role)?;
348 let stream = build_opendal_fs_list_for_batch(lister);
349
350 let batch_res: Vec<_> = stream.try_collect().await?;
351 let res = batch_res
352 .into_iter()
353 .map(SplitImpl::OpendalS3)
354 .collect_vec();
355
356 Ok(SourceScanInfo::Complete(res))
357 }
358 (ConnectorProperties::Gcs(prop), SourceFetchParameters::Empty) => {
359 let lister: OpendalEnumerator<OpendalGcs> =
360 OpendalEnumerator::new_gcs_source(*prop)?;
361 let stream = build_opendal_fs_list_for_batch(lister);
362 let batch_res: Vec<_> = stream.try_collect().await?;
363 let res = batch_res.into_iter().map(SplitImpl::Gcs).collect_vec();
364
365 Ok(SourceScanInfo::Complete(res))
366 }
367 (ConnectorProperties::Azblob(prop), SourceFetchParameters::Empty) => {
368 let lister: OpendalEnumerator<OpendalAzblob> =
369 OpendalEnumerator::new_azblob_source(*prop)?;
370 let stream = build_opendal_fs_list_for_batch(lister);
371 let batch_res: Vec<_> = stream.try_collect().await?;
372 let res = batch_res.into_iter().map(SplitImpl::Azblob).collect_vec();
373
374 Ok(SourceScanInfo::Complete(res))
375 }
376 (
377 ConnectorProperties::Iceberg(prop),
378 SourceFetchParameters::IcebergSpecificInfo(iceberg_specific_info),
379 ) => {
380 let iceberg_enumerator =
381 IcebergSplitEnumerator::new(*prop, SourceEnumeratorContext::dummy().into())
382 .await?;
383
384 let time_travel_info = to_iceberg_time_travel_as_of(&fetch_info.as_of, &timezone)?;
385
386 let split_info = iceberg_enumerator
387 .list_splits_batch(
388 fetch_info.schema,
389 time_travel_info,
390 batch_parallelism,
391 iceberg_specific_info.iceberg_scan_type,
392 iceberg_specific_info.predicate,
393 )
394 .await?
395 .into_iter()
396 .map(SplitImpl::Iceberg)
397 .collect_vec();
398
399 Ok(SourceScanInfo::Complete(split_info))
400 }
401 _ => Err(SchedulerError::Internal(anyhow!(
402 "Unsupported to query directly from this source"
403 ))),
404 }
405 }
406
407 pub fn split_info(&self) -> SchedulerResult<&Vec<SplitImpl>> {
408 match self {
409 Self::Incomplete(_) => Err(SchedulerError::Internal(anyhow!(
410 "Should not get split info from incomplete source scan info"
411 ))),
412 Self::Complete(split_info) => Ok(split_info),
413 }
414 }
415}
416
417#[derive(Clone, Debug)]
418pub struct TableScanInfo {
419 name: String,
421
422 partitions: Option<HashMap<WorkerSlotId, TablePartitionInfo>>,
430}
431
432impl TableScanInfo {
433 pub fn new(name: String, partitions: HashMap<WorkerSlotId, TablePartitionInfo>) -> Self {
435 Self {
436 name,
437 partitions: Some(partitions),
438 }
439 }
440
441 pub fn system_table(name: String) -> Self {
443 Self {
444 name,
445 partitions: None,
446 }
447 }
448
449 pub fn name(&self) -> &str {
450 self.name.as_ref()
451 }
452
453 pub fn partitions(&self) -> Option<&HashMap<WorkerSlotId, TablePartitionInfo>> {
454 self.partitions.as_ref()
455 }
456}
457
458#[derive(Clone, Debug)]
459pub struct TablePartitionInfo {
460 pub vnode_bitmap: Bitmap,
461 pub scan_ranges: Vec<ScanRangeProto>,
462}
463
464#[derive(Clone, Debug, EnumAsInner)]
465pub enum PartitionInfo {
466 Table(TablePartitionInfo),
467 Source(Vec<SplitImpl>),
468 File(Vec<String>),
469}
470
471#[derive(Clone, Debug)]
472pub struct FileScanInfo {
473 pub file_location: Vec<String>,
474}
475
476#[derive(Clone)]
478pub struct QueryStage {
479 pub query_id: QueryId,
480 pub id: StageId,
481 pub root: Arc<ExecutionPlanNode>,
482 pub exchange_info: Option<ExchangeInfo>,
483 pub parallelism: Option<u32>,
484 pub table_scan_info: Option<TableScanInfo>,
486 pub source_info: Option<SourceScanInfo>,
487 pub file_scan_info: Option<FileScanInfo>,
488 pub has_lookup_join: bool,
489 pub dml_table_id: Option<TableId>,
490 pub session_id: SessionId,
491 pub batch_enable_distributed_dml: bool,
492
493 children_exchange_distribution: Option<HashMap<StageId, Distribution>>,
495}
496
497impl QueryStage {
498 pub fn has_table_scan(&self) -> bool {
502 self.table_scan_info.is_some()
503 }
504
505 pub fn has_lookup_join(&self) -> bool {
508 self.has_lookup_join
509 }
510
511 pub fn clone_with_exchange_info(
512 &self,
513 exchange_info: Option<ExchangeInfo>,
514 parallelism: Option<u32>,
515 ) -> Self {
516 if let Some(exchange_info) = exchange_info {
517 return Self {
518 query_id: self.query_id.clone(),
519 id: self.id,
520 root: self.root.clone(),
521 exchange_info: Some(exchange_info),
522 parallelism,
523 table_scan_info: self.table_scan_info.clone(),
524 source_info: self.source_info.clone(),
525 file_scan_info: self.file_scan_info.clone(),
526 has_lookup_join: self.has_lookup_join,
527 dml_table_id: self.dml_table_id,
528 session_id: self.session_id,
529 batch_enable_distributed_dml: self.batch_enable_distributed_dml,
530 children_exchange_distribution: self.children_exchange_distribution.clone(),
531 };
532 }
533 self.clone()
534 }
535
536 pub fn clone_with_exchange_info_and_complete_source_info(
537 &self,
538 exchange_info: Option<ExchangeInfo>,
539 source_info: SourceScanInfo,
540 task_parallelism: u32,
541 ) -> Self {
542 assert!(matches!(source_info, SourceScanInfo::Complete(_)));
543 let exchange_info = if let Some(exchange_info) = exchange_info {
544 Some(exchange_info)
545 } else {
546 self.exchange_info.clone()
547 };
548 Self {
549 query_id: self.query_id.clone(),
550 id: self.id,
551 root: self.root.clone(),
552 exchange_info,
553 parallelism: Some(task_parallelism),
554 table_scan_info: self.table_scan_info.clone(),
555 source_info: Some(source_info),
556 file_scan_info: self.file_scan_info.clone(),
557 has_lookup_join: self.has_lookup_join,
558 dml_table_id: self.dml_table_id,
559 session_id: self.session_id,
560 batch_enable_distributed_dml: self.batch_enable_distributed_dml,
561 children_exchange_distribution: None,
562 }
563 }
564}
565
566impl Debug for QueryStage {
567 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
568 f.debug_struct("QueryStage")
569 .field("id", &self.id)
570 .field("parallelism", &self.parallelism)
571 .field("exchange_info", &self.exchange_info)
572 .field("has_table_scan", &self.has_table_scan())
573 .finish()
574 }
575}
576
577impl Serialize for QueryStage {
578 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
579 where
580 S: serde::Serializer,
581 {
582 let mut state = serializer.serialize_struct("QueryStage", 3)?;
583 state.serialize_field("root", &self.root)?;
584 state.serialize_field("parallelism", &self.parallelism)?;
585 state.serialize_field("exchange_info", &self.exchange_info)?;
586 state.end()
587 }
588}
589
590pub type QueryStageRef = Arc<QueryStage>;
591
592struct QueryStageBuilder {
593 query_id: QueryId,
594 id: StageId,
595 root: Option<Arc<ExecutionPlanNode>>,
596 parallelism: Option<u32>,
597 exchange_info: Option<ExchangeInfo>,
598
599 children_stages: Vec<QueryStageRef>,
600 table_scan_info: Option<TableScanInfo>,
602 source_info: Option<SourceScanInfo>,
603 file_scan_file: Option<FileScanInfo>,
604 has_lookup_join: bool,
605 dml_table_id: Option<TableId>,
606 session_id: SessionId,
607 batch_enable_distributed_dml: bool,
608
609 children_exchange_distribution: HashMap<StageId, Distribution>,
610}
611
612impl QueryStageBuilder {
613 #[allow(clippy::too_many_arguments)]
614 fn new(
615 id: StageId,
616 query_id: QueryId,
617 parallelism: Option<u32>,
618 exchange_info: Option<ExchangeInfo>,
619 table_scan_info: Option<TableScanInfo>,
620 source_info: Option<SourceScanInfo>,
621 file_scan_file: Option<FileScanInfo>,
622 has_lookup_join: bool,
623 dml_table_id: Option<TableId>,
624 session_id: SessionId,
625 batch_enable_distributed_dml: bool,
626 ) -> Self {
627 Self {
628 query_id,
629 id,
630 root: None,
631 parallelism,
632 exchange_info,
633 children_stages: vec![],
634 table_scan_info,
635 source_info,
636 file_scan_file,
637 has_lookup_join,
638 dml_table_id,
639 session_id,
640 batch_enable_distributed_dml,
641 children_exchange_distribution: HashMap::new(),
642 }
643 }
644
645 fn finish(self, stage_graph_builder: &mut StageGraphBuilder) -> QueryStageRef {
646 let children_exchange_distribution = if self.parallelism.is_none() {
647 Some(self.children_exchange_distribution)
648 } else {
649 None
650 };
651 let stage = Arc::new(QueryStage {
652 query_id: self.query_id,
653 id: self.id,
654 root: self.root.unwrap(),
655 exchange_info: self.exchange_info,
656 parallelism: self.parallelism,
657 table_scan_info: self.table_scan_info,
658 source_info: self.source_info,
659 file_scan_info: self.file_scan_file,
660 has_lookup_join: self.has_lookup_join,
661 dml_table_id: self.dml_table_id,
662 session_id: self.session_id,
663 batch_enable_distributed_dml: self.batch_enable_distributed_dml,
664 children_exchange_distribution,
665 });
666
667 stage_graph_builder.add_node(stage.clone());
668 for child_stage in self.children_stages {
669 stage_graph_builder.link_to_child(self.id, child_stage.id);
670 }
671 stage
672 }
673}
674
675#[derive(Debug, Serialize)]
677#[cfg_attr(test, derive(Clone))]
678pub struct StageGraph {
679 pub root_stage_id: StageId,
680 pub stages: HashMap<StageId, QueryStageRef>,
681 child_edges: HashMap<StageId, HashSet<StageId>>,
683 parent_edges: HashMap<StageId, HashSet<StageId>>,
685
686 batch_parallelism: usize,
687}
688
689impl StageGraph {
690 pub fn get_child_stages_unchecked(&self, stage_id: &StageId) -> &HashSet<StageId> {
691 self.child_edges.get(stage_id).unwrap()
692 }
693
694 pub fn get_child_stages(&self, stage_id: &StageId) -> Option<&HashSet<StageId>> {
695 self.child_edges.get(stage_id)
696 }
697
698 pub fn stage_ids_by_topo_order(&self) -> impl Iterator<Item = StageId> {
700 let mut stack = Vec::with_capacity(self.stages.len());
701 stack.push(self.root_stage_id);
702 let mut ret = Vec::with_capacity(self.stages.len());
703 let mut existing = HashSet::with_capacity(self.stages.len());
704
705 while let Some(s) = stack.pop() {
706 if !existing.contains(&s) {
707 ret.push(s);
708 existing.insert(s);
709 stack.extend(&self.child_edges[&s]);
710 }
711 }
712
713 ret.into_iter().rev()
714 }
715
716 async fn complete(
717 self,
718 catalog_reader: &CatalogReader,
719 worker_node_manager: &WorkerNodeSelector,
720 timezone: String,
721 ) -> SchedulerResult<StageGraph> {
722 let mut complete_stages = HashMap::new();
723 self.complete_stage(
724 self.stages.get(&self.root_stage_id).unwrap().clone(),
725 None,
726 &mut complete_stages,
727 catalog_reader,
728 worker_node_manager,
729 timezone,
730 )
731 .await?;
732 Ok(StageGraph {
733 root_stage_id: self.root_stage_id,
734 stages: complete_stages,
735 child_edges: self.child_edges,
736 parent_edges: self.parent_edges,
737 batch_parallelism: self.batch_parallelism,
738 })
739 }
740
741 #[async_recursion]
742 async fn complete_stage(
743 &self,
744 stage: QueryStageRef,
745 exchange_info: Option<ExchangeInfo>,
746 complete_stages: &mut HashMap<StageId, QueryStageRef>,
747 catalog_reader: &CatalogReader,
748 worker_node_manager: &WorkerNodeSelector,
749 timezone: String,
750 ) -> SchedulerResult<()> {
751 let parallelism = if stage.parallelism.is_some() {
752 complete_stages.insert(
754 stage.id,
755 Arc::new(stage.clone_with_exchange_info(exchange_info, stage.parallelism)),
756 );
757 None
758 } else if matches!(stage.source_info, Some(SourceScanInfo::Incomplete(_))) {
759 let complete_source_info = stage
760 .source_info
761 .as_ref()
762 .unwrap()
763 .clone()
764 .complete(self.batch_parallelism, timezone.to_owned())
765 .await?;
766
767 let task_parallelism = match &stage.source_info {
775 Some(SourceScanInfo::Incomplete(source_fetch_info)) => {
776 match source_fetch_info.connector {
777 ConnectorProperties::Gcs(_)
778 | ConnectorProperties::OpendalS3(_)
779 | ConnectorProperties::Azblob(_) => (min(
780 complete_source_info.split_info().unwrap().len() as u32,
781 (self.batch_parallelism / 2) as u32,
782 ))
783 .max(1),
784 _ => complete_source_info.split_info().unwrap().len() as u32,
785 }
786 }
787 _ => unreachable!(),
788 };
789 let complete_stage = Arc::new(stage.clone_with_exchange_info_and_complete_source_info(
792 exchange_info,
793 complete_source_info,
794 task_parallelism,
795 ));
796 let parallelism = complete_stage.parallelism;
797 complete_stages.insert(stage.id, complete_stage);
798 parallelism
799 } else {
800 assert!(stage.file_scan_info.is_some());
801 let parallelism = min(
802 self.batch_parallelism / 2,
803 stage.file_scan_info.as_ref().unwrap().file_location.len(),
804 );
805 complete_stages.insert(
806 stage.id,
807 Arc::new(stage.clone_with_exchange_info(exchange_info, Some(parallelism as u32))),
808 );
809 None
810 };
811
812 for child_stage_id in self.child_edges.get(&stage.id).unwrap_or(&HashSet::new()) {
813 let exchange_info = if let Some(parallelism) = parallelism {
814 let exchange_distribution = stage
815 .children_exchange_distribution
816 .as_ref()
817 .unwrap()
818 .get(child_stage_id)
819 .expect("Exchange distribution is not consistent with the stage graph");
820 Some(exchange_distribution.to_prost(
821 parallelism,
822 catalog_reader,
823 worker_node_manager,
824 )?)
825 } else {
826 None
827 };
828 self.complete_stage(
829 self.stages.get(child_stage_id).unwrap().clone(),
830 exchange_info,
831 complete_stages,
832 catalog_reader,
833 worker_node_manager,
834 timezone.to_owned(),
835 )
836 .await?;
837 }
838
839 Ok(())
840 }
841
842 pub fn to_petgraph(&self) -> Graph<String, String, Directed> {
844 let mut graph = Graph::<String, String, Directed>::new();
845
846 let mut node_indices = HashMap::new();
847
848 for (&stage_id, stage_ref) in self.stages.iter().sorted_by_key(|(id, _)| **id) {
850 let node_label = format!("Stage {}: {:?}", stage_id, stage_ref);
851 let node_index = graph.add_node(node_label);
852 node_indices.insert(stage_id, node_index);
853 }
854
855 for (&parent_id, children) in &self.child_edges {
857 if let Some(&parent_index) = node_indices.get(&parent_id) {
858 for &child_id in children {
859 if let Some(&child_index) = node_indices.get(&child_id) {
860 graph.add_edge(parent_index, child_index, "".to_owned());
862 }
863 }
864 }
865 }
866
867 graph
868 }
869}
870
871struct StageGraphBuilder {
872 stages: HashMap<StageId, QueryStageRef>,
873 child_edges: HashMap<StageId, HashSet<StageId>>,
874 parent_edges: HashMap<StageId, HashSet<StageId>>,
875 batch_parallelism: usize,
876}
877
878impl StageGraphBuilder {
879 pub fn new(batch_parallelism: usize) -> Self {
880 Self {
881 stages: HashMap::new(),
882 child_edges: HashMap::new(),
883 parent_edges: HashMap::new(),
884 batch_parallelism,
885 }
886 }
887
888 pub fn build(self, root_stage_id: StageId) -> StageGraph {
889 StageGraph {
890 root_stage_id,
891 stages: self.stages,
892 child_edges: self.child_edges,
893 parent_edges: self.parent_edges,
894 batch_parallelism: self.batch_parallelism,
895 }
896 }
897
898 pub fn link_to_child(&mut self, parent_id: StageId, child_id: StageId) {
901 self.child_edges
902 .get_mut(&parent_id)
903 .unwrap()
904 .insert(child_id);
905 self.parent_edges
906 .get_mut(&child_id)
907 .unwrap()
908 .insert(parent_id);
909 }
910
911 pub fn add_node(&mut self, stage: QueryStageRef) {
912 self.child_edges.insert(stage.id, HashSet::new());
914 self.parent_edges.insert(stage.id, HashSet::new());
915 self.stages.insert(stage.id, stage);
916 }
917}
918
919impl BatchPlanFragmenter {
920 pub async fn generate_complete_query(self) -> SchedulerResult<Query> {
926 let stage_graph = self.stage_graph.unwrap();
927 let new_stage_graph = stage_graph
928 .complete(
929 &self.catalog_reader,
930 &self.worker_node_manager,
931 self.timezone.to_owned(),
932 )
933 .await?;
934 Ok(Query {
935 query_id: self.query_id,
936 stage_graph: new_stage_graph,
937 })
938 }
939
940 fn new_stage(
941 &mut self,
942 root: PlanRef,
943 exchange_info: Option<ExchangeInfo>,
944 ) -> SchedulerResult<QueryStageRef> {
945 let next_stage_id = self.next_stage_id;
946 self.next_stage_id += 1;
947
948 let mut table_scan_info = self.collect_stage_table_scan(root.clone())?;
949 let source_info = if table_scan_info.is_none() {
952 Self::collect_stage_source(root.clone())?
953 } else {
954 None
955 };
956
957 let file_scan_info = if table_scan_info.is_none() && source_info.is_none() {
958 Self::collect_stage_file_scan(root.clone())?
959 } else {
960 None
961 };
962
963 let mut has_lookup_join = false;
964 let parallelism = match root.distribution() {
965 Distribution::Single => {
966 if let Some(info) = &mut table_scan_info {
967 if let Some(partitions) = &mut info.partitions {
968 if partitions.len() != 1 {
969 tracing::warn!(
972 "The stage has single distribution, but contains a scan of table `{}` with {} partitions. A single random worker will be assigned",
973 info.name,
974 partitions.len()
975 );
976
977 *partitions = partitions
978 .drain()
979 .take(1)
980 .update(|(_, info)| {
981 info.vnode_bitmap = Bitmap::ones(info.vnode_bitmap.len());
982 })
983 .collect();
984 }
985 } else {
986 }
988 } else if source_info.is_some() {
989 return Err(SchedulerError::Internal(anyhow!(
990 "The stage has single distribution, but contains a source operator"
991 )));
992 }
993 1
994 }
995 _ => {
996 if let Some(table_scan_info) = &table_scan_info {
997 table_scan_info
998 .partitions
999 .as_ref()
1000 .map(|m| m.len())
1001 .unwrap_or(1)
1002 } else if let Some(lookup_join_parallelism) =
1003 self.collect_stage_lookup_join_parallelism(root.clone())?
1004 {
1005 has_lookup_join = true;
1006 lookup_join_parallelism
1007 } else if source_info.is_some() {
1008 0
1009 } else if file_scan_info.is_some() {
1010 1
1011 } else {
1012 self.batch_parallelism
1013 }
1014 }
1015 };
1016 if source_info.is_none() && file_scan_info.is_none() && parallelism == 0 {
1017 return Err(BatchError::EmptyWorkerNodes.into());
1018 }
1019 let parallelism = if parallelism == 0 {
1020 None
1021 } else {
1022 Some(parallelism as u32)
1023 };
1024 let dml_table_id = Self::collect_dml_table_id(&root);
1025 let mut builder = QueryStageBuilder::new(
1026 next_stage_id,
1027 self.query_id.clone(),
1028 parallelism,
1029 exchange_info,
1030 table_scan_info,
1031 source_info,
1032 file_scan_info,
1033 has_lookup_join,
1034 dml_table_id,
1035 root.ctx().session_ctx().session_id(),
1036 root.ctx()
1037 .session_ctx()
1038 .config()
1039 .batch_enable_distributed_dml(),
1040 );
1041
1042 self.visit_node(root, &mut builder, None)?;
1043
1044 Ok(builder.finish(self.stage_graph_builder.as_mut().unwrap()))
1045 }
1046
1047 fn visit_node(
1048 &mut self,
1049 node: PlanRef,
1050 builder: &mut QueryStageBuilder,
1051 parent_exec_node: Option<&mut ExecutionPlanNode>,
1052 ) -> SchedulerResult<()> {
1053 match node.node_type() {
1054 PlanNodeType::BatchExchange => {
1055 self.visit_exchange(node.clone(), builder, parent_exec_node)?;
1056 }
1057 _ => {
1058 let mut execution_plan_node = ExecutionPlanNode::try_from(node.clone())?;
1059
1060 for child in node.inputs() {
1061 self.visit_node(child, builder, Some(&mut execution_plan_node))?;
1062 }
1063
1064 if let Some(parent) = parent_exec_node {
1065 parent.children.push(Arc::new(execution_plan_node));
1066 } else {
1067 builder.root = Some(Arc::new(execution_plan_node));
1068 }
1069 }
1070 }
1071 Ok(())
1072 }
1073
1074 fn visit_exchange(
1075 &mut self,
1076 node: PlanRef,
1077 builder: &mut QueryStageBuilder,
1078 parent_exec_node: Option<&mut ExecutionPlanNode>,
1079 ) -> SchedulerResult<()> {
1080 let mut execution_plan_node = ExecutionPlanNode::try_from(node.clone())?;
1081 let child_exchange_info = if let Some(parallelism) = builder.parallelism {
1082 Some(node.distribution().to_prost(
1083 parallelism,
1084 &self.catalog_reader,
1085 &self.worker_node_manager,
1086 )?)
1087 } else {
1088 None
1089 };
1090 let child_stage = self.new_stage(node.inputs()[0].clone(), child_exchange_info)?;
1091 execution_plan_node.source_stage_id = Some(child_stage.id);
1092 if builder.parallelism.is_none() {
1093 builder
1094 .children_exchange_distribution
1095 .insert(child_stage.id, node.distribution().clone());
1096 }
1097
1098 if let Some(parent) = parent_exec_node {
1099 parent.children.push(Arc::new(execution_plan_node));
1100 } else {
1101 builder.root = Some(Arc::new(execution_plan_node));
1102 }
1103
1104 builder.children_stages.push(child_stage);
1105 Ok(())
1106 }
1107
1108 fn collect_stage_source(node: PlanRef) -> SchedulerResult<Option<SourceScanInfo>> {
1113 if node.node_type() == PlanNodeType::BatchExchange {
1114 return Ok(None);
1116 }
1117
1118 if let Some(batch_kafka_node) = node.as_batch_kafka_scan() {
1119 let batch_kafka_scan: &BatchKafkaScan = batch_kafka_node;
1120 let source_catalog = batch_kafka_scan.source_catalog();
1121 if let Some(source_catalog) = source_catalog {
1122 let property =
1123 ConnectorProperties::extract(source_catalog.with_properties.clone(), false)?;
1124 let timestamp_bound = batch_kafka_scan.kafka_timestamp_range_value();
1125 return Ok(Some(SourceScanInfo::new(SourceFetchInfo {
1126 schema: batch_kafka_scan.base.schema().clone(),
1127 connector: property,
1128 fetch_parameters: SourceFetchParameters::KafkaTimebound {
1129 lower: timestamp_bound.0,
1130 upper: timestamp_bound.1,
1131 },
1132 as_of: None,
1133 })));
1134 }
1135 } else if let Some(batch_iceberg_scan) = node.as_batch_iceberg_scan() {
1136 let batch_iceberg_scan: &BatchIcebergScan = batch_iceberg_scan;
1137 let source_catalog = batch_iceberg_scan.source_catalog();
1138 if let Some(source_catalog) = source_catalog {
1139 let property =
1140 ConnectorProperties::extract(source_catalog.with_properties.clone(), false)?;
1141 let as_of = batch_iceberg_scan.as_of();
1142 return Ok(Some(SourceScanInfo::new(SourceFetchInfo {
1143 schema: batch_iceberg_scan.base.schema().clone(),
1144 connector: property,
1145 fetch_parameters: SourceFetchParameters::IcebergSpecificInfo(
1146 IcebergSpecificInfo {
1147 predicate: batch_iceberg_scan.predicate.clone(),
1148 iceberg_scan_type: batch_iceberg_scan.iceberg_scan_type(),
1149 },
1150 ),
1151 as_of,
1152 })));
1153 }
1154 } else if let Some(source_node) = node.as_batch_source() {
1155 let source_node: &BatchSource = source_node;
1157 let source_catalog = source_node.source_catalog();
1158 if let Some(source_catalog) = source_catalog {
1159 let property =
1160 ConnectorProperties::extract(source_catalog.with_properties.clone(), false)?;
1161 let as_of = source_node.as_of();
1162 return Ok(Some(SourceScanInfo::new(SourceFetchInfo {
1163 schema: source_node.base.schema().clone(),
1164 connector: property,
1165 fetch_parameters: SourceFetchParameters::Empty,
1166 as_of,
1167 })));
1168 }
1169 }
1170
1171 node.inputs()
1172 .into_iter()
1173 .find_map(|n| Self::collect_stage_source(n).transpose())
1174 .transpose()
1175 }
1176
1177 fn collect_stage_file_scan(node: PlanRef) -> SchedulerResult<Option<FileScanInfo>> {
1178 if node.node_type() == PlanNodeType::BatchExchange {
1179 return Ok(None);
1181 }
1182
1183 if let Some(batch_file_scan) = node.as_batch_file_scan() {
1184 return Ok(Some(FileScanInfo {
1185 file_location: batch_file_scan.core.file_location().clone(),
1186 }));
1187 }
1188
1189 node.inputs()
1190 .into_iter()
1191 .find_map(|n| Self::collect_stage_file_scan(n).transpose())
1192 .transpose()
1193 }
1194
1195 fn collect_stage_table_scan(&self, node: PlanRef) -> SchedulerResult<Option<TableScanInfo>> {
1200 let build_table_scan_info = |name, table_desc: &TableDesc, scan_range| {
1201 let table_catalog = self
1202 .catalog_reader
1203 .read_guard()
1204 .get_any_table_by_id(&table_desc.table_id)
1205 .cloned()
1206 .map_err(RwError::from)?;
1207 let vnode_mapping = self
1208 .worker_node_manager
1209 .fragment_mapping(table_catalog.fragment_id)?;
1210 let partitions = derive_partitions(scan_range, table_desc, &vnode_mapping)?;
1211 let info = TableScanInfo::new(name, partitions);
1212 Ok(Some(info))
1213 };
1214 if node.node_type() == PlanNodeType::BatchExchange {
1215 return Ok(None);
1217 }
1218 if let Some(scan_node) = node.as_batch_sys_seq_scan() {
1219 let name = scan_node.core().table_name.to_owned();
1220 Ok(Some(TableScanInfo::system_table(name)))
1221 } else if let Some(scan_node) = node.as_batch_log_seq_scan() {
1222 build_table_scan_info(
1223 scan_node.core().table_name.to_owned(),
1224 &scan_node.core().table_desc,
1225 &[],
1226 )
1227 } else if let Some(scan_node) = node.as_batch_seq_scan() {
1228 build_table_scan_info(
1229 scan_node.core().table_name.to_owned(),
1230 &scan_node.core().table_desc,
1231 scan_node.scan_ranges(),
1232 )
1233 } else {
1234 node.inputs()
1235 .into_iter()
1236 .find_map(|n| self.collect_stage_table_scan(n).transpose())
1237 .transpose()
1238 }
1239 }
1240
1241 fn collect_dml_table_id(node: &PlanRef) -> Option<TableId> {
1243 if node.node_type() == PlanNodeType::BatchExchange {
1244 return None;
1245 }
1246 if let Some(insert) = node.as_batch_insert() {
1247 Some(insert.core.table_id)
1248 } else if let Some(update) = node.as_batch_update() {
1249 Some(update.core.table_id)
1250 } else if let Some(delete) = node.as_batch_delete() {
1251 Some(delete.core.table_id)
1252 } else {
1253 node.inputs()
1254 .into_iter()
1255 .find_map(|n| Self::collect_dml_table_id(&n))
1256 }
1257 }
1258
1259 fn collect_stage_lookup_join_parallelism(
1260 &self,
1261 node: PlanRef,
1262 ) -> SchedulerResult<Option<usize>> {
1263 if node.node_type() == PlanNodeType::BatchExchange {
1264 return Ok(None);
1266 }
1267 if let Some(lookup_join) = node.as_batch_lookup_join() {
1268 let table_desc = lookup_join.right_table_desc();
1269 let table_catalog = self
1270 .catalog_reader
1271 .read_guard()
1272 .get_any_table_by_id(&table_desc.table_id)
1273 .cloned()
1274 .map_err(RwError::from)?;
1275 let vnode_mapping = self
1276 .worker_node_manager
1277 .fragment_mapping(table_catalog.fragment_id)?;
1278 let parallelism = vnode_mapping.iter().sorted().dedup().count();
1279 Ok(Some(parallelism))
1280 } else {
1281 node.inputs()
1282 .into_iter()
1283 .find_map(|n| self.collect_stage_lookup_join_parallelism(n).transpose())
1284 .transpose()
1285 }
1286 }
1287}
1288
1289fn derive_partitions(
1292 scan_ranges: &[ScanRange],
1293 table_desc: &TableDesc,
1294 vnode_mapping: &WorkerSlotMapping,
1295) -> SchedulerResult<HashMap<WorkerSlotId, TablePartitionInfo>> {
1296 let vnode_mapping = if table_desc.vnode_count != vnode_mapping.len() {
1297 assert!(
1301 table_desc.vnode_count == 1,
1302 "fragment vnode count {} does not match table vnode count {}",
1303 vnode_mapping.len(),
1304 table_desc.vnode_count,
1305 );
1306 &WorkerSlotMapping::new_single(vnode_mapping.iter().next().unwrap())
1307 } else {
1308 vnode_mapping
1309 };
1310 let vnode_count = vnode_mapping.len();
1311
1312 let mut partitions: HashMap<WorkerSlotId, (BitmapBuilder, Vec<_>)> = HashMap::new();
1313
1314 if scan_ranges.is_empty() {
1315 return Ok(vnode_mapping
1316 .to_bitmaps()
1317 .into_iter()
1318 .map(|(k, vnode_bitmap)| {
1319 (
1320 k,
1321 TablePartitionInfo {
1322 vnode_bitmap,
1323 scan_ranges: vec![],
1324 },
1325 )
1326 })
1327 .collect());
1328 }
1329
1330 let table_distribution = TableDistribution::new_from_storage_table_desc(
1331 Some(Bitmap::ones(vnode_count).into()),
1332 &table_desc.try_to_protobuf()?,
1333 );
1334
1335 for scan_range in scan_ranges {
1336 let vnode = scan_range.try_compute_vnode(&table_distribution);
1337 match vnode {
1338 None => {
1339 vnode_mapping.to_bitmaps().into_iter().for_each(
1341 |(worker_slot_id, vnode_bitmap)| {
1342 let (bitmap, scan_ranges) = partitions
1343 .entry(worker_slot_id)
1344 .or_insert_with(|| (BitmapBuilder::zeroed(vnode_count), vec![]));
1345 vnode_bitmap
1346 .iter()
1347 .enumerate()
1348 .for_each(|(vnode, b)| bitmap.set(vnode, b));
1349 scan_ranges.push(scan_range.to_protobuf());
1350 },
1351 );
1352 }
1353 Some(vnode) => {
1355 let worker_slot_id = vnode_mapping[vnode];
1356 let (bitmap, scan_ranges) = partitions
1357 .entry(worker_slot_id)
1358 .or_insert_with(|| (BitmapBuilder::zeroed(vnode_count), vec![]));
1359 bitmap.set(vnode.to_index(), true);
1360 scan_ranges.push(scan_range.to_protobuf());
1361 }
1362 }
1363 }
1364
1365 Ok(partitions
1366 .into_iter()
1367 .map(|(k, (bitmap, scan_ranges))| {
1368 (
1369 k,
1370 TablePartitionInfo {
1371 vnode_bitmap: bitmap.finish(),
1372 scan_ranges,
1373 },
1374 )
1375 })
1376 .collect())
1377}
1378
1379#[cfg(test)]
1380mod tests {
1381 use std::collections::{HashMap, HashSet};
1382
1383 use risingwave_pb::batch_plan::plan_node::NodeBody;
1384
1385 use crate::optimizer::plan_node::PlanNodeType;
1386 use crate::scheduler::plan_fragmenter::StageId;
1387
1388 #[tokio::test]
1389 async fn test_fragmenter() {
1390 let query = crate::scheduler::distributed::tests::create_query().await;
1391
1392 assert_eq!(query.stage_graph.root_stage_id, 0);
1393 assert_eq!(query.stage_graph.stages.len(), 4);
1394
1395 assert_eq!(query.stage_graph.child_edges[&0], [1].into());
1397 assert_eq!(query.stage_graph.child_edges[&1], [2, 3].into());
1398 assert_eq!(query.stage_graph.child_edges[&2], HashSet::new());
1399 assert_eq!(query.stage_graph.child_edges[&3], HashSet::new());
1400
1401 assert_eq!(query.stage_graph.parent_edges[&0], HashSet::new());
1403 assert_eq!(query.stage_graph.parent_edges[&1], [0].into());
1404 assert_eq!(query.stage_graph.parent_edges[&2], [1].into());
1405 assert_eq!(query.stage_graph.parent_edges[&3], [1].into());
1406
1407 {
1409 let stage_id_to_pos: HashMap<StageId, usize> = query
1410 .stage_graph
1411 .stage_ids_by_topo_order()
1412 .enumerate()
1413 .map(|(pos, stage_id)| (stage_id, pos))
1414 .collect();
1415
1416 for stage_id in query.stage_graph.stages.keys() {
1417 let stage_pos = stage_id_to_pos[stage_id];
1418 for child_stage_id in &query.stage_graph.child_edges[stage_id] {
1419 let child_pos = stage_id_to_pos[child_stage_id];
1420 assert!(stage_pos > child_pos);
1421 }
1422 }
1423 }
1424
1425 let root_exchange = query.stage_graph.stages.get(&0).unwrap();
1427 assert_eq!(root_exchange.root.node_type(), PlanNodeType::BatchExchange);
1428 assert_eq!(root_exchange.root.source_stage_id, Some(1));
1429 assert!(matches!(root_exchange.root.node, NodeBody::Exchange(_)));
1430 assert_eq!(root_exchange.parallelism, Some(1));
1431 assert!(!root_exchange.has_table_scan());
1432
1433 let join_node = query.stage_graph.stages.get(&1).unwrap();
1434 assert_eq!(join_node.root.node_type(), PlanNodeType::BatchHashJoin);
1435 assert_eq!(join_node.parallelism, Some(24));
1436
1437 assert!(matches!(join_node.root.node, NodeBody::HashJoin(_)));
1438 assert_eq!(join_node.root.source_stage_id, None);
1439 assert_eq!(2, join_node.root.children.len());
1440
1441 assert!(matches!(
1442 join_node.root.children[0].node,
1443 NodeBody::Exchange(_)
1444 ));
1445 assert_eq!(join_node.root.children[0].source_stage_id, Some(2));
1446 assert_eq!(0, join_node.root.children[0].children.len());
1447
1448 assert!(matches!(
1449 join_node.root.children[1].node,
1450 NodeBody::Exchange(_)
1451 ));
1452 assert_eq!(join_node.root.children[1].source_stage_id, Some(3));
1453 assert_eq!(0, join_node.root.children[1].children.len());
1454 assert!(!join_node.has_table_scan());
1455
1456 let scan_node1 = query.stage_graph.stages.get(&2).unwrap();
1457 assert_eq!(scan_node1.root.node_type(), PlanNodeType::BatchSeqScan);
1458 assert_eq!(scan_node1.root.source_stage_id, None);
1459 assert_eq!(0, scan_node1.root.children.len());
1460 assert!(scan_node1.has_table_scan());
1461
1462 let scan_node2 = query.stage_graph.stages.get(&3).unwrap();
1463 assert_eq!(scan_node2.root.node_type(), PlanNodeType::BatchFilter);
1464 assert_eq!(scan_node2.root.source_stage_id, None);
1465 assert_eq!(1, scan_node2.root.children.len());
1466 assert!(scan_node2.has_table_scan());
1467 }
1468}