1use std::assert_matches::assert_matches;
16use std::collections::HashMap;
17use std::mem;
18use std::pin::pin;
19use std::sync::Arc;
20use std::time::Duration;
21
22use StageEvent::Failed;
23use anyhow::anyhow;
24use arc_swap::ArcSwap;
25use futures::stream::Fuse;
26use futures::{StreamExt, TryStreamExt, stream};
27use futures_async_stream::for_await;
28use itertools::Itertools;
29use risingwave_batch::error::BatchError;
30use risingwave_batch::executor::ExecutorBuilder;
31use risingwave_batch::task::{ShutdownMsg, ShutdownSender, ShutdownToken, TaskId as TaskIdBatch};
32use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeSelector;
33use risingwave_common::array::DataChunk;
34use risingwave_common::hash::WorkerSlotMapping;
35use risingwave_common::util::addr::HostAddr;
36use risingwave_common::util::iter_util::ZipEqFast;
37use risingwave_connector::source::SplitMetaData;
38use risingwave_expr::expr_context::expr_context_scope;
39use risingwave_pb::batch_plan::plan_node::NodeBody;
40use risingwave_pb::batch_plan::{
41 DistributedLookupJoinNode, ExchangeNode, ExchangeSource, MergeSortExchangeNode, PlanFragment,
42 PlanNode as PbPlanNode, PlanNode, TaskId as PbTaskId, TaskOutputId,
43};
44use risingwave_pb::common::{HostAddress, WorkerNode};
45use risingwave_pb::plan_common::ExprContext;
46use risingwave_pb::task_service::{CancelTaskRequest, TaskInfoResponse};
47use risingwave_rpc_client::ComputeClientPoolRef;
48use risingwave_rpc_client::error::RpcError;
49use rw_futures_util::select_all;
50use thiserror_ext::AsReport;
51use tokio::spawn;
52use tokio::sync::RwLock;
53use tokio::sync::mpsc::{Receiver, Sender};
54use tonic::Streaming;
55use tracing::{Instrument, debug, error, warn};
56
57use crate::catalog::catalog_service::CatalogReader;
58use crate::catalog::{FragmentId, TableId};
59use crate::optimizer::plan_node::BatchPlanNodeType;
60use crate::scheduler::SchedulerError::{TaskExecutionError, TaskRunningOutOfMemory};
61use crate::scheduler::distributed::QueryMessage;
62use crate::scheduler::distributed::stage::StageState::Pending;
63use crate::scheduler::plan_fragmenter::{
64 ExecutionPlanNode, PartitionInfo, Query, ROOT_TASK_ID, StageId, TaskId,
65};
66use crate::scheduler::{ExecutionContextRef, SchedulerError, SchedulerResult};
67
68const TASK_SCHEDULING_PARALLELISM: usize = 10;
69
70#[derive(Debug)]
71enum StageState {
72 Pending {
77 msg_sender: Sender<QueryMessage>,
78 },
79 Started,
80 Running,
81 Completed,
82 Failed,
83}
84
85#[derive(Debug)]
86pub enum StageEvent {
87 Scheduled(StageId),
88 ScheduledRoot(Receiver<SchedulerResult<DataChunk>>),
89 Failed {
91 id: StageId,
92 reason: SchedulerError,
93 },
94 Completed(#[allow(dead_code)] StageId),
96}
97
98#[derive(Clone)]
99pub struct TaskStatus {
100 _task_id: TaskId,
101
102 location: Option<HostAddress>,
104}
105
106struct TaskStatusHolder {
107 inner: ArcSwap<TaskStatus>,
108}
109
110pub struct StageExecution {
111 stage_id: StageId,
112 query: Arc<Query>,
113 worker_node_manager: WorkerNodeSelector,
114 tasks: Arc<HashMap<TaskId, TaskStatusHolder>>,
115 state: Arc<RwLock<StageState>>,
116 shutdown_tx: RwLock<Option<ShutdownSender>>,
117 children: Vec<Arc<StageExecution>>,
121 compute_client_pool: ComputeClientPoolRef,
122 catalog_reader: CatalogReader,
123
124 ctx: ExecutionContextRef,
126}
127
128struct StageRunner {
129 state: Arc<RwLock<StageState>>,
130 stage_id: StageId,
131 query: Arc<Query>,
132 worker_node_manager: WorkerNodeSelector,
133 tasks: Arc<HashMap<TaskId, TaskStatusHolder>>,
134 msg_sender: Sender<QueryMessage>,
136 children: Vec<Arc<StageExecution>>,
137 compute_client_pool: ComputeClientPoolRef,
138 catalog_reader: CatalogReader,
139
140 ctx: ExecutionContextRef,
141}
142
143impl TaskStatusHolder {
144 fn new(task_id: TaskId) -> Self {
145 let task_status = TaskStatus {
146 _task_id: task_id,
147 location: None,
148 };
149
150 Self {
151 inner: ArcSwap::new(Arc::new(task_status)),
152 }
153 }
154
155 fn get_status(&self) -> Arc<TaskStatus> {
156 self.inner.load_full()
157 }
158}
159
160impl StageExecution {
161 pub fn new(
162 stage_id: StageId,
163 query: Arc<Query>,
164 worker_node_manager: WorkerNodeSelector,
165 msg_sender: Sender<QueryMessage>,
166 children: Vec<Arc<StageExecution>>,
167 compute_client_pool: ComputeClientPoolRef,
168 catalog_reader: CatalogReader,
169 ctx: ExecutionContextRef,
170 ) -> Self {
171 let tasks = (0..query.stage(stage_id).parallelism.unwrap())
172 .map(|task_id| (task_id as u64, TaskStatusHolder::new(task_id as u64)))
173 .collect();
174
175 Self {
176 stage_id,
177 query,
178 worker_node_manager,
179 tasks: Arc::new(tasks),
180 state: Arc::new(RwLock::new(Pending { msg_sender })),
181 shutdown_tx: RwLock::new(None),
182 children,
183 compute_client_pool,
184 catalog_reader,
185 ctx,
186 }
187 }
188
189 pub async fn start(&self) {
191 let mut s = self.state.write().await;
192 let cur_state = mem::replace(&mut *s, StageState::Failed);
193 match cur_state {
194 Pending { msg_sender } => {
195 let runner = StageRunner {
196 stage_id: self.stage_id,
197 query: self.query.clone(),
198 worker_node_manager: self.worker_node_manager.clone(),
199 tasks: self.tasks.clone(),
200 msg_sender,
201 children: self.children.clone(),
202 state: self.state.clone(),
203 compute_client_pool: self.compute_client_pool.clone(),
204 catalog_reader: self.catalog_reader.clone(),
205 ctx: self.ctx.clone(),
206 };
207
208 let (sender, receiver) = ShutdownToken::new();
210 let mut holder = self.shutdown_tx.write().await;
212 *holder = Some(sender);
213
214 *s = StageState::Started;
216
217 let span = tracing::info_span!(
218 "stage",
219 "otel.name" = format!("Stage {}-{}", self.query.query_id.id, self.stage_id),
220 query_id = self.query.query_id.id,
221 stage_id = %self.stage_id,
222 );
223 self.ctx
224 .session()
225 .env()
226 .compute_runtime()
227 .spawn(async move { runner.run(receiver).instrument(span).await });
228
229 tracing::trace!(
230 "Stage {:?}-{:?} started.",
231 self.query.query_id.id,
232 self.stage_id
233 )
234 }
235 _ => {
236 unreachable!("Only expect to schedule stage once");
237 }
238 }
239 }
240
241 pub async fn stop(&self, error: Option<String>) {
242 if let Some(shutdown_tx) = self.shutdown_tx.write().await.take() {
244 if !if let Some(error) = error {
248 shutdown_tx.abort(error)
249 } else {
250 shutdown_tx.cancel()
251 } {
252 tracing::trace!(
254 "Failed to send stop message stage: {:?}-{:?}",
255 self.query.query_id,
256 self.stage_id
257 );
258 }
259 }
260 }
261
262 pub async fn is_scheduled(&self) -> bool {
263 let s = self.state.read().await;
264 matches!(*s, StageState::Running | StageState::Completed)
265 }
266
267 pub async fn is_pending(&self) -> bool {
268 let s = self.state.read().await;
269 matches!(*s, StageState::Pending { .. })
270 }
271
272 pub async fn state(&self) -> &'static str {
273 let s = self.state.read().await;
274 match *s {
275 Pending { .. } => "Pending",
276 StageState::Started => "Started",
277 StageState::Running => "Running",
278 StageState::Completed => "Completed",
279 StageState::Failed => "Failed",
280 }
281 }
282
283 pub fn all_exchange_sources_for(&self, output_id: u64) -> Vec<ExchangeSource> {
290 self.tasks
291 .iter()
292 .map(|(task_id, status_holder)| {
293 let task_output_id = TaskOutputId {
294 task_id: Some(PbTaskId {
295 query_id: self.query.query_id.id.clone(),
296 stage_id: self.stage_id.into(),
297 task_id: *task_id,
298 }),
299 output_id,
300 };
301
302 ExchangeSource {
303 task_output_id: Some(task_output_id),
304 host: Some(status_holder.inner.load_full().location.clone().unwrap()),
305 local_execute_plan: None,
306 }
307 })
308 .collect()
309 }
310}
311
312impl StageRunner {
313 async fn run(mut self, shutdown_rx: ShutdownToken) {
314 if let Err(e) = self.schedule_tasks_for_all(shutdown_rx).await {
315 error!(
316 error = %e.as_report(),
317 query_id = ?self.query.query_id,
318 stage_id = ?self.stage_id,
319 "Failed to schedule tasks"
320 );
321 self.send_event(QueryMessage::Stage(Failed {
322 id: self.stage_id,
323 reason: e,
324 }))
325 .await;
326 }
327 }
328
329 async fn send_event(&self, event: QueryMessage) {
331 if let Err(_e) = self.msg_sender.send(event).await {
332 warn!("Failed to send event to Query Runner, may be killed by previous failed event");
333 }
334 }
335
336 async fn schedule_tasks(
339 &mut self,
340 mut shutdown_rx: ShutdownToken,
341 expr_context: ExprContext,
342 ) -> SchedulerResult<()> {
343 let mut futures = vec![];
344 let stage = &self.query.stage(self.stage_id);
345
346 if let Some(table_scan_info) = stage.table_scan_info.as_ref()
347 && let Some(vnode_bitmaps) = table_scan_info.partitions()
348 {
349 let worker_slot_ids = vnode_bitmaps.keys().cloned().collect_vec();
355 let workers = self
356 .worker_node_manager
357 .manager
358 .get_workers_by_worker_slot_ids(&worker_slot_ids)?;
359
360 for (i, (worker_slot_id, worker)) in worker_slot_ids
361 .into_iter()
362 .zip_eq_fast(workers.into_iter())
363 .enumerate()
364 {
365 let task_id = PbTaskId {
366 query_id: self.query.query_id.id.clone(),
367 stage_id: self.stage_id.into(),
368 task_id: i as u64,
369 };
370 let vnode_ranges = vnode_bitmaps[&worker_slot_id].clone();
371 let plan_fragment =
372 self.create_plan_fragment(i as u64, Some(PartitionInfo::Table(vnode_ranges)));
373 futures.push(self.schedule_task(
374 task_id,
375 plan_fragment,
376 Some(worker),
377 expr_context.clone(),
378 ));
379 }
380 } else if let Some(source_info) = stage.source_info.as_ref() {
381 let chunk_size = ((source_info.split_info().unwrap().len() as f32
383 / stage.parallelism.unwrap() as f32)
384 .ceil() as usize)
385 .max(1);
386 if source_info.split_info().unwrap().is_empty() {
387 const EMPTY_TASK_ID: u64 = 0;
389 let task_id = PbTaskId {
390 query_id: self.query.query_id.id.clone(),
391 stage_id: self.stage_id.into(),
392 task_id: EMPTY_TASK_ID,
393 };
394 let plan_fragment =
395 self.create_plan_fragment(EMPTY_TASK_ID, Some(PartitionInfo::Source(vec![])));
396 let worker =
397 self.choose_worker(&plan_fragment, EMPTY_TASK_ID as u32, stage.dml_table_id)?;
398 futures.push(self.schedule_task(
399 task_id,
400 plan_fragment,
401 worker,
402 expr_context.clone(),
403 ));
404 } else {
405 for (id, split) in source_info
406 .split_info()
407 .unwrap()
408 .chunks(chunk_size)
409 .enumerate()
410 {
411 let task_id = PbTaskId {
412 query_id: self.query.query_id.id.clone(),
413 stage_id: self.stage_id.into(),
414 task_id: id as u64,
415 };
416 let plan_fragment = self.create_plan_fragment(
417 id as u64,
418 Some(PartitionInfo::Source(split.to_vec())),
419 );
420 let worker =
421 self.choose_worker(&plan_fragment, id as u32, stage.dml_table_id)?;
422 futures.push(self.schedule_task(
423 task_id,
424 plan_fragment,
425 worker,
426 expr_context.clone(),
427 ));
428 }
429 }
430 } else if let Some(file_scan_info) = stage.file_scan_info.as_ref() {
431 let chunk_size = (file_scan_info.file_location.len() as f32
432 / stage.parallelism.unwrap() as f32)
433 .ceil() as usize;
434 for (id, files) in file_scan_info.file_location.chunks(chunk_size).enumerate() {
435 let task_id = PbTaskId {
436 query_id: self.query.query_id.id.clone(),
437 stage_id: self.stage_id.into(),
438 task_id: id as u64,
439 };
440 let plan_fragment =
441 self.create_plan_fragment(id as u64, Some(PartitionInfo::File(files.to_vec())));
442 let worker = self.choose_worker(&plan_fragment, id as u32, stage.dml_table_id)?;
443 futures.push(self.schedule_task(
444 task_id,
445 plan_fragment,
446 worker,
447 expr_context.clone(),
448 ));
449 }
450 } else {
451 for id in 0..stage.parallelism.unwrap() {
452 let task_id = PbTaskId {
453 query_id: self.query.query_id.id.clone(),
454 stage_id: self.stage_id.into(),
455 task_id: id as u64,
456 };
457 let plan_fragment = self.create_plan_fragment(id as u64, None);
458 let worker = self.choose_worker(&plan_fragment, id, stage.dml_table_id)?;
459 futures.push(self.schedule_task(
460 task_id,
461 plan_fragment,
462 worker,
463 expr_context.clone(),
464 ));
465 }
466 }
467
468 let buffered = stream::iter(futures).buffer_unordered(TASK_SCHEDULING_PARALLELISM);
470 let buffered_streams = buffered.try_collect::<Vec<_>>().await?;
471
472 let cancelled = pin!(shutdown_rx.cancelled());
474 let mut all_streams = select_all(buffered_streams).take_until(cancelled);
475
476 let mut running_task_cnt = 0;
478 let mut finished_task_cnt = 0;
479 let mut sent_signal_to_next = false;
480
481 while let Some(status_res_inner) = all_streams.next().await {
482 match status_res_inner {
483 Ok(status) => {
484 use risingwave_pb::task_service::task_info_response::TaskStatus as PbTaskStatus;
485 match PbTaskStatus::try_from(status.task_status).unwrap() {
486 PbTaskStatus::Running => {
487 running_task_cnt += 1;
488 assert!(running_task_cnt <= self.tasks.keys().len());
491 if running_task_cnt == self.tasks.keys().len() {
494 self.notify_stage_scheduled(QueryMessage::Stage(
495 StageEvent::Scheduled(self.stage_id),
496 ))
497 .await;
498 sent_signal_to_next = true;
499 }
500 }
501
502 PbTaskStatus::Finished => {
503 finished_task_cnt += 1;
504 assert!(finished_task_cnt <= self.tasks.keys().len());
505 assert!(running_task_cnt >= finished_task_cnt);
506 if finished_task_cnt == self.tasks.keys().len() {
507 self.notify_stage_completed().await;
510 sent_signal_to_next = true;
511 break;
512 }
513 }
514 PbTaskStatus::Aborted => {
515 error!(
519 "Abort task {:?} because of excessive memory usage. Please try again later.",
520 status.task_id.unwrap()
521 );
522 self.notify_stage_state_changed(
523 |_| StageState::Failed,
524 QueryMessage::Stage(Failed {
525 id: self.stage_id,
526 reason: TaskRunningOutOfMemory,
527 }),
528 )
529 .await;
530 sent_signal_to_next = true;
531 break;
532 }
533 PbTaskStatus::Failed => {
534 error!(
536 "Task {:?} failed, reason: {:?}",
537 status.task_id.unwrap(),
538 status.error_message,
539 );
540 self.notify_stage_state_changed(
541 |_| StageState::Failed,
542 QueryMessage::Stage(Failed {
543 id: self.stage_id,
544 reason: TaskExecutionError(status.error_message),
545 }),
546 )
547 .await;
548 sent_signal_to_next = true;
549 break;
550 }
551 PbTaskStatus::Ping => {
552 debug!("Receive ping from task {:?}", status.task_id.unwrap());
553 }
554 status => {
555 unreachable!("Unexpected task status {:?}", status);
558 }
559 }
560 }
561 Err(e) => {
562 error!(
564 "Fetching task status in stage {:?} failed, reason: {:?}",
565 self.stage_id,
566 e.message()
567 );
568 self.notify_stage_state_changed(
569 |_| StageState::Failed,
570 QueryMessage::Stage(Failed {
571 id: self.stage_id,
572 reason: RpcError::from_batch_status(e).into(),
573 }),
574 )
575 .await;
576 sent_signal_to_next = true;
577 break;
578 }
579 }
580 }
581
582 tracing::trace!(
583 "Stage [{:?}-{:?}], running task count: {}, finished task count: {}, sent signal to next: {}",
584 self.query.query_id,
585 self.stage_id,
586 running_task_cnt,
587 finished_task_cnt,
588 sent_signal_to_next,
589 );
590
591 if let Some(shutdown) = all_streams.take_future() {
592 tracing::trace!(
593 "Stage [{:?}-{:?}] waiting for stopping signal.",
594 self.query.query_id,
595 self.stage_id
596 );
597 shutdown.await;
599 }
600
601 tracing::trace!(
606 "Stopping stage: {:?}-{:?}, task_num: {}",
607 self.query.query_id,
608 self.stage_id,
609 self.tasks.len()
610 );
611 self.cancel_all_scheducancled_tasks().await?;
612
613 tracing::trace!(
614 "Stage runner [{:?}-{:?}] exited.",
615 self.query.query_id,
616 self.stage_id
617 );
618 Ok(())
619 }
620
621 async fn schedule_tasks_for_root(
622 &mut self,
623 mut shutdown_rx: ShutdownToken,
624 expr_context: ExprContext,
625 ) -> SchedulerResult<()> {
626 let root_stage_id = self.stage_id;
627 let plan_fragment = self.create_plan_fragment(ROOT_TASK_ID, None);
630 let plan_node = plan_fragment.root.unwrap();
631 let task_id = TaskIdBatch {
632 query_id: self.query.query_id.id.clone(),
633 stage_id: root_stage_id.into(),
634 task_id: 0,
635 };
636
637 let (result_tx, result_rx) = tokio::sync::mpsc::channel(
639 self.ctx
640 .session
641 .env()
642 .batch_config()
643 .developer
644 .root_stage_channel_size,
645 );
646 self.notify_stage_scheduled(QueryMessage::Stage(StageEvent::ScheduledRoot(result_rx)))
647 .await;
648
649 let executor = ExecutorBuilder::new(
650 &plan_node,
651 &task_id,
652 self.ctx.to_batch_task_context(),
653 shutdown_rx.clone(),
654 );
655
656 let shutdown_rx0 = shutdown_rx.clone();
657
658 let result = expr_context_scope(expr_context, async {
659 let executor = executor.build().await?;
660 let chunk_stream = executor.execute();
661 let cancelled = pin!(shutdown_rx.cancelled());
662 #[for_await]
663 for chunk in chunk_stream.take_until(cancelled) {
664 if let Err(ref e) = chunk {
665 if shutdown_rx0.is_cancelled() {
666 break;
667 }
668 let err_str = e.to_report_string();
669 if let Err(_e) = result_tx.send(chunk.map_err(|e| e.into())).await {
673 warn!("Root executor has been dropped before receive any events so the send is failed");
674 }
675 return Err(TaskExecutionError(err_str));
677 } else {
678 if let Err(_e) = result_tx.send(chunk.map_err(|e| e.into())).await {
680 warn!("Root executor has been dropped before receive any events so the send is failed");
681 }
682 }
683 }
684 Ok(())
685 }).await;
686
687 if let Err(err) = &result {
688 if let Err(_e) = result_tx
692 .send(Err(TaskExecutionError(err.to_report_string())))
693 .await
694 {
695 warn!("Send task execution failed");
696 }
697 }
698
699 match shutdown_rx0.message() {
701 ShutdownMsg::Abort(err_str) => {
702 if let Err(_e) = result_tx.send(Err(TaskExecutionError(err_str))).await {
704 warn!("Send task execution failed");
705 }
706 }
707 _ => self.notify_stage_completed().await,
708 }
709
710 tracing::trace!(
711 "Stage runner [{:?}-{:?}] existed. ",
712 self.query.query_id,
713 self.stage_id
714 );
715
716 result.map(|_| ())
719 }
720
721 async fn schedule_tasks_for_all(&mut self, shutdown_rx: ShutdownToken) -> SchedulerResult<()> {
722 let expr_context = ExprContext {
723 time_zone: self.ctx.session().config().timezone(),
724 strict_mode: self.ctx.session().config().batch_expr_strict_mode(),
725 };
726 if !self.is_root_stage() {
728 self.schedule_tasks(shutdown_rx, expr_context).await?;
729 } else {
730 self.schedule_tasks_for_root(shutdown_rx, expr_context)
731 .await?;
732 }
733 Ok(())
734 }
735
736 #[inline(always)]
737 fn get_fragment_id(&self, table_id: TableId) -> SchedulerResult<FragmentId> {
738 self.catalog_reader
739 .read_guard()
740 .get_any_table_by_id(table_id)
741 .map(|table| table.fragment_id)
742 .map_err(|e| SchedulerError::Internal(anyhow!(e)))
743 }
744
745 #[inline(always)]
746 fn get_table_dml_vnode_mapping(&self, table_id: TableId) -> SchedulerResult<WorkerSlotMapping> {
747 let guard = self.catalog_reader.read_guard();
748
749 let table = guard
750 .get_any_table_by_id(table_id)
751 .map_err(|e| SchedulerError::Internal(anyhow!(e)))?;
752
753 let fragment_id = match table.dml_fragment_id.as_ref() {
754 Some(dml_fragment_id) => dml_fragment_id,
755 None => &table.fragment_id,
757 };
758
759 self.worker_node_manager
760 .manager
761 .get_streaming_fragment_mapping(fragment_id)
762 .map_err(|e| e.into())
763 }
764
765 fn choose_worker(
766 &self,
767 plan_fragment: &PlanFragment,
768 task_id: u32,
769 dml_table_id: Option<TableId>,
770 ) -> SchedulerResult<Option<WorkerNode>> {
771 let plan_node = plan_fragment.root.as_ref().expect("fail to get plan node");
772
773 if let Some(table_id) = dml_table_id {
774 let vnode_mapping = self.get_table_dml_vnode_mapping(table_id)?;
775 let worker_slot_ids = vnode_mapping.iter_unique().collect_vec();
776 let candidates = self
777 .worker_node_manager
778 .manager
779 .get_workers_by_worker_slot_ids(&worker_slot_ids)?;
780 if candidates.is_empty() {
781 return Err(BatchError::EmptyWorkerNodes.into());
782 }
783 let stage = &self.query.stage(self.stage_id);
784 let candidate = if stage.batch_enable_distributed_dml {
785 candidates[task_id as usize % candidates.len()].clone()
788 } else {
789 candidates[stage.session_id.0 as usize % candidates.len()].clone()
791 };
792 return Ok(Some(candidate));
793 };
794
795 if let Some(distributed_lookup_join_node) =
796 Self::find_distributed_lookup_join_node(plan_node)
797 {
798 let fragment_id = self.get_fragment_id(
799 distributed_lookup_join_node
800 .inner_side_table_desc
801 .as_ref()
802 .unwrap()
803 .table_id,
804 )?;
805 let id_to_worker_slots = self
806 .worker_node_manager
807 .fragment_mapping(fragment_id)?
808 .iter_unique()
809 .collect_vec();
810
811 let worker_slot_id = id_to_worker_slots[task_id as usize];
812 let candidates = self
813 .worker_node_manager
814 .manager
815 .get_workers_by_worker_slot_ids(&[worker_slot_id])?;
816 if candidates.is_empty() {
817 return Err(BatchError::EmptyWorkerNodes.into());
818 }
819 Ok(Some(candidates[0].clone()))
820 } else {
821 Ok(None)
822 }
823 }
824
825 fn find_distributed_lookup_join_node(
826 plan_node: &PlanNode,
827 ) -> Option<&DistributedLookupJoinNode> {
828 let node_body = plan_node.node_body.as_ref().expect("fail to get node body");
829
830 match node_body {
831 NodeBody::DistributedLookupJoin(distributed_lookup_join_node) => {
832 Some(distributed_lookup_join_node)
833 }
834 _ => plan_node
835 .children
836 .iter()
837 .find_map(Self::find_distributed_lookup_join_node),
838 }
839 }
840
841 async fn notify_stage_scheduled(&self, msg: QueryMessage) {
843 self.notify_stage_state_changed(
844 |old_state| {
845 assert_matches!(old_state, StageState::Started);
846 StageState::Running
847 },
848 msg,
849 )
850 .await
851 }
852
853 async fn notify_stage_completed(&self) {
855 self.notify_stage_state_changed(
856 |old_state| {
857 assert_matches!(old_state, StageState::Running);
858 StageState::Completed
859 },
860 QueryMessage::Stage(StageEvent::Completed(self.stage_id)),
861 )
862 .await
863 }
864
865 async fn notify_stage_state_changed<F>(&self, new_state: F, msg: QueryMessage)
866 where
867 F: FnOnce(StageState) -> StageState,
868 {
869 {
870 let mut s = self.state.write().await;
871 let old_state = mem::replace(&mut *s, StageState::Failed);
872 *s = new_state(old_state);
873 }
874
875 self.send_event(msg).await;
876 }
877
878 async fn cancel_all_scheducancled_tasks(&self) -> SchedulerResult<()> {
882 for (task, task_status) in &*self.tasks {
894 let loc = &task_status.get_status().location;
896 let addr = loc.as_ref().expect("Get address should not fail");
897 let client = self
898 .compute_client_pool
899 .get_by_addr(HostAddr::from(addr))
900 .await
901 .map_err(|e| anyhow!(e))?;
902
903 let query_id = self.query.query_id.id.clone();
905 let stage_id = self.stage_id;
906 let task_id = *task;
907 spawn(async move {
908 if let Err(e) = client
909 .cancel(CancelTaskRequest {
910 task_id: Some(risingwave_pb::batch_plan::TaskId {
911 query_id: query_id.clone(),
912 stage_id: stage_id.into(),
913 task_id,
914 }),
915 })
916 .await
917 {
918 error!(
919 error = %e.as_report(),
920 ?task_id,
921 ?query_id,
922 ?stage_id,
923 "Abort task failed",
924 );
925 };
926 });
927 }
928 Ok(())
929 }
930
931 async fn schedule_task(
932 &self,
933 task_id: PbTaskId,
934 plan_fragment: PlanFragment,
935 worker: Option<WorkerNode>,
936 expr_context: ExprContext,
937 ) -> SchedulerResult<Fuse<Streaming<TaskInfoResponse>>> {
938 let mut worker = worker.unwrap_or(self.worker_node_manager.next_random_worker()?);
939 let worker_node_addr = worker.host.take().unwrap();
940 let compute_client = self
941 .compute_client_pool
942 .get_by_addr((&worker_node_addr).into())
943 .await
944 .inspect_err(|_| self.mask_failed_serving_worker(&worker))
945 .map_err(|e| anyhow!(e))?;
946
947 let t_id = task_id.task_id;
948
949 let stream_status: Fuse<Streaming<TaskInfoResponse>> = compute_client
950 .create_task(task_id, plan_fragment, expr_context)
951 .await
952 .inspect_err(|_| self.mask_failed_serving_worker(&worker))
953 .map_err(|e| anyhow!(e))?
954 .fuse();
955
956 self.tasks[&t_id].inner.store(Arc::new(TaskStatus {
957 _task_id: t_id,
958 location: Some(worker_node_addr),
959 }));
960
961 Ok(stream_status)
962 }
963
964 fn create_plan_fragment(
965 &self,
966 task_id: TaskId,
967 partition: Option<PartitionInfo>,
968 ) -> PlanFragment {
969 let mut identity_id = 0;
971
972 let stage = &self.query.stage(self.stage_id);
973
974 let plan_node_prost =
975 self.convert_plan_node(&stage.root, task_id, partition, &mut identity_id);
976 let exchange_info = stage.exchange_info.clone().unwrap();
977
978 PlanFragment {
979 root: Some(plan_node_prost),
980 exchange_info: Some(exchange_info),
981 }
982 }
983
984 fn convert_plan_node(
985 &self,
986 execution_plan_node: &ExecutionPlanNode,
987 task_id: TaskId,
988 partition: Option<PartitionInfo>,
989 identity_id: &mut u64,
990 ) -> PbPlanNode {
991 let identity = {
993 let identity_type = execution_plan_node.plan_node_type;
994 let id = *identity_id;
995 *identity_id += 1;
996 format!("{:?}-{}", identity_type, id)
997 };
998
999 match execution_plan_node.plan_node_type {
1000 BatchPlanNodeType::BatchExchange => {
1001 let child_stage = self
1003 .children
1004 .iter()
1005 .find(|child_stage| {
1006 child_stage.stage_id == execution_plan_node.source_stage_id.unwrap()
1007 })
1008 .unwrap();
1009 let exchange_sources = child_stage.all_exchange_sources_for(task_id);
1010
1011 match &execution_plan_node.node {
1012 NodeBody::Exchange(exchange_node) => PbPlanNode {
1013 children: vec![],
1014 identity,
1015 node_body: Some(NodeBody::Exchange(ExchangeNode {
1016 sources: exchange_sources,
1017 sequential: exchange_node.sequential,
1018 input_schema: execution_plan_node.schema.clone(),
1019 })),
1020 },
1021 NodeBody::MergeSortExchange(sort_merge_exchange_node) => PbPlanNode {
1022 children: vec![],
1023 identity,
1024 node_body: Some(NodeBody::MergeSortExchange(MergeSortExchangeNode {
1025 exchange: Some(ExchangeNode {
1026 sources: exchange_sources,
1027 sequential: false,
1028 input_schema: execution_plan_node.schema.clone(),
1029 }),
1030 column_orders: sort_merge_exchange_node.column_orders.clone(),
1031 })),
1032 },
1033 _ => unreachable!(),
1034 }
1035 }
1036 BatchPlanNodeType::BatchSeqScan => {
1037 let node_body = execution_plan_node.node.clone();
1038 let NodeBody::RowSeqScan(mut scan_node) = node_body else {
1039 unreachable!();
1040 };
1041 let partition = partition
1042 .expect("no partition info for seq scan")
1043 .into_table()
1044 .expect("PartitionInfo should be TablePartitionInfo");
1045 scan_node.vnode_bitmap = Some(partition.vnode_bitmap.to_protobuf());
1046 scan_node.scan_ranges = partition.scan_ranges;
1047 PbPlanNode {
1048 children: vec![],
1049 identity,
1050 node_body: Some(NodeBody::RowSeqScan(scan_node)),
1051 }
1052 }
1053 BatchPlanNodeType::BatchLogSeqScan => {
1054 let node_body = execution_plan_node.node.clone();
1055 let NodeBody::LogRowSeqScan(mut scan_node) = node_body else {
1056 unreachable!();
1057 };
1058 let partition = partition
1059 .expect("no partition info for seq scan")
1060 .into_table()
1061 .expect("PartitionInfo should be TablePartitionInfo");
1062 scan_node.vnode_bitmap = Some(partition.vnode_bitmap.to_protobuf());
1063 PbPlanNode {
1064 children: vec![],
1065 identity,
1066 node_body: Some(NodeBody::LogRowSeqScan(scan_node)),
1067 }
1068 }
1069 BatchPlanNodeType::BatchSource | BatchPlanNodeType::BatchKafkaScan => {
1070 let node_body = execution_plan_node.node.clone();
1071 let NodeBody::Source(mut source_node) = node_body else {
1072 unreachable!();
1073 };
1074
1075 let partition = partition
1076 .expect("no partition info for seq scan")
1077 .into_source()
1078 .expect("PartitionInfo should be SourcePartitionInfo");
1079 source_node.split = partition
1080 .into_iter()
1081 .map(|split| split.encode_to_bytes().into())
1082 .collect_vec();
1083 PbPlanNode {
1084 children: vec![],
1085 identity,
1086 node_body: Some(NodeBody::Source(source_node)),
1087 }
1088 }
1089 BatchPlanNodeType::BatchIcebergScan => {
1090 let node_body = execution_plan_node.node.clone();
1091 let NodeBody::IcebergScan(mut iceberg_scan_node) = node_body else {
1092 unreachable!();
1093 };
1094
1095 let partition = partition
1096 .expect("no partition info for seq scan")
1097 .into_source()
1098 .expect("PartitionInfo should be SourcePartitionInfo");
1099 iceberg_scan_node.split = partition
1100 .into_iter()
1101 .map(|split| split.encode_to_bytes().into())
1102 .collect_vec();
1103 PbPlanNode {
1104 children: vec![],
1105 identity,
1106 node_body: Some(NodeBody::IcebergScan(iceberg_scan_node)),
1107 }
1108 }
1109 _ => {
1110 let children = execution_plan_node
1111 .children
1112 .iter()
1113 .map(|e| self.convert_plan_node(e, task_id, partition.clone(), identity_id))
1114 .collect();
1115
1116 PbPlanNode {
1117 children,
1118 identity,
1119 node_body: Some(execution_plan_node.node.clone()),
1120 }
1121 }
1122 }
1123 }
1124
1125 fn is_root_stage(&self) -> bool {
1126 self.stage_id == 0.into()
1127 }
1128
1129 fn mask_failed_serving_worker(&self, worker: &WorkerNode) {
1130 if !worker.property.as_ref().is_some_and(|p| p.is_serving) {
1131 return;
1132 }
1133 let duration = Duration::from_secs(std::cmp::max(
1134 self.ctx
1135 .session
1136 .env()
1137 .batch_config()
1138 .mask_worker_temporary_secs as u64,
1139 1,
1140 ));
1141 self.worker_node_manager
1142 .manager
1143 .mask_worker_node(worker.id, duration);
1144 }
1145}