1use std::assert_matches::assert_matches;
16use std::cell::RefCell;
17use std::collections::HashMap;
18use std::mem;
19use std::pin::pin;
20use std::rc::Rc;
21use std::sync::Arc;
22use std::time::Duration;
23
24use StageEvent::Failed;
25use anyhow::anyhow;
26use arc_swap::ArcSwap;
27use futures::stream::Fuse;
28use futures::{StreamExt, TryStreamExt, stream};
29use futures_async_stream::for_await;
30use itertools::Itertools;
31use risingwave_batch::error::BatchError;
32use risingwave_batch::executor::ExecutorBuilder;
33use risingwave_batch::task::{ShutdownMsg, ShutdownSender, ShutdownToken, TaskId as TaskIdBatch};
34use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeSelector;
35use risingwave_common::array::DataChunk;
36use risingwave_common::hash::WorkerSlotMapping;
37use risingwave_common::util::addr::HostAddr;
38use risingwave_common::util::iter_util::ZipEqFast;
39use risingwave_connector::source::SplitMetaData;
40use risingwave_expr::expr_context::expr_context_scope;
41use risingwave_pb::batch_plan::plan_node::NodeBody;
42use risingwave_pb::batch_plan::{
43 DistributedLookupJoinNode, ExchangeNode, ExchangeSource, MergeSortExchangeNode, PlanFragment,
44 PlanNode as PbPlanNode, PlanNode, TaskId as PbTaskId, TaskOutputId,
45};
46use risingwave_pb::common::{BatchQueryEpoch, HostAddress, WorkerNode};
47use risingwave_pb::plan_common::ExprContext;
48use risingwave_pb::task_service::{CancelTaskRequest, TaskInfoResponse};
49use risingwave_rpc_client::ComputeClientPoolRef;
50use risingwave_rpc_client::error::RpcError;
51use rw_futures_util::select_all;
52use thiserror_ext::AsReport;
53use tokio::spawn;
54use tokio::sync::RwLock;
55use tokio::sync::mpsc::{Receiver, Sender};
56use tonic::Streaming;
57use tracing::{Instrument, debug, error, warn};
58
59use crate::catalog::catalog_service::CatalogReader;
60use crate::catalog::{FragmentId, TableId};
61use crate::optimizer::plan_node::BatchPlanNodeType;
62use crate::scheduler::SchedulerError::{TaskExecutionError, TaskRunningOutOfMemory};
63use crate::scheduler::distributed::QueryMessage;
64use crate::scheduler::distributed::stage::StageState::Pending;
65use crate::scheduler::plan_fragmenter::{
66 ExecutionPlanNode, PartitionInfo, QueryStageRef, ROOT_TASK_ID, StageId, TaskId,
67};
68use crate::scheduler::{ExecutionContextRef, SchedulerError, SchedulerResult};
69
70const TASK_SCHEDULING_PARALLELISM: usize = 10;
71
72#[derive(Debug)]
73enum StageState {
74 Pending {
79 msg_sender: Sender<QueryMessage>,
80 },
81 Started,
82 Running,
83 Completed,
84 Failed,
85}
86
87#[derive(Debug)]
88pub enum StageEvent {
89 Scheduled(StageId),
90 ScheduledRoot(Receiver<SchedulerResult<DataChunk>>),
91 Failed {
93 id: StageId,
94 reason: SchedulerError,
95 },
96 Completed(#[allow(dead_code)] StageId),
98}
99
100#[derive(Clone)]
101pub struct TaskStatus {
102 _task_id: TaskId,
103
104 location: Option<HostAddress>,
106}
107
108struct TaskStatusHolder {
109 inner: ArcSwap<TaskStatus>,
110}
111
112pub struct StageExecution {
113 epoch: BatchQueryEpoch,
114 stage: QueryStageRef,
115 worker_node_manager: WorkerNodeSelector,
116 tasks: Arc<HashMap<TaskId, TaskStatusHolder>>,
117 state: Arc<RwLock<StageState>>,
118 shutdown_tx: RwLock<Option<ShutdownSender>>,
119 children: Vec<Arc<StageExecution>>,
123 compute_client_pool: ComputeClientPoolRef,
124 catalog_reader: CatalogReader,
125
126 ctx: ExecutionContextRef,
128}
129
130struct StageRunner {
131 epoch: BatchQueryEpoch,
132 state: Arc<RwLock<StageState>>,
133 stage: QueryStageRef,
134 worker_node_manager: WorkerNodeSelector,
135 tasks: Arc<HashMap<TaskId, TaskStatusHolder>>,
136 msg_sender: Sender<QueryMessage>,
138 children: Vec<Arc<StageExecution>>,
139 compute_client_pool: ComputeClientPoolRef,
140 catalog_reader: CatalogReader,
141
142 ctx: ExecutionContextRef,
143}
144
145impl TaskStatusHolder {
146 fn new(task_id: TaskId) -> Self {
147 let task_status = TaskStatus {
148 _task_id: task_id,
149 location: None,
150 };
151
152 Self {
153 inner: ArcSwap::new(Arc::new(task_status)),
154 }
155 }
156
157 fn get_status(&self) -> Arc<TaskStatus> {
158 self.inner.load_full()
159 }
160}
161
162impl StageExecution {
163 #[allow(clippy::too_many_arguments)]
164 pub fn new(
165 epoch: BatchQueryEpoch,
166 stage: QueryStageRef,
167 worker_node_manager: WorkerNodeSelector,
168 msg_sender: Sender<QueryMessage>,
169 children: Vec<Arc<StageExecution>>,
170 compute_client_pool: ComputeClientPoolRef,
171 catalog_reader: CatalogReader,
172 ctx: ExecutionContextRef,
173 ) -> Self {
174 let tasks = (0..stage.parallelism.unwrap())
175 .map(|task_id| (task_id as u64, TaskStatusHolder::new(task_id as u64)))
176 .collect();
177
178 Self {
179 epoch,
180 stage,
181 worker_node_manager,
182 tasks: Arc::new(tasks),
183 state: Arc::new(RwLock::new(Pending { msg_sender })),
184 shutdown_tx: RwLock::new(None),
185 children,
186 compute_client_pool,
187 catalog_reader,
188 ctx,
189 }
190 }
191
192 pub async fn start(&self) {
194 let mut s = self.state.write().await;
195 let cur_state = mem::replace(&mut *s, StageState::Failed);
196 match cur_state {
197 Pending { msg_sender } => {
198 let runner = StageRunner {
199 epoch: self.epoch,
200 stage: self.stage.clone(),
201 worker_node_manager: self.worker_node_manager.clone(),
202 tasks: self.tasks.clone(),
203 msg_sender,
204 children: self.children.clone(),
205 state: self.state.clone(),
206 compute_client_pool: self.compute_client_pool.clone(),
207 catalog_reader: self.catalog_reader.clone(),
208 ctx: self.ctx.clone(),
209 };
210
211 let (sender, receiver) = ShutdownToken::new();
213 let mut holder = self.shutdown_tx.write().await;
215 *holder = Some(sender);
216
217 *s = StageState::Started;
219
220 let span = tracing::info_span!(
221 "stage",
222 "otel.name" = format!("Stage {}-{}", self.stage.query_id.id, self.stage.id),
223 query_id = self.stage.query_id.id,
224 stage_id = self.stage.id,
225 );
226 self.ctx
227 .session()
228 .env()
229 .compute_runtime()
230 .spawn(async move { runner.run(receiver).instrument(span).await });
231
232 tracing::trace!(
233 "Stage {:?}-{:?} started.",
234 self.stage.query_id.id,
235 self.stage.id
236 )
237 }
238 _ => {
239 unreachable!("Only expect to schedule stage once");
240 }
241 }
242 }
243
244 pub async fn stop(&self, error: Option<String>) {
245 if let Some(shutdown_tx) = self.shutdown_tx.write().await.take() {
247 if !if let Some(error) = error {
251 shutdown_tx.abort(error)
252 } else {
253 shutdown_tx.cancel()
254 } {
255 tracing::trace!(
257 "Failed to send stop message stage: {:?}-{:?}",
258 self.stage.query_id,
259 self.stage.id
260 );
261 }
262 }
263 }
264
265 pub async fn is_scheduled(&self) -> bool {
266 let s = self.state.read().await;
267 matches!(*s, StageState::Running | StageState::Completed)
268 }
269
270 pub async fn is_pending(&self) -> bool {
271 let s = self.state.read().await;
272 matches!(*s, StageState::Pending { .. })
273 }
274
275 pub async fn state(&self) -> &'static str {
276 let s = self.state.read().await;
277 match *s {
278 Pending { .. } => "Pending",
279 StageState::Started => "Started",
280 StageState::Running => "Running",
281 StageState::Completed => "Completed",
282 StageState::Failed => "Failed",
283 }
284 }
285
286 pub fn all_exchange_sources_for(&self, output_id: u64) -> Vec<ExchangeSource> {
293 self.tasks
294 .iter()
295 .map(|(task_id, status_holder)| {
296 let task_output_id = TaskOutputId {
297 task_id: Some(PbTaskId {
298 query_id: self.stage.query_id.id.clone(),
299 stage_id: self.stage.id,
300 task_id: *task_id,
301 }),
302 output_id,
303 };
304
305 ExchangeSource {
306 task_output_id: Some(task_output_id),
307 host: Some(status_holder.inner.load_full().location.clone().unwrap()),
308 local_execute_plan: None,
309 }
310 })
311 .collect()
312 }
313}
314
315impl StageRunner {
316 async fn run(mut self, shutdown_rx: ShutdownToken) {
317 if let Err(e) = self.schedule_tasks_for_all(shutdown_rx).await {
318 error!(
319 error = %e.as_report(),
320 query_id = ?self.stage.query_id,
321 stage_id = ?self.stage.id,
322 "Failed to schedule tasks"
323 );
324 self.send_event(QueryMessage::Stage(Failed {
325 id: self.stage.id,
326 reason: e,
327 }))
328 .await;
329 }
330 }
331
332 async fn send_event(&self, event: QueryMessage) {
334 if let Err(_e) = self.msg_sender.send(event).await {
335 warn!("Failed to send event to Query Runner, may be killed by previous failed event");
336 }
337 }
338
339 async fn schedule_tasks(
342 &mut self,
343 mut shutdown_rx: ShutdownToken,
344 expr_context: ExprContext,
345 ) -> SchedulerResult<()> {
346 let mut futures = vec![];
347
348 if let Some(table_scan_info) = self.stage.table_scan_info.as_ref()
349 && let Some(vnode_bitmaps) = table_scan_info.partitions()
350 {
351 let worker_slot_ids = vnode_bitmaps.keys().cloned().collect_vec();
357 let workers = self
358 .worker_node_manager
359 .manager
360 .get_workers_by_worker_slot_ids(&worker_slot_ids)?;
361
362 for (i, (worker_slot_id, worker)) in worker_slot_ids
363 .into_iter()
364 .zip_eq_fast(workers.into_iter())
365 .enumerate()
366 {
367 let task_id = PbTaskId {
368 query_id: self.stage.query_id.id.clone(),
369 stage_id: self.stage.id,
370 task_id: i as u64,
371 };
372 let vnode_ranges = vnode_bitmaps[&worker_slot_id].clone();
373 let plan_fragment =
374 self.create_plan_fragment(i as u64, Some(PartitionInfo::Table(vnode_ranges)));
375 futures.push(self.schedule_task(
376 task_id,
377 plan_fragment,
378 Some(worker),
379 expr_context.clone(),
380 ));
381 }
382 } else if let Some(source_info) = self.stage.source_info.as_ref() {
383 let chunk_size = ((source_info.split_info().unwrap().len() as f32
385 / self.stage.parallelism.unwrap() as f32)
386 .ceil() as usize)
387 .max(1);
388 if source_info.split_info().unwrap().is_empty() {
389 const EMPTY_TASK_ID: u64 = 0;
391 let task_id = PbTaskId {
392 query_id: self.stage.query_id.id.clone(),
393 stage_id: self.stage.id,
394 task_id: EMPTY_TASK_ID,
395 };
396 let plan_fragment =
397 self.create_plan_fragment(EMPTY_TASK_ID, Some(PartitionInfo::Source(vec![])));
398 let worker = self.choose_worker(
399 &plan_fragment,
400 EMPTY_TASK_ID as u32,
401 self.stage.dml_table_id,
402 )?;
403 futures.push(self.schedule_task(
404 task_id,
405 plan_fragment,
406 worker,
407 expr_context.clone(),
408 ));
409 } else {
410 for (id, split) in source_info
411 .split_info()
412 .unwrap()
413 .chunks(chunk_size)
414 .enumerate()
415 {
416 let task_id = PbTaskId {
417 query_id: self.stage.query_id.id.clone(),
418 stage_id: self.stage.id,
419 task_id: id as u64,
420 };
421 let plan_fragment = self.create_plan_fragment(
422 id as u64,
423 Some(PartitionInfo::Source(split.to_vec())),
424 );
425 let worker =
426 self.choose_worker(&plan_fragment, id as u32, self.stage.dml_table_id)?;
427 futures.push(self.schedule_task(
428 task_id,
429 plan_fragment,
430 worker,
431 expr_context.clone(),
432 ));
433 }
434 }
435 } else if let Some(file_scan_info) = self.stage.file_scan_info.as_ref() {
436 let chunk_size = (file_scan_info.file_location.len() as f32
437 / self.stage.parallelism.unwrap() as f32)
438 .ceil() as usize;
439 for (id, files) in file_scan_info.file_location.chunks(chunk_size).enumerate() {
440 let task_id = PbTaskId {
441 query_id: self.stage.query_id.id.clone(),
442 stage_id: self.stage.id,
443 task_id: id as u64,
444 };
445 let plan_fragment =
446 self.create_plan_fragment(id as u64, Some(PartitionInfo::File(files.to_vec())));
447 let worker =
448 self.choose_worker(&plan_fragment, id as u32, self.stage.dml_table_id)?;
449 futures.push(self.schedule_task(
450 task_id,
451 plan_fragment,
452 worker,
453 expr_context.clone(),
454 ));
455 }
456 } else if self.stage.has_vector_search {
457 let id = 0;
458 let task_id = PbTaskId {
459 query_id: self.stage.query_id.id.clone(),
460 stage_id: self.stage.id,
461 task_id: id as u64,
462 };
463 let plan_fragment = self.create_plan_fragment(id as u64, None);
464 let worker = self.choose_worker(&plan_fragment, id, self.stage.dml_table_id)?;
465 futures.push(self.schedule_task(task_id, plan_fragment, worker, expr_context.clone()));
466 } else {
467 for id in 0..self.stage.parallelism.unwrap() {
468 let task_id = PbTaskId {
469 query_id: self.stage.query_id.id.clone(),
470 stage_id: self.stage.id,
471 task_id: id as u64,
472 };
473 let plan_fragment = self.create_plan_fragment(id as u64, None);
474 let worker = self.choose_worker(&plan_fragment, id, self.stage.dml_table_id)?;
475 futures.push(self.schedule_task(
476 task_id,
477 plan_fragment,
478 worker,
479 expr_context.clone(),
480 ));
481 }
482 }
483
484 let buffered = stream::iter(futures).buffer_unordered(TASK_SCHEDULING_PARALLELISM);
486 let buffered_streams = buffered.try_collect::<Vec<_>>().await?;
487
488 let cancelled = pin!(shutdown_rx.cancelled());
490 let mut all_streams = select_all(buffered_streams).take_until(cancelled);
491
492 let mut running_task_cnt = 0;
494 let mut finished_task_cnt = 0;
495 let mut sent_signal_to_next = false;
496
497 while let Some(status_res_inner) = all_streams.next().await {
498 match status_res_inner {
499 Ok(status) => {
500 use risingwave_pb::task_service::task_info_response::TaskStatus as PbTaskStatus;
501 match PbTaskStatus::try_from(status.task_status).unwrap() {
502 PbTaskStatus::Running => {
503 running_task_cnt += 1;
504 assert!(running_task_cnt <= self.tasks.keys().len());
507 if running_task_cnt == self.tasks.keys().len() {
510 self.notify_stage_scheduled(QueryMessage::Stage(
511 StageEvent::Scheduled(self.stage.id),
512 ))
513 .await;
514 sent_signal_to_next = true;
515 }
516 }
517
518 PbTaskStatus::Finished => {
519 finished_task_cnt += 1;
520 assert!(finished_task_cnt <= self.tasks.keys().len());
521 assert!(running_task_cnt >= finished_task_cnt);
522 if finished_task_cnt == self.tasks.keys().len() {
523 self.notify_stage_completed().await;
526 sent_signal_to_next = true;
527 break;
528 }
529 }
530 PbTaskStatus::Aborted => {
531 error!(
535 "Abort task {:?} because of excessive memory usage. Please try again later.",
536 status.task_id.unwrap()
537 );
538 self.notify_stage_state_changed(
539 |_| StageState::Failed,
540 QueryMessage::Stage(Failed {
541 id: self.stage.id,
542 reason: TaskRunningOutOfMemory,
543 }),
544 )
545 .await;
546 sent_signal_to_next = true;
547 break;
548 }
549 PbTaskStatus::Failed => {
550 error!(
552 "Task {:?} failed, reason: {:?}",
553 status.task_id.unwrap(),
554 status.error_message,
555 );
556 self.notify_stage_state_changed(
557 |_| StageState::Failed,
558 QueryMessage::Stage(Failed {
559 id: self.stage.id,
560 reason: TaskExecutionError(status.error_message),
561 }),
562 )
563 .await;
564 sent_signal_to_next = true;
565 break;
566 }
567 PbTaskStatus::Ping => {
568 debug!("Receive ping from task {:?}", status.task_id.unwrap());
569 }
570 status => {
571 unreachable!("Unexpected task status {:?}", status);
574 }
575 }
576 }
577 Err(e) => {
578 error!(
580 "Fetching task status in stage {:?} failed, reason: {:?}",
581 self.stage.id,
582 e.message()
583 );
584 self.notify_stage_state_changed(
585 |_| StageState::Failed,
586 QueryMessage::Stage(Failed {
587 id: self.stage.id,
588 reason: RpcError::from_batch_status(e).into(),
589 }),
590 )
591 .await;
592 sent_signal_to_next = true;
593 break;
594 }
595 }
596 }
597
598 tracing::trace!(
599 "Stage [{:?}-{:?}], running task count: {}, finished task count: {}, sent signal to next: {}",
600 self.stage.query_id,
601 self.stage.id,
602 running_task_cnt,
603 finished_task_cnt,
604 sent_signal_to_next,
605 );
606
607 if let Some(shutdown) = all_streams.take_future() {
608 tracing::trace!(
609 "Stage [{:?}-{:?}] waiting for stopping signal.",
610 self.stage.query_id,
611 self.stage.id
612 );
613 shutdown.await;
615 }
616
617 tracing::trace!(
622 "Stopping stage: {:?}-{:?}, task_num: {}",
623 self.stage.query_id,
624 self.stage.id,
625 self.tasks.len()
626 );
627 self.cancel_all_scheducancled_tasks().await?;
628
629 tracing::trace!(
630 "Stage runner [{:?}-{:?}] exited.",
631 self.stage.query_id,
632 self.stage.id
633 );
634 Ok(())
635 }
636
637 async fn schedule_tasks_for_root(
638 &mut self,
639 mut shutdown_rx: ShutdownToken,
640 expr_context: ExprContext,
641 ) -> SchedulerResult<()> {
642 let root_stage_id = self.stage.id;
643 let plan_fragment = self.create_plan_fragment(ROOT_TASK_ID, None);
646 let plan_node = plan_fragment.root.unwrap();
647 let task_id = TaskIdBatch {
648 query_id: self.stage.query_id.id.clone(),
649 stage_id: root_stage_id,
650 task_id: 0,
651 };
652
653 let (result_tx, result_rx) = tokio::sync::mpsc::channel(
655 self.ctx
656 .session
657 .env()
658 .batch_config()
659 .developer
660 .root_stage_channel_size,
661 );
662 self.notify_stage_scheduled(QueryMessage::Stage(StageEvent::ScheduledRoot(result_rx)))
663 .await;
664
665 let executor = ExecutorBuilder::new(
666 &plan_node,
667 &task_id,
668 self.ctx.to_batch_task_context(),
669 self.epoch,
670 shutdown_rx.clone(),
671 );
672
673 let shutdown_rx0 = shutdown_rx.clone();
674
675 let result = expr_context_scope(expr_context, async {
676 let executor = executor.build().await?;
677 let chunk_stream = executor.execute();
678 let cancelled = pin!(shutdown_rx.cancelled());
679 #[for_await]
680 for chunk in chunk_stream.take_until(cancelled) {
681 if let Err(ref e) = chunk {
682 if shutdown_rx0.is_cancelled() {
683 break;
684 }
685 let err_str = e.to_report_string();
686 if let Err(_e) = result_tx.send(chunk.map_err(|e| e.into())).await {
690 warn!("Root executor has been dropped before receive any events so the send is failed");
691 }
692 return Err(TaskExecutionError(err_str));
694 } else {
695 if let Err(_e) = result_tx.send(chunk.map_err(|e| e.into())).await {
697 warn!("Root executor has been dropped before receive any events so the send is failed");
698 }
699 }
700 }
701 Ok(())
702 }).await;
703
704 if let Err(err) = &result {
705 if let Err(_e) = result_tx
709 .send(Err(TaskExecutionError(err.to_report_string())))
710 .await
711 {
712 warn!("Send task execution failed");
713 }
714 }
715
716 match shutdown_rx0.message() {
718 ShutdownMsg::Abort(err_str) => {
719 if let Err(_e) = result_tx.send(Err(TaskExecutionError(err_str))).await {
721 warn!("Send task execution failed");
722 }
723 }
724 _ => self.notify_stage_completed().await,
725 }
726
727 tracing::trace!(
728 "Stage runner [{:?}-{:?}] existed. ",
729 self.stage.query_id,
730 self.stage.id
731 );
732
733 result.map(|_| ())
736 }
737
738 async fn schedule_tasks_for_all(&mut self, shutdown_rx: ShutdownToken) -> SchedulerResult<()> {
739 let expr_context = ExprContext {
740 time_zone: self.ctx.session().config().timezone().to_owned(),
741 strict_mode: self.ctx.session().config().batch_expr_strict_mode(),
742 };
743 if !self.is_root_stage() {
745 self.schedule_tasks(shutdown_rx, expr_context).await?;
746 } else {
747 self.schedule_tasks_for_root(shutdown_rx, expr_context)
748 .await?;
749 }
750 Ok(())
751 }
752
753 #[inline(always)]
754 fn get_fragment_id(&self, table_id: &TableId) -> SchedulerResult<FragmentId> {
755 self.catalog_reader
756 .read_guard()
757 .get_any_table_by_id(table_id)
758 .map(|table| table.fragment_id)
759 .map_err(|e| SchedulerError::Internal(anyhow!(e)))
760 }
761
762 #[inline(always)]
763 fn get_table_dml_vnode_mapping(
764 &self,
765 table_id: &TableId,
766 ) -> SchedulerResult<WorkerSlotMapping> {
767 let guard = self.catalog_reader.read_guard();
768
769 let table = guard
770 .get_any_table_by_id(table_id)
771 .map_err(|e| SchedulerError::Internal(anyhow!(e)))?;
772
773 let fragment_id = match table.dml_fragment_id.as_ref() {
774 Some(dml_fragment_id) => dml_fragment_id,
775 None => &table.fragment_id,
777 };
778
779 self.worker_node_manager
780 .manager
781 .get_streaming_fragment_mapping(fragment_id)
782 .map_err(|e| e.into())
783 }
784
785 fn choose_worker(
786 &self,
787 plan_fragment: &PlanFragment,
788 task_id: u32,
789 dml_table_id: Option<TableId>,
790 ) -> SchedulerResult<Option<WorkerNode>> {
791 let plan_node = plan_fragment.root.as_ref().expect("fail to get plan node");
792
793 if let Some(table_id) = dml_table_id {
794 let vnode_mapping = self.get_table_dml_vnode_mapping(&table_id)?;
795 let worker_slot_ids = vnode_mapping.iter_unique().collect_vec();
796 let candidates = self
797 .worker_node_manager
798 .manager
799 .get_workers_by_worker_slot_ids(&worker_slot_ids)?;
800 if candidates.is_empty() {
801 return Err(BatchError::EmptyWorkerNodes.into());
802 }
803 let candidate = if self.stage.batch_enable_distributed_dml {
804 candidates[task_id as usize % candidates.len()].clone()
807 } else {
808 candidates[self.stage.session_id.0 as usize % candidates.len()].clone()
810 };
811 return Ok(Some(candidate));
812 };
813
814 if let Some(distributed_lookup_join_node) =
815 Self::find_distributed_lookup_join_node(plan_node)
816 {
817 let fragment_id = self.get_fragment_id(
818 &distributed_lookup_join_node
819 .inner_side_table_desc
820 .as_ref()
821 .unwrap()
822 .table_id
823 .into(),
824 )?;
825 let id_to_worker_slots = self
826 .worker_node_manager
827 .fragment_mapping(fragment_id)?
828 .iter_unique()
829 .collect_vec();
830
831 let worker_slot_id = id_to_worker_slots[task_id as usize];
832 let candidates = self
833 .worker_node_manager
834 .manager
835 .get_workers_by_worker_slot_ids(&[worker_slot_id])?;
836 if candidates.is_empty() {
837 return Err(BatchError::EmptyWorkerNodes.into());
838 }
839 Ok(Some(candidates[0].clone()))
840 } else {
841 Ok(None)
842 }
843 }
844
845 fn find_distributed_lookup_join_node(
846 plan_node: &PlanNode,
847 ) -> Option<&DistributedLookupJoinNode> {
848 let node_body = plan_node.node_body.as_ref().expect("fail to get node body");
849
850 match node_body {
851 NodeBody::DistributedLookupJoin(distributed_lookup_join_node) => {
852 Some(distributed_lookup_join_node)
853 }
854 _ => plan_node
855 .children
856 .iter()
857 .find_map(Self::find_distributed_lookup_join_node),
858 }
859 }
860
861 async fn notify_stage_scheduled(&self, msg: QueryMessage) {
863 self.notify_stage_state_changed(
864 |old_state| {
865 assert_matches!(old_state, StageState::Started);
866 StageState::Running
867 },
868 msg,
869 )
870 .await
871 }
872
873 async fn notify_stage_completed(&self) {
875 self.notify_stage_state_changed(
876 |old_state| {
877 assert_matches!(old_state, StageState::Running);
878 StageState::Completed
879 },
880 QueryMessage::Stage(StageEvent::Completed(self.stage.id)),
881 )
882 .await
883 }
884
885 async fn notify_stage_state_changed<F>(&self, new_state: F, msg: QueryMessage)
886 where
887 F: FnOnce(StageState) -> StageState,
888 {
889 {
890 let mut s = self.state.write().await;
891 let old_state = mem::replace(&mut *s, StageState::Failed);
892 *s = new_state(old_state);
893 }
894
895 self.send_event(msg).await;
896 }
897
898 async fn cancel_all_scheducancled_tasks(&self) -> SchedulerResult<()> {
902 for (task, task_status) in &*self.tasks {
914 let loc = &task_status.get_status().location;
916 let addr = loc.as_ref().expect("Get address should not fail");
917 let client = self
918 .compute_client_pool
919 .get_by_addr(HostAddr::from(addr))
920 .await
921 .map_err(|e| anyhow!(e))?;
922
923 let query_id = self.stage.query_id.id.clone();
925 let stage_id = self.stage.id;
926 let task_id = *task;
927 spawn(async move {
928 if let Err(e) = client
929 .cancel(CancelTaskRequest {
930 task_id: Some(risingwave_pb::batch_plan::TaskId {
931 query_id: query_id.clone(),
932 stage_id,
933 task_id,
934 }),
935 })
936 .await
937 {
938 error!(
939 error = %e.as_report(),
940 ?task_id,
941 ?query_id,
942 ?stage_id,
943 "Abort task failed",
944 );
945 };
946 });
947 }
948 Ok(())
949 }
950
951 async fn schedule_task(
952 &self,
953 task_id: PbTaskId,
954 plan_fragment: PlanFragment,
955 worker: Option<WorkerNode>,
956 expr_context: ExprContext,
957 ) -> SchedulerResult<Fuse<Streaming<TaskInfoResponse>>> {
958 let mut worker = worker.unwrap_or(self.worker_node_manager.next_random_worker()?);
959 let worker_node_addr = worker.host.take().unwrap();
960 let compute_client = self
961 .compute_client_pool
962 .get_by_addr((&worker_node_addr).into())
963 .await
964 .inspect_err(|_| self.mask_failed_serving_worker(&worker))
965 .map_err(|e| anyhow!(e))?;
966
967 let t_id = task_id.task_id;
968
969 let stream_status: Fuse<Streaming<TaskInfoResponse>> = compute_client
970 .create_task(task_id, plan_fragment, self.epoch, expr_context)
971 .await
972 .inspect_err(|_| self.mask_failed_serving_worker(&worker))
973 .map_err(|e| anyhow!(e))?
974 .fuse();
975
976 self.tasks[&t_id].inner.store(Arc::new(TaskStatus {
977 _task_id: t_id,
978 location: Some(worker_node_addr),
979 }));
980
981 Ok(stream_status)
982 }
983
984 pub fn create_plan_fragment(
985 &self,
986 task_id: TaskId,
987 partition: Option<PartitionInfo>,
988 ) -> PlanFragment {
989 let identity_id: Rc<RefCell<u64>> = Rc::new(RefCell::new(0));
991
992 let plan_node_prost =
993 self.convert_plan_node(&self.stage.root, task_id, partition, identity_id);
994 let exchange_info = self.stage.exchange_info.clone().unwrap();
995
996 PlanFragment {
997 root: Some(plan_node_prost),
998 exchange_info: Some(exchange_info),
999 }
1000 }
1001
1002 fn convert_plan_node(
1003 &self,
1004 execution_plan_node: &ExecutionPlanNode,
1005 task_id: TaskId,
1006 partition: Option<PartitionInfo>,
1007 identity_id: Rc<RefCell<u64>>,
1008 ) -> PbPlanNode {
1009 let identity = {
1011 let identity_type = execution_plan_node.plan_node_type;
1012 let id = *identity_id.borrow();
1013 identity_id.replace(id + 1);
1014 format!("{:?}-{}", identity_type, id)
1015 };
1016
1017 match execution_plan_node.plan_node_type {
1018 BatchPlanNodeType::BatchExchange => {
1019 let child_stage = self
1021 .children
1022 .iter()
1023 .find(|child_stage| {
1024 child_stage.stage.id == execution_plan_node.source_stage_id.unwrap()
1025 })
1026 .unwrap();
1027 let exchange_sources = child_stage.all_exchange_sources_for(task_id);
1028
1029 match &execution_plan_node.node {
1030 NodeBody::Exchange(exchange_node) => PbPlanNode {
1031 children: vec![],
1032 identity,
1033 node_body: Some(NodeBody::Exchange(ExchangeNode {
1034 sources: exchange_sources,
1035 sequential: exchange_node.sequential,
1036 input_schema: execution_plan_node.schema.clone(),
1037 })),
1038 },
1039 NodeBody::MergeSortExchange(sort_merge_exchange_node) => PbPlanNode {
1040 children: vec![],
1041 identity,
1042 node_body: Some(NodeBody::MergeSortExchange(MergeSortExchangeNode {
1043 exchange: Some(ExchangeNode {
1044 sources: exchange_sources,
1045 sequential: false,
1046 input_schema: execution_plan_node.schema.clone(),
1047 }),
1048 column_orders: sort_merge_exchange_node.column_orders.clone(),
1049 })),
1050 },
1051 _ => unreachable!(),
1052 }
1053 }
1054 BatchPlanNodeType::BatchSeqScan => {
1055 let node_body = execution_plan_node.node.clone();
1056 let NodeBody::RowSeqScan(mut scan_node) = node_body else {
1057 unreachable!();
1058 };
1059 let partition = partition
1060 .expect("no partition info for seq scan")
1061 .into_table()
1062 .expect("PartitionInfo should be TablePartitionInfo");
1063 scan_node.vnode_bitmap = Some(partition.vnode_bitmap.to_protobuf());
1064 scan_node.scan_ranges = partition.scan_ranges;
1065 PbPlanNode {
1066 children: vec![],
1067 identity,
1068 node_body: Some(NodeBody::RowSeqScan(scan_node)),
1069 }
1070 }
1071 BatchPlanNodeType::BatchLogSeqScan => {
1072 let node_body = execution_plan_node.node.clone();
1073 let NodeBody::LogRowSeqScan(mut scan_node) = node_body else {
1074 unreachable!();
1075 };
1076 let partition = partition
1077 .expect("no partition info for seq scan")
1078 .into_table()
1079 .expect("PartitionInfo should be TablePartitionInfo");
1080 scan_node.vnode_bitmap = Some(partition.vnode_bitmap.to_protobuf());
1081 PbPlanNode {
1082 children: vec![],
1083 identity,
1084 node_body: Some(NodeBody::LogRowSeqScan(scan_node)),
1085 }
1086 }
1087 BatchPlanNodeType::BatchSource | BatchPlanNodeType::BatchKafkaScan => {
1088 let node_body = execution_plan_node.node.clone();
1089 let NodeBody::Source(mut source_node) = node_body else {
1090 unreachable!();
1091 };
1092
1093 let partition = partition
1094 .expect("no partition info for seq scan")
1095 .into_source()
1096 .expect("PartitionInfo should be SourcePartitionInfo");
1097 source_node.split = partition
1098 .into_iter()
1099 .map(|split| split.encode_to_bytes().into())
1100 .collect_vec();
1101 PbPlanNode {
1102 children: vec![],
1103 identity,
1104 node_body: Some(NodeBody::Source(source_node)),
1105 }
1106 }
1107 BatchPlanNodeType::BatchIcebergScan => {
1108 let node_body = execution_plan_node.node.clone();
1109 let NodeBody::IcebergScan(mut iceberg_scan_node) = node_body else {
1110 unreachable!();
1111 };
1112
1113 let partition = partition
1114 .expect("no partition info for seq scan")
1115 .into_source()
1116 .expect("PartitionInfo should be SourcePartitionInfo");
1117 iceberg_scan_node.split = partition
1118 .into_iter()
1119 .map(|split| split.encode_to_bytes().into())
1120 .collect_vec();
1121 PbPlanNode {
1122 children: vec![],
1123 identity,
1124 node_body: Some(NodeBody::IcebergScan(iceberg_scan_node)),
1125 }
1126 }
1127 _ => {
1128 let children = execution_plan_node
1129 .children
1130 .iter()
1131 .map(|e| {
1132 self.convert_plan_node(e, task_id, partition.clone(), identity_id.clone())
1133 })
1134 .collect();
1135
1136 PbPlanNode {
1137 children,
1138 identity,
1139 node_body: Some(execution_plan_node.node.clone()),
1140 }
1141 }
1142 }
1143 }
1144
1145 fn is_root_stage(&self) -> bool {
1146 self.stage.id == 0
1147 }
1148
1149 fn mask_failed_serving_worker(&self, worker: &WorkerNode) {
1150 if !worker.property.as_ref().is_some_and(|p| p.is_serving) {
1151 return;
1152 }
1153 let duration = Duration::from_secs(std::cmp::max(
1154 self.ctx
1155 .session
1156 .env()
1157 .batch_config()
1158 .mask_worker_temporary_secs as u64,
1159 1,
1160 ));
1161 self.worker_node_manager
1162 .manager
1163 .mask_worker_node(worker.id, duration);
1164 }
1165}