1use std::fmt::{Debug, Formatter};
16use std::panic::AssertUnwindSafe;
17use std::sync::Arc;
18
19use anyhow::Context;
20use futures::StreamExt;
21use parking_lot::Mutex;
22use risingwave_common::array::DataChunk;
23use risingwave_common::util::panic::FutureCatchUnwindExt;
24use risingwave_common::util::runtime::BackgroundShutdownRuntime;
25use risingwave_common::util::tracing::TracingContext;
26use risingwave_expr::expr_context::expr_context_scope;
27use risingwave_pb::PbFieldNotFound;
28use risingwave_pb::batch_plan::{PbTaskId, PbTaskOutputId, PlanFragment};
29use risingwave_pb::plan_common::ExprContext;
30use risingwave_pb::task_service::task_info_response::TaskStatus;
31use risingwave_pb::task_service::{GetDataResponse, TaskInfoResponse};
32use thiserror_ext::AsReport;
33use tokio::select;
34use tokio::task::JoinHandle;
35use tracing::Instrument;
36
37use crate::error::BatchError::SenderError;
38use crate::error::{BatchError, Result, SharedResult};
39use crate::executor::{BoxedExecutor, ExecutorBuilder};
40use crate::rpc::service::exchange::ExchangeWriter;
41use crate::rpc::service::task_service::TaskInfoResponseResult;
42use crate::task::BatchTaskContext;
43use crate::task::channel::{ChanReceiverImpl, ChanSenderImpl, create_output_channel};
44
45pub const TASK_STATUS_BUFFER_SIZE: usize = 2;
47
48#[derive(Clone)]
56pub enum StateReporter {
57 Distributed(tokio::sync::mpsc::Sender<TaskInfoResponseResult>),
58 Mock(),
59}
60
61impl StateReporter {
62 pub async fn send(&mut self, val: TaskInfoResponse) -> Result<()> {
63 match self {
64 Self::Distributed(s) => s.send(Ok(val)).await.map_err(|_| SenderError),
65 Self::Mock() => Ok(()),
66 }
67 }
68
69 pub fn new_with_dist_sender(s: tokio::sync::mpsc::Sender<TaskInfoResponseResult>) -> Self {
70 Self::Distributed(s)
71 }
72
73 pub fn new_with_test() -> Self {
74 Self::Mock()
75 }
76}
77
78#[derive(PartialEq, Eq, Hash, Clone, Debug, Default)]
79pub struct TaskId {
80 pub task_id: u64,
81 pub stage_id: u32,
82 pub query_id: String,
83}
84
85#[derive(PartialEq, Eq, Hash, Clone, Default)]
86pub struct TaskOutputId {
87 pub task_id: TaskId,
88 pub output_id: u64,
89}
90
91impl Debug for TaskOutputId {
93 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
94 f.write_fmt(format_args!(
95 "TaskOutputId {{ query_id: \"{}\", stage_id: {}, task_id: {}, output_id: {} }}",
96 self.task_id.query_id, self.task_id.stage_id, self.task_id.task_id, self.output_id
97 ))
98 }
99}
100
101impl From<&PbTaskId> for TaskId {
102 fn from(prost: &PbTaskId) -> Self {
103 TaskId {
104 task_id: prost.task_id,
105 stage_id: prost.stage_id,
106 query_id: prost.query_id.clone(),
107 }
108 }
109}
110
111impl TaskId {
112 pub fn to_prost(&self) -> PbTaskId {
113 PbTaskId {
114 task_id: self.task_id,
115 stage_id: self.stage_id,
116 query_id: self.query_id.clone(),
117 }
118 }
119}
120
121impl TryFrom<&PbTaskOutputId> for TaskOutputId {
122 type Error = PbFieldNotFound;
123
124 fn try_from(prost: &PbTaskOutputId) -> std::result::Result<Self, PbFieldNotFound> {
125 Ok(TaskOutputId {
126 task_id: TaskId::from(prost.get_task_id()?),
127 output_id: prost.get_output_id(),
128 })
129 }
130}
131
132impl TaskOutputId {
133 pub fn to_prost(&self) -> PbTaskOutputId {
134 PbTaskOutputId {
135 task_id: Some(self.task_id.to_prost()),
136 output_id: self.output_id,
137 }
138 }
139}
140
141pub struct TaskOutput {
142 receiver: ChanReceiverImpl,
143 output_id: TaskOutputId,
144 failure: Arc<Mutex<Option<Arc<BatchError>>>>,
145}
146
147impl std::fmt::Debug for TaskOutput {
148 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
149 f.debug_struct("TaskOutput")
150 .field("output_id", &self.output_id)
151 .field("failure", &self.failure)
152 .finish_non_exhaustive()
153 }
154}
155
156impl TaskOutput {
157 async fn take_data_inner(
160 &mut self,
161 writer: &mut impl ExchangeWriter,
162 at_most_num: Option<usize>,
163 ) -> Result<bool> {
164 let mut cnt: usize = 0;
165 let limited = at_most_num.is_some();
166 let at_most_num = at_most_num.unwrap_or(usize::MAX);
167 loop {
168 if limited && cnt >= at_most_num {
169 return Ok(false);
170 }
171 match self.receiver.recv().await {
172 Ok(Some(chunk)) => {
174 trace!(
175 "Task output id: {:?}, data len: {:?}",
176 self.output_id,
177 chunk.cardinality()
178 );
179 let pb = chunk.to_protobuf().await;
180 let resp = GetDataResponse {
181 record_batch: Some(pb),
182 };
183 writer.write(Ok(resp)).await?;
184 }
185 Ok(None) => {
187 break;
188 }
189 Err(e) => {
191 writer.write(Err(tonic::Status::from(&*e))).await?;
192 break;
193 }
194 }
195 cnt += 1;
196 }
197 Ok(true)
198 }
199
200 pub async fn take_data_with_num(
203 &mut self,
204 writer: &mut impl ExchangeWriter,
205 num: usize,
206 ) -> Result<bool> {
207 self.take_data_inner(writer, Some(num)).await
208 }
209
210 pub async fn take_data(&mut self, writer: &mut impl ExchangeWriter) -> Result<()> {
212 let finish = self.take_data_inner(writer, None).await?;
213 assert!(finish);
214 Ok(())
215 }
216
217 pub async fn direct_take_data(&mut self) -> SharedResult<Option<DataChunk>> {
219 Ok(self.receiver.recv().await?.map(|c| c.into_data_chunk()))
220 }
221
222 pub fn id(&self) -> &TaskOutputId {
223 &self.output_id
224 }
225}
226
227#[derive(Clone, Debug)]
228pub enum ShutdownMsg {
229 Init,
231 Abort(String),
232 Cancel,
233}
234
235pub struct ShutdownSender(tokio::sync::watch::Sender<ShutdownMsg>);
237
238impl ShutdownSender {
239 pub fn cancel(&self) -> bool {
241 self.0.send(ShutdownMsg::Cancel).is_ok()
242 }
243
244 pub fn abort(&self, msg: impl Into<String>) -> bool {
246 self.0.send(ShutdownMsg::Abort(msg.into())).is_ok()
247 }
248}
249
250#[derive(Clone)]
252pub struct ShutdownToken(tokio::sync::watch::Receiver<ShutdownMsg>);
253
254impl ShutdownToken {
255 pub fn empty() -> Self {
257 Self::new().1
258 }
259
260 pub fn new() -> (ShutdownSender, Self) {
262 let (tx, rx) = tokio::sync::watch::channel(ShutdownMsg::Init);
263 (ShutdownSender(tx), ShutdownToken(rx))
264 }
265
266 pub fn check(&self) -> Result<()> {
268 match &*self.0.borrow() {
269 ShutdownMsg::Init => Ok(()),
270 msg => bail!("Receive shutdown msg: {msg:?}"),
271 }
272 }
273
274 pub async fn cancelled(&mut self) {
279 if matches!(*self.0.borrow(), ShutdownMsg::Init)
280 && let Err(_err) = self.0.changed().await
281 {
282 std::future::pending::<()>().await;
283 }
284 }
285
286 pub fn is_cancelled(&self) -> bool {
288 !matches!(*self.0.borrow(), ShutdownMsg::Init)
289 }
290
291 pub fn message(&self) -> ShutdownMsg {
293 self.0.borrow().clone()
294 }
295}
296
297pub struct BatchTaskExecution {
299 task_id: TaskId,
301
302 plan: PlanFragment,
304
305 state: Mutex<TaskStatus>,
307
308 receivers: Mutex<Vec<Option<ChanReceiverImpl>>>,
310
311 sender: ChanSenderImpl,
313
314 context: Arc<dyn BatchTaskContext>,
316
317 failure: Arc<Mutex<Option<Arc<BatchError>>>>,
319
320 runtime: Arc<BackgroundShutdownRuntime>,
322
323 shutdown_tx: ShutdownSender,
324 shutdown_rx: ShutdownToken,
325 heartbeat_join_handle: Mutex<Option<JoinHandle<()>>>,
326}
327
328impl BatchTaskExecution {
329 pub fn new(
330 prost_tid: &PbTaskId,
331 plan: PlanFragment,
332 context: Arc<dyn BatchTaskContext>,
333 runtime: Arc<BackgroundShutdownRuntime>,
334 ) -> Result<Self> {
335 let task_id = TaskId::from(prost_tid);
336
337 let (sender, receivers) = create_output_channel(
338 plan.get_exchange_info()?,
339 context.get_config().developer.output_channel_size,
340 )?;
341
342 let mut rts = Vec::new();
343 rts.extend(receivers.into_iter().map(Some));
344
345 let (shutdown_tx, shutdown_rx) = ShutdownToken::new();
346 Ok(Self {
347 task_id,
348 plan,
349 state: Mutex::new(TaskStatus::Pending),
350 receivers: Mutex::new(rts),
351 failure: Arc::new(Mutex::new(None)),
352 context,
353 runtime,
354 sender,
355 shutdown_tx,
356 shutdown_rx,
357 heartbeat_join_handle: Mutex::new(None),
358 })
359 }
360
361 pub fn get_task_id(&self) -> &TaskId {
362 &self.task_id
363 }
364
365 pub async fn async_execute(
372 self: Arc<Self>,
373 state_tx: Option<StateReporter>,
374 tracing_context: TracingContext,
375 expr_context: ExprContext,
376 ) -> Result<()> {
377 let mut state_tx = state_tx;
378 trace!(
379 "Prepare executing plan [{:?}]: {}",
380 self.task_id,
381 serde_json::to_string_pretty(self.plan.get_root()?).unwrap()
382 );
383
384 let exec = expr_context_scope(
385 expr_context.clone(),
386 ExecutorBuilder::new(
387 self.plan.root.as_ref().unwrap(),
388 &self.task_id,
389 self.context.clone(),
390 self.shutdown_rx.clone(),
391 )
392 .build(),
393 )
394 .await?;
395
396 let sender = self.sender.clone();
397 let _failure = self.failure.clone();
398 let task_id = self.task_id.clone();
399
400 self.change_state_notify(TaskStatus::Running, state_tx.as_mut(), None)
404 .await?;
405
406 let t_1 = self.clone();
408 let this = self.clone();
409 async fn notify_panic(
410 this: &BatchTaskExecution,
411 state_tx: Option<&mut StateReporter>,
412 message: Option<&str>,
413 ) {
414 let err_str = if let Some(message) = message {
415 format!("execution panic: {}", message)
416 } else {
417 "execution panic".into()
418 };
419
420 if let Err(e) = this
421 .change_state_notify(TaskStatus::Failed, state_tx, Some(err_str))
422 .await
423 {
424 warn!(
425 error = %e.as_report(),
426 "The status receiver in FE has closed so the status push is failed",
427 );
428 }
429 }
430 let fut = async move {
432 trace!("Executing plan [{:?}]", task_id);
433 let sender = sender;
434 let mut state_tx_1 = state_tx.clone();
435
436 let task = |task_id: TaskId| async move {
437 let span = tracing_context.attach(tracing::info_span!(
438 "batch_execute",
439 task_id = task_id.task_id,
440 stage_id = task_id.stage_id,
441 query_id = task_id.query_id,
442 ));
443
444 expr_context_scope(
447 expr_context,
448 t_1.run(exec, sender, state_tx_1.as_mut()).instrument(span),
449 )
450 .await;
451 };
452
453 if let Err(error) = AssertUnwindSafe(task(task_id.clone()))
454 .rw_catch_unwind()
455 .await
456 {
457 let message = panic_message::get_panic_message(&error);
458 error!(?task_id, error = message, "Batch task panic");
459 notify_panic(&this, state_tx.as_mut(), message).await;
460 }
461 };
462
463 self.runtime.spawn(fut);
464
465 Ok(())
466 }
467
468 pub async fn change_state_notify(
470 &self,
471 task_status: TaskStatus,
472 state_tx: Option<&mut StateReporter>,
473 err_str: Option<String>,
474 ) -> Result<()> {
475 self.change_state(task_status);
476 if let Some(reporter) = state_tx {
478 reporter
479 .send(TaskInfoResponse {
480 task_id: Some(self.task_id.to_prost()),
481 task_status: task_status.into(),
482 error_message: err_str.unwrap_or("".to_owned()),
483 })
484 .await
485 } else {
486 Ok(())
487 }
488 }
489
490 pub fn change_state(&self, task_status: TaskStatus) {
491 *self.state.lock() = task_status;
492 tracing::debug!(
493 "Task {:?} state changed to {:?}",
494 &self.task_id,
495 task_status
496 );
497 }
498
499 async fn run(
500 &self,
501 root: BoxedExecutor,
502 mut sender: ChanSenderImpl,
503 state_tx: Option<&mut StateReporter>,
504 ) {
505 self.context
506 .batch_metrics()
507 .as_ref()
508 .inspect(|m| m.batch_manager_metrics().task_num.inc());
509 let mut data_chunk_stream = root.execute();
510 let mut state;
511 let mut error = None;
512
513 let mut shutdown_rx = self.shutdown_rx.clone();
514 loop {
515 select! {
516 biased;
517 _ = shutdown_rx.cancelled() => {
519 match self.shutdown_rx.message() {
520 ShutdownMsg::Abort(e) => {
521 error = Some(BatchError::Aborted(e));
522 state = TaskStatus::Aborted;
523 break;
524 }
525 ShutdownMsg::Cancel => {
526 state = TaskStatus::Cancelled;
527 break;
528 }
529 ShutdownMsg::Init => {
530 unreachable!("Init message should not be received here!")
531 }
532 }
533 }
534 data_chunk = data_chunk_stream.next()=> {
535 match data_chunk {
536 Some(Ok(data_chunk)) => {
537 if let Err(e) = sender.send(data_chunk).await {
538 match e {
539 BatchError::SenderError => {
540 warn!("Task receiver closed!");
544 state = TaskStatus::Finished;
545 break;
546 }
547 x => {
548 error!("Failed to send data!");
549 error = Some(x);
550 state = TaskStatus::Failed;
551 break;
552 }
553 }
554 }
555 }
556 Some(Err(e)) => match self.shutdown_rx.message() {
557 ShutdownMsg::Init => {
558 error!(error = %e.as_report(), "Batch task failed");
561 error = Some(e);
562 state = TaskStatus::Failed;
563 break;
564 }
565 ShutdownMsg::Abort(_) => {
566 error = Some(e);
567 state = TaskStatus::Aborted;
568 break;
569 }
570 ShutdownMsg::Cancel => {
571 state = TaskStatus::Cancelled;
572 break;
573 }
574 },
575 None => {
576 debug!("Batch task {:?} finished successfully.", self.task_id);
577 state = TaskStatus::Finished;
578 break;
579 }
580 }
581 }
582 }
583 }
584
585 let error = error.map(Arc::new);
586 self.failure.lock().clone_from(&error);
587 let err_str = error.as_ref().map(|e| e.to_report_string());
588 if let Err(e) = sender.close(error).await {
589 match e {
590 SenderError => {
591 warn!("Task receiver closed when sending None!");
595 }
596 _x => {
597 error!("Failed to close task output channel: {:?}", self.task_id);
598 state = TaskStatus::Failed;
599 }
600 }
601 }
602
603 if let Err(e) = self.change_state_notify(state, state_tx, err_str).await {
604 warn!(
605 error = %e.as_report(),
606 "The status receiver in FE has closed so the status push is failed",
607 );
608 }
609
610 self.context
611 .batch_metrics()
612 .as_ref()
613 .inspect(|m| m.batch_manager_metrics().task_num.dec());
614 }
615
616 pub fn abort(&self, err_msg: String) {
617 if self.shutdown_tx.abort(err_msg) {
620 info!("Abort task {:?} done", self.task_id);
621 } else {
622 debug!("The task has already died before this request.")
623 }
624 }
625
626 pub fn cancel(&self) {
627 if !self.shutdown_tx.cancel() {
628 debug!("The task has already died before this request.");
629 }
630 }
631
632 pub fn get_task_output(&self, output_id: &PbTaskOutputId) -> Result<TaskOutput> {
633 let task_id = TaskId::from(output_id.get_task_id()?);
634 let receiver = self.receivers.lock()[output_id.get_output_id() as usize]
635 .take()
636 .with_context(|| {
637 format!(
638 "Task{:?}'s output{} has already been taken.",
639 task_id,
640 output_id.get_output_id(),
641 )
642 })?;
643 let task_output = TaskOutput {
644 receiver,
645 output_id: output_id.try_into()?,
646 failure: self.failure.clone(),
647 };
648 Ok(task_output)
649 }
650
651 pub fn check_if_running(&self) -> Result<()> {
652 if *self.state.lock() != TaskStatus::Running {
653 bail!("task {:?} is not running", self.get_task_id());
654 }
655 Ok(())
656 }
657
658 pub fn check_if_aborted(&self) -> Result<bool> {
659 match *self.state.lock() {
660 TaskStatus::Aborted => Ok(true),
661 TaskStatus::Finished => bail!("task {:?} has been finished", self.get_task_id()),
662 _ => Ok(false),
663 }
664 }
665
666 pub fn is_end(&self) -> bool {
668 let guard = self.state.lock();
669 !(*guard == TaskStatus::Running || *guard == TaskStatus::Pending)
670 }
671}
672
673impl BatchTaskExecution {
674 pub(crate) fn set_heartbeat_join_handle(&self, join_handle: JoinHandle<()>) {
675 *self.heartbeat_join_handle.lock() = Some(join_handle);
676 }
677
678 pub(crate) fn heartbeat_join_handle(&self) -> Option<JoinHandle<()>> {
679 self.heartbeat_join_handle.lock().take()
680 }
681}
682
683#[cfg(test)]
684mod tests {
685 use super::*;
686
687 #[test]
688 fn test_task_output_id_debug() {
689 let task_id = TaskId {
690 task_id: 1,
691 stage_id: 2,
692 query_id: "abc".to_owned(),
693 };
694 let task_output_id = TaskOutputId {
695 task_id,
696 output_id: 3,
697 };
698 assert_eq!(
699 format!("{:?}", task_output_id),
700 "TaskOutputId { query_id: \"abc\", stage_id: 2, task_id: 1, output_id: 3 }"
701 );
702 }
703}