1use std::collections::HashMap;
16use std::pin::pin;
17use std::sync::Arc;
18use std::time::Duration;
19use std::{assert_matches, mem};
20
21use StageEvent::Failed;
22use anyhow::anyhow;
23use arc_swap::ArcSwap;
24use futures::stream::Fuse;
25use futures::{StreamExt, TryStreamExt, stream};
26use futures_async_stream::for_await;
27use itertools::Itertools;
28use risingwave_batch::error::BatchError;
29use risingwave_batch::executor::ExecutorBuilder;
30use risingwave_batch::task::{ShutdownMsg, ShutdownSender, ShutdownToken, TaskId as TaskIdBatch};
31use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeSelector;
32use risingwave_common::array::DataChunk;
33use risingwave_common::hash::WorkerSlotMapping;
34use risingwave_common::util::addr::HostAddr;
35use risingwave_common::util::iter_util::ZipEqFast;
36use risingwave_connector::source::SplitMetaData;
37use risingwave_expr::expr_context::expr_context_scope;
38use risingwave_pb::batch_plan::plan_node::NodeBody;
39use risingwave_pb::batch_plan::{
40 DistributedLookupJoinNode, ExchangeNode, ExchangeSource, MergeSortExchangeNode, PlanFragment,
41 PlanNode as PbPlanNode, PlanNode, TaskId as PbTaskId, TaskOutputId,
42};
43use risingwave_pb::common::{HostAddress, WorkerNode};
44use risingwave_pb::plan_common::ExprContext;
45use risingwave_pb::task_service::{CancelTaskRequest, TaskInfoResponse};
46use risingwave_rpc_client::ComputeClientPoolRef;
47use risingwave_rpc_client::error::RpcError;
48use rw_futures_util::select_all;
49use thiserror_ext::AsReport;
50use tokio::spawn;
51use tokio::sync::RwLock;
52use tokio::sync::mpsc::{Receiver, Sender};
53use tonic::Streaming;
54use tracing::{Instrument, debug, error, warn};
55
56use crate::catalog::catalog_service::CatalogReader;
57use crate::catalog::{FragmentId, TableId};
58use crate::optimizer::plan_node::BatchPlanNodeType;
59use crate::scheduler::SchedulerError::{TaskExecutionError, TaskRunningOutOfMemory};
60use crate::scheduler::distributed::QueryMessage;
61use crate::scheduler::distributed::stage::StageState::Pending;
62use crate::scheduler::plan_fragmenter::{
63 ExecutionPlanNode, PartitionInfo, Query, ROOT_TASK_ID, StageId, TaskId,
64};
65use crate::scheduler::{ExecutionContextRef, SchedulerError, SchedulerResult};
66
67const TASK_SCHEDULING_PARALLELISM: usize = 10;
68
69#[derive(Debug)]
70enum StageState {
71 Pending {
76 msg_sender: Sender<QueryMessage>,
77 },
78 Started,
79 Running,
80 Completed,
81 Failed,
82}
83
84#[derive(Debug)]
85pub enum StageEvent {
86 Scheduled(StageId),
87 ScheduledRoot(Receiver<SchedulerResult<DataChunk>>),
88 Failed {
90 id: StageId,
91 reason: SchedulerError,
92 },
93 Completed(#[expect(dead_code)] StageId),
95}
96
97#[derive(Clone)]
98pub struct TaskStatus {
99 _task_id: TaskId,
100
101 location: Option<HostAddress>,
103}
104
105struct TaskStatusHolder {
106 inner: ArcSwap<TaskStatus>,
107}
108
109pub struct StageExecution {
110 stage_id: StageId,
111 query: Arc<Query>,
112 worker_node_manager: WorkerNodeSelector,
113 tasks: Arc<HashMap<TaskId, TaskStatusHolder>>,
114 state: Arc<RwLock<StageState>>,
115 shutdown_tx: RwLock<Option<ShutdownSender>>,
116 children: Vec<Arc<StageExecution>>,
120 compute_client_pool: ComputeClientPoolRef,
121 catalog_reader: CatalogReader,
122
123 ctx: ExecutionContextRef,
125}
126
127struct StageRunner {
128 state: Arc<RwLock<StageState>>,
129 stage_id: StageId,
130 query: Arc<Query>,
131 worker_node_manager: WorkerNodeSelector,
132 tasks: Arc<HashMap<TaskId, TaskStatusHolder>>,
133 msg_sender: Sender<QueryMessage>,
135 children: Vec<Arc<StageExecution>>,
136 compute_client_pool: ComputeClientPoolRef,
137 catalog_reader: CatalogReader,
138
139 ctx: ExecutionContextRef,
140}
141
142impl TaskStatusHolder {
143 fn new(task_id: TaskId) -> Self {
144 let task_status = TaskStatus {
145 _task_id: task_id,
146 location: None,
147 };
148
149 Self {
150 inner: ArcSwap::new(Arc::new(task_status)),
151 }
152 }
153
154 fn get_status(&self) -> Arc<TaskStatus> {
155 self.inner.load_full()
156 }
157}
158
159impl StageExecution {
160 pub fn new(
161 stage_id: StageId,
162 query: Arc<Query>,
163 worker_node_manager: WorkerNodeSelector,
164 msg_sender: Sender<QueryMessage>,
165 children: Vec<Arc<StageExecution>>,
166 compute_client_pool: ComputeClientPoolRef,
167 catalog_reader: CatalogReader,
168 ctx: ExecutionContextRef,
169 ) -> Self {
170 let tasks = (0..query.stage(stage_id).parallelism.unwrap())
171 .map(|task_id| (task_id as u64, TaskStatusHolder::new(task_id as u64)))
172 .collect();
173
174 Self {
175 stage_id,
176 query,
177 worker_node_manager,
178 tasks: Arc::new(tasks),
179 state: Arc::new(RwLock::new(Pending { msg_sender })),
180 shutdown_tx: RwLock::new(None),
181 children,
182 compute_client_pool,
183 catalog_reader,
184 ctx,
185 }
186 }
187
188 pub async fn start(&self) {
190 let mut s = self.state.write().await;
191 let cur_state = mem::replace(&mut *s, StageState::Failed);
192 match cur_state {
193 Pending { msg_sender } => {
194 let runner = StageRunner {
195 stage_id: self.stage_id,
196 query: self.query.clone(),
197 worker_node_manager: self.worker_node_manager.clone(),
198 tasks: self.tasks.clone(),
199 msg_sender,
200 children: self.children.clone(),
201 state: self.state.clone(),
202 compute_client_pool: self.compute_client_pool.clone(),
203 catalog_reader: self.catalog_reader.clone(),
204 ctx: self.ctx.clone(),
205 };
206
207 let (sender, receiver) = ShutdownToken::new();
209 let mut holder = self.shutdown_tx.write().await;
211 *holder = Some(sender);
212
213 *s = StageState::Started;
215
216 let span = tracing::info_span!(
217 "stage",
218 "otel.name" = format!("Stage {}-{}", self.query.query_id.id, self.stage_id),
219 query_id = self.query.query_id.id,
220 stage_id = %self.stage_id,
221 );
222 self.ctx
223 .session()
224 .env()
225 .compute_runtime()
226 .spawn(async move { runner.run(receiver).instrument(span).await });
227
228 tracing::trace!(
229 "Stage {:?}-{:?} started.",
230 self.query.query_id.id,
231 self.stage_id
232 )
233 }
234 _ => {
235 unreachable!("Only expect to schedule stage once");
236 }
237 }
238 }
239
240 pub async fn stop(&self, error: Option<String>) {
241 if let Some(shutdown_tx) = self.shutdown_tx.write().await.take() {
243 if !if let Some(error) = error {
247 shutdown_tx.abort(error)
248 } else {
249 shutdown_tx.cancel()
250 } {
251 tracing::trace!(
253 "Failed to send stop message stage: {:?}-{:?}",
254 self.query.query_id,
255 self.stage_id
256 );
257 }
258 }
259 }
260
261 pub async fn is_scheduled(&self) -> bool {
262 let s = self.state.read().await;
263 matches!(*s, StageState::Running | StageState::Completed)
264 }
265
266 pub async fn is_pending(&self) -> bool {
267 let s = self.state.read().await;
268 matches!(*s, StageState::Pending { .. })
269 }
270
271 pub async fn state(&self) -> &'static str {
272 let s = self.state.read().await;
273 match *s {
274 Pending { .. } => "Pending",
275 StageState::Started => "Started",
276 StageState::Running => "Running",
277 StageState::Completed => "Completed",
278 StageState::Failed => "Failed",
279 }
280 }
281
282 pub fn all_exchange_sources_for(&self, output_id: u64) -> Vec<ExchangeSource> {
289 self.tasks
290 .iter()
291 .map(|(task_id, status_holder)| {
292 let task_output_id = TaskOutputId {
293 task_id: Some(PbTaskId {
294 query_id: self.query.query_id.id.clone(),
295 stage_id: self.stage_id.into(),
296 task_id: *task_id,
297 }),
298 output_id,
299 };
300
301 ExchangeSource {
302 task_output_id: Some(task_output_id),
303 host: Some(status_holder.inner.load_full().location.clone().unwrap()),
304 local_execute_plan: None,
305 }
306 })
307 .collect()
308 }
309}
310
311impl StageRunner {
312 async fn run(mut self, shutdown_rx: ShutdownToken) {
313 if let Err(e) = self.schedule_tasks_for_all(shutdown_rx).await {
314 error!(
315 error = %e.as_report(),
316 query_id = ?self.query.query_id,
317 stage_id = ?self.stage_id,
318 "Failed to schedule tasks"
319 );
320 self.send_event(QueryMessage::Stage(Failed {
321 id: self.stage_id,
322 reason: e,
323 }))
324 .await;
325 }
326 }
327
328 async fn send_event(&self, event: QueryMessage) {
330 if let Err(_e) = self.msg_sender.send(event).await {
331 warn!("Failed to send event to Query Runner, may be killed by previous failed event");
332 }
333 }
334
335 async fn schedule_tasks(
338 &mut self,
339 mut shutdown_rx: ShutdownToken,
340 expr_context: ExprContext,
341 ) -> SchedulerResult<()> {
342 let mut futures = vec![];
343 let stage = &self.query.stage(self.stage_id);
344
345 if let Some(table_scan_info) = stage.table_scan_info.as_ref()
346 && let Some(vnode_bitmaps) = table_scan_info.partitions()
347 {
348 let worker_slot_ids = vnode_bitmaps.keys().cloned().collect_vec();
354 let workers = self
355 .worker_node_manager
356 .manager
357 .get_workers_by_worker_slot_ids(&worker_slot_ids)?;
358
359 for (i, (worker_slot_id, worker)) in worker_slot_ids
360 .into_iter()
361 .zip_eq_fast(workers.into_iter())
362 .enumerate()
363 {
364 let task_id = PbTaskId {
365 query_id: self.query.query_id.id.clone(),
366 stage_id: self.stage_id.into(),
367 task_id: i as u64,
368 };
369 let vnode_ranges = vnode_bitmaps[&worker_slot_id].clone();
370 let plan_fragment =
371 self.create_plan_fragment(i as u64, Some(PartitionInfo::Table(vnode_ranges)));
372 futures.push(self.schedule_task(
373 task_id,
374 plan_fragment,
375 Some(worker),
376 expr_context.clone(),
377 ));
378 }
379 } else if let Some(source_info) = stage.source_info.as_ref() {
380 let chunk_size = ((source_info.split_info().unwrap().len() as f32
382 / stage.parallelism.unwrap() as f32)
383 .ceil() as usize)
384 .max(1);
385 if source_info.split_info().unwrap().is_empty() {
386 const EMPTY_TASK_ID: u64 = 0;
388 let task_id = PbTaskId {
389 query_id: self.query.query_id.id.clone(),
390 stage_id: self.stage_id.into(),
391 task_id: EMPTY_TASK_ID,
392 };
393 let plan_fragment =
394 self.create_plan_fragment(EMPTY_TASK_ID, Some(PartitionInfo::Source(vec![])));
395 let worker =
396 self.choose_worker(&plan_fragment, EMPTY_TASK_ID as u32, stage.dml_table_id)?;
397 futures.push(self.schedule_task(
398 task_id,
399 plan_fragment,
400 worker,
401 expr_context.clone(),
402 ));
403 } else {
404 for (id, split) in source_info
405 .split_info()
406 .unwrap()
407 .chunks(chunk_size)
408 .enumerate()
409 {
410 let task_id = PbTaskId {
411 query_id: self.query.query_id.id.clone(),
412 stage_id: self.stage_id.into(),
413 task_id: id as u64,
414 };
415 let plan_fragment = self.create_plan_fragment(
416 id as u64,
417 Some(PartitionInfo::Source(split.to_vec())),
418 );
419 let worker =
420 self.choose_worker(&plan_fragment, id as u32, stage.dml_table_id)?;
421 futures.push(self.schedule_task(
422 task_id,
423 plan_fragment,
424 worker,
425 expr_context.clone(),
426 ));
427 }
428 }
429 } else if let Some(file_scan_info) = stage.file_scan_info.as_ref() {
430 let chunk_size = (file_scan_info.file_location.len() as f32
431 / stage.parallelism.unwrap() as f32)
432 .ceil() as usize;
433 for (id, files) in file_scan_info.file_location.chunks(chunk_size).enumerate() {
434 let task_id = PbTaskId {
435 query_id: self.query.query_id.id.clone(),
436 stage_id: self.stage_id.into(),
437 task_id: id as u64,
438 };
439 let plan_fragment =
440 self.create_plan_fragment(id as u64, Some(PartitionInfo::File(files.to_vec())));
441 let worker = self.choose_worker(&plan_fragment, id as u32, stage.dml_table_id)?;
442 futures.push(self.schedule_task(
443 task_id,
444 plan_fragment,
445 worker,
446 expr_context.clone(),
447 ));
448 }
449 } else {
450 for id in 0..stage.parallelism.unwrap() {
451 let task_id = PbTaskId {
452 query_id: self.query.query_id.id.clone(),
453 stage_id: self.stage_id.into(),
454 task_id: id as u64,
455 };
456 let plan_fragment = self.create_plan_fragment(id as u64, None);
457 let worker = self.choose_worker(&plan_fragment, id, stage.dml_table_id)?;
458 futures.push(self.schedule_task(
459 task_id,
460 plan_fragment,
461 worker,
462 expr_context.clone(),
463 ));
464 }
465 }
466
467 let buffered = stream::iter(futures).buffer_unordered(TASK_SCHEDULING_PARALLELISM);
469 let buffered_streams = buffered.try_collect::<Vec<_>>().await?;
470
471 let cancelled = pin!(shutdown_rx.cancelled());
473 let mut all_streams = select_all(buffered_streams).take_until(cancelled);
474
475 let mut running_task_cnt = 0;
477 let mut finished_task_cnt = 0;
478 let mut sent_signal_to_next = false;
479
480 while let Some(status_res_inner) = all_streams.next().await {
481 match status_res_inner {
482 Ok(status) => {
483 use risingwave_pb::task_service::task_info_response::TaskStatus as PbTaskStatus;
484 match PbTaskStatus::try_from(status.task_status).unwrap() {
485 PbTaskStatus::Running => {
486 running_task_cnt += 1;
487 assert!(running_task_cnt <= self.tasks.keys().len());
490 if running_task_cnt == self.tasks.keys().len() {
493 self.notify_stage_scheduled(QueryMessage::Stage(
494 StageEvent::Scheduled(self.stage_id),
495 ))
496 .await;
497 sent_signal_to_next = true;
498 }
499 }
500
501 PbTaskStatus::Finished => {
502 finished_task_cnt += 1;
503 assert!(finished_task_cnt <= self.tasks.keys().len());
504 assert!(running_task_cnt >= finished_task_cnt);
505 if finished_task_cnt == self.tasks.keys().len() {
506 self.notify_stage_completed().await;
509 sent_signal_to_next = true;
510 break;
511 }
512 }
513 PbTaskStatus::Aborted => {
514 error!(
518 "Abort task {:?} because of excessive memory usage. Please try again later.",
519 status.task_id.unwrap()
520 );
521 self.notify_stage_state_changed(
522 |_| StageState::Failed,
523 QueryMessage::Stage(Failed {
524 id: self.stage_id,
525 reason: TaskRunningOutOfMemory,
526 }),
527 )
528 .await;
529 sent_signal_to_next = true;
530 break;
531 }
532 PbTaskStatus::Failed => {
533 error!(
535 "Task {:?} failed, reason: {:?}",
536 status.task_id.unwrap(),
537 status.error_message,
538 );
539 self.notify_stage_state_changed(
540 |_| StageState::Failed,
541 QueryMessage::Stage(Failed {
542 id: self.stage_id,
543 reason: TaskExecutionError(status.error_message),
544 }),
545 )
546 .await;
547 sent_signal_to_next = true;
548 break;
549 }
550 PbTaskStatus::Ping => {
551 debug!("Receive ping from task {:?}", status.task_id.unwrap());
552 }
553 status => {
554 unreachable!("Unexpected task status {:?}", status);
557 }
558 }
559 }
560 Err(e) => {
561 error!(
563 "Fetching task status in stage {:?} failed, reason: {:?}",
564 self.stage_id,
565 e.message()
566 );
567 self.notify_stage_state_changed(
568 |_| StageState::Failed,
569 QueryMessage::Stage(Failed {
570 id: self.stage_id,
571 reason: RpcError::from_batch_status(e).into(),
572 }),
573 )
574 .await;
575 sent_signal_to_next = true;
576 break;
577 }
578 }
579 }
580
581 tracing::trace!(
582 "Stage [{:?}-{:?}], running task count: {}, finished task count: {}, sent signal to next: {}",
583 self.query.query_id,
584 self.stage_id,
585 running_task_cnt,
586 finished_task_cnt,
587 sent_signal_to_next,
588 );
589
590 if let Some(shutdown) = all_streams.take_future() {
591 tracing::trace!(
592 "Stage [{:?}-{:?}] waiting for stopping signal.",
593 self.query.query_id,
594 self.stage_id
595 );
596 shutdown.await;
598 }
599
600 tracing::trace!(
605 "Stopping stage: {:?}-{:?}, task_num: {}",
606 self.query.query_id,
607 self.stage_id,
608 self.tasks.len()
609 );
610 self.cancel_all_scheducancled_tasks().await?;
611
612 tracing::trace!(
613 "Stage runner [{:?}-{:?}] exited.",
614 self.query.query_id,
615 self.stage_id
616 );
617 Ok(())
618 }
619
620 async fn schedule_tasks_for_root(
621 &mut self,
622 mut shutdown_rx: ShutdownToken,
623 expr_context: ExprContext,
624 ) -> SchedulerResult<()> {
625 let root_stage_id = self.stage_id;
626 let plan_fragment = self.create_plan_fragment(ROOT_TASK_ID, None);
629 let plan_node = plan_fragment.root.unwrap();
630 let task_id = TaskIdBatch {
631 query_id: self.query.query_id.id.clone(),
632 stage_id: root_stage_id.into(),
633 task_id: 0,
634 };
635
636 let (result_tx, result_rx) = tokio::sync::mpsc::channel(
638 self.ctx
639 .session
640 .env()
641 .batch_config()
642 .developer
643 .root_stage_channel_size,
644 );
645 self.notify_stage_scheduled(QueryMessage::Stage(StageEvent::ScheduledRoot(result_rx)))
646 .await;
647
648 let executor = ExecutorBuilder::new(
649 &plan_node,
650 &task_id,
651 self.ctx.to_batch_task_context(),
652 shutdown_rx.clone(),
653 );
654
655 let shutdown_rx0 = shutdown_rx.clone();
656
657 let result = expr_context_scope(expr_context, async {
658 let executor = executor.build().await?;
659 let chunk_stream = executor.execute();
660 let cancelled = pin!(shutdown_rx.cancelled());
661 #[for_await]
662 for chunk in chunk_stream.take_until(cancelled) {
663 if let Err(ref e) = chunk {
664 if shutdown_rx0.is_cancelled() {
665 break;
666 }
667 let err_str = e.to_report_string();
668 if let Err(_e) = result_tx.send(chunk.map_err(|e| e.into())).await {
672 warn!("Root executor has been dropped before receive any events so the send is failed");
673 }
674 return Err(TaskExecutionError(err_str));
676 } else {
677 if let Err(_e) = result_tx.send(chunk.map_err(|e| e.into())).await {
679 warn!("Root executor has been dropped before receive any events so the send is failed");
680 }
681 }
682 }
683 Ok(())
684 }).await;
685
686 if let Err(err) = &result {
687 if let Err(_e) = result_tx
691 .send(Err(TaskExecutionError(err.to_report_string())))
692 .await
693 {
694 warn!("Send task execution failed");
695 }
696 }
697
698 match shutdown_rx0.message() {
700 ShutdownMsg::Abort(err_str) => {
701 if let Err(_e) = result_tx.send(Err(TaskExecutionError(err_str))).await {
703 warn!("Send task execution failed");
704 }
705 }
706 _ => self.notify_stage_completed().await,
707 }
708
709 tracing::trace!(
710 "Stage runner [{:?}-{:?}] existed. ",
711 self.query.query_id,
712 self.stage_id
713 );
714
715 result.map(|_| ())
718 }
719
720 async fn schedule_tasks_for_all(&mut self, shutdown_rx: ShutdownToken) -> SchedulerResult<()> {
721 let expr_context = ExprContext {
722 time_zone: self.ctx.session().config().timezone(),
723 strict_mode: self.ctx.session().config().batch_expr_strict_mode(),
724 };
725 if !self.is_root_stage() {
727 self.schedule_tasks(shutdown_rx, expr_context).await?;
728 } else {
729 self.schedule_tasks_for_root(shutdown_rx, expr_context)
730 .await?;
731 }
732 Ok(())
733 }
734
735 #[inline(always)]
736 fn get_fragment_id(&self, table_id: TableId) -> SchedulerResult<FragmentId> {
737 self.catalog_reader
738 .read_guard()
739 .get_any_table_by_id(table_id)
740 .map(|table| table.fragment_id)
741 .map_err(|e| SchedulerError::Internal(anyhow!(e)))
742 }
743
744 #[inline(always)]
745 fn get_table_dml_vnode_mapping(&self, table_id: TableId) -> SchedulerResult<WorkerSlotMapping> {
746 let guard = self.catalog_reader.read_guard();
747
748 let table = guard
749 .get_any_table_by_id(table_id)
750 .map_err(|e| SchedulerError::Internal(anyhow!(e)))?;
751
752 let fragment_id = match table.dml_fragment_id.as_ref() {
753 Some(dml_fragment_id) => dml_fragment_id,
754 None => &table.fragment_id,
756 };
757
758 self.worker_node_manager
759 .manager
760 .get_streaming_fragment_mapping(fragment_id)
761 .map_err(|e| e.into())
762 }
763
764 fn choose_worker(
765 &self,
766 plan_fragment: &PlanFragment,
767 task_id: u32,
768 dml_table_id: Option<TableId>,
769 ) -> SchedulerResult<Option<WorkerNode>> {
770 let plan_node = plan_fragment.root.as_ref().expect("fail to get plan node");
771
772 if let Some(table_id) = dml_table_id {
773 let vnode_mapping = self.get_table_dml_vnode_mapping(table_id)?;
774 let worker_slot_ids = vnode_mapping.iter_unique().collect_vec();
775 let candidates = self
776 .worker_node_manager
777 .manager
778 .get_workers_by_worker_slot_ids(&worker_slot_ids)?;
779 if candidates.is_empty() {
780 return Err(BatchError::EmptyWorkerNodes.into());
781 }
782 let stage = &self.query.stage(self.stage_id);
783 let candidate = if stage.batch_enable_distributed_dml {
784 candidates[task_id as usize % candidates.len()].clone()
787 } else {
788 candidates[stage.session_id.0 as usize % candidates.len()].clone()
790 };
791 return Ok(Some(candidate));
792 };
793
794 if let Some(distributed_lookup_join_node) =
795 Self::find_distributed_lookup_join_node(plan_node)
796 {
797 let fragment_id = self.get_fragment_id(
798 distributed_lookup_join_node
799 .inner_side_table_desc
800 .as_ref()
801 .unwrap()
802 .table_id,
803 )?;
804 let id_to_worker_slots = self
805 .worker_node_manager
806 .fragment_mapping(fragment_id)?
807 .iter_unique()
808 .collect_vec();
809
810 let worker_slot_id = id_to_worker_slots[task_id as usize];
811 let candidates = self
812 .worker_node_manager
813 .manager
814 .get_workers_by_worker_slot_ids(&[worker_slot_id])?;
815 if candidates.is_empty() {
816 return Err(BatchError::EmptyWorkerNodes.into());
817 }
818 Ok(Some(candidates[0].clone()))
819 } else {
820 Ok(None)
821 }
822 }
823
824 fn find_distributed_lookup_join_node(
825 plan_node: &PlanNode,
826 ) -> Option<&DistributedLookupJoinNode> {
827 let node_body = plan_node.node_body.as_ref().expect("fail to get node body");
828
829 match node_body {
830 NodeBody::DistributedLookupJoin(distributed_lookup_join_node) => {
831 Some(distributed_lookup_join_node)
832 }
833 _ => plan_node
834 .children
835 .iter()
836 .find_map(Self::find_distributed_lookup_join_node),
837 }
838 }
839
840 async fn notify_stage_scheduled(&self, msg: QueryMessage) {
842 self.notify_stage_state_changed(
843 |old_state| {
844 assert_matches!(old_state, StageState::Started);
845 StageState::Running
846 },
847 msg,
848 )
849 .await
850 }
851
852 async fn notify_stage_completed(&self) {
854 self.notify_stage_state_changed(
855 |old_state| {
856 assert_matches!(old_state, StageState::Running);
857 StageState::Completed
858 },
859 QueryMessage::Stage(StageEvent::Completed(self.stage_id)),
860 )
861 .await
862 }
863
864 async fn notify_stage_state_changed<F>(&self, new_state: F, msg: QueryMessage)
865 where
866 F: FnOnce(StageState) -> StageState,
867 {
868 {
869 let mut s = self.state.write().await;
870 let old_state = mem::replace(&mut *s, StageState::Failed);
871 *s = new_state(old_state);
872 }
873
874 self.send_event(msg).await;
875 }
876
877 async fn cancel_all_scheducancled_tasks(&self) -> SchedulerResult<()> {
881 for (task, task_status) in &*self.tasks {
893 let loc = &task_status.get_status().location;
895 let addr = loc.as_ref().expect("Get address should not fail");
896 let client = self
897 .compute_client_pool
898 .get_by_addr(HostAddr::from(addr))
899 .await
900 .map_err(|e| anyhow!(e))?;
901
902 let query_id = self.query.query_id.id.clone();
904 let stage_id = self.stage_id;
905 let task_id = *task;
906 spawn(async move {
907 if let Err(e) = client
908 .cancel(CancelTaskRequest {
909 task_id: Some(risingwave_pb::batch_plan::TaskId {
910 query_id: query_id.clone(),
911 stage_id: stage_id.into(),
912 task_id,
913 }),
914 })
915 .await
916 {
917 error!(
918 error = %e.as_report(),
919 ?task_id,
920 ?query_id,
921 ?stage_id,
922 "Abort task failed",
923 );
924 };
925 });
926 }
927 Ok(())
928 }
929
930 async fn schedule_task(
931 &self,
932 task_id: PbTaskId,
933 plan_fragment: PlanFragment,
934 worker: Option<WorkerNode>,
935 expr_context: ExprContext,
936 ) -> SchedulerResult<Fuse<Streaming<TaskInfoResponse>>> {
937 let mut worker = worker.unwrap_or(self.worker_node_manager.next_random_worker()?);
938 let worker_node_addr = worker.host.take().unwrap();
939 let compute_client = self
940 .compute_client_pool
941 .get_by_addr((&worker_node_addr).into())
942 .await
943 .inspect_err(|_| self.mask_failed_serving_worker(&worker))
944 .map_err(|e| anyhow!(e))?;
945
946 let t_id = task_id.task_id;
947
948 let stream_status: Fuse<Streaming<TaskInfoResponse>> = compute_client
949 .create_task(task_id, plan_fragment, expr_context)
950 .await
951 .inspect_err(|_| self.mask_failed_serving_worker(&worker))
952 .map_err(|e| anyhow!(e))?
953 .fuse();
954
955 self.tasks[&t_id].inner.store(Arc::new(TaskStatus {
956 _task_id: t_id,
957 location: Some(worker_node_addr),
958 }));
959
960 Ok(stream_status)
961 }
962
963 fn create_plan_fragment(
964 &self,
965 task_id: TaskId,
966 partition: Option<PartitionInfo>,
967 ) -> PlanFragment {
968 let mut identity_id = 0;
970
971 let stage = &self.query.stage(self.stage_id);
972
973 let plan_node_prost =
974 self.convert_plan_node(&stage.root, task_id, partition, &mut identity_id);
975 let exchange_info = stage.exchange_info.clone().unwrap();
976
977 PlanFragment {
978 root: Some(plan_node_prost),
979 exchange_info: Some(exchange_info),
980 }
981 }
982
983 fn convert_plan_node(
984 &self,
985 execution_plan_node: &ExecutionPlanNode,
986 task_id: TaskId,
987 partition: Option<PartitionInfo>,
988 identity_id: &mut u64,
989 ) -> PbPlanNode {
990 let identity = {
992 let identity_type = execution_plan_node.plan_node_type;
993 let id = *identity_id;
994 *identity_id += 1;
995 format!("{:?}-{}", identity_type, id)
996 };
997
998 match execution_plan_node.plan_node_type {
999 BatchPlanNodeType::BatchExchange => {
1000 let child_stage = self
1002 .children
1003 .iter()
1004 .find(|child_stage| {
1005 child_stage.stage_id == execution_plan_node.source_stage_id.unwrap()
1006 })
1007 .unwrap();
1008 let exchange_sources = child_stage.all_exchange_sources_for(task_id);
1009
1010 match &execution_plan_node.node {
1011 NodeBody::Exchange(exchange_node) => PbPlanNode {
1012 children: vec![],
1013 identity,
1014 node_body: Some(NodeBody::Exchange(ExchangeNode {
1015 sources: exchange_sources,
1016 sequential: exchange_node.sequential,
1017 input_schema: execution_plan_node.schema.clone(),
1018 })),
1019 },
1020 NodeBody::MergeSortExchange(sort_merge_exchange_node) => PbPlanNode {
1021 children: vec![],
1022 identity,
1023 node_body: Some(NodeBody::MergeSortExchange(MergeSortExchangeNode {
1024 exchange: Some(ExchangeNode {
1025 sources: exchange_sources,
1026 sequential: false,
1027 input_schema: execution_plan_node.schema.clone(),
1028 }),
1029 column_orders: sort_merge_exchange_node.column_orders.clone(),
1030 })),
1031 },
1032 _ => unreachable!(),
1033 }
1034 }
1035 BatchPlanNodeType::BatchSeqScan => {
1036 let node_body = execution_plan_node.node.clone();
1037 let NodeBody::RowSeqScan(mut scan_node) = node_body else {
1038 unreachable!();
1039 };
1040 let partition = partition
1041 .expect("no partition info for seq scan")
1042 .into_table()
1043 .expect("PartitionInfo should be TablePartitionInfo");
1044 scan_node.vnode_bitmap = Some(partition.vnode_bitmap.to_protobuf());
1045 scan_node.scan_ranges = partition.scan_ranges;
1046 PbPlanNode {
1047 children: vec![],
1048 identity,
1049 node_body: Some(NodeBody::RowSeqScan(scan_node)),
1050 }
1051 }
1052 BatchPlanNodeType::BatchLogSeqScan => {
1053 let node_body = execution_plan_node.node.clone();
1054 let NodeBody::LogRowSeqScan(mut scan_node) = node_body else {
1055 unreachable!();
1056 };
1057 let partition = partition
1058 .expect("no partition info for seq scan")
1059 .into_table()
1060 .expect("PartitionInfo should be TablePartitionInfo");
1061 scan_node.vnode_bitmap = Some(partition.vnode_bitmap.to_protobuf());
1062 PbPlanNode {
1063 children: vec![],
1064 identity,
1065 node_body: Some(NodeBody::LogRowSeqScan(scan_node)),
1066 }
1067 }
1068 BatchPlanNodeType::BatchSource | BatchPlanNodeType::BatchKafkaScan => {
1069 let node_body = execution_plan_node.node.clone();
1070 let NodeBody::Source(mut source_node) = node_body else {
1071 unreachable!();
1072 };
1073
1074 let partition = partition
1075 .expect("no partition info for seq scan")
1076 .into_source()
1077 .expect("PartitionInfo should be SourcePartitionInfo");
1078 source_node.split = partition
1079 .into_iter()
1080 .map(|split| split.encode_to_bytes().into())
1081 .collect_vec();
1082 PbPlanNode {
1083 children: vec![],
1084 identity,
1085 node_body: Some(NodeBody::Source(source_node)),
1086 }
1087 }
1088 BatchPlanNodeType::BatchIcebergScan => {
1089 let node_body = execution_plan_node.node.clone();
1090 let NodeBody::IcebergScan(mut iceberg_scan_node) = node_body else {
1091 unreachable!();
1092 };
1093
1094 let partition = partition
1095 .expect("no partition info for seq scan")
1096 .into_source()
1097 .expect("PartitionInfo should be SourcePartitionInfo");
1098 iceberg_scan_node.split = partition
1099 .into_iter()
1100 .map(|split| split.encode_to_bytes().into())
1101 .collect_vec();
1102 PbPlanNode {
1103 children: vec![],
1104 identity,
1105 node_body: Some(NodeBody::IcebergScan(iceberg_scan_node)),
1106 }
1107 }
1108 _ => {
1109 let children = execution_plan_node
1110 .children
1111 .iter()
1112 .map(|e| self.convert_plan_node(e, task_id, partition.clone(), identity_id))
1113 .collect();
1114
1115 PbPlanNode {
1116 children,
1117 identity,
1118 node_body: Some(execution_plan_node.node.clone()),
1119 }
1120 }
1121 }
1122 }
1123
1124 fn is_root_stage(&self) -> bool {
1125 self.stage_id == 0.into()
1126 }
1127
1128 fn mask_failed_serving_worker(&self, worker: &WorkerNode) {
1129 if !worker.property.as_ref().is_some_and(|p| p.is_serving) {
1130 return;
1131 }
1132 let duration = Duration::from_secs(std::cmp::max(
1133 self.ctx
1134 .session
1135 .env()
1136 .batch_config()
1137 .mask_worker_temporary_secs as u64,
1138 1,
1139 ));
1140 self.worker_node_manager
1141 .manager
1142 .mask_worker_node(worker.id, duration);
1143 }
1144}