1use std::collections::HashMap;
16use std::fmt::{Debug, Formatter};
17use std::mem;
18use std::sync::Arc;
19
20use anyhow::Context;
21use futures::executor::block_on;
22use petgraph::Graph;
23use petgraph::dot::{Config, Dot};
24use pgwire::pg_server::SessionId;
25use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeSelector;
26use risingwave_common::array::DataChunk;
27use risingwave_pb::batch_plan::{TaskId as PbTaskId, TaskOutputId as PbTaskOutputId};
28use risingwave_pb::common::{BatchQueryEpoch, HostAddress};
29use risingwave_rpc_client::ComputeClientPoolRef;
30use thiserror_ext::AsReport;
31use tokio::sync::mpsc::{Receiver, Sender, channel};
32use tokio::sync::{RwLock, oneshot};
33use tokio::task::JoinHandle;
34use tracing::{Instrument, debug, error, info, warn};
35
36use super::{DistributedQueryMetrics, QueryExecutionInfoRef, QueryResultFetcher, StageEvent};
37use crate::catalog::catalog_service::CatalogReader;
38use crate::scheduler::distributed::StageEvent::Scheduled;
39use crate::scheduler::distributed::StageExecution;
40use crate::scheduler::distributed::query::QueryMessage::Stage;
41use crate::scheduler::distributed::stage::StageEvent::ScheduledRoot;
42use crate::scheduler::plan_fragmenter::{Query, ROOT_TASK_ID, ROOT_TASK_OUTPUT_ID, StageId};
43use crate::scheduler::{ExecutionContextRef, SchedulerError, SchedulerResult};
44
45#[derive(Debug)]
47pub enum QueryMessage {
48 Stage(StageEvent),
50 CancelQuery(String),
52}
53
54enum QueryState {
55 Pending {
59 msg_receiver: Receiver<QueryMessage>,
60 },
61
62 Running,
64
65 Failed,
67}
68
69pub struct QueryExecution {
70 query: Arc<Query>,
71 state: RwLock<QueryState>,
72 shutdown_tx: Sender<QueryMessage>,
73 pub session_id: SessionId,
75 #[expect(dead_code)]
77 pub permit: Option<tokio::sync::OwnedSemaphorePermit>,
78}
79
80struct QueryRunner {
81 query: Arc<Query>,
82 stage_executions: HashMap<StageId, Arc<StageExecution>>,
83 scheduled_stages_count: usize,
84 msg_receiver: Receiver<QueryMessage>,
86
87 root_stage_sender: Option<oneshot::Sender<SchedulerResult<QueryResultFetcher>>>,
89
90 query_execution_info: QueryExecutionInfoRef,
92
93 query_metrics: Arc<DistributedQueryMetrics>,
94 timeout_abort_task_handle: Option<JoinHandle<()>>,
95}
96
97impl QueryExecution {
98 #[allow(clippy::too_many_arguments)]
99 pub fn new(
100 query: Query,
101 session_id: SessionId,
102 permit: Option<tokio::sync::OwnedSemaphorePermit>,
103 ) -> Self {
104 let query = Arc::new(query);
105 let (sender, receiver) = channel(100);
106 let state = QueryState::Pending {
107 msg_receiver: receiver,
108 };
109
110 Self {
111 query,
112 state: RwLock::new(state),
113 shutdown_tx: sender,
114 session_id,
115 permit,
116 }
117 }
118
119 #[allow(clippy::too_many_arguments)]
124 pub async fn start(
125 self: Arc<Self>,
126 context: ExecutionContextRef,
127 worker_node_manager: WorkerNodeSelector,
128 batch_query_epoch: BatchQueryEpoch,
129 compute_client_pool: ComputeClientPoolRef,
130 catalog_reader: CatalogReader,
131 query_execution_info: QueryExecutionInfoRef,
132 query_metrics: Arc<DistributedQueryMetrics>,
133 ) -> SchedulerResult<QueryResultFetcher> {
134 let mut state = self.state.write().await;
135 let cur_state = mem::replace(&mut *state, QueryState::Failed);
136
137 let stage_executions = self.gen_stage_executions(
141 batch_query_epoch,
142 context.clone(),
143 worker_node_manager,
144 compute_client_pool.clone(),
145 catalog_reader,
146 );
147
148 match cur_state {
149 QueryState::Pending { msg_receiver } => {
150 *state = QueryState::Running;
151
152 let mut timeout_abort_task_handle: Option<JoinHandle<()>> = None;
154 if let Some(timeout) = context.timeout() {
155 let this = self.clone();
156 timeout_abort_task_handle = Some(tokio::spawn(async move {
157 tokio::time::sleep(timeout).await;
158 warn!(
159 "Query {:?} timeout after {} seconds, sending cancel message.",
160 this.query.query_id,
161 timeout.as_secs(),
162 );
163 this.abort(format!("timeout after {} seconds", timeout.as_secs()))
164 .await;
165 }));
166 }
167
168 let (root_stage_sender, root_stage_receiver) =
170 oneshot::channel::<SchedulerResult<QueryResultFetcher>>();
171
172 let runner = QueryRunner {
173 query: self.query.clone(),
174 stage_executions,
175 msg_receiver,
176 root_stage_sender: Some(root_stage_sender),
177 scheduled_stages_count: 0,
178 query_execution_info,
179 query_metrics,
180 timeout_abort_task_handle,
181 };
182
183 let span = tracing::info_span!(
184 "distributed_execute",
185 query_id = self.query.query_id.id,
186 epoch = ?batch_query_epoch,
187 );
188
189 tracing::trace!("Starting query: {:?}", self.query.query_id);
190
191 tokio::spawn(async move { runner.run().instrument(span).await });
193
194 let root_stage = root_stage_receiver
195 .await
196 .context("Starting query execution failed")??;
197
198 tracing::trace!(
199 "Received root stage query result fetcher: {:?}, query id: {:?}",
200 root_stage,
201 self.query.query_id
202 );
203
204 tracing::trace!("Query {:?} started.", self.query.query_id);
205 Ok(root_stage)
206 }
207 _ => {
208 unreachable!("The query runner should not be scheduled twice");
209 }
210 }
211 }
212
213 pub async fn abort(self: Arc<Self>, reason: String) {
215 if self
216 .shutdown_tx
217 .send(QueryMessage::CancelQuery(reason))
218 .await
219 .is_err()
220 {
221 warn!("Send cancel query request failed: the query has ended");
222 } else {
223 info!("Send cancel request to query-{:?}", self.query.query_id);
224 };
225 }
226
227 fn gen_stage_executions(
228 &self,
229 epoch: BatchQueryEpoch,
230 context: ExecutionContextRef,
231 worker_node_manager: WorkerNodeSelector,
232 compute_client_pool: ComputeClientPoolRef,
233 catalog_reader: CatalogReader,
234 ) -> HashMap<StageId, Arc<StageExecution>> {
235 let mut stage_executions: HashMap<StageId, Arc<StageExecution>> =
236 HashMap::with_capacity(self.query.stage_graph.stages.len());
237
238 for stage_id in self.query.stage_graph.stage_ids_by_topo_order() {
239 let children_stages = self
240 .query
241 .stage_graph
242 .get_child_stages_unchecked(&stage_id)
243 .iter()
244 .map(|s| stage_executions[s].clone())
245 .collect::<Vec<Arc<StageExecution>>>();
246
247 let stage_exec = Arc::new(StageExecution::new(
248 epoch,
249 self.query.stage_graph.stages[&stage_id].clone(),
250 worker_node_manager.clone(),
251 self.shutdown_tx.clone(),
252 children_stages,
253 compute_client_pool.clone(),
254 catalog_reader.clone(),
255 context.clone(),
256 ));
257 stage_executions.insert(stage_id, stage_exec);
258 }
259 stage_executions
260 }
261}
262
263impl Drop for QueryRunner {
264 fn drop(&mut self) {
265 self.query_metrics.running_query_num.dec();
266 self.timeout_abort_task_handle
267 .as_ref()
268 .inspect(|h| h.abort());
269 }
270}
271
272impl Debug for QueryRunner {
273 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
274 let mut graph = Graph::<String, String>::new();
275 let mut stage_id_to_node_id = HashMap::new();
276 for stage in &self.stage_executions {
277 let node_id = graph.add_node(format!("{} {}", stage.0, block_on(stage.1.state())));
278 stage_id_to_node_id.insert(stage.0, node_id);
279 }
280
281 for stage in &self.stage_executions {
282 let stage_id = stage.0;
283 if let Some(child_stages) = self.query.stage_graph.get_child_stages(stage_id) {
284 for child_stage in child_stages {
285 graph.add_edge(
286 *stage_id_to_node_id.get(stage_id).unwrap(),
287 *stage_id_to_node_id.get(child_stage).unwrap(),
288 "".to_owned(),
289 );
290 }
291 }
292 }
293
294 writeln!(f, "{}", Dot::with_config(&graph, &[Config::EdgeNoLabel]))
296 }
297}
298
299impl QueryRunner {
300 async fn run(mut self) {
301 self.query_metrics.running_query_num.inc();
302 let leaf_stages = self.query.leaf_stages();
304 for stage_id in &leaf_stages {
305 self.stage_executions[stage_id].start().await;
306 tracing::trace!(
307 "Query stage {:?}-{:?} started.",
308 self.query.query_id,
309 stage_id
310 );
311 }
312 let mut stages_with_table_scan = self.query.stages_with_table_scan();
313 let has_lookup_join_stage = self.query.has_lookup_join_stage();
314
315 let mut finished_stage_cnt = 0usize;
316 while let Some(msg_inner) = self.msg_receiver.recv().await {
317 match msg_inner {
318 Stage(Scheduled(stage_id)) => {
319 tracing::trace!(
320 "Query stage {:?}-{:?} scheduled.",
321 self.query.query_id,
322 stage_id
323 );
324 self.scheduled_stages_count += 1;
325 stages_with_table_scan.remove(&stage_id);
326 if !has_lookup_join_stage && stages_with_table_scan.is_empty() {
329 tracing::trace!(
333 "Query {:?} has scheduled all of its stages that have table scan (iterator creation).",
334 self.query.query_id
335 );
336 }
337
338 for parent in self.query.get_parents(&stage_id) {
341 if self.all_children_scheduled(parent).await
342 && self.stage_executions[parent].is_pending().await
344 {
345 self.stage_executions[parent].start().await;
346 }
347 }
348 }
349 Stage(ScheduledRoot(receiver)) => {
350 self.send_root_stage_info(receiver);
353 }
354 Stage(StageEvent::Failed { id, reason }) => {
355 error!(
356 error = %reason.as_report(),
357 query_id = ?self.query.query_id,
358 stage_id = ?id,
359 "query stage failed"
360 );
361
362 self.clean_all_stages(Some(reason)).await;
363 break;
365 }
366 Stage(StageEvent::Completed(_)) => {
367 finished_stage_cnt += 1;
368 assert!(finished_stage_cnt <= self.stage_executions.len());
369 if finished_stage_cnt == self.stage_executions.len() {
370 tracing::trace!(
371 "Query {:?} completed, starting to clean stage tasks.",
372 &self.query.query_id
373 );
374 self.clean_all_stages(None).await;
376 break;
377 }
378 }
379 QueryMessage::CancelQuery(reason) => {
380 self.clean_all_stages(Some(SchedulerError::QueryCancelled(reason)))
381 .await;
382 break;
384 }
385 }
386 }
387 }
388
389 fn send_root_stage_info(&mut self, chunk_rx: Receiver<SchedulerResult<DataChunk>>) {
392 let root_task_output_id = {
393 let root_task_id_prost = PbTaskId {
394 query_id: self.query.query_id.clone().id,
395 stage_id: self.query.root_stage_id(),
396 task_id: ROOT_TASK_ID,
397 };
398
399 PbTaskOutputId {
400 task_id: Some(root_task_id_prost),
401 output_id: ROOT_TASK_OUTPUT_ID,
402 }
403 };
404
405 let root_stage_result = QueryResultFetcher::new(
406 root_task_output_id,
407 HostAddress::default(),
409 chunk_rx,
410 self.query.query_id.clone(),
411 self.query_execution_info.clone(),
412 );
413
414 let root_stage_sender = mem::take(&mut self.root_stage_sender);
416
417 if let Err(e) = root_stage_sender.unwrap().send(Ok(root_stage_result)) {
418 warn!("Query execution dropped: {:?}", e);
419 } else {
420 debug!("Root stage for {:?} sent.", self.query.query_id);
421 }
422 }
423
424 async fn all_children_scheduled(&self, stage_id: &StageId) -> bool {
425 for child in self.query.stage_graph.get_child_stages_unchecked(stage_id) {
426 if !self.stage_executions[child].is_scheduled().await {
427 return false;
428 }
429 }
430 true
431 }
432
433 async fn clean_all_stages(&mut self, error: Option<SchedulerError>) {
436 let error_msg = error.as_ref().map(|e| e.to_report_string());
438 if let Some(reason) = error {
439 let root_stage_sender = mem::take(&mut self.root_stage_sender);
441 if let Some(sender) = root_stage_sender {
444 if let Err(e) = sender.send(Err(reason)) {
445 warn!("Query execution dropped: {:?}", e);
446 } else {
447 debug!(
448 "Root stage failure event for {:?} sent.",
449 self.query.query_id
450 );
451 }
452 }
453
454 }
457
458 tracing::trace!("Cleaning stages in query [{:?}]", self.query.query_id);
459 for stage_execution in self.stage_executions.values() {
461 stage_execution.stop(error_msg.clone()).await;
463 }
464 }
465}
466
467#[cfg(test)]
468pub(crate) mod tests {
469 use std::collections::{HashMap, HashSet};
470 use std::sync::{Arc, RwLock};
471
472 use fixedbitset::FixedBitSet;
473 use risingwave_batch::worker_manager::worker_node_manager::{
474 WorkerNodeManager, WorkerNodeSelector,
475 };
476 use risingwave_common::catalog::{
477 ColumnCatalog, ColumnDesc, ConflictBehavior, CreateType, DEFAULT_SUPER_USER_ID, Engine,
478 StreamJobStatus,
479 };
480 use risingwave_common::hash::{VirtualNode, VnodeCount, WorkerSlotId, WorkerSlotMapping};
481 use risingwave_common::types::DataType;
482 use risingwave_pb::common::worker_node::Property;
483 use risingwave_pb::common::{HostAddress, WorkerNode, WorkerType};
484 use risingwave_pb::plan_common::JoinType;
485 use risingwave_rpc_client::ComputeClientPool;
486
487 use crate::TableCatalog;
488 use crate::catalog::catalog_service::CatalogReader;
489 use crate::catalog::root_catalog::Catalog;
490 use crate::catalog::table_catalog::TableType;
491 use crate::expr::InputRef;
492 use crate::optimizer::plan_node::{
493 BatchExchange, BatchFilter, BatchHashJoin, EqJoinPredicate, LogicalScan, ToBatch, generic,
494 };
495 use crate::optimizer::property::{Cardinality, Distribution, Order};
496 use crate::optimizer::{OptimizerContext, PlanRef};
497 use crate::scheduler::distributed::QueryExecution;
498 use crate::scheduler::plan_fragmenter::{BatchPlanFragmenter, Query};
499 use crate::scheduler::{
500 DistributedQueryMetrics, ExecutionContext, QueryExecutionInfo, ReadSnapshot,
501 };
502 use crate::session::SessionImpl;
503 use crate::utils::Condition;
504
505 #[tokio::test]
506 async fn test_query_should_not_hang_with_empty_worker() {
507 let worker_node_manager = Arc::new(WorkerNodeManager::mock(vec![]));
508 let worker_node_selector = WorkerNodeSelector::new(worker_node_manager.clone(), false);
509 let compute_client_pool = Arc::new(ComputeClientPool::for_test());
510 let catalog_reader =
511 CatalogReader::new(Arc::new(parking_lot::RwLock::new(Catalog::default())));
512 let query = create_query().await;
513 let query_id = query.query_id().clone();
514 let query_execution = Arc::new(QueryExecution::new(query, (0, 0), None));
515 let query_execution_info = Arc::new(RwLock::new(QueryExecutionInfo::new_from_map(
516 HashMap::from([(query_id, query_execution.clone())]),
517 )));
518
519 assert!(
520 query_execution
521 .start(
522 ExecutionContext::new(SessionImpl::mock().into(), None).into(),
523 worker_node_selector,
524 ReadSnapshot::ReadUncommitted
525 .batch_query_epoch(&HashSet::from_iter([0.into()]))
526 .unwrap(),
527 compute_client_pool,
528 catalog_reader,
529 query_execution_info,
530 Arc::new(DistributedQueryMetrics::for_test()),
531 )
532 .await
533 .is_err()
534 );
535 }
536
537 pub async fn create_query() -> Query {
538 let ctx = OptimizerContext::mock().await;
546 let table_id = 0.into();
547 let vnode_count = VirtualNode::COUNT_FOR_TEST;
548
549 let table_catalog: TableCatalog = TableCatalog {
550 id: table_id,
551 schema_id: 0,
552 database_id: 0,
553 associated_source_id: None,
554 name: "test".to_owned(),
555 dependent_relations: vec![],
556 columns: vec![
557 ColumnCatalog {
558 column_desc: ColumnDesc::named("a", 0.into(), DataType::Int32),
559 is_hidden: false,
560 },
561 ColumnCatalog {
562 column_desc: ColumnDesc::named("b", 1.into(), DataType::Float64),
563 is_hidden: false,
564 },
565 ColumnCatalog {
566 column_desc: ColumnDesc::named("c", 2.into(), DataType::Int64),
567 is_hidden: false,
568 },
569 ],
570 pk: vec![],
571 stream_key: vec![],
572 table_type: TableType::Table,
573 distribution_key: vec![],
574 append_only: false,
575 owner: DEFAULT_SUPER_USER_ID,
576 retention_seconds: None,
577 fragment_id: 0, dml_fragment_id: None, vnode_col_index: None,
580 row_id_index: None,
581 value_indices: vec![0, 1, 2],
582 definition: "".to_owned(),
583 conflict_behavior: ConflictBehavior::NoCheck,
584 version_column_index: None,
585 read_prefix_len_hint: 0,
586 version: None,
587 watermark_columns: FixedBitSet::with_capacity(3),
588 dist_key_in_pk: vec![],
589 cardinality: Cardinality::unknown(),
590 cleaned_by_watermark: false,
591 created_at_epoch: None,
592 initialized_at_epoch: None,
593 stream_job_status: StreamJobStatus::Creating,
594 create_type: CreateType::Foreground,
595 description: None,
596 incoming_sinks: vec![],
597 initialized_at_cluster_version: None,
598 created_at_cluster_version: None,
599 cdc_table_id: None,
600 vnode_count: VnodeCount::set(vnode_count),
601 webhook_info: None,
602 job_id: None,
603 engine: Engine::Hummock,
604 clean_watermark_index_in_pk: None,
605 };
606 let batch_plan_node: PlanRef = LogicalScan::create(
607 "".to_owned(),
608 table_catalog.into(),
609 vec![],
610 ctx,
611 None,
612 Cardinality::unknown(),
613 )
614 .to_batch()
615 .unwrap()
616 .to_distributed()
617 .unwrap();
618 let batch_filter = BatchFilter::new(generic::Filter::new(
619 Condition {
620 conjunctions: vec![],
621 },
622 batch_plan_node.clone(),
623 ))
624 .into();
625 let batch_exchange_node1: PlanRef = BatchExchange::new(
626 batch_plan_node.clone(),
627 Order::default(),
628 Distribution::HashShard(vec![0, 1]),
629 )
630 .into();
631 let batch_exchange_node2: PlanRef = BatchExchange::new(
632 batch_filter,
633 Order::default(),
634 Distribution::HashShard(vec![0, 1]),
635 )
636 .into();
637 let logical_join_node = generic::Join::with_full_output(
638 batch_exchange_node1.clone(),
639 batch_exchange_node2.clone(),
640 JoinType::Inner,
641 Condition::true_cond(),
642 );
643 let eq_key_1 = (
644 InputRef {
645 index: 0,
646 data_type: DataType::Int32,
647 },
648 InputRef {
649 index: 2,
650 data_type: DataType::Int32,
651 },
652 false,
653 );
654 let eq_key_2 = (
655 InputRef {
656 index: 1,
657 data_type: DataType::Float64,
658 },
659 InputRef {
660 index: 3,
661 data_type: DataType::Float64,
662 },
663 false,
664 );
665 let eq_join_predicate =
666 EqJoinPredicate::new(Condition::true_cond(), vec![eq_key_1, eq_key_2], 2, 2);
667 let hash_join_node: PlanRef =
668 BatchHashJoin::new(logical_join_node, eq_join_predicate, None).into();
669 let batch_exchange_node: PlanRef = BatchExchange::new(
670 hash_join_node.clone(),
671 Order::default(),
672 Distribution::Single,
673 )
674 .into();
675
676 let worker1 = WorkerNode {
677 id: 0,
678 r#type: WorkerType::ComputeNode as i32,
679 host: Some(HostAddress {
680 host: "127.0.0.1".to_owned(),
681 port: 5687,
682 }),
683 state: risingwave_pb::common::worker_node::State::Running as i32,
684 property: Some(Property {
685 parallelism: 8,
686 is_unschedulable: false,
687 is_serving: true,
688 is_streaming: true,
689 ..Default::default()
690 }),
691 transactional_id: Some(0),
692 ..Default::default()
693 };
694 let worker2 = WorkerNode {
695 id: 1,
696 r#type: WorkerType::ComputeNode as i32,
697 host: Some(HostAddress {
698 host: "127.0.0.1".to_owned(),
699 port: 5688,
700 }),
701 state: risingwave_pb::common::worker_node::State::Running as i32,
702 property: Some(Property {
703 parallelism: 8,
704 is_unschedulable: false,
705 is_serving: true,
706 is_streaming: true,
707 ..Default::default()
708 }),
709 transactional_id: Some(1),
710 ..Default::default()
711 };
712 let worker3 = WorkerNode {
713 id: 2,
714 r#type: WorkerType::ComputeNode as i32,
715 host: Some(HostAddress {
716 host: "127.0.0.1".to_owned(),
717 port: 5689,
718 }),
719 state: risingwave_pb::common::worker_node::State::Running as i32,
720 property: Some(Property {
721 parallelism: 8,
722 is_unschedulable: false,
723 is_serving: true,
724 is_streaming: true,
725 ..Default::default()
726 }),
727 transactional_id: Some(2),
728 ..Default::default()
729 };
730 let workers = vec![worker1, worker2, worker3];
731 let worker_node_manager = Arc::new(WorkerNodeManager::mock(workers));
732 let worker_node_selector = WorkerNodeSelector::new(worker_node_manager.clone(), false);
733 let mapping =
734 WorkerSlotMapping::new_uniform(std::iter::once(WorkerSlotId::new(0, 0)), vnode_count);
735 worker_node_manager.insert_streaming_fragment_mapping(0, mapping.clone());
736 worker_node_manager.set_serving_fragment_mapping(vec![(0, mapping)].into_iter().collect());
737 let catalog = Arc::new(parking_lot::RwLock::new(Catalog::default()));
738 catalog.write().insert_table_id_mapping(table_id, 0);
739 let catalog_reader = CatalogReader::new(catalog);
740 let fragmenter = BatchPlanFragmenter::new(
742 worker_node_selector,
743 catalog_reader,
744 None,
745 "UTC".to_owned(),
746 batch_exchange_node.clone(),
747 )
748 .unwrap();
749 fragmenter.generate_complete_query().await.unwrap()
750 }
751}