Skip to main content

risingwave_frontend/scheduler/distributed/
stage.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::pin::pin;
17use std::sync::Arc;
18use std::time::Duration;
19use std::{assert_matches, mem};
20
21use StageEvent::Failed;
22use anyhow::anyhow;
23use arc_swap::ArcSwap;
24use futures::stream::Fuse;
25use futures::{StreamExt, TryStreamExt, stream};
26use futures_async_stream::for_await;
27use itertools::Itertools;
28use risingwave_batch::error::BatchError;
29use risingwave_batch::executor::ExecutorBuilder;
30use risingwave_batch::task::{ShutdownMsg, ShutdownSender, ShutdownToken, TaskId as TaskIdBatch};
31use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeSelector;
32use risingwave_common::array::DataChunk;
33use risingwave_common::hash::WorkerSlotMapping;
34use risingwave_common::util::addr::HostAddr;
35use risingwave_common::util::iter_util::ZipEqFast;
36use risingwave_connector::source::SplitMetaData;
37use risingwave_expr::expr_context::expr_context_scope;
38use risingwave_pb::batch_plan::plan_node::NodeBody;
39use risingwave_pb::batch_plan::{
40    DistributedLookupJoinNode, ExchangeNode, ExchangeSource, MergeSortExchangeNode, PlanFragment,
41    PlanNode as PbPlanNode, PlanNode, TaskId as PbTaskId, TaskOutputId,
42};
43use risingwave_pb::common::{HostAddress, WorkerNode};
44use risingwave_pb::plan_common::ExprContext;
45use risingwave_pb::task_service::{CancelTaskRequest, TaskInfoResponse};
46use risingwave_rpc_client::ComputeClientPoolRef;
47use risingwave_rpc_client::error::RpcError;
48use rw_futures_util::select_all;
49use thiserror_ext::AsReport;
50use tokio::spawn;
51use tokio::sync::RwLock;
52use tokio::sync::mpsc::{Receiver, Sender};
53use tonic::Streaming;
54use tracing::{Instrument, debug, error, warn};
55
56use crate::catalog::catalog_service::CatalogReader;
57use crate::catalog::{FragmentId, TableId};
58use crate::optimizer::plan_node::BatchPlanNodeType;
59use crate::scheduler::SchedulerError::{TaskExecutionError, TaskRunningOutOfMemory};
60use crate::scheduler::distributed::QueryMessage;
61use crate::scheduler::distributed::stage::StageState::Pending;
62use crate::scheduler::plan_fragmenter::{
63    ExecutionPlanNode, PartitionInfo, Query, ROOT_TASK_ID, StageId, TaskId,
64};
65use crate::scheduler::{ExecutionContextRef, SchedulerError, SchedulerResult};
66
67const TASK_SCHEDULING_PARALLELISM: usize = 10;
68
69#[derive(Debug)]
70enum StageState {
71    /// We put `msg_sender` in `Pending` state to avoid holding it in `StageExecution`. In this
72    /// way, it could be efficiently moved into `StageRunner` instead of being cloned. This also
73    /// ensures that the sender can get dropped once it is used up, preventing some issues caused
74    /// by unnecessarily long lifetime.
75    Pending {
76        msg_sender: Sender<QueryMessage>,
77    },
78    Started,
79    Running,
80    Completed,
81    Failed,
82}
83
84#[derive(Debug)]
85pub enum StageEvent {
86    Scheduled(StageId),
87    ScheduledRoot(Receiver<SchedulerResult<DataChunk>>),
88    /// Stage failed.
89    Failed {
90        id: StageId,
91        reason: SchedulerError,
92    },
93    /// All tasks in stage finished.
94    Completed(#[expect(dead_code)] StageId),
95}
96
97#[derive(Clone)]
98pub struct TaskStatus {
99    _task_id: TaskId,
100
101    // None before task is scheduled.
102    location: Option<HostAddress>,
103}
104
105struct TaskStatusHolder {
106    inner: ArcSwap<TaskStatus>,
107}
108
109pub struct StageExecution {
110    stage_id: StageId,
111    query: Arc<Query>,
112    worker_node_manager: WorkerNodeSelector,
113    tasks: Arc<HashMap<TaskId, TaskStatusHolder>>,
114    state: Arc<RwLock<StageState>>,
115    shutdown_tx: RwLock<Option<ShutdownSender>>,
116    /// Children stage executions.
117    ///
118    /// We use `Vec` here since children's size is usually small.
119    children: Vec<Arc<StageExecution>>,
120    compute_client_pool: ComputeClientPoolRef,
121    catalog_reader: CatalogReader,
122
123    /// Execution context ref
124    ctx: ExecutionContextRef,
125}
126
127struct StageRunner {
128    state: Arc<RwLock<StageState>>,
129    stage_id: StageId,
130    query: Arc<Query>,
131    worker_node_manager: WorkerNodeSelector,
132    tasks: Arc<HashMap<TaskId, TaskStatusHolder>>,
133    // Send message to `QueryRunner` to notify stage state change.
134    msg_sender: Sender<QueryMessage>,
135    children: Vec<Arc<StageExecution>>,
136    compute_client_pool: ComputeClientPoolRef,
137    catalog_reader: CatalogReader,
138
139    ctx: ExecutionContextRef,
140}
141
142impl TaskStatusHolder {
143    fn new(task_id: TaskId) -> Self {
144        let task_status = TaskStatus {
145            _task_id: task_id,
146            location: None,
147        };
148
149        Self {
150            inner: ArcSwap::new(Arc::new(task_status)),
151        }
152    }
153
154    fn get_status(&self) -> Arc<TaskStatus> {
155        self.inner.load_full()
156    }
157}
158
159impl StageExecution {
160    pub fn new(
161        stage_id: StageId,
162        query: Arc<Query>,
163        worker_node_manager: WorkerNodeSelector,
164        msg_sender: Sender<QueryMessage>,
165        children: Vec<Arc<StageExecution>>,
166        compute_client_pool: ComputeClientPoolRef,
167        catalog_reader: CatalogReader,
168        ctx: ExecutionContextRef,
169    ) -> Self {
170        let tasks = (0..query.stage(stage_id).parallelism.unwrap())
171            .map(|task_id| (task_id as u64, TaskStatusHolder::new(task_id as u64)))
172            .collect();
173
174        Self {
175            stage_id,
176            query,
177            worker_node_manager,
178            tasks: Arc::new(tasks),
179            state: Arc::new(RwLock::new(Pending { msg_sender })),
180            shutdown_tx: RwLock::new(None),
181            children,
182            compute_client_pool,
183            catalog_reader,
184            ctx,
185        }
186    }
187
188    /// Starts execution of this stage, returns error if already started.
189    pub async fn start(&self) {
190        let mut s = self.state.write().await;
191        let cur_state = mem::replace(&mut *s, StageState::Failed);
192        match cur_state {
193            Pending { msg_sender } => {
194                let runner = StageRunner {
195                    stage_id: self.stage_id,
196                    query: self.query.clone(),
197                    worker_node_manager: self.worker_node_manager.clone(),
198                    tasks: self.tasks.clone(),
199                    msg_sender,
200                    children: self.children.clone(),
201                    state: self.state.clone(),
202                    compute_client_pool: self.compute_client_pool.clone(),
203                    catalog_reader: self.catalog_reader.clone(),
204                    ctx: self.ctx.clone(),
205                };
206
207                // The channel used for shutdown signal messaging.
208                let (sender, receiver) = ShutdownToken::new();
209                // Fill the shutdown sender.
210                let mut holder = self.shutdown_tx.write().await;
211                *holder = Some(sender);
212
213                // Change state before spawn runner.
214                *s = StageState::Started;
215
216                let span = tracing::info_span!(
217                    "stage",
218                    "otel.name" = format!("Stage {}-{}", self.query.query_id.id, self.stage_id),
219                    query_id = self.query.query_id.id,
220                    stage_id = %self.stage_id,
221                );
222                self.ctx
223                    .session()
224                    .env()
225                    .compute_runtime()
226                    .spawn(async move { runner.run(receiver).instrument(span).await });
227
228                tracing::trace!(
229                    "Stage {:?}-{:?} started.",
230                    self.query.query_id.id,
231                    self.stage_id
232                )
233            }
234            _ => {
235                unreachable!("Only expect to schedule stage once");
236            }
237        }
238    }
239
240    pub async fn stop(&self, error: Option<String>) {
241        // Send message to tell Stage Runner stop.
242        if let Some(shutdown_tx) = self.shutdown_tx.write().await.take() {
243            // It's possible that the stage has not been scheduled, so the channel sender is
244            // None.
245
246            if !if let Some(error) = error {
247                shutdown_tx.abort(error)
248            } else {
249                shutdown_tx.cancel()
250            } {
251                // The stage runner handle has already closed. so do no-op.
252                tracing::trace!(
253                    "Failed to send stop message stage: {:?}-{:?}",
254                    self.query.query_id,
255                    self.stage_id
256                );
257            }
258        }
259    }
260
261    pub async fn is_scheduled(&self) -> bool {
262        let s = self.state.read().await;
263        matches!(*s, StageState::Running | StageState::Completed)
264    }
265
266    pub async fn is_pending(&self) -> bool {
267        let s = self.state.read().await;
268        matches!(*s, StageState::Pending { .. })
269    }
270
271    pub async fn state(&self) -> &'static str {
272        let s = self.state.read().await;
273        match *s {
274            Pending { .. } => "Pending",
275            StageState::Started => "Started",
276            StageState::Running => "Running",
277            StageState::Completed => "Completed",
278            StageState::Failed => "Failed",
279        }
280    }
281
282    /// Returns all exchange sources for `output_id`. Each `ExchangeSource` is identified by
283    /// producer's `TaskId` and `output_id` (consumer's `TaskId`), since each task may produce
284    /// output to several channels.
285    ///
286    /// When this method is called, all tasks should have been scheduled, and their `worker_node`
287    /// should have been set.
288    pub fn all_exchange_sources_for(&self, output_id: u64) -> Vec<ExchangeSource> {
289        self.tasks
290            .iter()
291            .map(|(task_id, status_holder)| {
292                let task_output_id = TaskOutputId {
293                    task_id: Some(PbTaskId {
294                        query_id: self.query.query_id.id.clone(),
295                        stage_id: self.stage_id.into(),
296                        task_id: *task_id,
297                    }),
298                    output_id,
299                };
300
301                ExchangeSource {
302                    task_output_id: Some(task_output_id),
303                    host: Some(status_holder.inner.load_full().location.clone().unwrap()),
304                    local_execute_plan: None,
305                }
306            })
307            .collect()
308    }
309}
310
311impl StageRunner {
312    async fn run(mut self, shutdown_rx: ShutdownToken) {
313        if let Err(e) = self.schedule_tasks_for_all(shutdown_rx).await {
314            error!(
315                error = %e.as_report(),
316                query_id = ?self.query.query_id,
317                stage_id = ?self.stage_id,
318                "Failed to schedule tasks"
319            );
320            self.send_event(QueryMessage::Stage(Failed {
321                id: self.stage_id,
322                reason: e,
323            }))
324            .await;
325        }
326    }
327
328    /// Send stage event to listener.
329    async fn send_event(&self, event: QueryMessage) {
330        if let Err(_e) = self.msg_sender.send(event).await {
331            warn!("Failed to send event to Query Runner, may be killed by previous failed event");
332        }
333    }
334
335    /// Schedule all tasks to CN and wait process all status messages from RPC. Note that when all
336    /// task is created, it should tell `QueryRunner` to schedule next.
337    async fn schedule_tasks(
338        &mut self,
339        mut shutdown_rx: ShutdownToken,
340        expr_context: ExprContext,
341    ) -> SchedulerResult<()> {
342        let mut futures = vec![];
343        let stage = &self.query.stage(self.stage_id);
344
345        if let Some(table_scan_info) = stage.table_scan_info.as_ref()
346            && let Some(vnode_bitmaps) = table_scan_info.partitions()
347        {
348            // If the stage has table scan nodes, we create tasks according to the data distribution
349            // and partition of the table.
350            // We let each task read one partition by setting the `vnode_ranges` of the scan node in
351            // the task.
352            // We schedule the task to the worker node that owns the data partition.
353            let worker_slot_ids = vnode_bitmaps.keys().cloned().collect_vec();
354            let workers = self
355                .worker_node_manager
356                .manager
357                .get_workers_by_worker_slot_ids(&worker_slot_ids)?;
358
359            for (i, (worker_slot_id, worker)) in worker_slot_ids
360                .into_iter()
361                .zip_eq_fast(workers.into_iter())
362                .enumerate()
363            {
364                let task_id = PbTaskId {
365                    query_id: self.query.query_id.id.clone(),
366                    stage_id: self.stage_id.into(),
367                    task_id: i as u64,
368                };
369                let vnode_ranges = vnode_bitmaps[&worker_slot_id].clone();
370                let plan_fragment =
371                    self.create_plan_fragment(i as u64, Some(PartitionInfo::Table(vnode_ranges)));
372                futures.push(self.schedule_task(
373                    task_id,
374                    plan_fragment,
375                    Some(worker),
376                    expr_context.clone(),
377                ));
378            }
379        } else if let Some(source_info) = stage.source_info.as_ref() {
380            // If there is no file in source, the `chunk_size` is set to 1.
381            let chunk_size = ((source_info.split_info().unwrap().len() as f32
382                / stage.parallelism.unwrap() as f32)
383                .ceil() as usize)
384                .max(1);
385            if source_info.split_info().unwrap().is_empty() {
386                // No file in source, schedule an empty task.
387                const EMPTY_TASK_ID: u64 = 0;
388                let task_id = PbTaskId {
389                    query_id: self.query.query_id.id.clone(),
390                    stage_id: self.stage_id.into(),
391                    task_id: EMPTY_TASK_ID,
392                };
393                let plan_fragment =
394                    self.create_plan_fragment(EMPTY_TASK_ID, Some(PartitionInfo::Source(vec![])));
395                let worker =
396                    self.choose_worker(&plan_fragment, EMPTY_TASK_ID as u32, stage.dml_table_id)?;
397                futures.push(self.schedule_task(
398                    task_id,
399                    plan_fragment,
400                    worker,
401                    expr_context.clone(),
402                ));
403            } else {
404                for (id, split) in source_info
405                    .split_info()
406                    .unwrap()
407                    .chunks(chunk_size)
408                    .enumerate()
409                {
410                    let task_id = PbTaskId {
411                        query_id: self.query.query_id.id.clone(),
412                        stage_id: self.stage_id.into(),
413                        task_id: id as u64,
414                    };
415                    let plan_fragment = self.create_plan_fragment(
416                        id as u64,
417                        Some(PartitionInfo::Source(split.to_vec())),
418                    );
419                    let worker =
420                        self.choose_worker(&plan_fragment, id as u32, stage.dml_table_id)?;
421                    futures.push(self.schedule_task(
422                        task_id,
423                        plan_fragment,
424                        worker,
425                        expr_context.clone(),
426                    ));
427                }
428            }
429        } else if let Some(file_scan_info) = stage.file_scan_info.as_ref() {
430            let chunk_size = (file_scan_info.file_location.len() as f32
431                / stage.parallelism.unwrap() as f32)
432                .ceil() as usize;
433            for (id, files) in file_scan_info.file_location.chunks(chunk_size).enumerate() {
434                let task_id = PbTaskId {
435                    query_id: self.query.query_id.id.clone(),
436                    stage_id: self.stage_id.into(),
437                    task_id: id as u64,
438                };
439                let plan_fragment =
440                    self.create_plan_fragment(id as u64, Some(PartitionInfo::File(files.to_vec())));
441                let worker = self.choose_worker(&plan_fragment, id as u32, stage.dml_table_id)?;
442                futures.push(self.schedule_task(
443                    task_id,
444                    plan_fragment,
445                    worker,
446                    expr_context.clone(),
447                ));
448            }
449        } else {
450            for id in 0..stage.parallelism.unwrap() {
451                let task_id = PbTaskId {
452                    query_id: self.query.query_id.id.clone(),
453                    stage_id: self.stage_id.into(),
454                    task_id: id as u64,
455                };
456                let plan_fragment = self.create_plan_fragment(id as u64, None);
457                let worker = self.choose_worker(&plan_fragment, id, stage.dml_table_id)?;
458                futures.push(self.schedule_task(
459                    task_id,
460                    plan_fragment,
461                    worker,
462                    expr_context.clone(),
463                ));
464            }
465        }
466
467        // Await each future and convert them into a set of streams.
468        let buffered = stream::iter(futures).buffer_unordered(TASK_SCHEDULING_PARALLELISM);
469        let buffered_streams = buffered.try_collect::<Vec<_>>().await?;
470
471        // Merge different task streams into a single stream.
472        let cancelled = pin!(shutdown_rx.cancelled());
473        let mut all_streams = select_all(buffered_streams).take_until(cancelled);
474
475        // Process the stream until finished.
476        let mut running_task_cnt = 0;
477        let mut finished_task_cnt = 0;
478        let mut sent_signal_to_next = false;
479
480        while let Some(status_res_inner) = all_streams.next().await {
481            match status_res_inner {
482                Ok(status) => {
483                    use risingwave_pb::task_service::task_info_response::TaskStatus as PbTaskStatus;
484                    match PbTaskStatus::try_from(status.task_status).unwrap() {
485                        PbTaskStatus::Running => {
486                            running_task_cnt += 1;
487                            // The task running count should always less or equal than the
488                            // registered tasks number.
489                            assert!(running_task_cnt <= self.tasks.keys().len());
490                            // All tasks in this stage have been scheduled. Notify query runner to
491                            // schedule next stage.
492                            if running_task_cnt == self.tasks.keys().len() {
493                                self.notify_stage_scheduled(QueryMessage::Stage(
494                                    StageEvent::Scheduled(self.stage_id),
495                                ))
496                                .await;
497                                sent_signal_to_next = true;
498                            }
499                        }
500
501                        PbTaskStatus::Finished => {
502                            finished_task_cnt += 1;
503                            assert!(finished_task_cnt <= self.tasks.keys().len());
504                            assert!(running_task_cnt >= finished_task_cnt);
505                            if finished_task_cnt == self.tasks.keys().len() {
506                                // All tasks finished without failure, we should not break
507                                // this loop
508                                self.notify_stage_completed().await;
509                                sent_signal_to_next = true;
510                                break;
511                            }
512                        }
513                        PbTaskStatus::Aborted => {
514                            // Currently, the only reason that we receive an abort status is that
515                            // the task's memory usage is too high so
516                            // it's aborted.
517                            error!(
518                                "Abort task {:?} because of excessive memory usage. Please try again later.",
519                                status.task_id.unwrap()
520                            );
521                            self.notify_stage_state_changed(
522                                |_| StageState::Failed,
523                                QueryMessage::Stage(Failed {
524                                    id: self.stage_id,
525                                    reason: TaskRunningOutOfMemory,
526                                }),
527                            )
528                            .await;
529                            sent_signal_to_next = true;
530                            break;
531                        }
532                        PbTaskStatus::Failed => {
533                            // Task failed, we should fail whole query
534                            error!(
535                                "Task {:?} failed, reason: {:?}",
536                                status.task_id.unwrap(),
537                                status.error_message,
538                            );
539                            self.notify_stage_state_changed(
540                                |_| StageState::Failed,
541                                QueryMessage::Stage(Failed {
542                                    id: self.stage_id,
543                                    reason: TaskExecutionError(status.error_message),
544                                }),
545                            )
546                            .await;
547                            sent_signal_to_next = true;
548                            break;
549                        }
550                        PbTaskStatus::Ping => {
551                            debug!("Receive ping from task {:?}", status.task_id.unwrap());
552                        }
553                        status => {
554                            // The remain possible variant is Failed, but now they won't be pushed
555                            // from CN.
556                            unreachable!("Unexpected task status {:?}", status);
557                        }
558                    }
559                }
560                Err(e) => {
561                    // rpc error here, we should also notify stage failure
562                    error!(
563                        "Fetching task status in stage {:?} failed, reason: {:?}",
564                        self.stage_id,
565                        e.message()
566                    );
567                    self.notify_stage_state_changed(
568                        |_| StageState::Failed,
569                        QueryMessage::Stage(Failed {
570                            id: self.stage_id,
571                            reason: RpcError::from_batch_status(e).into(),
572                        }),
573                    )
574                    .await;
575                    sent_signal_to_next = true;
576                    break;
577                }
578            }
579        }
580
581        tracing::trace!(
582            "Stage [{:?}-{:?}], running task count: {}, finished task count: {}, sent signal to next: {}",
583            self.query.query_id,
584            self.stage_id,
585            running_task_cnt,
586            finished_task_cnt,
587            sent_signal_to_next,
588        );
589
590        if let Some(shutdown) = all_streams.take_future() {
591            tracing::trace!(
592                "Stage [{:?}-{:?}] waiting for stopping signal.",
593                self.query.query_id,
594                self.stage_id
595            );
596            // Waiting for shutdown signal.
597            shutdown.await;
598        }
599
600        // Received shutdown signal from query runner, should send abort RPC to all CNs.
601        // change state to aborted. Note that the task cancel can only happen after schedule
602        // all these tasks to CN. This can be an optimization for future:
603        // How to stop before schedule tasks.
604        tracing::trace!(
605            "Stopping stage: {:?}-{:?}, task_num: {}",
606            self.query.query_id,
607            self.stage_id,
608            self.tasks.len()
609        );
610        self.cancel_all_scheducancled_tasks().await?;
611
612        tracing::trace!(
613            "Stage runner [{:?}-{:?}] exited.",
614            self.query.query_id,
615            self.stage_id
616        );
617        Ok(())
618    }
619
620    async fn schedule_tasks_for_root(
621        &mut self,
622        mut shutdown_rx: ShutdownToken,
623        expr_context: ExprContext,
624    ) -> SchedulerResult<()> {
625        let root_stage_id = self.stage_id;
626        // Currently, the dml or table scan should never be root fragment, so the partition is None.
627        // And root fragment only contain one task.
628        let plan_fragment = self.create_plan_fragment(ROOT_TASK_ID, None);
629        let plan_node = plan_fragment.root.unwrap();
630        let task_id = TaskIdBatch {
631            query_id: self.query.query_id.id.clone(),
632            stage_id: root_stage_id.into(),
633            task_id: 0,
634        };
635
636        // Notify QueryRunner to poll chunk from result_rx.
637        let (result_tx, result_rx) = tokio::sync::mpsc::channel(
638            self.ctx
639                .session
640                .env()
641                .batch_config()
642                .developer
643                .root_stage_channel_size,
644        );
645        self.notify_stage_scheduled(QueryMessage::Stage(StageEvent::ScheduledRoot(result_rx)))
646            .await;
647
648        let executor = ExecutorBuilder::new(
649            &plan_node,
650            &task_id,
651            self.ctx.to_batch_task_context(),
652            shutdown_rx.clone(),
653        );
654
655        let shutdown_rx0 = shutdown_rx.clone();
656
657        let result = expr_context_scope(expr_context, async {
658            let executor = executor.build().await?;
659            let chunk_stream = executor.execute();
660            let cancelled = pin!(shutdown_rx.cancelled());
661            #[for_await]
662            for chunk in chunk_stream.take_until(cancelled) {
663                if let Err(ref e) = chunk {
664                    if shutdown_rx0.is_cancelled() {
665                        break;
666                    }
667                    let err_str = e.to_report_string();
668                    // This is possible if The Query Runner drop early before schedule the root
669                    // executor. Detail described in https://github.com/risingwavelabs/risingwave/issues/6883#issuecomment-1348102037.
670                    // The error format is just channel closed so no care.
671                    if let Err(_e) = result_tx.send(chunk.map_err(|e| e.into())).await {
672                        warn!("Root executor has been dropped before receive any events so the send is failed");
673                    }
674                    // Different from below, return this function and report error.
675                    return Err(TaskExecutionError(err_str));
676                } else {
677                    // Same for below.
678                    if let Err(_e) = result_tx.send(chunk.map_err(|e| e.into())).await {
679                        warn!("Root executor has been dropped before receive any events so the send is failed");
680                    }
681                }
682            }
683            Ok(())
684        }).await;
685
686        if let Err(err) = &result {
687            // If we encountered error when executing root stage locally, we have to notify the result fetcher, which is
688            // returned by `distribute_execute` and being listened by the FE handler task. Otherwise the FE handler cannot
689            // properly throw the error to the PG client.
690            if let Err(_e) = result_tx
691                .send(Err(TaskExecutionError(err.to_report_string())))
692                .await
693            {
694                warn!("Send task execution failed");
695            }
696        }
697
698        // Terminated by other tasks execution error, so no need to return error here.
699        match shutdown_rx0.message() {
700            ShutdownMsg::Abort(err_str) => {
701                // Tell Query Result Fetcher to stop polling and attach failure reason as str.
702                if let Err(_e) = result_tx.send(Err(TaskExecutionError(err_str))).await {
703                    warn!("Send task execution failed");
704                }
705            }
706            _ => self.notify_stage_completed().await,
707        }
708
709        tracing::trace!(
710            "Stage runner [{:?}-{:?}] existed. ",
711            self.query.query_id,
712            self.stage_id
713        );
714
715        // We still have to throw the error in this current task, so that `StageRunner::run` can further
716        // send `Failed` event to stop other stages.
717        result.map(|_| ())
718    }
719
720    async fn schedule_tasks_for_all(&mut self, shutdown_rx: ShutdownToken) -> SchedulerResult<()> {
721        let expr_context = ExprContext {
722            time_zone: self.ctx.session().config().timezone(),
723            strict_mode: self.ctx.session().config().batch_expr_strict_mode(),
724        };
725        // If root, we execute it locally.
726        if !self.is_root_stage() {
727            self.schedule_tasks(shutdown_rx, expr_context).await?;
728        } else {
729            self.schedule_tasks_for_root(shutdown_rx, expr_context)
730                .await?;
731        }
732        Ok(())
733    }
734
735    #[inline(always)]
736    fn get_fragment_id(&self, table_id: TableId) -> SchedulerResult<FragmentId> {
737        self.catalog_reader
738            .read_guard()
739            .get_any_table_by_id(table_id)
740            .map(|table| table.fragment_id)
741            .map_err(|e| SchedulerError::Internal(anyhow!(e)))
742    }
743
744    #[inline(always)]
745    fn get_table_dml_vnode_mapping(&self, table_id: TableId) -> SchedulerResult<WorkerSlotMapping> {
746        let guard = self.catalog_reader.read_guard();
747
748        let table = guard
749            .get_any_table_by_id(table_id)
750            .map_err(|e| SchedulerError::Internal(anyhow!(e)))?;
751
752        let fragment_id = match table.dml_fragment_id.as_ref() {
753            Some(dml_fragment_id) => dml_fragment_id,
754            // Backward compatibility for those table without `dml_fragment_id`.
755            None => &table.fragment_id,
756        };
757
758        self.worker_node_manager
759            .manager
760            .get_streaming_fragment_mapping(fragment_id)
761            .map_err(|e| e.into())
762    }
763
764    fn choose_worker(
765        &self,
766        plan_fragment: &PlanFragment,
767        task_id: u32,
768        dml_table_id: Option<TableId>,
769    ) -> SchedulerResult<Option<WorkerNode>> {
770        let plan_node = plan_fragment.root.as_ref().expect("fail to get plan node");
771
772        if let Some(table_id) = dml_table_id {
773            let vnode_mapping = self.get_table_dml_vnode_mapping(table_id)?;
774            let worker_slot_ids = vnode_mapping.iter_unique().collect_vec();
775            let candidates = self
776                .worker_node_manager
777                .manager
778                .get_workers_by_worker_slot_ids(&worker_slot_ids)?;
779            if candidates.is_empty() {
780                return Err(BatchError::EmptyWorkerNodes.into());
781            }
782            let stage = &self.query.stage(self.stage_id);
783            let candidate = if stage.batch_enable_distributed_dml {
784                // If distributed dml is enabled, we need to try our best to distribute dml tasks evenly to each worker.
785                // Using a `task_id` could be helpful in this case.
786                candidates[task_id as usize % candidates.len()].clone()
787            } else {
788                // If distributed dml is disabled, we need to guarantee that dml from the same session would be sent to a fixed worker/channel to provide a order guarantee.
789                candidates[stage.session_id.0 as usize % candidates.len()].clone()
790            };
791            return Ok(Some(candidate));
792        };
793
794        if let Some(distributed_lookup_join_node) =
795            Self::find_distributed_lookup_join_node(plan_node)
796        {
797            let fragment_id = self.get_fragment_id(
798                distributed_lookup_join_node
799                    .inner_side_table_desc
800                    .as_ref()
801                    .unwrap()
802                    .table_id,
803            )?;
804            let id_to_worker_slots = self
805                .worker_node_manager
806                .fragment_mapping(fragment_id)?
807                .iter_unique()
808                .collect_vec();
809
810            let worker_slot_id = id_to_worker_slots[task_id as usize];
811            let candidates = self
812                .worker_node_manager
813                .manager
814                .get_workers_by_worker_slot_ids(&[worker_slot_id])?;
815            if candidates.is_empty() {
816                return Err(BatchError::EmptyWorkerNodes.into());
817            }
818            Ok(Some(candidates[0].clone()))
819        } else {
820            Ok(None)
821        }
822    }
823
824    fn find_distributed_lookup_join_node(
825        plan_node: &PlanNode,
826    ) -> Option<&DistributedLookupJoinNode> {
827        let node_body = plan_node.node_body.as_ref().expect("fail to get node body");
828
829        match node_body {
830            NodeBody::DistributedLookupJoin(distributed_lookup_join_node) => {
831                Some(distributed_lookup_join_node)
832            }
833            _ => plan_node
834                .children
835                .iter()
836                .find_map(Self::find_distributed_lookup_join_node),
837        }
838    }
839
840    /// Write message into channel to notify query runner current stage have been scheduled.
841    async fn notify_stage_scheduled(&self, msg: QueryMessage) {
842        self.notify_stage_state_changed(
843            |old_state| {
844                assert_matches!(old_state, StageState::Started);
845                StageState::Running
846            },
847            msg,
848        )
849        .await
850    }
851
852    /// Notify query execution that this stage completed.
853    async fn notify_stage_completed(&self) {
854        self.notify_stage_state_changed(
855            |old_state| {
856                assert_matches!(old_state, StageState::Running);
857                StageState::Completed
858            },
859            QueryMessage::Stage(StageEvent::Completed(self.stage_id)),
860        )
861        .await
862    }
863
864    async fn notify_stage_state_changed<F>(&self, new_state: F, msg: QueryMessage)
865    where
866        F: FnOnce(StageState) -> StageState,
867    {
868        {
869            let mut s = self.state.write().await;
870            let old_state = mem::replace(&mut *s, StageState::Failed);
871            *s = new_state(old_state);
872        }
873
874        self.send_event(msg).await;
875    }
876
877    /// Abort all registered tasks. Note that here we do not care which part of tasks has already
878    /// failed or completed, cuz the abort task will not fail if the task has already die.
879    /// See PR (#4560).
880    async fn cancel_all_scheducancled_tasks(&self) -> SchedulerResult<()> {
881        // Set state to failed.
882        // {
883        //     let mut state = self.state.write().await;
884        //     // Ignore if already finished.
885        //     if let &StageState::Completed = &*state {
886        //         return Ok(());
887        //     }
888        //     // FIXME: Be careful for state jump back.
889        //     *state = StageState::Failed
890        // }
891
892        for (task, task_status) in &*self.tasks {
893            // 1. Collect task info and client.
894            let loc = &task_status.get_status().location;
895            let addr = loc.as_ref().expect("Get address should not fail");
896            let client = self
897                .compute_client_pool
898                .get_by_addr(HostAddr::from(addr))
899                .await
900                .map_err(|e| anyhow!(e))?;
901
902            // 2. Send RPC to each compute node for each task asynchronously.
903            let query_id = self.query.query_id.id.clone();
904            let stage_id = self.stage_id;
905            let task_id = *task;
906            spawn(async move {
907                if let Err(e) = client
908                    .cancel(CancelTaskRequest {
909                        task_id: Some(risingwave_pb::batch_plan::TaskId {
910                            query_id: query_id.clone(),
911                            stage_id: stage_id.into(),
912                            task_id,
913                        }),
914                    })
915                    .await
916                {
917                    error!(
918                        error = %e.as_report(),
919                        ?task_id,
920                        ?query_id,
921                        ?stage_id,
922                        "Abort task failed",
923                    );
924                };
925            });
926        }
927        Ok(())
928    }
929
930    async fn schedule_task(
931        &self,
932        task_id: PbTaskId,
933        plan_fragment: PlanFragment,
934        worker: Option<WorkerNode>,
935        expr_context: ExprContext,
936    ) -> SchedulerResult<Fuse<Streaming<TaskInfoResponse>>> {
937        let mut worker = worker.unwrap_or(self.worker_node_manager.next_random_worker()?);
938        let worker_node_addr = worker.host.take().unwrap();
939        let compute_client = self
940            .compute_client_pool
941            .get_by_addr((&worker_node_addr).into())
942            .await
943            .inspect_err(|_| self.mask_failed_serving_worker(&worker))
944            .map_err(|e| anyhow!(e))?;
945
946        let t_id = task_id.task_id;
947
948        let stream_status: Fuse<Streaming<TaskInfoResponse>> = compute_client
949            .create_task(task_id, plan_fragment, expr_context)
950            .await
951            .inspect_err(|_| self.mask_failed_serving_worker(&worker))
952            .map_err(|e| anyhow!(e))?
953            .fuse();
954
955        self.tasks[&t_id].inner.store(Arc::new(TaskStatus {
956            _task_id: t_id,
957            location: Some(worker_node_addr),
958        }));
959
960        Ok(stream_status)
961    }
962
963    fn create_plan_fragment(
964        &self,
965        task_id: TaskId,
966        partition: Option<PartitionInfo>,
967    ) -> PlanFragment {
968        // Used to maintain auto-increment identity_id of a task.
969        let mut identity_id = 0;
970
971        let stage = &self.query.stage(self.stage_id);
972
973        let plan_node_prost =
974            self.convert_plan_node(&stage.root, task_id, partition, &mut identity_id);
975        let exchange_info = stage.exchange_info.clone().unwrap();
976
977        PlanFragment {
978            root: Some(plan_node_prost),
979            exchange_info: Some(exchange_info),
980        }
981    }
982
983    fn convert_plan_node(
984        &self,
985        execution_plan_node: &ExecutionPlanNode,
986        task_id: TaskId,
987        partition: Option<PartitionInfo>,
988        identity_id: &mut u64,
989    ) -> PbPlanNode {
990        // Generate identity
991        let identity = {
992            let identity_type = execution_plan_node.plan_node_type;
993            let id = *identity_id;
994            *identity_id += 1;
995            format!("{:?}-{}", identity_type, id)
996        };
997
998        match execution_plan_node.plan_node_type {
999            BatchPlanNodeType::BatchExchange => {
1000                // Find the stage this exchange node should fetch from and get all exchange sources.
1001                let child_stage = self
1002                    .children
1003                    .iter()
1004                    .find(|child_stage| {
1005                        child_stage.stage_id == execution_plan_node.source_stage_id.unwrap()
1006                    })
1007                    .unwrap();
1008                let exchange_sources = child_stage.all_exchange_sources_for(task_id);
1009
1010                match &execution_plan_node.node {
1011                    NodeBody::Exchange(exchange_node) => PbPlanNode {
1012                        children: vec![],
1013                        identity,
1014                        node_body: Some(NodeBody::Exchange(ExchangeNode {
1015                            sources: exchange_sources,
1016                            sequential: exchange_node.sequential,
1017                            input_schema: execution_plan_node.schema.clone(),
1018                        })),
1019                    },
1020                    NodeBody::MergeSortExchange(sort_merge_exchange_node) => PbPlanNode {
1021                        children: vec![],
1022                        identity,
1023                        node_body: Some(NodeBody::MergeSortExchange(MergeSortExchangeNode {
1024                            exchange: Some(ExchangeNode {
1025                                sources: exchange_sources,
1026                                sequential: false,
1027                                input_schema: execution_plan_node.schema.clone(),
1028                            }),
1029                            column_orders: sort_merge_exchange_node.column_orders.clone(),
1030                        })),
1031                    },
1032                    _ => unreachable!(),
1033                }
1034            }
1035            BatchPlanNodeType::BatchSeqScan => {
1036                let node_body = execution_plan_node.node.clone();
1037                let NodeBody::RowSeqScan(mut scan_node) = node_body else {
1038                    unreachable!();
1039                };
1040                let partition = partition
1041                    .expect("no partition info for seq scan")
1042                    .into_table()
1043                    .expect("PartitionInfo should be TablePartitionInfo");
1044                scan_node.vnode_bitmap = Some(partition.vnode_bitmap.to_protobuf());
1045                scan_node.scan_ranges = partition.scan_ranges;
1046                PbPlanNode {
1047                    children: vec![],
1048                    identity,
1049                    node_body: Some(NodeBody::RowSeqScan(scan_node)),
1050                }
1051            }
1052            BatchPlanNodeType::BatchLogSeqScan => {
1053                let node_body = execution_plan_node.node.clone();
1054                let NodeBody::LogRowSeqScan(mut scan_node) = node_body else {
1055                    unreachable!();
1056                };
1057                let partition = partition
1058                    .expect("no partition info for seq scan")
1059                    .into_table()
1060                    .expect("PartitionInfo should be TablePartitionInfo");
1061                scan_node.vnode_bitmap = Some(partition.vnode_bitmap.to_protobuf());
1062                PbPlanNode {
1063                    children: vec![],
1064                    identity,
1065                    node_body: Some(NodeBody::LogRowSeqScan(scan_node)),
1066                }
1067            }
1068            BatchPlanNodeType::BatchSource | BatchPlanNodeType::BatchKafkaScan => {
1069                let node_body = execution_plan_node.node.clone();
1070                let NodeBody::Source(mut source_node) = node_body else {
1071                    unreachable!();
1072                };
1073
1074                let partition = partition
1075                    .expect("no partition info for seq scan")
1076                    .into_source()
1077                    .expect("PartitionInfo should be SourcePartitionInfo");
1078                source_node.split = partition
1079                    .into_iter()
1080                    .map(|split| split.encode_to_bytes().into())
1081                    .collect_vec();
1082                PbPlanNode {
1083                    children: vec![],
1084                    identity,
1085                    node_body: Some(NodeBody::Source(source_node)),
1086                }
1087            }
1088            BatchPlanNodeType::BatchIcebergScan => {
1089                let node_body = execution_plan_node.node.clone();
1090                let NodeBody::IcebergScan(mut iceberg_scan_node) = node_body else {
1091                    unreachable!();
1092                };
1093
1094                let partition = partition
1095                    .expect("no partition info for seq scan")
1096                    .into_source()
1097                    .expect("PartitionInfo should be SourcePartitionInfo");
1098                iceberg_scan_node.split = partition
1099                    .into_iter()
1100                    .map(|split| split.encode_to_bytes().into())
1101                    .collect_vec();
1102                PbPlanNode {
1103                    children: vec![],
1104                    identity,
1105                    node_body: Some(NodeBody::IcebergScan(iceberg_scan_node)),
1106                }
1107            }
1108            _ => {
1109                let children = execution_plan_node
1110                    .children
1111                    .iter()
1112                    .map(|e| self.convert_plan_node(e, task_id, partition.clone(), identity_id))
1113                    .collect();
1114
1115                PbPlanNode {
1116                    children,
1117                    identity,
1118                    node_body: Some(execution_plan_node.node.clone()),
1119                }
1120            }
1121        }
1122    }
1123
1124    fn is_root_stage(&self) -> bool {
1125        self.stage_id == 0.into()
1126    }
1127
1128    fn mask_failed_serving_worker(&self, worker: &WorkerNode) {
1129        if !worker.property.as_ref().is_some_and(|p| p.is_serving) {
1130            return;
1131        }
1132        let duration = Duration::from_secs(std::cmp::max(
1133            self.ctx
1134                .session
1135                .env()
1136                .batch_config()
1137                .mask_worker_temporary_secs as u64,
1138            1,
1139        ));
1140        self.worker_node_manager
1141            .manager
1142            .mask_worker_node(worker.id, duration);
1143    }
1144}