risingwave_frontend/scheduler/distributed/
query_manager.rs1use std::collections::{HashMap, HashSet};
16use std::fmt::{Debug, Formatter};
17use std::pin::Pin;
18use std::sync::{Arc, RwLock};
19use std::task::{Context, Poll};
20
21use futures::Stream;
22use pgwire::pg_server::{BoxedError, Session, SessionId};
23use risingwave_batch::worker_manager::worker_node_manager::{
24 WorkerNodeManagerRef, WorkerNodeSelector,
25};
26use risingwave_common::array::DataChunk;
27use risingwave_common::session_config::QueryMode;
28use risingwave_pb::batch_plan::TaskOutputId;
29use risingwave_pb::common::HostAddress;
30use risingwave_rpc_client::ComputeClientPoolRef;
31use tokio::sync::OwnedSemaphorePermit;
32
33use super::QueryExecution;
34use super::stats::DistributedQueryMetrics;
35use crate::catalog::TableId;
36use crate::catalog::catalog_service::CatalogReader;
37use crate::scheduler::plan_fragmenter::{Query, QueryId};
38use crate::scheduler::{ExecutionContextRef, SchedulerResult};
39
40pub struct DistributedQueryStream {
41 chunk_rx: tokio::sync::mpsc::Receiver<SchedulerResult<DataChunk>>,
42 query_id: QueryId,
44 query_execution_info: QueryExecutionInfoRef,
45}
46
47impl DistributedQueryStream {
48 pub fn query_id(&self) -> &QueryId {
49 &self.query_id
50 }
51}
52
53impl Stream for DistributedQueryStream {
54 type Item = Result<DataChunk, BoxedError>;
56
57 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
58 match self.chunk_rx.poll_recv(cx) {
59 Poll::Pending => Poll::Pending,
60 Poll::Ready(chunk) => match chunk {
61 Some(chunk_result) => match chunk_result {
62 Ok(chunk) => Poll::Ready(Some(Ok(chunk))),
63 Err(err) => Poll::Ready(Some(Err(Box::new(err)))),
64 },
65 None => Poll::Ready(None),
66 },
67 }
68 }
69}
70
71impl Drop for DistributedQueryStream {
72 fn drop(&mut self) {
73 let mut query_execution_info = self.query_execution_info.write().unwrap();
75 query_execution_info.delete_query(&self.query_id);
76 }
77}
78
79pub struct QueryResultFetcher {
80 task_output_id: TaskOutputId,
81 task_host: HostAddress,
82
83 chunk_rx: tokio::sync::mpsc::Receiver<SchedulerResult<DataChunk>>,
84
85 query_id: QueryId,
88 query_execution_info: QueryExecutionInfoRef,
89}
90
91#[derive(Clone, Default)]
95pub struct QueryExecutionInfo {
96 query_execution_map: HashMap<QueryId, Arc<QueryExecution>>,
97}
98
99impl QueryExecutionInfo {
100 #[cfg(test)]
101 pub fn new_from_map(query_execution_map: HashMap<QueryId, Arc<QueryExecution>>) -> Self {
102 Self {
103 query_execution_map,
104 }
105 }
106}
107
108pub type QueryExecutionInfoRef = Arc<RwLock<QueryExecutionInfo>>;
109
110impl QueryExecutionInfo {
111 pub fn add_query(&mut self, query_id: QueryId, query_execution: Arc<QueryExecution>) {
112 self.query_execution_map.insert(query_id, query_execution);
113 }
114
115 pub fn delete_query(&mut self, query_id: &QueryId) {
116 self.query_execution_map.remove(query_id);
117 }
118
119 pub fn abort_queries(&self, session_id: SessionId) {
120 for query in self.query_execution_map.values() {
121 if query.session_id == session_id {
123 let query = query.clone();
124 tokio::spawn(async move { query.abort("cancelled by user".to_owned()).await });
126 }
127 }
128 }
129}
130
131#[derive(Clone)]
133pub struct QueryManager {
134 worker_node_manager: WorkerNodeManagerRef,
135 compute_client_pool: ComputeClientPoolRef,
136 catalog_reader: CatalogReader,
137 query_execution_info: QueryExecutionInfoRef,
138 pub query_metrics: Arc<DistributedQueryMetrics>,
139 disrtibuted_query_limit: Option<u64>,
141 distributed_query_semaphore: Option<Arc<tokio::sync::Semaphore>>,
143 pub total_distributed_query_limit: Option<u64>,
145}
146
147impl QueryManager {
148 pub fn new(
149 worker_node_manager: WorkerNodeManagerRef,
150 compute_client_pool: ComputeClientPoolRef,
151 catalog_reader: CatalogReader,
152 query_metrics: Arc<DistributedQueryMetrics>,
153 disrtibuted_query_limit: Option<u64>,
154 total_distributed_query_limit: Option<u64>,
155 ) -> Self {
156 let distributed_query_semaphore = total_distributed_query_limit
157 .map(|limit| Arc::new(tokio::sync::Semaphore::new(limit as usize)));
158 Self {
159 worker_node_manager,
160 compute_client_pool,
161 catalog_reader,
162 query_execution_info: Arc::new(RwLock::new(QueryExecutionInfo::default())),
163 query_metrics,
164 disrtibuted_query_limit,
165 distributed_query_semaphore,
166 total_distributed_query_limit,
167 }
168 }
169
170 async fn get_permit(&self) -> SchedulerResult<Option<OwnedSemaphorePermit>> {
171 match self.distributed_query_semaphore {
172 Some(ref semaphore) => {
173 let permit = semaphore.clone().acquire_owned().await;
174 match permit {
175 Ok(permit) => Ok(Some(permit)),
176 Err(_) => {
177 self.query_metrics.rejected_query_counter.inc();
178 Err(crate::scheduler::SchedulerError::QueryReachLimit(
179 QueryMode::Distributed,
180 self.total_distributed_query_limit
181 .expect("should have distributed query limit"),
182 ))
183 }
184 }
185 }
186 None => Ok(None),
187 }
188 }
189
190 pub async fn schedule(
191 &self,
192 context: ExecutionContextRef,
193 query: Query,
194 read_storage_tables: HashSet<TableId>,
195 ) -> SchedulerResult<DistributedQueryStream> {
196 if let Some(query_limit) = self.disrtibuted_query_limit
197 && self.query_metrics.running_query_num.get() as u64 == query_limit
198 {
199 self.query_metrics.rejected_query_counter.inc();
200 return Err(crate::scheduler::SchedulerError::QueryReachLimit(
201 QueryMode::Distributed,
202 query_limit,
203 ));
204 }
205 let query_id = query.query_id.clone();
206 let permit = self.get_permit().await?;
207 let query_execution = Arc::new(QueryExecution::new(query, context.session().id(), permit));
208
209 context
211 .session()
212 .env()
213 .query_manager()
214 .add_query(query_id.clone(), query_execution.clone());
215
216 let pinned_snapshot = context.session().pinned_snapshot();
218
219 let worker_node_manager_reader = WorkerNodeSelector::new(
220 self.worker_node_manager.clone(),
221 pinned_snapshot.support_barrier_read(),
222 );
223 let query_result_fetcher = query_execution
225 .start(
226 context.clone(),
227 worker_node_manager_reader,
228 pinned_snapshot.batch_query_epoch(&read_storage_tables)?,
229 self.compute_client_pool.clone(),
230 self.catalog_reader.clone(),
231 self.query_execution_info.clone(),
232 self.query_metrics.clone(),
233 )
234 .await
235 .inspect_err(|_| {
236 context
238 .session()
239 .env()
240 .query_manager()
241 .delete_query(&query_id);
242 })?;
243 Ok(query_result_fetcher.stream_from_channel())
244 }
245
246 pub fn cancel_queries_in_session(&self, session_id: SessionId) {
247 let query_execution_info = self.query_execution_info.read().unwrap();
248 query_execution_info.abort_queries(session_id);
249 }
250
251 pub fn add_query(&self, query_id: QueryId, query_execution: Arc<QueryExecution>) {
252 let mut query_execution_info = self.query_execution_info.write().unwrap();
253 query_execution_info.add_query(query_id, query_execution);
254 }
255
256 pub fn delete_query(&self, query_id: &QueryId) {
257 let mut query_execution_info = self.query_execution_info.write().unwrap();
258 query_execution_info.delete_query(query_id);
259 }
260}
261
262impl QueryResultFetcher {
263 #[allow(clippy::too_many_arguments)]
264 pub fn new(
265 task_output_id: TaskOutputId,
266 task_host: HostAddress,
267 chunk_rx: tokio::sync::mpsc::Receiver<SchedulerResult<DataChunk>>,
268 query_id: QueryId,
269 query_execution_info: QueryExecutionInfoRef,
270 ) -> Self {
271 Self {
272 task_output_id,
273 task_host,
274 chunk_rx,
275 query_id,
276 query_execution_info,
277 }
278 }
279
280 fn stream_from_channel(self) -> DistributedQueryStream {
281 DistributedQueryStream {
282 chunk_rx: self.chunk_rx,
283 query_id: self.query_id,
284 query_execution_info: self.query_execution_info,
285 }
286 }
287}
288
289impl Debug for QueryResultFetcher {
290 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
291 f.debug_struct("QueryResultFetcher")
292 .field("task_output_id", &self.task_output_id)
293 .field("task_host", &self.task_host)
294 .finish()
295 }
296}