risingwave_common_secret/
secret_manager.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap};
16use std::fs::File;
17use std::io::Write;
18use std::path::PathBuf;
19
20use anyhow::{Context, anyhow};
21use parking_lot::RwLock;
22use parking_lot::lock_api::RwLockReadGuard;
23use prost::Message;
24use risingwave_pb::catalog::PbSecret;
25use risingwave_pb::secret::PbSecretRef;
26use risingwave_pb::secret::secret_ref::RefAsType;
27use thiserror_ext::AsReport;
28
29use super::SecretId;
30use super::error::{SecretError, SecretResult};
31
32static INSTANCE: std::sync::OnceLock<LocalSecretManager> = std::sync::OnceLock::new();
33
34#[derive(Debug)]
35pub struct LocalSecretManager {
36    secrets: RwLock<HashMap<SecretId, Vec<u8>>>,
37    /// The local directory used to write secrets into file, so that it can be passed into some libararies
38    secret_file_dir: PathBuf,
39}
40
41impl LocalSecretManager {
42    /// Initialize the secret manager with the given temp file path, cluster id, and encryption key.
43    /// # Panics
44    /// Panics if fail to create the secret file directory.
45    pub fn init(temp_file_dir: String, cluster_id: String, worker_id: u32) {
46        // use `get_or_init` to handle concurrent initialization in single node mode.
47        INSTANCE.get_or_init(|| {
48            let secret_file_dir = PathBuf::from(temp_file_dir)
49                .join(cluster_id)
50                .join(worker_id.to_string());
51            std::fs::remove_dir_all(&secret_file_dir).ok();
52
53            // This will cause file creation conflict in simulation tests.
54            // Should skip testing secret files in simulation tests.
55            #[cfg(not(madsim))]
56            std::fs::create_dir_all(&secret_file_dir).unwrap();
57
58            Self {
59                secrets: RwLock::new(HashMap::new()),
60                secret_file_dir,
61            }
62        });
63    }
64
65    /// Get the global secret manager instance.
66    /// # Panics
67    /// Panics if the secret manager is not initialized.
68    pub fn global() -> &'static LocalSecretManager {
69        // Initialize the secret manager for unit tests.
70        #[cfg(debug_assertions)]
71        LocalSecretManager::init("./tmp".to_owned(), "test_cluster".to_owned(), 0);
72
73        INSTANCE.get().unwrap()
74    }
75
76    pub fn add_secret(&self, secret_id: SecretId, secret: Vec<u8>) {
77        let mut secret_guard = self.secrets.write();
78        if secret_guard.insert(secret_id, secret).is_some() {
79            tracing::error!(
80                secret_id = secret_id,
81                "adding a secret but it already exists, overwriting it"
82            );
83        };
84    }
85
86    pub fn update_secret(&self, secret_id: SecretId, secret: Vec<u8>) {
87        let mut secret_guard = self.secrets.write();
88        if secret_guard.insert(secret_id, secret).is_none() {
89            tracing::error!(
90                secret_id = secret_id,
91                "updating a secret but it does not exist, adding it"
92            );
93        }
94        self.remove_secret_file_if_exist(&secret_id);
95    }
96
97    pub fn init_secrets(&self, secrets: Vec<PbSecret>) {
98        let mut secret_guard = self.secrets.write();
99        // Reset the secrets
100        secret_guard.clear();
101        // Error should only occurs when running simulation tests when we have multiple nodes
102        // in 1 process and can fail .
103        std::fs::remove_dir_all(&self.secret_file_dir)
104            .inspect_err(|e| {
105                tracing::error!(
106            error = %e.as_report(),
107            path = %self.secret_file_dir.to_string_lossy(),
108            "Failed to remove secret directory")
109            })
110            .ok();
111
112        #[cfg(not(madsim))]
113        std::fs::create_dir_all(&self.secret_file_dir).unwrap();
114
115        for secret in secrets {
116            secret_guard.insert(secret.id, secret.value);
117        }
118    }
119
120    pub fn get_secret(&self, secret_id: SecretId) -> Option<Vec<u8>> {
121        let secret_guard = self.secrets.read();
122        secret_guard.get(&secret_id).cloned()
123    }
124
125    pub fn remove_secret(&self, secret_id: SecretId) {
126        let mut secret_guard = self.secrets.write();
127        secret_guard.remove(&secret_id);
128        self.remove_secret_file_if_exist(&secret_id);
129    }
130
131    pub fn fill_secrets(
132        &self,
133        mut options: BTreeMap<String, String>,
134        secret_refs: BTreeMap<String, PbSecretRef>,
135    ) -> SecretResult<BTreeMap<String, String>> {
136        let secret_guard = self.secrets.read();
137        for (option_key, secret_ref) in secret_refs {
138            let path_str = self.fill_secret_inner(secret_ref, &secret_guard)?;
139            options.insert(option_key, path_str);
140        }
141        Ok(options)
142    }
143
144    pub fn fill_secret(&self, secret_ref: PbSecretRef) -> SecretResult<String> {
145        let secret_guard: RwLockReadGuard<'_, parking_lot::RawRwLock, HashMap<u32, Vec<u8>>> =
146            self.secrets.read();
147        self.fill_secret_inner(secret_ref, &secret_guard)
148    }
149
150    fn fill_secret_inner(
151        &self,
152        secret_ref: PbSecretRef,
153        secret_guard: &RwLockReadGuard<'_, parking_lot::RawRwLock, HashMap<u32, Vec<u8>>>,
154    ) -> SecretResult<String> {
155        let secret_id = secret_ref.secret_id;
156        let pb_secret_bytes = secret_guard
157            .get(&secret_id)
158            .ok_or(SecretError::ItemNotFound(secret_id))?;
159        let secret_value_bytes = Self::get_secret_value(pb_secret_bytes)?;
160        match secret_ref.ref_as() {
161            RefAsType::Text => {
162                // We converted the secret string from sql to bytes using `as_bytes` in frontend.
163                // So use `from_utf8` here to convert it back to string.
164                Ok(String::from_utf8(secret_value_bytes.clone())?)
165            }
166            RefAsType::File => {
167                let path_str =
168                    self.get_or_init_secret_file(secret_id, secret_value_bytes.clone())?;
169                Ok(path_str)
170            }
171            RefAsType::Unspecified => Err(SecretError::UnspecifiedRefType(secret_id)),
172        }
173    }
174
175    /// Get the secret file for the given secret id and return the path string. If the file does not exist, create it.
176    /// WARNING: This method should be called only when the secret manager is locked.
177    fn get_or_init_secret_file(
178        &self,
179        secret_id: SecretId,
180        secret_bytes: Vec<u8>,
181    ) -> SecretResult<String> {
182        let path = self.secret_file_dir.join(secret_id.to_string());
183        if !path.exists() {
184            let mut file = File::create(&path)?;
185            file.write_all(&secret_bytes)?;
186            file.sync_all()?;
187        }
188        Ok(path.to_string_lossy().to_string())
189    }
190
191    /// WARNING: This method should be called only when the secret manager is locked.
192    fn remove_secret_file_if_exist(&self, secret_id: &SecretId) {
193        let path = self.secret_file_dir.join(secret_id.to_string());
194        if path.exists() {
195            std::fs::remove_file(&path)
196                .inspect_err(|e| {
197                    tracing::error!(
198                error = %e.as_report(),
199                path = %path.to_string_lossy(),
200                "Failed to remove secret file")
201                })
202                .ok();
203        }
204    }
205
206    fn get_secret_value(pb_secret_bytes: &[u8]) -> SecretResult<Vec<u8>> {
207        let secret_value = match Self::get_pb_secret_backend(pb_secret_bytes)? {
208            risingwave_pb::secret::secret::SecretBackend::Meta(backend) => backend.value.clone(),
209            risingwave_pb::secret::secret::SecretBackend::HashicorpVault(_) => {
210                return Err(anyhow!("hashicorp_vault backend is not implemented yet").into());
211            }
212        };
213        Ok(secret_value)
214    }
215
216    /// Get the secret backend from the given decrypted secret bytes.
217    pub fn get_pb_secret_backend(
218        pb_secret_bytes: &[u8],
219    ) -> SecretResult<risingwave_pb::secret::secret::SecretBackend> {
220        let pb_secret = risingwave_pb::secret::Secret::decode(pb_secret_bytes)
221            .context("failed to decode secret")?;
222        Ok(pb_secret.get_secret_backend().unwrap().clone())
223    }
224}