risingwave_frontend/scheduler/distributed/
stage.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::assert_matches::assert_matches;
16use std::cell::RefCell;
17use std::collections::HashMap;
18use std::mem;
19use std::pin::pin;
20use std::rc::Rc;
21use std::sync::Arc;
22use std::time::Duration;
23
24use StageEvent::Failed;
25use anyhow::anyhow;
26use arc_swap::ArcSwap;
27use futures::stream::Fuse;
28use futures::{StreamExt, TryStreamExt, stream};
29use futures_async_stream::for_await;
30use itertools::Itertools;
31use risingwave_batch::error::BatchError;
32use risingwave_batch::executor::ExecutorBuilder;
33use risingwave_batch::task::{ShutdownMsg, ShutdownSender, ShutdownToken, TaskId as TaskIdBatch};
34use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeSelector;
35use risingwave_common::array::DataChunk;
36use risingwave_common::hash::WorkerSlotMapping;
37use risingwave_common::util::addr::HostAddr;
38use risingwave_common::util::iter_util::ZipEqFast;
39use risingwave_connector::source::SplitMetaData;
40use risingwave_expr::expr_context::expr_context_scope;
41use risingwave_pb::batch_plan::plan_node::NodeBody;
42use risingwave_pb::batch_plan::{
43    DistributedLookupJoinNode, ExchangeNode, ExchangeSource, MergeSortExchangeNode, PlanFragment,
44    PlanNode as PbPlanNode, PlanNode, TaskId as PbTaskId, TaskOutputId,
45};
46use risingwave_pb::common::{BatchQueryEpoch, HostAddress, WorkerNode};
47use risingwave_pb::plan_common::ExprContext;
48use risingwave_pb::task_service::{CancelTaskRequest, TaskInfoResponse};
49use risingwave_rpc_client::ComputeClientPoolRef;
50use risingwave_rpc_client::error::RpcError;
51use rw_futures_util::select_all;
52use thiserror_ext::AsReport;
53use tokio::spawn;
54use tokio::sync::RwLock;
55use tokio::sync::mpsc::{Receiver, Sender};
56use tonic::Streaming;
57use tracing::{Instrument, debug, error, warn};
58
59use crate::catalog::catalog_service::CatalogReader;
60use crate::catalog::{FragmentId, TableId};
61use crate::optimizer::plan_node::PlanNodeType;
62use crate::scheduler::SchedulerError::{TaskExecutionError, TaskRunningOutOfMemory};
63use crate::scheduler::distributed::QueryMessage;
64use crate::scheduler::distributed::stage::StageState::Pending;
65use crate::scheduler::plan_fragmenter::{
66    ExecutionPlanNode, PartitionInfo, QueryStageRef, ROOT_TASK_ID, StageId, TaskId,
67};
68use crate::scheduler::{ExecutionContextRef, SchedulerError, SchedulerResult};
69
70const TASK_SCHEDULING_PARALLELISM: usize = 10;
71
72#[derive(Debug)]
73enum StageState {
74    /// We put `msg_sender` in `Pending` state to avoid holding it in `StageExecution`. In this
75    /// way, it could be efficiently moved into `StageRunner` instead of being cloned. This also
76    /// ensures that the sender can get dropped once it is used up, preventing some issues caused
77    /// by unnecessarily long lifetime.
78    Pending {
79        msg_sender: Sender<QueryMessage>,
80    },
81    Started,
82    Running,
83    Completed,
84    Failed,
85}
86
87#[derive(Debug)]
88pub enum StageEvent {
89    Scheduled(StageId),
90    ScheduledRoot(Receiver<SchedulerResult<DataChunk>>),
91    /// Stage failed.
92    Failed {
93        id: StageId,
94        reason: SchedulerError,
95    },
96    /// All tasks in stage finished.
97    Completed(#[allow(dead_code)] StageId),
98}
99
100#[derive(Clone)]
101pub struct TaskStatus {
102    _task_id: TaskId,
103
104    // None before task is scheduled.
105    location: Option<HostAddress>,
106}
107
108struct TaskStatusHolder {
109    inner: ArcSwap<TaskStatus>,
110}
111
112pub struct StageExecution {
113    epoch: BatchQueryEpoch,
114    stage: QueryStageRef,
115    worker_node_manager: WorkerNodeSelector,
116    tasks: Arc<HashMap<TaskId, TaskStatusHolder>>,
117    state: Arc<RwLock<StageState>>,
118    shutdown_tx: RwLock<Option<ShutdownSender>>,
119    /// Children stage executions.
120    ///
121    /// We use `Vec` here since children's size is usually small.
122    children: Vec<Arc<StageExecution>>,
123    compute_client_pool: ComputeClientPoolRef,
124    catalog_reader: CatalogReader,
125
126    /// Execution context ref
127    ctx: ExecutionContextRef,
128}
129
130struct StageRunner {
131    epoch: BatchQueryEpoch,
132    state: Arc<RwLock<StageState>>,
133    stage: QueryStageRef,
134    worker_node_manager: WorkerNodeSelector,
135    tasks: Arc<HashMap<TaskId, TaskStatusHolder>>,
136    // Send message to `QueryRunner` to notify stage state change.
137    msg_sender: Sender<QueryMessage>,
138    children: Vec<Arc<StageExecution>>,
139    compute_client_pool: ComputeClientPoolRef,
140    catalog_reader: CatalogReader,
141
142    ctx: ExecutionContextRef,
143}
144
145impl TaskStatusHolder {
146    fn new(task_id: TaskId) -> Self {
147        let task_status = TaskStatus {
148            _task_id: task_id,
149            location: None,
150        };
151
152        Self {
153            inner: ArcSwap::new(Arc::new(task_status)),
154        }
155    }
156
157    fn get_status(&self) -> Arc<TaskStatus> {
158        self.inner.load_full()
159    }
160}
161
162impl StageExecution {
163    #[allow(clippy::too_many_arguments)]
164    pub fn new(
165        epoch: BatchQueryEpoch,
166        stage: QueryStageRef,
167        worker_node_manager: WorkerNodeSelector,
168        msg_sender: Sender<QueryMessage>,
169        children: Vec<Arc<StageExecution>>,
170        compute_client_pool: ComputeClientPoolRef,
171        catalog_reader: CatalogReader,
172        ctx: ExecutionContextRef,
173    ) -> Self {
174        let tasks = (0..stage.parallelism.unwrap())
175            .map(|task_id| (task_id as u64, TaskStatusHolder::new(task_id as u64)))
176            .collect();
177
178        Self {
179            epoch,
180            stage,
181            worker_node_manager,
182            tasks: Arc::new(tasks),
183            state: Arc::new(RwLock::new(Pending { msg_sender })),
184            shutdown_tx: RwLock::new(None),
185            children,
186            compute_client_pool,
187            catalog_reader,
188            ctx,
189        }
190    }
191
192    /// Starts execution of this stage, returns error if already started.
193    pub async fn start(&self) {
194        let mut s = self.state.write().await;
195        let cur_state = mem::replace(&mut *s, StageState::Failed);
196        match cur_state {
197            Pending { msg_sender } => {
198                let runner = StageRunner {
199                    epoch: self.epoch,
200                    stage: self.stage.clone(),
201                    worker_node_manager: self.worker_node_manager.clone(),
202                    tasks: self.tasks.clone(),
203                    msg_sender,
204                    children: self.children.clone(),
205                    state: self.state.clone(),
206                    compute_client_pool: self.compute_client_pool.clone(),
207                    catalog_reader: self.catalog_reader.clone(),
208                    ctx: self.ctx.clone(),
209                };
210
211                // The channel used for shutdown signal messaging.
212                let (sender, receiver) = ShutdownToken::new();
213                // Fill the shutdown sender.
214                let mut holder = self.shutdown_tx.write().await;
215                *holder = Some(sender);
216
217                // Change state before spawn runner.
218                *s = StageState::Started;
219
220                let span = tracing::info_span!(
221                    "stage",
222                    "otel.name" = format!("Stage {}-{}", self.stage.query_id.id, self.stage.id),
223                    query_id = self.stage.query_id.id,
224                    stage_id = self.stage.id,
225                );
226                self.ctx
227                    .session()
228                    .env()
229                    .compute_runtime()
230                    .spawn(async move { runner.run(receiver).instrument(span).await });
231
232                tracing::trace!(
233                    "Stage {:?}-{:?} started.",
234                    self.stage.query_id.id,
235                    self.stage.id
236                )
237            }
238            _ => {
239                unreachable!("Only expect to schedule stage once");
240            }
241        }
242    }
243
244    pub async fn stop(&self, error: Option<String>) {
245        // Send message to tell Stage Runner stop.
246        if let Some(shutdown_tx) = self.shutdown_tx.write().await.take() {
247            // It's possible that the stage has not been scheduled, so the channel sender is
248            // None.
249
250            if !if let Some(error) = error {
251                shutdown_tx.abort(error)
252            } else {
253                shutdown_tx.cancel()
254            } {
255                // The stage runner handle has already closed. so do no-op.
256                tracing::trace!(
257                    "Failed to send stop message stage: {:?}-{:?}",
258                    self.stage.query_id,
259                    self.stage.id
260                );
261            }
262        }
263    }
264
265    pub async fn is_scheduled(&self) -> bool {
266        let s = self.state.read().await;
267        matches!(*s, StageState::Running | StageState::Completed)
268    }
269
270    pub async fn is_pending(&self) -> bool {
271        let s = self.state.read().await;
272        matches!(*s, StageState::Pending { .. })
273    }
274
275    pub async fn state(&self) -> &'static str {
276        let s = self.state.read().await;
277        match *s {
278            Pending { .. } => "Pending",
279            StageState::Started => "Started",
280            StageState::Running => "Running",
281            StageState::Completed => "Completed",
282            StageState::Failed => "Failed",
283        }
284    }
285
286    /// Returns all exchange sources for `output_id`. Each `ExchangeSource` is identified by
287    /// producer's `TaskId` and `output_id` (consumer's `TaskId`), since each task may produce
288    /// output to several channels.
289    ///
290    /// When this method is called, all tasks should have been scheduled, and their `worker_node`
291    /// should have been set.
292    pub fn all_exchange_sources_for(&self, output_id: u64) -> Vec<ExchangeSource> {
293        self.tasks
294            .iter()
295            .map(|(task_id, status_holder)| {
296                let task_output_id = TaskOutputId {
297                    task_id: Some(PbTaskId {
298                        query_id: self.stage.query_id.id.clone(),
299                        stage_id: self.stage.id,
300                        task_id: *task_id,
301                    }),
302                    output_id,
303                };
304
305                ExchangeSource {
306                    task_output_id: Some(task_output_id),
307                    host: Some(status_holder.inner.load_full().location.clone().unwrap()),
308                    local_execute_plan: None,
309                }
310            })
311            .collect()
312    }
313}
314
315impl StageRunner {
316    async fn run(mut self, shutdown_rx: ShutdownToken) {
317        if let Err(e) = self.schedule_tasks_for_all(shutdown_rx).await {
318            error!(
319                error = %e.as_report(),
320                query_id = ?self.stage.query_id,
321                stage_id = ?self.stage.id,
322                "Failed to schedule tasks"
323            );
324            self.send_event(QueryMessage::Stage(Failed {
325                id: self.stage.id,
326                reason: e,
327            }))
328            .await;
329        }
330    }
331
332    /// Send stage event to listener.
333    async fn send_event(&self, event: QueryMessage) {
334        if let Err(_e) = self.msg_sender.send(event).await {
335            warn!("Failed to send event to Query Runner, may be killed by previous failed event");
336        }
337    }
338
339    /// Schedule all tasks to CN and wait process all status messages from RPC. Note that when all
340    /// task is created, it should tell `QueryRunner` to schedule next.
341    async fn schedule_tasks(
342        &mut self,
343        mut shutdown_rx: ShutdownToken,
344        expr_context: ExprContext,
345    ) -> SchedulerResult<()> {
346        let mut futures = vec![];
347
348        if let Some(table_scan_info) = self.stage.table_scan_info.as_ref()
349            && let Some(vnode_bitmaps) = table_scan_info.partitions()
350        {
351            // If the stage has table scan nodes, we create tasks according to the data distribution
352            // and partition of the table.
353            // We let each task read one partition by setting the `vnode_ranges` of the scan node in
354            // the task.
355            // We schedule the task to the worker node that owns the data partition.
356            let worker_slot_ids = vnode_bitmaps.keys().cloned().collect_vec();
357            let workers = self
358                .worker_node_manager
359                .manager
360                .get_workers_by_worker_slot_ids(&worker_slot_ids)?;
361
362            for (i, (worker_slot_id, worker)) in worker_slot_ids
363                .into_iter()
364                .zip_eq_fast(workers.into_iter())
365                .enumerate()
366            {
367                let task_id = PbTaskId {
368                    query_id: self.stage.query_id.id.clone(),
369                    stage_id: self.stage.id,
370                    task_id: i as u64,
371                };
372                let vnode_ranges = vnode_bitmaps[&worker_slot_id].clone();
373                let plan_fragment =
374                    self.create_plan_fragment(i as u64, Some(PartitionInfo::Table(vnode_ranges)));
375                futures.push(self.schedule_task(
376                    task_id,
377                    plan_fragment,
378                    Some(worker),
379                    expr_context.clone(),
380                ));
381            }
382        } else if let Some(source_info) = self.stage.source_info.as_ref() {
383            // If there is no file in source, the `chunk_size` is set to 1.
384            let chunk_size = ((source_info.split_info().unwrap().len() as f32
385                / self.stage.parallelism.unwrap() as f32)
386                .ceil() as usize)
387                .max(1);
388            if source_info.split_info().unwrap().is_empty() {
389                // No file in source, schedule an empty task.
390                const EMPTY_TASK_ID: u64 = 0;
391                let task_id = PbTaskId {
392                    query_id: self.stage.query_id.id.clone(),
393                    stage_id: self.stage.id,
394                    task_id: EMPTY_TASK_ID,
395                };
396                let plan_fragment =
397                    self.create_plan_fragment(EMPTY_TASK_ID, Some(PartitionInfo::Source(vec![])));
398                let worker = self.choose_worker(
399                    &plan_fragment,
400                    EMPTY_TASK_ID as u32,
401                    self.stage.dml_table_id,
402                )?;
403                futures.push(self.schedule_task(
404                    task_id,
405                    plan_fragment,
406                    worker,
407                    expr_context.clone(),
408                ));
409            } else {
410                for (id, split) in source_info
411                    .split_info()
412                    .unwrap()
413                    .chunks(chunk_size)
414                    .enumerate()
415                {
416                    let task_id = PbTaskId {
417                        query_id: self.stage.query_id.id.clone(),
418                        stage_id: self.stage.id,
419                        task_id: id as u64,
420                    };
421                    let plan_fragment = self.create_plan_fragment(
422                        id as u64,
423                        Some(PartitionInfo::Source(split.to_vec())),
424                    );
425                    let worker =
426                        self.choose_worker(&plan_fragment, id as u32, self.stage.dml_table_id)?;
427                    futures.push(self.schedule_task(
428                        task_id,
429                        plan_fragment,
430                        worker,
431                        expr_context.clone(),
432                    ));
433                }
434            }
435        } else if let Some(file_scan_info) = self.stage.file_scan_info.as_ref() {
436            let chunk_size = (file_scan_info.file_location.len() as f32
437                / self.stage.parallelism.unwrap() as f32)
438                .ceil() as usize;
439            for (id, files) in file_scan_info.file_location.chunks(chunk_size).enumerate() {
440                let task_id = PbTaskId {
441                    query_id: self.stage.query_id.id.clone(),
442                    stage_id: self.stage.id,
443                    task_id: id as u64,
444                };
445                let plan_fragment =
446                    self.create_plan_fragment(id as u64, Some(PartitionInfo::File(files.to_vec())));
447                let worker =
448                    self.choose_worker(&plan_fragment, id as u32, self.stage.dml_table_id)?;
449                futures.push(self.schedule_task(
450                    task_id,
451                    plan_fragment,
452                    worker,
453                    expr_context.clone(),
454                ));
455            }
456        } else {
457            for id in 0..self.stage.parallelism.unwrap() {
458                let task_id = PbTaskId {
459                    query_id: self.stage.query_id.id.clone(),
460                    stage_id: self.stage.id,
461                    task_id: id as u64,
462                };
463                let plan_fragment = self.create_plan_fragment(id as u64, None);
464                let worker = self.choose_worker(&plan_fragment, id, self.stage.dml_table_id)?;
465                futures.push(self.schedule_task(
466                    task_id,
467                    plan_fragment,
468                    worker,
469                    expr_context.clone(),
470                ));
471            }
472        }
473
474        // Await each future and convert them into a set of streams.
475        let buffered = stream::iter(futures).buffer_unordered(TASK_SCHEDULING_PARALLELISM);
476        let buffered_streams = buffered.try_collect::<Vec<_>>().await?;
477
478        // Merge different task streams into a single stream.
479        let cancelled = pin!(shutdown_rx.cancelled());
480        let mut all_streams = select_all(buffered_streams).take_until(cancelled);
481
482        // Process the stream until finished.
483        let mut running_task_cnt = 0;
484        let mut finished_task_cnt = 0;
485        let mut sent_signal_to_next = false;
486
487        while let Some(status_res_inner) = all_streams.next().await {
488            match status_res_inner {
489                Ok(status) => {
490                    use risingwave_pb::task_service::task_info_response::TaskStatus as PbTaskStatus;
491                    match PbTaskStatus::try_from(status.task_status).unwrap() {
492                        PbTaskStatus::Running => {
493                            running_task_cnt += 1;
494                            // The task running count should always less or equal than the
495                            // registered tasks number.
496                            assert!(running_task_cnt <= self.tasks.keys().len());
497                            // All tasks in this stage have been scheduled. Notify query runner to
498                            // schedule next stage.
499                            if running_task_cnt == self.tasks.keys().len() {
500                                self.notify_stage_scheduled(QueryMessage::Stage(
501                                    StageEvent::Scheduled(self.stage.id),
502                                ))
503                                .await;
504                                sent_signal_to_next = true;
505                            }
506                        }
507
508                        PbTaskStatus::Finished => {
509                            finished_task_cnt += 1;
510                            assert!(finished_task_cnt <= self.tasks.keys().len());
511                            assert!(running_task_cnt >= finished_task_cnt);
512                            if finished_task_cnt == self.tasks.keys().len() {
513                                // All tasks finished without failure, we should not break
514                                // this loop
515                                self.notify_stage_completed().await;
516                                sent_signal_to_next = true;
517                                break;
518                            }
519                        }
520                        PbTaskStatus::Aborted => {
521                            // Currently, the only reason that we receive an abort status is that
522                            // the task's memory usage is too high so
523                            // it's aborted.
524                            error!(
525                                "Abort task {:?} because of excessive memory usage. Please try again later.",
526                                status.task_id.unwrap()
527                            );
528                            self.notify_stage_state_changed(
529                                |_| StageState::Failed,
530                                QueryMessage::Stage(Failed {
531                                    id: self.stage.id,
532                                    reason: TaskRunningOutOfMemory,
533                                }),
534                            )
535                            .await;
536                            sent_signal_to_next = true;
537                            break;
538                        }
539                        PbTaskStatus::Failed => {
540                            // Task failed, we should fail whole query
541                            error!(
542                                "Task {:?} failed, reason: {:?}",
543                                status.task_id.unwrap(),
544                                status.error_message,
545                            );
546                            self.notify_stage_state_changed(
547                                |_| StageState::Failed,
548                                QueryMessage::Stage(Failed {
549                                    id: self.stage.id,
550                                    reason: TaskExecutionError(status.error_message),
551                                }),
552                            )
553                            .await;
554                            sent_signal_to_next = true;
555                            break;
556                        }
557                        PbTaskStatus::Ping => {
558                            debug!("Receive ping from task {:?}", status.task_id.unwrap());
559                        }
560                        status => {
561                            // The remain possible variant is Failed, but now they won't be pushed
562                            // from CN.
563                            unreachable!("Unexpected task status {:?}", status);
564                        }
565                    }
566                }
567                Err(e) => {
568                    // rpc error here, we should also notify stage failure
569                    error!(
570                        "Fetching task status in stage {:?} failed, reason: {:?}",
571                        self.stage.id,
572                        e.message()
573                    );
574                    self.notify_stage_state_changed(
575                        |_| StageState::Failed,
576                        QueryMessage::Stage(Failed {
577                            id: self.stage.id,
578                            reason: RpcError::from_batch_status(e).into(),
579                        }),
580                    )
581                    .await;
582                    sent_signal_to_next = true;
583                    break;
584                }
585            }
586        }
587
588        tracing::trace!(
589            "Stage [{:?}-{:?}], running task count: {}, finished task count: {}, sent signal to next: {}",
590            self.stage.query_id,
591            self.stage.id,
592            running_task_cnt,
593            finished_task_cnt,
594            sent_signal_to_next,
595        );
596
597        if let Some(shutdown) = all_streams.take_future() {
598            tracing::trace!(
599                "Stage [{:?}-{:?}] waiting for stopping signal.",
600                self.stage.query_id,
601                self.stage.id
602            );
603            // Waiting for shutdown signal.
604            shutdown.await;
605        }
606
607        // Received shutdown signal from query runner, should send abort RPC to all CNs.
608        // change state to aborted. Note that the task cancel can only happen after schedule
609        // all these tasks to CN. This can be an optimization for future:
610        // How to stop before schedule tasks.
611        tracing::trace!(
612            "Stopping stage: {:?}-{:?}, task_num: {}",
613            self.stage.query_id,
614            self.stage.id,
615            self.tasks.len()
616        );
617        self.cancel_all_scheducancled_tasks().await?;
618
619        tracing::trace!(
620            "Stage runner [{:?}-{:?}] exited.",
621            self.stage.query_id,
622            self.stage.id
623        );
624        Ok(())
625    }
626
627    async fn schedule_tasks_for_root(
628        &mut self,
629        mut shutdown_rx: ShutdownToken,
630        expr_context: ExprContext,
631    ) -> SchedulerResult<()> {
632        let root_stage_id = self.stage.id;
633        // Currently, the dml or table scan should never be root fragment, so the partition is None.
634        // And root fragment only contain one task.
635        let plan_fragment = self.create_plan_fragment(ROOT_TASK_ID, None);
636        let plan_node = plan_fragment.root.unwrap();
637        let task_id = TaskIdBatch {
638            query_id: self.stage.query_id.id.clone(),
639            stage_id: root_stage_id,
640            task_id: 0,
641        };
642
643        // Notify QueryRunner to poll chunk from result_rx.
644        let (result_tx, result_rx) = tokio::sync::mpsc::channel(
645            self.ctx
646                .session
647                .env()
648                .batch_config()
649                .developer
650                .root_stage_channel_size,
651        );
652        self.notify_stage_scheduled(QueryMessage::Stage(StageEvent::ScheduledRoot(result_rx)))
653            .await;
654
655        let executor = ExecutorBuilder::new(
656            &plan_node,
657            &task_id,
658            self.ctx.to_batch_task_context(),
659            self.epoch,
660            shutdown_rx.clone(),
661        );
662
663        let shutdown_rx0 = shutdown_rx.clone();
664
665        let result = expr_context_scope(expr_context, async {
666            let executor = executor.build().await?;
667            let chunk_stream = executor.execute();
668            let cancelled = pin!(shutdown_rx.cancelled());
669            #[for_await]
670            for chunk in chunk_stream.take_until(cancelled) {
671                if let Err(ref e) = chunk {
672                    if shutdown_rx0.is_cancelled() {
673                        break;
674                    }
675                    let err_str = e.to_report_string();
676                    // This is possible if The Query Runner drop early before schedule the root
677                    // executor. Detail described in https://github.com/risingwavelabs/risingwave/issues/6883#issuecomment-1348102037.
678                    // The error format is just channel closed so no care.
679                    if let Err(_e) = result_tx.send(chunk.map_err(|e| e.into())).await {
680                        warn!("Root executor has been dropped before receive any events so the send is failed");
681                    }
682                    // Different from below, return this function and report error.
683                    return Err(TaskExecutionError(err_str));
684                } else {
685                    // Same for below.
686                    if let Err(_e) = result_tx.send(chunk.map_err(|e| e.into())).await {
687                        warn!("Root executor has been dropped before receive any events so the send is failed");
688                    }
689                }
690            }
691            Ok(())
692        }).await;
693
694        if let Err(err) = &result {
695            // If we encountered error when executing root stage locally, we have to notify the result fetcher, which is
696            // returned by `distribute_execute` and being listened by the FE handler task. Otherwise the FE handler cannot
697            // properly throw the error to the PG client.
698            if let Err(_e) = result_tx
699                .send(Err(TaskExecutionError(err.to_report_string())))
700                .await
701            {
702                warn!("Send task execution failed");
703            }
704        }
705
706        // Terminated by other tasks execution error, so no need to return error here.
707        match shutdown_rx0.message() {
708            ShutdownMsg::Abort(err_str) => {
709                // Tell Query Result Fetcher to stop polling and attach failure reason as str.
710                if let Err(_e) = result_tx.send(Err(TaskExecutionError(err_str))).await {
711                    warn!("Send task execution failed");
712                }
713            }
714            _ => self.notify_stage_completed().await,
715        }
716
717        tracing::trace!(
718            "Stage runner [{:?}-{:?}] existed. ",
719            self.stage.query_id,
720            self.stage.id
721        );
722
723        // We still have to throw the error in this current task, so that `StageRunner::run` can further
724        // send `Failed` event to stop other stages.
725        result.map(|_| ())
726    }
727
728    async fn schedule_tasks_for_all(&mut self, shutdown_rx: ShutdownToken) -> SchedulerResult<()> {
729        let expr_context = ExprContext {
730            time_zone: self.ctx.session().config().timezone().to_owned(),
731            strict_mode: self.ctx.session().config().batch_expr_strict_mode(),
732        };
733        // If root, we execute it locally.
734        if !self.is_root_stage() {
735            self.schedule_tasks(shutdown_rx, expr_context).await?;
736        } else {
737            self.schedule_tasks_for_root(shutdown_rx, expr_context)
738                .await?;
739        }
740        Ok(())
741    }
742
743    #[inline(always)]
744    fn get_fragment_id(&self, table_id: &TableId) -> SchedulerResult<FragmentId> {
745        self.catalog_reader
746            .read_guard()
747            .get_any_table_by_id(table_id)
748            .map(|table| table.fragment_id)
749            .map_err(|e| SchedulerError::Internal(anyhow!(e)))
750    }
751
752    #[inline(always)]
753    fn get_table_dml_vnode_mapping(
754        &self,
755        table_id: &TableId,
756    ) -> SchedulerResult<WorkerSlotMapping> {
757        let guard = self.catalog_reader.read_guard();
758
759        let table = guard
760            .get_any_table_by_id(table_id)
761            .map_err(|e| SchedulerError::Internal(anyhow!(e)))?;
762
763        let fragment_id = match table.dml_fragment_id.as_ref() {
764            Some(dml_fragment_id) => dml_fragment_id,
765            // Backward compatibility for those table without `dml_fragment_id`.
766            None => &table.fragment_id,
767        };
768
769        self.worker_node_manager
770            .manager
771            .get_streaming_fragment_mapping(fragment_id)
772            .map_err(|e| e.into())
773    }
774
775    fn choose_worker(
776        &self,
777        plan_fragment: &PlanFragment,
778        task_id: u32,
779        dml_table_id: Option<TableId>,
780    ) -> SchedulerResult<Option<WorkerNode>> {
781        let plan_node = plan_fragment.root.as_ref().expect("fail to get plan node");
782
783        if let Some(table_id) = dml_table_id {
784            let vnode_mapping = self.get_table_dml_vnode_mapping(&table_id)?;
785            let worker_slot_ids = vnode_mapping.iter_unique().collect_vec();
786            let candidates = self
787                .worker_node_manager
788                .manager
789                .get_workers_by_worker_slot_ids(&worker_slot_ids)?;
790            if candidates.is_empty() {
791                return Err(BatchError::EmptyWorkerNodes.into());
792            }
793            let candidate = if self.stage.batch_enable_distributed_dml {
794                // If distributed dml is enabled, we need to try our best to distribute dml tasks evenly to each worker.
795                // Using a `task_id` could be helpful in this case.
796                candidates[task_id as usize % candidates.len()].clone()
797            } else {
798                // If distributed dml is disabled, we need to guarantee that dml from the same session would be sent to a fixed worker/channel to provide a order guarantee.
799                candidates[self.stage.session_id.0 as usize % candidates.len()].clone()
800            };
801            return Ok(Some(candidate));
802        };
803
804        if let Some(distributed_lookup_join_node) =
805            Self::find_distributed_lookup_join_node(plan_node)
806        {
807            let fragment_id = self.get_fragment_id(
808                &distributed_lookup_join_node
809                    .inner_side_table_desc
810                    .as_ref()
811                    .unwrap()
812                    .table_id
813                    .into(),
814            )?;
815            let id_to_worker_slots = self
816                .worker_node_manager
817                .fragment_mapping(fragment_id)?
818                .iter_unique()
819                .collect_vec();
820
821            let worker_slot_id = id_to_worker_slots[task_id as usize];
822            let candidates = self
823                .worker_node_manager
824                .manager
825                .get_workers_by_worker_slot_ids(&[worker_slot_id])?;
826            if candidates.is_empty() {
827                return Err(BatchError::EmptyWorkerNodes.into());
828            }
829            Ok(Some(candidates[0].clone()))
830        } else {
831            Ok(None)
832        }
833    }
834
835    fn find_distributed_lookup_join_node(
836        plan_node: &PlanNode,
837    ) -> Option<&DistributedLookupJoinNode> {
838        let node_body = plan_node.node_body.as_ref().expect("fail to get node body");
839
840        match node_body {
841            NodeBody::DistributedLookupJoin(distributed_lookup_join_node) => {
842                Some(distributed_lookup_join_node)
843            }
844            _ => plan_node
845                .children
846                .iter()
847                .find_map(Self::find_distributed_lookup_join_node),
848        }
849    }
850
851    /// Write message into channel to notify query runner current stage have been scheduled.
852    async fn notify_stage_scheduled(&self, msg: QueryMessage) {
853        self.notify_stage_state_changed(
854            |old_state| {
855                assert_matches!(old_state, StageState::Started);
856                StageState::Running
857            },
858            msg,
859        )
860        .await
861    }
862
863    /// Notify query execution that this stage completed.
864    async fn notify_stage_completed(&self) {
865        self.notify_stage_state_changed(
866            |old_state| {
867                assert_matches!(old_state, StageState::Running);
868                StageState::Completed
869            },
870            QueryMessage::Stage(StageEvent::Completed(self.stage.id)),
871        )
872        .await
873    }
874
875    async fn notify_stage_state_changed<F>(&self, new_state: F, msg: QueryMessage)
876    where
877        F: FnOnce(StageState) -> StageState,
878    {
879        {
880            let mut s = self.state.write().await;
881            let old_state = mem::replace(&mut *s, StageState::Failed);
882            *s = new_state(old_state);
883        }
884
885        self.send_event(msg).await;
886    }
887
888    /// Abort all registered tasks. Note that here we do not care which part of tasks has already
889    /// failed or completed, cuz the abort task will not fail if the task has already die.
890    /// See PR (#4560).
891    async fn cancel_all_scheducancled_tasks(&self) -> SchedulerResult<()> {
892        // Set state to failed.
893        // {
894        //     let mut state = self.state.write().await;
895        //     // Ignore if already finished.
896        //     if let &StageState::Completed = &*state {
897        //         return Ok(());
898        //     }
899        //     // FIXME: Be careful for state jump back.
900        //     *state = StageState::Failed
901        // }
902
903        for (task, task_status) in &*self.tasks {
904            // 1. Collect task info and client.
905            let loc = &task_status.get_status().location;
906            let addr = loc.as_ref().expect("Get address should not fail");
907            let client = self
908                .compute_client_pool
909                .get_by_addr(HostAddr::from(addr))
910                .await
911                .map_err(|e| anyhow!(e))?;
912
913            // 2. Send RPC to each compute node for each task asynchronously.
914            let query_id = self.stage.query_id.id.clone();
915            let stage_id = self.stage.id;
916            let task_id = *task;
917            spawn(async move {
918                if let Err(e) = client
919                    .cancel(CancelTaskRequest {
920                        task_id: Some(risingwave_pb::batch_plan::TaskId {
921                            query_id: query_id.clone(),
922                            stage_id,
923                            task_id,
924                        }),
925                    })
926                    .await
927                {
928                    error!(
929                        error = %e.as_report(),
930                        ?task_id,
931                        ?query_id,
932                        ?stage_id,
933                        "Abort task failed",
934                    );
935                };
936            });
937        }
938        Ok(())
939    }
940
941    async fn schedule_task(
942        &self,
943        task_id: PbTaskId,
944        plan_fragment: PlanFragment,
945        worker: Option<WorkerNode>,
946        expr_context: ExprContext,
947    ) -> SchedulerResult<Fuse<Streaming<TaskInfoResponse>>> {
948        let mut worker = worker.unwrap_or(self.worker_node_manager.next_random_worker()?);
949        let worker_node_addr = worker.host.take().unwrap();
950        let compute_client = self
951            .compute_client_pool
952            .get_by_addr((&worker_node_addr).into())
953            .await
954            .inspect_err(|_| self.mask_failed_serving_worker(&worker))
955            .map_err(|e| anyhow!(e))?;
956
957        let t_id = task_id.task_id;
958
959        let stream_status: Fuse<Streaming<TaskInfoResponse>> = compute_client
960            .create_task(task_id, plan_fragment, self.epoch, expr_context)
961            .await
962            .inspect_err(|_| self.mask_failed_serving_worker(&worker))
963            .map_err(|e| anyhow!(e))?
964            .fuse();
965
966        self.tasks[&t_id].inner.store(Arc::new(TaskStatus {
967            _task_id: t_id,
968            location: Some(worker_node_addr),
969        }));
970
971        Ok(stream_status)
972    }
973
974    pub fn create_plan_fragment(
975        &self,
976        task_id: TaskId,
977        partition: Option<PartitionInfo>,
978    ) -> PlanFragment {
979        // Used to maintain auto-increment identity_id of a task.
980        let identity_id: Rc<RefCell<u64>> = Rc::new(RefCell::new(0));
981
982        let plan_node_prost =
983            self.convert_plan_node(&self.stage.root, task_id, partition, identity_id);
984        let exchange_info = self.stage.exchange_info.clone().unwrap();
985
986        PlanFragment {
987            root: Some(plan_node_prost),
988            exchange_info: Some(exchange_info),
989        }
990    }
991
992    fn convert_plan_node(
993        &self,
994        execution_plan_node: &ExecutionPlanNode,
995        task_id: TaskId,
996        partition: Option<PartitionInfo>,
997        identity_id: Rc<RefCell<u64>>,
998    ) -> PbPlanNode {
999        // Generate identity
1000        let identity = {
1001            let identity_type = execution_plan_node.plan_node_type;
1002            let id = *identity_id.borrow();
1003            identity_id.replace(id + 1);
1004            format!("{:?}-{}", identity_type, id)
1005        };
1006
1007        match execution_plan_node.plan_node_type {
1008            PlanNodeType::BatchExchange => {
1009                // Find the stage this exchange node should fetch from and get all exchange sources.
1010                let child_stage = self
1011                    .children
1012                    .iter()
1013                    .find(|child_stage| {
1014                        child_stage.stage.id == execution_plan_node.source_stage_id.unwrap()
1015                    })
1016                    .unwrap();
1017                let exchange_sources = child_stage.all_exchange_sources_for(task_id);
1018
1019                match &execution_plan_node.node {
1020                    NodeBody::Exchange(exchange_node) => PbPlanNode {
1021                        children: vec![],
1022                        identity,
1023                        node_body: Some(NodeBody::Exchange(ExchangeNode {
1024                            sources: exchange_sources,
1025                            sequential: exchange_node.sequential,
1026                            input_schema: execution_plan_node.schema.clone(),
1027                        })),
1028                    },
1029                    NodeBody::MergeSortExchange(sort_merge_exchange_node) => PbPlanNode {
1030                        children: vec![],
1031                        identity,
1032                        node_body: Some(NodeBody::MergeSortExchange(MergeSortExchangeNode {
1033                            exchange: Some(ExchangeNode {
1034                                sources: exchange_sources,
1035                                sequential: false,
1036                                input_schema: execution_plan_node.schema.clone(),
1037                            }),
1038                            column_orders: sort_merge_exchange_node.column_orders.clone(),
1039                        })),
1040                    },
1041                    _ => unreachable!(),
1042                }
1043            }
1044            PlanNodeType::BatchSeqScan => {
1045                let node_body = execution_plan_node.node.clone();
1046                let NodeBody::RowSeqScan(mut scan_node) = node_body else {
1047                    unreachable!();
1048                };
1049                let partition = partition
1050                    .expect("no partition info for seq scan")
1051                    .into_table()
1052                    .expect("PartitionInfo should be TablePartitionInfo");
1053                scan_node.vnode_bitmap = Some(partition.vnode_bitmap.to_protobuf());
1054                scan_node.scan_ranges = partition.scan_ranges;
1055                PbPlanNode {
1056                    children: vec![],
1057                    identity,
1058                    node_body: Some(NodeBody::RowSeqScan(scan_node)),
1059                }
1060            }
1061            PlanNodeType::BatchLogSeqScan => {
1062                let node_body = execution_plan_node.node.clone();
1063                let NodeBody::LogRowSeqScan(mut scan_node) = node_body else {
1064                    unreachable!();
1065                };
1066                let partition = partition
1067                    .expect("no partition info for seq scan")
1068                    .into_table()
1069                    .expect("PartitionInfo should be TablePartitionInfo");
1070                scan_node.vnode_bitmap = Some(partition.vnode_bitmap.to_protobuf());
1071                PbPlanNode {
1072                    children: vec![],
1073                    identity,
1074                    node_body: Some(NodeBody::LogRowSeqScan(scan_node)),
1075                }
1076            }
1077            PlanNodeType::BatchSource | PlanNodeType::BatchKafkaScan => {
1078                let node_body = execution_plan_node.node.clone();
1079                let NodeBody::Source(mut source_node) = node_body else {
1080                    unreachable!();
1081                };
1082
1083                let partition = partition
1084                    .expect("no partition info for seq scan")
1085                    .into_source()
1086                    .expect("PartitionInfo should be SourcePartitionInfo");
1087                source_node.split = partition
1088                    .into_iter()
1089                    .map(|split| split.encode_to_bytes().into())
1090                    .collect_vec();
1091                PbPlanNode {
1092                    children: vec![],
1093                    identity,
1094                    node_body: Some(NodeBody::Source(source_node)),
1095                }
1096            }
1097            PlanNodeType::BatchIcebergScan => {
1098                let node_body = execution_plan_node.node.clone();
1099                let NodeBody::IcebergScan(mut iceberg_scan_node) = node_body else {
1100                    unreachable!();
1101                };
1102
1103                let partition = partition
1104                    .expect("no partition info for seq scan")
1105                    .into_source()
1106                    .expect("PartitionInfo should be SourcePartitionInfo");
1107                iceberg_scan_node.split = partition
1108                    .into_iter()
1109                    .map(|split| split.encode_to_bytes().into())
1110                    .collect_vec();
1111                PbPlanNode {
1112                    children: vec![],
1113                    identity,
1114                    node_body: Some(NodeBody::IcebergScan(iceberg_scan_node)),
1115                }
1116            }
1117            _ => {
1118                let children = execution_plan_node
1119                    .children
1120                    .iter()
1121                    .map(|e| {
1122                        self.convert_plan_node(e, task_id, partition.clone(), identity_id.clone())
1123                    })
1124                    .collect();
1125
1126                PbPlanNode {
1127                    children,
1128                    identity,
1129                    node_body: Some(execution_plan_node.node.clone()),
1130                }
1131            }
1132        }
1133    }
1134
1135    fn is_root_stage(&self) -> bool {
1136        self.stage.id == 0
1137    }
1138
1139    fn mask_failed_serving_worker(&self, worker: &WorkerNode) {
1140        if !worker.property.as_ref().is_some_and(|p| p.is_serving) {
1141            return;
1142        }
1143        let duration = Duration::from_secs(std::cmp::max(
1144            self.ctx
1145                .session
1146                .env()
1147                .batch_config()
1148                .mask_worker_temporary_secs as u64,
1149            1,
1150        ));
1151        self.worker_node_manager
1152            .manager
1153            .mask_worker_node(worker.id, duration);
1154    }
1155}