1use std::collections::{HashMap, hash_map};
16use std::net::SocketAddr;
17use std::sync::Arc;
18
19use anyhow::Context;
20use parking_lot::Mutex;
21use risingwave_common::config::BatchConfig;
22use risingwave_common::memory::MemoryContext;
23use risingwave_common::util::runtime::BackgroundShutdownRuntime;
24use risingwave_common::util::tracing::TracingContext;
25use risingwave_pb::batch_plan::{PbTaskId, PbTaskOutputId, PlanFragment};
26use risingwave_pb::plan_common::ExprContext;
27use risingwave_pb::task_service::task_info_response::TaskStatus;
28use risingwave_pb::task_service::{GetDataResponse, TaskInfoResponse};
29use tokio::sync::mpsc::Sender;
30use tonic::Status;
31
32use super::BatchTaskContext;
33use crate::error::Result;
34use crate::monitor::BatchManagerMetrics;
35use crate::rpc::service::exchange::GrpcExchangeWriter;
36use crate::task::{BatchTaskExecution, StateReporter, TaskId, TaskOutput, TaskOutputId};
37
38#[derive(Clone)]
40pub struct BatchManager {
41 tasks: Arc<Mutex<HashMap<TaskId, Arc<BatchTaskExecution>>>>,
43
44 runtime: Arc<BackgroundShutdownRuntime>,
46
47 config: BatchConfig,
49
50 mem_context: MemoryContext,
52
53 metrics: Arc<BatchManagerMetrics>,
55}
56
57impl BatchManager {
58 pub fn new(config: BatchConfig, metrics: Arc<BatchManagerMetrics>, mem_limit: u64) -> Self {
59 let runtime = {
60 let mut builder = tokio::runtime::Builder::new_multi_thread();
61 if let Some(worker_threads_num) = config.worker_threads_num {
62 builder.worker_threads(worker_threads_num);
63 }
64 builder
65 .thread_name("rw-batch")
66 .enable_all()
67 .build()
68 .unwrap()
69 };
70
71 let mem_context = MemoryContext::root(metrics.batch_total_mem.clone(), mem_limit);
72 BatchManager {
73 tasks: Arc::new(Mutex::new(HashMap::new())),
74 runtime: Arc::new(runtime.into()),
75 config,
76 metrics,
77 mem_context,
78 }
79 }
80
81 pub(crate) fn metrics(&self) -> Arc<BatchManagerMetrics> {
82 self.metrics.clone()
83 }
84
85 pub fn memory_context_ref(&self) -> MemoryContext {
86 self.mem_context.clone()
87 }
88
89 pub async fn fire_task(
90 self: &Arc<Self>,
91 tid: &PbTaskId,
92 plan: PlanFragment,
93 context: Arc<dyn BatchTaskContext>, state_reporter: StateReporter,
95 tracing_context: TracingContext,
96 expr_context: ExprContext,
97 ) -> Result<()> {
98 trace!("Received task id: {:?}, plan: {:?}", tid, plan);
99 let task = BatchTaskExecution::new(tid, plan, context, self.runtime())?;
100 let task_id = task.get_task_id().clone();
101 let task = Arc::new(task);
102 let ret = if let hash_map::Entry::Vacant(e) = self.tasks.lock().entry(task_id.clone()) {
106 e.insert(task.clone());
107
108 let this = self.clone();
109 let task_id = task_id.clone();
110 let state_reporter = state_reporter.clone();
111 let heartbeat_join_handle = self.runtime.spawn(async move {
112 this.start_task_heartbeat(state_reporter, task_id).await;
113 });
114 task.set_heartbeat_join_handle(heartbeat_join_handle);
115
116 Ok(())
117 } else {
118 bail!(
119 "can not create duplicate task with the same id: {:?}",
120 task_id,
121 );
122 };
123 task.async_execute(Some(state_reporter), tracing_context, expr_context)
124 .await
125 .inspect_err(|_| {
126 self.cancel_task(&task_id.to_prost());
127 })?;
128 ret
129 }
130
131 #[cfg(test)]
132 async fn fire_task_for_test(
133 self: &Arc<Self>,
134 tid: &PbTaskId,
135 plan: PlanFragment,
136 ) -> Result<()> {
137 use crate::task::ComputeNodeContext;
138
139 self.fire_task(
140 tid,
141 plan,
142 ComputeNodeContext::for_test(),
143 StateReporter::new_with_test(),
144 TracingContext::none(),
145 ExprContext {
146 time_zone: "UTC".to_owned(),
147 strict_mode: false,
148 },
149 )
150 .await
151 }
152
153 async fn start_task_heartbeat(&self, mut state_reporter: StateReporter, task_id: TaskId) {
154 let _metric_guard = scopeguard::guard((), |_| {
155 tracing::debug!("heartbeat worker for task {:?} stopped", task_id);
156 self.metrics.batch_heartbeat_worker_num.dec();
157 });
158 tracing::debug!("heartbeat worker for task {:?} started", task_id);
159 self.metrics.batch_heartbeat_worker_num.inc();
160 let mut heartbeat_interval = tokio::time::interval(core::time::Duration::from_secs(60));
163 heartbeat_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
164 heartbeat_interval.reset();
165 loop {
166 heartbeat_interval.tick().await;
167 if !self.tasks.lock().contains_key(&task_id) {
168 break;
169 }
170 if state_reporter
171 .send(TaskInfoResponse {
172 task_id: Some(task_id.to_prost()),
173 task_status: TaskStatus::Ping.into(),
174 error_message: "".to_owned(),
175 })
176 .await
177 .is_err()
178 {
179 tracing::warn!("try to cancel task {:?} due to heartbeat", task_id);
180 self.cancel_task(&task_id.to_prost());
182 break;
183 }
184 }
185 }
186
187 pub fn get_data(
188 &self,
189 tx: Sender<std::result::Result<GetDataResponse, Status>>,
190 peer_addr: SocketAddr,
191 pb_task_output_id: &PbTaskOutputId,
192 ) -> Result<()> {
193 let task_id = TaskOutputId::try_from(pb_task_output_id)?;
194 tracing::debug!(target: "events::compute::exchange", peer_addr = %peer_addr, from = ?task_id, "serve exchange RPC");
195 let mut task_output = self.take_output(pb_task_output_id)?;
196 self.runtime.spawn(async move {
197 let mut writer = GrpcExchangeWriter::new(tx.clone());
198 match task_output.take_data(&mut writer).await {
199 Ok(_) => {
200 tracing::trace!(
201 from = ?task_id,
202 "exchanged {} chunks",
203 writer.written_chunks(),
204 );
205 Ok(())
206 }
207 Err(e) => tx.send(Err(e.into())).await,
208 }
209 });
210 Ok(())
211 }
212
213 pub fn take_output(&self, output_id: &PbTaskOutputId) -> Result<TaskOutput> {
214 let task_id = TaskId::from(output_id.get_task_id()?);
215 self.tasks
216 .lock()
217 .get(&task_id)
218 .with_context(|| format!("task {:?} not found", task_id))?
219 .get_task_output(output_id)
220 }
221
222 pub fn cancel_task(&self, sid: &PbTaskId) {
223 let sid = TaskId::from(sid);
224 match self.tasks.lock().remove(&sid) {
225 Some(task) => {
226 tracing::trace!("Removed task: {:?}", task.get_task_id());
227 task.cancel();
230 if let Some(heartbeat_join_handle) = task.heartbeat_join_handle() {
231 heartbeat_join_handle.abort();
232 }
233 }
234 None => {
235 warn!("Task {:?} not found for cancel", sid)
236 }
237 };
238 }
239
240 pub fn check_if_task_running(&self, task_id: &TaskId) -> Result<()> {
242 match self.tasks.lock().get(task_id) {
243 Some(task) => task.check_if_running(),
244 None => bail!("task {:?} not found", task_id),
245 }
246 }
247
248 pub fn check_if_task_aborted(&self, task_id: &TaskId) -> Result<bool> {
249 match self.tasks.lock().get(task_id) {
250 Some(task) => task.check_if_aborted(),
251 None => bail!("task {:?} not found", task_id),
252 }
253 }
254
255 #[cfg(test)]
256 async fn wait_until_task_aborted(&self, task_id: &TaskId) -> Result<()> {
257 use std::time::Duration;
258 loop {
259 match self.tasks.lock().get(task_id) {
260 Some(task) => {
261 let ret = task.check_if_aborted();
262 match ret {
263 Ok(true) => return Ok(()),
264 Ok(false) => {}
265 Err(err) => return Err(err),
266 }
267 }
268 None => bail!("task {:?} not found", task_id),
269 }
270 tokio::time::sleep(Duration::from_millis(100)).await
271 }
272 }
273
274 pub fn runtime(&self) -> Arc<BackgroundShutdownRuntime> {
275 self.runtime.clone()
276 }
277
278 pub fn config(&self) -> &BatchConfig {
279 &self.config
280 }
281}
282
283#[cfg(test)]
284mod tests {
285 use std::sync::Arc;
286
287 use risingwave_common::config::BatchConfig;
288 use risingwave_pb::batch_plan::exchange_info::DistributionMode;
289 use risingwave_pb::batch_plan::plan_node::NodeBody;
290 use risingwave_pb::batch_plan::{
291 ExchangeInfo, PbTaskId, PbTaskOutputId, PlanFragment, PlanNode,
292 };
293
294 use crate::monitor::BatchManagerMetrics;
295 use crate::task::{BatchManager, TaskId};
296
297 #[tokio::test]
298 async fn test_task_not_found() {
299 let manager = Arc::new(BatchManager::new(
300 BatchConfig::default(),
301 BatchManagerMetrics::for_test(),
302 u64::MAX,
303 ));
304 let task_id = TaskId {
305 task_id: 0,
306 stage_id: 0,
307 query_id: "abc".to_owned(),
308 };
309
310 let error = manager.check_if_task_running(&task_id).unwrap_err();
311 assert!(error.to_string().contains("not found"), "{:?}", error);
312
313 let output_id = PbTaskOutputId {
314 task_id: Some(risingwave_pb::batch_plan::TaskId {
315 stage_id: 0,
316 task_id: 0,
317 query_id: "".to_owned(),
318 }),
319 output_id: 0,
320 };
321 let error = manager.take_output(&output_id).unwrap_err();
322 assert!(error.to_string().contains("not found"), "{:?}", error);
323 }
324
325 #[tokio::test]
326 #[ignore]
328 async fn test_task_cancel_for_busy_loop() {
329 let manager = Arc::new(BatchManager::new(
330 BatchConfig::default(),
331 BatchManagerMetrics::for_test(),
332 u64::MAX,
333 ));
334 let plan = PlanFragment {
335 root: Some(PlanNode {
336 children: vec![],
337 identity: "".to_owned(),
338 node_body: Some(NodeBody::BusyLoopExecutor(true)),
339 }),
340 exchange_info: Some(ExchangeInfo {
341 mode: DistributionMode::Single as i32,
342 distribution: None,
343 }),
344 };
345 let task_id = PbTaskId {
346 query_id: "".to_owned(),
347 stage_id: 0,
348 task_id: 0,
349 };
350 manager.fire_task_for_test(&task_id, plan).await.unwrap();
351 manager.cancel_task(&task_id);
352 let task_id = TaskId::from(&task_id);
353 assert!(!manager.tasks.lock().contains_key(&task_id));
354 }
355
356 #[tokio::test]
357 #[ignore]
359 async fn test_task_abort_for_busy_loop() {
360 let manager = Arc::new(BatchManager::new(
361 BatchConfig::default(),
362 BatchManagerMetrics::for_test(),
363 u64::MAX,
364 ));
365 let plan = PlanFragment {
366 root: Some(PlanNode {
367 children: vec![],
368 identity: "".to_owned(),
369 node_body: Some(NodeBody::BusyLoopExecutor(true)),
370 }),
371 exchange_info: Some(ExchangeInfo {
372 mode: DistributionMode::Single as i32,
373 distribution: None,
374 }),
375 };
376 let task_id = PbTaskId {
377 query_id: "".to_owned(),
378 stage_id: 0,
379 task_id: 0,
380 };
381 manager.fire_task_for_test(&task_id, plan).await.unwrap();
382 let task_id = TaskId::from(&task_id);
383 manager
384 .tasks
385 .lock()
386 .get(&task_id)
387 .unwrap()
388 .abort("Abort Test".to_owned());
389 assert!(manager.wait_until_task_aborted(&task_id).await.is_ok());
390 }
391}