1use std::collections::HashMap;
16use std::fmt::{Debug, Formatter};
17use std::mem;
18use std::sync::Arc;
19
20use anyhow::Context;
21use futures::executor::block_on;
22use petgraph::Graph;
23use petgraph::dot::{Config, Dot};
24use pgwire::pg_server::SessionId;
25use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeSelector;
26use risingwave_common::array::DataChunk;
27use risingwave_pb::batch_plan::{TaskId as PbTaskId, TaskOutputId as PbTaskOutputId};
28use risingwave_pb::common::HostAddress;
29use risingwave_rpc_client::ComputeClientPoolRef;
30use thiserror_ext::AsReport;
31use tokio::sync::mpsc::{Receiver, Sender, channel};
32use tokio::sync::{RwLock, oneshot};
33use tokio::task::JoinHandle;
34use tracing::{Instrument, debug, error, info, warn};
35
36use super::{DistributedQueryMetrics, QueryExecutionInfoRef, QueryResultFetcher, StageEvent};
37use crate::catalog::catalog_service::CatalogReader;
38use crate::scheduler::distributed::StageEvent::Scheduled;
39use crate::scheduler::distributed::StageExecution;
40use crate::scheduler::distributed::query::QueryMessage::Stage;
41use crate::scheduler::distributed::stage::StageEvent::ScheduledRoot;
42use crate::scheduler::plan_fragmenter::{Query, ROOT_TASK_ID, ROOT_TASK_OUTPUT_ID, StageId};
43use crate::scheduler::{ExecutionContextRef, SchedulerError, SchedulerResult};
44
45#[derive(Debug)]
47pub enum QueryMessage {
48 Stage(StageEvent),
50 CancelQuery(String),
52}
53
54enum QueryState {
55 Pending {
59 msg_receiver: Receiver<QueryMessage>,
60 },
61
62 Running,
64
65 Failed,
67}
68
69pub struct QueryExecution {
70 query: Arc<Query>,
71 state: RwLock<QueryState>,
72 shutdown_tx: Sender<QueryMessage>,
73 pub session_id: SessionId,
75 #[expect(dead_code)]
77 pub permit: Option<tokio::sync::OwnedSemaphorePermit>,
78}
79
80struct QueryRunner {
81 query: Arc<Query>,
82 stage_executions: HashMap<StageId, Arc<StageExecution>>,
83 scheduled_stages_count: usize,
84 msg_receiver: Receiver<QueryMessage>,
86
87 root_stage_sender: Option<oneshot::Sender<SchedulerResult<QueryResultFetcher>>>,
89
90 query_execution_info: QueryExecutionInfoRef,
92
93 query_metrics: Arc<DistributedQueryMetrics>,
94 timeout_abort_task_handle: Option<JoinHandle<()>>,
95}
96
97impl QueryExecution {
98 #[allow(clippy::too_many_arguments)]
99 pub fn new(
100 query: Query,
101 session_id: SessionId,
102 permit: Option<tokio::sync::OwnedSemaphorePermit>,
103 ) -> Self {
104 let query = Arc::new(query);
105 let (sender, receiver) = channel(100);
106 let state = QueryState::Pending {
107 msg_receiver: receiver,
108 };
109
110 Self {
111 query,
112 state: RwLock::new(state),
113 shutdown_tx: sender,
114 session_id,
115 permit,
116 }
117 }
118
119 #[allow(clippy::too_many_arguments)]
124 pub async fn start(
125 self: Arc<Self>,
126 context: ExecutionContextRef,
127 worker_node_manager: WorkerNodeSelector,
128 compute_client_pool: ComputeClientPoolRef,
129 catalog_reader: CatalogReader,
130 query_execution_info: QueryExecutionInfoRef,
131 query_metrics: Arc<DistributedQueryMetrics>,
132 ) -> SchedulerResult<QueryResultFetcher> {
133 let mut state = self.state.write().await;
134 let cur_state = mem::replace(&mut *state, QueryState::Failed);
135
136 let stage_executions = self.gen_stage_executions(
140 context.clone(),
141 worker_node_manager,
142 compute_client_pool.clone(),
143 catalog_reader,
144 );
145
146 match cur_state {
147 QueryState::Pending { msg_receiver } => {
148 *state = QueryState::Running;
149
150 let mut timeout_abort_task_handle: Option<JoinHandle<()>> = None;
152 if let Some(timeout) = context.timeout() {
153 let this = self.clone();
154 timeout_abort_task_handle = Some(tokio::spawn(async move {
155 tokio::time::sleep(timeout).await;
156 warn!(
157 "Query {:?} timeout after {} seconds, sending cancel message.",
158 this.query.query_id,
159 timeout.as_secs(),
160 );
161 this.abort(format!("timeout after {} seconds", timeout.as_secs()))
162 .await;
163 }));
164 }
165
166 let (root_stage_sender, root_stage_receiver) =
168 oneshot::channel::<SchedulerResult<QueryResultFetcher>>();
169
170 let runner = QueryRunner {
171 query: self.query.clone(),
172 stage_executions,
173 msg_receiver,
174 root_stage_sender: Some(root_stage_sender),
175 scheduled_stages_count: 0,
176 query_execution_info,
177 query_metrics,
178 timeout_abort_task_handle,
179 };
180
181 let span =
182 tracing::info_span!("distributed_execute", query_id = self.query.query_id.id,);
183
184 tracing::trace!("Starting query: {:?}", self.query.query_id);
185
186 tokio::spawn(async move { runner.run().instrument(span).await });
188
189 let root_stage = root_stage_receiver
190 .await
191 .context("Starting query execution failed")??;
192
193 tracing::trace!(
194 "Received root stage query result fetcher: {:?}, query id: {:?}",
195 root_stage,
196 self.query.query_id
197 );
198
199 tracing::trace!("Query {:?} started.", self.query.query_id);
200 Ok(root_stage)
201 }
202 _ => {
203 unreachable!("The query runner should not be scheduled twice");
204 }
205 }
206 }
207
208 pub async fn abort(self: Arc<Self>, reason: String) {
210 if self
211 .shutdown_tx
212 .send(QueryMessage::CancelQuery(reason))
213 .await
214 .is_err()
215 {
216 warn!("Send cancel query request failed: the query has ended");
217 } else {
218 info!("Send cancel request to query-{:?}", self.query.query_id);
219 };
220 }
221
222 fn gen_stage_executions(
223 &self,
224 context: ExecutionContextRef,
225 worker_node_manager: WorkerNodeSelector,
226 compute_client_pool: ComputeClientPoolRef,
227 catalog_reader: CatalogReader,
228 ) -> HashMap<StageId, Arc<StageExecution>> {
229 let mut stage_executions: HashMap<StageId, Arc<StageExecution>> =
230 HashMap::with_capacity(self.query.stage_graph.stages.len());
231
232 for stage_id in self.query.stage_graph.stage_ids_by_topo_order() {
233 let children_stages = self
234 .query
235 .stage_graph
236 .get_child_stages_unchecked(&stage_id)
237 .iter()
238 .map(|s| stage_executions[s].clone())
239 .collect::<Vec<Arc<StageExecution>>>();
240
241 let stage_exec = Arc::new(StageExecution::new(
242 stage_id,
243 self.query.clone(),
244 worker_node_manager.clone(),
245 self.shutdown_tx.clone(),
246 children_stages,
247 compute_client_pool.clone(),
248 catalog_reader.clone(),
249 context.clone(),
250 ));
251 stage_executions.insert(stage_id, stage_exec);
252 }
253 stage_executions
254 }
255}
256
257impl Drop for QueryRunner {
258 fn drop(&mut self) {
259 self.query_metrics.running_query_num.dec();
260 self.timeout_abort_task_handle
261 .as_ref()
262 .inspect(|h| h.abort());
263 }
264}
265
266impl Debug for QueryRunner {
267 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
268 let mut graph = Graph::<String, String>::new();
269 let mut stage_id_to_node_id = HashMap::new();
270 for stage in &self.stage_executions {
271 let node_id = graph.add_node(format!("{} {}", stage.0, block_on(stage.1.state())));
272 stage_id_to_node_id.insert(stage.0, node_id);
273 }
274
275 for stage in &self.stage_executions {
276 let stage_id = stage.0;
277 if let Some(child_stages) = self.query.stage_graph.get_child_stages(stage_id) {
278 for child_stage in child_stages {
279 graph.add_edge(
280 *stage_id_to_node_id.get(stage_id).unwrap(),
281 *stage_id_to_node_id.get(child_stage).unwrap(),
282 "".to_owned(),
283 );
284 }
285 }
286 }
287
288 writeln!(f, "{}", Dot::with_config(&graph, &[Config::EdgeNoLabel]))
290 }
291}
292
293impl QueryRunner {
294 async fn run(mut self) {
295 self.query_metrics.running_query_num.inc();
296 let leaf_stages = self.query.leaf_stages();
298 for stage_id in &leaf_stages {
299 self.stage_executions[stage_id].start().await;
300 tracing::trace!(
301 "Query stage {:?}-{:?} started.",
302 self.query.query_id,
303 stage_id
304 );
305 }
306 let mut stages_with_table_scan = self.query.stages_with_table_scan();
307 let has_lookup_join_stage = self.query.has_lookup_join_stage();
308
309 let mut finished_stage_cnt = 0usize;
310 while let Some(msg_inner) = self.msg_receiver.recv().await {
311 match msg_inner {
312 Stage(Scheduled(stage_id)) => {
313 tracing::trace!(
314 "Query stage {:?}-{:?} scheduled.",
315 self.query.query_id,
316 stage_id
317 );
318 self.scheduled_stages_count += 1;
319 stages_with_table_scan.remove(&stage_id);
320 if !has_lookup_join_stage && stages_with_table_scan.is_empty() {
323 tracing::trace!(
327 "Query {:?} has scheduled all of its stages that have table scan (iterator creation).",
328 self.query.query_id
329 );
330 }
331
332 for parent in self.query.get_parents(&stage_id) {
335 if self.all_children_scheduled(parent).await
336 && self.stage_executions[parent].is_pending().await
338 {
339 self.stage_executions[parent].start().await;
340 }
341 }
342 }
343 Stage(ScheduledRoot(receiver)) => {
344 self.send_root_stage_info(receiver);
347 }
348 Stage(StageEvent::Failed { id, reason }) => {
349 error!(
350 error = %reason.as_report(),
351 query_id = ?self.query.query_id,
352 stage_id = ?id,
353 "query stage failed"
354 );
355
356 self.clean_all_stages(Some(reason)).await;
357 break;
359 }
360 Stage(StageEvent::Completed(_)) => {
361 finished_stage_cnt += 1;
362 assert!(finished_stage_cnt <= self.stage_executions.len());
363 if finished_stage_cnt == self.stage_executions.len() {
364 tracing::trace!(
365 "Query {:?} completed, starting to clean stage tasks.",
366 &self.query.query_id
367 );
368 self.clean_all_stages(None).await;
370 break;
371 }
372 }
373 QueryMessage::CancelQuery(reason) => {
374 self.clean_all_stages(Some(SchedulerError::QueryCancelled(reason)))
375 .await;
376 break;
378 }
379 }
380 }
381 }
382
383 fn send_root_stage_info(&mut self, chunk_rx: Receiver<SchedulerResult<DataChunk>>) {
386 let root_task_output_id = {
387 let root_task_id_prost = PbTaskId {
388 query_id: self.query.query_id.clone().id,
389 stage_id: self.query.root_stage_id().into(),
390 task_id: ROOT_TASK_ID,
391 };
392
393 PbTaskOutputId {
394 task_id: Some(root_task_id_prost),
395 output_id: ROOT_TASK_OUTPUT_ID,
396 }
397 };
398
399 let root_stage_result = QueryResultFetcher::new(
400 root_task_output_id,
401 HostAddress::default(),
403 chunk_rx,
404 self.query.query_id.clone(),
405 self.query_execution_info.clone(),
406 );
407
408 let root_stage_sender = mem::take(&mut self.root_stage_sender);
410
411 if let Err(e) = root_stage_sender.unwrap().send(Ok(root_stage_result)) {
412 warn!("Query execution dropped: {:?}", e);
413 } else {
414 debug!("Root stage for {:?} sent.", self.query.query_id);
415 }
416 }
417
418 async fn all_children_scheduled(&self, stage_id: &StageId) -> bool {
419 for child in self.query.stage_graph.get_child_stages_unchecked(stage_id) {
420 if !self.stage_executions[child].is_scheduled().await {
421 return false;
422 }
423 }
424 true
425 }
426
427 async fn clean_all_stages(&mut self, error: Option<SchedulerError>) {
430 let error_msg = error.as_ref().map(|e| e.to_report_string());
432 if let Some(reason) = error {
433 let root_stage_sender = mem::take(&mut self.root_stage_sender);
435 if let Some(sender) = root_stage_sender {
438 if let Err(e) = sender.send(Err(reason)) {
439 warn!("Query execution dropped: {:?}", e);
440 } else {
441 debug!(
442 "Root stage failure event for {:?} sent.",
443 self.query.query_id
444 );
445 }
446 }
447
448 }
451
452 tracing::trace!("Cleaning stages in query [{:?}]", self.query.query_id);
453 for stage_execution in self.stage_executions.values() {
455 stage_execution.stop(error_msg.clone()).await;
457 }
458 }
459}
460
461#[cfg(test)]
462pub(crate) mod tests {
463 use std::collections::HashMap;
464 use std::sync::{Arc, RwLock};
465
466 use fixedbitset::FixedBitSet;
467 use risingwave_batch::worker_manager::worker_node_manager::{
468 WorkerNodeManager, WorkerNodeSelector,
469 };
470 use risingwave_common::catalog::{
471 ColumnCatalog, ColumnDesc, ConflictBehavior, CreateType, DEFAULT_SUPER_USER_ID, Engine,
472 StreamJobStatus,
473 };
474 use risingwave_common::hash::{VirtualNode, VnodeCount, WorkerSlotId, WorkerSlotMapping};
475 use risingwave_common::types::DataType;
476 use risingwave_pb::common::worker_node::Property;
477 use risingwave_pb::common::{HostAddress, WorkerNode, WorkerType};
478 use risingwave_pb::plan_common::JoinType;
479 use risingwave_rpc_client::ComputeClientPool;
480
481 use crate::TableCatalog;
482 use crate::catalog::catalog_service::CatalogReader;
483 use crate::catalog::root_catalog::Catalog;
484 use crate::catalog::table_catalog::TableType;
485 use crate::expr::InputRef;
486 use crate::optimizer::OptimizerContext;
487 use crate::optimizer::plan_node::{
488 BatchExchange, BatchFilter, BatchHashJoin, BatchPlanRef as PlanRef, EqJoinPredicate,
489 LogicalScan, ToBatch, generic,
490 };
491 use crate::optimizer::property::{Cardinality, Distribution, Order};
492 use crate::scheduler::distributed::QueryExecution;
493 use crate::scheduler::plan_fragmenter::{BatchPlanFragmenter, Query};
494 use crate::scheduler::{DistributedQueryMetrics, ExecutionContext, QueryExecutionInfo};
495 use crate::session::SessionImpl;
496 use crate::utils::Condition;
497
498 #[tokio::test]
499 async fn test_query_should_not_hang_with_empty_worker() {
500 let worker_node_manager = Arc::new(WorkerNodeManager::mock(vec![]));
501 let worker_node_selector = WorkerNodeSelector::new(worker_node_manager.clone(), false);
502 let compute_client_pool = Arc::new(ComputeClientPool::for_test());
503 let catalog_reader =
504 CatalogReader::new(Arc::new(parking_lot::RwLock::new(Catalog::default())));
505 let query = create_query().await;
506 let query_id = query.query_id().clone();
507 let query_execution = Arc::new(QueryExecution::new(query, (0, 0), None));
508 let query_execution_info = Arc::new(RwLock::new(QueryExecutionInfo::new_from_map(
509 HashMap::from([(query_id, query_execution.clone())]),
510 )));
511
512 assert!(
513 query_execution
514 .start(
515 ExecutionContext::new(SessionImpl::mock().into(), None).into(),
516 worker_node_selector,
517 compute_client_pool,
518 catalog_reader,
519 query_execution_info,
520 Arc::new(DistributedQueryMetrics::for_test()),
521 )
522 .await
523 .is_err()
524 );
525 }
526
527 pub async fn create_query() -> Query {
528 let ctx = OptimizerContext::mock().await;
536 let table_id = 0.into();
537 let vnode_count = VirtualNode::COUNT_FOR_TEST;
538
539 let table_catalog: TableCatalog = TableCatalog {
540 id: table_id,
541 schema_id: 0,
542 database_id: 0,
543 associated_source_id: None,
544 name: "test".to_owned(),
545 refreshable: false,
546 columns: vec![
547 ColumnCatalog {
548 column_desc: ColumnDesc::named("a", 0.into(), DataType::Int32),
549 is_hidden: false,
550 },
551 ColumnCatalog {
552 column_desc: ColumnDesc::named("b", 1.into(), DataType::Float64),
553 is_hidden: false,
554 },
555 ColumnCatalog {
556 column_desc: ColumnDesc::named("c", 2.into(), DataType::Int64),
557 is_hidden: false,
558 },
559 ],
560 pk: vec![],
561 stream_key: vec![],
562 table_type: TableType::Table,
563 distribution_key: vec![],
564 append_only: false,
565 owner: DEFAULT_SUPER_USER_ID,
566 retention_seconds: None,
567 fragment_id: 0, dml_fragment_id: None, vnode_col_index: None,
570 row_id_index: None,
571 value_indices: vec![0, 1, 2],
572 definition: "".to_owned(),
573 conflict_behavior: ConflictBehavior::NoCheck,
574 version_column_indices: vec![],
575 read_prefix_len_hint: 0,
576 version: None,
577 watermark_columns: FixedBitSet::with_capacity(3),
578 dist_key_in_pk: vec![],
579 cardinality: Cardinality::unknown(),
580 cleaned_by_watermark: false,
581 created_at_epoch: None,
582 initialized_at_epoch: None,
583 stream_job_status: StreamJobStatus::Creating,
584 create_type: CreateType::Foreground,
585 description: None,
586 initialized_at_cluster_version: None,
587 created_at_cluster_version: None,
588 cdc_table_id: None,
589 vnode_count: VnodeCount::set(vnode_count),
590 webhook_info: None,
591 job_id: None,
592 engine: Engine::Hummock,
593 clean_watermark_index_in_pk: None,
594 vector_index_info: None,
595 cdc_table_type: None,
596 };
597 let batch_plan_node = LogicalScan::create(table_catalog.into(), ctx, None)
598 .to_batch()
599 .unwrap()
600 .to_distributed()
601 .unwrap();
602 let batch_filter = BatchFilter::new(generic::Filter::new(
603 Condition {
604 conjunctions: vec![],
605 },
606 batch_plan_node.clone(),
607 ))
608 .into();
609 let batch_exchange_node1: PlanRef = BatchExchange::new(
610 batch_plan_node.clone(),
611 Order::default(),
612 Distribution::HashShard(vec![0, 1]),
613 )
614 .into();
615 let batch_exchange_node2: PlanRef = BatchExchange::new(
616 batch_filter,
617 Order::default(),
618 Distribution::HashShard(vec![0, 1]),
619 )
620 .into();
621 let logical_join_node = generic::Join::with_full_output(
622 batch_exchange_node1.clone(),
623 batch_exchange_node2.clone(),
624 JoinType::Inner,
625 Condition::true_cond(),
626 );
627 let eq_key_1 = (
628 InputRef {
629 index: 0,
630 data_type: DataType::Int32,
631 },
632 InputRef {
633 index: 2,
634 data_type: DataType::Int32,
635 },
636 false,
637 );
638 let eq_key_2 = (
639 InputRef {
640 index: 1,
641 data_type: DataType::Float64,
642 },
643 InputRef {
644 index: 3,
645 data_type: DataType::Float64,
646 },
647 false,
648 );
649 let eq_join_predicate =
650 EqJoinPredicate::new(Condition::true_cond(), vec![eq_key_1, eq_key_2], 2, 2);
651 let hash_join_node: PlanRef =
652 BatchHashJoin::new(logical_join_node, eq_join_predicate, None).into();
653 let batch_exchange_node: PlanRef = BatchExchange::new(
654 hash_join_node.clone(),
655 Order::default(),
656 Distribution::Single,
657 )
658 .into();
659
660 let worker1 = WorkerNode {
661 id: 0,
662 r#type: WorkerType::ComputeNode as i32,
663 host: Some(HostAddress {
664 host: "127.0.0.1".to_owned(),
665 port: 5687,
666 }),
667 state: risingwave_pb::common::worker_node::State::Running as i32,
668 property: Some(Property {
669 parallelism: 8,
670 is_unschedulable: false,
671 is_serving: true,
672 is_streaming: true,
673 ..Default::default()
674 }),
675 transactional_id: Some(0),
676 ..Default::default()
677 };
678 let worker2 = WorkerNode {
679 id: 1,
680 r#type: WorkerType::ComputeNode as i32,
681 host: Some(HostAddress {
682 host: "127.0.0.1".to_owned(),
683 port: 5688,
684 }),
685 state: risingwave_pb::common::worker_node::State::Running as i32,
686 property: Some(Property {
687 parallelism: 8,
688 is_unschedulable: false,
689 is_serving: true,
690 is_streaming: true,
691 ..Default::default()
692 }),
693 transactional_id: Some(1),
694 ..Default::default()
695 };
696 let worker3 = WorkerNode {
697 id: 2,
698 r#type: WorkerType::ComputeNode as i32,
699 host: Some(HostAddress {
700 host: "127.0.0.1".to_owned(),
701 port: 5689,
702 }),
703 state: risingwave_pb::common::worker_node::State::Running as i32,
704 property: Some(Property {
705 parallelism: 8,
706 is_unschedulable: false,
707 is_serving: true,
708 is_streaming: true,
709 ..Default::default()
710 }),
711 transactional_id: Some(2),
712 ..Default::default()
713 };
714 let workers = vec![worker1, worker2, worker3];
715 let worker_node_manager = Arc::new(WorkerNodeManager::mock(workers));
716 let worker_node_selector = WorkerNodeSelector::new(worker_node_manager.clone(), false);
717 let mapping =
718 WorkerSlotMapping::new_uniform(std::iter::once(WorkerSlotId::new(0, 0)), vnode_count);
719 worker_node_manager.insert_streaming_fragment_mapping(0, mapping.clone());
720 worker_node_manager.set_serving_fragment_mapping(vec![(0, mapping)].into_iter().collect());
721 let catalog = Arc::new(parking_lot::RwLock::new(Catalog::default()));
722 catalog.write().insert_table_id_mapping(table_id, 0);
723 let catalog_reader = CatalogReader::new(catalog);
724 let fragmenter = BatchPlanFragmenter::new(
726 worker_node_selector,
727 catalog_reader,
728 None,
729 "UTC".to_owned(),
730 batch_exchange_node.clone(),
731 )
732 .unwrap();
733 fragmenter.generate_complete_query().await.unwrap()
734 }
735}