risingwave_frontend/scheduler/distributed/
stage.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::assert_matches::assert_matches;
16use std::collections::HashMap;
17use std::mem;
18use std::pin::pin;
19use std::sync::Arc;
20use std::time::Duration;
21
22use StageEvent::Failed;
23use anyhow::anyhow;
24use arc_swap::ArcSwap;
25use futures::stream::Fuse;
26use futures::{StreamExt, TryStreamExt, stream};
27use futures_async_stream::for_await;
28use itertools::Itertools;
29use risingwave_batch::error::BatchError;
30use risingwave_batch::executor::ExecutorBuilder;
31use risingwave_batch::task::{ShutdownMsg, ShutdownSender, ShutdownToken, TaskId as TaskIdBatch};
32use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeSelector;
33use risingwave_common::array::DataChunk;
34use risingwave_common::hash::WorkerSlotMapping;
35use risingwave_common::util::addr::HostAddr;
36use risingwave_common::util::iter_util::ZipEqFast;
37use risingwave_connector::source::SplitMetaData;
38use risingwave_expr::expr_context::expr_context_scope;
39use risingwave_pb::batch_plan::plan_node::NodeBody;
40use risingwave_pb::batch_plan::{
41    DistributedLookupJoinNode, ExchangeNode, ExchangeSource, MergeSortExchangeNode, PlanFragment,
42    PlanNode as PbPlanNode, PlanNode, TaskId as PbTaskId, TaskOutputId,
43};
44use risingwave_pb::common::{HostAddress, WorkerNode};
45use risingwave_pb::plan_common::ExprContext;
46use risingwave_pb::task_service::{CancelTaskRequest, TaskInfoResponse};
47use risingwave_rpc_client::ComputeClientPoolRef;
48use risingwave_rpc_client::error::RpcError;
49use rw_futures_util::select_all;
50use thiserror_ext::AsReport;
51use tokio::spawn;
52use tokio::sync::RwLock;
53use tokio::sync::mpsc::{Receiver, Sender};
54use tonic::Streaming;
55use tracing::{Instrument, debug, error, warn};
56
57use crate::catalog::catalog_service::CatalogReader;
58use crate::catalog::{FragmentId, TableId};
59use crate::optimizer::plan_node::BatchPlanNodeType;
60use crate::scheduler::SchedulerError::{TaskExecutionError, TaskRunningOutOfMemory};
61use crate::scheduler::distributed::QueryMessage;
62use crate::scheduler::distributed::stage::StageState::Pending;
63use crate::scheduler::plan_fragmenter::{
64    ExecutionPlanNode, PartitionInfo, Query, ROOT_TASK_ID, StageId, TaskId,
65};
66use crate::scheduler::{ExecutionContextRef, SchedulerError, SchedulerResult};
67
68const TASK_SCHEDULING_PARALLELISM: usize = 10;
69
70#[derive(Debug)]
71enum StageState {
72    /// We put `msg_sender` in `Pending` state to avoid holding it in `StageExecution`. In this
73    /// way, it could be efficiently moved into `StageRunner` instead of being cloned. This also
74    /// ensures that the sender can get dropped once it is used up, preventing some issues caused
75    /// by unnecessarily long lifetime.
76    Pending {
77        msg_sender: Sender<QueryMessage>,
78    },
79    Started,
80    Running,
81    Completed,
82    Failed,
83}
84
85#[derive(Debug)]
86pub enum StageEvent {
87    Scheduled(StageId),
88    ScheduledRoot(Receiver<SchedulerResult<DataChunk>>),
89    /// Stage failed.
90    Failed {
91        id: StageId,
92        reason: SchedulerError,
93    },
94    /// All tasks in stage finished.
95    Completed(#[allow(dead_code)] StageId),
96}
97
98#[derive(Clone)]
99pub struct TaskStatus {
100    _task_id: TaskId,
101
102    // None before task is scheduled.
103    location: Option<HostAddress>,
104}
105
106struct TaskStatusHolder {
107    inner: ArcSwap<TaskStatus>,
108}
109
110pub struct StageExecution {
111    stage_id: StageId,
112    query: Arc<Query>,
113    worker_node_manager: WorkerNodeSelector,
114    tasks: Arc<HashMap<TaskId, TaskStatusHolder>>,
115    state: Arc<RwLock<StageState>>,
116    shutdown_tx: RwLock<Option<ShutdownSender>>,
117    /// Children stage executions.
118    ///
119    /// We use `Vec` here since children's size is usually small.
120    children: Vec<Arc<StageExecution>>,
121    compute_client_pool: ComputeClientPoolRef,
122    catalog_reader: CatalogReader,
123
124    /// Execution context ref
125    ctx: ExecutionContextRef,
126}
127
128struct StageRunner {
129    state: Arc<RwLock<StageState>>,
130    stage_id: StageId,
131    query: Arc<Query>,
132    worker_node_manager: WorkerNodeSelector,
133    tasks: Arc<HashMap<TaskId, TaskStatusHolder>>,
134    // Send message to `QueryRunner` to notify stage state change.
135    msg_sender: Sender<QueryMessage>,
136    children: Vec<Arc<StageExecution>>,
137    compute_client_pool: ComputeClientPoolRef,
138    catalog_reader: CatalogReader,
139
140    ctx: ExecutionContextRef,
141}
142
143impl TaskStatusHolder {
144    fn new(task_id: TaskId) -> Self {
145        let task_status = TaskStatus {
146            _task_id: task_id,
147            location: None,
148        };
149
150        Self {
151            inner: ArcSwap::new(Arc::new(task_status)),
152        }
153    }
154
155    fn get_status(&self) -> Arc<TaskStatus> {
156        self.inner.load_full()
157    }
158}
159
160impl StageExecution {
161    pub fn new(
162        stage_id: StageId,
163        query: Arc<Query>,
164        worker_node_manager: WorkerNodeSelector,
165        msg_sender: Sender<QueryMessage>,
166        children: Vec<Arc<StageExecution>>,
167        compute_client_pool: ComputeClientPoolRef,
168        catalog_reader: CatalogReader,
169        ctx: ExecutionContextRef,
170    ) -> Self {
171        let tasks = (0..query.stage(stage_id).parallelism.unwrap())
172            .map(|task_id| (task_id as u64, TaskStatusHolder::new(task_id as u64)))
173            .collect();
174
175        Self {
176            stage_id,
177            query,
178            worker_node_manager,
179            tasks: Arc::new(tasks),
180            state: Arc::new(RwLock::new(Pending { msg_sender })),
181            shutdown_tx: RwLock::new(None),
182            children,
183            compute_client_pool,
184            catalog_reader,
185            ctx,
186        }
187    }
188
189    /// Starts execution of this stage, returns error if already started.
190    pub async fn start(&self) {
191        let mut s = self.state.write().await;
192        let cur_state = mem::replace(&mut *s, StageState::Failed);
193        match cur_state {
194            Pending { msg_sender } => {
195                let runner = StageRunner {
196                    stage_id: self.stage_id,
197                    query: self.query.clone(),
198                    worker_node_manager: self.worker_node_manager.clone(),
199                    tasks: self.tasks.clone(),
200                    msg_sender,
201                    children: self.children.clone(),
202                    state: self.state.clone(),
203                    compute_client_pool: self.compute_client_pool.clone(),
204                    catalog_reader: self.catalog_reader.clone(),
205                    ctx: self.ctx.clone(),
206                };
207
208                // The channel used for shutdown signal messaging.
209                let (sender, receiver) = ShutdownToken::new();
210                // Fill the shutdown sender.
211                let mut holder = self.shutdown_tx.write().await;
212                *holder = Some(sender);
213
214                // Change state before spawn runner.
215                *s = StageState::Started;
216
217                let span = tracing::info_span!(
218                    "stage",
219                    "otel.name" = format!("Stage {}-{}", self.query.query_id.id, self.stage_id),
220                    query_id = self.query.query_id.id,
221                    stage_id = %self.stage_id,
222                );
223                self.ctx
224                    .session()
225                    .env()
226                    .compute_runtime()
227                    .spawn(async move { runner.run(receiver).instrument(span).await });
228
229                tracing::trace!(
230                    "Stage {:?}-{:?} started.",
231                    self.query.query_id.id,
232                    self.stage_id
233                )
234            }
235            _ => {
236                unreachable!("Only expect to schedule stage once");
237            }
238        }
239    }
240
241    pub async fn stop(&self, error: Option<String>) {
242        // Send message to tell Stage Runner stop.
243        if let Some(shutdown_tx) = self.shutdown_tx.write().await.take() {
244            // It's possible that the stage has not been scheduled, so the channel sender is
245            // None.
246
247            if !if let Some(error) = error {
248                shutdown_tx.abort(error)
249            } else {
250                shutdown_tx.cancel()
251            } {
252                // The stage runner handle has already closed. so do no-op.
253                tracing::trace!(
254                    "Failed to send stop message stage: {:?}-{:?}",
255                    self.query.query_id,
256                    self.stage_id
257                );
258            }
259        }
260    }
261
262    pub async fn is_scheduled(&self) -> bool {
263        let s = self.state.read().await;
264        matches!(*s, StageState::Running | StageState::Completed)
265    }
266
267    pub async fn is_pending(&self) -> bool {
268        let s = self.state.read().await;
269        matches!(*s, StageState::Pending { .. })
270    }
271
272    pub async fn state(&self) -> &'static str {
273        let s = self.state.read().await;
274        match *s {
275            Pending { .. } => "Pending",
276            StageState::Started => "Started",
277            StageState::Running => "Running",
278            StageState::Completed => "Completed",
279            StageState::Failed => "Failed",
280        }
281    }
282
283    /// Returns all exchange sources for `output_id`. Each `ExchangeSource` is identified by
284    /// producer's `TaskId` and `output_id` (consumer's `TaskId`), since each task may produce
285    /// output to several channels.
286    ///
287    /// When this method is called, all tasks should have been scheduled, and their `worker_node`
288    /// should have been set.
289    pub fn all_exchange_sources_for(&self, output_id: u64) -> Vec<ExchangeSource> {
290        self.tasks
291            .iter()
292            .map(|(task_id, status_holder)| {
293                let task_output_id = TaskOutputId {
294                    task_id: Some(PbTaskId {
295                        query_id: self.query.query_id.id.clone(),
296                        stage_id: self.stage_id.into(),
297                        task_id: *task_id,
298                    }),
299                    output_id,
300                };
301
302                ExchangeSource {
303                    task_output_id: Some(task_output_id),
304                    host: Some(status_holder.inner.load_full().location.clone().unwrap()),
305                    local_execute_plan: None,
306                }
307            })
308            .collect()
309    }
310}
311
312impl StageRunner {
313    async fn run(mut self, shutdown_rx: ShutdownToken) {
314        if let Err(e) = self.schedule_tasks_for_all(shutdown_rx).await {
315            error!(
316                error = %e.as_report(),
317                query_id = ?self.query.query_id,
318                stage_id = ?self.stage_id,
319                "Failed to schedule tasks"
320            );
321            self.send_event(QueryMessage::Stage(Failed {
322                id: self.stage_id,
323                reason: e,
324            }))
325            .await;
326        }
327    }
328
329    /// Send stage event to listener.
330    async fn send_event(&self, event: QueryMessage) {
331        if let Err(_e) = self.msg_sender.send(event).await {
332            warn!("Failed to send event to Query Runner, may be killed by previous failed event");
333        }
334    }
335
336    /// Schedule all tasks to CN and wait process all status messages from RPC. Note that when all
337    /// task is created, it should tell `QueryRunner` to schedule next.
338    async fn schedule_tasks(
339        &mut self,
340        mut shutdown_rx: ShutdownToken,
341        expr_context: ExprContext,
342    ) -> SchedulerResult<()> {
343        let mut futures = vec![];
344        let stage = &self.query.stage(self.stage_id);
345
346        if let Some(table_scan_info) = stage.table_scan_info.as_ref()
347            && let Some(vnode_bitmaps) = table_scan_info.partitions()
348        {
349            // If the stage has table scan nodes, we create tasks according to the data distribution
350            // and partition of the table.
351            // We let each task read one partition by setting the `vnode_ranges` of the scan node in
352            // the task.
353            // We schedule the task to the worker node that owns the data partition.
354            let worker_slot_ids = vnode_bitmaps.keys().cloned().collect_vec();
355            let workers = self
356                .worker_node_manager
357                .manager
358                .get_workers_by_worker_slot_ids(&worker_slot_ids)?;
359
360            for (i, (worker_slot_id, worker)) in worker_slot_ids
361                .into_iter()
362                .zip_eq_fast(workers.into_iter())
363                .enumerate()
364            {
365                let task_id = PbTaskId {
366                    query_id: self.query.query_id.id.clone(),
367                    stage_id: self.stage_id.into(),
368                    task_id: i as u64,
369                };
370                let vnode_ranges = vnode_bitmaps[&worker_slot_id].clone();
371                let plan_fragment =
372                    self.create_plan_fragment(i as u64, Some(PartitionInfo::Table(vnode_ranges)));
373                futures.push(self.schedule_task(
374                    task_id,
375                    plan_fragment,
376                    Some(worker),
377                    expr_context.clone(),
378                ));
379            }
380        } else if let Some(source_info) = stage.source_info.as_ref() {
381            // If there is no file in source, the `chunk_size` is set to 1.
382            let chunk_size = ((source_info.split_info().unwrap().len() as f32
383                / stage.parallelism.unwrap() as f32)
384                .ceil() as usize)
385                .max(1);
386            if source_info.split_info().unwrap().is_empty() {
387                // No file in source, schedule an empty task.
388                const EMPTY_TASK_ID: u64 = 0;
389                let task_id = PbTaskId {
390                    query_id: self.query.query_id.id.clone(),
391                    stage_id: self.stage_id.into(),
392                    task_id: EMPTY_TASK_ID,
393                };
394                let plan_fragment =
395                    self.create_plan_fragment(EMPTY_TASK_ID, Some(PartitionInfo::Source(vec![])));
396                let worker =
397                    self.choose_worker(&plan_fragment, EMPTY_TASK_ID as u32, stage.dml_table_id)?;
398                futures.push(self.schedule_task(
399                    task_id,
400                    plan_fragment,
401                    worker,
402                    expr_context.clone(),
403                ));
404            } else {
405                for (id, split) in source_info
406                    .split_info()
407                    .unwrap()
408                    .chunks(chunk_size)
409                    .enumerate()
410                {
411                    let task_id = PbTaskId {
412                        query_id: self.query.query_id.id.clone(),
413                        stage_id: self.stage_id.into(),
414                        task_id: id as u64,
415                    };
416                    let plan_fragment = self.create_plan_fragment(
417                        id as u64,
418                        Some(PartitionInfo::Source(split.to_vec())),
419                    );
420                    let worker =
421                        self.choose_worker(&plan_fragment, id as u32, stage.dml_table_id)?;
422                    futures.push(self.schedule_task(
423                        task_id,
424                        plan_fragment,
425                        worker,
426                        expr_context.clone(),
427                    ));
428                }
429            }
430        } else if let Some(file_scan_info) = stage.file_scan_info.as_ref() {
431            let chunk_size = (file_scan_info.file_location.len() as f32
432                / stage.parallelism.unwrap() as f32)
433                .ceil() as usize;
434            for (id, files) in file_scan_info.file_location.chunks(chunk_size).enumerate() {
435                let task_id = PbTaskId {
436                    query_id: self.query.query_id.id.clone(),
437                    stage_id: self.stage_id.into(),
438                    task_id: id as u64,
439                };
440                let plan_fragment =
441                    self.create_plan_fragment(id as u64, Some(PartitionInfo::File(files.to_vec())));
442                let worker = self.choose_worker(&plan_fragment, id as u32, stage.dml_table_id)?;
443                futures.push(self.schedule_task(
444                    task_id,
445                    plan_fragment,
446                    worker,
447                    expr_context.clone(),
448                ));
449            }
450        } else {
451            for id in 0..stage.parallelism.unwrap() {
452                let task_id = PbTaskId {
453                    query_id: self.query.query_id.id.clone(),
454                    stage_id: self.stage_id.into(),
455                    task_id: id as u64,
456                };
457                let plan_fragment = self.create_plan_fragment(id as u64, None);
458                let worker = self.choose_worker(&plan_fragment, id, stage.dml_table_id)?;
459                futures.push(self.schedule_task(
460                    task_id,
461                    plan_fragment,
462                    worker,
463                    expr_context.clone(),
464                ));
465            }
466        }
467
468        // Await each future and convert them into a set of streams.
469        let buffered = stream::iter(futures).buffer_unordered(TASK_SCHEDULING_PARALLELISM);
470        let buffered_streams = buffered.try_collect::<Vec<_>>().await?;
471
472        // Merge different task streams into a single stream.
473        let cancelled = pin!(shutdown_rx.cancelled());
474        let mut all_streams = select_all(buffered_streams).take_until(cancelled);
475
476        // Process the stream until finished.
477        let mut running_task_cnt = 0;
478        let mut finished_task_cnt = 0;
479        let mut sent_signal_to_next = false;
480
481        while let Some(status_res_inner) = all_streams.next().await {
482            match status_res_inner {
483                Ok(status) => {
484                    use risingwave_pb::task_service::task_info_response::TaskStatus as PbTaskStatus;
485                    match PbTaskStatus::try_from(status.task_status).unwrap() {
486                        PbTaskStatus::Running => {
487                            running_task_cnt += 1;
488                            // The task running count should always less or equal than the
489                            // registered tasks number.
490                            assert!(running_task_cnt <= self.tasks.keys().len());
491                            // All tasks in this stage have been scheduled. Notify query runner to
492                            // schedule next stage.
493                            if running_task_cnt == self.tasks.keys().len() {
494                                self.notify_stage_scheduled(QueryMessage::Stage(
495                                    StageEvent::Scheduled(self.stage_id),
496                                ))
497                                .await;
498                                sent_signal_to_next = true;
499                            }
500                        }
501
502                        PbTaskStatus::Finished => {
503                            finished_task_cnt += 1;
504                            assert!(finished_task_cnt <= self.tasks.keys().len());
505                            assert!(running_task_cnt >= finished_task_cnt);
506                            if finished_task_cnt == self.tasks.keys().len() {
507                                // All tasks finished without failure, we should not break
508                                // this loop
509                                self.notify_stage_completed().await;
510                                sent_signal_to_next = true;
511                                break;
512                            }
513                        }
514                        PbTaskStatus::Aborted => {
515                            // Currently, the only reason that we receive an abort status is that
516                            // the task's memory usage is too high so
517                            // it's aborted.
518                            error!(
519                                "Abort task {:?} because of excessive memory usage. Please try again later.",
520                                status.task_id.unwrap()
521                            );
522                            self.notify_stage_state_changed(
523                                |_| StageState::Failed,
524                                QueryMessage::Stage(Failed {
525                                    id: self.stage_id,
526                                    reason: TaskRunningOutOfMemory,
527                                }),
528                            )
529                            .await;
530                            sent_signal_to_next = true;
531                            break;
532                        }
533                        PbTaskStatus::Failed => {
534                            // Task failed, we should fail whole query
535                            error!(
536                                "Task {:?} failed, reason: {:?}",
537                                status.task_id.unwrap(),
538                                status.error_message,
539                            );
540                            self.notify_stage_state_changed(
541                                |_| StageState::Failed,
542                                QueryMessage::Stage(Failed {
543                                    id: self.stage_id,
544                                    reason: TaskExecutionError(status.error_message),
545                                }),
546                            )
547                            .await;
548                            sent_signal_to_next = true;
549                            break;
550                        }
551                        PbTaskStatus::Ping => {
552                            debug!("Receive ping from task {:?}", status.task_id.unwrap());
553                        }
554                        status => {
555                            // The remain possible variant is Failed, but now they won't be pushed
556                            // from CN.
557                            unreachable!("Unexpected task status {:?}", status);
558                        }
559                    }
560                }
561                Err(e) => {
562                    // rpc error here, we should also notify stage failure
563                    error!(
564                        "Fetching task status in stage {:?} failed, reason: {:?}",
565                        self.stage_id,
566                        e.message()
567                    );
568                    self.notify_stage_state_changed(
569                        |_| StageState::Failed,
570                        QueryMessage::Stage(Failed {
571                            id: self.stage_id,
572                            reason: RpcError::from_batch_status(e).into(),
573                        }),
574                    )
575                    .await;
576                    sent_signal_to_next = true;
577                    break;
578                }
579            }
580        }
581
582        tracing::trace!(
583            "Stage [{:?}-{:?}], running task count: {}, finished task count: {}, sent signal to next: {}",
584            self.query.query_id,
585            self.stage_id,
586            running_task_cnt,
587            finished_task_cnt,
588            sent_signal_to_next,
589        );
590
591        if let Some(shutdown) = all_streams.take_future() {
592            tracing::trace!(
593                "Stage [{:?}-{:?}] waiting for stopping signal.",
594                self.query.query_id,
595                self.stage_id
596            );
597            // Waiting for shutdown signal.
598            shutdown.await;
599        }
600
601        // Received shutdown signal from query runner, should send abort RPC to all CNs.
602        // change state to aborted. Note that the task cancel can only happen after schedule
603        // all these tasks to CN. This can be an optimization for future:
604        // How to stop before schedule tasks.
605        tracing::trace!(
606            "Stopping stage: {:?}-{:?}, task_num: {}",
607            self.query.query_id,
608            self.stage_id,
609            self.tasks.len()
610        );
611        self.cancel_all_scheducancled_tasks().await?;
612
613        tracing::trace!(
614            "Stage runner [{:?}-{:?}] exited.",
615            self.query.query_id,
616            self.stage_id
617        );
618        Ok(())
619    }
620
621    async fn schedule_tasks_for_root(
622        &mut self,
623        mut shutdown_rx: ShutdownToken,
624        expr_context: ExprContext,
625    ) -> SchedulerResult<()> {
626        let root_stage_id = self.stage_id;
627        // Currently, the dml or table scan should never be root fragment, so the partition is None.
628        // And root fragment only contain one task.
629        let plan_fragment = self.create_plan_fragment(ROOT_TASK_ID, None);
630        let plan_node = plan_fragment.root.unwrap();
631        let task_id = TaskIdBatch {
632            query_id: self.query.query_id.id.clone(),
633            stage_id: root_stage_id.into(),
634            task_id: 0,
635        };
636
637        // Notify QueryRunner to poll chunk from result_rx.
638        let (result_tx, result_rx) = tokio::sync::mpsc::channel(
639            self.ctx
640                .session
641                .env()
642                .batch_config()
643                .developer
644                .root_stage_channel_size,
645        );
646        self.notify_stage_scheduled(QueryMessage::Stage(StageEvent::ScheduledRoot(result_rx)))
647            .await;
648
649        let executor = ExecutorBuilder::new(
650            &plan_node,
651            &task_id,
652            self.ctx.to_batch_task_context(),
653            shutdown_rx.clone(),
654        );
655
656        let shutdown_rx0 = shutdown_rx.clone();
657
658        let result = expr_context_scope(expr_context, async {
659            let executor = executor.build().await?;
660            let chunk_stream = executor.execute();
661            let cancelled = pin!(shutdown_rx.cancelled());
662            #[for_await]
663            for chunk in chunk_stream.take_until(cancelled) {
664                if let Err(ref e) = chunk {
665                    if shutdown_rx0.is_cancelled() {
666                        break;
667                    }
668                    let err_str = e.to_report_string();
669                    // This is possible if The Query Runner drop early before schedule the root
670                    // executor. Detail described in https://github.com/risingwavelabs/risingwave/issues/6883#issuecomment-1348102037.
671                    // The error format is just channel closed so no care.
672                    if let Err(_e) = result_tx.send(chunk.map_err(|e| e.into())).await {
673                        warn!("Root executor has been dropped before receive any events so the send is failed");
674                    }
675                    // Different from below, return this function and report error.
676                    return Err(TaskExecutionError(err_str));
677                } else {
678                    // Same for below.
679                    if let Err(_e) = result_tx.send(chunk.map_err(|e| e.into())).await {
680                        warn!("Root executor has been dropped before receive any events so the send is failed");
681                    }
682                }
683            }
684            Ok(())
685        }).await;
686
687        if let Err(err) = &result {
688            // If we encountered error when executing root stage locally, we have to notify the result fetcher, which is
689            // returned by `distribute_execute` and being listened by the FE handler task. Otherwise the FE handler cannot
690            // properly throw the error to the PG client.
691            if let Err(_e) = result_tx
692                .send(Err(TaskExecutionError(err.to_report_string())))
693                .await
694            {
695                warn!("Send task execution failed");
696            }
697        }
698
699        // Terminated by other tasks execution error, so no need to return error here.
700        match shutdown_rx0.message() {
701            ShutdownMsg::Abort(err_str) => {
702                // Tell Query Result Fetcher to stop polling and attach failure reason as str.
703                if let Err(_e) = result_tx.send(Err(TaskExecutionError(err_str))).await {
704                    warn!("Send task execution failed");
705                }
706            }
707            _ => self.notify_stage_completed().await,
708        }
709
710        tracing::trace!(
711            "Stage runner [{:?}-{:?}] existed. ",
712            self.query.query_id,
713            self.stage_id
714        );
715
716        // We still have to throw the error in this current task, so that `StageRunner::run` can further
717        // send `Failed` event to stop other stages.
718        result.map(|_| ())
719    }
720
721    async fn schedule_tasks_for_all(&mut self, shutdown_rx: ShutdownToken) -> SchedulerResult<()> {
722        let expr_context = ExprContext {
723            time_zone: self.ctx.session().config().timezone(),
724            strict_mode: self.ctx.session().config().batch_expr_strict_mode(),
725        };
726        // If root, we execute it locally.
727        if !self.is_root_stage() {
728            self.schedule_tasks(shutdown_rx, expr_context).await?;
729        } else {
730            self.schedule_tasks_for_root(shutdown_rx, expr_context)
731                .await?;
732        }
733        Ok(())
734    }
735
736    #[inline(always)]
737    fn get_fragment_id(&self, table_id: &TableId) -> SchedulerResult<FragmentId> {
738        self.catalog_reader
739            .read_guard()
740            .get_any_table_by_id(table_id)
741            .map(|table| table.fragment_id)
742            .map_err(|e| SchedulerError::Internal(anyhow!(e)))
743    }
744
745    #[inline(always)]
746    fn get_table_dml_vnode_mapping(
747        &self,
748        table_id: &TableId,
749    ) -> SchedulerResult<WorkerSlotMapping> {
750        let guard = self.catalog_reader.read_guard();
751
752        let table = guard
753            .get_any_table_by_id(table_id)
754            .map_err(|e| SchedulerError::Internal(anyhow!(e)))?;
755
756        let fragment_id = match table.dml_fragment_id.as_ref() {
757            Some(dml_fragment_id) => dml_fragment_id,
758            // Backward compatibility for those table without `dml_fragment_id`.
759            None => &table.fragment_id,
760        };
761
762        self.worker_node_manager
763            .manager
764            .get_streaming_fragment_mapping(fragment_id)
765            .map_err(|e| e.into())
766    }
767
768    fn choose_worker(
769        &self,
770        plan_fragment: &PlanFragment,
771        task_id: u32,
772        dml_table_id: Option<TableId>,
773    ) -> SchedulerResult<Option<WorkerNode>> {
774        let plan_node = plan_fragment.root.as_ref().expect("fail to get plan node");
775
776        if let Some(table_id) = dml_table_id {
777            let vnode_mapping = self.get_table_dml_vnode_mapping(&table_id)?;
778            let worker_slot_ids = vnode_mapping.iter_unique().collect_vec();
779            let candidates = self
780                .worker_node_manager
781                .manager
782                .get_workers_by_worker_slot_ids(&worker_slot_ids)?;
783            if candidates.is_empty() {
784                return Err(BatchError::EmptyWorkerNodes.into());
785            }
786            let stage = &self.query.stage(self.stage_id);
787            let candidate = if stage.batch_enable_distributed_dml {
788                // If distributed dml is enabled, we need to try our best to distribute dml tasks evenly to each worker.
789                // Using a `task_id` could be helpful in this case.
790                candidates[task_id as usize % candidates.len()].clone()
791            } else {
792                // If distributed dml is disabled, we need to guarantee that dml from the same session would be sent to a fixed worker/channel to provide a order guarantee.
793                candidates[stage.session_id.0 as usize % candidates.len()].clone()
794            };
795            return Ok(Some(candidate));
796        };
797
798        if let Some(distributed_lookup_join_node) =
799            Self::find_distributed_lookup_join_node(plan_node)
800        {
801            let fragment_id = self.get_fragment_id(
802                &distributed_lookup_join_node
803                    .inner_side_table_desc
804                    .as_ref()
805                    .unwrap()
806                    .table_id
807                    .into(),
808            )?;
809            let id_to_worker_slots = self
810                .worker_node_manager
811                .fragment_mapping(fragment_id)?
812                .iter_unique()
813                .collect_vec();
814
815            let worker_slot_id = id_to_worker_slots[task_id as usize];
816            let candidates = self
817                .worker_node_manager
818                .manager
819                .get_workers_by_worker_slot_ids(&[worker_slot_id])?;
820            if candidates.is_empty() {
821                return Err(BatchError::EmptyWorkerNodes.into());
822            }
823            Ok(Some(candidates[0].clone()))
824        } else {
825            Ok(None)
826        }
827    }
828
829    fn find_distributed_lookup_join_node(
830        plan_node: &PlanNode,
831    ) -> Option<&DistributedLookupJoinNode> {
832        let node_body = plan_node.node_body.as_ref().expect("fail to get node body");
833
834        match node_body {
835            NodeBody::DistributedLookupJoin(distributed_lookup_join_node) => {
836                Some(distributed_lookup_join_node)
837            }
838            _ => plan_node
839                .children
840                .iter()
841                .find_map(Self::find_distributed_lookup_join_node),
842        }
843    }
844
845    /// Write message into channel to notify query runner current stage have been scheduled.
846    async fn notify_stage_scheduled(&self, msg: QueryMessage) {
847        self.notify_stage_state_changed(
848            |old_state| {
849                assert_matches!(old_state, StageState::Started);
850                StageState::Running
851            },
852            msg,
853        )
854        .await
855    }
856
857    /// Notify query execution that this stage completed.
858    async fn notify_stage_completed(&self) {
859        self.notify_stage_state_changed(
860            |old_state| {
861                assert_matches!(old_state, StageState::Running);
862                StageState::Completed
863            },
864            QueryMessage::Stage(StageEvent::Completed(self.stage_id)),
865        )
866        .await
867    }
868
869    async fn notify_stage_state_changed<F>(&self, new_state: F, msg: QueryMessage)
870    where
871        F: FnOnce(StageState) -> StageState,
872    {
873        {
874            let mut s = self.state.write().await;
875            let old_state = mem::replace(&mut *s, StageState::Failed);
876            *s = new_state(old_state);
877        }
878
879        self.send_event(msg).await;
880    }
881
882    /// Abort all registered tasks. Note that here we do not care which part of tasks has already
883    /// failed or completed, cuz the abort task will not fail if the task has already die.
884    /// See PR (#4560).
885    async fn cancel_all_scheducancled_tasks(&self) -> SchedulerResult<()> {
886        // Set state to failed.
887        // {
888        //     let mut state = self.state.write().await;
889        //     // Ignore if already finished.
890        //     if let &StageState::Completed = &*state {
891        //         return Ok(());
892        //     }
893        //     // FIXME: Be careful for state jump back.
894        //     *state = StageState::Failed
895        // }
896
897        for (task, task_status) in &*self.tasks {
898            // 1. Collect task info and client.
899            let loc = &task_status.get_status().location;
900            let addr = loc.as_ref().expect("Get address should not fail");
901            let client = self
902                .compute_client_pool
903                .get_by_addr(HostAddr::from(addr))
904                .await
905                .map_err(|e| anyhow!(e))?;
906
907            // 2. Send RPC to each compute node for each task asynchronously.
908            let query_id = self.query.query_id.id.clone();
909            let stage_id = self.stage_id;
910            let task_id = *task;
911            spawn(async move {
912                if let Err(e) = client
913                    .cancel(CancelTaskRequest {
914                        task_id: Some(risingwave_pb::batch_plan::TaskId {
915                            query_id: query_id.clone(),
916                            stage_id: stage_id.into(),
917                            task_id,
918                        }),
919                    })
920                    .await
921                {
922                    error!(
923                        error = %e.as_report(),
924                        ?task_id,
925                        ?query_id,
926                        ?stage_id,
927                        "Abort task failed",
928                    );
929                };
930            });
931        }
932        Ok(())
933    }
934
935    async fn schedule_task(
936        &self,
937        task_id: PbTaskId,
938        plan_fragment: PlanFragment,
939        worker: Option<WorkerNode>,
940        expr_context: ExprContext,
941    ) -> SchedulerResult<Fuse<Streaming<TaskInfoResponse>>> {
942        let mut worker = worker.unwrap_or(self.worker_node_manager.next_random_worker()?);
943        let worker_node_addr = worker.host.take().unwrap();
944        let compute_client = self
945            .compute_client_pool
946            .get_by_addr((&worker_node_addr).into())
947            .await
948            .inspect_err(|_| self.mask_failed_serving_worker(&worker))
949            .map_err(|e| anyhow!(e))?;
950
951        let t_id = task_id.task_id;
952
953        let stream_status: Fuse<Streaming<TaskInfoResponse>> = compute_client
954            .create_task(task_id, plan_fragment, expr_context)
955            .await
956            .inspect_err(|_| self.mask_failed_serving_worker(&worker))
957            .map_err(|e| anyhow!(e))?
958            .fuse();
959
960        self.tasks[&t_id].inner.store(Arc::new(TaskStatus {
961            _task_id: t_id,
962            location: Some(worker_node_addr),
963        }));
964
965        Ok(stream_status)
966    }
967
968    fn create_plan_fragment(
969        &self,
970        task_id: TaskId,
971        partition: Option<PartitionInfo>,
972    ) -> PlanFragment {
973        // Used to maintain auto-increment identity_id of a task.
974        let mut identity_id = 0;
975
976        let stage = &self.query.stage(self.stage_id);
977
978        let plan_node_prost =
979            self.convert_plan_node(&stage.root, task_id, partition, &mut identity_id);
980        let exchange_info = stage.exchange_info.clone().unwrap();
981
982        PlanFragment {
983            root: Some(plan_node_prost),
984            exchange_info: Some(exchange_info),
985        }
986    }
987
988    fn convert_plan_node(
989        &self,
990        execution_plan_node: &ExecutionPlanNode,
991        task_id: TaskId,
992        partition: Option<PartitionInfo>,
993        identity_id: &mut u64,
994    ) -> PbPlanNode {
995        // Generate identity
996        let identity = {
997            let identity_type = execution_plan_node.plan_node_type;
998            let id = *identity_id;
999            *identity_id += 1;
1000            format!("{:?}-{}", identity_type, id)
1001        };
1002
1003        match execution_plan_node.plan_node_type {
1004            BatchPlanNodeType::BatchExchange => {
1005                // Find the stage this exchange node should fetch from and get all exchange sources.
1006                let child_stage = self
1007                    .children
1008                    .iter()
1009                    .find(|child_stage| {
1010                        child_stage.stage_id == execution_plan_node.source_stage_id.unwrap()
1011                    })
1012                    .unwrap();
1013                let exchange_sources = child_stage.all_exchange_sources_for(task_id);
1014
1015                match &execution_plan_node.node {
1016                    NodeBody::Exchange(exchange_node) => PbPlanNode {
1017                        children: vec![],
1018                        identity,
1019                        node_body: Some(NodeBody::Exchange(ExchangeNode {
1020                            sources: exchange_sources,
1021                            sequential: exchange_node.sequential,
1022                            input_schema: execution_plan_node.schema.clone(),
1023                        })),
1024                    },
1025                    NodeBody::MergeSortExchange(sort_merge_exchange_node) => PbPlanNode {
1026                        children: vec![],
1027                        identity,
1028                        node_body: Some(NodeBody::MergeSortExchange(MergeSortExchangeNode {
1029                            exchange: Some(ExchangeNode {
1030                                sources: exchange_sources,
1031                                sequential: false,
1032                                input_schema: execution_plan_node.schema.clone(),
1033                            }),
1034                            column_orders: sort_merge_exchange_node.column_orders.clone(),
1035                        })),
1036                    },
1037                    _ => unreachable!(),
1038                }
1039            }
1040            BatchPlanNodeType::BatchSeqScan => {
1041                let node_body = execution_plan_node.node.clone();
1042                let NodeBody::RowSeqScan(mut scan_node) = node_body else {
1043                    unreachable!();
1044                };
1045                let partition = partition
1046                    .expect("no partition info for seq scan")
1047                    .into_table()
1048                    .expect("PartitionInfo should be TablePartitionInfo");
1049                scan_node.vnode_bitmap = Some(partition.vnode_bitmap.to_protobuf());
1050                scan_node.scan_ranges = partition.scan_ranges;
1051                PbPlanNode {
1052                    children: vec![],
1053                    identity,
1054                    node_body: Some(NodeBody::RowSeqScan(scan_node)),
1055                }
1056            }
1057            BatchPlanNodeType::BatchLogSeqScan => {
1058                let node_body = execution_plan_node.node.clone();
1059                let NodeBody::LogRowSeqScan(mut scan_node) = node_body else {
1060                    unreachable!();
1061                };
1062                let partition = partition
1063                    .expect("no partition info for seq scan")
1064                    .into_table()
1065                    .expect("PartitionInfo should be TablePartitionInfo");
1066                scan_node.vnode_bitmap = Some(partition.vnode_bitmap.to_protobuf());
1067                PbPlanNode {
1068                    children: vec![],
1069                    identity,
1070                    node_body: Some(NodeBody::LogRowSeqScan(scan_node)),
1071                }
1072            }
1073            BatchPlanNodeType::BatchSource | BatchPlanNodeType::BatchKafkaScan => {
1074                let node_body = execution_plan_node.node.clone();
1075                let NodeBody::Source(mut source_node) = node_body else {
1076                    unreachable!();
1077                };
1078
1079                let partition = partition
1080                    .expect("no partition info for seq scan")
1081                    .into_source()
1082                    .expect("PartitionInfo should be SourcePartitionInfo");
1083                source_node.split = partition
1084                    .into_iter()
1085                    .map(|split| split.encode_to_bytes().into())
1086                    .collect_vec();
1087                PbPlanNode {
1088                    children: vec![],
1089                    identity,
1090                    node_body: Some(NodeBody::Source(source_node)),
1091                }
1092            }
1093            BatchPlanNodeType::BatchIcebergScan => {
1094                let node_body = execution_plan_node.node.clone();
1095                let NodeBody::IcebergScan(mut iceberg_scan_node) = node_body else {
1096                    unreachable!();
1097                };
1098
1099                let partition = partition
1100                    .expect("no partition info for seq scan")
1101                    .into_source()
1102                    .expect("PartitionInfo should be SourcePartitionInfo");
1103                iceberg_scan_node.split = partition
1104                    .into_iter()
1105                    .map(|split| split.encode_to_bytes().into())
1106                    .collect_vec();
1107                PbPlanNode {
1108                    children: vec![],
1109                    identity,
1110                    node_body: Some(NodeBody::IcebergScan(iceberg_scan_node)),
1111                }
1112            }
1113            _ => {
1114                let children = execution_plan_node
1115                    .children
1116                    .iter()
1117                    .map(|e| self.convert_plan_node(e, task_id, partition.clone(), identity_id))
1118                    .collect();
1119
1120                PbPlanNode {
1121                    children,
1122                    identity,
1123                    node_body: Some(execution_plan_node.node.clone()),
1124                }
1125            }
1126        }
1127    }
1128
1129    fn is_root_stage(&self) -> bool {
1130        self.stage_id == 0.into()
1131    }
1132
1133    fn mask_failed_serving_worker(&self, worker: &WorkerNode) {
1134        if !worker.property.as_ref().is_some_and(|p| p.is_serving) {
1135            return;
1136        }
1137        let duration = Duration::from_secs(std::cmp::max(
1138            self.ctx
1139                .session
1140                .env()
1141                .batch_config()
1142                .mask_worker_temporary_secs as u64,
1143            1,
1144        ));
1145        self.worker_node_manager
1146            .manager
1147            .mask_worker_node(worker.id, duration);
1148    }
1149}