use std::collections::{hash_map, HashMap};
use std::net::SocketAddr;
use std::sync::Arc;
use anyhow::Context;
use parking_lot::Mutex;
use risingwave_common::config::BatchConfig;
use risingwave_common::memory::MemoryContext;
use risingwave_common::util::runtime::BackgroundShutdownRuntime;
use risingwave_common::util::tracing::TracingContext;
use risingwave_pb::batch_plan::{PbTaskId, PbTaskOutputId, PlanFragment};
use risingwave_pb::common::BatchQueryEpoch;
use risingwave_pb::plan_common::ExprContext;
use risingwave_pb::task_service::task_info_response::TaskStatus;
use risingwave_pb::task_service::{GetDataResponse, TaskInfoResponse};
use tokio::sync::mpsc::Sender;
use tonic::Status;
use crate::error::Result;
use crate::monitor::BatchManagerMetrics;
use crate::rpc::service::exchange::GrpcExchangeWriter;
use crate::task::{
BatchTaskExecution, ComputeNodeContext, StateReporter, TaskId, TaskOutput, TaskOutputId,
};
#[derive(Clone)]
pub struct BatchManager {
tasks: Arc<Mutex<HashMap<TaskId, Arc<BatchTaskExecution<ComputeNodeContext>>>>>,
runtime: Arc<BackgroundShutdownRuntime>,
config: BatchConfig,
mem_context: MemoryContext,
metrics: Arc<BatchManagerMetrics>,
}
impl BatchManager {
pub fn new(config: BatchConfig, metrics: Arc<BatchManagerMetrics>, mem_limit: u64) -> Self {
let runtime = {
let mut builder = tokio::runtime::Builder::new_multi_thread();
if let Some(worker_threads_num) = config.worker_threads_num {
builder.worker_threads(worker_threads_num);
}
builder
.thread_name("rw-batch")
.enable_all()
.build()
.unwrap()
};
let mem_context = MemoryContext::root(metrics.batch_total_mem.clone(), mem_limit);
BatchManager {
tasks: Arc::new(Mutex::new(HashMap::new())),
runtime: Arc::new(runtime.into()),
config,
metrics,
mem_context,
}
}
pub(crate) fn metrics(&self) -> Arc<BatchManagerMetrics> {
self.metrics.clone()
}
pub fn memory_context_ref(&self) -> MemoryContext {
self.mem_context.clone()
}
pub async fn fire_task(
self: &Arc<Self>,
tid: &PbTaskId,
plan: PlanFragment,
epoch: BatchQueryEpoch,
context: ComputeNodeContext,
state_reporter: StateReporter,
tracing_context: TracingContext,
expr_context: ExprContext,
) -> Result<()> {
trace!("Received task id: {:?}, plan: {:?}", tid, plan);
let task = BatchTaskExecution::new(tid, plan, context, epoch, self.runtime())?;
let task_id = task.get_task_id().clone();
let task = Arc::new(task);
let ret = if let hash_map::Entry::Vacant(e) = self.tasks.lock().entry(task_id.clone()) {
e.insert(task.clone());
let this = self.clone();
let task_id = task_id.clone();
let state_reporter = state_reporter.clone();
let heartbeat_join_handle = self.runtime.spawn(async move {
this.start_task_heartbeat(state_reporter, task_id).await;
});
task.set_heartbeat_join_handle(heartbeat_join_handle);
Ok(())
} else {
bail!(
"can not create duplicate task with the same id: {:?}",
task_id,
);
};
task.async_execute(Some(state_reporter), tracing_context, expr_context)
.await
.inspect_err(|_| {
self.cancel_task(&task_id.to_prost());
})?;
ret
}
#[cfg(test)]
async fn fire_task_for_test(
self: &Arc<Self>,
tid: &PbTaskId,
plan: PlanFragment,
) -> Result<()> {
use risingwave_hummock_sdk::test_batch_query_epoch;
self.fire_task(
tid,
plan,
test_batch_query_epoch(),
ComputeNodeContext::for_test(),
StateReporter::new_with_test(),
TracingContext::none(),
ExprContext {
time_zone: "UTC".to_string(),
},
)
.await
}
async fn start_task_heartbeat(&self, mut state_reporter: StateReporter, task_id: TaskId) {
let _metric_guard = scopeguard::guard((), |_| {
tracing::debug!("heartbeat worker for task {:?} stopped", task_id);
self.metrics.batch_heartbeat_worker_num.dec();
});
tracing::debug!("heartbeat worker for task {:?} started", task_id);
self.metrics.batch_heartbeat_worker_num.inc();
let mut heartbeat_interval = tokio::time::interval(core::time::Duration::from_secs(60));
heartbeat_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
heartbeat_interval.reset();
loop {
heartbeat_interval.tick().await;
if !self.tasks.lock().contains_key(&task_id) {
break;
}
if state_reporter
.send(TaskInfoResponse {
task_id: Some(task_id.to_prost()),
task_status: TaskStatus::Ping.into(),
error_message: "".to_string(),
})
.await
.is_err()
{
tracing::warn!("try to cancel task {:?} due to heartbeat", task_id);
self.cancel_task(&task_id.to_prost());
break;
}
}
}
pub fn get_data(
&self,
tx: Sender<std::result::Result<GetDataResponse, Status>>,
peer_addr: SocketAddr,
pb_task_output_id: &PbTaskOutputId,
) -> Result<()> {
let task_id = TaskOutputId::try_from(pb_task_output_id)?;
tracing::debug!(target: "events::compute::exchange", peer_addr = %peer_addr, from = ?task_id, "serve exchange RPC");
let mut task_output = self.take_output(pb_task_output_id)?;
self.runtime.spawn(async move {
let mut writer = GrpcExchangeWriter::new(tx.clone());
match task_output.take_data(&mut writer).await {
Ok(_) => {
tracing::trace!(
from = ?task_id,
"exchanged {} chunks",
writer.written_chunks(),
);
Ok(())
}
Err(e) => tx.send(Err(e.into())).await,
}
});
Ok(())
}
pub fn take_output(&self, output_id: &PbTaskOutputId) -> Result<TaskOutput> {
let task_id = TaskId::from(output_id.get_task_id()?);
self.tasks
.lock()
.get(&task_id)
.with_context(|| format!("task {:?} not found", task_id))?
.get_task_output(output_id)
}
pub fn cancel_task(&self, sid: &PbTaskId) {
let sid = TaskId::from(sid);
match self.tasks.lock().remove(&sid) {
Some(task) => {
tracing::trace!("Removed task: {:?}", task.get_task_id());
task.cancel();
if let Some(heartbeat_join_handle) = task.heartbeat_join_handle() {
heartbeat_join_handle.abort();
}
}
None => {
warn!("Task {:?} not found for cancel", sid)
}
};
}
pub fn check_if_task_running(&self, task_id: &TaskId) -> Result<()> {
match self.tasks.lock().get(task_id) {
Some(task) => task.check_if_running(),
None => bail!("task {:?} not found", task_id),
}
}
pub fn check_if_task_aborted(&self, task_id: &TaskId) -> Result<bool> {
match self.tasks.lock().get(task_id) {
Some(task) => task.check_if_aborted(),
None => bail!("task {:?} not found", task_id),
}
}
#[cfg(test)]
async fn wait_until_task_aborted(&self, task_id: &TaskId) -> Result<()> {
use std::time::Duration;
loop {
match self.tasks.lock().get(task_id) {
Some(task) => {
let ret = task.check_if_aborted();
match ret {
Ok(true) => return Ok(()),
Ok(false) => {}
Err(err) => return Err(err),
}
}
None => bail!("task {:?} not found", task_id),
}
tokio::time::sleep(Duration::from_millis(100)).await
}
}
pub fn runtime(&self) -> Arc<BackgroundShutdownRuntime> {
self.runtime.clone()
}
pub fn config(&self) -> &BatchConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use risingwave_common::config::BatchConfig;
use risingwave_pb::batch_plan::exchange_info::DistributionMode;
use risingwave_pb::batch_plan::plan_node::NodeBody;
use risingwave_pb::batch_plan::{
ExchangeInfo, PbTaskId, PbTaskOutputId, PlanFragment, PlanNode, ValuesNode,
};
use crate::monitor::BatchManagerMetrics;
use crate::task::{BatchManager, TaskId};
#[tokio::test]
async fn test_task_not_found() {
let manager = Arc::new(BatchManager::new(
BatchConfig::default(),
BatchManagerMetrics::for_test(),
u64::MAX,
));
let task_id = TaskId {
task_id: 0,
stage_id: 0,
query_id: "abc".to_string(),
};
let error = manager.check_if_task_running(&task_id).unwrap_err();
assert!(error.to_string().contains("not found"), "{:?}", error);
let output_id = PbTaskOutputId {
task_id: Some(risingwave_pb::batch_plan::TaskId {
stage_id: 0,
task_id: 0,
query_id: "".to_owned(),
}),
output_id: 0,
};
let error = manager.take_output(&output_id).unwrap_err();
assert!(error.to_string().contains("not found"), "{:?}", error);
}
#[tokio::test]
async fn test_task_id_conflict() {
let manager = Arc::new(BatchManager::new(
BatchConfig::default(),
BatchManagerMetrics::for_test(),
u64::MAX,
));
let plan = PlanFragment {
root: Some(PlanNode {
children: vec![],
identity: "".to_string(),
node_body: Some(NodeBody::Values(ValuesNode {
tuples: vec![],
fields: vec![],
})),
}),
exchange_info: Some(ExchangeInfo {
mode: DistributionMode::Single as i32,
distribution: None,
}),
};
let task_id = PbTaskId {
query_id: "".to_string(),
stage_id: 0,
task_id: 0,
};
manager
.fire_task_for_test(&task_id, plan.clone())
.await
.unwrap();
let err = manager
.fire_task_for_test(&task_id, plan)
.await
.unwrap_err();
assert!(err
.to_string()
.contains("can not create duplicate task with the same id"));
}
#[tokio::test]
#[ignore]
async fn test_task_cancel_for_busy_loop() {
let manager = Arc::new(BatchManager::new(
BatchConfig::default(),
BatchManagerMetrics::for_test(),
u64::MAX,
));
let plan = PlanFragment {
root: Some(PlanNode {
children: vec![],
identity: "".to_string(),
node_body: Some(NodeBody::BusyLoopExecutor(true)),
}),
exchange_info: Some(ExchangeInfo {
mode: DistributionMode::Single as i32,
distribution: None,
}),
};
let task_id = PbTaskId {
query_id: "".to_string(),
stage_id: 0,
task_id: 0,
};
manager.fire_task_for_test(&task_id, plan).await.unwrap();
manager.cancel_task(&task_id);
let task_id = TaskId::from(&task_id);
assert!(!manager.tasks.lock().contains_key(&task_id));
}
#[tokio::test]
#[ignore]
async fn test_task_abort_for_busy_loop() {
let manager = Arc::new(BatchManager::new(
BatchConfig::default(),
BatchManagerMetrics::for_test(),
u64::MAX,
));
let plan = PlanFragment {
root: Some(PlanNode {
children: vec![],
identity: "".to_string(),
node_body: Some(NodeBody::BusyLoopExecutor(true)),
}),
exchange_info: Some(ExchangeInfo {
mode: DistributionMode::Single as i32,
distribution: None,
}),
};
let task_id = PbTaskId {
query_id: "".to_string(),
stage_id: 0,
task_id: 0,
};
manager.fire_task_for_test(&task_id, plan).await.unwrap();
let task_id = TaskId::from(&task_id);
manager
.tasks
.lock()
.get(&task_id)
.unwrap()
.abort("Abort Test".to_owned());
assert!(manager.wait_until_task_aborted(&task_id).await.is_ok());
}
}