risingwave_frontend/scheduler/distributed/
query_manager.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{HashMap, HashSet};
16use std::fmt::{Debug, Formatter};
17use std::pin::Pin;
18use std::sync::{Arc, RwLock};
19use std::task::{Context, Poll};
20
21use futures::Stream;
22use pgwire::pg_server::{BoxedError, Session, SessionId};
23use risingwave_batch::worker_manager::worker_node_manager::{
24    WorkerNodeManagerRef, WorkerNodeSelector,
25};
26use risingwave_common::array::DataChunk;
27use risingwave_common::session_config::QueryMode;
28use risingwave_pb::batch_plan::TaskOutputId;
29use risingwave_pb::common::HostAddress;
30use risingwave_rpc_client::ComputeClientPoolRef;
31use tokio::sync::OwnedSemaphorePermit;
32
33use super::QueryExecution;
34use super::stats::DistributedQueryMetrics;
35use crate::catalog::TableId;
36use crate::catalog::catalog_service::CatalogReader;
37use crate::scheduler::plan_fragmenter::{Query, QueryId};
38use crate::scheduler::{ExecutionContextRef, SchedulerResult};
39
40pub struct DistributedQueryStream {
41    chunk_rx: tokio::sync::mpsc::Receiver<SchedulerResult<DataChunk>>,
42    // Used for cleaning up `QueryExecution` after all data are polled.
43    query_id: QueryId,
44    query_execution_info: QueryExecutionInfoRef,
45}
46
47impl DistributedQueryStream {
48    pub fn query_id(&self) -> &QueryId {
49        &self.query_id
50    }
51}
52
53impl Stream for DistributedQueryStream {
54    // TODO(error-handling): use a concrete error type.
55    type Item = Result<DataChunk, BoxedError>;
56
57    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
58        match self.chunk_rx.poll_recv(cx) {
59            Poll::Pending => Poll::Pending,
60            Poll::Ready(chunk) => match chunk {
61                Some(chunk_result) => match chunk_result {
62                    Ok(chunk) => Poll::Ready(Some(Ok(chunk))),
63                    Err(err) => Poll::Ready(Some(Err(Box::new(err)))),
64                },
65                None => Poll::Ready(None),
66            },
67        }
68    }
69}
70
71impl Drop for DistributedQueryStream {
72    fn drop(&mut self) {
73        // Clear `QueryExecution`. Avoid holding it after execution ends.
74        let mut query_execution_info = self.query_execution_info.write().unwrap();
75        query_execution_info.delete_query(&self.query_id);
76    }
77}
78
79pub struct QueryResultFetcher {
80    task_output_id: TaskOutputId,
81    task_host: HostAddress,
82
83    chunk_rx: tokio::sync::mpsc::Receiver<SchedulerResult<DataChunk>>,
84
85    // `query_id` and `query_execution_info` are used for cleaning up `QueryExecution` after
86    // execution.
87    query_id: QueryId,
88    query_execution_info: QueryExecutionInfoRef,
89}
90
91/// [`QueryExecutionInfo`] stores necessary information of query executions. Currently, a
92/// `QueryExecution` will be removed right after it ends execution. We might add additional fields
93/// in the future.
94#[derive(Clone, Default)]
95pub struct QueryExecutionInfo {
96    query_execution_map: HashMap<QueryId, Arc<QueryExecution>>,
97}
98
99impl QueryExecutionInfo {
100    #[cfg(test)]
101    pub fn new_from_map(query_execution_map: HashMap<QueryId, Arc<QueryExecution>>) -> Self {
102        Self {
103            query_execution_map,
104        }
105    }
106}
107
108pub type QueryExecutionInfoRef = Arc<RwLock<QueryExecutionInfo>>;
109
110impl QueryExecutionInfo {
111    pub fn add_query(&mut self, query_id: QueryId, query_execution: Arc<QueryExecution>) {
112        self.query_execution_map.insert(query_id, query_execution);
113    }
114
115    pub fn delete_query(&mut self, query_id: &QueryId) {
116        self.query_execution_map.remove(query_id);
117    }
118
119    pub fn abort_queries(&self, session_id: SessionId) {
120        for query in self.query_execution_map.values() {
121            // `QueryExecutionInfo` might have queries from different sessions.
122            if query.session_id == session_id {
123                let query = query.clone();
124                // Spawn a task to abort. Avoid await point in this function.
125                tokio::spawn(async move { query.abort("cancelled by user".to_owned()).await });
126            }
127        }
128    }
129}
130
131/// Manages execution of distributed batch queries.
132#[derive(Clone)]
133pub struct QueryManager {
134    worker_node_manager: WorkerNodeManagerRef,
135    compute_client_pool: ComputeClientPoolRef,
136    catalog_reader: CatalogReader,
137    query_execution_info: QueryExecutionInfoRef,
138    pub query_metrics: Arc<DistributedQueryMetrics>,
139    /// Limit per session.
140    disrtibuted_query_limit: Option<u64>,
141    /// Limits the number of concurrent distributed queries.
142    distributed_query_semaphore: Option<Arc<tokio::sync::Semaphore>>,
143    /// Total permitted distributed query number.
144    pub total_distributed_query_limit: Option<u64>,
145}
146
147impl QueryManager {
148    pub fn new(
149        worker_node_manager: WorkerNodeManagerRef,
150        compute_client_pool: ComputeClientPoolRef,
151        catalog_reader: CatalogReader,
152        query_metrics: Arc<DistributedQueryMetrics>,
153        disrtibuted_query_limit: Option<u64>,
154        total_distributed_query_limit: Option<u64>,
155    ) -> Self {
156        let distributed_query_semaphore = total_distributed_query_limit
157            .map(|limit| Arc::new(tokio::sync::Semaphore::new(limit as usize)));
158        Self {
159            worker_node_manager,
160            compute_client_pool,
161            catalog_reader,
162            query_execution_info: Arc::new(RwLock::new(QueryExecutionInfo::default())),
163            query_metrics,
164            disrtibuted_query_limit,
165            distributed_query_semaphore,
166            total_distributed_query_limit,
167        }
168    }
169
170    async fn get_permit(&self) -> SchedulerResult<Option<OwnedSemaphorePermit>> {
171        match self.distributed_query_semaphore {
172            Some(ref semaphore) => {
173                let permit = semaphore.clone().acquire_owned().await;
174                match permit {
175                    Ok(permit) => Ok(Some(permit)),
176                    Err(_) => {
177                        self.query_metrics.rejected_query_counter.inc();
178                        Err(crate::scheduler::SchedulerError::QueryReachLimit(
179                            QueryMode::Distributed,
180                            self.total_distributed_query_limit
181                                .expect("should have distributed query limit"),
182                        ))
183                    }
184                }
185            }
186            None => Ok(None),
187        }
188    }
189
190    pub async fn schedule(
191        &self,
192        context: ExecutionContextRef,
193        query: Query,
194        read_storage_tables: HashSet<TableId>,
195    ) -> SchedulerResult<DistributedQueryStream> {
196        if let Some(query_limit) = self.disrtibuted_query_limit
197            && self.query_metrics.running_query_num.get() as u64 == query_limit
198        {
199            self.query_metrics.rejected_query_counter.inc();
200            return Err(crate::scheduler::SchedulerError::QueryReachLimit(
201                QueryMode::Distributed,
202                query_limit,
203            ));
204        }
205        let query_id = query.query_id.clone();
206        let permit = self.get_permit().await?;
207        let query_execution = Arc::new(QueryExecution::new(query, context.session().id(), permit));
208
209        // Add queries status when begin.
210        context
211            .session()
212            .env()
213            .query_manager()
214            .add_query(query_id.clone(), query_execution.clone());
215
216        // TODO: if there's no table scan, we don't need to acquire snapshot.
217        let pinned_snapshot = context.session().pinned_snapshot();
218
219        let worker_node_manager_reader = WorkerNodeSelector::new(
220            self.worker_node_manager.clone(),
221            pinned_snapshot.support_barrier_read(),
222        );
223        // Starts the execution of the query.
224        let query_result_fetcher = query_execution
225            .start(
226                context.clone(),
227                worker_node_manager_reader,
228                pinned_snapshot.batch_query_epoch(&read_storage_tables)?,
229                self.compute_client_pool.clone(),
230                self.catalog_reader.clone(),
231                self.query_execution_info.clone(),
232                self.query_metrics.clone(),
233            )
234            .await
235            .inspect_err(|_| {
236                // Clean up query execution on error.
237                context
238                    .session()
239                    .env()
240                    .query_manager()
241                    .delete_query(&query_id);
242            })?;
243        Ok(query_result_fetcher.stream_from_channel())
244    }
245
246    pub fn cancel_queries_in_session(&self, session_id: SessionId) {
247        let query_execution_info = self.query_execution_info.read().unwrap();
248        query_execution_info.abort_queries(session_id);
249    }
250
251    pub fn add_query(&self, query_id: QueryId, query_execution: Arc<QueryExecution>) {
252        let mut query_execution_info = self.query_execution_info.write().unwrap();
253        query_execution_info.add_query(query_id, query_execution);
254    }
255
256    pub fn delete_query(&self, query_id: &QueryId) {
257        let mut query_execution_info = self.query_execution_info.write().unwrap();
258        query_execution_info.delete_query(query_id);
259    }
260}
261
262impl QueryResultFetcher {
263    #[allow(clippy::too_many_arguments)]
264    pub fn new(
265        task_output_id: TaskOutputId,
266        task_host: HostAddress,
267        chunk_rx: tokio::sync::mpsc::Receiver<SchedulerResult<DataChunk>>,
268        query_id: QueryId,
269        query_execution_info: QueryExecutionInfoRef,
270    ) -> Self {
271        Self {
272            task_output_id,
273            task_host,
274            chunk_rx,
275            query_id,
276            query_execution_info,
277        }
278    }
279
280    fn stream_from_channel(self) -> DistributedQueryStream {
281        DistributedQueryStream {
282            chunk_rx: self.chunk_rx,
283            query_id: self.query_id,
284            query_execution_info: self.query_execution_info,
285        }
286    }
287}
288
289impl Debug for QueryResultFetcher {
290    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
291        f.debug_struct("QueryResultFetcher")
292            .field("task_output_id", &self.task_output_id)
293            .field("task_host", &self.task_host)
294            .finish()
295    }
296}