1use std::collections::HashMap;
16use std::fmt::{Debug, Formatter};
17use std::mem;
18use std::sync::Arc;
19
20use anyhow::Context;
21use futures::executor::block_on;
22use petgraph::Graph;
23use petgraph::dot::{Config, Dot};
24use pgwire::pg_server::SessionId;
25use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeSelector;
26use risingwave_common::array::DataChunk;
27use risingwave_pb::batch_plan::{TaskId as PbTaskId, TaskOutputId as PbTaskOutputId};
28use risingwave_pb::common::{BatchQueryEpoch, HostAddress};
29use risingwave_rpc_client::ComputeClientPoolRef;
30use thiserror_ext::AsReport;
31use tokio::sync::mpsc::{Receiver, Sender, channel};
32use tokio::sync::{RwLock, oneshot};
33use tokio::task::JoinHandle;
34use tracing::{Instrument, debug, error, info, warn};
35
36use super::{DistributedQueryMetrics, QueryExecutionInfoRef, QueryResultFetcher, StageEvent};
37use crate::catalog::catalog_service::CatalogReader;
38use crate::scheduler::distributed::StageEvent::Scheduled;
39use crate::scheduler::distributed::StageExecution;
40use crate::scheduler::distributed::query::QueryMessage::Stage;
41use crate::scheduler::distributed::stage::StageEvent::ScheduledRoot;
42use crate::scheduler::plan_fragmenter::{Query, ROOT_TASK_ID, ROOT_TASK_OUTPUT_ID, StageId};
43use crate::scheduler::{ExecutionContextRef, SchedulerError, SchedulerResult};
44
45#[derive(Debug)]
47pub enum QueryMessage {
48 Stage(StageEvent),
50 CancelQuery(String),
52}
53
54enum QueryState {
55 Pending {
59 msg_receiver: Receiver<QueryMessage>,
60 },
61
62 Running,
64
65 Failed,
67}
68
69pub struct QueryExecution {
70 query: Arc<Query>,
71 state: RwLock<QueryState>,
72 shutdown_tx: Sender<QueryMessage>,
73 pub session_id: SessionId,
75 #[expect(dead_code)]
77 pub permit: Option<tokio::sync::OwnedSemaphorePermit>,
78}
79
80struct QueryRunner {
81 query: Arc<Query>,
82 stage_executions: HashMap<StageId, Arc<StageExecution>>,
83 scheduled_stages_count: usize,
84 msg_receiver: Receiver<QueryMessage>,
86
87 root_stage_sender: Option<oneshot::Sender<SchedulerResult<QueryResultFetcher>>>,
89
90 query_execution_info: QueryExecutionInfoRef,
92
93 query_metrics: Arc<DistributedQueryMetrics>,
94 timeout_abort_task_handle: Option<JoinHandle<()>>,
95}
96
97impl QueryExecution {
98 #[allow(clippy::too_many_arguments)]
99 pub fn new(
100 query: Query,
101 session_id: SessionId,
102 permit: Option<tokio::sync::OwnedSemaphorePermit>,
103 ) -> Self {
104 let query = Arc::new(query);
105 let (sender, receiver) = channel(100);
106 let state = QueryState::Pending {
107 msg_receiver: receiver,
108 };
109
110 Self {
111 query,
112 state: RwLock::new(state),
113 shutdown_tx: sender,
114 session_id,
115 permit,
116 }
117 }
118
119 #[allow(clippy::too_many_arguments)]
124 pub async fn start(
125 self: Arc<Self>,
126 context: ExecutionContextRef,
127 worker_node_manager: WorkerNodeSelector,
128 batch_query_epoch: BatchQueryEpoch,
129 compute_client_pool: ComputeClientPoolRef,
130 catalog_reader: CatalogReader,
131 query_execution_info: QueryExecutionInfoRef,
132 query_metrics: Arc<DistributedQueryMetrics>,
133 ) -> SchedulerResult<QueryResultFetcher> {
134 let mut state = self.state.write().await;
135 let cur_state = mem::replace(&mut *state, QueryState::Failed);
136
137 let stage_executions = self.gen_stage_executions(
141 batch_query_epoch,
142 context.clone(),
143 worker_node_manager,
144 compute_client_pool.clone(),
145 catalog_reader,
146 );
147
148 match cur_state {
149 QueryState::Pending { msg_receiver } => {
150 *state = QueryState::Running;
151
152 let mut timeout_abort_task_handle: Option<JoinHandle<()>> = None;
154 if let Some(timeout) = context.timeout() {
155 let this = self.clone();
156 timeout_abort_task_handle = Some(tokio::spawn(async move {
157 tokio::time::sleep(timeout).await;
158 warn!(
159 "Query {:?} timeout after {} seconds, sending cancel message.",
160 this.query.query_id,
161 timeout.as_secs(),
162 );
163 this.abort(format!("timeout after {} seconds", timeout.as_secs()))
164 .await;
165 }));
166 }
167
168 let (root_stage_sender, root_stage_receiver) =
170 oneshot::channel::<SchedulerResult<QueryResultFetcher>>();
171
172 let runner = QueryRunner {
173 query: self.query.clone(),
174 stage_executions,
175 msg_receiver,
176 root_stage_sender: Some(root_stage_sender),
177 scheduled_stages_count: 0,
178 query_execution_info,
179 query_metrics,
180 timeout_abort_task_handle,
181 };
182
183 let span = tracing::info_span!(
184 "distributed_execute",
185 query_id = self.query.query_id.id,
186 epoch = ?batch_query_epoch,
187 );
188
189 tracing::trace!("Starting query: {:?}", self.query.query_id);
190
191 tokio::spawn(async move { runner.run().instrument(span).await });
193
194 let root_stage = root_stage_receiver
195 .await
196 .context("Starting query execution failed")??;
197
198 tracing::trace!(
199 "Received root stage query result fetcher: {:?}, query id: {:?}",
200 root_stage,
201 self.query.query_id
202 );
203
204 tracing::trace!("Query {:?} started.", self.query.query_id);
205 Ok(root_stage)
206 }
207 _ => {
208 unreachable!("The query runner should not be scheduled twice");
209 }
210 }
211 }
212
213 pub async fn abort(self: Arc<Self>, reason: String) {
215 if self
216 .shutdown_tx
217 .send(QueryMessage::CancelQuery(reason))
218 .await
219 .is_err()
220 {
221 warn!("Send cancel query request failed: the query has ended");
222 } else {
223 info!("Send cancel request to query-{:?}", self.query.query_id);
224 };
225 }
226
227 fn gen_stage_executions(
228 &self,
229 epoch: BatchQueryEpoch,
230 context: ExecutionContextRef,
231 worker_node_manager: WorkerNodeSelector,
232 compute_client_pool: ComputeClientPoolRef,
233 catalog_reader: CatalogReader,
234 ) -> HashMap<StageId, Arc<StageExecution>> {
235 let mut stage_executions: HashMap<StageId, Arc<StageExecution>> =
236 HashMap::with_capacity(self.query.stage_graph.stages.len());
237
238 for stage_id in self.query.stage_graph.stage_ids_by_topo_order() {
239 let children_stages = self
240 .query
241 .stage_graph
242 .get_child_stages_unchecked(&stage_id)
243 .iter()
244 .map(|s| stage_executions[s].clone())
245 .collect::<Vec<Arc<StageExecution>>>();
246
247 let stage_exec = Arc::new(StageExecution::new(
248 epoch,
249 self.query.stage_graph.stages[&stage_id].clone(),
250 worker_node_manager.clone(),
251 self.shutdown_tx.clone(),
252 children_stages,
253 compute_client_pool.clone(),
254 catalog_reader.clone(),
255 context.clone(),
256 ));
257 stage_executions.insert(stage_id, stage_exec);
258 }
259 stage_executions
260 }
261}
262
263impl Drop for QueryRunner {
264 fn drop(&mut self) {
265 self.query_metrics.running_query_num.dec();
266 self.timeout_abort_task_handle
267 .as_ref()
268 .inspect(|h| h.abort());
269 }
270}
271
272impl Debug for QueryRunner {
273 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
274 let mut graph = Graph::<String, String>::new();
275 let mut stage_id_to_node_id = HashMap::new();
276 for stage in &self.stage_executions {
277 let node_id = graph.add_node(format!("{} {}", stage.0, block_on(stage.1.state())));
278 stage_id_to_node_id.insert(stage.0, node_id);
279 }
280
281 for stage in &self.stage_executions {
282 let stage_id = stage.0;
283 if let Some(child_stages) = self.query.stage_graph.get_child_stages(stage_id) {
284 for child_stage in child_stages {
285 graph.add_edge(
286 *stage_id_to_node_id.get(stage_id).unwrap(),
287 *stage_id_to_node_id.get(child_stage).unwrap(),
288 "".to_owned(),
289 );
290 }
291 }
292 }
293
294 writeln!(f, "{}", Dot::with_config(&graph, &[Config::EdgeNoLabel]))
296 }
297}
298
299impl QueryRunner {
300 async fn run(mut self) {
301 self.query_metrics.running_query_num.inc();
302 let leaf_stages = self.query.leaf_stages();
304 for stage_id in &leaf_stages {
305 self.stage_executions[stage_id].start().await;
306 tracing::trace!(
307 "Query stage {:?}-{:?} started.",
308 self.query.query_id,
309 stage_id
310 );
311 }
312 let mut stages_with_table_scan = self.query.stages_with_table_scan();
313 let has_lookup_join_stage = self.query.has_lookup_join_stage();
314
315 let mut finished_stage_cnt = 0usize;
316 while let Some(msg_inner) = self.msg_receiver.recv().await {
317 match msg_inner {
318 Stage(Scheduled(stage_id)) => {
319 tracing::trace!(
320 "Query stage {:?}-{:?} scheduled.",
321 self.query.query_id,
322 stage_id
323 );
324 self.scheduled_stages_count += 1;
325 stages_with_table_scan.remove(&stage_id);
326 if !has_lookup_join_stage && stages_with_table_scan.is_empty() {
329 tracing::trace!(
333 "Query {:?} has scheduled all of its stages that have table scan (iterator creation).",
334 self.query.query_id
335 );
336 }
337
338 for parent in self.query.get_parents(&stage_id) {
341 if self.all_children_scheduled(parent).await
342 && self.stage_executions[parent].is_pending().await
344 {
345 self.stage_executions[parent].start().await;
346 }
347 }
348 }
349 Stage(ScheduledRoot(receiver)) => {
350 self.send_root_stage_info(receiver);
353 }
354 Stage(StageEvent::Failed { id, reason }) => {
355 error!(
356 error = %reason.as_report(),
357 query_id = ?self.query.query_id,
358 stage_id = ?id,
359 "query stage failed"
360 );
361
362 self.clean_all_stages(Some(reason)).await;
363 break;
365 }
366 Stage(StageEvent::Completed(_)) => {
367 finished_stage_cnt += 1;
368 assert!(finished_stage_cnt <= self.stage_executions.len());
369 if finished_stage_cnt == self.stage_executions.len() {
370 tracing::trace!(
371 "Query {:?} completed, starting to clean stage tasks.",
372 &self.query.query_id
373 );
374 self.clean_all_stages(None).await;
376 break;
377 }
378 }
379 QueryMessage::CancelQuery(reason) => {
380 self.clean_all_stages(Some(SchedulerError::QueryCancelled(reason)))
381 .await;
382 break;
384 }
385 }
386 }
387 }
388
389 fn send_root_stage_info(&mut self, chunk_rx: Receiver<SchedulerResult<DataChunk>>) {
392 let root_task_output_id = {
393 let root_task_id_prost = PbTaskId {
394 query_id: self.query.query_id.clone().id,
395 stage_id: self.query.root_stage_id(),
396 task_id: ROOT_TASK_ID,
397 };
398
399 PbTaskOutputId {
400 task_id: Some(root_task_id_prost),
401 output_id: ROOT_TASK_OUTPUT_ID,
402 }
403 };
404
405 let root_stage_result = QueryResultFetcher::new(
406 root_task_output_id,
407 HostAddress::default(),
409 chunk_rx,
410 self.query.query_id.clone(),
411 self.query_execution_info.clone(),
412 );
413
414 let root_stage_sender = mem::take(&mut self.root_stage_sender);
416
417 if let Err(e) = root_stage_sender.unwrap().send(Ok(root_stage_result)) {
418 warn!("Query execution dropped: {:?}", e);
419 } else {
420 debug!("Root stage for {:?} sent.", self.query.query_id);
421 }
422 }
423
424 async fn all_children_scheduled(&self, stage_id: &StageId) -> bool {
425 for child in self.query.stage_graph.get_child_stages_unchecked(stage_id) {
426 if !self.stage_executions[child].is_scheduled().await {
427 return false;
428 }
429 }
430 true
431 }
432
433 async fn clean_all_stages(&mut self, error: Option<SchedulerError>) {
436 let error_msg = error.as_ref().map(|e| e.to_report_string());
438 if let Some(reason) = error {
439 let root_stage_sender = mem::take(&mut self.root_stage_sender);
441 if let Some(sender) = root_stage_sender {
444 if let Err(e) = sender.send(Err(reason)) {
445 warn!("Query execution dropped: {:?}", e);
446 } else {
447 debug!(
448 "Root stage failure event for {:?} sent.",
449 self.query.query_id
450 );
451 }
452 }
453
454 }
457
458 tracing::trace!("Cleaning stages in query [{:?}]", self.query.query_id);
459 for stage_execution in self.stage_executions.values() {
461 stage_execution.stop(error_msg.clone()).await;
463 }
464 }
465}
466
467#[cfg(test)]
468pub(crate) mod tests {
469 use std::collections::{HashMap, HashSet};
470 use std::sync::{Arc, RwLock};
471
472 use fixedbitset::FixedBitSet;
473 use risingwave_batch::worker_manager::worker_node_manager::{
474 WorkerNodeManager, WorkerNodeSelector,
475 };
476 use risingwave_common::catalog::{
477 ColumnCatalog, ColumnDesc, ConflictBehavior, CreateType, DEFAULT_SUPER_USER_ID, Engine,
478 StreamJobStatus,
479 };
480 use risingwave_common::hash::{VirtualNode, VnodeCount, WorkerSlotId, WorkerSlotMapping};
481 use risingwave_common::types::DataType;
482 use risingwave_pb::common::worker_node::Property;
483 use risingwave_pb::common::{HostAddress, WorkerNode, WorkerType};
484 use risingwave_pb::plan_common::JoinType;
485 use risingwave_rpc_client::ComputeClientPool;
486
487 use crate::TableCatalog;
488 use crate::catalog::catalog_service::CatalogReader;
489 use crate::catalog::root_catalog::Catalog;
490 use crate::catalog::table_catalog::TableType;
491 use crate::expr::InputRef;
492 use crate::optimizer::plan_node::{
493 BatchExchange, BatchFilter, BatchHashJoin, EqJoinPredicate, LogicalScan, ToBatch, generic,
494 };
495 use crate::optimizer::property::{Cardinality, Distribution, Order};
496 use crate::optimizer::{OptimizerContext, PlanRef};
497 use crate::scheduler::distributed::QueryExecution;
498 use crate::scheduler::plan_fragmenter::{BatchPlanFragmenter, Query};
499 use crate::scheduler::{
500 DistributedQueryMetrics, ExecutionContext, QueryExecutionInfo, ReadSnapshot,
501 };
502 use crate::session::SessionImpl;
503 use crate::utils::Condition;
504
505 #[tokio::test]
506 async fn test_query_should_not_hang_with_empty_worker() {
507 let worker_node_manager = Arc::new(WorkerNodeManager::mock(vec![]));
508 let worker_node_selector = WorkerNodeSelector::new(worker_node_manager.clone(), false);
509 let compute_client_pool = Arc::new(ComputeClientPool::for_test());
510 let catalog_reader =
511 CatalogReader::new(Arc::new(parking_lot::RwLock::new(Catalog::default())));
512 let query = create_query().await;
513 let query_id = query.query_id().clone();
514 let query_execution = Arc::new(QueryExecution::new(query, (0, 0), None));
515 let query_execution_info = Arc::new(RwLock::new(QueryExecutionInfo::new_from_map(
516 HashMap::from([(query_id, query_execution.clone())]),
517 )));
518
519 assert!(
520 query_execution
521 .start(
522 ExecutionContext::new(SessionImpl::mock().into(), None).into(),
523 worker_node_selector,
524 ReadSnapshot::ReadUncommitted
525 .batch_query_epoch(&HashSet::from_iter([0.into()]))
526 .unwrap(),
527 compute_client_pool,
528 catalog_reader,
529 query_execution_info,
530 Arc::new(DistributedQueryMetrics::for_test()),
531 )
532 .await
533 .is_err()
534 );
535 }
536
537 pub async fn create_query() -> Query {
538 let ctx = OptimizerContext::mock().await;
546 let table_id = 0.into();
547 let vnode_count = VirtualNode::COUNT_FOR_TEST;
548
549 let table_catalog: TableCatalog = TableCatalog {
550 id: table_id,
551 schema_id: 0,
552 database_id: 0,
553 associated_source_id: None,
554 name: "test".to_owned(),
555 columns: vec![
556 ColumnCatalog {
557 column_desc: ColumnDesc::named("a", 0.into(), DataType::Int32),
558 is_hidden: false,
559 },
560 ColumnCatalog {
561 column_desc: ColumnDesc::named("b", 1.into(), DataType::Float64),
562 is_hidden: false,
563 },
564 ColumnCatalog {
565 column_desc: ColumnDesc::named("c", 2.into(), DataType::Int64),
566 is_hidden: false,
567 },
568 ],
569 pk: vec![],
570 stream_key: vec![],
571 table_type: TableType::Table,
572 distribution_key: vec![],
573 append_only: false,
574 owner: DEFAULT_SUPER_USER_ID,
575 retention_seconds: None,
576 fragment_id: 0, dml_fragment_id: None, vnode_col_index: None,
579 row_id_index: None,
580 value_indices: vec![0, 1, 2],
581 definition: "".to_owned(),
582 conflict_behavior: ConflictBehavior::NoCheck,
583 version_column_index: None,
584 read_prefix_len_hint: 0,
585 version: None,
586 watermark_columns: FixedBitSet::with_capacity(3),
587 dist_key_in_pk: vec![],
588 cardinality: Cardinality::unknown(),
589 cleaned_by_watermark: false,
590 created_at_epoch: None,
591 initialized_at_epoch: None,
592 stream_job_status: StreamJobStatus::Creating,
593 create_type: CreateType::Foreground,
594 description: None,
595 incoming_sinks: vec![],
596 initialized_at_cluster_version: None,
597 created_at_cluster_version: None,
598 cdc_table_id: None,
599 vnode_count: VnodeCount::set(vnode_count),
600 webhook_info: None,
601 job_id: None,
602 engine: Engine::Hummock,
603 clean_watermark_index_in_pk: None,
604 };
605 let batch_plan_node: PlanRef = LogicalScan::create(
606 "".to_owned(),
607 table_catalog.into(),
608 vec![],
609 ctx,
610 None,
611 Cardinality::unknown(),
612 )
613 .to_batch()
614 .unwrap()
615 .to_distributed()
616 .unwrap();
617 let batch_filter = BatchFilter::new(generic::Filter::new(
618 Condition {
619 conjunctions: vec![],
620 },
621 batch_plan_node.clone(),
622 ))
623 .into();
624 let batch_exchange_node1: PlanRef = BatchExchange::new(
625 batch_plan_node.clone(),
626 Order::default(),
627 Distribution::HashShard(vec![0, 1]),
628 )
629 .into();
630 let batch_exchange_node2: PlanRef = BatchExchange::new(
631 batch_filter,
632 Order::default(),
633 Distribution::HashShard(vec![0, 1]),
634 )
635 .into();
636 let logical_join_node = generic::Join::with_full_output(
637 batch_exchange_node1.clone(),
638 batch_exchange_node2.clone(),
639 JoinType::Inner,
640 Condition::true_cond(),
641 );
642 let eq_key_1 = (
643 InputRef {
644 index: 0,
645 data_type: DataType::Int32,
646 },
647 InputRef {
648 index: 2,
649 data_type: DataType::Int32,
650 },
651 false,
652 );
653 let eq_key_2 = (
654 InputRef {
655 index: 1,
656 data_type: DataType::Float64,
657 },
658 InputRef {
659 index: 3,
660 data_type: DataType::Float64,
661 },
662 false,
663 );
664 let eq_join_predicate =
665 EqJoinPredicate::new(Condition::true_cond(), vec![eq_key_1, eq_key_2], 2, 2);
666 let hash_join_node: PlanRef =
667 BatchHashJoin::new(logical_join_node, eq_join_predicate, None).into();
668 let batch_exchange_node: PlanRef = BatchExchange::new(
669 hash_join_node.clone(),
670 Order::default(),
671 Distribution::Single,
672 )
673 .into();
674
675 let worker1 = WorkerNode {
676 id: 0,
677 r#type: WorkerType::ComputeNode as i32,
678 host: Some(HostAddress {
679 host: "127.0.0.1".to_owned(),
680 port: 5687,
681 }),
682 state: risingwave_pb::common::worker_node::State::Running as i32,
683 property: Some(Property {
684 parallelism: 8,
685 is_unschedulable: false,
686 is_serving: true,
687 is_streaming: true,
688 ..Default::default()
689 }),
690 transactional_id: Some(0),
691 ..Default::default()
692 };
693 let worker2 = WorkerNode {
694 id: 1,
695 r#type: WorkerType::ComputeNode as i32,
696 host: Some(HostAddress {
697 host: "127.0.0.1".to_owned(),
698 port: 5688,
699 }),
700 state: risingwave_pb::common::worker_node::State::Running as i32,
701 property: Some(Property {
702 parallelism: 8,
703 is_unschedulable: false,
704 is_serving: true,
705 is_streaming: true,
706 ..Default::default()
707 }),
708 transactional_id: Some(1),
709 ..Default::default()
710 };
711 let worker3 = WorkerNode {
712 id: 2,
713 r#type: WorkerType::ComputeNode as i32,
714 host: Some(HostAddress {
715 host: "127.0.0.1".to_owned(),
716 port: 5689,
717 }),
718 state: risingwave_pb::common::worker_node::State::Running as i32,
719 property: Some(Property {
720 parallelism: 8,
721 is_unschedulable: false,
722 is_serving: true,
723 is_streaming: true,
724 ..Default::default()
725 }),
726 transactional_id: Some(2),
727 ..Default::default()
728 };
729 let workers = vec![worker1, worker2, worker3];
730 let worker_node_manager = Arc::new(WorkerNodeManager::mock(workers));
731 let worker_node_selector = WorkerNodeSelector::new(worker_node_manager.clone(), false);
732 let mapping =
733 WorkerSlotMapping::new_uniform(std::iter::once(WorkerSlotId::new(0, 0)), vnode_count);
734 worker_node_manager.insert_streaming_fragment_mapping(0, mapping.clone());
735 worker_node_manager.set_serving_fragment_mapping(vec![(0, mapping)].into_iter().collect());
736 let catalog = Arc::new(parking_lot::RwLock::new(Catalog::default()));
737 catalog.write().insert_table_id_mapping(table_id, 0);
738 let catalog_reader = CatalogReader::new(catalog);
739 let fragmenter = BatchPlanFragmenter::new(
741 worker_node_selector,
742 catalog_reader,
743 None,
744 "UTC".to_owned(),
745 batch_exchange_node.clone(),
746 )
747 .unwrap();
748 fragmenter.generate_complete_query().await.unwrap()
749 }
750}