risingwave_common_secret/
secret_manager.rs1use std::collections::{BTreeMap, HashMap};
16use std::fs::File;
17use std::io::Write;
18use std::path::PathBuf;
19
20use anyhow::Context;
21use parking_lot::RwLock;
22use parking_lot::lock_api::RwLockReadGuard;
23use prost::Message;
24use risingwave_pb::catalog::PbSecret;
25use risingwave_pb::secret::PbSecretRef;
26use risingwave_pb::secret::secret_ref::RefAsType;
27use thiserror_ext::AsReport;
28use tokio::runtime::Handle;
29use tokio::task;
30
31use super::SecretId;
32use super::error::{SecretError, SecretResult};
33use super::vault_client::{HashiCorpVaultClient, HashiCorpVaultConfig};
34
35static INSTANCE: std::sync::OnceLock<LocalSecretManager> = std::sync::OnceLock::new();
36
37#[derive(Debug)]
38pub struct LocalSecretManager {
39 secrets: RwLock<HashMap<SecretId, Vec<u8>>>,
40 secret_file_dir: PathBuf,
42}
43
44impl LocalSecretManager {
45 pub fn init(temp_file_dir: String, cluster_id: String, worker_id: u32) {
49 INSTANCE.get_or_init(|| {
51 let secret_file_dir = PathBuf::from(temp_file_dir)
52 .join(cluster_id)
53 .join(worker_id.to_string());
54 std::fs::remove_dir_all(&secret_file_dir).ok();
55
56 #[cfg(not(madsim))]
59 std::fs::create_dir_all(&secret_file_dir).unwrap();
60
61 Self {
62 secrets: RwLock::new(HashMap::new()),
63 secret_file_dir,
64 }
65 });
66 }
67
68 pub fn global() -> &'static LocalSecretManager {
72 #[cfg(debug_assertions)]
74 LocalSecretManager::init("./tmp".to_owned(), "test_cluster".to_owned(), 0);
75
76 INSTANCE.get().unwrap()
77 }
78
79 pub fn add_secret(&self, secret_id: SecretId, secret: Vec<u8>) {
80 let mut secret_guard = self.secrets.write();
81 if secret_guard.insert(secret_id, secret).is_some() {
82 tracing::error!(
83 secret_id = secret_id,
84 "adding a secret but it already exists, overwriting it"
85 );
86 };
87 }
88
89 pub fn update_secret(&self, secret_id: SecretId, secret: Vec<u8>) {
90 let mut secret_guard = self.secrets.write();
91 if secret_guard.insert(secret_id, secret).is_none() {
92 tracing::error!(
93 secret_id = secret_id,
94 "updating a secret but it does not exist, adding it"
95 );
96 }
97 self.remove_secret_file_if_exist(&secret_id);
98 }
99
100 pub fn init_secrets(&self, secrets: Vec<PbSecret>) {
101 let mut secret_guard = self.secrets.write();
102 secret_guard.clear();
104 std::fs::remove_dir_all(&self.secret_file_dir)
107 .inspect_err(|e| {
108 tracing::error!(
109 error = %e.as_report(),
110 path = %self.secret_file_dir.to_string_lossy(),
111 "Failed to remove secret directory")
112 })
113 .ok();
114
115 #[cfg(not(madsim))]
116 std::fs::create_dir_all(&self.secret_file_dir).unwrap();
117
118 for secret in secrets {
119 secret_guard.insert(secret.id, secret.value);
120 }
121 }
122
123 pub fn get_secret(&self, secret_id: SecretId) -> Option<Vec<u8>> {
124 let secret_guard = self.secrets.read();
125 secret_guard.get(&secret_id).cloned()
126 }
127
128 pub fn remove_secret(&self, secret_id: SecretId) {
129 let mut secret_guard = self.secrets.write();
130 secret_guard.remove(&secret_id);
131 self.remove_secret_file_if_exist(&secret_id);
132 }
133
134 pub fn fill_secrets(
135 &self,
136 mut options: BTreeMap<String, String>,
137 secret_refs: BTreeMap<String, PbSecretRef>,
138 ) -> SecretResult<BTreeMap<String, String>> {
139 let secret_guard = self.secrets.read();
140 for (option_key, secret_ref) in secret_refs {
141 let path_str = self.fill_secret_inner(secret_ref, &secret_guard)?;
142 options.insert(option_key, path_str);
143 }
144 Ok(options)
145 }
146
147 pub fn fill_secret(&self, secret_ref: PbSecretRef) -> SecretResult<String> {
148 let secret_guard: RwLockReadGuard<'_, parking_lot::RawRwLock, HashMap<u32, Vec<u8>>> =
149 self.secrets.read();
150 self.fill_secret_inner(secret_ref, &secret_guard)
151 }
152
153 fn fill_secret_inner(
154 &self,
155 secret_ref: PbSecretRef,
156 secret_guard: &RwLockReadGuard<'_, parking_lot::RawRwLock, HashMap<u32, Vec<u8>>>,
157 ) -> SecretResult<String> {
158 let secret_id = secret_ref.secret_id;
159 let pb_secret_bytes = secret_guard
160 .get(&secret_id)
161 .ok_or(SecretError::ItemNotFound(secret_id))?;
162 let secret_value_bytes = Self::get_secret_value(pb_secret_bytes)?;
163 match secret_ref.ref_as() {
164 RefAsType::Text => {
165 Ok(String::from_utf8(secret_value_bytes.clone())?)
168 }
169 RefAsType::File => {
170 let path_str =
171 self.get_or_init_secret_file(secret_id, secret_value_bytes.clone())?;
172 Ok(path_str)
173 }
174 RefAsType::Unspecified => Err(SecretError::UnspecifiedRefType(secret_id)),
175 }
176 }
177
178 fn get_or_init_secret_file(
181 &self,
182 secret_id: SecretId,
183 secret_bytes: Vec<u8>,
184 ) -> SecretResult<String> {
185 let path = self.secret_file_dir.join(secret_id.to_string());
186 if !path.exists() {
187 let mut file = File::create(&path)?;
188 file.write_all(&secret_bytes)?;
189 file.sync_all()?;
190 }
191 Ok(path.to_string_lossy().to_string())
192 }
193
194 fn remove_secret_file_if_exist(&self, secret_id: &SecretId) {
196 let path = self.secret_file_dir.join(secret_id.to_string());
197 if path.exists() {
198 std::fs::remove_file(&path)
199 .inspect_err(|e| {
200 tracing::error!(
201 error = %e.as_report(),
202 path = %path.to_string_lossy(),
203 "Failed to remove secret file")
204 })
205 .ok();
206 }
207 }
208
209 #[cfg_or_panic::cfg_or_panic(not(madsim))]
210 fn get_secret_value(pb_secret_bytes: &[u8]) -> SecretResult<Vec<u8>> {
211 let secret_value = match Self::get_pb_secret_backend(pb_secret_bytes)? {
212 risingwave_pb::secret::secret::SecretBackend::Meta(backend) => backend.value.clone(),
213 risingwave_pb::secret::secret::SecretBackend::HashicorpVault(vault_backend) => {
214 let config = HashiCorpVaultConfig::from_protobuf(&vault_backend)?;
215 let client = HashiCorpVaultClient::new(config)?;
216
217 task::block_in_place(move || {
218 Handle::current().block_on(async move { client.get_secret().await })
219 })?
220 }
221 };
222 Ok(secret_value)
223 }
224
225 pub fn get_pb_secret_backend(
227 pb_secret_bytes: &[u8],
228 ) -> SecretResult<risingwave_pb::secret::secret::SecretBackend> {
229 let pb_secret = risingwave_pb::secret::Secret::decode(pb_secret_bytes)
230 .context("failed to decode secret")?;
231 Ok(pb_secret.get_secret_backend().unwrap().clone())
232 }
233}