risingwave_batch/task/
task_execution.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::{Debug, Formatter};
16use std::panic::AssertUnwindSafe;
17use std::sync::Arc;
18
19use anyhow::Context;
20use futures::StreamExt;
21use parking_lot::Mutex;
22use risingwave_common::array::DataChunk;
23use risingwave_common::util::panic::FutureCatchUnwindExt;
24use risingwave_common::util::runtime::BackgroundShutdownRuntime;
25use risingwave_common::util::tracing::TracingContext;
26use risingwave_expr::expr_context::expr_context_scope;
27use risingwave_pb::PbFieldNotFound;
28use risingwave_pb::batch_plan::{PbTaskId, PbTaskOutputId, PlanFragment};
29use risingwave_pb::common::BatchQueryEpoch;
30use risingwave_pb::plan_common::ExprContext;
31use risingwave_pb::task_service::task_info_response::TaskStatus;
32use risingwave_pb::task_service::{GetDataResponse, TaskInfoResponse};
33use thiserror_ext::AsReport;
34use tokio::select;
35use tokio::task::JoinHandle;
36use tracing::Instrument;
37
38use crate::error::BatchError::SenderError;
39use crate::error::{BatchError, Result, SharedResult};
40use crate::executor::{BoxedExecutor, ExecutorBuilder};
41use crate::rpc::service::exchange::ExchangeWriter;
42use crate::rpc::service::task_service::TaskInfoResponseResult;
43use crate::task::BatchTaskContext;
44use crate::task::channel::{ChanReceiverImpl, ChanSenderImpl, create_output_channel};
45
46// Now we will only at most have 2 status for each status channel. Running -> Failed or Finished.
47pub const TASK_STATUS_BUFFER_SIZE: usize = 2;
48
49/// Send batch task status (local/distributed) to frontend.
50///
51///
52/// Local mode use `StateReporter::Local`, Distributed mode use `StateReporter::Distributed` to send
53/// status (Failed/Finished) update. `StateReporter::Mock` is only used in test and do not takes any
54/// effect. Local sender only report Failed update, Distributed sender will also report
55/// Finished/Pending/Starting/Aborted etc.
56#[derive(Clone)]
57pub enum StateReporter {
58    Distributed(tokio::sync::mpsc::Sender<TaskInfoResponseResult>),
59    Mock(),
60}
61
62impl StateReporter {
63    pub async fn send(&mut self, val: TaskInfoResponse) -> Result<()> {
64        match self {
65            Self::Distributed(s) => s.send(Ok(val)).await.map_err(|_| SenderError),
66            Self::Mock() => Ok(()),
67        }
68    }
69
70    pub fn new_with_dist_sender(s: tokio::sync::mpsc::Sender<TaskInfoResponseResult>) -> Self {
71        Self::Distributed(s)
72    }
73
74    pub fn new_with_test() -> Self {
75        Self::Mock()
76    }
77}
78
79#[derive(PartialEq, Eq, Hash, Clone, Debug, Default)]
80pub struct TaskId {
81    pub task_id: u64,
82    pub stage_id: u32,
83    pub query_id: String,
84}
85
86#[derive(PartialEq, Eq, Hash, Clone, Default)]
87pub struct TaskOutputId {
88    pub task_id: TaskId,
89    pub output_id: u64,
90}
91
92/// More compact formatter compared to derived `fmt::Debug`.
93impl Debug for TaskOutputId {
94    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
95        f.write_fmt(format_args!(
96            "TaskOutputId {{ query_id: \"{}\", stage_id: {}, task_id: {}, output_id: {} }}",
97            self.task_id.query_id, self.task_id.stage_id, self.task_id.task_id, self.output_id
98        ))
99    }
100}
101
102impl From<&PbTaskId> for TaskId {
103    fn from(prost: &PbTaskId) -> Self {
104        TaskId {
105            task_id: prost.task_id,
106            stage_id: prost.stage_id,
107            query_id: prost.query_id.clone(),
108        }
109    }
110}
111
112impl TaskId {
113    pub fn to_prost(&self) -> PbTaskId {
114        PbTaskId {
115            task_id: self.task_id,
116            stage_id: self.stage_id,
117            query_id: self.query_id.clone(),
118        }
119    }
120}
121
122impl TryFrom<&PbTaskOutputId> for TaskOutputId {
123    type Error = PbFieldNotFound;
124
125    fn try_from(prost: &PbTaskOutputId) -> std::result::Result<Self, PbFieldNotFound> {
126        Ok(TaskOutputId {
127            task_id: TaskId::from(prost.get_task_id()?),
128            output_id: prost.get_output_id(),
129        })
130    }
131}
132
133impl TaskOutputId {
134    pub fn to_prost(&self) -> PbTaskOutputId {
135        PbTaskOutputId {
136            task_id: Some(self.task_id.to_prost()),
137            output_id: self.output_id,
138        }
139    }
140}
141
142pub struct TaskOutput {
143    receiver: ChanReceiverImpl,
144    output_id: TaskOutputId,
145    failure: Arc<Mutex<Option<Arc<BatchError>>>>,
146}
147
148impl std::fmt::Debug for TaskOutput {
149    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
150        f.debug_struct("TaskOutput")
151            .field("output_id", &self.output_id)
152            .field("failure", &self.failure)
153            .finish_non_exhaustive()
154    }
155}
156
157impl TaskOutput {
158    /// Write the data in serialized format to `ExchangeWriter`.
159    /// Return whether the data stream is finished.
160    async fn take_data_inner(
161        &mut self,
162        writer: &mut impl ExchangeWriter,
163        at_most_num: Option<usize>,
164    ) -> Result<bool> {
165        let mut cnt: usize = 0;
166        let limited = at_most_num.is_some();
167        let at_most_num = at_most_num.unwrap_or(usize::MAX);
168        loop {
169            if limited && cnt >= at_most_num {
170                return Ok(false);
171            }
172            match self.receiver.recv().await {
173                // Received some data
174                Ok(Some(chunk)) => {
175                    trace!(
176                        "Task output id: {:?}, data len: {:?}",
177                        self.output_id,
178                        chunk.cardinality()
179                    );
180                    let pb = chunk.to_protobuf().await;
181                    let resp = GetDataResponse {
182                        record_batch: Some(pb),
183                    };
184                    writer.write(Ok(resp)).await?;
185                }
186                // Reached EOF
187                Ok(None) => {
188                    break;
189                }
190                // Error happened
191                Err(e) => {
192                    writer.write(Err(tonic::Status::from(&*e))).await?;
193                    break;
194                }
195            }
196            cnt += 1;
197        }
198        Ok(true)
199    }
200
201    /// Take at most num data and write the data in serialized format to `ExchangeWriter`.
202    /// Return whether the data stream is finished.
203    pub async fn take_data_with_num(
204        &mut self,
205        writer: &mut impl ExchangeWriter,
206        num: usize,
207    ) -> Result<bool> {
208        self.take_data_inner(writer, Some(num)).await
209    }
210
211    /// Take all data and write the data in serialized format to `ExchangeWriter`.
212    pub async fn take_data(&mut self, writer: &mut impl ExchangeWriter) -> Result<()> {
213        let finish = self.take_data_inner(writer, None).await?;
214        assert!(finish);
215        Ok(())
216    }
217
218    /// Directly takes data without serialization.
219    pub async fn direct_take_data(&mut self) -> SharedResult<Option<DataChunk>> {
220        Ok(self.receiver.recv().await?.map(|c| c.into_data_chunk()))
221    }
222
223    pub fn id(&self) -> &TaskOutputId {
224        &self.output_id
225    }
226}
227
228#[derive(Clone, Debug)]
229pub enum ShutdownMsg {
230    /// Used in init, it never occur in receiver later.
231    Init,
232    Abort(String),
233    Cancel,
234}
235
236/// A token which can be used to signal a shutdown request.
237pub struct ShutdownSender(tokio::sync::watch::Sender<ShutdownMsg>);
238
239impl ShutdownSender {
240    /// Send a cancel message. Return true if the message is sent successfully.
241    pub fn cancel(&self) -> bool {
242        self.0.send(ShutdownMsg::Cancel).is_ok()
243    }
244
245    /// Send an abort message. Return true if the message is sent successfully.
246    pub fn abort(&self, msg: impl Into<String>) -> bool {
247        self.0.send(ShutdownMsg::Abort(msg.into())).is_ok()
248    }
249}
250
251/// A token which can be used to receive a shutdown signal.
252#[derive(Clone)]
253pub struct ShutdownToken(tokio::sync::watch::Receiver<ShutdownMsg>);
254
255impl ShutdownToken {
256    /// Create an empty token.
257    pub fn empty() -> Self {
258        Self::new().1
259    }
260
261    /// Create a new token.
262    pub fn new() -> (ShutdownSender, Self) {
263        let (tx, rx) = tokio::sync::watch::channel(ShutdownMsg::Init);
264        (ShutdownSender(tx), ShutdownToken(rx))
265    }
266
267    /// Return error if the shutdown token has been triggered.
268    pub fn check(&self) -> Result<()> {
269        match &*self.0.borrow() {
270            ShutdownMsg::Init => Ok(()),
271            msg => bail!("Receive shutdown msg: {msg:?}"),
272        }
273    }
274
275    /// Wait until cancellation is requested.
276    ///
277    /// # Cancel safety
278    /// This method is cancel safe.
279    pub async fn cancelled(&mut self) {
280        if matches!(*self.0.borrow(), ShutdownMsg::Init) {
281            if let Err(_err) = self.0.changed().await {
282                std::future::pending::<()>().await;
283            }
284        }
285    }
286
287    /// Return true if the shutdown token has been triggered.
288    pub fn is_cancelled(&self) -> bool {
289        !matches!(*self.0.borrow(), ShutdownMsg::Init)
290    }
291
292    /// Return the current shutdown message.
293    pub fn message(&self) -> ShutdownMsg {
294        self.0.borrow().clone()
295    }
296}
297
298/// `BatchTaskExecution` represents a single task execution.
299pub struct BatchTaskExecution {
300    /// Task id.
301    task_id: TaskId,
302
303    /// Inner plan to execute.
304    plan: PlanFragment,
305
306    /// Task state.
307    state: Mutex<TaskStatus>,
308
309    /// Receivers data of the task.
310    receivers: Mutex<Vec<Option<ChanReceiverImpl>>>,
311
312    /// Sender for sending chunks between different executors.
313    sender: ChanSenderImpl,
314
315    /// Context for task execution
316    context: Arc<dyn BatchTaskContext>,
317
318    /// The execution failure.
319    failure: Arc<Mutex<Option<Arc<BatchError>>>>,
320
321    epoch: BatchQueryEpoch,
322
323    /// Runtime for the batch tasks.
324    runtime: Arc<BackgroundShutdownRuntime>,
325
326    shutdown_tx: ShutdownSender,
327    shutdown_rx: ShutdownToken,
328    heartbeat_join_handle: Mutex<Option<JoinHandle<()>>>,
329}
330
331impl BatchTaskExecution {
332    pub fn new(
333        prost_tid: &PbTaskId,
334        plan: PlanFragment,
335        context: Arc<dyn BatchTaskContext>,
336        epoch: BatchQueryEpoch,
337        runtime: Arc<BackgroundShutdownRuntime>,
338    ) -> Result<Self> {
339        let task_id = TaskId::from(prost_tid);
340
341        let (sender, receivers) = create_output_channel(
342            plan.get_exchange_info()?,
343            context.get_config().developer.output_channel_size,
344        )?;
345
346        let mut rts = Vec::new();
347        rts.extend(receivers.into_iter().map(Some));
348
349        let (shutdown_tx, shutdown_rx) = ShutdownToken::new();
350        Ok(Self {
351            task_id,
352            plan,
353            state: Mutex::new(TaskStatus::Pending),
354            receivers: Mutex::new(rts),
355            failure: Arc::new(Mutex::new(None)),
356            epoch,
357            context,
358            runtime,
359            sender,
360            shutdown_tx,
361            shutdown_rx,
362            heartbeat_join_handle: Mutex::new(None),
363        })
364    }
365
366    pub fn get_task_id(&self) -> &TaskId {
367        &self.task_id
368    }
369
370    /// `async_execute` executes the task in background, it spawns a tokio coroutine and returns
371    /// immediately. The result produced by the task will be sent to one or more channels, according
372    /// to a particular shuffling strategy. For example, in hash shuffling, the result will be
373    /// hash partitioned across multiple channels.
374    /// To obtain the result, one must pick one of the channels to consume via [`TaskOutputId`]. As
375    /// such, parallel consumers are able to consume the result independently.
376    pub async fn async_execute(
377        self: Arc<Self>,
378        state_tx: Option<StateReporter>,
379        tracing_context: TracingContext,
380        expr_context: ExprContext,
381    ) -> Result<()> {
382        let mut state_tx = state_tx;
383        trace!(
384            "Prepare executing plan [{:?}]: {}",
385            self.task_id,
386            serde_json::to_string_pretty(self.plan.get_root()?).unwrap()
387        );
388
389        let exec = expr_context_scope(
390            expr_context.clone(),
391            ExecutorBuilder::new(
392                self.plan.root.as_ref().unwrap(),
393                &self.task_id,
394                self.context.clone(),
395                self.epoch,
396                self.shutdown_rx.clone(),
397            )
398            .build(),
399        )
400        .await?;
401
402        let sender = self.sender.clone();
403        let _failure = self.failure.clone();
404        let task_id = self.task_id.clone();
405
406        // After we init the output receivers, it's must safe to schedule next stage -- able to send
407        // TaskStatus::Running here.
408        // Init the state receivers. Swap out later.
409        self.change_state_notify(TaskStatus::Running, state_tx.as_mut(), None)
410            .await?;
411
412        // Clone `self` to make compiler happy because of the move block.
413        let t_1 = self.clone();
414        let this = self.clone();
415        async fn notify_panic(
416            this: &BatchTaskExecution,
417            state_tx: Option<&mut StateReporter>,
418            message: Option<&str>,
419        ) {
420            let err_str = if let Some(message) = message {
421                format!("execution panic: {}", message)
422            } else {
423                "execution panic".into()
424            };
425
426            if let Err(e) = this
427                .change_state_notify(TaskStatus::Failed, state_tx, Some(err_str))
428                .await
429            {
430                warn!(
431                    error = %e.as_report(),
432                    "The status receiver in FE has closed so the status push is failed",
433                );
434            }
435        }
436        // Spawn task for real execution.
437        let fut = async move {
438            trace!("Executing plan [{:?}]", task_id);
439            let sender = sender;
440            let mut state_tx_1 = state_tx.clone();
441
442            let task = |task_id: TaskId| async move {
443                let span = tracing_context.attach(tracing::info_span!(
444                    "batch_execute",
445                    task_id = task_id.task_id,
446                    stage_id = task_id.stage_id,
447                    query_id = task_id.query_id,
448                ));
449
450                // We should only pass a reference of sender to execution because we should only
451                // close it after task error has been set.
452                expr_context_scope(
453                    expr_context,
454                    t_1.run(exec, sender, state_tx_1.as_mut()).instrument(span),
455                )
456                .await;
457            };
458
459            if let Err(error) = AssertUnwindSafe(task(task_id.clone()))
460                .rw_catch_unwind()
461                .await
462            {
463                let message = panic_message::get_panic_message(&error);
464                error!(?task_id, error = message, "Batch task panic");
465                notify_panic(&this, state_tx.as_mut(), message).await;
466            }
467        };
468
469        self.runtime.spawn(fut);
470
471        Ok(())
472    }
473
474    /// Change state and notify frontend for task status via streaming GRPC.
475    pub async fn change_state_notify(
476        &self,
477        task_status: TaskStatus,
478        state_tx: Option<&mut StateReporter>,
479        err_str: Option<String>,
480    ) -> Result<()> {
481        self.change_state(task_status);
482        // Notify frontend the task status.
483        if let Some(reporter) = state_tx {
484            reporter
485                .send(TaskInfoResponse {
486                    task_id: Some(self.task_id.to_prost()),
487                    task_status: task_status.into(),
488                    error_message: err_str.unwrap_or("".to_owned()),
489                })
490                .await
491        } else {
492            Ok(())
493        }
494    }
495
496    pub fn change_state(&self, task_status: TaskStatus) {
497        *self.state.lock() = task_status;
498        tracing::debug!(
499            "Task {:?} state changed to {:?}",
500            &self.task_id,
501            task_status
502        );
503    }
504
505    async fn run(
506        &self,
507        root: BoxedExecutor,
508        mut sender: ChanSenderImpl,
509        state_tx: Option<&mut StateReporter>,
510    ) {
511        self.context
512            .batch_metrics()
513            .as_ref()
514            .inspect(|m| m.batch_manager_metrics().task_num.inc());
515        let mut data_chunk_stream = root.execute();
516        let mut state;
517        let mut error = None;
518
519        let mut shutdown_rx = self.shutdown_rx.clone();
520        loop {
521            select! {
522                biased;
523                // `shutdown_rx` can't be removed here to avoid `sender.send(data_chunk)` blocked whole execution.
524                _ = shutdown_rx.cancelled() => {
525                    match self.shutdown_rx.message() {
526                        ShutdownMsg::Abort(e) => {
527                            error = Some(BatchError::Aborted(e));
528                            state = TaskStatus::Aborted;
529                            break;
530                        }
531                        ShutdownMsg::Cancel => {
532                            state = TaskStatus::Cancelled;
533                            break;
534                        }
535                        ShutdownMsg::Init => {
536                            unreachable!("Init message should not be received here!")
537                        }
538                    }
539                }
540                data_chunk = data_chunk_stream.next()=> {
541                    match data_chunk {
542                        Some(Ok(data_chunk)) => {
543                            if let Err(e) = sender.send(data_chunk).await {
544                                match e {
545                                    BatchError::SenderError => {
546                                        // This is possible since when we have limit executor in parent
547                                        // stage, it may early stop receiving data from downstream, which
548                                        // leads to close of channel.
549                                        warn!("Task receiver closed!");
550                                        state = TaskStatus::Finished;
551                                        break;
552                                    }
553                                    x => {
554                                        error!("Failed to send data!");
555                                        error = Some(x);
556                                        state = TaskStatus::Failed;
557                                        break;
558                                    }
559                                }
560                            }
561                        }
562                        Some(Err(e)) => match self.shutdown_rx.message() {
563                            ShutdownMsg::Init => {
564                                // There is no message received from shutdown channel, which means it caused
565                                // task failed.
566                                error!(error = %e.as_report(), "Batch task failed");
567                                error = Some(e);
568                                state = TaskStatus::Failed;
569                                break;
570                            }
571                            ShutdownMsg::Abort(_) => {
572                                error = Some(e);
573                                state = TaskStatus::Aborted;
574                                break;
575                            }
576                            ShutdownMsg::Cancel => {
577                                state = TaskStatus::Cancelled;
578                                break;
579                            }
580                        },
581                        None => {
582                            debug!("Batch task {:?} finished successfully.", self.task_id);
583                            state = TaskStatus::Finished;
584                            break;
585                        }
586                    }
587                }
588            }
589        }
590
591        let error = error.map(Arc::new);
592        self.failure.lock().clone_from(&error);
593        let err_str = error.as_ref().map(|e| e.to_report_string());
594        if let Err(e) = sender.close(error).await {
595            match e {
596                SenderError => {
597                    // This is possible since when we have limit executor in parent
598                    // stage, it may early stop receiving data from downstream, which
599                    // leads to close of channel.
600                    warn!("Task receiver closed when sending None!");
601                }
602                _x => {
603                    error!("Failed to close task output channel: {:?}", self.task_id);
604                    state = TaskStatus::Failed;
605                }
606            }
607        }
608
609        if let Err(e) = self.change_state_notify(state, state_tx, err_str).await {
610            warn!(
611                error = %e.as_report(),
612                "The status receiver in FE has closed so the status push is failed",
613            );
614        }
615
616        self.context
617            .batch_metrics()
618            .as_ref()
619            .inspect(|m| m.batch_manager_metrics().task_num.dec());
620    }
621
622    pub fn abort(&self, err_msg: String) {
623        // No need to set state to be Aborted here cuz it will be set by shutdown receiver.
624        // Stop task execution.
625        if self.shutdown_tx.abort(err_msg) {
626            info!("Abort task {:?} done", self.task_id);
627        } else {
628            debug!("The task has already died before this request.")
629        }
630    }
631
632    pub fn cancel(&self) {
633        if !self.shutdown_tx.cancel() {
634            debug!("The task has already died before this request.");
635        }
636    }
637
638    pub fn get_task_output(&self, output_id: &PbTaskOutputId) -> Result<TaskOutput> {
639        let task_id = TaskId::from(output_id.get_task_id()?);
640        let receiver = self.receivers.lock()[output_id.get_output_id() as usize]
641            .take()
642            .with_context(|| {
643                format!(
644                    "Task{:?}'s output{} has already been taken.",
645                    task_id,
646                    output_id.get_output_id(),
647                )
648            })?;
649        let task_output = TaskOutput {
650            receiver,
651            output_id: output_id.try_into()?,
652            failure: self.failure.clone(),
653        };
654        Ok(task_output)
655    }
656
657    pub fn check_if_running(&self) -> Result<()> {
658        if *self.state.lock() != TaskStatus::Running {
659            bail!("task {:?} is not running", self.get_task_id());
660        }
661        Ok(())
662    }
663
664    pub fn check_if_aborted(&self) -> Result<bool> {
665        match *self.state.lock() {
666            TaskStatus::Aborted => Ok(true),
667            TaskStatus::Finished => bail!("task {:?} has been finished", self.get_task_id()),
668            _ => Ok(false),
669        }
670    }
671
672    /// Check the task status: whether has ended.
673    pub fn is_end(&self) -> bool {
674        let guard = self.state.lock();
675        !(*guard == TaskStatus::Running || *guard == TaskStatus::Pending)
676    }
677}
678
679impl BatchTaskExecution {
680    pub(crate) fn set_heartbeat_join_handle(&self, join_handle: JoinHandle<()>) {
681        *self.heartbeat_join_handle.lock() = Some(join_handle);
682    }
683
684    pub(crate) fn heartbeat_join_handle(&self) -> Option<JoinHandle<()>> {
685        self.heartbeat_join_handle.lock().take()
686    }
687}
688
689#[cfg(test)]
690mod tests {
691    use super::*;
692
693    #[test]
694    fn test_task_output_id_debug() {
695        let task_id = TaskId {
696            task_id: 1,
697            stage_id: 2,
698            query_id: "abc".to_owned(),
699        };
700        let task_output_id = TaskOutputId {
701            task_id,
702            output_id: 3,
703        };
704        assert_eq!(
705            format!("{:?}", task_output_id),
706            "TaskOutputId { query_id: \"abc\", stage_id: 2, task_id: 1, output_id: 3 }"
707        );
708    }
709}