1use std::collections::{HashMap, hash_map};
16use std::net::SocketAddr;
17use std::sync::Arc;
18
19use anyhow::Context;
20use parking_lot::Mutex;
21use risingwave_common::config::BatchConfig;
22use risingwave_common::memory::MemoryContext;
23use risingwave_common::util::runtime::BackgroundShutdownRuntime;
24use risingwave_common::util::tracing::TracingContext;
25use risingwave_pb::batch_plan::{PbTaskId, PbTaskOutputId, PlanFragment};
26use risingwave_pb::common::BatchQueryEpoch;
27use risingwave_pb::plan_common::ExprContext;
28use risingwave_pb::task_service::task_info_response::TaskStatus;
29use risingwave_pb::task_service::{GetDataResponse, TaskInfoResponse};
30use tokio::sync::mpsc::Sender;
31use tonic::Status;
32
33use super::BatchTaskContext;
34use crate::error::Result;
35use crate::monitor::BatchManagerMetrics;
36use crate::rpc::service::exchange::GrpcExchangeWriter;
37use crate::task::{BatchTaskExecution, StateReporter, TaskId, TaskOutput, TaskOutputId};
38
39#[derive(Clone)]
41pub struct BatchManager {
42 tasks: Arc<Mutex<HashMap<TaskId, Arc<BatchTaskExecution>>>>,
44
45 runtime: Arc<BackgroundShutdownRuntime>,
47
48 config: BatchConfig,
50
51 mem_context: MemoryContext,
53
54 metrics: Arc<BatchManagerMetrics>,
56}
57
58impl BatchManager {
59 pub fn new(config: BatchConfig, metrics: Arc<BatchManagerMetrics>, mem_limit: u64) -> Self {
60 let runtime = {
61 let mut builder = tokio::runtime::Builder::new_multi_thread();
62 if let Some(worker_threads_num) = config.worker_threads_num {
63 builder.worker_threads(worker_threads_num);
64 }
65 builder
66 .thread_name("rw-batch")
67 .enable_all()
68 .build()
69 .unwrap()
70 };
71
72 let mem_context = MemoryContext::root(metrics.batch_total_mem.clone(), mem_limit);
73 BatchManager {
74 tasks: Arc::new(Mutex::new(HashMap::new())),
75 runtime: Arc::new(runtime.into()),
76 config,
77 metrics,
78 mem_context,
79 }
80 }
81
82 pub(crate) fn metrics(&self) -> Arc<BatchManagerMetrics> {
83 self.metrics.clone()
84 }
85
86 pub fn memory_context_ref(&self) -> MemoryContext {
87 self.mem_context.clone()
88 }
89
90 pub async fn fire_task(
91 self: &Arc<Self>,
92 tid: &PbTaskId,
93 plan: PlanFragment,
94 epoch: BatchQueryEpoch,
95 context: Arc<dyn BatchTaskContext>, state_reporter: StateReporter,
97 tracing_context: TracingContext,
98 expr_context: ExprContext,
99 ) -> Result<()> {
100 trace!("Received task id: {:?}, plan: {:?}", tid, plan);
101 let task = BatchTaskExecution::new(tid, plan, context, epoch, self.runtime())?;
102 let task_id = task.get_task_id().clone();
103 let task = Arc::new(task);
104 let ret = if let hash_map::Entry::Vacant(e) = self.tasks.lock().entry(task_id.clone()) {
108 e.insert(task.clone());
109
110 let this = self.clone();
111 let task_id = task_id.clone();
112 let state_reporter = state_reporter.clone();
113 let heartbeat_join_handle = self.runtime.spawn(async move {
114 this.start_task_heartbeat(state_reporter, task_id).await;
115 });
116 task.set_heartbeat_join_handle(heartbeat_join_handle);
117
118 Ok(())
119 } else {
120 bail!(
121 "can not create duplicate task with the same id: {:?}",
122 task_id,
123 );
124 };
125 task.async_execute(Some(state_reporter), tracing_context, expr_context)
126 .await
127 .inspect_err(|_| {
128 self.cancel_task(&task_id.to_prost());
129 })?;
130 ret
131 }
132
133 #[cfg(test)]
134 async fn fire_task_for_test(
135 self: &Arc<Self>,
136 tid: &PbTaskId,
137 plan: PlanFragment,
138 ) -> Result<()> {
139 use risingwave_hummock_sdk::test_batch_query_epoch;
140
141 use crate::task::ComputeNodeContext;
142
143 self.fire_task(
144 tid,
145 plan,
146 test_batch_query_epoch(),
147 ComputeNodeContext::for_test(),
148 StateReporter::new_with_test(),
149 TracingContext::none(),
150 ExprContext {
151 time_zone: "UTC".to_owned(),
152 strict_mode: false,
153 },
154 )
155 .await
156 }
157
158 async fn start_task_heartbeat(&self, mut state_reporter: StateReporter, task_id: TaskId) {
159 let _metric_guard = scopeguard::guard((), |_| {
160 tracing::debug!("heartbeat worker for task {:?} stopped", task_id);
161 self.metrics.batch_heartbeat_worker_num.dec();
162 });
163 tracing::debug!("heartbeat worker for task {:?} started", task_id);
164 self.metrics.batch_heartbeat_worker_num.inc();
165 let mut heartbeat_interval = tokio::time::interval(core::time::Duration::from_secs(60));
168 heartbeat_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
169 heartbeat_interval.reset();
170 loop {
171 heartbeat_interval.tick().await;
172 if !self.tasks.lock().contains_key(&task_id) {
173 break;
174 }
175 if state_reporter
176 .send(TaskInfoResponse {
177 task_id: Some(task_id.to_prost()),
178 task_status: TaskStatus::Ping.into(),
179 error_message: "".to_owned(),
180 })
181 .await
182 .is_err()
183 {
184 tracing::warn!("try to cancel task {:?} due to heartbeat", task_id);
185 self.cancel_task(&task_id.to_prost());
187 break;
188 }
189 }
190 }
191
192 pub fn get_data(
193 &self,
194 tx: Sender<std::result::Result<GetDataResponse, Status>>,
195 peer_addr: SocketAddr,
196 pb_task_output_id: &PbTaskOutputId,
197 ) -> Result<()> {
198 let task_id = TaskOutputId::try_from(pb_task_output_id)?;
199 tracing::debug!(target: "events::compute::exchange", peer_addr = %peer_addr, from = ?task_id, "serve exchange RPC");
200 let mut task_output = self.take_output(pb_task_output_id)?;
201 self.runtime.spawn(async move {
202 let mut writer = GrpcExchangeWriter::new(tx.clone());
203 match task_output.take_data(&mut writer).await {
204 Ok(_) => {
205 tracing::trace!(
206 from = ?task_id,
207 "exchanged {} chunks",
208 writer.written_chunks(),
209 );
210 Ok(())
211 }
212 Err(e) => tx.send(Err(e.into())).await,
213 }
214 });
215 Ok(())
216 }
217
218 pub fn take_output(&self, output_id: &PbTaskOutputId) -> Result<TaskOutput> {
219 let task_id = TaskId::from(output_id.get_task_id()?);
220 self.tasks
221 .lock()
222 .get(&task_id)
223 .with_context(|| format!("task {:?} not found", task_id))?
224 .get_task_output(output_id)
225 }
226
227 pub fn cancel_task(&self, sid: &PbTaskId) {
228 let sid = TaskId::from(sid);
229 match self.tasks.lock().remove(&sid) {
230 Some(task) => {
231 tracing::trace!("Removed task: {:?}", task.get_task_id());
232 task.cancel();
235 if let Some(heartbeat_join_handle) = task.heartbeat_join_handle() {
236 heartbeat_join_handle.abort();
237 }
238 }
239 None => {
240 warn!("Task {:?} not found for cancel", sid)
241 }
242 };
243 }
244
245 pub fn check_if_task_running(&self, task_id: &TaskId) -> Result<()> {
247 match self.tasks.lock().get(task_id) {
248 Some(task) => task.check_if_running(),
249 None => bail!("task {:?} not found", task_id),
250 }
251 }
252
253 pub fn check_if_task_aborted(&self, task_id: &TaskId) -> Result<bool> {
254 match self.tasks.lock().get(task_id) {
255 Some(task) => task.check_if_aborted(),
256 None => bail!("task {:?} not found", task_id),
257 }
258 }
259
260 #[cfg(test)]
261 async fn wait_until_task_aborted(&self, task_id: &TaskId) -> Result<()> {
262 use std::time::Duration;
263 loop {
264 match self.tasks.lock().get(task_id) {
265 Some(task) => {
266 let ret = task.check_if_aborted();
267 match ret {
268 Ok(true) => return Ok(()),
269 Ok(false) => {}
270 Err(err) => return Err(err),
271 }
272 }
273 None => bail!("task {:?} not found", task_id),
274 }
275 tokio::time::sleep(Duration::from_millis(100)).await
276 }
277 }
278
279 pub fn runtime(&self) -> Arc<BackgroundShutdownRuntime> {
280 self.runtime.clone()
281 }
282
283 pub fn config(&self) -> &BatchConfig {
284 &self.config
285 }
286}
287
288#[cfg(test)]
289mod tests {
290 use std::sync::Arc;
291
292 use risingwave_common::config::BatchConfig;
293 use risingwave_pb::batch_plan::exchange_info::DistributionMode;
294 use risingwave_pb::batch_plan::plan_node::NodeBody;
295 use risingwave_pb::batch_plan::{
296 ExchangeInfo, PbTaskId, PbTaskOutputId, PlanFragment, PlanNode,
297 };
298
299 use crate::monitor::BatchManagerMetrics;
300 use crate::task::{BatchManager, TaskId};
301
302 #[tokio::test]
303 async fn test_task_not_found() {
304 let manager = Arc::new(BatchManager::new(
305 BatchConfig::default(),
306 BatchManagerMetrics::for_test(),
307 u64::MAX,
308 ));
309 let task_id = TaskId {
310 task_id: 0,
311 stage_id: 0,
312 query_id: "abc".to_owned(),
313 };
314
315 let error = manager.check_if_task_running(&task_id).unwrap_err();
316 assert!(error.to_string().contains("not found"), "{:?}", error);
317
318 let output_id = PbTaskOutputId {
319 task_id: Some(risingwave_pb::batch_plan::TaskId {
320 stage_id: 0,
321 task_id: 0,
322 query_id: "".to_owned(),
323 }),
324 output_id: 0,
325 };
326 let error = manager.take_output(&output_id).unwrap_err();
327 assert!(error.to_string().contains("not found"), "{:?}", error);
328 }
329
330 #[tokio::test]
331 #[ignore]
333 async fn test_task_cancel_for_busy_loop() {
334 let manager = Arc::new(BatchManager::new(
335 BatchConfig::default(),
336 BatchManagerMetrics::for_test(),
337 u64::MAX,
338 ));
339 let plan = PlanFragment {
340 root: Some(PlanNode {
341 children: vec![],
342 identity: "".to_owned(),
343 node_body: Some(NodeBody::BusyLoopExecutor(true)),
344 }),
345 exchange_info: Some(ExchangeInfo {
346 mode: DistributionMode::Single as i32,
347 distribution: None,
348 }),
349 };
350 let task_id = PbTaskId {
351 query_id: "".to_owned(),
352 stage_id: 0,
353 task_id: 0,
354 };
355 manager.fire_task_for_test(&task_id, plan).await.unwrap();
356 manager.cancel_task(&task_id);
357 let task_id = TaskId::from(&task_id);
358 assert!(!manager.tasks.lock().contains_key(&task_id));
359 }
360
361 #[tokio::test]
362 #[ignore]
364 async fn test_task_abort_for_busy_loop() {
365 let manager = Arc::new(BatchManager::new(
366 BatchConfig::default(),
367 BatchManagerMetrics::for_test(),
368 u64::MAX,
369 ));
370 let plan = PlanFragment {
371 root: Some(PlanNode {
372 children: vec![],
373 identity: "".to_owned(),
374 node_body: Some(NodeBody::BusyLoopExecutor(true)),
375 }),
376 exchange_info: Some(ExchangeInfo {
377 mode: DistributionMode::Single as i32,
378 distribution: None,
379 }),
380 };
381 let task_id = PbTaskId {
382 query_id: "".to_owned(),
383 stage_id: 0,
384 task_id: 0,
385 };
386 manager.fire_task_for_test(&task_id, plan).await.unwrap();
387 let task_id = TaskId::from(&task_id);
388 manager
389 .tasks
390 .lock()
391 .get(&task_id)
392 .unwrap()
393 .abort("Abort Test".to_owned());
394 assert!(manager.wait_until_task_aborted(&task_id).await.is_ok());
395 }
396}