risingwave_common_secret/
secret_manager.rs1use std::collections::{BTreeMap, HashMap};
16use std::fs::File;
17use std::io::Write;
18use std::path::PathBuf;
19
20use anyhow::Context;
21use parking_lot::RwLock;
22use parking_lot::lock_api::RwLockReadGuard;
23use prost::Message;
24use risingwave_pb::catalog::PbSecret;
25use risingwave_pb::id::WorkerId;
26use risingwave_pb::secret::PbSecretRef;
27use risingwave_pb::secret::secret_ref::RefAsType;
28use thiserror_ext::AsReport;
29use tokio::runtime::Handle;
30use tokio::task;
31
32use super::SecretId;
33use super::error::{SecretError, SecretResult};
34use super::vault_client::{HashiCorpVaultClient, HashiCorpVaultConfig};
35
36static INSTANCE: std::sync::OnceLock<LocalSecretManager> = std::sync::OnceLock::new();
37
38#[derive(Debug)]
39pub struct LocalSecretManager {
40 secrets: RwLock<HashMap<SecretId, Vec<u8>>>,
41 secret_file_dir: PathBuf,
43}
44
45impl LocalSecretManager {
46 pub fn init(temp_file_dir: String, cluster_id: String, worker_id: WorkerId) {
50 INSTANCE.get_or_init(|| {
52 let secret_file_dir = PathBuf::from(temp_file_dir)
53 .join(cluster_id)
54 .join(worker_id.to_string());
55 std::fs::remove_dir_all(&secret_file_dir).ok();
56
57 #[cfg(not(madsim))]
60 std::fs::create_dir_all(&secret_file_dir).unwrap();
61
62 Self {
63 secrets: RwLock::new(HashMap::new()),
64 secret_file_dir,
65 }
66 });
67 }
68
69 pub fn global() -> &'static LocalSecretManager {
73 #[cfg(debug_assertions)]
75 LocalSecretManager::init("./tmp".to_owned(), "test_cluster".to_owned(), 0.into());
76
77 INSTANCE.get().unwrap()
78 }
79
80 pub fn add_secret(&self, secret_id: SecretId, secret: Vec<u8>) {
81 let mut secret_guard = self.secrets.write();
82 if secret_guard.insert(secret_id, secret).is_some() {
83 tracing::error!(
84 secret_id = %secret_id,
85 "adding a secret but it already exists, overwriting it"
86 );
87 };
88 }
89
90 pub fn update_secret(&self, secret_id: SecretId, secret: Vec<u8>) {
91 let mut secret_guard = self.secrets.write();
92 if secret_guard.insert(secret_id, secret).is_none() {
93 tracing::error!(
94 secret_id = %secret_id,
95 "updating a secret but it does not exist, adding it"
96 );
97 }
98 self.remove_secret_file_if_exist(&secret_id);
99 }
100
101 pub fn init_secrets(&self, secrets: Vec<PbSecret>) {
102 let mut secret_guard = self.secrets.write();
103 secret_guard.clear();
105 std::fs::remove_dir_all(&self.secret_file_dir)
108 .inspect_err(|e| {
109 tracing::error!(
110 error = %e.as_report(),
111 path = %self.secret_file_dir.to_string_lossy(),
112 "Failed to remove secret directory")
113 })
114 .ok();
115
116 #[cfg(not(madsim))]
117 std::fs::create_dir_all(&self.secret_file_dir).unwrap();
118
119 for secret in secrets {
120 secret_guard.insert(secret.id, secret.value);
121 }
122 }
123
124 pub fn get_secret(&self, secret_id: SecretId) -> Option<Vec<u8>> {
125 let secret_guard = self.secrets.read();
126 secret_guard.get(&secret_id).cloned()
127 }
128
129 pub fn remove_secret(&self, secret_id: SecretId) {
130 let mut secret_guard = self.secrets.write();
131 secret_guard.remove(&secret_id);
132 self.remove_secret_file_if_exist(&secret_id);
133 }
134
135 pub fn fill_secrets(
136 &self,
137 mut options: BTreeMap<String, String>,
138 secret_refs: BTreeMap<String, PbSecretRef>,
139 ) -> SecretResult<BTreeMap<String, String>> {
140 let secret_guard = self.secrets.read();
141 for (option_key, secret_ref) in secret_refs {
142 let path_str = self.fill_secret_inner(secret_ref, &secret_guard)?;
143 options.insert(option_key, path_str);
144 }
145 Ok(options)
146 }
147
148 pub fn fill_secret(&self, secret_ref: PbSecretRef) -> SecretResult<String> {
149 let secret_guard: RwLockReadGuard<'_, parking_lot::RawRwLock, HashMap<SecretId, Vec<u8>>> =
150 self.secrets.read();
151 self.fill_secret_inner(secret_ref, &secret_guard)
152 }
153
154 fn fill_secret_inner(
155 &self,
156 secret_ref: PbSecretRef,
157 secret_guard: &RwLockReadGuard<'_, parking_lot::RawRwLock, HashMap<SecretId, Vec<u8>>>,
158 ) -> SecretResult<String> {
159 let secret_id = secret_ref.secret_id;
160 let pb_secret_bytes = secret_guard
161 .get(&secret_id)
162 .ok_or(SecretError::ItemNotFound(secret_id))?;
163 let secret_value_bytes = Self::get_secret_value(pb_secret_bytes)?;
164 match secret_ref.ref_as() {
165 RefAsType::Text => {
166 Ok(String::from_utf8(secret_value_bytes)?)
169 }
170 RefAsType::File => {
171 let path_str = self.get_or_init_secret_file(secret_id, secret_value_bytes)?;
172 Ok(path_str)
173 }
174 RefAsType::Unspecified => Err(SecretError::UnspecifiedRefType(secret_id)),
175 }
176 }
177
178 fn get_or_init_secret_file(
181 &self,
182 secret_id: SecretId,
183 secret_bytes: Vec<u8>,
184 ) -> SecretResult<String> {
185 let path = self.secret_file_dir.join(secret_id.to_string());
186 if !path.exists() {
187 let mut file = File::create(&path)?;
188 file.write_all(&secret_bytes)?;
189 file.sync_all()?;
190 }
191 Ok(path.to_string_lossy().to_string())
192 }
193
194 fn remove_secret_file_if_exist(&self, secret_id: &SecretId) {
196 let path = self.secret_file_dir.join(secret_id.to_string());
197 if path.exists() {
198 std::fs::remove_file(&path)
199 .inspect_err(|e| {
200 tracing::error!(
201 error = %e.as_report(),
202 path = %path.to_string_lossy(),
203 "Failed to remove secret file")
204 })
205 .ok();
206 }
207 }
208
209 #[cfg_or_panic::cfg_or_panic(not(madsim))]
210 fn get_secret_value(pb_secret_bytes: &[u8]) -> SecretResult<Vec<u8>> {
211 let secret_value = match Self::get_pb_secret_backend(pb_secret_bytes)? {
212 risingwave_pb::secret::secret::SecretBackend::Meta(backend) => backend.value,
213 risingwave_pb::secret::secret::SecretBackend::HashicorpVault(vault_backend) => {
214 let config = HashiCorpVaultConfig::from_protobuf(&vault_backend)?;
215 let client = HashiCorpVaultClient::new(config)?;
216
217 task::block_in_place(move || {
218 Handle::current().block_on(async move { client.get_secret().await })
219 })?
220 }
221 };
222 Ok(secret_value)
223 }
224
225 pub fn get_pb_secret_backend(
227 pb_secret_bytes: &[u8],
228 ) -> SecretResult<risingwave_pb::secret::secret::SecretBackend> {
229 let pb_secret = risingwave_pb::secret::Secret::decode(pb_secret_bytes)
230 .context("failed to decode secret")?;
231 Ok(pb_secret.get_secret_backend().unwrap().clone())
232 }
233}