risingwave_frontend/scheduler/distributed/
query_manager.rs1use std::collections::HashMap;
16use std::fmt::{Debug, Formatter};
17use std::pin::Pin;
18use std::sync::{Arc, RwLock};
19use std::task::{Context, Poll};
20
21use futures::Stream;
22use pgwire::pg_server::{BoxedError, Session, SessionId};
23use risingwave_batch::worker_manager::worker_node_manager::{
24 WorkerNodeManagerRef, WorkerNodeSelector,
25};
26use risingwave_common::array::DataChunk;
27use risingwave_common::session_config::QueryMode;
28use risingwave_pb::batch_plan::TaskOutputId;
29use risingwave_pb::common::HostAddress;
30use risingwave_rpc_client::ComputeClientPoolRef;
31use tokio::sync::OwnedSemaphorePermit;
32
33use super::QueryExecution;
34use super::stats::DistributedQueryMetrics;
35use crate::catalog::catalog_service::CatalogReader;
36use crate::scheduler::plan_fragmenter::{Query, QueryId};
37use crate::scheduler::{ExecutionContextRef, SchedulerResult};
38
39pub struct DistributedQueryStream {
40 chunk_rx: tokio::sync::mpsc::Receiver<SchedulerResult<DataChunk>>,
41 query_id: QueryId,
43 query_execution_info: QueryExecutionInfoRef,
44}
45
46impl DistributedQueryStream {
47 pub fn query_id(&self) -> &QueryId {
48 &self.query_id
49 }
50}
51
52impl Stream for DistributedQueryStream {
53 type Item = Result<DataChunk, BoxedError>;
55
56 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
57 match self.chunk_rx.poll_recv(cx) {
58 Poll::Pending => Poll::Pending,
59 Poll::Ready(chunk) => match chunk {
60 Some(chunk_result) => match chunk_result {
61 Ok(chunk) => Poll::Ready(Some(Ok(chunk))),
62 Err(err) => Poll::Ready(Some(Err(Box::new(err)))),
63 },
64 None => Poll::Ready(None),
65 },
66 }
67 }
68}
69
70impl Drop for DistributedQueryStream {
71 fn drop(&mut self) {
72 let mut query_execution_info = self.query_execution_info.write().unwrap();
74 query_execution_info.delete_query(&self.query_id);
75 }
76}
77
78pub struct QueryResultFetcher {
79 task_output_id: TaskOutputId,
80 task_host: HostAddress,
81
82 chunk_rx: tokio::sync::mpsc::Receiver<SchedulerResult<DataChunk>>,
83
84 query_id: QueryId,
87 query_execution_info: QueryExecutionInfoRef,
88}
89
90#[derive(Clone, Default)]
94pub struct QueryExecutionInfo {
95 query_execution_map: HashMap<QueryId, Arc<QueryExecution>>,
96}
97
98impl QueryExecutionInfo {
99 #[cfg(test)]
100 pub fn new_from_map(query_execution_map: HashMap<QueryId, Arc<QueryExecution>>) -> Self {
101 Self {
102 query_execution_map,
103 }
104 }
105}
106
107pub type QueryExecutionInfoRef = Arc<RwLock<QueryExecutionInfo>>;
108
109impl QueryExecutionInfo {
110 pub fn add_query(&mut self, query_id: QueryId, query_execution: Arc<QueryExecution>) {
111 self.query_execution_map.insert(query_id, query_execution);
112 }
113
114 pub fn delete_query(&mut self, query_id: &QueryId) {
115 self.query_execution_map.remove(query_id);
116 }
117
118 pub fn abort_queries(&self, session_id: SessionId) {
119 for query in self.query_execution_map.values() {
120 if query.session_id == session_id {
122 let query = query.clone();
123 tokio::spawn(async move { query.abort("cancelled by user".to_owned()).await });
125 }
126 }
127 }
128}
129
130#[derive(Clone)]
132pub struct QueryManager {
133 worker_node_manager: WorkerNodeManagerRef,
134 compute_client_pool: ComputeClientPoolRef,
135 catalog_reader: CatalogReader,
136 query_execution_info: QueryExecutionInfoRef,
137 pub query_metrics: Arc<DistributedQueryMetrics>,
138 distributed_query_limit: Option<u64>,
140 distributed_query_semaphore: Option<Arc<tokio::sync::Semaphore>>,
142 pub total_distributed_query_limit: Option<u64>,
144}
145
146impl QueryManager {
147 pub fn new(
148 worker_node_manager: WorkerNodeManagerRef,
149 compute_client_pool: ComputeClientPoolRef,
150 catalog_reader: CatalogReader,
151 query_metrics: Arc<DistributedQueryMetrics>,
152 distributed_query_limit: Option<u64>,
153 total_distributed_query_limit: Option<u64>,
154 ) -> Self {
155 let distributed_query_semaphore = total_distributed_query_limit
156 .map(|limit| Arc::new(tokio::sync::Semaphore::new(limit as usize)));
157 Self {
158 worker_node_manager,
159 compute_client_pool,
160 catalog_reader,
161 query_execution_info: Arc::new(RwLock::new(QueryExecutionInfo::default())),
162 query_metrics,
163 distributed_query_limit,
164 distributed_query_semaphore,
165 total_distributed_query_limit,
166 }
167 }
168
169 async fn get_permit(&self) -> SchedulerResult<Option<OwnedSemaphorePermit>> {
170 match self.distributed_query_semaphore {
171 Some(ref semaphore) => {
172 let permit = semaphore.clone().acquire_owned().await;
173 match permit {
174 Ok(permit) => Ok(Some(permit)),
175 Err(_) => {
176 self.query_metrics.rejected_query_counter.inc();
177 Err(crate::scheduler::SchedulerError::QueryReachLimit(
178 QueryMode::Distributed,
179 self.total_distributed_query_limit
180 .expect("should have distributed query limit"),
181 ))
182 }
183 }
184 }
185 None => Ok(None),
186 }
187 }
188
189 pub async fn schedule(
190 &self,
191 context: ExecutionContextRef,
192 mut query: Query,
193 ) -> SchedulerResult<DistributedQueryStream> {
194 let pinned_snapshot = context.session().pinned_snapshot();
196 pinned_snapshot.fill_batch_query_epoch(&mut query)?;
197
198 if let Some(query_limit) = self.distributed_query_limit
199 && self.query_metrics.running_query_num.get() as u64 == query_limit
200 {
201 self.query_metrics.rejected_query_counter.inc();
202 return Err(crate::scheduler::SchedulerError::QueryReachLimit(
203 QueryMode::Distributed,
204 query_limit,
205 ));
206 }
207 let query_id = query.query_id.clone();
208 let permit = self.get_permit().await?;
209 let query_execution = Arc::new(QueryExecution::new(query, context.session().id(), permit));
210
211 context
213 .session()
214 .env()
215 .query_manager()
216 .add_query(query_id.clone(), query_execution.clone());
217
218 let worker_node_manager_reader = WorkerNodeSelector::new(
219 self.worker_node_manager.clone(),
220 pinned_snapshot.support_barrier_read(),
221 );
222
223 let query_result_fetcher = query_execution
225 .start(
226 context.clone(),
227 worker_node_manager_reader,
228 self.compute_client_pool.clone(),
229 self.catalog_reader.clone(),
230 self.query_execution_info.clone(),
231 self.query_metrics.clone(),
232 )
233 .await
234 .inspect_err(|_| {
235 context
237 .session()
238 .env()
239 .query_manager()
240 .delete_query(&query_id);
241 })?;
242 Ok(query_result_fetcher.stream_from_channel())
243 }
244
245 pub fn cancel_queries_in_session(&self, session_id: SessionId) {
246 let query_execution_info = self.query_execution_info.read().unwrap();
247 query_execution_info.abort_queries(session_id);
248 }
249
250 pub fn add_query(&self, query_id: QueryId, query_execution: Arc<QueryExecution>) {
251 let mut query_execution_info = self.query_execution_info.write().unwrap();
252 query_execution_info.add_query(query_id, query_execution);
253 }
254
255 pub fn delete_query(&self, query_id: &QueryId) {
256 let mut query_execution_info = self.query_execution_info.write().unwrap();
257 query_execution_info.delete_query(query_id);
258 }
259}
260
261impl QueryResultFetcher {
262 #[allow(clippy::too_many_arguments)]
263 pub fn new(
264 task_output_id: TaskOutputId,
265 task_host: HostAddress,
266 chunk_rx: tokio::sync::mpsc::Receiver<SchedulerResult<DataChunk>>,
267 query_id: QueryId,
268 query_execution_info: QueryExecutionInfoRef,
269 ) -> Self {
270 Self {
271 task_output_id,
272 task_host,
273 chunk_rx,
274 query_id,
275 query_execution_info,
276 }
277 }
278
279 fn stream_from_channel(self) -> DistributedQueryStream {
280 DistributedQueryStream {
281 chunk_rx: self.chunk_rx,
282 query_id: self.query_id,
283 query_execution_info: self.query_execution_info,
284 }
285 }
286}
287
288impl Debug for QueryResultFetcher {
289 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
290 f.debug_struct("QueryResultFetcher")
291 .field("task_output_id", &self.task_output_id)
292 .field("task_host", &self.task_host)
293 .finish()
294 }
295}