1use std::assert_matches::assert_matches;
16use std::cell::RefCell;
17use std::collections::HashMap;
18use std::mem;
19use std::pin::pin;
20use std::rc::Rc;
21use std::sync::Arc;
22use std::time::Duration;
23
24use StageEvent::Failed;
25use anyhow::anyhow;
26use arc_swap::ArcSwap;
27use futures::stream::Fuse;
28use futures::{StreamExt, TryStreamExt, stream};
29use futures_async_stream::for_await;
30use itertools::Itertools;
31use risingwave_batch::error::BatchError;
32use risingwave_batch::executor::ExecutorBuilder;
33use risingwave_batch::task::{ShutdownMsg, ShutdownSender, ShutdownToken, TaskId as TaskIdBatch};
34use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeSelector;
35use risingwave_common::array::DataChunk;
36use risingwave_common::hash::WorkerSlotMapping;
37use risingwave_common::util::addr::HostAddr;
38use risingwave_common::util::iter_util::ZipEqFast;
39use risingwave_connector::source::SplitMetaData;
40use risingwave_expr::expr_context::expr_context_scope;
41use risingwave_pb::batch_plan::plan_node::NodeBody;
42use risingwave_pb::batch_plan::{
43 DistributedLookupJoinNode, ExchangeNode, ExchangeSource, MergeSortExchangeNode, PlanFragment,
44 PlanNode as PbPlanNode, PlanNode, TaskId as PbTaskId, TaskOutputId,
45};
46use risingwave_pb::common::{BatchQueryEpoch, HostAddress, WorkerNode};
47use risingwave_pb::plan_common::ExprContext;
48use risingwave_pb::task_service::{CancelTaskRequest, TaskInfoResponse};
49use risingwave_rpc_client::ComputeClientPoolRef;
50use risingwave_rpc_client::error::RpcError;
51use rw_futures_util::select_all;
52use thiserror_ext::AsReport;
53use tokio::spawn;
54use tokio::sync::RwLock;
55use tokio::sync::mpsc::{Receiver, Sender};
56use tonic::Streaming;
57use tracing::{Instrument, debug, error, warn};
58
59use crate::catalog::catalog_service::CatalogReader;
60use crate::catalog::{FragmentId, TableId};
61use crate::optimizer::plan_node::PlanNodeType;
62use crate::scheduler::SchedulerError::{TaskExecutionError, TaskRunningOutOfMemory};
63use crate::scheduler::distributed::QueryMessage;
64use crate::scheduler::distributed::stage::StageState::Pending;
65use crate::scheduler::plan_fragmenter::{
66 ExecutionPlanNode, PartitionInfo, QueryStageRef, ROOT_TASK_ID, StageId, TaskId,
67};
68use crate::scheduler::{ExecutionContextRef, SchedulerError, SchedulerResult};
69
70const TASK_SCHEDULING_PARALLELISM: usize = 10;
71
72#[derive(Debug)]
73enum StageState {
74 Pending {
79 msg_sender: Sender<QueryMessage>,
80 },
81 Started,
82 Running,
83 Completed,
84 Failed,
85}
86
87#[derive(Debug)]
88pub enum StageEvent {
89 Scheduled(StageId),
90 ScheduledRoot(Receiver<SchedulerResult<DataChunk>>),
91 Failed {
93 id: StageId,
94 reason: SchedulerError,
95 },
96 Completed(#[allow(dead_code)] StageId),
98}
99
100#[derive(Clone)]
101pub struct TaskStatus {
102 _task_id: TaskId,
103
104 location: Option<HostAddress>,
106}
107
108struct TaskStatusHolder {
109 inner: ArcSwap<TaskStatus>,
110}
111
112pub struct StageExecution {
113 epoch: BatchQueryEpoch,
114 stage: QueryStageRef,
115 worker_node_manager: WorkerNodeSelector,
116 tasks: Arc<HashMap<TaskId, TaskStatusHolder>>,
117 state: Arc<RwLock<StageState>>,
118 shutdown_tx: RwLock<Option<ShutdownSender>>,
119 children: Vec<Arc<StageExecution>>,
123 compute_client_pool: ComputeClientPoolRef,
124 catalog_reader: CatalogReader,
125
126 ctx: ExecutionContextRef,
128}
129
130struct StageRunner {
131 epoch: BatchQueryEpoch,
132 state: Arc<RwLock<StageState>>,
133 stage: QueryStageRef,
134 worker_node_manager: WorkerNodeSelector,
135 tasks: Arc<HashMap<TaskId, TaskStatusHolder>>,
136 msg_sender: Sender<QueryMessage>,
138 children: Vec<Arc<StageExecution>>,
139 compute_client_pool: ComputeClientPoolRef,
140 catalog_reader: CatalogReader,
141
142 ctx: ExecutionContextRef,
143}
144
145impl TaskStatusHolder {
146 fn new(task_id: TaskId) -> Self {
147 let task_status = TaskStatus {
148 _task_id: task_id,
149 location: None,
150 };
151
152 Self {
153 inner: ArcSwap::new(Arc::new(task_status)),
154 }
155 }
156
157 fn get_status(&self) -> Arc<TaskStatus> {
158 self.inner.load_full()
159 }
160}
161
162impl StageExecution {
163 #[allow(clippy::too_many_arguments)]
164 pub fn new(
165 epoch: BatchQueryEpoch,
166 stage: QueryStageRef,
167 worker_node_manager: WorkerNodeSelector,
168 msg_sender: Sender<QueryMessage>,
169 children: Vec<Arc<StageExecution>>,
170 compute_client_pool: ComputeClientPoolRef,
171 catalog_reader: CatalogReader,
172 ctx: ExecutionContextRef,
173 ) -> Self {
174 let tasks = (0..stage.parallelism.unwrap())
175 .map(|task_id| (task_id as u64, TaskStatusHolder::new(task_id as u64)))
176 .collect();
177
178 Self {
179 epoch,
180 stage,
181 worker_node_manager,
182 tasks: Arc::new(tasks),
183 state: Arc::new(RwLock::new(Pending { msg_sender })),
184 shutdown_tx: RwLock::new(None),
185 children,
186 compute_client_pool,
187 catalog_reader,
188 ctx,
189 }
190 }
191
192 pub async fn start(&self) {
194 let mut s = self.state.write().await;
195 let cur_state = mem::replace(&mut *s, StageState::Failed);
196 match cur_state {
197 Pending { msg_sender } => {
198 let runner = StageRunner {
199 epoch: self.epoch,
200 stage: self.stage.clone(),
201 worker_node_manager: self.worker_node_manager.clone(),
202 tasks: self.tasks.clone(),
203 msg_sender,
204 children: self.children.clone(),
205 state: self.state.clone(),
206 compute_client_pool: self.compute_client_pool.clone(),
207 catalog_reader: self.catalog_reader.clone(),
208 ctx: self.ctx.clone(),
209 };
210
211 let (sender, receiver) = ShutdownToken::new();
213 let mut holder = self.shutdown_tx.write().await;
215 *holder = Some(sender);
216
217 *s = StageState::Started;
219
220 let span = tracing::info_span!(
221 "stage",
222 "otel.name" = format!("Stage {}-{}", self.stage.query_id.id, self.stage.id),
223 query_id = self.stage.query_id.id,
224 stage_id = self.stage.id,
225 );
226 self.ctx
227 .session()
228 .env()
229 .compute_runtime()
230 .spawn(async move { runner.run(receiver).instrument(span).await });
231
232 tracing::trace!(
233 "Stage {:?}-{:?} started.",
234 self.stage.query_id.id,
235 self.stage.id
236 )
237 }
238 _ => {
239 unreachable!("Only expect to schedule stage once");
240 }
241 }
242 }
243
244 pub async fn stop(&self, error: Option<String>) {
245 if let Some(shutdown_tx) = self.shutdown_tx.write().await.take() {
247 if !if let Some(error) = error {
251 shutdown_tx.abort(error)
252 } else {
253 shutdown_tx.cancel()
254 } {
255 tracing::trace!(
257 "Failed to send stop message stage: {:?}-{:?}",
258 self.stage.query_id,
259 self.stage.id
260 );
261 }
262 }
263 }
264
265 pub async fn is_scheduled(&self) -> bool {
266 let s = self.state.read().await;
267 matches!(*s, StageState::Running | StageState::Completed)
268 }
269
270 pub async fn is_pending(&self) -> bool {
271 let s = self.state.read().await;
272 matches!(*s, StageState::Pending { .. })
273 }
274
275 pub async fn state(&self) -> &'static str {
276 let s = self.state.read().await;
277 match *s {
278 Pending { .. } => "Pending",
279 StageState::Started => "Started",
280 StageState::Running => "Running",
281 StageState::Completed => "Completed",
282 StageState::Failed => "Failed",
283 }
284 }
285
286 pub fn all_exchange_sources_for(&self, output_id: u64) -> Vec<ExchangeSource> {
293 self.tasks
294 .iter()
295 .map(|(task_id, status_holder)| {
296 let task_output_id = TaskOutputId {
297 task_id: Some(PbTaskId {
298 query_id: self.stage.query_id.id.clone(),
299 stage_id: self.stage.id,
300 task_id: *task_id,
301 }),
302 output_id,
303 };
304
305 ExchangeSource {
306 task_output_id: Some(task_output_id),
307 host: Some(status_holder.inner.load_full().location.clone().unwrap()),
308 local_execute_plan: None,
309 }
310 })
311 .collect()
312 }
313}
314
315impl StageRunner {
316 async fn run(mut self, shutdown_rx: ShutdownToken) {
317 if let Err(e) = self.schedule_tasks_for_all(shutdown_rx).await {
318 error!(
319 error = %e.as_report(),
320 query_id = ?self.stage.query_id,
321 stage_id = ?self.stage.id,
322 "Failed to schedule tasks"
323 );
324 self.send_event(QueryMessage::Stage(Failed {
325 id: self.stage.id,
326 reason: e,
327 }))
328 .await;
329 }
330 }
331
332 async fn send_event(&self, event: QueryMessage) {
334 if let Err(_e) = self.msg_sender.send(event).await {
335 warn!("Failed to send event to Query Runner, may be killed by previous failed event");
336 }
337 }
338
339 async fn schedule_tasks(
342 &mut self,
343 mut shutdown_rx: ShutdownToken,
344 expr_context: ExprContext,
345 ) -> SchedulerResult<()> {
346 let mut futures = vec![];
347
348 if let Some(table_scan_info) = self.stage.table_scan_info.as_ref()
349 && let Some(vnode_bitmaps) = table_scan_info.partitions()
350 {
351 let worker_slot_ids = vnode_bitmaps.keys().cloned().collect_vec();
357 let workers = self
358 .worker_node_manager
359 .manager
360 .get_workers_by_worker_slot_ids(&worker_slot_ids)?;
361
362 for (i, (worker_slot_id, worker)) in worker_slot_ids
363 .into_iter()
364 .zip_eq_fast(workers.into_iter())
365 .enumerate()
366 {
367 let task_id = PbTaskId {
368 query_id: self.stage.query_id.id.clone(),
369 stage_id: self.stage.id,
370 task_id: i as u64,
371 };
372 let vnode_ranges = vnode_bitmaps[&worker_slot_id].clone();
373 let plan_fragment =
374 self.create_plan_fragment(i as u64, Some(PartitionInfo::Table(vnode_ranges)));
375 futures.push(self.schedule_task(
376 task_id,
377 plan_fragment,
378 Some(worker),
379 expr_context.clone(),
380 ));
381 }
382 } else if let Some(source_info) = self.stage.source_info.as_ref() {
383 let chunk_size = ((source_info.split_info().unwrap().len() as f32
385 / self.stage.parallelism.unwrap() as f32)
386 .ceil() as usize)
387 .max(1);
388 if source_info.split_info().unwrap().is_empty() {
389 const EMPTY_TASK_ID: u64 = 0;
391 let task_id = PbTaskId {
392 query_id: self.stage.query_id.id.clone(),
393 stage_id: self.stage.id,
394 task_id: EMPTY_TASK_ID,
395 };
396 let plan_fragment =
397 self.create_plan_fragment(EMPTY_TASK_ID, Some(PartitionInfo::Source(vec![])));
398 let worker = self.choose_worker(
399 &plan_fragment,
400 EMPTY_TASK_ID as u32,
401 self.stage.dml_table_id,
402 )?;
403 futures.push(self.schedule_task(
404 task_id,
405 plan_fragment,
406 worker,
407 expr_context.clone(),
408 ));
409 } else {
410 for (id, split) in source_info
411 .split_info()
412 .unwrap()
413 .chunks(chunk_size)
414 .enumerate()
415 {
416 let task_id = PbTaskId {
417 query_id: self.stage.query_id.id.clone(),
418 stage_id: self.stage.id,
419 task_id: id as u64,
420 };
421 let plan_fragment = self.create_plan_fragment(
422 id as u64,
423 Some(PartitionInfo::Source(split.to_vec())),
424 );
425 let worker =
426 self.choose_worker(&plan_fragment, id as u32, self.stage.dml_table_id)?;
427 futures.push(self.schedule_task(
428 task_id,
429 plan_fragment,
430 worker,
431 expr_context.clone(),
432 ));
433 }
434 }
435 } else if let Some(file_scan_info) = self.stage.file_scan_info.as_ref() {
436 let chunk_size = (file_scan_info.file_location.len() as f32
437 / self.stage.parallelism.unwrap() as f32)
438 .ceil() as usize;
439 for (id, files) in file_scan_info.file_location.chunks(chunk_size).enumerate() {
440 let task_id = PbTaskId {
441 query_id: self.stage.query_id.id.clone(),
442 stage_id: self.stage.id,
443 task_id: id as u64,
444 };
445 let plan_fragment =
446 self.create_plan_fragment(id as u64, Some(PartitionInfo::File(files.to_vec())));
447 let worker =
448 self.choose_worker(&plan_fragment, id as u32, self.stage.dml_table_id)?;
449 futures.push(self.schedule_task(
450 task_id,
451 plan_fragment,
452 worker,
453 expr_context.clone(),
454 ));
455 }
456 } else {
457 for id in 0..self.stage.parallelism.unwrap() {
458 let task_id = PbTaskId {
459 query_id: self.stage.query_id.id.clone(),
460 stage_id: self.stage.id,
461 task_id: id as u64,
462 };
463 let plan_fragment = self.create_plan_fragment(id as u64, None);
464 let worker = self.choose_worker(&plan_fragment, id, self.stage.dml_table_id)?;
465 futures.push(self.schedule_task(
466 task_id,
467 plan_fragment,
468 worker,
469 expr_context.clone(),
470 ));
471 }
472 }
473
474 let buffered = stream::iter(futures).buffer_unordered(TASK_SCHEDULING_PARALLELISM);
476 let buffered_streams = buffered.try_collect::<Vec<_>>().await?;
477
478 let cancelled = pin!(shutdown_rx.cancelled());
480 let mut all_streams = select_all(buffered_streams).take_until(cancelled);
481
482 let mut running_task_cnt = 0;
484 let mut finished_task_cnt = 0;
485 let mut sent_signal_to_next = false;
486
487 while let Some(status_res_inner) = all_streams.next().await {
488 match status_res_inner {
489 Ok(status) => {
490 use risingwave_pb::task_service::task_info_response::TaskStatus as PbTaskStatus;
491 match PbTaskStatus::try_from(status.task_status).unwrap() {
492 PbTaskStatus::Running => {
493 running_task_cnt += 1;
494 assert!(running_task_cnt <= self.tasks.keys().len());
497 if running_task_cnt == self.tasks.keys().len() {
500 self.notify_stage_scheduled(QueryMessage::Stage(
501 StageEvent::Scheduled(self.stage.id),
502 ))
503 .await;
504 sent_signal_to_next = true;
505 }
506 }
507
508 PbTaskStatus::Finished => {
509 finished_task_cnt += 1;
510 assert!(finished_task_cnt <= self.tasks.keys().len());
511 assert!(running_task_cnt >= finished_task_cnt);
512 if finished_task_cnt == self.tasks.keys().len() {
513 self.notify_stage_completed().await;
516 sent_signal_to_next = true;
517 break;
518 }
519 }
520 PbTaskStatus::Aborted => {
521 error!(
525 "Abort task {:?} because of excessive memory usage. Please try again later.",
526 status.task_id.unwrap()
527 );
528 self.notify_stage_state_changed(
529 |_| StageState::Failed,
530 QueryMessage::Stage(Failed {
531 id: self.stage.id,
532 reason: TaskRunningOutOfMemory,
533 }),
534 )
535 .await;
536 sent_signal_to_next = true;
537 break;
538 }
539 PbTaskStatus::Failed => {
540 error!(
542 "Task {:?} failed, reason: {:?}",
543 status.task_id.unwrap(),
544 status.error_message,
545 );
546 self.notify_stage_state_changed(
547 |_| StageState::Failed,
548 QueryMessage::Stage(Failed {
549 id: self.stage.id,
550 reason: TaskExecutionError(status.error_message),
551 }),
552 )
553 .await;
554 sent_signal_to_next = true;
555 break;
556 }
557 PbTaskStatus::Ping => {
558 debug!("Receive ping from task {:?}", status.task_id.unwrap());
559 }
560 status => {
561 unreachable!("Unexpected task status {:?}", status);
564 }
565 }
566 }
567 Err(e) => {
568 error!(
570 "Fetching task status in stage {:?} failed, reason: {:?}",
571 self.stage.id,
572 e.message()
573 );
574 self.notify_stage_state_changed(
575 |_| StageState::Failed,
576 QueryMessage::Stage(Failed {
577 id: self.stage.id,
578 reason: RpcError::from_batch_status(e).into(),
579 }),
580 )
581 .await;
582 sent_signal_to_next = true;
583 break;
584 }
585 }
586 }
587
588 tracing::trace!(
589 "Stage [{:?}-{:?}], running task count: {}, finished task count: {}, sent signal to next: {}",
590 self.stage.query_id,
591 self.stage.id,
592 running_task_cnt,
593 finished_task_cnt,
594 sent_signal_to_next,
595 );
596
597 if let Some(shutdown) = all_streams.take_future() {
598 tracing::trace!(
599 "Stage [{:?}-{:?}] waiting for stopping signal.",
600 self.stage.query_id,
601 self.stage.id
602 );
603 shutdown.await;
605 }
606
607 tracing::trace!(
612 "Stopping stage: {:?}-{:?}, task_num: {}",
613 self.stage.query_id,
614 self.stage.id,
615 self.tasks.len()
616 );
617 self.cancel_all_scheducancled_tasks().await?;
618
619 tracing::trace!(
620 "Stage runner [{:?}-{:?}] exited.",
621 self.stage.query_id,
622 self.stage.id
623 );
624 Ok(())
625 }
626
627 async fn schedule_tasks_for_root(
628 &mut self,
629 mut shutdown_rx: ShutdownToken,
630 expr_context: ExprContext,
631 ) -> SchedulerResult<()> {
632 let root_stage_id = self.stage.id;
633 let plan_fragment = self.create_plan_fragment(ROOT_TASK_ID, None);
636 let plan_node = plan_fragment.root.unwrap();
637 let task_id = TaskIdBatch {
638 query_id: self.stage.query_id.id.clone(),
639 stage_id: root_stage_id,
640 task_id: 0,
641 };
642
643 let (result_tx, result_rx) = tokio::sync::mpsc::channel(
645 self.ctx
646 .session
647 .env()
648 .batch_config()
649 .developer
650 .root_stage_channel_size,
651 );
652 self.notify_stage_scheduled(QueryMessage::Stage(StageEvent::ScheduledRoot(result_rx)))
653 .await;
654
655 let executor = ExecutorBuilder::new(
656 &plan_node,
657 &task_id,
658 self.ctx.to_batch_task_context(),
659 self.epoch,
660 shutdown_rx.clone(),
661 );
662
663 let shutdown_rx0 = shutdown_rx.clone();
664
665 let result = expr_context_scope(expr_context, async {
666 let executor = executor.build().await?;
667 let chunk_stream = executor.execute();
668 let cancelled = pin!(shutdown_rx.cancelled());
669 #[for_await]
670 for chunk in chunk_stream.take_until(cancelled) {
671 if let Err(ref e) = chunk {
672 if shutdown_rx0.is_cancelled() {
673 break;
674 }
675 let err_str = e.to_report_string();
676 if let Err(_e) = result_tx.send(chunk.map_err(|e| e.into())).await {
680 warn!("Root executor has been dropped before receive any events so the send is failed");
681 }
682 return Err(TaskExecutionError(err_str));
684 } else {
685 if let Err(_e) = result_tx.send(chunk.map_err(|e| e.into())).await {
687 warn!("Root executor has been dropped before receive any events so the send is failed");
688 }
689 }
690 }
691 Ok(())
692 }).await;
693
694 if let Err(err) = &result {
695 if let Err(_e) = result_tx
699 .send(Err(TaskExecutionError(err.to_report_string())))
700 .await
701 {
702 warn!("Send task execution failed");
703 }
704 }
705
706 match shutdown_rx0.message() {
708 ShutdownMsg::Abort(err_str) => {
709 if let Err(_e) = result_tx.send(Err(TaskExecutionError(err_str))).await {
711 warn!("Send task execution failed");
712 }
713 }
714 _ => self.notify_stage_completed().await,
715 }
716
717 tracing::trace!(
718 "Stage runner [{:?}-{:?}] existed. ",
719 self.stage.query_id,
720 self.stage.id
721 );
722
723 result.map(|_| ())
726 }
727
728 async fn schedule_tasks_for_all(&mut self, shutdown_rx: ShutdownToken) -> SchedulerResult<()> {
729 let expr_context = ExprContext {
730 time_zone: self.ctx.session().config().timezone().to_owned(),
731 strict_mode: self.ctx.session().config().batch_expr_strict_mode(),
732 };
733 if !self.is_root_stage() {
735 self.schedule_tasks(shutdown_rx, expr_context).await?;
736 } else {
737 self.schedule_tasks_for_root(shutdown_rx, expr_context)
738 .await?;
739 }
740 Ok(())
741 }
742
743 #[inline(always)]
744 fn get_fragment_id(&self, table_id: &TableId) -> SchedulerResult<FragmentId> {
745 self.catalog_reader
746 .read_guard()
747 .get_any_table_by_id(table_id)
748 .map(|table| table.fragment_id)
749 .map_err(|e| SchedulerError::Internal(anyhow!(e)))
750 }
751
752 #[inline(always)]
753 fn get_table_dml_vnode_mapping(
754 &self,
755 table_id: &TableId,
756 ) -> SchedulerResult<WorkerSlotMapping> {
757 let guard = self.catalog_reader.read_guard();
758
759 let table = guard
760 .get_any_table_by_id(table_id)
761 .map_err(|e| SchedulerError::Internal(anyhow!(e)))?;
762
763 let fragment_id = match table.dml_fragment_id.as_ref() {
764 Some(dml_fragment_id) => dml_fragment_id,
765 None => &table.fragment_id,
767 };
768
769 self.worker_node_manager
770 .manager
771 .get_streaming_fragment_mapping(fragment_id)
772 .map_err(|e| e.into())
773 }
774
775 fn choose_worker(
776 &self,
777 plan_fragment: &PlanFragment,
778 task_id: u32,
779 dml_table_id: Option<TableId>,
780 ) -> SchedulerResult<Option<WorkerNode>> {
781 let plan_node = plan_fragment.root.as_ref().expect("fail to get plan node");
782
783 if let Some(table_id) = dml_table_id {
784 let vnode_mapping = self.get_table_dml_vnode_mapping(&table_id)?;
785 let worker_slot_ids = vnode_mapping.iter_unique().collect_vec();
786 let candidates = self
787 .worker_node_manager
788 .manager
789 .get_workers_by_worker_slot_ids(&worker_slot_ids)?;
790 if candidates.is_empty() {
791 return Err(BatchError::EmptyWorkerNodes.into());
792 }
793 let candidate = if self.stage.batch_enable_distributed_dml {
794 candidates[task_id as usize % candidates.len()].clone()
797 } else {
798 candidates[self.stage.session_id.0 as usize % candidates.len()].clone()
800 };
801 return Ok(Some(candidate));
802 };
803
804 if let Some(distributed_lookup_join_node) =
805 Self::find_distributed_lookup_join_node(plan_node)
806 {
807 let fragment_id = self.get_fragment_id(
808 &distributed_lookup_join_node
809 .inner_side_table_desc
810 .as_ref()
811 .unwrap()
812 .table_id
813 .into(),
814 )?;
815 let id_to_worker_slots = self
816 .worker_node_manager
817 .fragment_mapping(fragment_id)?
818 .iter_unique()
819 .collect_vec();
820
821 let worker_slot_id = id_to_worker_slots[task_id as usize];
822 let candidates = self
823 .worker_node_manager
824 .manager
825 .get_workers_by_worker_slot_ids(&[worker_slot_id])?;
826 if candidates.is_empty() {
827 return Err(BatchError::EmptyWorkerNodes.into());
828 }
829 Ok(Some(candidates[0].clone()))
830 } else {
831 Ok(None)
832 }
833 }
834
835 fn find_distributed_lookup_join_node(
836 plan_node: &PlanNode,
837 ) -> Option<&DistributedLookupJoinNode> {
838 let node_body = plan_node.node_body.as_ref().expect("fail to get node body");
839
840 match node_body {
841 NodeBody::DistributedLookupJoin(distributed_lookup_join_node) => {
842 Some(distributed_lookup_join_node)
843 }
844 _ => plan_node
845 .children
846 .iter()
847 .find_map(Self::find_distributed_lookup_join_node),
848 }
849 }
850
851 async fn notify_stage_scheduled(&self, msg: QueryMessage) {
853 self.notify_stage_state_changed(
854 |old_state| {
855 assert_matches!(old_state, StageState::Started);
856 StageState::Running
857 },
858 msg,
859 )
860 .await
861 }
862
863 async fn notify_stage_completed(&self) {
865 self.notify_stage_state_changed(
866 |old_state| {
867 assert_matches!(old_state, StageState::Running);
868 StageState::Completed
869 },
870 QueryMessage::Stage(StageEvent::Completed(self.stage.id)),
871 )
872 .await
873 }
874
875 async fn notify_stage_state_changed<F>(&self, new_state: F, msg: QueryMessage)
876 where
877 F: FnOnce(StageState) -> StageState,
878 {
879 {
880 let mut s = self.state.write().await;
881 let old_state = mem::replace(&mut *s, StageState::Failed);
882 *s = new_state(old_state);
883 }
884
885 self.send_event(msg).await;
886 }
887
888 async fn cancel_all_scheducancled_tasks(&self) -> SchedulerResult<()> {
892 for (task, task_status) in &*self.tasks {
904 let loc = &task_status.get_status().location;
906 let addr = loc.as_ref().expect("Get address should not fail");
907 let client = self
908 .compute_client_pool
909 .get_by_addr(HostAddr::from(addr))
910 .await
911 .map_err(|e| anyhow!(e))?;
912
913 let query_id = self.stage.query_id.id.clone();
915 let stage_id = self.stage.id;
916 let task_id = *task;
917 spawn(async move {
918 if let Err(e) = client
919 .cancel(CancelTaskRequest {
920 task_id: Some(risingwave_pb::batch_plan::TaskId {
921 query_id: query_id.clone(),
922 stage_id,
923 task_id,
924 }),
925 })
926 .await
927 {
928 error!(
929 error = %e.as_report(),
930 ?task_id,
931 ?query_id,
932 ?stage_id,
933 "Abort task failed",
934 );
935 };
936 });
937 }
938 Ok(())
939 }
940
941 async fn schedule_task(
942 &self,
943 task_id: PbTaskId,
944 plan_fragment: PlanFragment,
945 worker: Option<WorkerNode>,
946 expr_context: ExprContext,
947 ) -> SchedulerResult<Fuse<Streaming<TaskInfoResponse>>> {
948 let mut worker = worker.unwrap_or(self.worker_node_manager.next_random_worker()?);
949 let worker_node_addr = worker.host.take().unwrap();
950 let compute_client = self
951 .compute_client_pool
952 .get_by_addr((&worker_node_addr).into())
953 .await
954 .inspect_err(|_| self.mask_failed_serving_worker(&worker))
955 .map_err(|e| anyhow!(e))?;
956
957 let t_id = task_id.task_id;
958
959 let stream_status: Fuse<Streaming<TaskInfoResponse>> = compute_client
960 .create_task(task_id, plan_fragment, self.epoch, expr_context)
961 .await
962 .inspect_err(|_| self.mask_failed_serving_worker(&worker))
963 .map_err(|e| anyhow!(e))?
964 .fuse();
965
966 self.tasks[&t_id].inner.store(Arc::new(TaskStatus {
967 _task_id: t_id,
968 location: Some(worker_node_addr),
969 }));
970
971 Ok(stream_status)
972 }
973
974 pub fn create_plan_fragment(
975 &self,
976 task_id: TaskId,
977 partition: Option<PartitionInfo>,
978 ) -> PlanFragment {
979 let identity_id: Rc<RefCell<u64>> = Rc::new(RefCell::new(0));
981
982 let plan_node_prost =
983 self.convert_plan_node(&self.stage.root, task_id, partition, identity_id);
984 let exchange_info = self.stage.exchange_info.clone().unwrap();
985
986 PlanFragment {
987 root: Some(plan_node_prost),
988 exchange_info: Some(exchange_info),
989 }
990 }
991
992 fn convert_plan_node(
993 &self,
994 execution_plan_node: &ExecutionPlanNode,
995 task_id: TaskId,
996 partition: Option<PartitionInfo>,
997 identity_id: Rc<RefCell<u64>>,
998 ) -> PbPlanNode {
999 let identity = {
1001 let identity_type = execution_plan_node.plan_node_type;
1002 let id = *identity_id.borrow();
1003 identity_id.replace(id + 1);
1004 format!("{:?}-{}", identity_type, id)
1005 };
1006
1007 match execution_plan_node.plan_node_type {
1008 PlanNodeType::BatchExchange => {
1009 let child_stage = self
1011 .children
1012 .iter()
1013 .find(|child_stage| {
1014 child_stage.stage.id == execution_plan_node.source_stage_id.unwrap()
1015 })
1016 .unwrap();
1017 let exchange_sources = child_stage.all_exchange_sources_for(task_id);
1018
1019 match &execution_plan_node.node {
1020 NodeBody::Exchange(exchange_node) => PbPlanNode {
1021 children: vec![],
1022 identity,
1023 node_body: Some(NodeBody::Exchange(ExchangeNode {
1024 sources: exchange_sources,
1025 sequential: exchange_node.sequential,
1026 input_schema: execution_plan_node.schema.clone(),
1027 })),
1028 },
1029 NodeBody::MergeSortExchange(sort_merge_exchange_node) => PbPlanNode {
1030 children: vec![],
1031 identity,
1032 node_body: Some(NodeBody::MergeSortExchange(MergeSortExchangeNode {
1033 exchange: Some(ExchangeNode {
1034 sources: exchange_sources,
1035 sequential: false,
1036 input_schema: execution_plan_node.schema.clone(),
1037 }),
1038 column_orders: sort_merge_exchange_node.column_orders.clone(),
1039 })),
1040 },
1041 _ => unreachable!(),
1042 }
1043 }
1044 PlanNodeType::BatchSeqScan => {
1045 let node_body = execution_plan_node.node.clone();
1046 let NodeBody::RowSeqScan(mut scan_node) = node_body else {
1047 unreachable!();
1048 };
1049 let partition = partition
1050 .expect("no partition info for seq scan")
1051 .into_table()
1052 .expect("PartitionInfo should be TablePartitionInfo");
1053 scan_node.vnode_bitmap = Some(partition.vnode_bitmap.to_protobuf());
1054 scan_node.scan_ranges = partition.scan_ranges;
1055 PbPlanNode {
1056 children: vec![],
1057 identity,
1058 node_body: Some(NodeBody::RowSeqScan(scan_node)),
1059 }
1060 }
1061 PlanNodeType::BatchLogSeqScan => {
1062 let node_body = execution_plan_node.node.clone();
1063 let NodeBody::LogRowSeqScan(mut scan_node) = node_body else {
1064 unreachable!();
1065 };
1066 let partition = partition
1067 .expect("no partition info for seq scan")
1068 .into_table()
1069 .expect("PartitionInfo should be TablePartitionInfo");
1070 scan_node.vnode_bitmap = Some(partition.vnode_bitmap.to_protobuf());
1071 PbPlanNode {
1072 children: vec![],
1073 identity,
1074 node_body: Some(NodeBody::LogRowSeqScan(scan_node)),
1075 }
1076 }
1077 PlanNodeType::BatchSource | PlanNodeType::BatchKafkaScan => {
1078 let node_body = execution_plan_node.node.clone();
1079 let NodeBody::Source(mut source_node) = node_body else {
1080 unreachable!();
1081 };
1082
1083 let partition = partition
1084 .expect("no partition info for seq scan")
1085 .into_source()
1086 .expect("PartitionInfo should be SourcePartitionInfo");
1087 source_node.split = partition
1088 .into_iter()
1089 .map(|split| split.encode_to_bytes().into())
1090 .collect_vec();
1091 PbPlanNode {
1092 children: vec![],
1093 identity,
1094 node_body: Some(NodeBody::Source(source_node)),
1095 }
1096 }
1097 PlanNodeType::BatchIcebergScan => {
1098 let node_body = execution_plan_node.node.clone();
1099 let NodeBody::IcebergScan(mut iceberg_scan_node) = node_body else {
1100 unreachable!();
1101 };
1102
1103 let partition = partition
1104 .expect("no partition info for seq scan")
1105 .into_source()
1106 .expect("PartitionInfo should be SourcePartitionInfo");
1107 iceberg_scan_node.split = partition
1108 .into_iter()
1109 .map(|split| split.encode_to_bytes().into())
1110 .collect_vec();
1111 PbPlanNode {
1112 children: vec![],
1113 identity,
1114 node_body: Some(NodeBody::IcebergScan(iceberg_scan_node)),
1115 }
1116 }
1117 _ => {
1118 let children = execution_plan_node
1119 .children
1120 .iter()
1121 .map(|e| {
1122 self.convert_plan_node(e, task_id, partition.clone(), identity_id.clone())
1123 })
1124 .collect();
1125
1126 PbPlanNode {
1127 children,
1128 identity,
1129 node_body: Some(execution_plan_node.node.clone()),
1130 }
1131 }
1132 }
1133 }
1134
1135 fn is_root_stage(&self) -> bool {
1136 self.stage.id == 0
1137 }
1138
1139 fn mask_failed_serving_worker(&self, worker: &WorkerNode) {
1140 if !worker.property.as_ref().is_some_and(|p| p.is_serving) {
1141 return;
1142 }
1143 let duration = Duration::from_secs(std::cmp::max(
1144 self.ctx
1145 .session
1146 .env()
1147 .batch_config()
1148 .mask_worker_temporary_secs as u64,
1149 1,
1150 ));
1151 self.worker_node_manager
1152 .manager
1153 .mask_worker_node(worker.id, duration);
1154 }
1155}