risingwave_common_secret/
secret_manager.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap};
16use std::fs::File;
17use std::io::Write;
18use std::path::PathBuf;
19
20use anyhow::Context;
21use parking_lot::RwLock;
22use parking_lot::lock_api::RwLockReadGuard;
23use prost::Message;
24use risingwave_pb::catalog::PbSecret;
25use risingwave_pb::secret::PbSecretRef;
26use risingwave_pb::secret::secret_ref::RefAsType;
27use thiserror_ext::AsReport;
28use tokio::runtime::Handle;
29use tokio::task;
30
31use super::SecretId;
32use super::error::{SecretError, SecretResult};
33use super::vault_client::{HashiCorpVaultClient, HashiCorpVaultConfig};
34
35static INSTANCE: std::sync::OnceLock<LocalSecretManager> = std::sync::OnceLock::new();
36
37#[derive(Debug)]
38pub struct LocalSecretManager {
39    secrets: RwLock<HashMap<SecretId, Vec<u8>>>,
40    /// The local directory used to write secrets into file, so that it can be passed into some libararies
41    secret_file_dir: PathBuf,
42}
43
44impl LocalSecretManager {
45    /// Initialize the secret manager with the given temp file path, cluster id, and encryption key.
46    /// # Panics
47    /// Panics if fail to create the secret file directory.
48    pub fn init(temp_file_dir: String, cluster_id: String, worker_id: u32) {
49        // use `get_or_init` to handle concurrent initialization in single node mode.
50        INSTANCE.get_or_init(|| {
51            let secret_file_dir = PathBuf::from(temp_file_dir)
52                .join(cluster_id)
53                .join(worker_id.to_string());
54            std::fs::remove_dir_all(&secret_file_dir).ok();
55
56            // This will cause file creation conflict in simulation tests.
57            // Should skip testing secret files in simulation tests.
58            #[cfg(not(madsim))]
59            std::fs::create_dir_all(&secret_file_dir).unwrap();
60
61            Self {
62                secrets: RwLock::new(HashMap::new()),
63                secret_file_dir,
64            }
65        });
66    }
67
68    /// Get the global secret manager instance.
69    /// # Panics
70    /// Panics if the secret manager is not initialized.
71    pub fn global() -> &'static LocalSecretManager {
72        // Initialize the secret manager for unit tests.
73        #[cfg(debug_assertions)]
74        LocalSecretManager::init("./tmp".to_owned(), "test_cluster".to_owned(), 0);
75
76        INSTANCE.get().unwrap()
77    }
78
79    pub fn add_secret(&self, secret_id: SecretId, secret: Vec<u8>) {
80        let mut secret_guard = self.secrets.write();
81        if secret_guard.insert(secret_id, secret).is_some() {
82            tracing::error!(
83                secret_id = secret_id,
84                "adding a secret but it already exists, overwriting it"
85            );
86        };
87    }
88
89    pub fn update_secret(&self, secret_id: SecretId, secret: Vec<u8>) {
90        let mut secret_guard = self.secrets.write();
91        if secret_guard.insert(secret_id, secret).is_none() {
92            tracing::error!(
93                secret_id = secret_id,
94                "updating a secret but it does not exist, adding it"
95            );
96        }
97        self.remove_secret_file_if_exist(&secret_id);
98    }
99
100    pub fn init_secrets(&self, secrets: Vec<PbSecret>) {
101        let mut secret_guard = self.secrets.write();
102        // Reset the secrets
103        secret_guard.clear();
104        // Error should only occurs when running simulation tests when we have multiple nodes
105        // in 1 process and can fail .
106        std::fs::remove_dir_all(&self.secret_file_dir)
107            .inspect_err(|e| {
108                tracing::error!(
109            error = %e.as_report(),
110            path = %self.secret_file_dir.to_string_lossy(),
111            "Failed to remove secret directory")
112            })
113            .ok();
114
115        #[cfg(not(madsim))]
116        std::fs::create_dir_all(&self.secret_file_dir).unwrap();
117
118        for secret in secrets {
119            secret_guard.insert(secret.id, secret.value);
120        }
121    }
122
123    pub fn get_secret(&self, secret_id: SecretId) -> Option<Vec<u8>> {
124        let secret_guard = self.secrets.read();
125        secret_guard.get(&secret_id).cloned()
126    }
127
128    pub fn remove_secret(&self, secret_id: SecretId) {
129        let mut secret_guard = self.secrets.write();
130        secret_guard.remove(&secret_id);
131        self.remove_secret_file_if_exist(&secret_id);
132    }
133
134    pub fn fill_secrets(
135        &self,
136        mut options: BTreeMap<String, String>,
137        secret_refs: BTreeMap<String, PbSecretRef>,
138    ) -> SecretResult<BTreeMap<String, String>> {
139        let secret_guard = self.secrets.read();
140        for (option_key, secret_ref) in secret_refs {
141            let path_str = self.fill_secret_inner(secret_ref, &secret_guard)?;
142            options.insert(option_key, path_str);
143        }
144        Ok(options)
145    }
146
147    pub fn fill_secret(&self, secret_ref: PbSecretRef) -> SecretResult<String> {
148        let secret_guard: RwLockReadGuard<'_, parking_lot::RawRwLock, HashMap<u32, Vec<u8>>> =
149            self.secrets.read();
150        self.fill_secret_inner(secret_ref, &secret_guard)
151    }
152
153    fn fill_secret_inner(
154        &self,
155        secret_ref: PbSecretRef,
156        secret_guard: &RwLockReadGuard<'_, parking_lot::RawRwLock, HashMap<u32, Vec<u8>>>,
157    ) -> SecretResult<String> {
158        let secret_id = secret_ref.secret_id;
159        let pb_secret_bytes = secret_guard
160            .get(&secret_id)
161            .ok_or(SecretError::ItemNotFound(secret_id))?;
162        let secret_value_bytes = Self::get_secret_value(pb_secret_bytes)?;
163        match secret_ref.ref_as() {
164            RefAsType::Text => {
165                // We converted the secret string from sql to bytes using `as_bytes` in frontend.
166                // So use `from_utf8` here to convert it back to string.
167                Ok(String::from_utf8(secret_value_bytes.clone())?)
168            }
169            RefAsType::File => {
170                let path_str =
171                    self.get_or_init_secret_file(secret_id, secret_value_bytes.clone())?;
172                Ok(path_str)
173            }
174            RefAsType::Unspecified => Err(SecretError::UnspecifiedRefType(secret_id)),
175        }
176    }
177
178    /// Get the secret file for the given secret id and return the path string. If the file does not exist, create it.
179    /// WARNING: This method should be called only when the secret manager is locked.
180    fn get_or_init_secret_file(
181        &self,
182        secret_id: SecretId,
183        secret_bytes: Vec<u8>,
184    ) -> SecretResult<String> {
185        let path = self.secret_file_dir.join(secret_id.to_string());
186        if !path.exists() {
187            let mut file = File::create(&path)?;
188            file.write_all(&secret_bytes)?;
189            file.sync_all()?;
190        }
191        Ok(path.to_string_lossy().to_string())
192    }
193
194    /// WARNING: This method should be called only when the secret manager is locked.
195    fn remove_secret_file_if_exist(&self, secret_id: &SecretId) {
196        let path = self.secret_file_dir.join(secret_id.to_string());
197        if path.exists() {
198            std::fs::remove_file(&path)
199                .inspect_err(|e| {
200                    tracing::error!(
201                error = %e.as_report(),
202                path = %path.to_string_lossy(),
203                "Failed to remove secret file")
204                })
205                .ok();
206        }
207    }
208
209    #[cfg_or_panic::cfg_or_panic(not(madsim))]
210    fn get_secret_value(pb_secret_bytes: &[u8]) -> SecretResult<Vec<u8>> {
211        let secret_value = match Self::get_pb_secret_backend(pb_secret_bytes)? {
212            risingwave_pb::secret::secret::SecretBackend::Meta(backend) => backend.value.clone(),
213            risingwave_pb::secret::secret::SecretBackend::HashicorpVault(vault_backend) => {
214                let config = HashiCorpVaultConfig::from_protobuf(&vault_backend)?;
215                let client = HashiCorpVaultClient::new(config)?;
216
217                task::block_in_place(move || {
218                    Handle::current().block_on(async move { client.get_secret().await })
219                })?
220            }
221        };
222        Ok(secret_value)
223    }
224
225    /// Get the secret backend from the given decrypted secret bytes.
226    pub fn get_pb_secret_backend(
227        pb_secret_bytes: &[u8],
228    ) -> SecretResult<risingwave_pb::secret::secret::SecretBackend> {
229        let pb_secret = risingwave_pb::secret::Secret::decode(pb_secret_bytes)
230            .context("failed to decode secret")?;
231        Ok(pb_secret.get_secret_backend().unwrap().clone())
232    }
233}