use std::cmp::min;
use std::collections::{HashMap, HashSet};
use std::fmt::{Debug, Formatter};
use std::num::NonZeroU64;
use std::sync::Arc;
use anyhow::anyhow;
use async_recursion::async_recursion;
use enum_as_inner::EnumAsInner;
use futures::TryStreamExt;
use iceberg::expr::Predicate as IcebergPredicate;
use itertools::Itertools;
use pgwire::pg_server::SessionId;
use risingwave_batch::error::BatchError;
use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeSelector;
use risingwave_common::bail;
use risingwave_common::bitmap::{Bitmap, BitmapBuilder};
use risingwave_common::catalog::{Schema, TableDesc};
use risingwave_common::hash::table_distribution::TableDistribution;
use risingwave_common::hash::{WorkerSlotId, WorkerSlotMapping};
use risingwave_common::util::scan_range::ScanRange;
use risingwave_connector::source::filesystem::opendal_source::opendal_enumerator::OpendalEnumerator;
use risingwave_connector::source::filesystem::opendal_source::{
OpendalAzblob, OpendalGcs, OpendalS3,
};
use risingwave_connector::source::iceberg::{IcebergSplitEnumerator, IcebergTimeTravelInfo};
use risingwave_connector::source::kafka::KafkaSplitEnumerator;
use risingwave_connector::source::reader::reader::build_opendal_fs_list_for_batch;
use risingwave_connector::source::{
ConnectorProperties, SourceEnumeratorContext, SplitEnumerator, SplitImpl,
};
use risingwave_pb::batch_plan::plan_node::NodeBody;
use risingwave_pb::batch_plan::{ExchangeInfo, ScanRange as ScanRangeProto};
use risingwave_pb::plan_common::Field as PbField;
use risingwave_sqlparser::ast::AsOf;
use serde::ser::SerializeStruct;
use serde::Serialize;
use uuid::Uuid;
use super::SchedulerError;
use crate::catalog::catalog_service::CatalogReader;
use crate::catalog::TableId;
use crate::error::RwError;
use crate::optimizer::plan_node::generic::{GenericPlanRef, PhysicalPlanRef};
use crate::optimizer::plan_node::{
BatchIcebergScan, BatchKafkaScan, BatchSource, PlanNodeId, PlanNodeType,
};
use crate::optimizer::property::Distribution;
use crate::optimizer::PlanRef;
use crate::scheduler::SchedulerResult;
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
pub struct QueryId {
pub id: String,
}
impl std::fmt::Display for QueryId {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "QueryId:{}", self.id)
}
}
pub type StageId = u32;
pub const ROOT_TASK_ID: u64 = 0;
pub const ROOT_TASK_OUTPUT_ID: u64 = 0;
pub type TaskId = u64;
#[derive(Clone, Debug)]
pub struct ExecutionPlanNode {
pub plan_node_id: PlanNodeId,
pub plan_node_type: PlanNodeType,
pub node: NodeBody,
pub schema: Vec<PbField>,
pub children: Vec<Arc<ExecutionPlanNode>>,
pub source_stage_id: Option<StageId>,
}
impl Serialize for ExecutionPlanNode {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut state = serializer.serialize_struct("QueryStage", 5)?;
state.serialize_field("plan_node_id", &self.plan_node_id)?;
state.serialize_field("plan_node_type", &self.plan_node_type)?;
state.serialize_field("schema", &self.schema)?;
state.serialize_field("children", &self.children)?;
state.serialize_field("source_stage_id", &self.source_stage_id)?;
state.end()
}
}
impl TryFrom<PlanRef> for ExecutionPlanNode {
type Error = SchedulerError;
fn try_from(plan_node: PlanRef) -> Result<Self, Self::Error> {
Ok(Self {
plan_node_id: plan_node.plan_base().id(),
plan_node_type: plan_node.node_type(),
node: plan_node.try_to_batch_prost_body()?,
children: vec![],
schema: plan_node.schema().to_prost(),
source_stage_id: None,
})
}
}
impl ExecutionPlanNode {
pub fn node_type(&self) -> PlanNodeType {
self.plan_node_type
}
}
pub struct BatchPlanFragmenter {
query_id: QueryId,
next_stage_id: StageId,
worker_node_manager: WorkerNodeSelector,
catalog_reader: CatalogReader,
batch_parallelism: usize,
stage_graph_builder: Option<StageGraphBuilder>,
stage_graph: Option<StageGraph>,
}
impl Default for QueryId {
fn default() -> Self {
Self {
id: Uuid::new_v4().to_string(),
}
}
}
impl BatchPlanFragmenter {
pub fn new(
worker_node_manager: WorkerNodeSelector,
catalog_reader: CatalogReader,
batch_parallelism: Option<NonZeroU64>,
batch_node: PlanRef,
) -> SchedulerResult<Self> {
let batch_parallelism = if let Some(num) = batch_parallelism {
min(
num.get() as usize,
worker_node_manager.schedule_unit_count(),
)
} else {
worker_node_manager.schedule_unit_count()
};
let mut plan_fragmenter = Self {
query_id: Default::default(),
next_stage_id: 0,
worker_node_manager,
catalog_reader,
batch_parallelism,
stage_graph_builder: Some(StageGraphBuilder::new(batch_parallelism)),
stage_graph: None,
};
plan_fragmenter.split_into_stage(batch_node)?;
Ok(plan_fragmenter)
}
fn split_into_stage(&mut self, batch_node: PlanRef) -> SchedulerResult<()> {
let root_stage = self.new_stage(
batch_node,
Some(Distribution::Single.to_prost(
1,
&self.catalog_reader,
&self.worker_node_manager,
)?),
)?;
self.stage_graph = Some(
self.stage_graph_builder
.take()
.unwrap()
.build(root_stage.id),
);
Ok(())
}
}
#[derive(Debug)]
#[cfg_attr(test, derive(Clone))]
pub struct Query {
pub query_id: QueryId,
pub stage_graph: StageGraph,
}
impl Query {
pub fn leaf_stages(&self) -> Vec<StageId> {
let mut ret_leaf_stages = Vec::new();
for stage_id in self.stage_graph.stages.keys() {
if self
.stage_graph
.get_child_stages_unchecked(stage_id)
.is_empty()
{
ret_leaf_stages.push(*stage_id);
}
}
ret_leaf_stages
}
pub fn get_parents(&self, stage_id: &StageId) -> &HashSet<StageId> {
self.stage_graph.parent_edges.get(stage_id).unwrap()
}
pub fn root_stage_id(&self) -> StageId {
self.stage_graph.root_stage_id
}
pub fn query_id(&self) -> &QueryId {
&self.query_id
}
pub fn stages_with_table_scan(&self) -> HashSet<StageId> {
self.stage_graph
.stages
.iter()
.filter_map(|(stage_id, stage_query)| {
if stage_query.has_table_scan() {
Some(*stage_id)
} else {
None
}
})
.collect()
}
pub fn has_lookup_join_stage(&self) -> bool {
self.stage_graph
.stages
.iter()
.any(|(_stage_id, stage_query)| stage_query.has_lookup_join())
}
}
#[derive(Debug, Clone)]
pub enum SourceFetchParameters {
IcebergPredicate(IcebergPredicate),
KafkaTimebound {
lower: Option<i64>,
upper: Option<i64>,
},
Empty,
}
#[derive(Debug, Clone)]
pub struct SourceFetchInfo {
pub schema: Schema,
pub connector: ConnectorProperties,
pub fetch_parameters: SourceFetchParameters,
pub as_of: Option<AsOf>,
}
#[derive(Clone, Debug)]
pub enum SourceScanInfo {
Incomplete(SourceFetchInfo),
Complete(Vec<SplitImpl>),
}
impl SourceScanInfo {
pub fn new(fetch_info: SourceFetchInfo) -> Self {
Self::Incomplete(fetch_info)
}
pub async fn complete(self, batch_parallelism: usize) -> SchedulerResult<Self> {
let fetch_info = match self {
SourceScanInfo::Incomplete(fetch_info) => fetch_info,
SourceScanInfo::Complete(_) => {
unreachable!("Never call complete when SourceScanInfo is already complete")
}
};
match (fetch_info.connector, fetch_info.fetch_parameters) {
(
ConnectorProperties::Kafka(prop),
SourceFetchParameters::KafkaTimebound { lower, upper },
) => {
let mut kafka_enumerator =
KafkaSplitEnumerator::new(*prop, SourceEnumeratorContext::dummy().into())
.await?;
let split_info = kafka_enumerator
.list_splits_batch(lower, upper)
.await?
.into_iter()
.map(SplitImpl::Kafka)
.collect_vec();
Ok(SourceScanInfo::Complete(split_info))
}
(ConnectorProperties::OpendalS3(prop), SourceFetchParameters::Empty) => {
let lister: OpendalEnumerator<OpendalS3> =
OpendalEnumerator::new_s3_source(prop.s3_properties, prop.assume_role)?;
let stream = build_opendal_fs_list_for_batch(lister);
let batch_res: Vec<_> = stream.try_collect().await?;
let res = batch_res
.into_iter()
.map(SplitImpl::OpendalS3)
.collect_vec();
Ok(SourceScanInfo::Complete(res))
}
(ConnectorProperties::Gcs(prop), SourceFetchParameters::Empty) => {
let lister: OpendalEnumerator<OpendalGcs> =
OpendalEnumerator::new_gcs_source(*prop)?;
let stream = build_opendal_fs_list_for_batch(lister);
let batch_res: Vec<_> = stream.try_collect().await?;
let res = batch_res.into_iter().map(SplitImpl::Gcs).collect_vec();
Ok(SourceScanInfo::Complete(res))
}
(ConnectorProperties::Azblob(prop), SourceFetchParameters::Empty) => {
let lister: OpendalEnumerator<OpendalAzblob> =
OpendalEnumerator::new_azblob_source(*prop)?;
let stream = build_opendal_fs_list_for_batch(lister);
let batch_res: Vec<_> = stream.try_collect().await?;
let res = batch_res.into_iter().map(SplitImpl::Azblob).collect_vec();
Ok(SourceScanInfo::Complete(res))
}
(
ConnectorProperties::Iceberg(prop),
SourceFetchParameters::IcebergPredicate(predicate),
) => {
let iceberg_enumerator =
IcebergSplitEnumerator::new(*prop, SourceEnumeratorContext::dummy().into())
.await?;
let time_travel_info = match fetch_info.as_of {
Some(AsOf::VersionNum(v)) => Some(IcebergTimeTravelInfo::Version(v)),
Some(AsOf::TimestampNum(ts)) => {
Some(IcebergTimeTravelInfo::TimestampMs(ts * 1000))
}
Some(AsOf::VersionString(_)) => {
bail!("Unsupported version string in iceberg time travel")
}
Some(AsOf::TimestampString(ts)) => Some(
speedate::DateTime::parse_str_rfc3339(&ts)
.map(|t| {
IcebergTimeTravelInfo::TimestampMs(
t.timestamp_tz() * 1000 + t.time.microsecond as i64 / 1000,
)
})
.map_err(|_e| anyhow!("fail to parse timestamp"))?,
),
Some(AsOf::ProcessTime) | Some(AsOf::ProcessTimeWithInterval(_)) => {
unreachable!()
}
None => None,
};
let split_info = iceberg_enumerator
.list_splits_batch(
fetch_info.schema,
time_travel_info,
batch_parallelism,
predicate,
)
.await?
.into_iter()
.map(SplitImpl::Iceberg)
.collect_vec();
Ok(SourceScanInfo::Complete(split_info))
}
_ => Err(SchedulerError::Internal(anyhow!(
"Unsupported to query directly from this source"
))),
}
}
pub fn split_info(&self) -> SchedulerResult<&Vec<SplitImpl>> {
match self {
Self::Incomplete(_) => Err(SchedulerError::Internal(anyhow!(
"Should not get split info from incomplete source scan info"
))),
Self::Complete(split_info) => Ok(split_info),
}
}
}
#[derive(Clone, Debug)]
pub struct TableScanInfo {
name: String,
partitions: Option<HashMap<WorkerSlotId, TablePartitionInfo>>,
}
impl TableScanInfo {
pub fn new(name: String, partitions: HashMap<WorkerSlotId, TablePartitionInfo>) -> Self {
Self {
name,
partitions: Some(partitions),
}
}
pub fn system_table(name: String) -> Self {
Self {
name,
partitions: None,
}
}
pub fn name(&self) -> &str {
self.name.as_ref()
}
pub fn partitions(&self) -> Option<&HashMap<WorkerSlotId, TablePartitionInfo>> {
self.partitions.as_ref()
}
}
#[derive(Clone, Debug)]
pub struct TablePartitionInfo {
pub vnode_bitmap: Bitmap,
pub scan_ranges: Vec<ScanRangeProto>,
}
#[derive(Clone, Debug, EnumAsInner)]
pub enum PartitionInfo {
Table(TablePartitionInfo),
Source(Vec<SplitImpl>),
File(Vec<String>),
}
#[derive(Clone, Debug)]
pub struct FileScanInfo {
pub file_location: Vec<String>,
}
#[derive(Clone)]
pub struct QueryStage {
pub query_id: QueryId,
pub id: StageId,
pub root: Arc<ExecutionPlanNode>,
pub exchange_info: Option<ExchangeInfo>,
pub parallelism: Option<u32>,
pub table_scan_info: Option<TableScanInfo>,
pub source_info: Option<SourceScanInfo>,
pub file_scan_info: Option<FileScanInfo>,
pub has_lookup_join: bool,
pub dml_table_id: Option<TableId>,
pub session_id: SessionId,
pub batch_enable_distributed_dml: bool,
children_exchange_distribution: Option<HashMap<StageId, Distribution>>,
}
impl QueryStage {
pub fn has_table_scan(&self) -> bool {
self.table_scan_info.is_some()
}
pub fn has_lookup_join(&self) -> bool {
self.has_lookup_join
}
pub fn clone_with_exchange_info(
&self,
exchange_info: Option<ExchangeInfo>,
parallelism: Option<u32>,
) -> Self {
if let Some(exchange_info) = exchange_info {
return Self {
query_id: self.query_id.clone(),
id: self.id,
root: self.root.clone(),
exchange_info: Some(exchange_info),
parallelism,
table_scan_info: self.table_scan_info.clone(),
source_info: self.source_info.clone(),
file_scan_info: self.file_scan_info.clone(),
has_lookup_join: self.has_lookup_join,
dml_table_id: self.dml_table_id,
session_id: self.session_id,
batch_enable_distributed_dml: self.batch_enable_distributed_dml,
children_exchange_distribution: self.children_exchange_distribution.clone(),
};
}
self.clone()
}
pub fn clone_with_exchange_info_and_complete_source_info(
&self,
exchange_info: Option<ExchangeInfo>,
source_info: SourceScanInfo,
task_parallelism: u32,
) -> Self {
assert!(matches!(source_info, SourceScanInfo::Complete(_)));
let exchange_info = if let Some(exchange_info) = exchange_info {
Some(exchange_info)
} else {
self.exchange_info.clone()
};
Self {
query_id: self.query_id.clone(),
id: self.id,
root: self.root.clone(),
exchange_info,
parallelism: Some(task_parallelism),
table_scan_info: self.table_scan_info.clone(),
source_info: Some(source_info),
file_scan_info: self.file_scan_info.clone(),
has_lookup_join: self.has_lookup_join,
dml_table_id: self.dml_table_id,
session_id: self.session_id,
batch_enable_distributed_dml: self.batch_enable_distributed_dml,
children_exchange_distribution: None,
}
}
}
impl Debug for QueryStage {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("QueryStage")
.field("id", &self.id)
.field("parallelism", &self.parallelism)
.field("exchange_info", &self.exchange_info)
.field("has_table_scan", &self.has_table_scan())
.finish()
}
}
impl Serialize for QueryStage {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut state = serializer.serialize_struct("QueryStage", 3)?;
state.serialize_field("root", &self.root)?;
state.serialize_field("parallelism", &self.parallelism)?;
state.serialize_field("exchange_info", &self.exchange_info)?;
state.end()
}
}
pub type QueryStageRef = Arc<QueryStage>;
struct QueryStageBuilder {
query_id: QueryId,
id: StageId,
root: Option<Arc<ExecutionPlanNode>>,
parallelism: Option<u32>,
exchange_info: Option<ExchangeInfo>,
children_stages: Vec<QueryStageRef>,
table_scan_info: Option<TableScanInfo>,
source_info: Option<SourceScanInfo>,
file_scan_file: Option<FileScanInfo>,
has_lookup_join: bool,
dml_table_id: Option<TableId>,
session_id: SessionId,
batch_enable_distributed_dml: bool,
children_exchange_distribution: HashMap<StageId, Distribution>,
}
impl QueryStageBuilder {
#[allow(clippy::too_many_arguments)]
fn new(
id: StageId,
query_id: QueryId,
parallelism: Option<u32>,
exchange_info: Option<ExchangeInfo>,
table_scan_info: Option<TableScanInfo>,
source_info: Option<SourceScanInfo>,
file_scan_file: Option<FileScanInfo>,
has_lookup_join: bool,
dml_table_id: Option<TableId>,
session_id: SessionId,
batch_enable_distributed_dml: bool,
) -> Self {
Self {
query_id,
id,
root: None,
parallelism,
exchange_info,
children_stages: vec![],
table_scan_info,
source_info,
file_scan_file,
has_lookup_join,
dml_table_id,
session_id,
batch_enable_distributed_dml,
children_exchange_distribution: HashMap::new(),
}
}
fn finish(self, stage_graph_builder: &mut StageGraphBuilder) -> QueryStageRef {
let children_exchange_distribution = if self.parallelism.is_none() {
Some(self.children_exchange_distribution)
} else {
None
};
let stage = Arc::new(QueryStage {
query_id: self.query_id,
id: self.id,
root: self.root.unwrap(),
exchange_info: self.exchange_info,
parallelism: self.parallelism,
table_scan_info: self.table_scan_info,
source_info: self.source_info,
file_scan_info: self.file_scan_file,
has_lookup_join: self.has_lookup_join,
dml_table_id: self.dml_table_id,
session_id: self.session_id,
batch_enable_distributed_dml: self.batch_enable_distributed_dml,
children_exchange_distribution,
});
stage_graph_builder.add_node(stage.clone());
for child_stage in self.children_stages {
stage_graph_builder.link_to_child(self.id, child_stage.id);
}
stage
}
}
#[derive(Debug, Serialize)]
#[cfg_attr(test, derive(Clone))]
pub struct StageGraph {
pub root_stage_id: StageId,
pub stages: HashMap<StageId, QueryStageRef>,
child_edges: HashMap<StageId, HashSet<StageId>>,
parent_edges: HashMap<StageId, HashSet<StageId>>,
batch_parallelism: usize,
}
impl StageGraph {
pub fn get_child_stages_unchecked(&self, stage_id: &StageId) -> &HashSet<StageId> {
self.child_edges.get(stage_id).unwrap()
}
pub fn get_child_stages(&self, stage_id: &StageId) -> Option<&HashSet<StageId>> {
self.child_edges.get(stage_id)
}
pub fn stage_ids_by_topo_order(&self) -> impl Iterator<Item = StageId> {
let mut stack = Vec::with_capacity(self.stages.len());
stack.push(self.root_stage_id);
let mut ret = Vec::with_capacity(self.stages.len());
let mut existing = HashSet::with_capacity(self.stages.len());
while let Some(s) = stack.pop() {
if !existing.contains(&s) {
ret.push(s);
existing.insert(s);
stack.extend(&self.child_edges[&s]);
}
}
ret.into_iter().rev()
}
async fn complete(
self,
catalog_reader: &CatalogReader,
worker_node_manager: &WorkerNodeSelector,
) -> SchedulerResult<StageGraph> {
let mut complete_stages = HashMap::new();
self.complete_stage(
self.stages.get(&self.root_stage_id).unwrap().clone(),
None,
&mut complete_stages,
catalog_reader,
worker_node_manager,
)
.await?;
Ok(StageGraph {
root_stage_id: self.root_stage_id,
stages: complete_stages,
child_edges: self.child_edges,
parent_edges: self.parent_edges,
batch_parallelism: self.batch_parallelism,
})
}
#[async_recursion]
async fn complete_stage(
&self,
stage: QueryStageRef,
exchange_info: Option<ExchangeInfo>,
complete_stages: &mut HashMap<StageId, QueryStageRef>,
catalog_reader: &CatalogReader,
worker_node_manager: &WorkerNodeSelector,
) -> SchedulerResult<()> {
let parallelism = if stage.parallelism.is_some() {
complete_stages.insert(
stage.id,
Arc::new(stage.clone_with_exchange_info(exchange_info, stage.parallelism)),
);
None
} else if matches!(stage.source_info, Some(SourceScanInfo::Incomplete(_))) {
let complete_source_info = stage
.source_info
.as_ref()
.unwrap()
.clone()
.complete(self.batch_parallelism)
.await?;
let task_parallelism = match &stage.source_info {
Some(SourceScanInfo::Incomplete(source_fetch_info)) => {
match source_fetch_info.connector {
ConnectorProperties::Gcs(_)
| ConnectorProperties::OpendalS3(_)
| ConnectorProperties::Azblob(_) => (min(
complete_source_info.split_info().unwrap().len() as u32,
(self.batch_parallelism / 2) as u32,
))
.max(1),
_ => complete_source_info.split_info().unwrap().len() as u32,
}
}
_ => unreachable!(),
};
let complete_stage = Arc::new(stage.clone_with_exchange_info_and_complete_source_info(
exchange_info,
complete_source_info,
task_parallelism,
));
let parallelism = complete_stage.parallelism;
complete_stages.insert(stage.id, complete_stage);
parallelism
} else {
assert!(stage.file_scan_info.is_some());
let parallelism = min(
self.batch_parallelism / 2,
stage.file_scan_info.as_ref().unwrap().file_location.len(),
);
complete_stages.insert(
stage.id,
Arc::new(stage.clone_with_exchange_info(exchange_info, Some(parallelism as u32))),
);
None
};
for child_stage_id in self.child_edges.get(&stage.id).unwrap_or(&HashSet::new()) {
let exchange_info = if let Some(parallelism) = parallelism {
let exchange_distribution = stage
.children_exchange_distribution
.as_ref()
.unwrap()
.get(child_stage_id)
.expect("Exchange distribution is not consistent with the stage graph");
Some(exchange_distribution.to_prost(
parallelism,
catalog_reader,
worker_node_manager,
)?)
} else {
None
};
self.complete_stage(
self.stages.get(child_stage_id).unwrap().clone(),
exchange_info,
complete_stages,
catalog_reader,
worker_node_manager,
)
.await?;
}
Ok(())
}
}
struct StageGraphBuilder {
stages: HashMap<StageId, QueryStageRef>,
child_edges: HashMap<StageId, HashSet<StageId>>,
parent_edges: HashMap<StageId, HashSet<StageId>>,
batch_parallelism: usize,
}
impl StageGraphBuilder {
pub fn new(batch_parallelism: usize) -> Self {
Self {
stages: HashMap::new(),
child_edges: HashMap::new(),
parent_edges: HashMap::new(),
batch_parallelism,
}
}
pub fn build(self, root_stage_id: StageId) -> StageGraph {
StageGraph {
root_stage_id,
stages: self.stages,
child_edges: self.child_edges,
parent_edges: self.parent_edges,
batch_parallelism: self.batch_parallelism,
}
}
pub fn link_to_child(&mut self, parent_id: StageId, child_id: StageId) {
self.child_edges
.get_mut(&parent_id)
.unwrap()
.insert(child_id);
self.parent_edges
.get_mut(&child_id)
.unwrap()
.insert(parent_id);
}
pub fn add_node(&mut self, stage: QueryStageRef) {
self.child_edges.insert(stage.id, HashSet::new());
self.parent_edges.insert(stage.id, HashSet::new());
self.stages.insert(stage.id, stage);
}
}
impl BatchPlanFragmenter {
pub async fn generate_complete_query(self) -> SchedulerResult<Query> {
let stage_graph = self.stage_graph.unwrap();
let new_stage_graph = stage_graph
.complete(&self.catalog_reader, &self.worker_node_manager)
.await?;
Ok(Query {
query_id: self.query_id,
stage_graph: new_stage_graph,
})
}
fn new_stage(
&mut self,
root: PlanRef,
exchange_info: Option<ExchangeInfo>,
) -> SchedulerResult<QueryStageRef> {
let next_stage_id = self.next_stage_id;
self.next_stage_id += 1;
let mut table_scan_info = self.collect_stage_table_scan(root.clone())?;
let source_info = if table_scan_info.is_none() {
Self::collect_stage_source(root.clone())?
} else {
None
};
let file_scan_info = if table_scan_info.is_none() && source_info.is_none() {
Self::collect_stage_file_scan(root.clone())?
} else {
None
};
let mut has_lookup_join = false;
let parallelism = match root.distribution() {
Distribution::Single => {
if let Some(info) = &mut table_scan_info {
if let Some(partitions) = &mut info.partitions {
if partitions.len() != 1 {
tracing::warn!(
"The stage has single distribution, but contains a scan of table `{}` with {} partitions. A single random worker will be assigned",
info.name,
partitions.len()
);
*partitions = partitions
.drain()
.take(1)
.update(|(_, info)| {
info.vnode_bitmap = Bitmap::ones(info.vnode_bitmap.len());
})
.collect();
}
} else {
}
} else if source_info.is_some() {
return Err(SchedulerError::Internal(anyhow!(
"The stage has single distribution, but contains a source operator"
)));
}
1
}
_ => {
if let Some(table_scan_info) = &table_scan_info {
table_scan_info
.partitions
.as_ref()
.map(|m| m.len())
.unwrap_or(1)
} else if let Some(lookup_join_parallelism) =
self.collect_stage_lookup_join_parallelism(root.clone())?
{
has_lookup_join = true;
lookup_join_parallelism
} else if source_info.is_some() {
0
} else if file_scan_info.is_some() {
1
} else {
self.batch_parallelism
}
}
};
if source_info.is_none() && file_scan_info.is_none() && parallelism == 0 {
return Err(BatchError::EmptyWorkerNodes.into());
}
let parallelism = if parallelism == 0 {
None
} else {
Some(parallelism as u32)
};
let dml_table_id = Self::collect_dml_table_id(&root);
let mut builder = QueryStageBuilder::new(
next_stage_id,
self.query_id.clone(),
parallelism,
exchange_info,
table_scan_info,
source_info,
file_scan_info,
has_lookup_join,
dml_table_id,
root.ctx().session_ctx().session_id(),
root.ctx()
.session_ctx()
.config()
.batch_enable_distributed_dml(),
);
self.visit_node(root, &mut builder, None)?;
Ok(builder.finish(self.stage_graph_builder.as_mut().unwrap()))
}
fn visit_node(
&mut self,
node: PlanRef,
builder: &mut QueryStageBuilder,
parent_exec_node: Option<&mut ExecutionPlanNode>,
) -> SchedulerResult<()> {
match node.node_type() {
PlanNodeType::BatchExchange => {
self.visit_exchange(node.clone(), builder, parent_exec_node)?;
}
_ => {
let mut execution_plan_node = ExecutionPlanNode::try_from(node.clone())?;
for child in node.inputs() {
self.visit_node(child, builder, Some(&mut execution_plan_node))?;
}
if let Some(parent) = parent_exec_node {
parent.children.push(Arc::new(execution_plan_node));
} else {
builder.root = Some(Arc::new(execution_plan_node));
}
}
}
Ok(())
}
fn visit_exchange(
&mut self,
node: PlanRef,
builder: &mut QueryStageBuilder,
parent_exec_node: Option<&mut ExecutionPlanNode>,
) -> SchedulerResult<()> {
let mut execution_plan_node = ExecutionPlanNode::try_from(node.clone())?;
let child_exchange_info = if let Some(parallelism) = builder.parallelism {
Some(node.distribution().to_prost(
parallelism,
&self.catalog_reader,
&self.worker_node_manager,
)?)
} else {
None
};
let child_stage = self.new_stage(node.inputs()[0].clone(), child_exchange_info)?;
execution_plan_node.source_stage_id = Some(child_stage.id);
if builder.parallelism.is_none() {
builder
.children_exchange_distribution
.insert(child_stage.id, node.distribution().clone());
}
if let Some(parent) = parent_exec_node {
parent.children.push(Arc::new(execution_plan_node));
} else {
builder.root = Some(Arc::new(execution_plan_node));
}
builder.children_stages.push(child_stage);
Ok(())
}
fn collect_stage_source(node: PlanRef) -> SchedulerResult<Option<SourceScanInfo>> {
if node.node_type() == PlanNodeType::BatchExchange {
return Ok(None);
}
if let Some(batch_kafka_node) = node.as_batch_kafka_scan() {
let batch_kafka_scan: &BatchKafkaScan = batch_kafka_node;
let source_catalog = batch_kafka_scan.source_catalog();
if let Some(source_catalog) = source_catalog {
let property =
ConnectorProperties::extract(source_catalog.with_properties.clone(), false)?;
let timestamp_bound = batch_kafka_scan.kafka_timestamp_range_value();
return Ok(Some(SourceScanInfo::new(SourceFetchInfo {
schema: batch_kafka_scan.base.schema().clone(),
connector: property,
fetch_parameters: SourceFetchParameters::KafkaTimebound {
lower: timestamp_bound.0,
upper: timestamp_bound.1,
},
as_of: None,
})));
}
} else if let Some(batch_iceberg_scan) = node.as_batch_iceberg_scan() {
let batch_iceberg_scan: &BatchIcebergScan = batch_iceberg_scan;
let source_catalog = batch_iceberg_scan.source_catalog();
if let Some(source_catalog) = source_catalog {
let property =
ConnectorProperties::extract(source_catalog.with_properties.clone(), false)?;
let as_of = batch_iceberg_scan.as_of();
return Ok(Some(SourceScanInfo::new(SourceFetchInfo {
schema: batch_iceberg_scan.base.schema().clone(),
connector: property,
fetch_parameters: SourceFetchParameters::IcebergPredicate(
batch_iceberg_scan.predicate.clone(),
),
as_of,
})));
}
} else if let Some(source_node) = node.as_batch_source() {
let source_node: &BatchSource = source_node;
let source_catalog = source_node.source_catalog();
if let Some(source_catalog) = source_catalog {
let property =
ConnectorProperties::extract(source_catalog.with_properties.clone(), false)?;
let as_of = source_node.as_of();
return Ok(Some(SourceScanInfo::new(SourceFetchInfo {
schema: source_node.base.schema().clone(),
connector: property,
fetch_parameters: SourceFetchParameters::Empty,
as_of,
})));
}
}
node.inputs()
.into_iter()
.find_map(|n| Self::collect_stage_source(n).transpose())
.transpose()
}
fn collect_stage_file_scan(node: PlanRef) -> SchedulerResult<Option<FileScanInfo>> {
if node.node_type() == PlanNodeType::BatchExchange {
return Ok(None);
}
if let Some(batch_file_scan) = node.as_batch_file_scan() {
return Ok(Some(FileScanInfo {
file_location: batch_file_scan.core.file_location.clone(),
}));
}
node.inputs()
.into_iter()
.find_map(|n| Self::collect_stage_file_scan(n).transpose())
.transpose()
}
fn collect_stage_table_scan(&self, node: PlanRef) -> SchedulerResult<Option<TableScanInfo>> {
let build_table_scan_info = |name, table_desc: &TableDesc, scan_range| {
let table_catalog = self
.catalog_reader
.read_guard()
.get_any_table_by_id(&table_desc.table_id)
.cloned()
.map_err(RwError::from)?;
let vnode_mapping = self
.worker_node_manager
.fragment_mapping(table_catalog.fragment_id)?;
let partitions = derive_partitions(scan_range, table_desc, &vnode_mapping)?;
let info = TableScanInfo::new(name, partitions);
Ok(Some(info))
};
if node.node_type() == PlanNodeType::BatchExchange {
return Ok(None);
}
if let Some(scan_node) = node.as_batch_sys_seq_scan() {
let name = scan_node.core().table_name.to_owned();
Ok(Some(TableScanInfo::system_table(name)))
} else if let Some(scan_node) = node.as_batch_log_seq_scan() {
build_table_scan_info(
scan_node.core().table_name.to_owned(),
&scan_node.core().table_desc,
&[],
)
} else if let Some(scan_node) = node.as_batch_seq_scan() {
build_table_scan_info(
scan_node.core().table_name.to_owned(),
&scan_node.core().table_desc,
scan_node.scan_ranges(),
)
} else {
node.inputs()
.into_iter()
.find_map(|n| self.collect_stage_table_scan(n).transpose())
.transpose()
}
}
fn collect_dml_table_id(node: &PlanRef) -> Option<TableId> {
if node.node_type() == PlanNodeType::BatchExchange {
return None;
}
if let Some(insert) = node.as_batch_insert() {
Some(insert.core.table_id)
} else if let Some(update) = node.as_batch_update() {
Some(update.core.table_id)
} else if let Some(delete) = node.as_batch_delete() {
Some(delete.core.table_id)
} else {
node.inputs()
.into_iter()
.find_map(|n| Self::collect_dml_table_id(&n))
}
}
fn collect_stage_lookup_join_parallelism(
&self,
node: PlanRef,
) -> SchedulerResult<Option<usize>> {
if node.node_type() == PlanNodeType::BatchExchange {
return Ok(None);
}
if let Some(lookup_join) = node.as_batch_lookup_join() {
let table_desc = lookup_join.right_table_desc();
let table_catalog = self
.catalog_reader
.read_guard()
.get_any_table_by_id(&table_desc.table_id)
.cloned()
.map_err(RwError::from)?;
let vnode_mapping = self
.worker_node_manager
.fragment_mapping(table_catalog.fragment_id)?;
let parallelism = vnode_mapping.iter().sorted().dedup().count();
Ok(Some(parallelism))
} else {
node.inputs()
.into_iter()
.find_map(|n| self.collect_stage_lookup_join_parallelism(n).transpose())
.transpose()
}
}
}
fn derive_partitions(
scan_ranges: &[ScanRange],
table_desc: &TableDesc,
vnode_mapping: &WorkerSlotMapping,
) -> SchedulerResult<HashMap<WorkerSlotId, TablePartitionInfo>> {
let vnode_mapping = if table_desc.vnode_count != vnode_mapping.len() {
assert!(
table_desc.vnode_count == 1,
"fragment vnode count {} does not match table vnode count {}",
vnode_mapping.len(),
table_desc.vnode_count,
);
&WorkerSlotMapping::new_single(vnode_mapping.iter().next().unwrap())
} else {
vnode_mapping
};
let vnode_count = vnode_mapping.len();
let mut partitions: HashMap<WorkerSlotId, (BitmapBuilder, Vec<_>)> = HashMap::new();
if scan_ranges.is_empty() {
return Ok(vnode_mapping
.to_bitmaps()
.into_iter()
.map(|(k, vnode_bitmap)| {
(
k,
TablePartitionInfo {
vnode_bitmap,
scan_ranges: vec![],
},
)
})
.collect());
}
let table_distribution = TableDistribution::new_from_storage_table_desc(
Some(Bitmap::ones(vnode_count).into()),
&table_desc.try_to_protobuf()?,
);
for scan_range in scan_ranges {
let vnode = scan_range.try_compute_vnode(&table_distribution);
match vnode {
None => {
vnode_mapping.to_bitmaps().into_iter().for_each(
|(worker_slot_id, vnode_bitmap)| {
let (bitmap, scan_ranges) = partitions
.entry(worker_slot_id)
.or_insert_with(|| (BitmapBuilder::zeroed(vnode_count), vec![]));
vnode_bitmap
.iter()
.enumerate()
.for_each(|(vnode, b)| bitmap.set(vnode, b));
scan_ranges.push(scan_range.to_protobuf());
},
);
}
Some(vnode) => {
let worker_slot_id = vnode_mapping[vnode];
let (bitmap, scan_ranges) = partitions
.entry(worker_slot_id)
.or_insert_with(|| (BitmapBuilder::zeroed(vnode_count), vec![]));
bitmap.set(vnode.to_index(), true);
scan_ranges.push(scan_range.to_protobuf());
}
}
}
Ok(partitions
.into_iter()
.map(|(k, (bitmap, scan_ranges))| {
(
k,
TablePartitionInfo {
vnode_bitmap: bitmap.finish(),
scan_ranges,
},
)
})
.collect())
}
#[cfg(test)]
mod tests {
use std::collections::{HashMap, HashSet};
use risingwave_pb::batch_plan::plan_node::NodeBody;
use crate::optimizer::plan_node::PlanNodeType;
use crate::scheduler::plan_fragmenter::StageId;
#[tokio::test]
async fn test_fragmenter() {
let query = crate::scheduler::distributed::tests::create_query().await;
assert_eq!(query.stage_graph.root_stage_id, 0);
assert_eq!(query.stage_graph.stages.len(), 4);
assert_eq!(query.stage_graph.child_edges[&0], [1].into());
assert_eq!(query.stage_graph.child_edges[&1], [2, 3].into());
assert_eq!(query.stage_graph.child_edges[&2], HashSet::new());
assert_eq!(query.stage_graph.child_edges[&3], HashSet::new());
assert_eq!(query.stage_graph.parent_edges[&0], HashSet::new());
assert_eq!(query.stage_graph.parent_edges[&1], [0].into());
assert_eq!(query.stage_graph.parent_edges[&2], [1].into());
assert_eq!(query.stage_graph.parent_edges[&3], [1].into());
{
let stage_id_to_pos: HashMap<StageId, usize> = query
.stage_graph
.stage_ids_by_topo_order()
.enumerate()
.map(|(pos, stage_id)| (stage_id, pos))
.collect();
for stage_id in query.stage_graph.stages.keys() {
let stage_pos = stage_id_to_pos[stage_id];
for child_stage_id in &query.stage_graph.child_edges[stage_id] {
let child_pos = stage_id_to_pos[child_stage_id];
assert!(stage_pos > child_pos);
}
}
}
let root_exchange = query.stage_graph.stages.get(&0).unwrap();
assert_eq!(root_exchange.root.node_type(), PlanNodeType::BatchExchange);
assert_eq!(root_exchange.root.source_stage_id, Some(1));
assert!(matches!(root_exchange.root.node, NodeBody::Exchange(_)));
assert_eq!(root_exchange.parallelism, Some(1));
assert!(!root_exchange.has_table_scan());
let join_node = query.stage_graph.stages.get(&1).unwrap();
assert_eq!(join_node.root.node_type(), PlanNodeType::BatchHashJoin);
assert_eq!(join_node.parallelism, Some(24));
assert!(matches!(join_node.root.node, NodeBody::HashJoin(_)));
assert_eq!(join_node.root.source_stage_id, None);
assert_eq!(2, join_node.root.children.len());
assert!(matches!(
join_node.root.children[0].node,
NodeBody::Exchange(_)
));
assert_eq!(join_node.root.children[0].source_stage_id, Some(2));
assert_eq!(0, join_node.root.children[0].children.len());
assert!(matches!(
join_node.root.children[1].node,
NodeBody::Exchange(_)
));
assert_eq!(join_node.root.children[1].source_stage_id, Some(3));
assert_eq!(0, join_node.root.children[1].children.len());
assert!(!join_node.has_table_scan());
let scan_node1 = query.stage_graph.stages.get(&2).unwrap();
assert_eq!(scan_node1.root.node_type(), PlanNodeType::BatchSeqScan);
assert_eq!(scan_node1.root.source_stage_id, None);
assert_eq!(0, scan_node1.root.children.len());
assert!(scan_node1.has_table_scan());
let scan_node2 = query.stage_graph.stages.get(&3).unwrap();
assert_eq!(scan_node2.root.node_type(), PlanNodeType::BatchFilter);
assert_eq!(scan_node2.root.source_stage_id, None);
assert_eq!(1, scan_node2.root.children.len());
assert!(scan_node2.has_table_scan());
}
}