1use std::fmt::{Debug, Formatter};
16use std::panic::AssertUnwindSafe;
17use std::sync::Arc;
18
19use anyhow::Context;
20use futures::StreamExt;
21use parking_lot::Mutex;
22use risingwave_common::array::DataChunk;
23use risingwave_common::util::panic::FutureCatchUnwindExt;
24use risingwave_common::util::runtime::BackgroundShutdownRuntime;
25use risingwave_common::util::tracing::TracingContext;
26use risingwave_expr::expr_context::expr_context_scope;
27use risingwave_pb::PbFieldNotFound;
28use risingwave_pb::batch_plan::{PbTaskId, PbTaskOutputId, PlanFragment};
29use risingwave_pb::common::BatchQueryEpoch;
30use risingwave_pb::plan_common::ExprContext;
31use risingwave_pb::task_service::task_info_response::TaskStatus;
32use risingwave_pb::task_service::{GetDataResponse, TaskInfoResponse};
33use thiserror_ext::AsReport;
34use tokio::select;
35use tokio::task::JoinHandle;
36use tracing::Instrument;
37
38use crate::error::BatchError::SenderError;
39use crate::error::{BatchError, Result, SharedResult};
40use crate::executor::{BoxedExecutor, ExecutorBuilder};
41use crate::rpc::service::exchange::ExchangeWriter;
42use crate::rpc::service::task_service::TaskInfoResponseResult;
43use crate::task::BatchTaskContext;
44use crate::task::channel::{ChanReceiverImpl, ChanSenderImpl, create_output_channel};
45
46pub const TASK_STATUS_BUFFER_SIZE: usize = 2;
48
49#[derive(Clone)]
57pub enum StateReporter {
58 Distributed(tokio::sync::mpsc::Sender<TaskInfoResponseResult>),
59 Mock(),
60}
61
62impl StateReporter {
63 pub async fn send(&mut self, val: TaskInfoResponse) -> Result<()> {
64 match self {
65 Self::Distributed(s) => s.send(Ok(val)).await.map_err(|_| SenderError),
66 Self::Mock() => Ok(()),
67 }
68 }
69
70 pub fn new_with_dist_sender(s: tokio::sync::mpsc::Sender<TaskInfoResponseResult>) -> Self {
71 Self::Distributed(s)
72 }
73
74 pub fn new_with_test() -> Self {
75 Self::Mock()
76 }
77}
78
79#[derive(PartialEq, Eq, Hash, Clone, Debug, Default)]
80pub struct TaskId {
81 pub task_id: u64,
82 pub stage_id: u32,
83 pub query_id: String,
84}
85
86#[derive(PartialEq, Eq, Hash, Clone, Default)]
87pub struct TaskOutputId {
88 pub task_id: TaskId,
89 pub output_id: u64,
90}
91
92impl Debug for TaskOutputId {
94 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
95 f.write_fmt(format_args!(
96 "TaskOutputId {{ query_id: \"{}\", stage_id: {}, task_id: {}, output_id: {} }}",
97 self.task_id.query_id, self.task_id.stage_id, self.task_id.task_id, self.output_id
98 ))
99 }
100}
101
102impl From<&PbTaskId> for TaskId {
103 fn from(prost: &PbTaskId) -> Self {
104 TaskId {
105 task_id: prost.task_id,
106 stage_id: prost.stage_id,
107 query_id: prost.query_id.clone(),
108 }
109 }
110}
111
112impl TaskId {
113 pub fn to_prost(&self) -> PbTaskId {
114 PbTaskId {
115 task_id: self.task_id,
116 stage_id: self.stage_id,
117 query_id: self.query_id.clone(),
118 }
119 }
120}
121
122impl TryFrom<&PbTaskOutputId> for TaskOutputId {
123 type Error = PbFieldNotFound;
124
125 fn try_from(prost: &PbTaskOutputId) -> std::result::Result<Self, PbFieldNotFound> {
126 Ok(TaskOutputId {
127 task_id: TaskId::from(prost.get_task_id()?),
128 output_id: prost.get_output_id(),
129 })
130 }
131}
132
133impl TaskOutputId {
134 pub fn to_prost(&self) -> PbTaskOutputId {
135 PbTaskOutputId {
136 task_id: Some(self.task_id.to_prost()),
137 output_id: self.output_id,
138 }
139 }
140}
141
142pub struct TaskOutput {
143 receiver: ChanReceiverImpl,
144 output_id: TaskOutputId,
145 failure: Arc<Mutex<Option<Arc<BatchError>>>>,
146}
147
148impl std::fmt::Debug for TaskOutput {
149 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
150 f.debug_struct("TaskOutput")
151 .field("output_id", &self.output_id)
152 .field("failure", &self.failure)
153 .finish_non_exhaustive()
154 }
155}
156
157impl TaskOutput {
158 async fn take_data_inner(
161 &mut self,
162 writer: &mut impl ExchangeWriter,
163 at_most_num: Option<usize>,
164 ) -> Result<bool> {
165 let mut cnt: usize = 0;
166 let limited = at_most_num.is_some();
167 let at_most_num = at_most_num.unwrap_or(usize::MAX);
168 loop {
169 if limited && cnt >= at_most_num {
170 return Ok(false);
171 }
172 match self.receiver.recv().await {
173 Ok(Some(chunk)) => {
175 trace!(
176 "Task output id: {:?}, data len: {:?}",
177 self.output_id,
178 chunk.cardinality()
179 );
180 let pb = chunk.to_protobuf().await;
181 let resp = GetDataResponse {
182 record_batch: Some(pb),
183 };
184 writer.write(Ok(resp)).await?;
185 }
186 Ok(None) => {
188 break;
189 }
190 Err(e) => {
192 writer.write(Err(tonic::Status::from(&*e))).await?;
193 break;
194 }
195 }
196 cnt += 1;
197 }
198 Ok(true)
199 }
200
201 pub async fn take_data_with_num(
204 &mut self,
205 writer: &mut impl ExchangeWriter,
206 num: usize,
207 ) -> Result<bool> {
208 self.take_data_inner(writer, Some(num)).await
209 }
210
211 pub async fn take_data(&mut self, writer: &mut impl ExchangeWriter) -> Result<()> {
213 let finish = self.take_data_inner(writer, None).await?;
214 assert!(finish);
215 Ok(())
216 }
217
218 pub async fn direct_take_data(&mut self) -> SharedResult<Option<DataChunk>> {
220 Ok(self.receiver.recv().await?.map(|c| c.into_data_chunk()))
221 }
222
223 pub fn id(&self) -> &TaskOutputId {
224 &self.output_id
225 }
226}
227
228#[derive(Clone, Debug)]
229pub enum ShutdownMsg {
230 Init,
232 Abort(String),
233 Cancel,
234}
235
236pub struct ShutdownSender(tokio::sync::watch::Sender<ShutdownMsg>);
238
239impl ShutdownSender {
240 pub fn cancel(&self) -> bool {
242 self.0.send(ShutdownMsg::Cancel).is_ok()
243 }
244
245 pub fn abort(&self, msg: impl Into<String>) -> bool {
247 self.0.send(ShutdownMsg::Abort(msg.into())).is_ok()
248 }
249}
250
251#[derive(Clone)]
253pub struct ShutdownToken(tokio::sync::watch::Receiver<ShutdownMsg>);
254
255impl ShutdownToken {
256 pub fn empty() -> Self {
258 Self::new().1
259 }
260
261 pub fn new() -> (ShutdownSender, Self) {
263 let (tx, rx) = tokio::sync::watch::channel(ShutdownMsg::Init);
264 (ShutdownSender(tx), ShutdownToken(rx))
265 }
266
267 pub fn check(&self) -> Result<()> {
269 match &*self.0.borrow() {
270 ShutdownMsg::Init => Ok(()),
271 msg => bail!("Receive shutdown msg: {msg:?}"),
272 }
273 }
274
275 pub async fn cancelled(&mut self) {
280 if matches!(*self.0.borrow(), ShutdownMsg::Init) {
281 if let Err(_err) = self.0.changed().await {
282 std::future::pending::<()>().await;
283 }
284 }
285 }
286
287 pub fn is_cancelled(&self) -> bool {
289 !matches!(*self.0.borrow(), ShutdownMsg::Init)
290 }
291
292 pub fn message(&self) -> ShutdownMsg {
294 self.0.borrow().clone()
295 }
296}
297
298pub struct BatchTaskExecution {
300 task_id: TaskId,
302
303 plan: PlanFragment,
305
306 state: Mutex<TaskStatus>,
308
309 receivers: Mutex<Vec<Option<ChanReceiverImpl>>>,
311
312 sender: ChanSenderImpl,
314
315 context: Arc<dyn BatchTaskContext>,
317
318 failure: Arc<Mutex<Option<Arc<BatchError>>>>,
320
321 epoch: BatchQueryEpoch,
322
323 runtime: Arc<BackgroundShutdownRuntime>,
325
326 shutdown_tx: ShutdownSender,
327 shutdown_rx: ShutdownToken,
328 heartbeat_join_handle: Mutex<Option<JoinHandle<()>>>,
329}
330
331impl BatchTaskExecution {
332 pub fn new(
333 prost_tid: &PbTaskId,
334 plan: PlanFragment,
335 context: Arc<dyn BatchTaskContext>,
336 epoch: BatchQueryEpoch,
337 runtime: Arc<BackgroundShutdownRuntime>,
338 ) -> Result<Self> {
339 let task_id = TaskId::from(prost_tid);
340
341 let (sender, receivers) = create_output_channel(
342 plan.get_exchange_info()?,
343 context.get_config().developer.output_channel_size,
344 )?;
345
346 let mut rts = Vec::new();
347 rts.extend(receivers.into_iter().map(Some));
348
349 let (shutdown_tx, shutdown_rx) = ShutdownToken::new();
350 Ok(Self {
351 task_id,
352 plan,
353 state: Mutex::new(TaskStatus::Pending),
354 receivers: Mutex::new(rts),
355 failure: Arc::new(Mutex::new(None)),
356 epoch,
357 context,
358 runtime,
359 sender,
360 shutdown_tx,
361 shutdown_rx,
362 heartbeat_join_handle: Mutex::new(None),
363 })
364 }
365
366 pub fn get_task_id(&self) -> &TaskId {
367 &self.task_id
368 }
369
370 pub async fn async_execute(
377 self: Arc<Self>,
378 state_tx: Option<StateReporter>,
379 tracing_context: TracingContext,
380 expr_context: ExprContext,
381 ) -> Result<()> {
382 let mut state_tx = state_tx;
383 trace!(
384 "Prepare executing plan [{:?}]: {}",
385 self.task_id,
386 serde_json::to_string_pretty(self.plan.get_root()?).unwrap()
387 );
388
389 let exec = expr_context_scope(
390 expr_context.clone(),
391 ExecutorBuilder::new(
392 self.plan.root.as_ref().unwrap(),
393 &self.task_id,
394 self.context.clone(),
395 self.epoch,
396 self.shutdown_rx.clone(),
397 )
398 .build(),
399 )
400 .await?;
401
402 let sender = self.sender.clone();
403 let _failure = self.failure.clone();
404 let task_id = self.task_id.clone();
405
406 self.change_state_notify(TaskStatus::Running, state_tx.as_mut(), None)
410 .await?;
411
412 let t_1 = self.clone();
414 let this = self.clone();
415 async fn notify_panic(
416 this: &BatchTaskExecution,
417 state_tx: Option<&mut StateReporter>,
418 message: Option<&str>,
419 ) {
420 let err_str = if let Some(message) = message {
421 format!("execution panic: {}", message)
422 } else {
423 "execution panic".into()
424 };
425
426 if let Err(e) = this
427 .change_state_notify(TaskStatus::Failed, state_tx, Some(err_str))
428 .await
429 {
430 warn!(
431 error = %e.as_report(),
432 "The status receiver in FE has closed so the status push is failed",
433 );
434 }
435 }
436 let fut = async move {
438 trace!("Executing plan [{:?}]", task_id);
439 let sender = sender;
440 let mut state_tx_1 = state_tx.clone();
441
442 let task = |task_id: TaskId| async move {
443 let span = tracing_context.attach(tracing::info_span!(
444 "batch_execute",
445 task_id = task_id.task_id,
446 stage_id = task_id.stage_id,
447 query_id = task_id.query_id,
448 ));
449
450 expr_context_scope(
453 expr_context,
454 t_1.run(exec, sender, state_tx_1.as_mut()).instrument(span),
455 )
456 .await;
457 };
458
459 if let Err(error) = AssertUnwindSafe(task(task_id.clone()))
460 .rw_catch_unwind()
461 .await
462 {
463 let message = panic_message::get_panic_message(&error);
464 error!(?task_id, error = message, "Batch task panic");
465 notify_panic(&this, state_tx.as_mut(), message).await;
466 }
467 };
468
469 self.runtime.spawn(fut);
470
471 Ok(())
472 }
473
474 pub async fn change_state_notify(
476 &self,
477 task_status: TaskStatus,
478 state_tx: Option<&mut StateReporter>,
479 err_str: Option<String>,
480 ) -> Result<()> {
481 self.change_state(task_status);
482 if let Some(reporter) = state_tx {
484 reporter
485 .send(TaskInfoResponse {
486 task_id: Some(self.task_id.to_prost()),
487 task_status: task_status.into(),
488 error_message: err_str.unwrap_or("".to_owned()),
489 })
490 .await
491 } else {
492 Ok(())
493 }
494 }
495
496 pub fn change_state(&self, task_status: TaskStatus) {
497 *self.state.lock() = task_status;
498 tracing::debug!(
499 "Task {:?} state changed to {:?}",
500 &self.task_id,
501 task_status
502 );
503 }
504
505 async fn run(
506 &self,
507 root: BoxedExecutor,
508 mut sender: ChanSenderImpl,
509 state_tx: Option<&mut StateReporter>,
510 ) {
511 self.context
512 .batch_metrics()
513 .as_ref()
514 .inspect(|m| m.batch_manager_metrics().task_num.inc());
515 let mut data_chunk_stream = root.execute();
516 let mut state;
517 let mut error = None;
518
519 let mut shutdown_rx = self.shutdown_rx.clone();
520 loop {
521 select! {
522 biased;
523 _ = shutdown_rx.cancelled() => {
525 match self.shutdown_rx.message() {
526 ShutdownMsg::Abort(e) => {
527 error = Some(BatchError::Aborted(e));
528 state = TaskStatus::Aborted;
529 break;
530 }
531 ShutdownMsg::Cancel => {
532 state = TaskStatus::Cancelled;
533 break;
534 }
535 ShutdownMsg::Init => {
536 unreachable!("Init message should not be received here!")
537 }
538 }
539 }
540 data_chunk = data_chunk_stream.next()=> {
541 match data_chunk {
542 Some(Ok(data_chunk)) => {
543 if let Err(e) = sender.send(data_chunk).await {
544 match e {
545 BatchError::SenderError => {
546 warn!("Task receiver closed!");
550 state = TaskStatus::Finished;
551 break;
552 }
553 x => {
554 error!("Failed to send data!");
555 error = Some(x);
556 state = TaskStatus::Failed;
557 break;
558 }
559 }
560 }
561 }
562 Some(Err(e)) => match self.shutdown_rx.message() {
563 ShutdownMsg::Init => {
564 error!(error = %e.as_report(), "Batch task failed");
567 error = Some(e);
568 state = TaskStatus::Failed;
569 break;
570 }
571 ShutdownMsg::Abort(_) => {
572 error = Some(e);
573 state = TaskStatus::Aborted;
574 break;
575 }
576 ShutdownMsg::Cancel => {
577 state = TaskStatus::Cancelled;
578 break;
579 }
580 },
581 None => {
582 debug!("Batch task {:?} finished successfully.", self.task_id);
583 state = TaskStatus::Finished;
584 break;
585 }
586 }
587 }
588 }
589 }
590
591 let error = error.map(Arc::new);
592 self.failure.lock().clone_from(&error);
593 let err_str = error.as_ref().map(|e| e.to_report_string());
594 if let Err(e) = sender.close(error).await {
595 match e {
596 SenderError => {
597 warn!("Task receiver closed when sending None!");
601 }
602 _x => {
603 error!("Failed to close task output channel: {:?}", self.task_id);
604 state = TaskStatus::Failed;
605 }
606 }
607 }
608
609 if let Err(e) = self.change_state_notify(state, state_tx, err_str).await {
610 warn!(
611 error = %e.as_report(),
612 "The status receiver in FE has closed so the status push is failed",
613 );
614 }
615
616 self.context
617 .batch_metrics()
618 .as_ref()
619 .inspect(|m| m.batch_manager_metrics().task_num.dec());
620 }
621
622 pub fn abort(&self, err_msg: String) {
623 if self.shutdown_tx.abort(err_msg) {
626 info!("Abort task {:?} done", self.task_id);
627 } else {
628 debug!("The task has already died before this request.")
629 }
630 }
631
632 pub fn cancel(&self) {
633 if !self.shutdown_tx.cancel() {
634 debug!("The task has already died before this request.");
635 }
636 }
637
638 pub fn get_task_output(&self, output_id: &PbTaskOutputId) -> Result<TaskOutput> {
639 let task_id = TaskId::from(output_id.get_task_id()?);
640 let receiver = self.receivers.lock()[output_id.get_output_id() as usize]
641 .take()
642 .with_context(|| {
643 format!(
644 "Task{:?}'s output{} has already been taken.",
645 task_id,
646 output_id.get_output_id(),
647 )
648 })?;
649 let task_output = TaskOutput {
650 receiver,
651 output_id: output_id.try_into()?,
652 failure: self.failure.clone(),
653 };
654 Ok(task_output)
655 }
656
657 pub fn check_if_running(&self) -> Result<()> {
658 if *self.state.lock() != TaskStatus::Running {
659 bail!("task {:?} is not running", self.get_task_id());
660 }
661 Ok(())
662 }
663
664 pub fn check_if_aborted(&self) -> Result<bool> {
665 match *self.state.lock() {
666 TaskStatus::Aborted => Ok(true),
667 TaskStatus::Finished => bail!("task {:?} has been finished", self.get_task_id()),
668 _ => Ok(false),
669 }
670 }
671
672 pub fn is_end(&self) -> bool {
674 let guard = self.state.lock();
675 !(*guard == TaskStatus::Running || *guard == TaskStatus::Pending)
676 }
677}
678
679impl BatchTaskExecution {
680 pub(crate) fn set_heartbeat_join_handle(&self, join_handle: JoinHandle<()>) {
681 *self.heartbeat_join_handle.lock() = Some(join_handle);
682 }
683
684 pub(crate) fn heartbeat_join_handle(&self) -> Option<JoinHandle<()>> {
685 self.heartbeat_join_handle.lock().take()
686 }
687}
688
689#[cfg(test)]
690mod tests {
691 use super::*;
692
693 #[test]
694 fn test_task_output_id_debug() {
695 let task_id = TaskId {
696 task_id: 1,
697 stage_id: 2,
698 query_id: "abc".to_owned(),
699 };
700 let task_output_id = TaskOutputId {
701 task_id,
702 output_id: 3,
703 };
704 assert_eq!(
705 format!("{:?}", task_output_id),
706 "TaskOutputId { query_id: \"abc\", stage_id: 2, task_id: 1, output_id: 3 }"
707 );
708 }
709}