risingwave_common_secret/
secret_manager.rsuse std::collections::{BTreeMap, HashMap};
use std::fs::File;
use std::io::Write;
use std::path::PathBuf;
use anyhow::{anyhow, Context};
use parking_lot::lock_api::RwLockReadGuard;
use parking_lot::RwLock;
use prost::Message;
use risingwave_pb::catalog::PbSecret;
use risingwave_pb::secret::secret_ref::RefAsType;
use risingwave_pb::secret::PbSecretRef;
use thiserror_ext::AsReport;
use super::error::{SecretError, SecretResult};
use super::SecretId;
static INSTANCE: std::sync::OnceLock<LocalSecretManager> = std::sync::OnceLock::new();
#[derive(Debug)]
pub struct LocalSecretManager {
secrets: RwLock<HashMap<SecretId, Vec<u8>>>,
secret_file_dir: PathBuf,
}
impl LocalSecretManager {
pub fn init(temp_file_dir: String, cluster_id: String, worker_id: u32) {
INSTANCE.get_or_init(|| {
let secret_file_dir = PathBuf::from(temp_file_dir)
.join(cluster_id)
.join(worker_id.to_string());
std::fs::remove_dir_all(&secret_file_dir).ok();
#[cfg(not(madsim))]
std::fs::create_dir_all(&secret_file_dir).unwrap();
Self {
secrets: RwLock::new(HashMap::new()),
secret_file_dir,
}
});
}
pub fn global() -> &'static LocalSecretManager {
#[cfg(debug_assertions)]
LocalSecretManager::init("./tmp".to_owned(), "test_cluster".to_owned(), 0);
INSTANCE.get().unwrap()
}
pub fn add_secret(&self, secret_id: SecretId, secret: Vec<u8>) {
let mut secret_guard = self.secrets.write();
if secret_guard.insert(secret_id, secret).is_some() {
tracing::error!(
secret_id = secret_id,
"adding a secret but it already exists, overwriting it"
);
};
}
pub fn update_secret(&self, secret_id: SecretId, secret: Vec<u8>) {
let mut secret_guard = self.secrets.write();
if secret_guard.insert(secret_id, secret).is_none() {
tracing::error!(
secret_id = secret_id,
"updating a secret but it does not exist, adding it"
);
}
self.remove_secret_file_if_exist(&secret_id);
}
pub fn init_secrets(&self, secrets: Vec<PbSecret>) {
let mut secret_guard = self.secrets.write();
secret_guard.clear();
std::fs::remove_dir_all(&self.secret_file_dir)
.inspect_err(|e| {
tracing::error!(
error = %e.as_report(),
path = %self.secret_file_dir.to_string_lossy(),
"Failed to remove secret directory")
})
.ok();
#[cfg(not(madsim))]
std::fs::create_dir_all(&self.secret_file_dir).unwrap();
for secret in secrets {
secret_guard.insert(secret.id, secret.value);
}
}
pub fn get_secret(&self, secret_id: SecretId) -> Option<Vec<u8>> {
let secret_guard = self.secrets.read();
secret_guard.get(&secret_id).cloned()
}
pub fn remove_secret(&self, secret_id: SecretId) {
let mut secret_guard = self.secrets.write();
secret_guard.remove(&secret_id);
self.remove_secret_file_if_exist(&secret_id);
}
pub fn fill_secrets(
&self,
mut options: BTreeMap<String, String>,
secret_refs: BTreeMap<String, PbSecretRef>,
) -> SecretResult<BTreeMap<String, String>> {
let secret_guard = self.secrets.read();
for (option_key, secret_ref) in secret_refs {
let path_str = self.fill_secret_inner(secret_ref, &secret_guard)?;
options.insert(option_key, path_str);
}
Ok(options)
}
pub fn fill_secret(&self, secret_ref: PbSecretRef) -> SecretResult<String> {
let secret_guard: RwLockReadGuard<'_, parking_lot::RawRwLock, HashMap<u32, Vec<u8>>> =
self.secrets.read();
self.fill_secret_inner(secret_ref, &secret_guard)
}
fn fill_secret_inner(
&self,
secret_ref: PbSecretRef,
secret_guard: &RwLockReadGuard<'_, parking_lot::RawRwLock, HashMap<u32, Vec<u8>>>,
) -> SecretResult<String> {
let secret_id = secret_ref.secret_id;
let pb_secret_bytes = secret_guard
.get(&secret_id)
.ok_or(SecretError::ItemNotFound(secret_id))?;
let secret_value_bytes = Self::get_secret_value(pb_secret_bytes)?;
match secret_ref.ref_as() {
RefAsType::Text => {
Ok(String::from_utf8(secret_value_bytes.clone())?)
}
RefAsType::File => {
let path_str =
self.get_or_init_secret_file(secret_id, secret_value_bytes.clone())?;
Ok(path_str)
}
RefAsType::Unspecified => Err(SecretError::UnspecifiedRefType(secret_id)),
}
}
fn get_or_init_secret_file(
&self,
secret_id: SecretId,
secret_bytes: Vec<u8>,
) -> SecretResult<String> {
let path = self.secret_file_dir.join(secret_id.to_string());
if !path.exists() {
let mut file = File::create(&path)?;
file.write_all(&secret_bytes)?;
file.sync_all()?;
}
Ok(path.to_string_lossy().to_string())
}
fn remove_secret_file_if_exist(&self, secret_id: &SecretId) {
let path = self.secret_file_dir.join(secret_id.to_string());
if path.exists() {
std::fs::remove_file(&path)
.inspect_err(|e| {
tracing::error!(
error = %e.as_report(),
path = %path.to_string_lossy(),
"Failed to remove secret file")
})
.ok();
}
}
fn get_secret_value(pb_secret_bytes: &[u8]) -> SecretResult<Vec<u8>> {
let secret_value = match Self::get_pb_secret_backend(pb_secret_bytes)? {
risingwave_pb::secret::secret::SecretBackend::Meta(backend) => backend.value.clone(),
risingwave_pb::secret::secret::SecretBackend::HashicorpVault(_) => {
return Err(anyhow!("hashicorp_vault backend is not implemented yet").into())
}
};
Ok(secret_value)
}
pub fn get_pb_secret_backend(
pb_secret_bytes: &[u8],
) -> SecretResult<risingwave_pb::secret::secret::SecretBackend> {
let pb_secret = risingwave_pb::secret::Secret::decode(pb_secret_bytes)
.context("failed to decode secret")?;
Ok(pb_secret.get_secret_backend().unwrap().clone())
}
}