1use std::collections::HashMap;
16use std::fmt::{Debug, Formatter};
17use std::mem;
18use std::sync::Arc;
19
20use anyhow::Context;
21use futures::executor::block_on;
22use petgraph::Graph;
23use petgraph::dot::{Config, Dot};
24use pgwire::pg_server::SessionId;
25use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeSelector;
26use risingwave_common::array::DataChunk;
27use risingwave_pb::batch_plan::{TaskId as PbTaskId, TaskOutputId as PbTaskOutputId};
28use risingwave_pb::common::{BatchQueryEpoch, HostAddress};
29use risingwave_rpc_client::ComputeClientPoolRef;
30use thiserror_ext::AsReport;
31use tokio::sync::mpsc::{Receiver, Sender, channel};
32use tokio::sync::{RwLock, oneshot};
33use tokio::task::JoinHandle;
34use tracing::{Instrument, debug, error, info, warn};
35
36use super::{DistributedQueryMetrics, QueryExecutionInfoRef, QueryResultFetcher, StageEvent};
37use crate::catalog::catalog_service::CatalogReader;
38use crate::scheduler::distributed::StageEvent::Scheduled;
39use crate::scheduler::distributed::StageExecution;
40use crate::scheduler::distributed::query::QueryMessage::Stage;
41use crate::scheduler::distributed::stage::StageEvent::ScheduledRoot;
42use crate::scheduler::plan_fragmenter::{Query, ROOT_TASK_ID, ROOT_TASK_OUTPUT_ID, StageId};
43use crate::scheduler::{ExecutionContextRef, SchedulerError, SchedulerResult};
44
45#[derive(Debug)]
47pub enum QueryMessage {
48 Stage(StageEvent),
50 CancelQuery(String),
52}
53
54enum QueryState {
55 Pending {
59 msg_receiver: Receiver<QueryMessage>,
60 },
61
62 Running,
64
65 Failed,
67}
68
69pub struct QueryExecution {
70 query: Arc<Query>,
71 state: RwLock<QueryState>,
72 shutdown_tx: Sender<QueryMessage>,
73 pub session_id: SessionId,
75 #[expect(dead_code)]
77 pub permit: Option<tokio::sync::OwnedSemaphorePermit>,
78}
79
80struct QueryRunner {
81 query: Arc<Query>,
82 stage_executions: HashMap<StageId, Arc<StageExecution>>,
83 scheduled_stages_count: usize,
84 msg_receiver: Receiver<QueryMessage>,
86
87 root_stage_sender: Option<oneshot::Sender<SchedulerResult<QueryResultFetcher>>>,
89
90 query_execution_info: QueryExecutionInfoRef,
92
93 query_metrics: Arc<DistributedQueryMetrics>,
94 timeout_abort_task_handle: Option<JoinHandle<()>>,
95}
96
97impl QueryExecution {
98 #[allow(clippy::too_many_arguments)]
99 pub fn new(
100 query: Query,
101 session_id: SessionId,
102 permit: Option<tokio::sync::OwnedSemaphorePermit>,
103 ) -> Self {
104 let query = Arc::new(query);
105 let (sender, receiver) = channel(100);
106 let state = QueryState::Pending {
107 msg_receiver: receiver,
108 };
109
110 Self {
111 query,
112 state: RwLock::new(state),
113 shutdown_tx: sender,
114 session_id,
115 permit,
116 }
117 }
118
119 #[allow(clippy::too_many_arguments)]
124 pub async fn start(
125 self: Arc<Self>,
126 context: ExecutionContextRef,
127 worker_node_manager: WorkerNodeSelector,
128 batch_query_epoch: BatchQueryEpoch,
129 compute_client_pool: ComputeClientPoolRef,
130 catalog_reader: CatalogReader,
131 query_execution_info: QueryExecutionInfoRef,
132 query_metrics: Arc<DistributedQueryMetrics>,
133 ) -> SchedulerResult<QueryResultFetcher> {
134 let mut state = self.state.write().await;
135 let cur_state = mem::replace(&mut *state, QueryState::Failed);
136
137 let stage_executions = self.gen_stage_executions(
141 batch_query_epoch,
142 context.clone(),
143 worker_node_manager,
144 compute_client_pool.clone(),
145 catalog_reader,
146 );
147
148 match cur_state {
149 QueryState::Pending { msg_receiver } => {
150 *state = QueryState::Running;
151
152 let mut timeout_abort_task_handle: Option<JoinHandle<()>> = None;
154 if let Some(timeout) = context.timeout() {
155 let this = self.clone();
156 timeout_abort_task_handle = Some(tokio::spawn(async move {
157 tokio::time::sleep(timeout).await;
158 warn!(
159 "Query {:?} timeout after {} seconds, sending cancel message.",
160 this.query.query_id,
161 timeout.as_secs(),
162 );
163 this.abort(format!("timeout after {} seconds", timeout.as_secs()))
164 .await;
165 }));
166 }
167
168 let (root_stage_sender, root_stage_receiver) =
170 oneshot::channel::<SchedulerResult<QueryResultFetcher>>();
171
172 let runner = QueryRunner {
173 query: self.query.clone(),
174 stage_executions,
175 msg_receiver,
176 root_stage_sender: Some(root_stage_sender),
177 scheduled_stages_count: 0,
178 query_execution_info,
179 query_metrics,
180 timeout_abort_task_handle,
181 };
182
183 let span = tracing::info_span!(
184 "distributed_execute",
185 query_id = self.query.query_id.id,
186 epoch = ?batch_query_epoch,
187 );
188
189 tracing::trace!("Starting query: {:?}", self.query.query_id);
190
191 tokio::spawn(async move { runner.run().instrument(span).await });
193
194 let root_stage = root_stage_receiver
195 .await
196 .context("Starting query execution failed")??;
197
198 tracing::trace!(
199 "Received root stage query result fetcher: {:?}, query id: {:?}",
200 root_stage,
201 self.query.query_id
202 );
203
204 tracing::trace!("Query {:?} started.", self.query.query_id);
205 Ok(root_stage)
206 }
207 _ => {
208 unreachable!("The query runner should not be scheduled twice");
209 }
210 }
211 }
212
213 pub async fn abort(self: Arc<Self>, reason: String) {
215 if self
216 .shutdown_tx
217 .send(QueryMessage::CancelQuery(reason))
218 .await
219 .is_err()
220 {
221 warn!("Send cancel query request failed: the query has ended");
222 } else {
223 info!("Send cancel request to query-{:?}", self.query.query_id);
224 };
225 }
226
227 fn gen_stage_executions(
228 &self,
229 epoch: BatchQueryEpoch,
230 context: ExecutionContextRef,
231 worker_node_manager: WorkerNodeSelector,
232 compute_client_pool: ComputeClientPoolRef,
233 catalog_reader: CatalogReader,
234 ) -> HashMap<StageId, Arc<StageExecution>> {
235 let mut stage_executions: HashMap<StageId, Arc<StageExecution>> =
236 HashMap::with_capacity(self.query.stage_graph.stages.len());
237
238 for stage_id in self.query.stage_graph.stage_ids_by_topo_order() {
239 let children_stages = self
240 .query
241 .stage_graph
242 .get_child_stages_unchecked(&stage_id)
243 .iter()
244 .map(|s| stage_executions[s].clone())
245 .collect::<Vec<Arc<StageExecution>>>();
246
247 let stage_exec = Arc::new(StageExecution::new(
248 epoch,
249 self.query.stage_graph.stages[&stage_id].clone(),
250 worker_node_manager.clone(),
251 self.shutdown_tx.clone(),
252 children_stages,
253 compute_client_pool.clone(),
254 catalog_reader.clone(),
255 context.clone(),
256 ));
257 stage_executions.insert(stage_id, stage_exec);
258 }
259 stage_executions
260 }
261}
262
263impl Drop for QueryRunner {
264 fn drop(&mut self) {
265 self.query_metrics.running_query_num.dec();
266 self.timeout_abort_task_handle
267 .as_ref()
268 .inspect(|h| h.abort());
269 }
270}
271
272impl Debug for QueryRunner {
273 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
274 let mut graph = Graph::<String, String>::new();
275 let mut stage_id_to_node_id = HashMap::new();
276 for stage in &self.stage_executions {
277 let node_id = graph.add_node(format!("{} {}", stage.0, block_on(stage.1.state())));
278 stage_id_to_node_id.insert(stage.0, node_id);
279 }
280
281 for stage in &self.stage_executions {
282 let stage_id = stage.0;
283 if let Some(child_stages) = self.query.stage_graph.get_child_stages(stage_id) {
284 for child_stage in child_stages {
285 graph.add_edge(
286 *stage_id_to_node_id.get(stage_id).unwrap(),
287 *stage_id_to_node_id.get(child_stage).unwrap(),
288 "".to_owned(),
289 );
290 }
291 }
292 }
293
294 writeln!(f, "{}", Dot::with_config(&graph, &[Config::EdgeNoLabel]))
296 }
297}
298
299impl QueryRunner {
300 async fn run(mut self) {
301 self.query_metrics.running_query_num.inc();
302 let leaf_stages = self.query.leaf_stages();
304 for stage_id in &leaf_stages {
305 self.stage_executions[stage_id].start().await;
306 tracing::trace!(
307 "Query stage {:?}-{:?} started.",
308 self.query.query_id,
309 stage_id
310 );
311 }
312 let mut stages_with_table_scan = self.query.stages_with_table_scan();
313 let has_lookup_join_stage = self.query.has_lookup_join_stage();
314
315 let mut finished_stage_cnt = 0usize;
316 while let Some(msg_inner) = self.msg_receiver.recv().await {
317 match msg_inner {
318 Stage(Scheduled(stage_id)) => {
319 tracing::trace!(
320 "Query stage {:?}-{:?} scheduled.",
321 self.query.query_id,
322 stage_id
323 );
324 self.scheduled_stages_count += 1;
325 stages_with_table_scan.remove(&stage_id);
326 if !has_lookup_join_stage && stages_with_table_scan.is_empty() {
329 tracing::trace!(
333 "Query {:?} has scheduled all of its stages that have table scan (iterator creation).",
334 self.query.query_id
335 );
336 }
337
338 for parent in self.query.get_parents(&stage_id) {
341 if self.all_children_scheduled(parent).await
342 && self.stage_executions[parent].is_pending().await
344 {
345 self.stage_executions[parent].start().await;
346 }
347 }
348 }
349 Stage(ScheduledRoot(receiver)) => {
350 self.send_root_stage_info(receiver);
353 }
354 Stage(StageEvent::Failed { id, reason }) => {
355 error!(
356 error = %reason.as_report(),
357 query_id = ?self.query.query_id,
358 stage_id = ?id,
359 "query stage failed"
360 );
361
362 self.clean_all_stages(Some(reason)).await;
363 break;
365 }
366 Stage(StageEvent::Completed(_)) => {
367 finished_stage_cnt += 1;
368 assert!(finished_stage_cnt <= self.stage_executions.len());
369 if finished_stage_cnt == self.stage_executions.len() {
370 tracing::trace!(
371 "Query {:?} completed, starting to clean stage tasks.",
372 &self.query.query_id
373 );
374 self.clean_all_stages(None).await;
376 break;
377 }
378 }
379 QueryMessage::CancelQuery(reason) => {
380 self.clean_all_stages(Some(SchedulerError::QueryCancelled(reason)))
381 .await;
382 break;
384 }
385 }
386 }
387 }
388
389 fn send_root_stage_info(&mut self, chunk_rx: Receiver<SchedulerResult<DataChunk>>) {
392 let root_task_output_id = {
393 let root_task_id_prost = PbTaskId {
394 query_id: self.query.query_id.clone().id,
395 stage_id: self.query.root_stage_id(),
396 task_id: ROOT_TASK_ID,
397 };
398
399 PbTaskOutputId {
400 task_id: Some(root_task_id_prost),
401 output_id: ROOT_TASK_OUTPUT_ID,
402 }
403 };
404
405 let root_stage_result = QueryResultFetcher::new(
406 root_task_output_id,
407 HostAddress::default(),
409 chunk_rx,
410 self.query.query_id.clone(),
411 self.query_execution_info.clone(),
412 );
413
414 let root_stage_sender = mem::take(&mut self.root_stage_sender);
416
417 if let Err(e) = root_stage_sender.unwrap().send(Ok(root_stage_result)) {
418 warn!("Query execution dropped: {:?}", e);
419 } else {
420 debug!("Root stage for {:?} sent.", self.query.query_id);
421 }
422 }
423
424 async fn all_children_scheduled(&self, stage_id: &StageId) -> bool {
425 for child in self.query.stage_graph.get_child_stages_unchecked(stage_id) {
426 if !self.stage_executions[child].is_scheduled().await {
427 return false;
428 }
429 }
430 true
431 }
432
433 async fn clean_all_stages(&mut self, error: Option<SchedulerError>) {
436 let error_msg = error.as_ref().map(|e| e.to_report_string());
438 if let Some(reason) = error {
439 let root_stage_sender = mem::take(&mut self.root_stage_sender);
441 if let Some(sender) = root_stage_sender {
444 if let Err(e) = sender.send(Err(reason)) {
445 warn!("Query execution dropped: {:?}", e);
446 } else {
447 debug!(
448 "Root stage failure event for {:?} sent.",
449 self.query.query_id
450 );
451 }
452 }
453
454 }
457
458 tracing::trace!("Cleaning stages in query [{:?}]", self.query.query_id);
459 for stage_execution in self.stage_executions.values() {
461 stage_execution.stop(error_msg.clone()).await;
463 }
464 }
465}
466
467#[cfg(test)]
468pub(crate) mod tests {
469 use std::collections::{HashMap, HashSet};
470 use std::sync::{Arc, RwLock};
471
472 use fixedbitset::FixedBitSet;
473 use risingwave_batch::worker_manager::worker_node_manager::{
474 WorkerNodeManager, WorkerNodeSelector,
475 };
476 use risingwave_common::catalog::{
477 ColumnCatalog, ColumnDesc, ConflictBehavior, CreateType, DEFAULT_SUPER_USER_ID, Engine,
478 StreamJobStatus,
479 };
480 use risingwave_common::hash::{VirtualNode, VnodeCount, WorkerSlotId, WorkerSlotMapping};
481 use risingwave_common::types::DataType;
482 use risingwave_pb::common::worker_node::Property;
483 use risingwave_pb::common::{HostAddress, WorkerNode, WorkerType};
484 use risingwave_pb::plan_common::JoinType;
485 use risingwave_rpc_client::ComputeClientPool;
486
487 use crate::TableCatalog;
488 use crate::catalog::catalog_service::CatalogReader;
489 use crate::catalog::root_catalog::Catalog;
490 use crate::catalog::table_catalog::TableType;
491 use crate::expr::InputRef;
492 use crate::optimizer::OptimizerContext;
493 use crate::optimizer::plan_node::{
494 BatchExchange, BatchFilter, BatchHashJoin, BatchPlanRef as PlanRef, EqJoinPredicate,
495 LogicalScan, ToBatch, generic,
496 };
497 use crate::optimizer::property::{Cardinality, Distribution, Order};
498 use crate::scheduler::distributed::QueryExecution;
499 use crate::scheduler::plan_fragmenter::{BatchPlanFragmenter, Query};
500 use crate::scheduler::{
501 DistributedQueryMetrics, ExecutionContext, QueryExecutionInfo, ReadSnapshot,
502 };
503 use crate::session::SessionImpl;
504 use crate::utils::Condition;
505
506 #[tokio::test]
507 async fn test_query_should_not_hang_with_empty_worker() {
508 let worker_node_manager = Arc::new(WorkerNodeManager::mock(vec![]));
509 let worker_node_selector = WorkerNodeSelector::new(worker_node_manager.clone(), false);
510 let compute_client_pool = Arc::new(ComputeClientPool::for_test());
511 let catalog_reader =
512 CatalogReader::new(Arc::new(parking_lot::RwLock::new(Catalog::default())));
513 let query = create_query().await;
514 let query_id = query.query_id().clone();
515 let query_execution = Arc::new(QueryExecution::new(query, (0, 0), None));
516 let query_execution_info = Arc::new(RwLock::new(QueryExecutionInfo::new_from_map(
517 HashMap::from([(query_id, query_execution.clone())]),
518 )));
519
520 assert!(
521 query_execution
522 .start(
523 ExecutionContext::new(SessionImpl::mock().into(), None).into(),
524 worker_node_selector,
525 ReadSnapshot::ReadUncommitted
526 .batch_query_epoch(&HashSet::from_iter([0.into()]))
527 .unwrap(),
528 compute_client_pool,
529 catalog_reader,
530 query_execution_info,
531 Arc::new(DistributedQueryMetrics::for_test()),
532 )
533 .await
534 .is_err()
535 );
536 }
537
538 pub async fn create_query() -> Query {
539 let ctx = OptimizerContext::mock().await;
547 let table_id = 0.into();
548 let vnode_count = VirtualNode::COUNT_FOR_TEST;
549
550 let table_catalog: TableCatalog = TableCatalog {
551 id: table_id,
552 schema_id: 0,
553 database_id: 0,
554 associated_source_id: None,
555 name: "test".to_owned(),
556 refreshable: false,
557 columns: vec![
558 ColumnCatalog {
559 column_desc: ColumnDesc::named("a", 0.into(), DataType::Int32),
560 is_hidden: false,
561 },
562 ColumnCatalog {
563 column_desc: ColumnDesc::named("b", 1.into(), DataType::Float64),
564 is_hidden: false,
565 },
566 ColumnCatalog {
567 column_desc: ColumnDesc::named("c", 2.into(), DataType::Int64),
568 is_hidden: false,
569 },
570 ],
571 pk: vec![],
572 stream_key: vec![],
573 table_type: TableType::Table,
574 distribution_key: vec![],
575 append_only: false,
576 owner: DEFAULT_SUPER_USER_ID,
577 retention_seconds: None,
578 fragment_id: 0, dml_fragment_id: None, vnode_col_index: None,
581 row_id_index: None,
582 value_indices: vec![0, 1, 2],
583 definition: "".to_owned(),
584 conflict_behavior: ConflictBehavior::NoCheck,
585 version_column_index: None,
586 read_prefix_len_hint: 0,
587 version: None,
588 watermark_columns: FixedBitSet::with_capacity(3),
589 dist_key_in_pk: vec![],
590 cardinality: Cardinality::unknown(),
591 cleaned_by_watermark: false,
592 created_at_epoch: None,
593 initialized_at_epoch: None,
594 stream_job_status: StreamJobStatus::Creating,
595 create_type: CreateType::Foreground,
596 description: None,
597 incoming_sinks: vec![],
598 initialized_at_cluster_version: None,
599 created_at_cluster_version: None,
600 cdc_table_id: None,
601 vnode_count: VnodeCount::set(vnode_count),
602 webhook_info: None,
603 job_id: None,
604 engine: Engine::Hummock,
605 clean_watermark_index_in_pk: None,
606 vector_index_info: None,
607 };
608 let batch_plan_node = LogicalScan::create(table_catalog.into(), ctx, None)
609 .to_batch()
610 .unwrap()
611 .to_distributed()
612 .unwrap();
613 let batch_filter = BatchFilter::new(generic::Filter::new(
614 Condition {
615 conjunctions: vec![],
616 },
617 batch_plan_node.clone(),
618 ))
619 .into();
620 let batch_exchange_node1: PlanRef = BatchExchange::new(
621 batch_plan_node.clone(),
622 Order::default(),
623 Distribution::HashShard(vec![0, 1]),
624 )
625 .into();
626 let batch_exchange_node2: PlanRef = BatchExchange::new(
627 batch_filter,
628 Order::default(),
629 Distribution::HashShard(vec![0, 1]),
630 )
631 .into();
632 let logical_join_node = generic::Join::with_full_output(
633 batch_exchange_node1.clone(),
634 batch_exchange_node2.clone(),
635 JoinType::Inner,
636 Condition::true_cond(),
637 );
638 let eq_key_1 = (
639 InputRef {
640 index: 0,
641 data_type: DataType::Int32,
642 },
643 InputRef {
644 index: 2,
645 data_type: DataType::Int32,
646 },
647 false,
648 );
649 let eq_key_2 = (
650 InputRef {
651 index: 1,
652 data_type: DataType::Float64,
653 },
654 InputRef {
655 index: 3,
656 data_type: DataType::Float64,
657 },
658 false,
659 );
660 let eq_join_predicate =
661 EqJoinPredicate::new(Condition::true_cond(), vec![eq_key_1, eq_key_2], 2, 2);
662 let hash_join_node: PlanRef =
663 BatchHashJoin::new(logical_join_node, eq_join_predicate, None).into();
664 let batch_exchange_node: PlanRef = BatchExchange::new(
665 hash_join_node.clone(),
666 Order::default(),
667 Distribution::Single,
668 )
669 .into();
670
671 let worker1 = WorkerNode {
672 id: 0,
673 r#type: WorkerType::ComputeNode as i32,
674 host: Some(HostAddress {
675 host: "127.0.0.1".to_owned(),
676 port: 5687,
677 }),
678 state: risingwave_pb::common::worker_node::State::Running as i32,
679 property: Some(Property {
680 parallelism: 8,
681 is_unschedulable: false,
682 is_serving: true,
683 is_streaming: true,
684 ..Default::default()
685 }),
686 transactional_id: Some(0),
687 ..Default::default()
688 };
689 let worker2 = WorkerNode {
690 id: 1,
691 r#type: WorkerType::ComputeNode as i32,
692 host: Some(HostAddress {
693 host: "127.0.0.1".to_owned(),
694 port: 5688,
695 }),
696 state: risingwave_pb::common::worker_node::State::Running as i32,
697 property: Some(Property {
698 parallelism: 8,
699 is_unschedulable: false,
700 is_serving: true,
701 is_streaming: true,
702 ..Default::default()
703 }),
704 transactional_id: Some(1),
705 ..Default::default()
706 };
707 let worker3 = WorkerNode {
708 id: 2,
709 r#type: WorkerType::ComputeNode as i32,
710 host: Some(HostAddress {
711 host: "127.0.0.1".to_owned(),
712 port: 5689,
713 }),
714 state: risingwave_pb::common::worker_node::State::Running as i32,
715 property: Some(Property {
716 parallelism: 8,
717 is_unschedulable: false,
718 is_serving: true,
719 is_streaming: true,
720 ..Default::default()
721 }),
722 transactional_id: Some(2),
723 ..Default::default()
724 };
725 let workers = vec![worker1, worker2, worker3];
726 let worker_node_manager = Arc::new(WorkerNodeManager::mock(workers));
727 let worker_node_selector = WorkerNodeSelector::new(worker_node_manager.clone(), false);
728 let mapping =
729 WorkerSlotMapping::new_uniform(std::iter::once(WorkerSlotId::new(0, 0)), vnode_count);
730 worker_node_manager.insert_streaming_fragment_mapping(0, mapping.clone());
731 worker_node_manager.set_serving_fragment_mapping(vec![(0, mapping)].into_iter().collect());
732 let catalog = Arc::new(parking_lot::RwLock::new(Catalog::default()));
733 catalog.write().insert_table_id_mapping(table_id, 0);
734 let catalog_reader = CatalogReader::new(catalog);
735 let fragmenter = BatchPlanFragmenter::new(
737 worker_node_selector,
738 catalog_reader,
739 None,
740 "UTC".to_owned(),
741 batch_exchange_node.clone(),
742 )
743 .unwrap();
744 fragmenter.generate_complete_query().await.unwrap()
745 }
746}