risingwave_batch/task/
task_execution.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::{Debug, Formatter};
16use std::panic::AssertUnwindSafe;
17use std::sync::Arc;
18
19use anyhow::Context;
20use futures::StreamExt;
21use parking_lot::Mutex;
22use risingwave_common::array::DataChunk;
23use risingwave_common::util::panic::FutureCatchUnwindExt;
24use risingwave_common::util::runtime::BackgroundShutdownRuntime;
25use risingwave_common::util::tracing::TracingContext;
26use risingwave_expr::expr_context::expr_context_scope;
27use risingwave_pb::PbFieldNotFound;
28use risingwave_pb::batch_plan::{PbTaskId, PbTaskOutputId, PlanFragment};
29use risingwave_pb::plan_common::ExprContext;
30use risingwave_pb::task_service::task_info_response::TaskStatus;
31use risingwave_pb::task_service::{GetDataResponse, TaskInfoResponse};
32use thiserror_ext::AsReport;
33use tokio::select;
34use tokio::task::JoinHandle;
35use tracing::Instrument;
36
37use crate::error::BatchError::SenderError;
38use crate::error::{BatchError, Result, SharedResult};
39use crate::executor::{BoxedExecutor, ExecutorBuilder};
40use crate::rpc::service::exchange::ExchangeWriter;
41use crate::rpc::service::task_service::TaskInfoResponseResult;
42use crate::task::BatchTaskContext;
43use crate::task::channel::{ChanReceiverImpl, ChanSenderImpl, create_output_channel};
44
45// Now we will only at most have 2 status for each status channel. Running -> Failed or Finished.
46pub const TASK_STATUS_BUFFER_SIZE: usize = 2;
47
48/// Send batch task status (local/distributed) to frontend.
49///
50///
51/// Local mode use `StateReporter::Local`, Distributed mode use `StateReporter::Distributed` to send
52/// status (Failed/Finished) update. `StateReporter::Mock` is only used in test and do not takes any
53/// effect. Local sender only report Failed update, Distributed sender will also report
54/// Finished/Pending/Starting/Aborted etc.
55#[derive(Clone)]
56pub enum StateReporter {
57    Distributed(tokio::sync::mpsc::Sender<TaskInfoResponseResult>),
58    Mock(),
59}
60
61impl StateReporter {
62    pub async fn send(&mut self, val: TaskInfoResponse) -> Result<()> {
63        match self {
64            Self::Distributed(s) => s.send(Ok(val)).await.map_err(|_| SenderError),
65            Self::Mock() => Ok(()),
66        }
67    }
68
69    pub fn new_with_dist_sender(s: tokio::sync::mpsc::Sender<TaskInfoResponseResult>) -> Self {
70        Self::Distributed(s)
71    }
72
73    pub fn new_with_test() -> Self {
74        Self::Mock()
75    }
76}
77
78#[derive(PartialEq, Eq, Hash, Clone, Debug, Default)]
79pub struct TaskId {
80    pub task_id: u64,
81    pub stage_id: u32,
82    pub query_id: String,
83}
84
85#[derive(PartialEq, Eq, Hash, Clone, Default)]
86pub struct TaskOutputId {
87    pub task_id: TaskId,
88    pub output_id: u64,
89}
90
91/// More compact formatter compared to derived `fmt::Debug`.
92impl Debug for TaskOutputId {
93    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
94        f.write_fmt(format_args!(
95            "TaskOutputId {{ query_id: \"{}\", stage_id: {}, task_id: {}, output_id: {} }}",
96            self.task_id.query_id, self.task_id.stage_id, self.task_id.task_id, self.output_id
97        ))
98    }
99}
100
101impl From<&PbTaskId> for TaskId {
102    fn from(prost: &PbTaskId) -> Self {
103        TaskId {
104            task_id: prost.task_id,
105            stage_id: prost.stage_id,
106            query_id: prost.query_id.clone(),
107        }
108    }
109}
110
111impl TaskId {
112    pub fn to_prost(&self) -> PbTaskId {
113        PbTaskId {
114            task_id: self.task_id,
115            stage_id: self.stage_id,
116            query_id: self.query_id.clone(),
117        }
118    }
119}
120
121impl TryFrom<&PbTaskOutputId> for TaskOutputId {
122    type Error = PbFieldNotFound;
123
124    fn try_from(prost: &PbTaskOutputId) -> std::result::Result<Self, PbFieldNotFound> {
125        Ok(TaskOutputId {
126            task_id: TaskId::from(prost.get_task_id()?),
127            output_id: prost.get_output_id(),
128        })
129    }
130}
131
132impl TaskOutputId {
133    pub fn to_prost(&self) -> PbTaskOutputId {
134        PbTaskOutputId {
135            task_id: Some(self.task_id.to_prost()),
136            output_id: self.output_id,
137        }
138    }
139}
140
141pub struct TaskOutput {
142    receiver: ChanReceiverImpl,
143    output_id: TaskOutputId,
144    failure: Arc<Mutex<Option<Arc<BatchError>>>>,
145}
146
147impl std::fmt::Debug for TaskOutput {
148    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
149        f.debug_struct("TaskOutput")
150            .field("output_id", &self.output_id)
151            .field("failure", &self.failure)
152            .finish_non_exhaustive()
153    }
154}
155
156impl TaskOutput {
157    /// Write the data in serialized format to `ExchangeWriter`.
158    /// Return whether the data stream is finished.
159    async fn take_data_inner(
160        &mut self,
161        writer: &mut impl ExchangeWriter,
162        at_most_num: Option<usize>,
163    ) -> Result<bool> {
164        let mut cnt: usize = 0;
165        let limited = at_most_num.is_some();
166        let at_most_num = at_most_num.unwrap_or(usize::MAX);
167        loop {
168            if limited && cnt >= at_most_num {
169                return Ok(false);
170            }
171            match self.receiver.recv().await {
172                // Received some data
173                Ok(Some(chunk)) => {
174                    trace!(
175                        "Task output id: {:?}, data len: {:?}",
176                        self.output_id,
177                        chunk.cardinality()
178                    );
179                    let pb = chunk.to_protobuf().await;
180                    let resp = GetDataResponse {
181                        record_batch: Some(pb),
182                    };
183                    writer.write(Ok(resp)).await?;
184                }
185                // Reached EOF
186                Ok(None) => {
187                    break;
188                }
189                // Error happened
190                Err(e) => {
191                    writer.write(Err(tonic::Status::from(&*e))).await?;
192                    break;
193                }
194            }
195            cnt += 1;
196        }
197        Ok(true)
198    }
199
200    /// Take at most num data and write the data in serialized format to `ExchangeWriter`.
201    /// Return whether the data stream is finished.
202    pub async fn take_data_with_num(
203        &mut self,
204        writer: &mut impl ExchangeWriter,
205        num: usize,
206    ) -> Result<bool> {
207        self.take_data_inner(writer, Some(num)).await
208    }
209
210    /// Take all data and write the data in serialized format to `ExchangeWriter`.
211    pub async fn take_data(&mut self, writer: &mut impl ExchangeWriter) -> Result<()> {
212        let finish = self.take_data_inner(writer, None).await?;
213        assert!(finish);
214        Ok(())
215    }
216
217    /// Directly takes data without serialization.
218    pub async fn direct_take_data(&mut self) -> SharedResult<Option<DataChunk>> {
219        Ok(self.receiver.recv().await?.map(|c| c.into_data_chunk()))
220    }
221
222    pub fn id(&self) -> &TaskOutputId {
223        &self.output_id
224    }
225}
226
227#[derive(Clone, Debug)]
228pub enum ShutdownMsg {
229    /// Used in init, it never occur in receiver later.
230    Init,
231    Abort(String),
232    Cancel,
233}
234
235/// A token which can be used to signal a shutdown request.
236pub struct ShutdownSender(tokio::sync::watch::Sender<ShutdownMsg>);
237
238impl ShutdownSender {
239    /// Send a cancel message. Return true if the message is sent successfully.
240    pub fn cancel(&self) -> bool {
241        self.0.send(ShutdownMsg::Cancel).is_ok()
242    }
243
244    /// Send an abort message. Return true if the message is sent successfully.
245    pub fn abort(&self, msg: impl Into<String>) -> bool {
246        self.0.send(ShutdownMsg::Abort(msg.into())).is_ok()
247    }
248}
249
250/// A token which can be used to receive a shutdown signal.
251#[derive(Clone)]
252pub struct ShutdownToken(tokio::sync::watch::Receiver<ShutdownMsg>);
253
254impl ShutdownToken {
255    /// Create an empty token.
256    pub fn empty() -> Self {
257        Self::new().1
258    }
259
260    /// Create a new token.
261    pub fn new() -> (ShutdownSender, Self) {
262        let (tx, rx) = tokio::sync::watch::channel(ShutdownMsg::Init);
263        (ShutdownSender(tx), ShutdownToken(rx))
264    }
265
266    /// Return error if the shutdown token has been triggered.
267    pub fn check(&self) -> Result<()> {
268        match &*self.0.borrow() {
269            ShutdownMsg::Init => Ok(()),
270            msg => bail!("Receive shutdown msg: {msg:?}"),
271        }
272    }
273
274    /// Wait until cancellation is requested.
275    ///
276    /// # Cancel safety
277    /// This method is cancel safe.
278    pub async fn cancelled(&mut self) {
279        if matches!(*self.0.borrow(), ShutdownMsg::Init)
280            && let Err(_err) = self.0.changed().await
281        {
282            std::future::pending::<()>().await;
283        }
284    }
285
286    /// Return true if the shutdown token has been triggered.
287    pub fn is_cancelled(&self) -> bool {
288        !matches!(*self.0.borrow(), ShutdownMsg::Init)
289    }
290
291    /// Return the current shutdown message.
292    pub fn message(&self) -> ShutdownMsg {
293        self.0.borrow().clone()
294    }
295}
296
297/// `BatchTaskExecution` represents a single task execution.
298pub struct BatchTaskExecution {
299    /// Task id.
300    task_id: TaskId,
301
302    /// Inner plan to execute.
303    plan: PlanFragment,
304
305    /// Task state.
306    state: Mutex<TaskStatus>,
307
308    /// Receivers data of the task.
309    receivers: Mutex<Vec<Option<ChanReceiverImpl>>>,
310
311    /// Sender for sending chunks between different executors.
312    sender: ChanSenderImpl,
313
314    /// Context for task execution
315    context: Arc<dyn BatchTaskContext>,
316
317    /// The execution failure.
318    failure: Arc<Mutex<Option<Arc<BatchError>>>>,
319
320    /// Runtime for the batch tasks.
321    runtime: Arc<BackgroundShutdownRuntime>,
322
323    shutdown_tx: ShutdownSender,
324    shutdown_rx: ShutdownToken,
325    heartbeat_join_handle: Mutex<Option<JoinHandle<()>>>,
326}
327
328impl BatchTaskExecution {
329    pub fn new(
330        prost_tid: &PbTaskId,
331        plan: PlanFragment,
332        context: Arc<dyn BatchTaskContext>,
333        runtime: Arc<BackgroundShutdownRuntime>,
334    ) -> Result<Self> {
335        let task_id = TaskId::from(prost_tid);
336
337        let (sender, receivers) = create_output_channel(
338            plan.get_exchange_info()?,
339            context.get_config().developer.output_channel_size,
340        )?;
341
342        let mut rts = Vec::new();
343        rts.extend(receivers.into_iter().map(Some));
344
345        let (shutdown_tx, shutdown_rx) = ShutdownToken::new();
346        Ok(Self {
347            task_id,
348            plan,
349            state: Mutex::new(TaskStatus::Pending),
350            receivers: Mutex::new(rts),
351            failure: Arc::new(Mutex::new(None)),
352            context,
353            runtime,
354            sender,
355            shutdown_tx,
356            shutdown_rx,
357            heartbeat_join_handle: Mutex::new(None),
358        })
359    }
360
361    pub fn get_task_id(&self) -> &TaskId {
362        &self.task_id
363    }
364
365    /// `async_execute` executes the task in background, it spawns a tokio coroutine and returns
366    /// immediately. The result produced by the task will be sent to one or more channels, according
367    /// to a particular shuffling strategy. For example, in hash shuffling, the result will be
368    /// hash partitioned across multiple channels.
369    /// To obtain the result, one must pick one of the channels to consume via [`TaskOutputId`]. As
370    /// such, parallel consumers are able to consume the result independently.
371    pub async fn async_execute(
372        self: Arc<Self>,
373        state_tx: Option<StateReporter>,
374        tracing_context: TracingContext,
375        expr_context: ExprContext,
376    ) -> Result<()> {
377        let mut state_tx = state_tx;
378        trace!(
379            "Prepare executing plan [{:?}]: {}",
380            self.task_id,
381            serde_json::to_string_pretty(self.plan.get_root()?).unwrap()
382        );
383
384        let exec = expr_context_scope(
385            expr_context.clone(),
386            ExecutorBuilder::new(
387                self.plan.root.as_ref().unwrap(),
388                &self.task_id,
389                self.context.clone(),
390                self.shutdown_rx.clone(),
391            )
392            .build(),
393        )
394        .await?;
395
396        let sender = self.sender.clone();
397        let _failure = self.failure.clone();
398        let task_id = self.task_id.clone();
399
400        // After we init the output receivers, it's must safe to schedule next stage -- able to send
401        // TaskStatus::Running here.
402        // Init the state receivers. Swap out later.
403        self.change_state_notify(TaskStatus::Running, state_tx.as_mut(), None)
404            .await?;
405
406        // Clone `self` to make compiler happy because of the move block.
407        let t_1 = self.clone();
408        let this = self.clone();
409        async fn notify_panic(
410            this: &BatchTaskExecution,
411            state_tx: Option<&mut StateReporter>,
412            message: Option<&str>,
413        ) {
414            let err_str = if let Some(message) = message {
415                format!("execution panic: {}", message)
416            } else {
417                "execution panic".into()
418            };
419
420            if let Err(e) = this
421                .change_state_notify(TaskStatus::Failed, state_tx, Some(err_str))
422                .await
423            {
424                warn!(
425                    error = %e.as_report(),
426                    "The status receiver in FE has closed so the status push is failed",
427                );
428            }
429        }
430        // Spawn task for real execution.
431        let fut = async move {
432            trace!("Executing plan [{:?}]", task_id);
433            let sender = sender;
434            let mut state_tx_1 = state_tx.clone();
435
436            let task = |task_id: TaskId| async move {
437                let span = tracing_context.attach(tracing::info_span!(
438                    "batch_execute",
439                    task_id = task_id.task_id,
440                    stage_id = task_id.stage_id,
441                    query_id = task_id.query_id,
442                ));
443
444                // We should only pass a reference of sender to execution because we should only
445                // close it after task error has been set.
446                expr_context_scope(
447                    expr_context,
448                    t_1.run(exec, sender, state_tx_1.as_mut()).instrument(span),
449                )
450                .await;
451            };
452
453            if let Err(error) = AssertUnwindSafe(task(task_id.clone()))
454                .rw_catch_unwind()
455                .await
456            {
457                let message = panic_message::get_panic_message(&error);
458                error!(?task_id, error = message, "Batch task panic");
459                notify_panic(&this, state_tx.as_mut(), message).await;
460            }
461        };
462
463        self.runtime.spawn(fut);
464
465        Ok(())
466    }
467
468    /// Change state and notify frontend for task status via streaming GRPC.
469    pub async fn change_state_notify(
470        &self,
471        task_status: TaskStatus,
472        state_tx: Option<&mut StateReporter>,
473        err_str: Option<String>,
474    ) -> Result<()> {
475        self.change_state(task_status);
476        // Notify frontend the task status.
477        if let Some(reporter) = state_tx {
478            reporter
479                .send(TaskInfoResponse {
480                    task_id: Some(self.task_id.to_prost()),
481                    task_status: task_status.into(),
482                    error_message: err_str.unwrap_or("".to_owned()),
483                })
484                .await
485        } else {
486            Ok(())
487        }
488    }
489
490    pub fn change_state(&self, task_status: TaskStatus) {
491        *self.state.lock() = task_status;
492        tracing::debug!(
493            "Task {:?} state changed to {:?}",
494            &self.task_id,
495            task_status
496        );
497    }
498
499    async fn run(
500        &self,
501        root: BoxedExecutor,
502        mut sender: ChanSenderImpl,
503        state_tx: Option<&mut StateReporter>,
504    ) {
505        self.context
506            .batch_metrics()
507            .as_ref()
508            .inspect(|m| m.batch_manager_metrics().task_num.inc());
509        let mut data_chunk_stream = root.execute();
510        let mut state;
511        let mut error = None;
512
513        let mut shutdown_rx = self.shutdown_rx.clone();
514        loop {
515            select! {
516                biased;
517                // `shutdown_rx` can't be removed here to avoid `sender.send(data_chunk)` blocked whole execution.
518                _ = shutdown_rx.cancelled() => {
519                    match self.shutdown_rx.message() {
520                        ShutdownMsg::Abort(e) => {
521                            error = Some(BatchError::Aborted(e));
522                            state = TaskStatus::Aborted;
523                            break;
524                        }
525                        ShutdownMsg::Cancel => {
526                            state = TaskStatus::Cancelled;
527                            break;
528                        }
529                        ShutdownMsg::Init => {
530                            unreachable!("Init message should not be received here!")
531                        }
532                    }
533                }
534                data_chunk = data_chunk_stream.next()=> {
535                    match data_chunk {
536                        Some(Ok(data_chunk)) => {
537                            if let Err(e) = sender.send(data_chunk).await {
538                                match e {
539                                    BatchError::SenderError => {
540                                        // This is possible since when we have limit executor in parent
541                                        // stage, it may early stop receiving data from downstream, which
542                                        // leads to close of channel.
543                                        warn!("Task receiver closed!");
544                                        state = TaskStatus::Finished;
545                                        break;
546                                    }
547                                    x => {
548                                        error!("Failed to send data!");
549                                        error = Some(x);
550                                        state = TaskStatus::Failed;
551                                        break;
552                                    }
553                                }
554                            }
555                        }
556                        Some(Err(e)) => match self.shutdown_rx.message() {
557                            ShutdownMsg::Init => {
558                                // There is no message received from shutdown channel, which means it caused
559                                // task failed.
560                                error!(error = %e.as_report(), "Batch task failed");
561                                error = Some(e);
562                                state = TaskStatus::Failed;
563                                break;
564                            }
565                            ShutdownMsg::Abort(_) => {
566                                error = Some(e);
567                                state = TaskStatus::Aborted;
568                                break;
569                            }
570                            ShutdownMsg::Cancel => {
571                                state = TaskStatus::Cancelled;
572                                break;
573                            }
574                        },
575                        None => {
576                            debug!("Batch task {:?} finished successfully.", self.task_id);
577                            state = TaskStatus::Finished;
578                            break;
579                        }
580                    }
581                }
582            }
583        }
584
585        let error = error.map(Arc::new);
586        self.failure.lock().clone_from(&error);
587        let err_str = error.as_ref().map(|e| e.to_report_string());
588        if let Err(e) = sender.close(error).await {
589            match e {
590                SenderError => {
591                    // This is possible since when we have limit executor in parent
592                    // stage, it may early stop receiving data from downstream, which
593                    // leads to close of channel.
594                    warn!("Task receiver closed when sending None!");
595                }
596                _x => {
597                    error!("Failed to close task output channel: {:?}", self.task_id);
598                    state = TaskStatus::Failed;
599                }
600            }
601        }
602
603        if let Err(e) = self.change_state_notify(state, state_tx, err_str).await {
604            warn!(
605                error = %e.as_report(),
606                "The status receiver in FE has closed so the status push is failed",
607            );
608        }
609
610        self.context
611            .batch_metrics()
612            .as_ref()
613            .inspect(|m| m.batch_manager_metrics().task_num.dec());
614    }
615
616    pub fn abort(&self, err_msg: String) {
617        // No need to set state to be Aborted here cuz it will be set by shutdown receiver.
618        // Stop task execution.
619        if self.shutdown_tx.abort(err_msg) {
620            info!("Abort task {:?} done", self.task_id);
621        } else {
622            debug!("The task has already died before this request.")
623        }
624    }
625
626    pub fn cancel(&self) {
627        if !self.shutdown_tx.cancel() {
628            debug!("The task has already died before this request.");
629        }
630    }
631
632    pub fn get_task_output(&self, output_id: &PbTaskOutputId) -> Result<TaskOutput> {
633        let task_id = TaskId::from(output_id.get_task_id()?);
634        let receiver = self.receivers.lock()[output_id.get_output_id() as usize]
635            .take()
636            .with_context(|| {
637                format!(
638                    "Task{:?}'s output{} has already been taken.",
639                    task_id,
640                    output_id.get_output_id(),
641                )
642            })?;
643        let task_output = TaskOutput {
644            receiver,
645            output_id: output_id.try_into()?,
646            failure: self.failure.clone(),
647        };
648        Ok(task_output)
649    }
650
651    pub fn check_if_running(&self) -> Result<()> {
652        if *self.state.lock() != TaskStatus::Running {
653            bail!("task {:?} is not running", self.get_task_id());
654        }
655        Ok(())
656    }
657
658    pub fn check_if_aborted(&self) -> Result<bool> {
659        match *self.state.lock() {
660            TaskStatus::Aborted => Ok(true),
661            TaskStatus::Finished => bail!("task {:?} has been finished", self.get_task_id()),
662            _ => Ok(false),
663        }
664    }
665
666    /// Check the task status: whether has ended.
667    pub fn is_end(&self) -> bool {
668        let guard = self.state.lock();
669        !(*guard == TaskStatus::Running || *guard == TaskStatus::Pending)
670    }
671}
672
673impl BatchTaskExecution {
674    pub(crate) fn set_heartbeat_join_handle(&self, join_handle: JoinHandle<()>) {
675        *self.heartbeat_join_handle.lock() = Some(join_handle);
676    }
677
678    pub(crate) fn heartbeat_join_handle(&self) -> Option<JoinHandle<()>> {
679        self.heartbeat_join_handle.lock().take()
680    }
681}
682
683#[cfg(test)]
684mod tests {
685    use super::*;
686
687    #[test]
688    fn test_task_output_id_debug() {
689        let task_id = TaskId {
690            task_id: 1,
691            stage_id: 2,
692            query_id: "abc".to_owned(),
693        };
694        let task_output_id = TaskOutputId {
695            task_id,
696            output_id: 3,
697        };
698        assert_eq!(
699            format!("{:?}", task_output_id),
700            "TaskOutputId { query_id: \"abc\", stage_id: 2, task_id: 1, output_id: 3 }"
701        );
702    }
703}