risingwave_frontend/scheduler/distributed/
query_manager.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::fmt::{Debug, Formatter};
17use std::pin::Pin;
18use std::sync::{Arc, RwLock};
19use std::task::{Context, Poll};
20
21use futures::Stream;
22use pgwire::pg_server::{BoxedError, Session, SessionId};
23use risingwave_batch::worker_manager::worker_node_manager::{
24    WorkerNodeManagerRef, WorkerNodeSelector,
25};
26use risingwave_common::array::DataChunk;
27use risingwave_common::session_config::QueryMode;
28use risingwave_pb::batch_plan::TaskOutputId;
29use risingwave_pb::common::HostAddress;
30use risingwave_rpc_client::ComputeClientPoolRef;
31use tokio::sync::OwnedSemaphorePermit;
32
33use super::QueryExecution;
34use super::stats::DistributedQueryMetrics;
35use crate::catalog::catalog_service::CatalogReader;
36use crate::scheduler::plan_fragmenter::{Query, QueryId};
37use crate::scheduler::{ExecutionContextRef, SchedulerResult};
38
39pub struct DistributedQueryStream {
40    chunk_rx: tokio::sync::mpsc::Receiver<SchedulerResult<DataChunk>>,
41    // Used for cleaning up `QueryExecution` after all data are polled.
42    query_id: QueryId,
43    query_execution_info: QueryExecutionInfoRef,
44}
45
46impl DistributedQueryStream {
47    pub fn query_id(&self) -> &QueryId {
48        &self.query_id
49    }
50}
51
52impl Stream for DistributedQueryStream {
53    // TODO(error-handling): use a concrete error type.
54    type Item = Result<DataChunk, BoxedError>;
55
56    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
57        match self.chunk_rx.poll_recv(cx) {
58            Poll::Pending => Poll::Pending,
59            Poll::Ready(chunk) => match chunk {
60                Some(chunk_result) => match chunk_result {
61                    Ok(chunk) => Poll::Ready(Some(Ok(chunk))),
62                    Err(err) => Poll::Ready(Some(Err(Box::new(err)))),
63                },
64                None => Poll::Ready(None),
65            },
66        }
67    }
68}
69
70impl Drop for DistributedQueryStream {
71    fn drop(&mut self) {
72        // Clear `QueryExecution`. Avoid holding it after execution ends.
73        let mut query_execution_info = self.query_execution_info.write().unwrap();
74        query_execution_info.delete_query(&self.query_id);
75    }
76}
77
78pub struct QueryResultFetcher {
79    task_output_id: TaskOutputId,
80    task_host: HostAddress,
81
82    chunk_rx: tokio::sync::mpsc::Receiver<SchedulerResult<DataChunk>>,
83
84    // `query_id` and `query_execution_info` are used for cleaning up `QueryExecution` after
85    // execution.
86    query_id: QueryId,
87    query_execution_info: QueryExecutionInfoRef,
88}
89
90/// [`QueryExecutionInfo`] stores necessary information of query executions. Currently, a
91/// `QueryExecution` will be removed right after it ends execution. We might add additional fields
92/// in the future.
93#[derive(Clone, Default)]
94pub struct QueryExecutionInfo {
95    query_execution_map: HashMap<QueryId, Arc<QueryExecution>>,
96}
97
98impl QueryExecutionInfo {
99    #[cfg(test)]
100    pub fn new_from_map(query_execution_map: HashMap<QueryId, Arc<QueryExecution>>) -> Self {
101        Self {
102            query_execution_map,
103        }
104    }
105}
106
107pub type QueryExecutionInfoRef = Arc<RwLock<QueryExecutionInfo>>;
108
109impl QueryExecutionInfo {
110    pub fn add_query(&mut self, query_id: QueryId, query_execution: Arc<QueryExecution>) {
111        self.query_execution_map.insert(query_id, query_execution);
112    }
113
114    pub fn delete_query(&mut self, query_id: &QueryId) {
115        self.query_execution_map.remove(query_id);
116    }
117
118    pub fn abort_queries(&self, session_id: SessionId) {
119        for query in self.query_execution_map.values() {
120            // `QueryExecutionInfo` might have queries from different sessions.
121            if query.session_id == session_id {
122                let query = query.clone();
123                // Spawn a task to abort. Avoid await point in this function.
124                tokio::spawn(async move { query.abort("cancelled by user".to_owned()).await });
125            }
126        }
127    }
128}
129
130/// Manages execution of distributed batch queries.
131#[derive(Clone)]
132pub struct QueryManager {
133    worker_node_manager: WorkerNodeManagerRef,
134    compute_client_pool: ComputeClientPoolRef,
135    catalog_reader: CatalogReader,
136    query_execution_info: QueryExecutionInfoRef,
137    pub query_metrics: Arc<DistributedQueryMetrics>,
138    /// Limit per session.
139    distributed_query_limit: Option<u64>,
140    /// Limits the number of concurrent distributed queries.
141    distributed_query_semaphore: Option<Arc<tokio::sync::Semaphore>>,
142    /// Total permitted distributed query number.
143    pub total_distributed_query_limit: Option<u64>,
144}
145
146impl QueryManager {
147    pub fn new(
148        worker_node_manager: WorkerNodeManagerRef,
149        compute_client_pool: ComputeClientPoolRef,
150        catalog_reader: CatalogReader,
151        query_metrics: Arc<DistributedQueryMetrics>,
152        distributed_query_limit: Option<u64>,
153        total_distributed_query_limit: Option<u64>,
154    ) -> Self {
155        let distributed_query_semaphore = total_distributed_query_limit
156            .map(|limit| Arc::new(tokio::sync::Semaphore::new(limit as usize)));
157        Self {
158            worker_node_manager,
159            compute_client_pool,
160            catalog_reader,
161            query_execution_info: Arc::new(RwLock::new(QueryExecutionInfo::default())),
162            query_metrics,
163            distributed_query_limit,
164            distributed_query_semaphore,
165            total_distributed_query_limit,
166        }
167    }
168
169    async fn get_permit(&self) -> SchedulerResult<Option<OwnedSemaphorePermit>> {
170        match self.distributed_query_semaphore {
171            Some(ref semaphore) => {
172                let permit = semaphore.clone().acquire_owned().await;
173                match permit {
174                    Ok(permit) => Ok(Some(permit)),
175                    Err(_) => {
176                        self.query_metrics.rejected_query_counter.inc();
177                        Err(crate::scheduler::SchedulerError::QueryReachLimit(
178                            QueryMode::Distributed,
179                            self.total_distributed_query_limit
180                                .expect("should have distributed query limit"),
181                        ))
182                    }
183                }
184            }
185            None => Ok(None),
186        }
187    }
188
189    pub async fn schedule(
190        &self,
191        context: ExecutionContextRef,
192        mut query: Query,
193    ) -> SchedulerResult<DistributedQueryStream> {
194        // TODO: if there's no table scan, we don't need to acquire snapshot.
195        let pinned_snapshot = context.session().pinned_snapshot();
196        pinned_snapshot.fill_batch_query_epoch(&mut query)?;
197
198        if let Some(query_limit) = self.distributed_query_limit
199            && self.query_metrics.running_query_num.get() as u64 == query_limit
200        {
201            self.query_metrics.rejected_query_counter.inc();
202            return Err(crate::scheduler::SchedulerError::QueryReachLimit(
203                QueryMode::Distributed,
204                query_limit,
205            ));
206        }
207        let query_id = query.query_id.clone();
208        let permit = self.get_permit().await?;
209        let query_execution = Arc::new(QueryExecution::new(query, context.session().id(), permit));
210
211        // Add queries status when begin.
212        context
213            .session()
214            .env()
215            .query_manager()
216            .add_query(query_id.clone(), query_execution.clone());
217
218        let worker_node_manager_reader = WorkerNodeSelector::new(
219            self.worker_node_manager.clone(),
220            pinned_snapshot.support_barrier_read(),
221        );
222
223        // Starts the execution of the query.
224        let query_result_fetcher = query_execution
225            .start(
226                context.clone(),
227                worker_node_manager_reader,
228                self.compute_client_pool.clone(),
229                self.catalog_reader.clone(),
230                self.query_execution_info.clone(),
231                self.query_metrics.clone(),
232            )
233            .await
234            .inspect_err(|_| {
235                // Clean up query execution on error.
236                context
237                    .session()
238                    .env()
239                    .query_manager()
240                    .delete_query(&query_id);
241            })?;
242        Ok(query_result_fetcher.stream_from_channel())
243    }
244
245    pub fn cancel_queries_in_session(&self, session_id: SessionId) {
246        let query_execution_info = self.query_execution_info.read().unwrap();
247        query_execution_info.abort_queries(session_id);
248    }
249
250    pub fn add_query(&self, query_id: QueryId, query_execution: Arc<QueryExecution>) {
251        let mut query_execution_info = self.query_execution_info.write().unwrap();
252        query_execution_info.add_query(query_id, query_execution);
253    }
254
255    pub fn delete_query(&self, query_id: &QueryId) {
256        let mut query_execution_info = self.query_execution_info.write().unwrap();
257        query_execution_info.delete_query(query_id);
258    }
259}
260
261impl QueryResultFetcher {
262    #[allow(clippy::too_many_arguments)]
263    pub fn new(
264        task_output_id: TaskOutputId,
265        task_host: HostAddress,
266        chunk_rx: tokio::sync::mpsc::Receiver<SchedulerResult<DataChunk>>,
267        query_id: QueryId,
268        query_execution_info: QueryExecutionInfoRef,
269    ) -> Self {
270        Self {
271            task_output_id,
272            task_host,
273            chunk_rx,
274            query_id,
275            query_execution_info,
276        }
277    }
278
279    fn stream_from_channel(self) -> DistributedQueryStream {
280        DistributedQueryStream {
281            chunk_rx: self.chunk_rx,
282            query_id: self.query_id,
283            query_execution_info: self.query_execution_info,
284        }
285    }
286}
287
288impl Debug for QueryResultFetcher {
289    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
290        f.debug_struct("QueryResultFetcher")
291            .field("task_output_id", &self.task_output_id)
292            .field("task_host", &self.task_host)
293            .finish()
294    }
295}