1use std::assert_matches::assert_matches;
16use std::collections::HashMap;
17use std::mem;
18use std::pin::pin;
19use std::sync::Arc;
20use std::time::Duration;
21
22use StageEvent::Failed;
23use anyhow::anyhow;
24use arc_swap::ArcSwap;
25use futures::stream::Fuse;
26use futures::{StreamExt, TryStreamExt, stream};
27use futures_async_stream::for_await;
28use itertools::Itertools;
29use risingwave_batch::error::BatchError;
30use risingwave_batch::executor::ExecutorBuilder;
31use risingwave_batch::task::{ShutdownMsg, ShutdownSender, ShutdownToken, TaskId as TaskIdBatch};
32use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeSelector;
33use risingwave_common::array::DataChunk;
34use risingwave_common::hash::WorkerSlotMapping;
35use risingwave_common::util::addr::HostAddr;
36use risingwave_common::util::iter_util::ZipEqFast;
37use risingwave_connector::source::SplitMetaData;
38use risingwave_expr::expr_context::expr_context_scope;
39use risingwave_pb::batch_plan::plan_node::NodeBody;
40use risingwave_pb::batch_plan::{
41 DistributedLookupJoinNode, ExchangeNode, ExchangeSource, MergeSortExchangeNode, PlanFragment,
42 PlanNode as PbPlanNode, PlanNode, TaskId as PbTaskId, TaskOutputId,
43};
44use risingwave_pb::common::{HostAddress, WorkerNode};
45use risingwave_pb::plan_common::ExprContext;
46use risingwave_pb::task_service::{CancelTaskRequest, TaskInfoResponse};
47use risingwave_rpc_client::ComputeClientPoolRef;
48use risingwave_rpc_client::error::RpcError;
49use rw_futures_util::select_all;
50use thiserror_ext::AsReport;
51use tokio::spawn;
52use tokio::sync::RwLock;
53use tokio::sync::mpsc::{Receiver, Sender};
54use tonic::Streaming;
55use tracing::{Instrument, debug, error, warn};
56
57use crate::catalog::catalog_service::CatalogReader;
58use crate::catalog::{FragmentId, TableId};
59use crate::optimizer::plan_node::BatchPlanNodeType;
60use crate::scheduler::SchedulerError::{TaskExecutionError, TaskRunningOutOfMemory};
61use crate::scheduler::distributed::QueryMessage;
62use crate::scheduler::distributed::stage::StageState::Pending;
63use crate::scheduler::plan_fragmenter::{
64 ExecutionPlanNode, PartitionInfo, Query, ROOT_TASK_ID, StageId, TaskId,
65};
66use crate::scheduler::{ExecutionContextRef, SchedulerError, SchedulerResult};
67
68const TASK_SCHEDULING_PARALLELISM: usize = 10;
69
70#[derive(Debug)]
71enum StageState {
72 Pending {
77 msg_sender: Sender<QueryMessage>,
78 },
79 Started,
80 Running,
81 Completed,
82 Failed,
83}
84
85#[derive(Debug)]
86pub enum StageEvent {
87 Scheduled(StageId),
88 ScheduledRoot(Receiver<SchedulerResult<DataChunk>>),
89 Failed {
91 id: StageId,
92 reason: SchedulerError,
93 },
94 Completed(#[allow(dead_code)] StageId),
96}
97
98#[derive(Clone)]
99pub struct TaskStatus {
100 _task_id: TaskId,
101
102 location: Option<HostAddress>,
104}
105
106struct TaskStatusHolder {
107 inner: ArcSwap<TaskStatus>,
108}
109
110pub struct StageExecution {
111 stage_id: StageId,
112 query: Arc<Query>,
113 worker_node_manager: WorkerNodeSelector,
114 tasks: Arc<HashMap<TaskId, TaskStatusHolder>>,
115 state: Arc<RwLock<StageState>>,
116 shutdown_tx: RwLock<Option<ShutdownSender>>,
117 children: Vec<Arc<StageExecution>>,
121 compute_client_pool: ComputeClientPoolRef,
122 catalog_reader: CatalogReader,
123
124 ctx: ExecutionContextRef,
126}
127
128struct StageRunner {
129 state: Arc<RwLock<StageState>>,
130 stage_id: StageId,
131 query: Arc<Query>,
132 worker_node_manager: WorkerNodeSelector,
133 tasks: Arc<HashMap<TaskId, TaskStatusHolder>>,
134 msg_sender: Sender<QueryMessage>,
136 children: Vec<Arc<StageExecution>>,
137 compute_client_pool: ComputeClientPoolRef,
138 catalog_reader: CatalogReader,
139
140 ctx: ExecutionContextRef,
141}
142
143impl TaskStatusHolder {
144 fn new(task_id: TaskId) -> Self {
145 let task_status = TaskStatus {
146 _task_id: task_id,
147 location: None,
148 };
149
150 Self {
151 inner: ArcSwap::new(Arc::new(task_status)),
152 }
153 }
154
155 fn get_status(&self) -> Arc<TaskStatus> {
156 self.inner.load_full()
157 }
158}
159
160impl StageExecution {
161 pub fn new(
162 stage_id: StageId,
163 query: Arc<Query>,
164 worker_node_manager: WorkerNodeSelector,
165 msg_sender: Sender<QueryMessage>,
166 children: Vec<Arc<StageExecution>>,
167 compute_client_pool: ComputeClientPoolRef,
168 catalog_reader: CatalogReader,
169 ctx: ExecutionContextRef,
170 ) -> Self {
171 let tasks = (0..query.stage(stage_id).parallelism.unwrap())
172 .map(|task_id| (task_id as u64, TaskStatusHolder::new(task_id as u64)))
173 .collect();
174
175 Self {
176 stage_id,
177 query,
178 worker_node_manager,
179 tasks: Arc::new(tasks),
180 state: Arc::new(RwLock::new(Pending { msg_sender })),
181 shutdown_tx: RwLock::new(None),
182 children,
183 compute_client_pool,
184 catalog_reader,
185 ctx,
186 }
187 }
188
189 pub async fn start(&self) {
191 let mut s = self.state.write().await;
192 let cur_state = mem::replace(&mut *s, StageState::Failed);
193 match cur_state {
194 Pending { msg_sender } => {
195 let runner = StageRunner {
196 stage_id: self.stage_id,
197 query: self.query.clone(),
198 worker_node_manager: self.worker_node_manager.clone(),
199 tasks: self.tasks.clone(),
200 msg_sender,
201 children: self.children.clone(),
202 state: self.state.clone(),
203 compute_client_pool: self.compute_client_pool.clone(),
204 catalog_reader: self.catalog_reader.clone(),
205 ctx: self.ctx.clone(),
206 };
207
208 let (sender, receiver) = ShutdownToken::new();
210 let mut holder = self.shutdown_tx.write().await;
212 *holder = Some(sender);
213
214 *s = StageState::Started;
216
217 let span = tracing::info_span!(
218 "stage",
219 "otel.name" = format!("Stage {}-{}", self.query.query_id.id, self.stage_id),
220 query_id = self.query.query_id.id,
221 stage_id = %self.stage_id,
222 );
223 self.ctx
224 .session()
225 .env()
226 .compute_runtime()
227 .spawn(async move { runner.run(receiver).instrument(span).await });
228
229 tracing::trace!(
230 "Stage {:?}-{:?} started.",
231 self.query.query_id.id,
232 self.stage_id
233 )
234 }
235 _ => {
236 unreachable!("Only expect to schedule stage once");
237 }
238 }
239 }
240
241 pub async fn stop(&self, error: Option<String>) {
242 if let Some(shutdown_tx) = self.shutdown_tx.write().await.take() {
244 if !if let Some(error) = error {
248 shutdown_tx.abort(error)
249 } else {
250 shutdown_tx.cancel()
251 } {
252 tracing::trace!(
254 "Failed to send stop message stage: {:?}-{:?}",
255 self.query.query_id,
256 self.stage_id
257 );
258 }
259 }
260 }
261
262 pub async fn is_scheduled(&self) -> bool {
263 let s = self.state.read().await;
264 matches!(*s, StageState::Running | StageState::Completed)
265 }
266
267 pub async fn is_pending(&self) -> bool {
268 let s = self.state.read().await;
269 matches!(*s, StageState::Pending { .. })
270 }
271
272 pub async fn state(&self) -> &'static str {
273 let s = self.state.read().await;
274 match *s {
275 Pending { .. } => "Pending",
276 StageState::Started => "Started",
277 StageState::Running => "Running",
278 StageState::Completed => "Completed",
279 StageState::Failed => "Failed",
280 }
281 }
282
283 pub fn all_exchange_sources_for(&self, output_id: u64) -> Vec<ExchangeSource> {
290 self.tasks
291 .iter()
292 .map(|(task_id, status_holder)| {
293 let task_output_id = TaskOutputId {
294 task_id: Some(PbTaskId {
295 query_id: self.query.query_id.id.clone(),
296 stage_id: self.stage_id.into(),
297 task_id: *task_id,
298 }),
299 output_id,
300 };
301
302 ExchangeSource {
303 task_output_id: Some(task_output_id),
304 host: Some(status_holder.inner.load_full().location.clone().unwrap()),
305 local_execute_plan: None,
306 }
307 })
308 .collect()
309 }
310}
311
312impl StageRunner {
313 async fn run(mut self, shutdown_rx: ShutdownToken) {
314 if let Err(e) = self.schedule_tasks_for_all(shutdown_rx).await {
315 error!(
316 error = %e.as_report(),
317 query_id = ?self.query.query_id,
318 stage_id = ?self.stage_id,
319 "Failed to schedule tasks"
320 );
321 self.send_event(QueryMessage::Stage(Failed {
322 id: self.stage_id,
323 reason: e,
324 }))
325 .await;
326 }
327 }
328
329 async fn send_event(&self, event: QueryMessage) {
331 if let Err(_e) = self.msg_sender.send(event).await {
332 warn!("Failed to send event to Query Runner, may be killed by previous failed event");
333 }
334 }
335
336 async fn schedule_tasks(
339 &mut self,
340 mut shutdown_rx: ShutdownToken,
341 expr_context: ExprContext,
342 ) -> SchedulerResult<()> {
343 let mut futures = vec![];
344 let stage = &self.query.stage(self.stage_id);
345
346 if let Some(table_scan_info) = stage.table_scan_info.as_ref()
347 && let Some(vnode_bitmaps) = table_scan_info.partitions()
348 {
349 let worker_slot_ids = vnode_bitmaps.keys().cloned().collect_vec();
355 let workers = self
356 .worker_node_manager
357 .manager
358 .get_workers_by_worker_slot_ids(&worker_slot_ids)?;
359
360 for (i, (worker_slot_id, worker)) in worker_slot_ids
361 .into_iter()
362 .zip_eq_fast(workers.into_iter())
363 .enumerate()
364 {
365 let task_id = PbTaskId {
366 query_id: self.query.query_id.id.clone(),
367 stage_id: self.stage_id.into(),
368 task_id: i as u64,
369 };
370 let vnode_ranges = vnode_bitmaps[&worker_slot_id].clone();
371 let plan_fragment =
372 self.create_plan_fragment(i as u64, Some(PartitionInfo::Table(vnode_ranges)));
373 futures.push(self.schedule_task(
374 task_id,
375 plan_fragment,
376 Some(worker),
377 expr_context.clone(),
378 ));
379 }
380 } else if let Some(source_info) = stage.source_info.as_ref() {
381 let chunk_size = ((source_info.split_info().unwrap().len() as f32
383 / stage.parallelism.unwrap() as f32)
384 .ceil() as usize)
385 .max(1);
386 if source_info.split_info().unwrap().is_empty() {
387 const EMPTY_TASK_ID: u64 = 0;
389 let task_id = PbTaskId {
390 query_id: self.query.query_id.id.clone(),
391 stage_id: self.stage_id.into(),
392 task_id: EMPTY_TASK_ID,
393 };
394 let plan_fragment =
395 self.create_plan_fragment(EMPTY_TASK_ID, Some(PartitionInfo::Source(vec![])));
396 let worker =
397 self.choose_worker(&plan_fragment, EMPTY_TASK_ID as u32, stage.dml_table_id)?;
398 futures.push(self.schedule_task(
399 task_id,
400 plan_fragment,
401 worker,
402 expr_context.clone(),
403 ));
404 } else {
405 for (id, split) in source_info
406 .split_info()
407 .unwrap()
408 .chunks(chunk_size)
409 .enumerate()
410 {
411 let task_id = PbTaskId {
412 query_id: self.query.query_id.id.clone(),
413 stage_id: self.stage_id.into(),
414 task_id: id as u64,
415 };
416 let plan_fragment = self.create_plan_fragment(
417 id as u64,
418 Some(PartitionInfo::Source(split.to_vec())),
419 );
420 let worker =
421 self.choose_worker(&plan_fragment, id as u32, stage.dml_table_id)?;
422 futures.push(self.schedule_task(
423 task_id,
424 plan_fragment,
425 worker,
426 expr_context.clone(),
427 ));
428 }
429 }
430 } else if let Some(file_scan_info) = stage.file_scan_info.as_ref() {
431 let chunk_size = (file_scan_info.file_location.len() as f32
432 / stage.parallelism.unwrap() as f32)
433 .ceil() as usize;
434 for (id, files) in file_scan_info.file_location.chunks(chunk_size).enumerate() {
435 let task_id = PbTaskId {
436 query_id: self.query.query_id.id.clone(),
437 stage_id: self.stage_id.into(),
438 task_id: id as u64,
439 };
440 let plan_fragment =
441 self.create_plan_fragment(id as u64, Some(PartitionInfo::File(files.to_vec())));
442 let worker = self.choose_worker(&plan_fragment, id as u32, stage.dml_table_id)?;
443 futures.push(self.schedule_task(
444 task_id,
445 plan_fragment,
446 worker,
447 expr_context.clone(),
448 ));
449 }
450 } else {
451 for id in 0..stage.parallelism.unwrap() {
452 let task_id = PbTaskId {
453 query_id: self.query.query_id.id.clone(),
454 stage_id: self.stage_id.into(),
455 task_id: id as u64,
456 };
457 let plan_fragment = self.create_plan_fragment(id as u64, None);
458 let worker = self.choose_worker(&plan_fragment, id, stage.dml_table_id)?;
459 futures.push(self.schedule_task(
460 task_id,
461 plan_fragment,
462 worker,
463 expr_context.clone(),
464 ));
465 }
466 }
467
468 let buffered = stream::iter(futures).buffer_unordered(TASK_SCHEDULING_PARALLELISM);
470 let buffered_streams = buffered.try_collect::<Vec<_>>().await?;
471
472 let cancelled = pin!(shutdown_rx.cancelled());
474 let mut all_streams = select_all(buffered_streams).take_until(cancelled);
475
476 let mut running_task_cnt = 0;
478 let mut finished_task_cnt = 0;
479 let mut sent_signal_to_next = false;
480
481 while let Some(status_res_inner) = all_streams.next().await {
482 match status_res_inner {
483 Ok(status) => {
484 use risingwave_pb::task_service::task_info_response::TaskStatus as PbTaskStatus;
485 match PbTaskStatus::try_from(status.task_status).unwrap() {
486 PbTaskStatus::Running => {
487 running_task_cnt += 1;
488 assert!(running_task_cnt <= self.tasks.keys().len());
491 if running_task_cnt == self.tasks.keys().len() {
494 self.notify_stage_scheduled(QueryMessage::Stage(
495 StageEvent::Scheduled(self.stage_id),
496 ))
497 .await;
498 sent_signal_to_next = true;
499 }
500 }
501
502 PbTaskStatus::Finished => {
503 finished_task_cnt += 1;
504 assert!(finished_task_cnt <= self.tasks.keys().len());
505 assert!(running_task_cnt >= finished_task_cnt);
506 if finished_task_cnt == self.tasks.keys().len() {
507 self.notify_stage_completed().await;
510 sent_signal_to_next = true;
511 break;
512 }
513 }
514 PbTaskStatus::Aborted => {
515 error!(
519 "Abort task {:?} because of excessive memory usage. Please try again later.",
520 status.task_id.unwrap()
521 );
522 self.notify_stage_state_changed(
523 |_| StageState::Failed,
524 QueryMessage::Stage(Failed {
525 id: self.stage_id,
526 reason: TaskRunningOutOfMemory,
527 }),
528 )
529 .await;
530 sent_signal_to_next = true;
531 break;
532 }
533 PbTaskStatus::Failed => {
534 error!(
536 "Task {:?} failed, reason: {:?}",
537 status.task_id.unwrap(),
538 status.error_message,
539 );
540 self.notify_stage_state_changed(
541 |_| StageState::Failed,
542 QueryMessage::Stage(Failed {
543 id: self.stage_id,
544 reason: TaskExecutionError(status.error_message),
545 }),
546 )
547 .await;
548 sent_signal_to_next = true;
549 break;
550 }
551 PbTaskStatus::Ping => {
552 debug!("Receive ping from task {:?}", status.task_id.unwrap());
553 }
554 status => {
555 unreachable!("Unexpected task status {:?}", status);
558 }
559 }
560 }
561 Err(e) => {
562 error!(
564 "Fetching task status in stage {:?} failed, reason: {:?}",
565 self.stage_id,
566 e.message()
567 );
568 self.notify_stage_state_changed(
569 |_| StageState::Failed,
570 QueryMessage::Stage(Failed {
571 id: self.stage_id,
572 reason: RpcError::from_batch_status(e).into(),
573 }),
574 )
575 .await;
576 sent_signal_to_next = true;
577 break;
578 }
579 }
580 }
581
582 tracing::trace!(
583 "Stage [{:?}-{:?}], running task count: {}, finished task count: {}, sent signal to next: {}",
584 self.query.query_id,
585 self.stage_id,
586 running_task_cnt,
587 finished_task_cnt,
588 sent_signal_to_next,
589 );
590
591 if let Some(shutdown) = all_streams.take_future() {
592 tracing::trace!(
593 "Stage [{:?}-{:?}] waiting for stopping signal.",
594 self.query.query_id,
595 self.stage_id
596 );
597 shutdown.await;
599 }
600
601 tracing::trace!(
606 "Stopping stage: {:?}-{:?}, task_num: {}",
607 self.query.query_id,
608 self.stage_id,
609 self.tasks.len()
610 );
611 self.cancel_all_scheducancled_tasks().await?;
612
613 tracing::trace!(
614 "Stage runner [{:?}-{:?}] exited.",
615 self.query.query_id,
616 self.stage_id
617 );
618 Ok(())
619 }
620
621 async fn schedule_tasks_for_root(
622 &mut self,
623 mut shutdown_rx: ShutdownToken,
624 expr_context: ExprContext,
625 ) -> SchedulerResult<()> {
626 let root_stage_id = self.stage_id;
627 let plan_fragment = self.create_plan_fragment(ROOT_TASK_ID, None);
630 let plan_node = plan_fragment.root.unwrap();
631 let task_id = TaskIdBatch {
632 query_id: self.query.query_id.id.clone(),
633 stage_id: root_stage_id.into(),
634 task_id: 0,
635 };
636
637 let (result_tx, result_rx) = tokio::sync::mpsc::channel(
639 self.ctx
640 .session
641 .env()
642 .batch_config()
643 .developer
644 .root_stage_channel_size,
645 );
646 self.notify_stage_scheduled(QueryMessage::Stage(StageEvent::ScheduledRoot(result_rx)))
647 .await;
648
649 let executor = ExecutorBuilder::new(
650 &plan_node,
651 &task_id,
652 self.ctx.to_batch_task_context(),
653 shutdown_rx.clone(),
654 );
655
656 let shutdown_rx0 = shutdown_rx.clone();
657
658 let result = expr_context_scope(expr_context, async {
659 let executor = executor.build().await?;
660 let chunk_stream = executor.execute();
661 let cancelled = pin!(shutdown_rx.cancelled());
662 #[for_await]
663 for chunk in chunk_stream.take_until(cancelled) {
664 if let Err(ref e) = chunk {
665 if shutdown_rx0.is_cancelled() {
666 break;
667 }
668 let err_str = e.to_report_string();
669 if let Err(_e) = result_tx.send(chunk.map_err(|e| e.into())).await {
673 warn!("Root executor has been dropped before receive any events so the send is failed");
674 }
675 return Err(TaskExecutionError(err_str));
677 } else {
678 if let Err(_e) = result_tx.send(chunk.map_err(|e| e.into())).await {
680 warn!("Root executor has been dropped before receive any events so the send is failed");
681 }
682 }
683 }
684 Ok(())
685 }).await;
686
687 if let Err(err) = &result {
688 if let Err(_e) = result_tx
692 .send(Err(TaskExecutionError(err.to_report_string())))
693 .await
694 {
695 warn!("Send task execution failed");
696 }
697 }
698
699 match shutdown_rx0.message() {
701 ShutdownMsg::Abort(err_str) => {
702 if let Err(_e) = result_tx.send(Err(TaskExecutionError(err_str))).await {
704 warn!("Send task execution failed");
705 }
706 }
707 _ => self.notify_stage_completed().await,
708 }
709
710 tracing::trace!(
711 "Stage runner [{:?}-{:?}] existed. ",
712 self.query.query_id,
713 self.stage_id
714 );
715
716 result.map(|_| ())
719 }
720
721 async fn schedule_tasks_for_all(&mut self, shutdown_rx: ShutdownToken) -> SchedulerResult<()> {
722 let expr_context = ExprContext {
723 time_zone: self.ctx.session().config().timezone(),
724 strict_mode: self.ctx.session().config().batch_expr_strict_mode(),
725 };
726 if !self.is_root_stage() {
728 self.schedule_tasks(shutdown_rx, expr_context).await?;
729 } else {
730 self.schedule_tasks_for_root(shutdown_rx, expr_context)
731 .await?;
732 }
733 Ok(())
734 }
735
736 #[inline(always)]
737 fn get_fragment_id(&self, table_id: &TableId) -> SchedulerResult<FragmentId> {
738 self.catalog_reader
739 .read_guard()
740 .get_any_table_by_id(table_id)
741 .map(|table| table.fragment_id)
742 .map_err(|e| SchedulerError::Internal(anyhow!(e)))
743 }
744
745 #[inline(always)]
746 fn get_table_dml_vnode_mapping(
747 &self,
748 table_id: &TableId,
749 ) -> SchedulerResult<WorkerSlotMapping> {
750 let guard = self.catalog_reader.read_guard();
751
752 let table = guard
753 .get_any_table_by_id(table_id)
754 .map_err(|e| SchedulerError::Internal(anyhow!(e)))?;
755
756 let fragment_id = match table.dml_fragment_id.as_ref() {
757 Some(dml_fragment_id) => dml_fragment_id,
758 None => &table.fragment_id,
760 };
761
762 self.worker_node_manager
763 .manager
764 .get_streaming_fragment_mapping(fragment_id)
765 .map_err(|e| e.into())
766 }
767
768 fn choose_worker(
769 &self,
770 plan_fragment: &PlanFragment,
771 task_id: u32,
772 dml_table_id: Option<TableId>,
773 ) -> SchedulerResult<Option<WorkerNode>> {
774 let plan_node = plan_fragment.root.as_ref().expect("fail to get plan node");
775
776 if let Some(table_id) = dml_table_id {
777 let vnode_mapping = self.get_table_dml_vnode_mapping(&table_id)?;
778 let worker_slot_ids = vnode_mapping.iter_unique().collect_vec();
779 let candidates = self
780 .worker_node_manager
781 .manager
782 .get_workers_by_worker_slot_ids(&worker_slot_ids)?;
783 if candidates.is_empty() {
784 return Err(BatchError::EmptyWorkerNodes.into());
785 }
786 let stage = &self.query.stage(self.stage_id);
787 let candidate = if stage.batch_enable_distributed_dml {
788 candidates[task_id as usize % candidates.len()].clone()
791 } else {
792 candidates[stage.session_id.0 as usize % candidates.len()].clone()
794 };
795 return Ok(Some(candidate));
796 };
797
798 if let Some(distributed_lookup_join_node) =
799 Self::find_distributed_lookup_join_node(plan_node)
800 {
801 let fragment_id = self.get_fragment_id(
802 &distributed_lookup_join_node
803 .inner_side_table_desc
804 .as_ref()
805 .unwrap()
806 .table_id
807 .into(),
808 )?;
809 let id_to_worker_slots = self
810 .worker_node_manager
811 .fragment_mapping(fragment_id)?
812 .iter_unique()
813 .collect_vec();
814
815 let worker_slot_id = id_to_worker_slots[task_id as usize];
816 let candidates = self
817 .worker_node_manager
818 .manager
819 .get_workers_by_worker_slot_ids(&[worker_slot_id])?;
820 if candidates.is_empty() {
821 return Err(BatchError::EmptyWorkerNodes.into());
822 }
823 Ok(Some(candidates[0].clone()))
824 } else {
825 Ok(None)
826 }
827 }
828
829 fn find_distributed_lookup_join_node(
830 plan_node: &PlanNode,
831 ) -> Option<&DistributedLookupJoinNode> {
832 let node_body = plan_node.node_body.as_ref().expect("fail to get node body");
833
834 match node_body {
835 NodeBody::DistributedLookupJoin(distributed_lookup_join_node) => {
836 Some(distributed_lookup_join_node)
837 }
838 _ => plan_node
839 .children
840 .iter()
841 .find_map(Self::find_distributed_lookup_join_node),
842 }
843 }
844
845 async fn notify_stage_scheduled(&self, msg: QueryMessage) {
847 self.notify_stage_state_changed(
848 |old_state| {
849 assert_matches!(old_state, StageState::Started);
850 StageState::Running
851 },
852 msg,
853 )
854 .await
855 }
856
857 async fn notify_stage_completed(&self) {
859 self.notify_stage_state_changed(
860 |old_state| {
861 assert_matches!(old_state, StageState::Running);
862 StageState::Completed
863 },
864 QueryMessage::Stage(StageEvent::Completed(self.stage_id)),
865 )
866 .await
867 }
868
869 async fn notify_stage_state_changed<F>(&self, new_state: F, msg: QueryMessage)
870 where
871 F: FnOnce(StageState) -> StageState,
872 {
873 {
874 let mut s = self.state.write().await;
875 let old_state = mem::replace(&mut *s, StageState::Failed);
876 *s = new_state(old_state);
877 }
878
879 self.send_event(msg).await;
880 }
881
882 async fn cancel_all_scheducancled_tasks(&self) -> SchedulerResult<()> {
886 for (task, task_status) in &*self.tasks {
898 let loc = &task_status.get_status().location;
900 let addr = loc.as_ref().expect("Get address should not fail");
901 let client = self
902 .compute_client_pool
903 .get_by_addr(HostAddr::from(addr))
904 .await
905 .map_err(|e| anyhow!(e))?;
906
907 let query_id = self.query.query_id.id.clone();
909 let stage_id = self.stage_id;
910 let task_id = *task;
911 spawn(async move {
912 if let Err(e) = client
913 .cancel(CancelTaskRequest {
914 task_id: Some(risingwave_pb::batch_plan::TaskId {
915 query_id: query_id.clone(),
916 stage_id: stage_id.into(),
917 task_id,
918 }),
919 })
920 .await
921 {
922 error!(
923 error = %e.as_report(),
924 ?task_id,
925 ?query_id,
926 ?stage_id,
927 "Abort task failed",
928 );
929 };
930 });
931 }
932 Ok(())
933 }
934
935 async fn schedule_task(
936 &self,
937 task_id: PbTaskId,
938 plan_fragment: PlanFragment,
939 worker: Option<WorkerNode>,
940 expr_context: ExprContext,
941 ) -> SchedulerResult<Fuse<Streaming<TaskInfoResponse>>> {
942 let mut worker = worker.unwrap_or(self.worker_node_manager.next_random_worker()?);
943 let worker_node_addr = worker.host.take().unwrap();
944 let compute_client = self
945 .compute_client_pool
946 .get_by_addr((&worker_node_addr).into())
947 .await
948 .inspect_err(|_| self.mask_failed_serving_worker(&worker))
949 .map_err(|e| anyhow!(e))?;
950
951 let t_id = task_id.task_id;
952
953 let stream_status: Fuse<Streaming<TaskInfoResponse>> = compute_client
954 .create_task(task_id, plan_fragment, expr_context)
955 .await
956 .inspect_err(|_| self.mask_failed_serving_worker(&worker))
957 .map_err(|e| anyhow!(e))?
958 .fuse();
959
960 self.tasks[&t_id].inner.store(Arc::new(TaskStatus {
961 _task_id: t_id,
962 location: Some(worker_node_addr),
963 }));
964
965 Ok(stream_status)
966 }
967
968 fn create_plan_fragment(
969 &self,
970 task_id: TaskId,
971 partition: Option<PartitionInfo>,
972 ) -> PlanFragment {
973 let mut identity_id = 0;
975
976 let stage = &self.query.stage(self.stage_id);
977
978 let plan_node_prost =
979 self.convert_plan_node(&stage.root, task_id, partition, &mut identity_id);
980 let exchange_info = stage.exchange_info.clone().unwrap();
981
982 PlanFragment {
983 root: Some(plan_node_prost),
984 exchange_info: Some(exchange_info),
985 }
986 }
987
988 fn convert_plan_node(
989 &self,
990 execution_plan_node: &ExecutionPlanNode,
991 task_id: TaskId,
992 partition: Option<PartitionInfo>,
993 identity_id: &mut u64,
994 ) -> PbPlanNode {
995 let identity = {
997 let identity_type = execution_plan_node.plan_node_type;
998 let id = *identity_id;
999 *identity_id += 1;
1000 format!("{:?}-{}", identity_type, id)
1001 };
1002
1003 match execution_plan_node.plan_node_type {
1004 BatchPlanNodeType::BatchExchange => {
1005 let child_stage = self
1007 .children
1008 .iter()
1009 .find(|child_stage| {
1010 child_stage.stage_id == execution_plan_node.source_stage_id.unwrap()
1011 })
1012 .unwrap();
1013 let exchange_sources = child_stage.all_exchange_sources_for(task_id);
1014
1015 match &execution_plan_node.node {
1016 NodeBody::Exchange(exchange_node) => PbPlanNode {
1017 children: vec![],
1018 identity,
1019 node_body: Some(NodeBody::Exchange(ExchangeNode {
1020 sources: exchange_sources,
1021 sequential: exchange_node.sequential,
1022 input_schema: execution_plan_node.schema.clone(),
1023 })),
1024 },
1025 NodeBody::MergeSortExchange(sort_merge_exchange_node) => PbPlanNode {
1026 children: vec![],
1027 identity,
1028 node_body: Some(NodeBody::MergeSortExchange(MergeSortExchangeNode {
1029 exchange: Some(ExchangeNode {
1030 sources: exchange_sources,
1031 sequential: false,
1032 input_schema: execution_plan_node.schema.clone(),
1033 }),
1034 column_orders: sort_merge_exchange_node.column_orders.clone(),
1035 })),
1036 },
1037 _ => unreachable!(),
1038 }
1039 }
1040 BatchPlanNodeType::BatchSeqScan => {
1041 let node_body = execution_plan_node.node.clone();
1042 let NodeBody::RowSeqScan(mut scan_node) = node_body else {
1043 unreachable!();
1044 };
1045 let partition = partition
1046 .expect("no partition info for seq scan")
1047 .into_table()
1048 .expect("PartitionInfo should be TablePartitionInfo");
1049 scan_node.vnode_bitmap = Some(partition.vnode_bitmap.to_protobuf());
1050 scan_node.scan_ranges = partition.scan_ranges;
1051 PbPlanNode {
1052 children: vec![],
1053 identity,
1054 node_body: Some(NodeBody::RowSeqScan(scan_node)),
1055 }
1056 }
1057 BatchPlanNodeType::BatchLogSeqScan => {
1058 let node_body = execution_plan_node.node.clone();
1059 let NodeBody::LogRowSeqScan(mut scan_node) = node_body else {
1060 unreachable!();
1061 };
1062 let partition = partition
1063 .expect("no partition info for seq scan")
1064 .into_table()
1065 .expect("PartitionInfo should be TablePartitionInfo");
1066 scan_node.vnode_bitmap = Some(partition.vnode_bitmap.to_protobuf());
1067 PbPlanNode {
1068 children: vec![],
1069 identity,
1070 node_body: Some(NodeBody::LogRowSeqScan(scan_node)),
1071 }
1072 }
1073 BatchPlanNodeType::BatchSource | BatchPlanNodeType::BatchKafkaScan => {
1074 let node_body = execution_plan_node.node.clone();
1075 let NodeBody::Source(mut source_node) = node_body else {
1076 unreachable!();
1077 };
1078
1079 let partition = partition
1080 .expect("no partition info for seq scan")
1081 .into_source()
1082 .expect("PartitionInfo should be SourcePartitionInfo");
1083 source_node.split = partition
1084 .into_iter()
1085 .map(|split| split.encode_to_bytes().into())
1086 .collect_vec();
1087 PbPlanNode {
1088 children: vec![],
1089 identity,
1090 node_body: Some(NodeBody::Source(source_node)),
1091 }
1092 }
1093 BatchPlanNodeType::BatchIcebergScan => {
1094 let node_body = execution_plan_node.node.clone();
1095 let NodeBody::IcebergScan(mut iceberg_scan_node) = node_body else {
1096 unreachable!();
1097 };
1098
1099 let partition = partition
1100 .expect("no partition info for seq scan")
1101 .into_source()
1102 .expect("PartitionInfo should be SourcePartitionInfo");
1103 iceberg_scan_node.split = partition
1104 .into_iter()
1105 .map(|split| split.encode_to_bytes().into())
1106 .collect_vec();
1107 PbPlanNode {
1108 children: vec![],
1109 identity,
1110 node_body: Some(NodeBody::IcebergScan(iceberg_scan_node)),
1111 }
1112 }
1113 _ => {
1114 let children = execution_plan_node
1115 .children
1116 .iter()
1117 .map(|e| self.convert_plan_node(e, task_id, partition.clone(), identity_id))
1118 .collect();
1119
1120 PbPlanNode {
1121 children,
1122 identity,
1123 node_body: Some(execution_plan_node.node.clone()),
1124 }
1125 }
1126 }
1127 }
1128
1129 fn is_root_stage(&self) -> bool {
1130 self.stage_id == 0.into()
1131 }
1132
1133 fn mask_failed_serving_worker(&self, worker: &WorkerNode) {
1134 if !worker.property.as_ref().is_some_and(|p| p.is_serving) {
1135 return;
1136 }
1137 let duration = Duration::from_secs(std::cmp::max(
1138 self.ctx
1139 .session
1140 .env()
1141 .batch_config()
1142 .mask_worker_temporary_secs as u64,
1143 1,
1144 ));
1145 self.worker_node_manager
1146 .manager
1147 .mask_worker_node(worker.id, duration);
1148 }
1149}