risingwave_frontend/scheduler/distributed/
stage.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::assert_matches::assert_matches;
16use std::cell::RefCell;
17use std::collections::HashMap;
18use std::mem;
19use std::pin::pin;
20use std::rc::Rc;
21use std::sync::Arc;
22use std::time::Duration;
23
24use StageEvent::Failed;
25use anyhow::anyhow;
26use arc_swap::ArcSwap;
27use futures::stream::Fuse;
28use futures::{StreamExt, TryStreamExt, stream};
29use futures_async_stream::for_await;
30use itertools::Itertools;
31use risingwave_batch::error::BatchError;
32use risingwave_batch::executor::ExecutorBuilder;
33use risingwave_batch::task::{ShutdownMsg, ShutdownSender, ShutdownToken, TaskId as TaskIdBatch};
34use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeSelector;
35use risingwave_common::array::DataChunk;
36use risingwave_common::hash::WorkerSlotMapping;
37use risingwave_common::util::addr::HostAddr;
38use risingwave_common::util::iter_util::ZipEqFast;
39use risingwave_connector::source::SplitMetaData;
40use risingwave_expr::expr_context::expr_context_scope;
41use risingwave_pb::batch_plan::plan_node::NodeBody;
42use risingwave_pb::batch_plan::{
43    DistributedLookupJoinNode, ExchangeNode, ExchangeSource, MergeSortExchangeNode, PlanFragment,
44    PlanNode as PbPlanNode, PlanNode, TaskId as PbTaskId, TaskOutputId,
45};
46use risingwave_pb::common::{BatchQueryEpoch, HostAddress, WorkerNode};
47use risingwave_pb::plan_common::ExprContext;
48use risingwave_pb::task_service::{CancelTaskRequest, TaskInfoResponse};
49use risingwave_rpc_client::ComputeClientPoolRef;
50use risingwave_rpc_client::error::RpcError;
51use rw_futures_util::select_all;
52use thiserror_ext::AsReport;
53use tokio::spawn;
54use tokio::sync::RwLock;
55use tokio::sync::mpsc::{Receiver, Sender};
56use tonic::Streaming;
57use tracing::{Instrument, debug, error, warn};
58
59use crate::catalog::catalog_service::CatalogReader;
60use crate::catalog::{FragmentId, TableId};
61use crate::optimizer::plan_node::BatchPlanNodeType;
62use crate::scheduler::SchedulerError::{TaskExecutionError, TaskRunningOutOfMemory};
63use crate::scheduler::distributed::QueryMessage;
64use crate::scheduler::distributed::stage::StageState::Pending;
65use crate::scheduler::plan_fragmenter::{
66    ExecutionPlanNode, PartitionInfo, QueryStageRef, ROOT_TASK_ID, StageId, TaskId,
67};
68use crate::scheduler::{ExecutionContextRef, SchedulerError, SchedulerResult};
69
70const TASK_SCHEDULING_PARALLELISM: usize = 10;
71
72#[derive(Debug)]
73enum StageState {
74    /// We put `msg_sender` in `Pending` state to avoid holding it in `StageExecution`. In this
75    /// way, it could be efficiently moved into `StageRunner` instead of being cloned. This also
76    /// ensures that the sender can get dropped once it is used up, preventing some issues caused
77    /// by unnecessarily long lifetime.
78    Pending {
79        msg_sender: Sender<QueryMessage>,
80    },
81    Started,
82    Running,
83    Completed,
84    Failed,
85}
86
87#[derive(Debug)]
88pub enum StageEvent {
89    Scheduled(StageId),
90    ScheduledRoot(Receiver<SchedulerResult<DataChunk>>),
91    /// Stage failed.
92    Failed {
93        id: StageId,
94        reason: SchedulerError,
95    },
96    /// All tasks in stage finished.
97    Completed(#[allow(dead_code)] StageId),
98}
99
100#[derive(Clone)]
101pub struct TaskStatus {
102    _task_id: TaskId,
103
104    // None before task is scheduled.
105    location: Option<HostAddress>,
106}
107
108struct TaskStatusHolder {
109    inner: ArcSwap<TaskStatus>,
110}
111
112pub struct StageExecution {
113    epoch: BatchQueryEpoch,
114    stage: QueryStageRef,
115    worker_node_manager: WorkerNodeSelector,
116    tasks: Arc<HashMap<TaskId, TaskStatusHolder>>,
117    state: Arc<RwLock<StageState>>,
118    shutdown_tx: RwLock<Option<ShutdownSender>>,
119    /// Children stage executions.
120    ///
121    /// We use `Vec` here since children's size is usually small.
122    children: Vec<Arc<StageExecution>>,
123    compute_client_pool: ComputeClientPoolRef,
124    catalog_reader: CatalogReader,
125
126    /// Execution context ref
127    ctx: ExecutionContextRef,
128}
129
130struct StageRunner {
131    epoch: BatchQueryEpoch,
132    state: Arc<RwLock<StageState>>,
133    stage: QueryStageRef,
134    worker_node_manager: WorkerNodeSelector,
135    tasks: Arc<HashMap<TaskId, TaskStatusHolder>>,
136    // Send message to `QueryRunner` to notify stage state change.
137    msg_sender: Sender<QueryMessage>,
138    children: Vec<Arc<StageExecution>>,
139    compute_client_pool: ComputeClientPoolRef,
140    catalog_reader: CatalogReader,
141
142    ctx: ExecutionContextRef,
143}
144
145impl TaskStatusHolder {
146    fn new(task_id: TaskId) -> Self {
147        let task_status = TaskStatus {
148            _task_id: task_id,
149            location: None,
150        };
151
152        Self {
153            inner: ArcSwap::new(Arc::new(task_status)),
154        }
155    }
156
157    fn get_status(&self) -> Arc<TaskStatus> {
158        self.inner.load_full()
159    }
160}
161
162impl StageExecution {
163    #[allow(clippy::too_many_arguments)]
164    pub fn new(
165        epoch: BatchQueryEpoch,
166        stage: QueryStageRef,
167        worker_node_manager: WorkerNodeSelector,
168        msg_sender: Sender<QueryMessage>,
169        children: Vec<Arc<StageExecution>>,
170        compute_client_pool: ComputeClientPoolRef,
171        catalog_reader: CatalogReader,
172        ctx: ExecutionContextRef,
173    ) -> Self {
174        let tasks = (0..stage.parallelism.unwrap())
175            .map(|task_id| (task_id as u64, TaskStatusHolder::new(task_id as u64)))
176            .collect();
177
178        Self {
179            epoch,
180            stage,
181            worker_node_manager,
182            tasks: Arc::new(tasks),
183            state: Arc::new(RwLock::new(Pending { msg_sender })),
184            shutdown_tx: RwLock::new(None),
185            children,
186            compute_client_pool,
187            catalog_reader,
188            ctx,
189        }
190    }
191
192    /// Starts execution of this stage, returns error if already started.
193    pub async fn start(&self) {
194        let mut s = self.state.write().await;
195        let cur_state = mem::replace(&mut *s, StageState::Failed);
196        match cur_state {
197            Pending { msg_sender } => {
198                let runner = StageRunner {
199                    epoch: self.epoch,
200                    stage: self.stage.clone(),
201                    worker_node_manager: self.worker_node_manager.clone(),
202                    tasks: self.tasks.clone(),
203                    msg_sender,
204                    children: self.children.clone(),
205                    state: self.state.clone(),
206                    compute_client_pool: self.compute_client_pool.clone(),
207                    catalog_reader: self.catalog_reader.clone(),
208                    ctx: self.ctx.clone(),
209                };
210
211                // The channel used for shutdown signal messaging.
212                let (sender, receiver) = ShutdownToken::new();
213                // Fill the shutdown sender.
214                let mut holder = self.shutdown_tx.write().await;
215                *holder = Some(sender);
216
217                // Change state before spawn runner.
218                *s = StageState::Started;
219
220                let span = tracing::info_span!(
221                    "stage",
222                    "otel.name" = format!("Stage {}-{}", self.stage.query_id.id, self.stage.id),
223                    query_id = self.stage.query_id.id,
224                    stage_id = self.stage.id,
225                );
226                self.ctx
227                    .session()
228                    .env()
229                    .compute_runtime()
230                    .spawn(async move { runner.run(receiver).instrument(span).await });
231
232                tracing::trace!(
233                    "Stage {:?}-{:?} started.",
234                    self.stage.query_id.id,
235                    self.stage.id
236                )
237            }
238            _ => {
239                unreachable!("Only expect to schedule stage once");
240            }
241        }
242    }
243
244    pub async fn stop(&self, error: Option<String>) {
245        // Send message to tell Stage Runner stop.
246        if let Some(shutdown_tx) = self.shutdown_tx.write().await.take() {
247            // It's possible that the stage has not been scheduled, so the channel sender is
248            // None.
249
250            if !if let Some(error) = error {
251                shutdown_tx.abort(error)
252            } else {
253                shutdown_tx.cancel()
254            } {
255                // The stage runner handle has already closed. so do no-op.
256                tracing::trace!(
257                    "Failed to send stop message stage: {:?}-{:?}",
258                    self.stage.query_id,
259                    self.stage.id
260                );
261            }
262        }
263    }
264
265    pub async fn is_scheduled(&self) -> bool {
266        let s = self.state.read().await;
267        matches!(*s, StageState::Running | StageState::Completed)
268    }
269
270    pub async fn is_pending(&self) -> bool {
271        let s = self.state.read().await;
272        matches!(*s, StageState::Pending { .. })
273    }
274
275    pub async fn state(&self) -> &'static str {
276        let s = self.state.read().await;
277        match *s {
278            Pending { .. } => "Pending",
279            StageState::Started => "Started",
280            StageState::Running => "Running",
281            StageState::Completed => "Completed",
282            StageState::Failed => "Failed",
283        }
284    }
285
286    /// Returns all exchange sources for `output_id`. Each `ExchangeSource` is identified by
287    /// producer's `TaskId` and `output_id` (consumer's `TaskId`), since each task may produce
288    /// output to several channels.
289    ///
290    /// When this method is called, all tasks should have been scheduled, and their `worker_node`
291    /// should have been set.
292    pub fn all_exchange_sources_for(&self, output_id: u64) -> Vec<ExchangeSource> {
293        self.tasks
294            .iter()
295            .map(|(task_id, status_holder)| {
296                let task_output_id = TaskOutputId {
297                    task_id: Some(PbTaskId {
298                        query_id: self.stage.query_id.id.clone(),
299                        stage_id: self.stage.id,
300                        task_id: *task_id,
301                    }),
302                    output_id,
303                };
304
305                ExchangeSource {
306                    task_output_id: Some(task_output_id),
307                    host: Some(status_holder.inner.load_full().location.clone().unwrap()),
308                    local_execute_plan: None,
309                }
310            })
311            .collect()
312    }
313}
314
315impl StageRunner {
316    async fn run(mut self, shutdown_rx: ShutdownToken) {
317        if let Err(e) = self.schedule_tasks_for_all(shutdown_rx).await {
318            error!(
319                error = %e.as_report(),
320                query_id = ?self.stage.query_id,
321                stage_id = ?self.stage.id,
322                "Failed to schedule tasks"
323            );
324            self.send_event(QueryMessage::Stage(Failed {
325                id: self.stage.id,
326                reason: e,
327            }))
328            .await;
329        }
330    }
331
332    /// Send stage event to listener.
333    async fn send_event(&self, event: QueryMessage) {
334        if let Err(_e) = self.msg_sender.send(event).await {
335            warn!("Failed to send event to Query Runner, may be killed by previous failed event");
336        }
337    }
338
339    /// Schedule all tasks to CN and wait process all status messages from RPC. Note that when all
340    /// task is created, it should tell `QueryRunner` to schedule next.
341    async fn schedule_tasks(
342        &mut self,
343        mut shutdown_rx: ShutdownToken,
344        expr_context: ExprContext,
345    ) -> SchedulerResult<()> {
346        let mut futures = vec![];
347
348        if let Some(table_scan_info) = self.stage.table_scan_info.as_ref()
349            && let Some(vnode_bitmaps) = table_scan_info.partitions()
350        {
351            // If the stage has table scan nodes, we create tasks according to the data distribution
352            // and partition of the table.
353            // We let each task read one partition by setting the `vnode_ranges` of the scan node in
354            // the task.
355            // We schedule the task to the worker node that owns the data partition.
356            let worker_slot_ids = vnode_bitmaps.keys().cloned().collect_vec();
357            let workers = self
358                .worker_node_manager
359                .manager
360                .get_workers_by_worker_slot_ids(&worker_slot_ids)?;
361
362            for (i, (worker_slot_id, worker)) in worker_slot_ids
363                .into_iter()
364                .zip_eq_fast(workers.into_iter())
365                .enumerate()
366            {
367                let task_id = PbTaskId {
368                    query_id: self.stage.query_id.id.clone(),
369                    stage_id: self.stage.id,
370                    task_id: i as u64,
371                };
372                let vnode_ranges = vnode_bitmaps[&worker_slot_id].clone();
373                let plan_fragment =
374                    self.create_plan_fragment(i as u64, Some(PartitionInfo::Table(vnode_ranges)));
375                futures.push(self.schedule_task(
376                    task_id,
377                    plan_fragment,
378                    Some(worker),
379                    expr_context.clone(),
380                ));
381            }
382        } else if let Some(source_info) = self.stage.source_info.as_ref() {
383            // If there is no file in source, the `chunk_size` is set to 1.
384            let chunk_size = ((source_info.split_info().unwrap().len() as f32
385                / self.stage.parallelism.unwrap() as f32)
386                .ceil() as usize)
387                .max(1);
388            if source_info.split_info().unwrap().is_empty() {
389                // No file in source, schedule an empty task.
390                const EMPTY_TASK_ID: u64 = 0;
391                let task_id = PbTaskId {
392                    query_id: self.stage.query_id.id.clone(),
393                    stage_id: self.stage.id,
394                    task_id: EMPTY_TASK_ID,
395                };
396                let plan_fragment =
397                    self.create_plan_fragment(EMPTY_TASK_ID, Some(PartitionInfo::Source(vec![])));
398                let worker = self.choose_worker(
399                    &plan_fragment,
400                    EMPTY_TASK_ID as u32,
401                    self.stage.dml_table_id,
402                )?;
403                futures.push(self.schedule_task(
404                    task_id,
405                    plan_fragment,
406                    worker,
407                    expr_context.clone(),
408                ));
409            } else {
410                for (id, split) in source_info
411                    .split_info()
412                    .unwrap()
413                    .chunks(chunk_size)
414                    .enumerate()
415                {
416                    let task_id = PbTaskId {
417                        query_id: self.stage.query_id.id.clone(),
418                        stage_id: self.stage.id,
419                        task_id: id as u64,
420                    };
421                    let plan_fragment = self.create_plan_fragment(
422                        id as u64,
423                        Some(PartitionInfo::Source(split.to_vec())),
424                    );
425                    let worker =
426                        self.choose_worker(&plan_fragment, id as u32, self.stage.dml_table_id)?;
427                    futures.push(self.schedule_task(
428                        task_id,
429                        plan_fragment,
430                        worker,
431                        expr_context.clone(),
432                    ));
433                }
434            }
435        } else if let Some(file_scan_info) = self.stage.file_scan_info.as_ref() {
436            let chunk_size = (file_scan_info.file_location.len() as f32
437                / self.stage.parallelism.unwrap() as f32)
438                .ceil() as usize;
439            for (id, files) in file_scan_info.file_location.chunks(chunk_size).enumerate() {
440                let task_id = PbTaskId {
441                    query_id: self.stage.query_id.id.clone(),
442                    stage_id: self.stage.id,
443                    task_id: id as u64,
444                };
445                let plan_fragment =
446                    self.create_plan_fragment(id as u64, Some(PartitionInfo::File(files.to_vec())));
447                let worker =
448                    self.choose_worker(&plan_fragment, id as u32, self.stage.dml_table_id)?;
449                futures.push(self.schedule_task(
450                    task_id,
451                    plan_fragment,
452                    worker,
453                    expr_context.clone(),
454                ));
455            }
456        } else if self.stage.has_vector_search {
457            let id = 0;
458            let task_id = PbTaskId {
459                query_id: self.stage.query_id.id.clone(),
460                stage_id: self.stage.id,
461                task_id: id as u64,
462            };
463            let plan_fragment = self.create_plan_fragment(id as u64, None);
464            let worker = self.choose_worker(&plan_fragment, id, self.stage.dml_table_id)?;
465            futures.push(self.schedule_task(task_id, plan_fragment, worker, expr_context.clone()));
466        } else {
467            for id in 0..self.stage.parallelism.unwrap() {
468                let task_id = PbTaskId {
469                    query_id: self.stage.query_id.id.clone(),
470                    stage_id: self.stage.id,
471                    task_id: id as u64,
472                };
473                let plan_fragment = self.create_plan_fragment(id as u64, None);
474                let worker = self.choose_worker(&plan_fragment, id, self.stage.dml_table_id)?;
475                futures.push(self.schedule_task(
476                    task_id,
477                    plan_fragment,
478                    worker,
479                    expr_context.clone(),
480                ));
481            }
482        }
483
484        // Await each future and convert them into a set of streams.
485        let buffered = stream::iter(futures).buffer_unordered(TASK_SCHEDULING_PARALLELISM);
486        let buffered_streams = buffered.try_collect::<Vec<_>>().await?;
487
488        // Merge different task streams into a single stream.
489        let cancelled = pin!(shutdown_rx.cancelled());
490        let mut all_streams = select_all(buffered_streams).take_until(cancelled);
491
492        // Process the stream until finished.
493        let mut running_task_cnt = 0;
494        let mut finished_task_cnt = 0;
495        let mut sent_signal_to_next = false;
496
497        while let Some(status_res_inner) = all_streams.next().await {
498            match status_res_inner {
499                Ok(status) => {
500                    use risingwave_pb::task_service::task_info_response::TaskStatus as PbTaskStatus;
501                    match PbTaskStatus::try_from(status.task_status).unwrap() {
502                        PbTaskStatus::Running => {
503                            running_task_cnt += 1;
504                            // The task running count should always less or equal than the
505                            // registered tasks number.
506                            assert!(running_task_cnt <= self.tasks.keys().len());
507                            // All tasks in this stage have been scheduled. Notify query runner to
508                            // schedule next stage.
509                            if running_task_cnt == self.tasks.keys().len() {
510                                self.notify_stage_scheduled(QueryMessage::Stage(
511                                    StageEvent::Scheduled(self.stage.id),
512                                ))
513                                .await;
514                                sent_signal_to_next = true;
515                            }
516                        }
517
518                        PbTaskStatus::Finished => {
519                            finished_task_cnt += 1;
520                            assert!(finished_task_cnt <= self.tasks.keys().len());
521                            assert!(running_task_cnt >= finished_task_cnt);
522                            if finished_task_cnt == self.tasks.keys().len() {
523                                // All tasks finished without failure, we should not break
524                                // this loop
525                                self.notify_stage_completed().await;
526                                sent_signal_to_next = true;
527                                break;
528                            }
529                        }
530                        PbTaskStatus::Aborted => {
531                            // Currently, the only reason that we receive an abort status is that
532                            // the task's memory usage is too high so
533                            // it's aborted.
534                            error!(
535                                "Abort task {:?} because of excessive memory usage. Please try again later.",
536                                status.task_id.unwrap()
537                            );
538                            self.notify_stage_state_changed(
539                                |_| StageState::Failed,
540                                QueryMessage::Stage(Failed {
541                                    id: self.stage.id,
542                                    reason: TaskRunningOutOfMemory,
543                                }),
544                            )
545                            .await;
546                            sent_signal_to_next = true;
547                            break;
548                        }
549                        PbTaskStatus::Failed => {
550                            // Task failed, we should fail whole query
551                            error!(
552                                "Task {:?} failed, reason: {:?}",
553                                status.task_id.unwrap(),
554                                status.error_message,
555                            );
556                            self.notify_stage_state_changed(
557                                |_| StageState::Failed,
558                                QueryMessage::Stage(Failed {
559                                    id: self.stage.id,
560                                    reason: TaskExecutionError(status.error_message),
561                                }),
562                            )
563                            .await;
564                            sent_signal_to_next = true;
565                            break;
566                        }
567                        PbTaskStatus::Ping => {
568                            debug!("Receive ping from task {:?}", status.task_id.unwrap());
569                        }
570                        status => {
571                            // The remain possible variant is Failed, but now they won't be pushed
572                            // from CN.
573                            unreachable!("Unexpected task status {:?}", status);
574                        }
575                    }
576                }
577                Err(e) => {
578                    // rpc error here, we should also notify stage failure
579                    error!(
580                        "Fetching task status in stage {:?} failed, reason: {:?}",
581                        self.stage.id,
582                        e.message()
583                    );
584                    self.notify_stage_state_changed(
585                        |_| StageState::Failed,
586                        QueryMessage::Stage(Failed {
587                            id: self.stage.id,
588                            reason: RpcError::from_batch_status(e).into(),
589                        }),
590                    )
591                    .await;
592                    sent_signal_to_next = true;
593                    break;
594                }
595            }
596        }
597
598        tracing::trace!(
599            "Stage [{:?}-{:?}], running task count: {}, finished task count: {}, sent signal to next: {}",
600            self.stage.query_id,
601            self.stage.id,
602            running_task_cnt,
603            finished_task_cnt,
604            sent_signal_to_next,
605        );
606
607        if let Some(shutdown) = all_streams.take_future() {
608            tracing::trace!(
609                "Stage [{:?}-{:?}] waiting for stopping signal.",
610                self.stage.query_id,
611                self.stage.id
612            );
613            // Waiting for shutdown signal.
614            shutdown.await;
615        }
616
617        // Received shutdown signal from query runner, should send abort RPC to all CNs.
618        // change state to aborted. Note that the task cancel can only happen after schedule
619        // all these tasks to CN. This can be an optimization for future:
620        // How to stop before schedule tasks.
621        tracing::trace!(
622            "Stopping stage: {:?}-{:?}, task_num: {}",
623            self.stage.query_id,
624            self.stage.id,
625            self.tasks.len()
626        );
627        self.cancel_all_scheducancled_tasks().await?;
628
629        tracing::trace!(
630            "Stage runner [{:?}-{:?}] exited.",
631            self.stage.query_id,
632            self.stage.id
633        );
634        Ok(())
635    }
636
637    async fn schedule_tasks_for_root(
638        &mut self,
639        mut shutdown_rx: ShutdownToken,
640        expr_context: ExprContext,
641    ) -> SchedulerResult<()> {
642        let root_stage_id = self.stage.id;
643        // Currently, the dml or table scan should never be root fragment, so the partition is None.
644        // And root fragment only contain one task.
645        let plan_fragment = self.create_plan_fragment(ROOT_TASK_ID, None);
646        let plan_node = plan_fragment.root.unwrap();
647        let task_id = TaskIdBatch {
648            query_id: self.stage.query_id.id.clone(),
649            stage_id: root_stage_id,
650            task_id: 0,
651        };
652
653        // Notify QueryRunner to poll chunk from result_rx.
654        let (result_tx, result_rx) = tokio::sync::mpsc::channel(
655            self.ctx
656                .session
657                .env()
658                .batch_config()
659                .developer
660                .root_stage_channel_size,
661        );
662        self.notify_stage_scheduled(QueryMessage::Stage(StageEvent::ScheduledRoot(result_rx)))
663            .await;
664
665        let executor = ExecutorBuilder::new(
666            &plan_node,
667            &task_id,
668            self.ctx.to_batch_task_context(),
669            self.epoch,
670            shutdown_rx.clone(),
671        );
672
673        let shutdown_rx0 = shutdown_rx.clone();
674
675        let result = expr_context_scope(expr_context, async {
676            let executor = executor.build().await?;
677            let chunk_stream = executor.execute();
678            let cancelled = pin!(shutdown_rx.cancelled());
679            #[for_await]
680            for chunk in chunk_stream.take_until(cancelled) {
681                if let Err(ref e) = chunk {
682                    if shutdown_rx0.is_cancelled() {
683                        break;
684                    }
685                    let err_str = e.to_report_string();
686                    // This is possible if The Query Runner drop early before schedule the root
687                    // executor. Detail described in https://github.com/risingwavelabs/risingwave/issues/6883#issuecomment-1348102037.
688                    // The error format is just channel closed so no care.
689                    if let Err(_e) = result_tx.send(chunk.map_err(|e| e.into())).await {
690                        warn!("Root executor has been dropped before receive any events so the send is failed");
691                    }
692                    // Different from below, return this function and report error.
693                    return Err(TaskExecutionError(err_str));
694                } else {
695                    // Same for below.
696                    if let Err(_e) = result_tx.send(chunk.map_err(|e| e.into())).await {
697                        warn!("Root executor has been dropped before receive any events so the send is failed");
698                    }
699                }
700            }
701            Ok(())
702        }).await;
703
704        if let Err(err) = &result {
705            // If we encountered error when executing root stage locally, we have to notify the result fetcher, which is
706            // returned by `distribute_execute` and being listened by the FE handler task. Otherwise the FE handler cannot
707            // properly throw the error to the PG client.
708            if let Err(_e) = result_tx
709                .send(Err(TaskExecutionError(err.to_report_string())))
710                .await
711            {
712                warn!("Send task execution failed");
713            }
714        }
715
716        // Terminated by other tasks execution error, so no need to return error here.
717        match shutdown_rx0.message() {
718            ShutdownMsg::Abort(err_str) => {
719                // Tell Query Result Fetcher to stop polling and attach failure reason as str.
720                if let Err(_e) = result_tx.send(Err(TaskExecutionError(err_str))).await {
721                    warn!("Send task execution failed");
722                }
723            }
724            _ => self.notify_stage_completed().await,
725        }
726
727        tracing::trace!(
728            "Stage runner [{:?}-{:?}] existed. ",
729            self.stage.query_id,
730            self.stage.id
731        );
732
733        // We still have to throw the error in this current task, so that `StageRunner::run` can further
734        // send `Failed` event to stop other stages.
735        result.map(|_| ())
736    }
737
738    async fn schedule_tasks_for_all(&mut self, shutdown_rx: ShutdownToken) -> SchedulerResult<()> {
739        let expr_context = ExprContext {
740            time_zone: self.ctx.session().config().timezone().to_owned(),
741            strict_mode: self.ctx.session().config().batch_expr_strict_mode(),
742        };
743        // If root, we execute it locally.
744        if !self.is_root_stage() {
745            self.schedule_tasks(shutdown_rx, expr_context).await?;
746        } else {
747            self.schedule_tasks_for_root(shutdown_rx, expr_context)
748                .await?;
749        }
750        Ok(())
751    }
752
753    #[inline(always)]
754    fn get_fragment_id(&self, table_id: &TableId) -> SchedulerResult<FragmentId> {
755        self.catalog_reader
756            .read_guard()
757            .get_any_table_by_id(table_id)
758            .map(|table| table.fragment_id)
759            .map_err(|e| SchedulerError::Internal(anyhow!(e)))
760    }
761
762    #[inline(always)]
763    fn get_table_dml_vnode_mapping(
764        &self,
765        table_id: &TableId,
766    ) -> SchedulerResult<WorkerSlotMapping> {
767        let guard = self.catalog_reader.read_guard();
768
769        let table = guard
770            .get_any_table_by_id(table_id)
771            .map_err(|e| SchedulerError::Internal(anyhow!(e)))?;
772
773        let fragment_id = match table.dml_fragment_id.as_ref() {
774            Some(dml_fragment_id) => dml_fragment_id,
775            // Backward compatibility for those table without `dml_fragment_id`.
776            None => &table.fragment_id,
777        };
778
779        self.worker_node_manager
780            .manager
781            .get_streaming_fragment_mapping(fragment_id)
782            .map_err(|e| e.into())
783    }
784
785    fn choose_worker(
786        &self,
787        plan_fragment: &PlanFragment,
788        task_id: u32,
789        dml_table_id: Option<TableId>,
790    ) -> SchedulerResult<Option<WorkerNode>> {
791        let plan_node = plan_fragment.root.as_ref().expect("fail to get plan node");
792
793        if let Some(table_id) = dml_table_id {
794            let vnode_mapping = self.get_table_dml_vnode_mapping(&table_id)?;
795            let worker_slot_ids = vnode_mapping.iter_unique().collect_vec();
796            let candidates = self
797                .worker_node_manager
798                .manager
799                .get_workers_by_worker_slot_ids(&worker_slot_ids)?;
800            if candidates.is_empty() {
801                return Err(BatchError::EmptyWorkerNodes.into());
802            }
803            let candidate = if self.stage.batch_enable_distributed_dml {
804                // If distributed dml is enabled, we need to try our best to distribute dml tasks evenly to each worker.
805                // Using a `task_id` could be helpful in this case.
806                candidates[task_id as usize % candidates.len()].clone()
807            } else {
808                // If distributed dml is disabled, we need to guarantee that dml from the same session would be sent to a fixed worker/channel to provide a order guarantee.
809                candidates[self.stage.session_id.0 as usize % candidates.len()].clone()
810            };
811            return Ok(Some(candidate));
812        };
813
814        if let Some(distributed_lookup_join_node) =
815            Self::find_distributed_lookup_join_node(plan_node)
816        {
817            let fragment_id = self.get_fragment_id(
818                &distributed_lookup_join_node
819                    .inner_side_table_desc
820                    .as_ref()
821                    .unwrap()
822                    .table_id
823                    .into(),
824            )?;
825            let id_to_worker_slots = self
826                .worker_node_manager
827                .fragment_mapping(fragment_id)?
828                .iter_unique()
829                .collect_vec();
830
831            let worker_slot_id = id_to_worker_slots[task_id as usize];
832            let candidates = self
833                .worker_node_manager
834                .manager
835                .get_workers_by_worker_slot_ids(&[worker_slot_id])?;
836            if candidates.is_empty() {
837                return Err(BatchError::EmptyWorkerNodes.into());
838            }
839            Ok(Some(candidates[0].clone()))
840        } else {
841            Ok(None)
842        }
843    }
844
845    fn find_distributed_lookup_join_node(
846        plan_node: &PlanNode,
847    ) -> Option<&DistributedLookupJoinNode> {
848        let node_body = plan_node.node_body.as_ref().expect("fail to get node body");
849
850        match node_body {
851            NodeBody::DistributedLookupJoin(distributed_lookup_join_node) => {
852                Some(distributed_lookup_join_node)
853            }
854            _ => plan_node
855                .children
856                .iter()
857                .find_map(Self::find_distributed_lookup_join_node),
858        }
859    }
860
861    /// Write message into channel to notify query runner current stage have been scheduled.
862    async fn notify_stage_scheduled(&self, msg: QueryMessage) {
863        self.notify_stage_state_changed(
864            |old_state| {
865                assert_matches!(old_state, StageState::Started);
866                StageState::Running
867            },
868            msg,
869        )
870        .await
871    }
872
873    /// Notify query execution that this stage completed.
874    async fn notify_stage_completed(&self) {
875        self.notify_stage_state_changed(
876            |old_state| {
877                assert_matches!(old_state, StageState::Running);
878                StageState::Completed
879            },
880            QueryMessage::Stage(StageEvent::Completed(self.stage.id)),
881        )
882        .await
883    }
884
885    async fn notify_stage_state_changed<F>(&self, new_state: F, msg: QueryMessage)
886    where
887        F: FnOnce(StageState) -> StageState,
888    {
889        {
890            let mut s = self.state.write().await;
891            let old_state = mem::replace(&mut *s, StageState::Failed);
892            *s = new_state(old_state);
893        }
894
895        self.send_event(msg).await;
896    }
897
898    /// Abort all registered tasks. Note that here we do not care which part of tasks has already
899    /// failed or completed, cuz the abort task will not fail if the task has already die.
900    /// See PR (#4560).
901    async fn cancel_all_scheducancled_tasks(&self) -> SchedulerResult<()> {
902        // Set state to failed.
903        // {
904        //     let mut state = self.state.write().await;
905        //     // Ignore if already finished.
906        //     if let &StageState::Completed = &*state {
907        //         return Ok(());
908        //     }
909        //     // FIXME: Be careful for state jump back.
910        //     *state = StageState::Failed
911        // }
912
913        for (task, task_status) in &*self.tasks {
914            // 1. Collect task info and client.
915            let loc = &task_status.get_status().location;
916            let addr = loc.as_ref().expect("Get address should not fail");
917            let client = self
918                .compute_client_pool
919                .get_by_addr(HostAddr::from(addr))
920                .await
921                .map_err(|e| anyhow!(e))?;
922
923            // 2. Send RPC to each compute node for each task asynchronously.
924            let query_id = self.stage.query_id.id.clone();
925            let stage_id = self.stage.id;
926            let task_id = *task;
927            spawn(async move {
928                if let Err(e) = client
929                    .cancel(CancelTaskRequest {
930                        task_id: Some(risingwave_pb::batch_plan::TaskId {
931                            query_id: query_id.clone(),
932                            stage_id,
933                            task_id,
934                        }),
935                    })
936                    .await
937                {
938                    error!(
939                        error = %e.as_report(),
940                        ?task_id,
941                        ?query_id,
942                        ?stage_id,
943                        "Abort task failed",
944                    );
945                };
946            });
947        }
948        Ok(())
949    }
950
951    async fn schedule_task(
952        &self,
953        task_id: PbTaskId,
954        plan_fragment: PlanFragment,
955        worker: Option<WorkerNode>,
956        expr_context: ExprContext,
957    ) -> SchedulerResult<Fuse<Streaming<TaskInfoResponse>>> {
958        let mut worker = worker.unwrap_or(self.worker_node_manager.next_random_worker()?);
959        let worker_node_addr = worker.host.take().unwrap();
960        let compute_client = self
961            .compute_client_pool
962            .get_by_addr((&worker_node_addr).into())
963            .await
964            .inspect_err(|_| self.mask_failed_serving_worker(&worker))
965            .map_err(|e| anyhow!(e))?;
966
967        let t_id = task_id.task_id;
968
969        let stream_status: Fuse<Streaming<TaskInfoResponse>> = compute_client
970            .create_task(task_id, plan_fragment, self.epoch, expr_context)
971            .await
972            .inspect_err(|_| self.mask_failed_serving_worker(&worker))
973            .map_err(|e| anyhow!(e))?
974            .fuse();
975
976        self.tasks[&t_id].inner.store(Arc::new(TaskStatus {
977            _task_id: t_id,
978            location: Some(worker_node_addr),
979        }));
980
981        Ok(stream_status)
982    }
983
984    pub fn create_plan_fragment(
985        &self,
986        task_id: TaskId,
987        partition: Option<PartitionInfo>,
988    ) -> PlanFragment {
989        // Used to maintain auto-increment identity_id of a task.
990        let identity_id: Rc<RefCell<u64>> = Rc::new(RefCell::new(0));
991
992        let plan_node_prost =
993            self.convert_plan_node(&self.stage.root, task_id, partition, identity_id);
994        let exchange_info = self.stage.exchange_info.clone().unwrap();
995
996        PlanFragment {
997            root: Some(plan_node_prost),
998            exchange_info: Some(exchange_info),
999        }
1000    }
1001
1002    fn convert_plan_node(
1003        &self,
1004        execution_plan_node: &ExecutionPlanNode,
1005        task_id: TaskId,
1006        partition: Option<PartitionInfo>,
1007        identity_id: Rc<RefCell<u64>>,
1008    ) -> PbPlanNode {
1009        // Generate identity
1010        let identity = {
1011            let identity_type = execution_plan_node.plan_node_type;
1012            let id = *identity_id.borrow();
1013            identity_id.replace(id + 1);
1014            format!("{:?}-{}", identity_type, id)
1015        };
1016
1017        match execution_plan_node.plan_node_type {
1018            BatchPlanNodeType::BatchExchange => {
1019                // Find the stage this exchange node should fetch from and get all exchange sources.
1020                let child_stage = self
1021                    .children
1022                    .iter()
1023                    .find(|child_stage| {
1024                        child_stage.stage.id == execution_plan_node.source_stage_id.unwrap()
1025                    })
1026                    .unwrap();
1027                let exchange_sources = child_stage.all_exchange_sources_for(task_id);
1028
1029                match &execution_plan_node.node {
1030                    NodeBody::Exchange(exchange_node) => PbPlanNode {
1031                        children: vec![],
1032                        identity,
1033                        node_body: Some(NodeBody::Exchange(ExchangeNode {
1034                            sources: exchange_sources,
1035                            sequential: exchange_node.sequential,
1036                            input_schema: execution_plan_node.schema.clone(),
1037                        })),
1038                    },
1039                    NodeBody::MergeSortExchange(sort_merge_exchange_node) => PbPlanNode {
1040                        children: vec![],
1041                        identity,
1042                        node_body: Some(NodeBody::MergeSortExchange(MergeSortExchangeNode {
1043                            exchange: Some(ExchangeNode {
1044                                sources: exchange_sources,
1045                                sequential: false,
1046                                input_schema: execution_plan_node.schema.clone(),
1047                            }),
1048                            column_orders: sort_merge_exchange_node.column_orders.clone(),
1049                        })),
1050                    },
1051                    _ => unreachable!(),
1052                }
1053            }
1054            BatchPlanNodeType::BatchSeqScan => {
1055                let node_body = execution_plan_node.node.clone();
1056                let NodeBody::RowSeqScan(mut scan_node) = node_body else {
1057                    unreachable!();
1058                };
1059                let partition = partition
1060                    .expect("no partition info for seq scan")
1061                    .into_table()
1062                    .expect("PartitionInfo should be TablePartitionInfo");
1063                scan_node.vnode_bitmap = Some(partition.vnode_bitmap.to_protobuf());
1064                scan_node.scan_ranges = partition.scan_ranges;
1065                PbPlanNode {
1066                    children: vec![],
1067                    identity,
1068                    node_body: Some(NodeBody::RowSeqScan(scan_node)),
1069                }
1070            }
1071            BatchPlanNodeType::BatchLogSeqScan => {
1072                let node_body = execution_plan_node.node.clone();
1073                let NodeBody::LogRowSeqScan(mut scan_node) = node_body else {
1074                    unreachable!();
1075                };
1076                let partition = partition
1077                    .expect("no partition info for seq scan")
1078                    .into_table()
1079                    .expect("PartitionInfo should be TablePartitionInfo");
1080                scan_node.vnode_bitmap = Some(partition.vnode_bitmap.to_protobuf());
1081                PbPlanNode {
1082                    children: vec![],
1083                    identity,
1084                    node_body: Some(NodeBody::LogRowSeqScan(scan_node)),
1085                }
1086            }
1087            BatchPlanNodeType::BatchSource | BatchPlanNodeType::BatchKafkaScan => {
1088                let node_body = execution_plan_node.node.clone();
1089                let NodeBody::Source(mut source_node) = node_body else {
1090                    unreachable!();
1091                };
1092
1093                let partition = partition
1094                    .expect("no partition info for seq scan")
1095                    .into_source()
1096                    .expect("PartitionInfo should be SourcePartitionInfo");
1097                source_node.split = partition
1098                    .into_iter()
1099                    .map(|split| split.encode_to_bytes().into())
1100                    .collect_vec();
1101                PbPlanNode {
1102                    children: vec![],
1103                    identity,
1104                    node_body: Some(NodeBody::Source(source_node)),
1105                }
1106            }
1107            BatchPlanNodeType::BatchIcebergScan => {
1108                let node_body = execution_plan_node.node.clone();
1109                let NodeBody::IcebergScan(mut iceberg_scan_node) = node_body else {
1110                    unreachable!();
1111                };
1112
1113                let partition = partition
1114                    .expect("no partition info for seq scan")
1115                    .into_source()
1116                    .expect("PartitionInfo should be SourcePartitionInfo");
1117                iceberg_scan_node.split = partition
1118                    .into_iter()
1119                    .map(|split| split.encode_to_bytes().into())
1120                    .collect_vec();
1121                PbPlanNode {
1122                    children: vec![],
1123                    identity,
1124                    node_body: Some(NodeBody::IcebergScan(iceberg_scan_node)),
1125                }
1126            }
1127            _ => {
1128                let children = execution_plan_node
1129                    .children
1130                    .iter()
1131                    .map(|e| {
1132                        self.convert_plan_node(e, task_id, partition.clone(), identity_id.clone())
1133                    })
1134                    .collect();
1135
1136                PbPlanNode {
1137                    children,
1138                    identity,
1139                    node_body: Some(execution_plan_node.node.clone()),
1140                }
1141            }
1142        }
1143    }
1144
1145    fn is_root_stage(&self) -> bool {
1146        self.stage.id == 0
1147    }
1148
1149    fn mask_failed_serving_worker(&self, worker: &WorkerNode) {
1150        if !worker.property.as_ref().is_some_and(|p| p.is_serving) {
1151            return;
1152        }
1153        let duration = Duration::from_secs(std::cmp::max(
1154            self.ctx
1155                .session
1156                .env()
1157                .batch_config()
1158                .mask_worker_temporary_secs as u64,
1159            1,
1160        ));
1161        self.worker_node_manager
1162            .manager
1163            .mask_worker_node(worker.id, duration);
1164    }
1165}