use std::collections::HashMap;
use std::fmt::{Debug, Formatter};
use std::mem;
use std::sync::Arc;
use anyhow::Context;
use futures::executor::block_on;
use petgraph::dot::{Config, Dot};
use petgraph::Graph;
use pgwire::pg_server::SessionId;
use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeSelector;
use risingwave_common::array::DataChunk;
use risingwave_pb::batch_plan::{TaskId as PbTaskId, TaskOutputId as PbTaskOutputId};
use risingwave_pb::common::{BatchQueryEpoch, HostAddress};
use risingwave_rpc_client::ComputeClientPoolRef;
use thiserror_ext::AsReport;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::sync::{oneshot, RwLock};
use tokio::task::JoinHandle;
use tracing::{debug, error, info, warn, Instrument};
use super::{DistributedQueryMetrics, QueryExecutionInfoRef, QueryResultFetcher, StageEvent};
use crate::catalog::catalog_service::CatalogReader;
use crate::scheduler::distributed::query::QueryMessage::Stage;
use crate::scheduler::distributed::stage::StageEvent::ScheduledRoot;
use crate::scheduler::distributed::StageEvent::Scheduled;
use crate::scheduler::distributed::StageExecution;
use crate::scheduler::plan_fragmenter::{Query, StageId, ROOT_TASK_ID, ROOT_TASK_OUTPUT_ID};
use crate::scheduler::{ExecutionContextRef, SchedulerError, SchedulerResult};
#[derive(Debug)]
pub enum QueryMessage {
Stage(StageEvent),
CancelQuery(String),
}
enum QueryState {
Pending {
msg_receiver: Receiver<QueryMessage>,
},
Running,
Failed,
}
pub struct QueryExecution {
query: Arc<Query>,
state: RwLock<QueryState>,
shutdown_tx: Sender<QueryMessage>,
pub session_id: SessionId,
pub permit: Option<tokio::sync::OwnedSemaphorePermit>,
}
struct QueryRunner {
query: Arc<Query>,
stage_executions: HashMap<StageId, Arc<StageExecution>>,
scheduled_stages_count: usize,
msg_receiver: Receiver<QueryMessage>,
root_stage_sender: Option<oneshot::Sender<SchedulerResult<QueryResultFetcher>>>,
query_execution_info: QueryExecutionInfoRef,
query_metrics: Arc<DistributedQueryMetrics>,
timeout_abort_task_handle: Option<JoinHandle<()>>,
}
impl QueryExecution {
#[allow(clippy::too_many_arguments)]
pub fn new(
query: Query,
session_id: SessionId,
permit: Option<tokio::sync::OwnedSemaphorePermit>,
) -> Self {
let query = Arc::new(query);
let (sender, receiver) = channel(100);
let state = QueryState::Pending {
msg_receiver: receiver,
};
Self {
query,
state: RwLock::new(state),
shutdown_tx: sender,
session_id,
permit,
}
}
#[allow(clippy::too_many_arguments)]
pub async fn start(
self: Arc<Self>,
context: ExecutionContextRef,
worker_node_manager: WorkerNodeSelector,
batch_query_epoch: BatchQueryEpoch,
compute_client_pool: ComputeClientPoolRef,
catalog_reader: CatalogReader,
query_execution_info: QueryExecutionInfoRef,
query_metrics: Arc<DistributedQueryMetrics>,
) -> SchedulerResult<QueryResultFetcher> {
let mut state = self.state.write().await;
let cur_state = mem::replace(&mut *state, QueryState::Failed);
let stage_executions = self.gen_stage_executions(
batch_query_epoch,
context.clone(),
worker_node_manager,
compute_client_pool.clone(),
catalog_reader,
);
match cur_state {
QueryState::Pending { msg_receiver } => {
*state = QueryState::Running;
let mut timeout_abort_task_handle: Option<JoinHandle<()>> = None;
if let Some(timeout) = context.timeout() {
let this = self.clone();
timeout_abort_task_handle = Some(tokio::spawn(async move {
tokio::time::sleep(timeout).await;
warn!(
"Query {:?} timeout after {} seconds, sending cancel message.",
this.query.query_id,
timeout.as_secs(),
);
this.abort(format!("timeout after {} seconds", timeout.as_secs()))
.await;
}));
}
let (root_stage_sender, root_stage_receiver) =
oneshot::channel::<SchedulerResult<QueryResultFetcher>>();
let runner = QueryRunner {
query: self.query.clone(),
stage_executions,
msg_receiver,
root_stage_sender: Some(root_stage_sender),
scheduled_stages_count: 0,
query_execution_info,
query_metrics,
timeout_abort_task_handle,
};
let span = tracing::info_span!(
"distributed_execute",
query_id = self.query.query_id.id,
epoch = ?batch_query_epoch,
);
tracing::trace!("Starting query: {:?}", self.query.query_id);
tokio::spawn(async move { runner.run().instrument(span).await });
let root_stage = root_stage_receiver
.await
.context("Starting query execution failed")??;
tracing::trace!(
"Received root stage query result fetcher: {:?}, query id: {:?}",
root_stage,
self.query.query_id
);
tracing::trace!("Query {:?} started.", self.query.query_id);
Ok(root_stage)
}
_ => {
unreachable!("The query runner should not be scheduled twice");
}
}
}
pub async fn abort(self: Arc<Self>, reason: String) {
if self
.shutdown_tx
.send(QueryMessage::CancelQuery(reason))
.await
.is_err()
{
warn!("Send cancel query request failed: the query has ended");
} else {
info!("Send cancel request to query-{:?}", self.query.query_id);
};
}
fn gen_stage_executions(
&self,
epoch: BatchQueryEpoch,
context: ExecutionContextRef,
worker_node_manager: WorkerNodeSelector,
compute_client_pool: ComputeClientPoolRef,
catalog_reader: CatalogReader,
) -> HashMap<StageId, Arc<StageExecution>> {
let mut stage_executions: HashMap<StageId, Arc<StageExecution>> =
HashMap::with_capacity(self.query.stage_graph.stages.len());
for stage_id in self.query.stage_graph.stage_ids_by_topo_order() {
let children_stages = self
.query
.stage_graph
.get_child_stages_unchecked(&stage_id)
.iter()
.map(|s| stage_executions[s].clone())
.collect::<Vec<Arc<StageExecution>>>();
let stage_exec = Arc::new(StageExecution::new(
epoch,
self.query.stage_graph.stages[&stage_id].clone(),
worker_node_manager.clone(),
self.shutdown_tx.clone(),
children_stages,
compute_client_pool.clone(),
catalog_reader.clone(),
context.clone(),
));
stage_executions.insert(stage_id, stage_exec);
}
stage_executions
}
}
impl Drop for QueryRunner {
fn drop(&mut self) {
self.query_metrics.running_query_num.dec();
self.timeout_abort_task_handle
.as_ref()
.inspect(|h| h.abort());
}
}
impl Debug for QueryRunner {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let mut graph = Graph::<String, String>::new();
let mut stage_id_to_node_id = HashMap::new();
for stage in &self.stage_executions {
let node_id = graph.add_node(format!("{} {}", stage.0, block_on(stage.1.state())));
stage_id_to_node_id.insert(stage.0, node_id);
}
for stage in &self.stage_executions {
let stage_id = stage.0;
if let Some(child_stages) = self.query.stage_graph.get_child_stages(stage_id) {
for child_stage in child_stages {
graph.add_edge(
*stage_id_to_node_id.get(stage_id).unwrap(),
*stage_id_to_node_id.get(child_stage).unwrap(),
"".to_string(),
);
}
}
}
writeln!(f, "{}", Dot::with_config(&graph, &[Config::EdgeNoLabel]))
}
}
impl QueryRunner {
async fn run(mut self) {
self.query_metrics.running_query_num.inc();
let leaf_stages = self.query.leaf_stages();
for stage_id in &leaf_stages {
self.stage_executions[stage_id].start().await;
tracing::trace!(
"Query stage {:?}-{:?} started.",
self.query.query_id,
stage_id
);
}
let mut stages_with_table_scan = self.query.stages_with_table_scan();
let has_lookup_join_stage = self.query.has_lookup_join_stage();
let mut finished_stage_cnt = 0usize;
while let Some(msg_inner) = self.msg_receiver.recv().await {
match msg_inner {
Stage(Scheduled(stage_id)) => {
tracing::trace!(
"Query stage {:?}-{:?} scheduled.",
self.query.query_id,
stage_id
);
self.scheduled_stages_count += 1;
stages_with_table_scan.remove(&stage_id);
if !has_lookup_join_stage && stages_with_table_scan.is_empty() {
tracing::trace!("Query {:?} has scheduled all of its stages that have table scan (iterator creation).", self.query.query_id);
}
for parent in self.query.get_parents(&stage_id) {
if self.all_children_scheduled(parent).await
&& self.stage_executions[parent].is_pending().await
{
self.stage_executions[parent].start().await;
}
}
}
Stage(ScheduledRoot(receiver)) => {
self.send_root_stage_info(receiver);
}
Stage(StageEvent::Failed { id, reason }) => {
error!(
error = %reason.as_report(),
query_id = ?self.query.query_id,
stage_id = ?id,
"query stage failed"
);
self.clean_all_stages(Some(reason)).await;
break;
}
Stage(StageEvent::Completed(_)) => {
finished_stage_cnt += 1;
assert!(finished_stage_cnt <= self.stage_executions.len());
if finished_stage_cnt == self.stage_executions.len() {
tracing::trace!(
"Query {:?} completed, starting to clean stage tasks.",
&self.query.query_id
);
self.clean_all_stages(None).await;
break;
}
}
QueryMessage::CancelQuery(reason) => {
self.clean_all_stages(Some(SchedulerError::QueryCancelled(reason)))
.await;
break;
}
}
}
}
fn send_root_stage_info(&mut self, chunk_rx: Receiver<SchedulerResult<DataChunk>>) {
let root_task_output_id = {
let root_task_id_prost = PbTaskId {
query_id: self.query.query_id.clone().id,
stage_id: self.query.root_stage_id(),
task_id: ROOT_TASK_ID,
};
PbTaskOutputId {
task_id: Some(root_task_id_prost),
output_id: ROOT_TASK_OUTPUT_ID,
}
};
let root_stage_result = QueryResultFetcher::new(
root_task_output_id,
HostAddress::default(),
chunk_rx,
self.query.query_id.clone(),
self.query_execution_info.clone(),
);
let root_stage_sender = mem::take(&mut self.root_stage_sender);
if let Err(e) = root_stage_sender.unwrap().send(Ok(root_stage_result)) {
warn!("Query execution dropped: {:?}", e);
} else {
debug!("Root stage for {:?} sent.", self.query.query_id);
}
}
async fn all_children_scheduled(&self, stage_id: &StageId) -> bool {
for child in self.query.stage_graph.get_child_stages_unchecked(stage_id) {
if !self.stage_executions[child].is_scheduled().await {
return false;
}
}
true
}
async fn clean_all_stages(&mut self, error: Option<SchedulerError>) {
let error_msg = error.as_ref().map(|e| e.to_report_string());
if let Some(reason) = error {
let root_stage_sender = mem::take(&mut self.root_stage_sender);
if let Some(sender) = root_stage_sender {
if let Err(e) = sender.send(Err(reason)) {
warn!("Query execution dropped: {:?}", e);
} else {
debug!(
"Root stage failure event for {:?} sent.",
self.query.query_id
);
}
}
}
tracing::trace!("Cleaning stages in query [{:?}]", self.query.query_id);
for stage_execution in self.stage_executions.values() {
stage_execution.stop(error_msg.clone()).await;
}
}
}
#[cfg(test)]
pub(crate) mod tests {
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, RwLock};
use fixedbitset::FixedBitSet;
use risingwave_batch::worker_manager::worker_node_manager::{
WorkerNodeManager, WorkerNodeSelector,
};
use risingwave_common::catalog::{
ColumnCatalog, ColumnDesc, ConflictBehavior, CreateType, StreamJobStatus,
DEFAULT_SUPER_USER_ID,
};
use risingwave_common::hash::{VirtualNode, VnodeCount, WorkerSlotId, WorkerSlotMapping};
use risingwave_common::types::DataType;
use risingwave_pb::common::worker_node::Property;
use risingwave_pb::common::{HostAddress, WorkerNode, WorkerType};
use risingwave_pb::plan_common::JoinType;
use risingwave_rpc_client::ComputeClientPool;
use crate::catalog::catalog_service::CatalogReader;
use crate::catalog::root_catalog::Catalog;
use crate::catalog::table_catalog::TableType;
use crate::expr::InputRef;
use crate::optimizer::plan_node::{
generic, BatchExchange, BatchFilter, BatchHashJoin, EqJoinPredicate, LogicalScan, ToBatch,
};
use crate::optimizer::property::{Cardinality, Distribution, Order};
use crate::optimizer::{OptimizerContext, PlanRef};
use crate::scheduler::distributed::QueryExecution;
use crate::scheduler::plan_fragmenter::{BatchPlanFragmenter, Query};
use crate::scheduler::{
DistributedQueryMetrics, ExecutionContext, QueryExecutionInfo, ReadSnapshot,
};
use crate::session::SessionImpl;
use crate::utils::Condition;
use crate::TableCatalog;
#[tokio::test]
async fn test_query_should_not_hang_with_empty_worker() {
let worker_node_manager = Arc::new(WorkerNodeManager::mock(vec![]));
let worker_node_selector = WorkerNodeSelector::new(worker_node_manager.clone(), false);
let compute_client_pool = Arc::new(ComputeClientPool::for_test());
let catalog_reader =
CatalogReader::new(Arc::new(parking_lot::RwLock::new(Catalog::default())));
let query = create_query().await;
let query_id = query.query_id().clone();
let query_execution = Arc::new(QueryExecution::new(query, (0, 0), None));
let query_execution_info = Arc::new(RwLock::new(QueryExecutionInfo::new_from_map(
HashMap::from([(query_id, query_execution.clone())]),
)));
assert!(query_execution
.start(
ExecutionContext::new(SessionImpl::mock().into(), None).into(),
worker_node_selector,
ReadSnapshot::ReadUncommitted
.batch_query_epoch(&HashSet::from_iter([0.into()]))
.unwrap(),
compute_client_pool,
catalog_reader,
query_execution_info,
Arc::new(DistributedQueryMetrics::for_test()),
)
.await
.is_err());
}
pub async fn create_query() -> Query {
let ctx = OptimizerContext::mock().await;
let table_id = 0.into();
let vnode_count = VirtualNode::COUNT_FOR_TEST;
let table_catalog: TableCatalog = TableCatalog {
id: table_id,
associated_source_id: None,
name: "test".to_string(),
dependent_relations: vec![],
columns: vec![
ColumnCatalog {
column_desc: ColumnDesc::new_atomic(DataType::Int32, "a", 0),
is_hidden: false,
},
ColumnCatalog {
column_desc: ColumnDesc::new_atomic(DataType::Float64, "b", 1),
is_hidden: false,
},
ColumnCatalog {
column_desc: ColumnDesc::new_atomic(DataType::Int64, "c", 2),
is_hidden: false,
},
],
pk: vec![],
stream_key: vec![],
table_type: TableType::Table,
distribution_key: vec![],
append_only: false,
owner: DEFAULT_SUPER_USER_ID,
retention_seconds: None,
fragment_id: 0, dml_fragment_id: None, vnode_col_index: None,
row_id_index: None,
value_indices: vec![0, 1, 2],
definition: "".to_string(),
conflict_behavior: ConflictBehavior::NoCheck,
version_column_index: None,
read_prefix_len_hint: 0,
version: None,
watermark_columns: FixedBitSet::with_capacity(3),
dist_key_in_pk: vec![],
cardinality: Cardinality::unknown(),
cleaned_by_watermark: false,
created_at_epoch: None,
initialized_at_epoch: None,
stream_job_status: StreamJobStatus::Creating,
create_type: CreateType::Foreground,
description: None,
incoming_sinks: vec![],
initialized_at_cluster_version: None,
created_at_cluster_version: None,
cdc_table_id: None,
vnode_count: VnodeCount::set(vnode_count),
};
let batch_plan_node: PlanRef = LogicalScan::create(
"".to_string(),
table_catalog.into(),
vec![],
ctx,
None,
Cardinality::unknown(),
)
.to_batch()
.unwrap()
.to_distributed()
.unwrap();
let batch_filter = BatchFilter::new(generic::Filter::new(
Condition {
conjunctions: vec![],
},
batch_plan_node.clone(),
))
.into();
let batch_exchange_node1: PlanRef = BatchExchange::new(
batch_plan_node.clone(),
Order::default(),
Distribution::HashShard(vec![0, 1]),
)
.into();
let batch_exchange_node2: PlanRef = BatchExchange::new(
batch_filter,
Order::default(),
Distribution::HashShard(vec![0, 1]),
)
.into();
let logical_join_node = generic::Join::with_full_output(
batch_exchange_node1.clone(),
batch_exchange_node2.clone(),
JoinType::Inner,
Condition::true_cond(),
);
let eq_key_1 = (
InputRef {
index: 0,
data_type: DataType::Int32,
},
InputRef {
index: 2,
data_type: DataType::Int32,
},
false,
);
let eq_key_2 = (
InputRef {
index: 1,
data_type: DataType::Float64,
},
InputRef {
index: 3,
data_type: DataType::Float64,
},
false,
);
let eq_join_predicate =
EqJoinPredicate::new(Condition::true_cond(), vec![eq_key_1, eq_key_2], 2, 2);
let hash_join_node: PlanRef =
BatchHashJoin::new(logical_join_node, eq_join_predicate).into();
let batch_exchange_node: PlanRef = BatchExchange::new(
hash_join_node.clone(),
Order::default(),
Distribution::Single,
)
.into();
let worker1 = WorkerNode {
id: 0,
r#type: WorkerType::ComputeNode as i32,
host: Some(HostAddress {
host: "127.0.0.1".to_string(),
port: 5687,
}),
state: risingwave_pb::common::worker_node::State::Running as i32,
parallelism: 8,
property: Some(Property {
is_unschedulable: false,
is_serving: true,
is_streaming: true,
internal_rpc_host_addr: "".to_string(),
}),
transactional_id: Some(0),
..Default::default()
};
let worker2 = WorkerNode {
id: 1,
r#type: WorkerType::ComputeNode as i32,
host: Some(HostAddress {
host: "127.0.0.1".to_string(),
port: 5688,
}),
state: risingwave_pb::common::worker_node::State::Running as i32,
parallelism: 8,
property: Some(Property {
is_unschedulable: false,
is_serving: true,
is_streaming: true,
internal_rpc_host_addr: "".to_string(),
}),
transactional_id: Some(1),
..Default::default()
};
let worker3 = WorkerNode {
id: 2,
r#type: WorkerType::ComputeNode as i32,
host: Some(HostAddress {
host: "127.0.0.1".to_string(),
port: 5689,
}),
state: risingwave_pb::common::worker_node::State::Running as i32,
parallelism: 8,
property: Some(Property {
is_unschedulable: false,
is_serving: true,
is_streaming: true,
internal_rpc_host_addr: "".to_string(),
}),
transactional_id: Some(2),
..Default::default()
};
let workers = vec![worker1, worker2, worker3];
let worker_node_manager = Arc::new(WorkerNodeManager::mock(workers));
let worker_node_selector = WorkerNodeSelector::new(worker_node_manager.clone(), false);
let mapping =
WorkerSlotMapping::new_uniform(std::iter::once(WorkerSlotId::new(0, 0)), vnode_count);
worker_node_manager.insert_streaming_fragment_mapping(0, mapping.clone());
worker_node_manager.set_serving_fragment_mapping(vec![(0, mapping)].into_iter().collect());
let catalog = Arc::new(parking_lot::RwLock::new(Catalog::default()));
catalog.write().insert_table_id_mapping(table_id, 0);
let catalog_reader = CatalogReader::new(catalog);
let fragmenter = BatchPlanFragmenter::new(
worker_node_selector,
catalog_reader,
None,
batch_exchange_node.clone(),
)
.unwrap();
fragmenter.generate_complete_query().await.unwrap()
}
}