1use std::cmp::min;
16use std::collections::{HashMap, HashSet};
17use std::fmt::{Debug, Formatter};
18use std::num::NonZeroU64;
19use std::sync::Arc;
20
21use anyhow::anyhow;
22use async_recursion::async_recursion;
23use enum_as_inner::EnumAsInner;
24use futures::TryStreamExt;
25use iceberg::expr::Predicate as IcebergPredicate;
26use itertools::Itertools;
27use petgraph::{Directed, Graph};
28use pgwire::pg_server::SessionId;
29use risingwave_batch::error::BatchError;
30use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeSelector;
31use risingwave_common::bitmap::{Bitmap, BitmapBuilder};
32use risingwave_common::catalog::{Schema, TableDesc};
33use risingwave_common::hash::table_distribution::TableDistribution;
34use risingwave_common::hash::{WorkerSlotId, WorkerSlotMapping};
35use risingwave_common::util::scan_range::ScanRange;
36use risingwave_connector::source::filesystem::opendal_source::opendal_enumerator::OpendalEnumerator;
37use risingwave_connector::source::filesystem::opendal_source::{
38 OpendalAzblob, OpendalGcs, OpendalS3,
39};
40use risingwave_connector::source::iceberg::IcebergSplitEnumerator;
41use risingwave_connector::source::kafka::KafkaSplitEnumerator;
42use risingwave_connector::source::prelude::DatagenSplitEnumerator;
43use risingwave_connector::source::reader::reader::build_opendal_fs_list_for_batch;
44use risingwave_connector::source::{
45 ConnectorProperties, SourceEnumeratorContext, SplitEnumerator, SplitImpl,
46};
47use risingwave_pb::batch_plan::iceberg_scan_node::IcebergScanType;
48use risingwave_pb::batch_plan::plan_node::NodeBody;
49use risingwave_pb::batch_plan::{ExchangeInfo, ScanRange as ScanRangeProto};
50use risingwave_pb::plan_common::Field as PbField;
51use risingwave_sqlparser::ast::AsOf;
52use serde::Serialize;
53use serde::ser::SerializeStruct;
54use uuid::Uuid;
55
56use super::SchedulerError;
57use crate::catalog::TableId;
58use crate::catalog::catalog_service::CatalogReader;
59use crate::error::RwError;
60use crate::optimizer::PlanRef;
61use crate::optimizer::plan_node::generic::{GenericPlanRef, PhysicalPlanRef};
62use crate::optimizer::plan_node::utils::to_iceberg_time_travel_as_of;
63use crate::optimizer::plan_node::{
64 BatchIcebergScan, BatchKafkaScan, BatchSource, PlanNodeId, PlanNodeType,
65};
66use crate::optimizer::property::Distribution;
67use crate::scheduler::SchedulerResult;
68
69#[derive(Clone, Debug, Hash, Eq, PartialEq)]
70pub struct QueryId {
71 pub id: String,
72}
73
74impl std::fmt::Display for QueryId {
75 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
76 write!(f, "QueryId:{}", self.id)
77 }
78}
79
80pub type StageId = u32;
81
82pub const ROOT_TASK_ID: u64 = 0;
84pub const ROOT_TASK_OUTPUT_ID: u64 = 0;
86pub type TaskId = u64;
87
88#[derive(Clone, Debug)]
90pub struct ExecutionPlanNode {
91 pub plan_node_id: PlanNodeId,
92 pub plan_node_type: PlanNodeType,
93 pub node: NodeBody,
94 pub schema: Vec<PbField>,
95
96 pub children: Vec<Arc<ExecutionPlanNode>>,
97
98 pub source_stage_id: Option<StageId>,
103}
104
105impl Serialize for ExecutionPlanNode {
106 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
107 where
108 S: serde::Serializer,
109 {
110 let mut state = serializer.serialize_struct("QueryStage", 5)?;
111 state.serialize_field("plan_node_id", &self.plan_node_id)?;
112 state.serialize_field("plan_node_type", &self.plan_node_type)?;
113 state.serialize_field("schema", &self.schema)?;
114 state.serialize_field("children", &self.children)?;
115 state.serialize_field("source_stage_id", &self.source_stage_id)?;
116 state.end()
117 }
118}
119
120impl TryFrom<PlanRef> for ExecutionPlanNode {
121 type Error = SchedulerError;
122
123 fn try_from(plan_node: PlanRef) -> Result<Self, Self::Error> {
124 Ok(Self {
125 plan_node_id: plan_node.plan_base().id(),
126 plan_node_type: plan_node.node_type(),
127 node: plan_node.try_to_batch_prost_body()?,
128 children: vec![],
129 schema: plan_node.schema().to_prost(),
130 source_stage_id: None,
131 })
132 }
133}
134
135impl ExecutionPlanNode {
136 pub fn node_type(&self) -> PlanNodeType {
137 self.plan_node_type
138 }
139}
140
141pub struct BatchPlanFragmenter {
143 query_id: QueryId,
144 next_stage_id: StageId,
145 worker_node_manager: WorkerNodeSelector,
146 catalog_reader: CatalogReader,
147
148 batch_parallelism: usize,
149 timezone: String,
150
151 stage_graph_builder: Option<StageGraphBuilder>,
152 stage_graph: Option<StageGraph>,
153}
154
155impl Default for QueryId {
156 fn default() -> Self {
157 Self {
158 id: Uuid::new_v4().to_string(),
159 }
160 }
161}
162
163impl BatchPlanFragmenter {
164 pub fn new(
165 worker_node_manager: WorkerNodeSelector,
166 catalog_reader: CatalogReader,
167 batch_parallelism: Option<NonZeroU64>,
168 timezone: String,
169 batch_node: PlanRef,
170 ) -> SchedulerResult<Self> {
171 let batch_parallelism = if let Some(num) = batch_parallelism {
176 min(
178 num.get() as usize,
179 worker_node_manager.schedule_unit_count(),
180 )
181 } else {
182 worker_node_manager.schedule_unit_count()
184 };
185
186 let mut plan_fragmenter = Self {
187 query_id: Default::default(),
188 next_stage_id: 0,
189 worker_node_manager,
190 catalog_reader,
191 batch_parallelism,
192 timezone,
193 stage_graph_builder: Some(StageGraphBuilder::new(batch_parallelism)),
194 stage_graph: None,
195 };
196 plan_fragmenter.split_into_stage(batch_node)?;
197 Ok(plan_fragmenter)
198 }
199
200 fn split_into_stage(&mut self, batch_node: PlanRef) -> SchedulerResult<()> {
202 let root_stage = self.new_stage(
203 batch_node,
204 Some(Distribution::Single.to_prost(
205 1,
206 &self.catalog_reader,
207 &self.worker_node_manager,
208 )?),
209 )?;
210 self.stage_graph = Some(
211 self.stage_graph_builder
212 .take()
213 .unwrap()
214 .build(root_stage.id),
215 );
216 Ok(())
217 }
218}
219
220#[derive(Debug)]
222#[cfg_attr(test, derive(Clone))]
223pub struct Query {
224 pub query_id: QueryId,
226 pub stage_graph: StageGraph,
227}
228
229impl Query {
230 pub fn leaf_stages(&self) -> Vec<StageId> {
231 let mut ret_leaf_stages = Vec::new();
232 for stage_id in self.stage_graph.stages.keys() {
233 if self
234 .stage_graph
235 .get_child_stages_unchecked(stage_id)
236 .is_empty()
237 {
238 ret_leaf_stages.push(*stage_id);
239 }
240 }
241 ret_leaf_stages
242 }
243
244 pub fn get_parents(&self, stage_id: &StageId) -> &HashSet<StageId> {
245 self.stage_graph.parent_edges.get(stage_id).unwrap()
246 }
247
248 pub fn root_stage_id(&self) -> StageId {
249 self.stage_graph.root_stage_id
250 }
251
252 pub fn query_id(&self) -> &QueryId {
253 &self.query_id
254 }
255
256 pub fn stages_with_table_scan(&self) -> HashSet<StageId> {
257 self.stage_graph
258 .stages
259 .iter()
260 .filter_map(|(stage_id, stage_query)| {
261 if stage_query.has_table_scan() {
262 Some(*stage_id)
263 } else {
264 None
265 }
266 })
267 .collect()
268 }
269
270 pub fn has_lookup_join_stage(&self) -> bool {
271 self.stage_graph
272 .stages
273 .iter()
274 .any(|(_stage_id, stage_query)| stage_query.has_lookup_join())
275 }
276}
277
278#[derive(Debug, Clone)]
279pub enum SourceFetchParameters {
280 IcebergSpecificInfo(IcebergSpecificInfo),
281 KafkaTimebound {
282 lower: Option<i64>,
283 upper: Option<i64>,
284 },
285 Empty,
286}
287
288#[derive(Debug, Clone)]
289pub struct SourceFetchInfo {
290 pub schema: Schema,
291 pub connector: ConnectorProperties,
294 pub fetch_parameters: SourceFetchParameters,
297 pub as_of: Option<AsOf>,
298}
299
300#[derive(Debug, Clone)]
301pub struct IcebergSpecificInfo {
302 pub iceberg_scan_type: IcebergScanType,
303 pub predicate: IcebergPredicate,
304}
305
306#[derive(Clone, Debug)]
307pub enum SourceScanInfo {
308 Incomplete(SourceFetchInfo),
310 Complete(Vec<SplitImpl>),
311}
312
313impl SourceScanInfo {
314 pub fn new(fetch_info: SourceFetchInfo) -> Self {
315 Self::Incomplete(fetch_info)
316 }
317
318 pub async fn complete(
319 self,
320 batch_parallelism: usize,
321 timezone: String,
322 ) -> SchedulerResult<Self> {
323 let fetch_info = match self {
324 SourceScanInfo::Incomplete(fetch_info) => fetch_info,
325 SourceScanInfo::Complete(_) => {
326 unreachable!("Never call complete when SourceScanInfo is already complete")
327 }
328 };
329 match (fetch_info.connector, fetch_info.fetch_parameters) {
330 (
331 ConnectorProperties::Kafka(prop),
332 SourceFetchParameters::KafkaTimebound { lower, upper },
333 ) => {
334 let mut kafka_enumerator =
335 KafkaSplitEnumerator::new(*prop, SourceEnumeratorContext::dummy().into())
336 .await?;
337 let split_info = kafka_enumerator
338 .list_splits_batch(lower, upper)
339 .await?
340 .into_iter()
341 .map(SplitImpl::Kafka)
342 .collect_vec();
343
344 Ok(SourceScanInfo::Complete(split_info))
345 }
346 (ConnectorProperties::Datagen(prop), SourceFetchParameters::Empty) => {
347 let mut datagen_enumerator =
348 DatagenSplitEnumerator::new(*prop, SourceEnumeratorContext::dummy().into())
349 .await?;
350 let split_info = datagen_enumerator.list_splits().await?;
351 let res = split_info.into_iter().map(SplitImpl::Datagen).collect_vec();
352
353 Ok(SourceScanInfo::Complete(res))
354 }
355 (ConnectorProperties::OpendalS3(prop), SourceFetchParameters::Empty) => {
356 let lister: OpendalEnumerator<OpendalS3> =
357 OpendalEnumerator::new_s3_source(prop.s3_properties, prop.assume_role)?;
358 let stream = build_opendal_fs_list_for_batch(lister);
359
360 let batch_res: Vec<_> = stream.try_collect().await?;
361 let res = batch_res
362 .into_iter()
363 .map(SplitImpl::OpendalS3)
364 .collect_vec();
365
366 Ok(SourceScanInfo::Complete(res))
367 }
368 (ConnectorProperties::Gcs(prop), SourceFetchParameters::Empty) => {
369 let lister: OpendalEnumerator<OpendalGcs> =
370 OpendalEnumerator::new_gcs_source(*prop)?;
371 let stream = build_opendal_fs_list_for_batch(lister);
372 let batch_res: Vec<_> = stream.try_collect().await?;
373 let res = batch_res.into_iter().map(SplitImpl::Gcs).collect_vec();
374
375 Ok(SourceScanInfo::Complete(res))
376 }
377 (ConnectorProperties::Azblob(prop), SourceFetchParameters::Empty) => {
378 let lister: OpendalEnumerator<OpendalAzblob> =
379 OpendalEnumerator::new_azblob_source(*prop)?;
380 let stream = build_opendal_fs_list_for_batch(lister);
381 let batch_res: Vec<_> = stream.try_collect().await?;
382 let res = batch_res.into_iter().map(SplitImpl::Azblob).collect_vec();
383
384 Ok(SourceScanInfo::Complete(res))
385 }
386 (
387 ConnectorProperties::Iceberg(prop),
388 SourceFetchParameters::IcebergSpecificInfo(iceberg_specific_info),
389 ) => {
390 let iceberg_enumerator =
391 IcebergSplitEnumerator::new(*prop, SourceEnumeratorContext::dummy().into())
392 .await?;
393
394 let time_travel_info = to_iceberg_time_travel_as_of(&fetch_info.as_of, &timezone)?;
395
396 let split_info = iceberg_enumerator
397 .list_splits_batch(
398 fetch_info.schema,
399 time_travel_info,
400 batch_parallelism,
401 iceberg_specific_info.iceberg_scan_type,
402 iceberg_specific_info.predicate,
403 )
404 .await?
405 .into_iter()
406 .map(SplitImpl::Iceberg)
407 .collect_vec();
408
409 Ok(SourceScanInfo::Complete(split_info))
410 }
411 (connector, _) => Err(SchedulerError::Internal(anyhow!(
412 "Unsupported to query directly from this {} source, \
413 please create a table or streaming job from it",
414 connector.kind()
415 ))),
416 }
417 }
418
419 pub fn split_info(&self) -> SchedulerResult<&Vec<SplitImpl>> {
420 match self {
421 Self::Incomplete(_) => Err(SchedulerError::Internal(anyhow!(
422 "Should not get split info from incomplete source scan info"
423 ))),
424 Self::Complete(split_info) => Ok(split_info),
425 }
426 }
427}
428
429#[derive(Clone, Debug)]
430pub struct TableScanInfo {
431 name: String,
433
434 partitions: Option<HashMap<WorkerSlotId, TablePartitionInfo>>,
442}
443
444impl TableScanInfo {
445 pub fn new(name: String, partitions: HashMap<WorkerSlotId, TablePartitionInfo>) -> Self {
447 Self {
448 name,
449 partitions: Some(partitions),
450 }
451 }
452
453 pub fn system_table(name: String) -> Self {
455 Self {
456 name,
457 partitions: None,
458 }
459 }
460
461 pub fn name(&self) -> &str {
462 self.name.as_ref()
463 }
464
465 pub fn partitions(&self) -> Option<&HashMap<WorkerSlotId, TablePartitionInfo>> {
466 self.partitions.as_ref()
467 }
468}
469
470#[derive(Clone, Debug)]
471pub struct TablePartitionInfo {
472 pub vnode_bitmap: Bitmap,
473 pub scan_ranges: Vec<ScanRangeProto>,
474}
475
476#[derive(Clone, Debug, EnumAsInner)]
477pub enum PartitionInfo {
478 Table(TablePartitionInfo),
479 Source(Vec<SplitImpl>),
480 File(Vec<String>),
481}
482
483#[derive(Clone, Debug)]
484pub struct FileScanInfo {
485 pub file_location: Vec<String>,
486}
487
488#[derive(Clone)]
490pub struct QueryStage {
491 pub query_id: QueryId,
492 pub id: StageId,
493 pub root: Arc<ExecutionPlanNode>,
494 pub exchange_info: Option<ExchangeInfo>,
495 pub parallelism: Option<u32>,
496 pub table_scan_info: Option<TableScanInfo>,
498 pub source_info: Option<SourceScanInfo>,
499 pub file_scan_info: Option<FileScanInfo>,
500 pub has_lookup_join: bool,
501 pub dml_table_id: Option<TableId>,
502 pub session_id: SessionId,
503 pub batch_enable_distributed_dml: bool,
504
505 children_exchange_distribution: Option<HashMap<StageId, Distribution>>,
507}
508
509impl QueryStage {
510 pub fn has_table_scan(&self) -> bool {
514 self.table_scan_info.is_some()
515 }
516
517 pub fn has_lookup_join(&self) -> bool {
520 self.has_lookup_join
521 }
522
523 pub fn clone_with_exchange_info(
524 &self,
525 exchange_info: Option<ExchangeInfo>,
526 parallelism: Option<u32>,
527 ) -> Self {
528 if let Some(exchange_info) = exchange_info {
529 return Self {
530 query_id: self.query_id.clone(),
531 id: self.id,
532 root: self.root.clone(),
533 exchange_info: Some(exchange_info),
534 parallelism,
535 table_scan_info: self.table_scan_info.clone(),
536 source_info: self.source_info.clone(),
537 file_scan_info: self.file_scan_info.clone(),
538 has_lookup_join: self.has_lookup_join,
539 dml_table_id: self.dml_table_id,
540 session_id: self.session_id,
541 batch_enable_distributed_dml: self.batch_enable_distributed_dml,
542 children_exchange_distribution: self.children_exchange_distribution.clone(),
543 };
544 }
545 self.clone()
546 }
547
548 pub fn clone_with_exchange_info_and_complete_source_info(
549 &self,
550 exchange_info: Option<ExchangeInfo>,
551 source_info: SourceScanInfo,
552 task_parallelism: u32,
553 ) -> Self {
554 assert!(matches!(source_info, SourceScanInfo::Complete(_)));
555 let exchange_info = if let Some(exchange_info) = exchange_info {
556 Some(exchange_info)
557 } else {
558 self.exchange_info.clone()
559 };
560 Self {
561 query_id: self.query_id.clone(),
562 id: self.id,
563 root: self.root.clone(),
564 exchange_info,
565 parallelism: Some(task_parallelism),
566 table_scan_info: self.table_scan_info.clone(),
567 source_info: Some(source_info),
568 file_scan_info: self.file_scan_info.clone(),
569 has_lookup_join: self.has_lookup_join,
570 dml_table_id: self.dml_table_id,
571 session_id: self.session_id,
572 batch_enable_distributed_dml: self.batch_enable_distributed_dml,
573 children_exchange_distribution: None,
574 }
575 }
576}
577
578impl Debug for QueryStage {
579 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
580 f.debug_struct("QueryStage")
581 .field("id", &self.id)
582 .field("parallelism", &self.parallelism)
583 .field("exchange_info", &self.exchange_info)
584 .field("has_table_scan", &self.has_table_scan())
585 .finish()
586 }
587}
588
589impl Serialize for QueryStage {
590 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
591 where
592 S: serde::Serializer,
593 {
594 let mut state = serializer.serialize_struct("QueryStage", 3)?;
595 state.serialize_field("root", &self.root)?;
596 state.serialize_field("parallelism", &self.parallelism)?;
597 state.serialize_field("exchange_info", &self.exchange_info)?;
598 state.end()
599 }
600}
601
602pub type QueryStageRef = Arc<QueryStage>;
603
604struct QueryStageBuilder {
605 query_id: QueryId,
606 id: StageId,
607 root: Option<Arc<ExecutionPlanNode>>,
608 parallelism: Option<u32>,
609 exchange_info: Option<ExchangeInfo>,
610
611 children_stages: Vec<QueryStageRef>,
612 table_scan_info: Option<TableScanInfo>,
614 source_info: Option<SourceScanInfo>,
615 file_scan_file: Option<FileScanInfo>,
616 has_lookup_join: bool,
617 dml_table_id: Option<TableId>,
618 session_id: SessionId,
619 batch_enable_distributed_dml: bool,
620
621 children_exchange_distribution: HashMap<StageId, Distribution>,
622}
623
624impl QueryStageBuilder {
625 #[allow(clippy::too_many_arguments)]
626 fn new(
627 id: StageId,
628 query_id: QueryId,
629 parallelism: Option<u32>,
630 exchange_info: Option<ExchangeInfo>,
631 table_scan_info: Option<TableScanInfo>,
632 source_info: Option<SourceScanInfo>,
633 file_scan_file: Option<FileScanInfo>,
634 has_lookup_join: bool,
635 dml_table_id: Option<TableId>,
636 session_id: SessionId,
637 batch_enable_distributed_dml: bool,
638 ) -> Self {
639 Self {
640 query_id,
641 id,
642 root: None,
643 parallelism,
644 exchange_info,
645 children_stages: vec![],
646 table_scan_info,
647 source_info,
648 file_scan_file,
649 has_lookup_join,
650 dml_table_id,
651 session_id,
652 batch_enable_distributed_dml,
653 children_exchange_distribution: HashMap::new(),
654 }
655 }
656
657 fn finish(self, stage_graph_builder: &mut StageGraphBuilder) -> QueryStageRef {
658 let children_exchange_distribution = if self.parallelism.is_none() {
659 Some(self.children_exchange_distribution)
660 } else {
661 None
662 };
663 let stage = Arc::new(QueryStage {
664 query_id: self.query_id,
665 id: self.id,
666 root: self.root.unwrap(),
667 exchange_info: self.exchange_info,
668 parallelism: self.parallelism,
669 table_scan_info: self.table_scan_info,
670 source_info: self.source_info,
671 file_scan_info: self.file_scan_file,
672 has_lookup_join: self.has_lookup_join,
673 dml_table_id: self.dml_table_id,
674 session_id: self.session_id,
675 batch_enable_distributed_dml: self.batch_enable_distributed_dml,
676 children_exchange_distribution,
677 });
678
679 stage_graph_builder.add_node(stage.clone());
680 for child_stage in self.children_stages {
681 stage_graph_builder.link_to_child(self.id, child_stage.id);
682 }
683 stage
684 }
685}
686
687#[derive(Debug, Serialize)]
689#[cfg_attr(test, derive(Clone))]
690pub struct StageGraph {
691 pub root_stage_id: StageId,
692 pub stages: HashMap<StageId, QueryStageRef>,
693 child_edges: HashMap<StageId, HashSet<StageId>>,
695 parent_edges: HashMap<StageId, HashSet<StageId>>,
697
698 batch_parallelism: usize,
699}
700
701impl StageGraph {
702 pub fn get_child_stages_unchecked(&self, stage_id: &StageId) -> &HashSet<StageId> {
703 self.child_edges.get(stage_id).unwrap()
704 }
705
706 pub fn get_child_stages(&self, stage_id: &StageId) -> Option<&HashSet<StageId>> {
707 self.child_edges.get(stage_id)
708 }
709
710 pub fn stage_ids_by_topo_order(&self) -> impl Iterator<Item = StageId> {
712 let mut stack = Vec::with_capacity(self.stages.len());
713 stack.push(self.root_stage_id);
714 let mut ret = Vec::with_capacity(self.stages.len());
715 let mut existing = HashSet::with_capacity(self.stages.len());
716
717 while let Some(s) = stack.pop() {
718 if !existing.contains(&s) {
719 ret.push(s);
720 existing.insert(s);
721 stack.extend(&self.child_edges[&s]);
722 }
723 }
724
725 ret.into_iter().rev()
726 }
727
728 async fn complete(
729 self,
730 catalog_reader: &CatalogReader,
731 worker_node_manager: &WorkerNodeSelector,
732 timezone: String,
733 ) -> SchedulerResult<StageGraph> {
734 let mut complete_stages = HashMap::new();
735 self.complete_stage(
736 self.stages.get(&self.root_stage_id).unwrap().clone(),
737 None,
738 &mut complete_stages,
739 catalog_reader,
740 worker_node_manager,
741 timezone,
742 )
743 .await?;
744 Ok(StageGraph {
745 root_stage_id: self.root_stage_id,
746 stages: complete_stages,
747 child_edges: self.child_edges,
748 parent_edges: self.parent_edges,
749 batch_parallelism: self.batch_parallelism,
750 })
751 }
752
753 #[async_recursion]
754 async fn complete_stage(
755 &self,
756 stage: QueryStageRef,
757 exchange_info: Option<ExchangeInfo>,
758 complete_stages: &mut HashMap<StageId, QueryStageRef>,
759 catalog_reader: &CatalogReader,
760 worker_node_manager: &WorkerNodeSelector,
761 timezone: String,
762 ) -> SchedulerResult<()> {
763 let parallelism = if stage.parallelism.is_some() {
764 complete_stages.insert(
766 stage.id,
767 Arc::new(stage.clone_with_exchange_info(exchange_info, stage.parallelism)),
768 );
769 None
770 } else if matches!(stage.source_info, Some(SourceScanInfo::Incomplete(_))) {
771 let complete_source_info = stage
772 .source_info
773 .as_ref()
774 .unwrap()
775 .clone()
776 .complete(self.batch_parallelism, timezone.to_owned())
777 .await?;
778
779 let task_parallelism = match &stage.source_info {
787 Some(SourceScanInfo::Incomplete(source_fetch_info)) => {
788 match source_fetch_info.connector {
789 ConnectorProperties::Gcs(_)
790 | ConnectorProperties::OpendalS3(_)
791 | ConnectorProperties::Azblob(_) => (min(
792 complete_source_info.split_info().unwrap().len() as u32,
793 (self.batch_parallelism / 2) as u32,
794 ))
795 .max(1),
796 _ => complete_source_info.split_info().unwrap().len() as u32,
797 }
798 }
799 _ => unreachable!(),
800 };
801 let complete_stage = Arc::new(stage.clone_with_exchange_info_and_complete_source_info(
804 exchange_info,
805 complete_source_info,
806 task_parallelism,
807 ));
808 let parallelism = complete_stage.parallelism;
809 complete_stages.insert(stage.id, complete_stage);
810 parallelism
811 } else {
812 assert!(stage.file_scan_info.is_some());
813 let parallelism = min(
814 self.batch_parallelism / 2,
815 stage.file_scan_info.as_ref().unwrap().file_location.len(),
816 );
817 complete_stages.insert(
818 stage.id,
819 Arc::new(stage.clone_with_exchange_info(exchange_info, Some(parallelism as u32))),
820 );
821 None
822 };
823
824 for child_stage_id in self.child_edges.get(&stage.id).unwrap_or(&HashSet::new()) {
825 let exchange_info = if let Some(parallelism) = parallelism {
826 let exchange_distribution = stage
827 .children_exchange_distribution
828 .as_ref()
829 .unwrap()
830 .get(child_stage_id)
831 .expect("Exchange distribution is not consistent with the stage graph");
832 Some(exchange_distribution.to_prost(
833 parallelism,
834 catalog_reader,
835 worker_node_manager,
836 )?)
837 } else {
838 None
839 };
840 self.complete_stage(
841 self.stages.get(child_stage_id).unwrap().clone(),
842 exchange_info,
843 complete_stages,
844 catalog_reader,
845 worker_node_manager,
846 timezone.to_owned(),
847 )
848 .await?;
849 }
850
851 Ok(())
852 }
853
854 pub fn to_petgraph(&self) -> Graph<String, String, Directed> {
856 let mut graph = Graph::<String, String, Directed>::new();
857
858 let mut node_indices = HashMap::new();
859
860 for (&stage_id, stage_ref) in self.stages.iter().sorted_by_key(|(id, _)| **id) {
862 let node_label = format!("Stage {}: {:?}", stage_id, stage_ref);
863 let node_index = graph.add_node(node_label);
864 node_indices.insert(stage_id, node_index);
865 }
866
867 for (&parent_id, children) in &self.child_edges {
869 if let Some(&parent_index) = node_indices.get(&parent_id) {
870 for &child_id in children {
871 if let Some(&child_index) = node_indices.get(&child_id) {
872 graph.add_edge(parent_index, child_index, "".to_owned());
874 }
875 }
876 }
877 }
878
879 graph
880 }
881}
882
883struct StageGraphBuilder {
884 stages: HashMap<StageId, QueryStageRef>,
885 child_edges: HashMap<StageId, HashSet<StageId>>,
886 parent_edges: HashMap<StageId, HashSet<StageId>>,
887 batch_parallelism: usize,
888}
889
890impl StageGraphBuilder {
891 pub fn new(batch_parallelism: usize) -> Self {
892 Self {
893 stages: HashMap::new(),
894 child_edges: HashMap::new(),
895 parent_edges: HashMap::new(),
896 batch_parallelism,
897 }
898 }
899
900 pub fn build(self, root_stage_id: StageId) -> StageGraph {
901 StageGraph {
902 root_stage_id,
903 stages: self.stages,
904 child_edges: self.child_edges,
905 parent_edges: self.parent_edges,
906 batch_parallelism: self.batch_parallelism,
907 }
908 }
909
910 pub fn link_to_child(&mut self, parent_id: StageId, child_id: StageId) {
913 self.child_edges
914 .get_mut(&parent_id)
915 .unwrap()
916 .insert(child_id);
917 self.parent_edges
918 .get_mut(&child_id)
919 .unwrap()
920 .insert(parent_id);
921 }
922
923 pub fn add_node(&mut self, stage: QueryStageRef) {
924 self.child_edges.insert(stage.id, HashSet::new());
926 self.parent_edges.insert(stage.id, HashSet::new());
927 self.stages.insert(stage.id, stage);
928 }
929}
930
931impl BatchPlanFragmenter {
932 pub async fn generate_complete_query(self) -> SchedulerResult<Query> {
938 let stage_graph = self.stage_graph.unwrap();
939 let new_stage_graph = stage_graph
940 .complete(
941 &self.catalog_reader,
942 &self.worker_node_manager,
943 self.timezone.to_owned(),
944 )
945 .await?;
946 Ok(Query {
947 query_id: self.query_id,
948 stage_graph: new_stage_graph,
949 })
950 }
951
952 fn new_stage(
953 &mut self,
954 root: PlanRef,
955 exchange_info: Option<ExchangeInfo>,
956 ) -> SchedulerResult<QueryStageRef> {
957 let next_stage_id = self.next_stage_id;
958 self.next_stage_id += 1;
959
960 let mut table_scan_info = self.collect_stage_table_scan(root.clone())?;
961 let source_info = if table_scan_info.is_none() {
964 Self::collect_stage_source(root.clone())?
965 } else {
966 None
967 };
968
969 let file_scan_info = if table_scan_info.is_none() && source_info.is_none() {
970 Self::collect_stage_file_scan(root.clone())?
971 } else {
972 None
973 };
974
975 let mut has_lookup_join = false;
976 let parallelism = match root.distribution() {
977 Distribution::Single => {
978 if let Some(info) = &mut table_scan_info {
979 if let Some(partitions) = &mut info.partitions {
980 if partitions.len() != 1 {
981 tracing::warn!(
984 "The stage has single distribution, but contains a scan of table `{}` with {} partitions. A single random worker will be assigned",
985 info.name,
986 partitions.len()
987 );
988
989 *partitions = partitions
990 .drain()
991 .take(1)
992 .update(|(_, info)| {
993 info.vnode_bitmap = Bitmap::ones(info.vnode_bitmap.len());
994 })
995 .collect();
996 }
997 } else {
998 }
1000 } else if source_info.is_some() {
1001 return Err(SchedulerError::Internal(anyhow!(
1002 "The stage has single distribution, but contains a source operator"
1003 )));
1004 }
1005 1
1006 }
1007 _ => {
1008 if let Some(table_scan_info) = &table_scan_info {
1009 table_scan_info
1010 .partitions
1011 .as_ref()
1012 .map(|m| m.len())
1013 .unwrap_or(1)
1014 } else if let Some(lookup_join_parallelism) =
1015 self.collect_stage_lookup_join_parallelism(root.clone())?
1016 {
1017 has_lookup_join = true;
1018 lookup_join_parallelism
1019 } else if source_info.is_some() {
1020 0
1021 } else if file_scan_info.is_some() {
1022 1
1023 } else {
1024 self.batch_parallelism
1025 }
1026 }
1027 };
1028 if source_info.is_none() && file_scan_info.is_none() && parallelism == 0 {
1029 return Err(BatchError::EmptyWorkerNodes.into());
1030 }
1031 let parallelism = if parallelism == 0 {
1032 None
1033 } else {
1034 Some(parallelism as u32)
1035 };
1036 let dml_table_id = Self::collect_dml_table_id(&root);
1037 let mut builder = QueryStageBuilder::new(
1038 next_stage_id,
1039 self.query_id.clone(),
1040 parallelism,
1041 exchange_info,
1042 table_scan_info,
1043 source_info,
1044 file_scan_info,
1045 has_lookup_join,
1046 dml_table_id,
1047 root.ctx().session_ctx().session_id(),
1048 root.ctx()
1049 .session_ctx()
1050 .config()
1051 .batch_enable_distributed_dml(),
1052 );
1053
1054 self.visit_node(root, &mut builder, None)?;
1055
1056 Ok(builder.finish(self.stage_graph_builder.as_mut().unwrap()))
1057 }
1058
1059 fn visit_node(
1060 &mut self,
1061 node: PlanRef,
1062 builder: &mut QueryStageBuilder,
1063 parent_exec_node: Option<&mut ExecutionPlanNode>,
1064 ) -> SchedulerResult<()> {
1065 match node.node_type() {
1066 PlanNodeType::BatchExchange => {
1067 self.visit_exchange(node.clone(), builder, parent_exec_node)?;
1068 }
1069 _ => {
1070 let mut execution_plan_node = ExecutionPlanNode::try_from(node.clone())?;
1071
1072 for child in node.inputs() {
1073 self.visit_node(child, builder, Some(&mut execution_plan_node))?;
1074 }
1075
1076 if let Some(parent) = parent_exec_node {
1077 parent.children.push(Arc::new(execution_plan_node));
1078 } else {
1079 builder.root = Some(Arc::new(execution_plan_node));
1080 }
1081 }
1082 }
1083 Ok(())
1084 }
1085
1086 fn visit_exchange(
1087 &mut self,
1088 node: PlanRef,
1089 builder: &mut QueryStageBuilder,
1090 parent_exec_node: Option<&mut ExecutionPlanNode>,
1091 ) -> SchedulerResult<()> {
1092 let mut execution_plan_node = ExecutionPlanNode::try_from(node.clone())?;
1093 let child_exchange_info = if let Some(parallelism) = builder.parallelism {
1094 Some(node.distribution().to_prost(
1095 parallelism,
1096 &self.catalog_reader,
1097 &self.worker_node_manager,
1098 )?)
1099 } else {
1100 None
1101 };
1102 let child_stage = self.new_stage(node.inputs()[0].clone(), child_exchange_info)?;
1103 execution_plan_node.source_stage_id = Some(child_stage.id);
1104 if builder.parallelism.is_none() {
1105 builder
1106 .children_exchange_distribution
1107 .insert(child_stage.id, node.distribution().clone());
1108 }
1109
1110 if let Some(parent) = parent_exec_node {
1111 parent.children.push(Arc::new(execution_plan_node));
1112 } else {
1113 builder.root = Some(Arc::new(execution_plan_node));
1114 }
1115
1116 builder.children_stages.push(child_stage);
1117 Ok(())
1118 }
1119
1120 fn collect_stage_source(node: PlanRef) -> SchedulerResult<Option<SourceScanInfo>> {
1125 if node.node_type() == PlanNodeType::BatchExchange {
1126 return Ok(None);
1128 }
1129
1130 if let Some(batch_kafka_node) = node.as_batch_kafka_scan() {
1131 let batch_kafka_scan: &BatchKafkaScan = batch_kafka_node;
1132 let source_catalog = batch_kafka_scan.source_catalog();
1133 if let Some(source_catalog) = source_catalog {
1134 let property =
1135 ConnectorProperties::extract(source_catalog.with_properties.clone(), false)?;
1136 let timestamp_bound = batch_kafka_scan.kafka_timestamp_range_value();
1137 return Ok(Some(SourceScanInfo::new(SourceFetchInfo {
1138 schema: batch_kafka_scan.base.schema().clone(),
1139 connector: property,
1140 fetch_parameters: SourceFetchParameters::KafkaTimebound {
1141 lower: timestamp_bound.0,
1142 upper: timestamp_bound.1,
1143 },
1144 as_of: None,
1145 })));
1146 }
1147 } else if let Some(batch_iceberg_scan) = node.as_batch_iceberg_scan() {
1148 let batch_iceberg_scan: &BatchIcebergScan = batch_iceberg_scan;
1149 let source_catalog = batch_iceberg_scan.source_catalog();
1150 if let Some(source_catalog) = source_catalog {
1151 let property =
1152 ConnectorProperties::extract(source_catalog.with_properties.clone(), false)?;
1153 let as_of = batch_iceberg_scan.as_of();
1154 return Ok(Some(SourceScanInfo::new(SourceFetchInfo {
1155 schema: batch_iceberg_scan.base.schema().clone(),
1156 connector: property,
1157 fetch_parameters: SourceFetchParameters::IcebergSpecificInfo(
1158 IcebergSpecificInfo {
1159 predicate: batch_iceberg_scan.predicate.clone(),
1160 iceberg_scan_type: batch_iceberg_scan.iceberg_scan_type(),
1161 },
1162 ),
1163 as_of,
1164 })));
1165 }
1166 } else if let Some(source_node) = node.as_batch_source() {
1167 let source_node: &BatchSource = source_node;
1169 let source_catalog = source_node.source_catalog();
1170 if let Some(source_catalog) = source_catalog {
1171 let property =
1172 ConnectorProperties::extract(source_catalog.with_properties.clone(), false)?;
1173 let as_of = source_node.as_of();
1174 return Ok(Some(SourceScanInfo::new(SourceFetchInfo {
1175 schema: source_node.base.schema().clone(),
1176 connector: property,
1177 fetch_parameters: SourceFetchParameters::Empty,
1178 as_of,
1179 })));
1180 }
1181 }
1182
1183 node.inputs()
1184 .into_iter()
1185 .find_map(|n| Self::collect_stage_source(n).transpose())
1186 .transpose()
1187 }
1188
1189 fn collect_stage_file_scan(node: PlanRef) -> SchedulerResult<Option<FileScanInfo>> {
1190 if node.node_type() == PlanNodeType::BatchExchange {
1191 return Ok(None);
1193 }
1194
1195 if let Some(batch_file_scan) = node.as_batch_file_scan() {
1196 return Ok(Some(FileScanInfo {
1197 file_location: batch_file_scan.core.file_location().clone(),
1198 }));
1199 }
1200
1201 node.inputs()
1202 .into_iter()
1203 .find_map(|n| Self::collect_stage_file_scan(n).transpose())
1204 .transpose()
1205 }
1206
1207 fn collect_stage_table_scan(&self, node: PlanRef) -> SchedulerResult<Option<TableScanInfo>> {
1212 let build_table_scan_info = |name, table_desc: &TableDesc, scan_range| {
1213 let table_catalog = self
1214 .catalog_reader
1215 .read_guard()
1216 .get_any_table_by_id(&table_desc.table_id)
1217 .cloned()
1218 .map_err(RwError::from)?;
1219 let vnode_mapping = self
1220 .worker_node_manager
1221 .fragment_mapping(table_catalog.fragment_id)?;
1222 let partitions = derive_partitions(scan_range, table_desc, &vnode_mapping)?;
1223 let info = TableScanInfo::new(name, partitions);
1224 Ok(Some(info))
1225 };
1226 if node.node_type() == PlanNodeType::BatchExchange {
1227 return Ok(None);
1229 }
1230 if let Some(scan_node) = node.as_batch_sys_seq_scan() {
1231 let name = scan_node.core().table_name.to_owned();
1232 Ok(Some(TableScanInfo::system_table(name)))
1233 } else if let Some(scan_node) = node.as_batch_log_seq_scan() {
1234 build_table_scan_info(
1235 scan_node.core().table_name.to_owned(),
1236 &scan_node.core().table_desc,
1237 &[],
1238 )
1239 } else if let Some(scan_node) = node.as_batch_seq_scan() {
1240 build_table_scan_info(
1241 scan_node.core().table_name.to_owned(),
1242 &scan_node.core().table_desc,
1243 scan_node.scan_ranges(),
1244 )
1245 } else {
1246 node.inputs()
1247 .into_iter()
1248 .find_map(|n| self.collect_stage_table_scan(n).transpose())
1249 .transpose()
1250 }
1251 }
1252
1253 fn collect_dml_table_id(node: &PlanRef) -> Option<TableId> {
1255 if node.node_type() == PlanNodeType::BatchExchange {
1256 return None;
1257 }
1258 if let Some(insert) = node.as_batch_insert() {
1259 Some(insert.core.table_id)
1260 } else if let Some(update) = node.as_batch_update() {
1261 Some(update.core.table_id)
1262 } else if let Some(delete) = node.as_batch_delete() {
1263 Some(delete.core.table_id)
1264 } else {
1265 node.inputs()
1266 .into_iter()
1267 .find_map(|n| Self::collect_dml_table_id(&n))
1268 }
1269 }
1270
1271 fn collect_stage_lookup_join_parallelism(
1272 &self,
1273 node: PlanRef,
1274 ) -> SchedulerResult<Option<usize>> {
1275 if node.node_type() == PlanNodeType::BatchExchange {
1276 return Ok(None);
1278 }
1279 if let Some(lookup_join) = node.as_batch_lookup_join() {
1280 let table_desc = lookup_join.right_table_desc();
1281 let table_catalog = self
1282 .catalog_reader
1283 .read_guard()
1284 .get_any_table_by_id(&table_desc.table_id)
1285 .cloned()
1286 .map_err(RwError::from)?;
1287 let vnode_mapping = self
1288 .worker_node_manager
1289 .fragment_mapping(table_catalog.fragment_id)?;
1290 let parallelism = vnode_mapping.iter().sorted().dedup().count();
1291 Ok(Some(parallelism))
1292 } else {
1293 node.inputs()
1294 .into_iter()
1295 .find_map(|n| self.collect_stage_lookup_join_parallelism(n).transpose())
1296 .transpose()
1297 }
1298 }
1299}
1300
1301fn derive_partitions(
1304 scan_ranges: &[ScanRange],
1305 table_desc: &TableDesc,
1306 vnode_mapping: &WorkerSlotMapping,
1307) -> SchedulerResult<HashMap<WorkerSlotId, TablePartitionInfo>> {
1308 let vnode_mapping = if table_desc.vnode_count != vnode_mapping.len() {
1309 assert!(
1313 table_desc.vnode_count == 1,
1314 "fragment vnode count {} does not match table vnode count {}",
1315 vnode_mapping.len(),
1316 table_desc.vnode_count,
1317 );
1318 &WorkerSlotMapping::new_single(vnode_mapping.iter().next().unwrap())
1319 } else {
1320 vnode_mapping
1321 };
1322 let vnode_count = vnode_mapping.len();
1323
1324 let mut partitions: HashMap<WorkerSlotId, (BitmapBuilder, Vec<_>)> = HashMap::new();
1325
1326 if scan_ranges.is_empty() {
1327 return Ok(vnode_mapping
1328 .to_bitmaps()
1329 .into_iter()
1330 .map(|(k, vnode_bitmap)| {
1331 (
1332 k,
1333 TablePartitionInfo {
1334 vnode_bitmap,
1335 scan_ranges: vec![],
1336 },
1337 )
1338 })
1339 .collect());
1340 }
1341
1342 let table_distribution = TableDistribution::new_from_storage_table_desc(
1343 Some(Bitmap::ones(vnode_count).into()),
1344 &table_desc.try_to_protobuf()?,
1345 );
1346
1347 for scan_range in scan_ranges {
1348 let vnode = scan_range.try_compute_vnode(&table_distribution);
1349 match vnode {
1350 None => {
1351 vnode_mapping.to_bitmaps().into_iter().for_each(
1353 |(worker_slot_id, vnode_bitmap)| {
1354 let (bitmap, scan_ranges) = partitions
1355 .entry(worker_slot_id)
1356 .or_insert_with(|| (BitmapBuilder::zeroed(vnode_count), vec![]));
1357 vnode_bitmap
1358 .iter()
1359 .enumerate()
1360 .for_each(|(vnode, b)| bitmap.set(vnode, b));
1361 scan_ranges.push(scan_range.to_protobuf());
1362 },
1363 );
1364 }
1365 Some(vnode) => {
1367 let worker_slot_id = vnode_mapping[vnode];
1368 let (bitmap, scan_ranges) = partitions
1369 .entry(worker_slot_id)
1370 .or_insert_with(|| (BitmapBuilder::zeroed(vnode_count), vec![]));
1371 bitmap.set(vnode.to_index(), true);
1372 scan_ranges.push(scan_range.to_protobuf());
1373 }
1374 }
1375 }
1376
1377 Ok(partitions
1378 .into_iter()
1379 .map(|(k, (bitmap, scan_ranges))| {
1380 (
1381 k,
1382 TablePartitionInfo {
1383 vnode_bitmap: bitmap.finish(),
1384 scan_ranges,
1385 },
1386 )
1387 })
1388 .collect())
1389}
1390
1391#[cfg(test)]
1392mod tests {
1393 use std::collections::{HashMap, HashSet};
1394
1395 use risingwave_pb::batch_plan::plan_node::NodeBody;
1396
1397 use crate::optimizer::plan_node::PlanNodeType;
1398 use crate::scheduler::plan_fragmenter::StageId;
1399
1400 #[tokio::test]
1401 async fn test_fragmenter() {
1402 let query = crate::scheduler::distributed::tests::create_query().await;
1403
1404 assert_eq!(query.stage_graph.root_stage_id, 0);
1405 assert_eq!(query.stage_graph.stages.len(), 4);
1406
1407 assert_eq!(query.stage_graph.child_edges[&0], [1].into());
1409 assert_eq!(query.stage_graph.child_edges[&1], [2, 3].into());
1410 assert_eq!(query.stage_graph.child_edges[&2], HashSet::new());
1411 assert_eq!(query.stage_graph.child_edges[&3], HashSet::new());
1412
1413 assert_eq!(query.stage_graph.parent_edges[&0], HashSet::new());
1415 assert_eq!(query.stage_graph.parent_edges[&1], [0].into());
1416 assert_eq!(query.stage_graph.parent_edges[&2], [1].into());
1417 assert_eq!(query.stage_graph.parent_edges[&3], [1].into());
1418
1419 {
1421 let stage_id_to_pos: HashMap<StageId, usize> = query
1422 .stage_graph
1423 .stage_ids_by_topo_order()
1424 .enumerate()
1425 .map(|(pos, stage_id)| (stage_id, pos))
1426 .collect();
1427
1428 for stage_id in query.stage_graph.stages.keys() {
1429 let stage_pos = stage_id_to_pos[stage_id];
1430 for child_stage_id in &query.stage_graph.child_edges[stage_id] {
1431 let child_pos = stage_id_to_pos[child_stage_id];
1432 assert!(stage_pos > child_pos);
1433 }
1434 }
1435 }
1436
1437 let root_exchange = query.stage_graph.stages.get(&0).unwrap();
1439 assert_eq!(root_exchange.root.node_type(), PlanNodeType::BatchExchange);
1440 assert_eq!(root_exchange.root.source_stage_id, Some(1));
1441 assert!(matches!(root_exchange.root.node, NodeBody::Exchange(_)));
1442 assert_eq!(root_exchange.parallelism, Some(1));
1443 assert!(!root_exchange.has_table_scan());
1444
1445 let join_node = query.stage_graph.stages.get(&1).unwrap();
1446 assert_eq!(join_node.root.node_type(), PlanNodeType::BatchHashJoin);
1447 assert_eq!(join_node.parallelism, Some(24));
1448
1449 assert!(matches!(join_node.root.node, NodeBody::HashJoin(_)));
1450 assert_eq!(join_node.root.source_stage_id, None);
1451 assert_eq!(2, join_node.root.children.len());
1452
1453 assert!(matches!(
1454 join_node.root.children[0].node,
1455 NodeBody::Exchange(_)
1456 ));
1457 assert_eq!(join_node.root.children[0].source_stage_id, Some(2));
1458 assert_eq!(0, join_node.root.children[0].children.len());
1459
1460 assert!(matches!(
1461 join_node.root.children[1].node,
1462 NodeBody::Exchange(_)
1463 ));
1464 assert_eq!(join_node.root.children[1].source_stage_id, Some(3));
1465 assert_eq!(0, join_node.root.children[1].children.len());
1466 assert!(!join_node.has_table_scan());
1467
1468 let scan_node1 = query.stage_graph.stages.get(&2).unwrap();
1469 assert_eq!(scan_node1.root.node_type(), PlanNodeType::BatchSeqScan);
1470 assert_eq!(scan_node1.root.source_stage_id, None);
1471 assert_eq!(0, scan_node1.root.children.len());
1472 assert!(scan_node1.has_table_scan());
1473
1474 let scan_node2 = query.stage_graph.stages.get(&3).unwrap();
1475 assert_eq!(scan_node2.root.node_type(), PlanNodeType::BatchFilter);
1476 assert_eq!(scan_node2.root.source_stage_id, None);
1477 assert_eq!(1, scan_node2.root.children.len());
1478 assert!(scan_node2.has_table_scan());
1479 }
1480}