risingwave_frontend/scheduler/
streaming_manager.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::fmt::{Debug, Formatter};
17use std::sync::Arc;
18
19use itertools::Itertools;
20use parking_lot::RwLock;
21use pgwire::pg_server::SessionId;
22use risingwave_pb::meta::cancel_creating_jobs_request::{
23    CreatingJobInfo, CreatingJobInfos, PbJobs,
24};
25use uuid::Uuid;
26
27use crate::catalog::{DatabaseId, SchemaId};
28use crate::meta_client::FrontendMetaClient;
29
30#[derive(Clone, Debug, Hash, Eq, PartialEq)]
31pub struct TaskId {
32    pub id: String,
33}
34
35impl std::fmt::Display for TaskId {
36    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
37        write!(f, "TaskId:{}", self.id)
38    }
39}
40
41impl Default for TaskId {
42    fn default() -> Self {
43        Self {
44            id: Uuid::new_v4().to_string(),
45        }
46    }
47}
48
49pub type StreamingJobTrackerRef = Arc<StreamingJobTracker>;
50
51pub struct StreamingJobTracker {
52    creating_streaming_job: RwLock<HashMap<TaskId, CreatingStreamingJobInfo>>,
53    meta_client: Arc<dyn FrontendMetaClient>,
54}
55
56impl StreamingJobTracker {
57    pub fn new(meta_client: Arc<dyn FrontendMetaClient>) -> Self {
58        Self {
59            creating_streaming_job: RwLock::new(HashMap::default()),
60            meta_client,
61        }
62    }
63}
64
65#[derive(Clone, Default)]
66pub struct CreatingStreamingJobInfo {
67    /// Identified by `process_id`, `secret_key`.
68    session_id: SessionId,
69    info: CreatingJobInfo,
70}
71
72impl CreatingStreamingJobInfo {
73    pub fn new(
74        session_id: SessionId,
75        database_id: DatabaseId,
76        schema_id: SchemaId,
77        name: String,
78    ) -> Self {
79        Self {
80            session_id,
81            info: CreatingJobInfo {
82                database_id,
83                schema_id,
84                name,
85            },
86        }
87    }
88}
89
90pub struct StreamingJobGuard<'a> {
91    task_id: TaskId,
92    tracker: &'a StreamingJobTracker,
93}
94
95impl Drop for StreamingJobGuard<'_> {
96    fn drop(&mut self) {
97        self.tracker.delete_job(&self.task_id);
98    }
99}
100
101impl StreamingJobTracker {
102    pub fn guard(&self, task_info: CreatingStreamingJobInfo) -> StreamingJobGuard<'_> {
103        let task_id = TaskId::default();
104        self.add_job(task_id.clone(), task_info);
105        StreamingJobGuard {
106            task_id,
107            tracker: self,
108        }
109    }
110
111    fn add_job(&self, task_id: TaskId, info: CreatingStreamingJobInfo) {
112        self.creating_streaming_job.write().insert(task_id, info);
113    }
114
115    fn delete_job(&self, task_id: &TaskId) {
116        self.creating_streaming_job.write().remove(task_id);
117    }
118
119    pub fn abort_jobs(&self, session_id: SessionId) {
120        let jobs = self
121            .creating_streaming_job
122            .read()
123            .values()
124            .filter(|job| job.session_id == session_id)
125            .cloned()
126            .collect_vec();
127
128        let client = self.meta_client.clone();
129        tokio::spawn(async move {
130            client
131                .cancel_creating_jobs(PbJobs::Infos(CreatingJobInfos {
132                    infos: jobs.into_iter().map(|job| job.info).collect_vec(),
133                }))
134                .await
135        });
136    }
137}