risingwave_common_secret/
secret_manager.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap};
16use std::fs::File;
17use std::io::Write;
18use std::path::PathBuf;
19
20use anyhow::Context;
21use parking_lot::RwLock;
22use parking_lot::lock_api::RwLockReadGuard;
23use prost::Message;
24use risingwave_pb::catalog::PbSecret;
25use risingwave_pb::id::WorkerId;
26use risingwave_pb::secret::PbSecretRef;
27use risingwave_pb::secret::secret_ref::RefAsType;
28use thiserror_ext::AsReport;
29use tokio::runtime::Handle;
30use tokio::task;
31
32use super::SecretId;
33use super::error::{SecretError, SecretResult};
34use super::vault_client::{HashiCorpVaultClient, HashiCorpVaultConfig};
35
36static INSTANCE: std::sync::OnceLock<LocalSecretManager> = std::sync::OnceLock::new();
37
38#[derive(Debug)]
39pub struct LocalSecretManager {
40    secrets: RwLock<HashMap<SecretId, Vec<u8>>>,
41    /// The local directory used to write secrets into file, so that it can be passed into some libraries
42    secret_file_dir: PathBuf,
43}
44
45impl LocalSecretManager {
46    /// Initialize the secret manager with the given temp file path, cluster id, and encryption key.
47    /// # Panics
48    /// Panics if fail to create the secret file directory.
49    pub fn init(temp_file_dir: String, cluster_id: String, worker_id: WorkerId) {
50        // use `get_or_init` to handle concurrent initialization in single node mode.
51        INSTANCE.get_or_init(|| {
52            let secret_file_dir = PathBuf::from(temp_file_dir)
53                .join(cluster_id)
54                .join(worker_id.to_string());
55            std::fs::remove_dir_all(&secret_file_dir).ok();
56
57            // This will cause file creation conflict in simulation tests.
58            // Should skip testing secret files in simulation tests.
59            #[cfg(not(madsim))]
60            std::fs::create_dir_all(&secret_file_dir).unwrap();
61
62            Self {
63                secrets: RwLock::new(HashMap::new()),
64                secret_file_dir,
65            }
66        });
67    }
68
69    /// Get the global secret manager instance.
70    /// # Panics
71    /// Panics if the secret manager is not initialized.
72    pub fn global() -> &'static LocalSecretManager {
73        // Initialize the secret manager for unit tests.
74        #[cfg(debug_assertions)]
75        LocalSecretManager::init("./tmp".to_owned(), "test_cluster".to_owned(), 0.into());
76
77        INSTANCE.get().unwrap()
78    }
79
80    pub fn add_secret(&self, secret_id: SecretId, secret: Vec<u8>) {
81        let mut secret_guard = self.secrets.write();
82        if secret_guard.insert(secret_id, secret).is_some() {
83            tracing::error!(
84                secret_id = %secret_id,
85                "adding a secret but it already exists, overwriting it"
86            );
87        };
88    }
89
90    pub fn update_secret(&self, secret_id: SecretId, secret: Vec<u8>) {
91        let mut secret_guard = self.secrets.write();
92        if secret_guard.insert(secret_id, secret).is_none() {
93            tracing::error!(
94                secret_id = %secret_id,
95                "updating a secret but it does not exist, adding it"
96            );
97        }
98        self.remove_secret_file_if_exist(&secret_id);
99    }
100
101    pub fn init_secrets(&self, secrets: Vec<PbSecret>) {
102        let mut secret_guard = self.secrets.write();
103        // Reset the secrets
104        secret_guard.clear();
105        // Error should only occurs when running simulation tests when we have multiple nodes
106        // in 1 process and can fail .
107        std::fs::remove_dir_all(&self.secret_file_dir)
108            .inspect_err(|e| {
109                tracing::error!(
110            error = %e.as_report(),
111            path = %self.secret_file_dir.to_string_lossy(),
112            "Failed to remove secret directory")
113            })
114            .ok();
115
116        #[cfg(not(madsim))]
117        std::fs::create_dir_all(&self.secret_file_dir).unwrap();
118
119        for secret in secrets {
120            secret_guard.insert(secret.id, secret.value);
121        }
122    }
123
124    pub fn get_secret(&self, secret_id: SecretId) -> Option<Vec<u8>> {
125        let secret_guard = self.secrets.read();
126        secret_guard.get(&secret_id).cloned()
127    }
128
129    pub fn remove_secret(&self, secret_id: SecretId) {
130        let mut secret_guard = self.secrets.write();
131        secret_guard.remove(&secret_id);
132        self.remove_secret_file_if_exist(&secret_id);
133    }
134
135    pub fn fill_secrets(
136        &self,
137        mut options: BTreeMap<String, String>,
138        secret_refs: BTreeMap<String, PbSecretRef>,
139    ) -> SecretResult<BTreeMap<String, String>> {
140        let secret_guard = self.secrets.read();
141        for (option_key, secret_ref) in secret_refs {
142            let path_str = self.fill_secret_inner(secret_ref, &secret_guard)?;
143            options.insert(option_key, path_str);
144        }
145        Ok(options)
146    }
147
148    pub fn fill_secret(&self, secret_ref: PbSecretRef) -> SecretResult<String> {
149        let secret_guard: RwLockReadGuard<'_, parking_lot::RawRwLock, HashMap<SecretId, Vec<u8>>> =
150            self.secrets.read();
151        self.fill_secret_inner(secret_ref, &secret_guard)
152    }
153
154    fn fill_secret_inner(
155        &self,
156        secret_ref: PbSecretRef,
157        secret_guard: &RwLockReadGuard<'_, parking_lot::RawRwLock, HashMap<SecretId, Vec<u8>>>,
158    ) -> SecretResult<String> {
159        let secret_id = secret_ref.secret_id;
160        let pb_secret_bytes = secret_guard
161            .get(&secret_id)
162            .ok_or(SecretError::ItemNotFound(secret_id))?;
163        let secret_value_bytes = Self::get_secret_value(pb_secret_bytes)?;
164        match secret_ref.ref_as() {
165            RefAsType::Text => {
166                // We converted the secret string from sql to bytes using `as_bytes` in frontend.
167                // So use `from_utf8` here to convert it back to string.
168                Ok(String::from_utf8(secret_value_bytes)?)
169            }
170            RefAsType::File => {
171                let path_str = self.get_or_init_secret_file(secret_id, secret_value_bytes)?;
172                Ok(path_str)
173            }
174            RefAsType::Unspecified => Err(SecretError::UnspecifiedRefType(secret_id)),
175        }
176    }
177
178    /// Get the secret file for the given secret id and return the path string. If the file does not exist, create it.
179    /// WARNING: This method should be called only when the secret manager is locked.
180    fn get_or_init_secret_file(
181        &self,
182        secret_id: SecretId,
183        secret_bytes: Vec<u8>,
184    ) -> SecretResult<String> {
185        let path = self.secret_file_dir.join(secret_id.to_string());
186        if !path.exists() {
187            let mut file = File::create(&path)?;
188            file.write_all(&secret_bytes)?;
189            file.sync_all()?;
190        }
191        Ok(path.to_string_lossy().to_string())
192    }
193
194    /// WARNING: This method should be called only when the secret manager is locked.
195    fn remove_secret_file_if_exist(&self, secret_id: &SecretId) {
196        let path = self.secret_file_dir.join(secret_id.to_string());
197        if path.exists() {
198            std::fs::remove_file(&path)
199                .inspect_err(|e| {
200                    tracing::error!(
201                error = %e.as_report(),
202                path = %path.to_string_lossy(),
203                "Failed to remove secret file")
204                })
205                .ok();
206        }
207    }
208
209    #[cfg_or_panic::cfg_or_panic(not(madsim))]
210    fn get_secret_value(pb_secret_bytes: &[u8]) -> SecretResult<Vec<u8>> {
211        let secret_value = match Self::get_pb_secret_backend(pb_secret_bytes)? {
212            risingwave_pb::secret::secret::SecretBackend::Meta(backend) => backend.value,
213            risingwave_pb::secret::secret::SecretBackend::HashicorpVault(vault_backend) => {
214                let config = HashiCorpVaultConfig::from_protobuf(&vault_backend)?;
215                let client = HashiCorpVaultClient::new(config)?;
216
217                task::block_in_place(move || {
218                    Handle::current().block_on(async move { client.get_secret().await })
219                })?
220            }
221        };
222        Ok(secret_value)
223    }
224
225    /// Get the secret backend from the given decrypted secret bytes.
226    pub fn get_pb_secret_backend(
227        pb_secret_bytes: &[u8],
228    ) -> SecretResult<risingwave_pb::secret::secret::SecretBackend> {
229        let pb_secret = risingwave_pb::secret::Secret::decode(pb_secret_bytes)
230            .context("failed to decode secret")?;
231        Ok(pb_secret.get_secret_backend().unwrap().clone())
232    }
233}