risingwave_common_secret/
vault_client.rs1use std::collections::HashMap;
16use std::sync::LazyLock;
17use std::time::{Duration, Instant};
18
19use anyhow::{Context, Result};
20use moka::future::Cache as MokaCache;
21use reqwest::Client;
22use risingwave_pb::secret;
23use serde::{Deserialize, Serialize};
24use serde_json::Value;
25use serde_with::{DisplayFromStr, serde_as};
26use url::Url;
27
28#[derive(Debug, Clone, PartialEq, Eq, Hash)]
29struct TokenCacheKey {
30 vault_base_url: String,
31 role_id: String,
32}
33
34#[derive(Debug, Clone)]
35struct CachedToken {
36 token: String,
37 expires_at: Instant,
38}
39
40static GLOBAL_VAULT_TOKEN_CACHE: LazyLock<MokaCache<TokenCacheKey, CachedToken>> =
43 LazyLock::new(|| {
44 MokaCache::builder()
45 .max_capacity(1000) .build()
47 });
48
49#[serde_as]
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct HashiCorpVaultConfig {
52 pub addr: String,
53 pub path: String,
54 pub field: String,
55 #[serde(flatten)]
56 pub auth: HashiCorpVaultAuth,
57 #[serde(default)]
58 #[serde_as(as = "DisplayFromStr")]
59 pub tls_skip_verify: bool,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
63#[serde(tag = "auth_method", rename_all = "lowercase")]
64pub enum HashiCorpVaultAuth {
65 Token {
66 auth_token: String,
67 },
68 #[serde(rename = "approle")]
69 AppRole {
70 auth_role_id: String,
71 auth_secret_id: String,
72 },
73}
74
75#[derive(Debug, Serialize, Deserialize)]
76struct VaultAppRoleLoginRequest {
77 role_id: String,
78 secret_id: String,
79}
80
81#[derive(Debug, Serialize, Deserialize)]
82struct VaultAuthResponse {
83 auth: VaultAuthData,
84}
85
86#[derive(Debug, Serialize, Deserialize)]
87struct VaultAuthData {
88 client_token: String,
89 lease_duration: u64,
90}
91
92#[derive(Debug, Serialize, Deserialize)]
93struct VaultSecretResponse {
94 data: VaultSecretData,
95}
96
97#[derive(Debug, Serialize, Deserialize)]
98struct VaultSecretData {
99 data: HashMap<String, Value>,
100}
101
102#[derive(Debug)]
103pub struct HashiCorpVaultClient {
104 client: Client,
105 config: HashiCorpVaultConfig,
106}
107
108impl HashiCorpVaultConfig {
109 pub fn from_protobuf(vault_backend: &secret::SecretHashicorpVaultBackend) -> Result<Self> {
111 let auth = match vault_backend.auth.as_ref() {
112 Some(secret::secret_hashicorp_vault_backend::Auth::TokenAuth(token_auth)) => {
113 HashiCorpVaultAuth::Token {
114 auth_token: token_auth.token.clone(),
115 }
116 }
117 Some(secret::secret_hashicorp_vault_backend::Auth::ApproleAuth(approle_auth)) => {
118 HashiCorpVaultAuth::AppRole {
119 auth_role_id: approle_auth.role_id.clone(),
120 auth_secret_id: approle_auth.secret_id.clone(),
121 }
122 }
123 None => {
124 return Err(anyhow::anyhow!(
125 "No auth method specified for Vault backend"
126 ));
127 }
128 };
129
130 Ok(HashiCorpVaultConfig {
131 addr: vault_backend.addr.clone(),
132 path: vault_backend.path.clone(),
133 field: vault_backend.field.clone(),
134 auth,
135 tls_skip_verify: vault_backend.tls_skip_verify,
136 })
137 }
138
139 pub fn to_protobuf(&self) -> secret::SecretHashicorpVaultBackend {
141 let auth = match &self.auth {
142 HashiCorpVaultAuth::Token { auth_token } => Some(
143 secret::secret_hashicorp_vault_backend::Auth::TokenAuth(secret::VaultTokenAuth {
144 token: auth_token.clone(),
145 }),
146 ),
147 HashiCorpVaultAuth::AppRole {
148 auth_role_id,
149 auth_secret_id,
150 } => Some(secret::secret_hashicorp_vault_backend::Auth::ApproleAuth(
151 secret::VaultAppRoleAuth {
152 role_id: auth_role_id.clone(),
153 secret_id: auth_secret_id.clone(),
154 },
155 )),
156 };
157
158 secret::SecretHashicorpVaultBackend {
159 addr: self.addr.clone(),
160 path: self.path.clone(),
161 field: self.field.clone(),
162 auth,
163 tls_skip_verify: self.tls_skip_verify,
164 }
165 }
166}
167
168impl HashiCorpVaultClient {
169 pub fn new(config: HashiCorpVaultConfig) -> Result<Self> {
170 let mut client_builder = Client::builder();
171
172 if config.tls_skip_verify {
173 client_builder = client_builder.danger_accept_invalid_certs(true);
174 }
175
176 let client = client_builder
177 .build()
178 .context("Failed to create HTTP client")?;
179
180 Ok(Self { client, config })
181 }
182
183 pub async fn get_secret(&self) -> Result<Vec<u8>> {
184 let mut force_refresh_token = false;
186
187 for retry_count in 0..2 {
189 let token = self.get_token_internal(force_refresh_token).await?;
191
192 let mut url = Url::parse(&self.config.addr).context("Failed to parse Vault address")?;
194 url.set_path(&format!("v1/{}", self.config.path));
195 let response = self
196 .client
197 .get(url.as_str())
198 .header("X-Vault-Token", &token)
199 .send()
200 .await
201 .context("Failed to send request to Vault")?;
202
203 if (response.status() == 401 || response.status() == 403)
205 && retry_count == 0
206 && matches!(self.config.auth, HashiCorpVaultAuth::AppRole { .. })
207 {
208 force_refresh_token = true;
210 continue;
211 }
212
213 if !response.status().is_success() {
214 return Err(anyhow::anyhow!(
215 "Vault API returned error status: {} - {}",
216 response.status(),
217 response.text().await.unwrap_or_default()
218 ));
219 }
220
221 return self.process_secret_response(response).await;
223 }
224
225 Err(anyhow::anyhow!("Failed to get secret from Vault"))
227 }
228
229 async fn process_secret_response(&self, response: reqwest::Response) -> Result<Vec<u8>> {
230 let secret_response: VaultSecretResponse = response
257 .json()
258 .await
259 .context("Failed to parse Vault secret response")?;
260
261 let field_value = secret_response
262 .data
263 .data
264 .get(&self.config.field)
265 .ok_or_else(|| anyhow::anyhow!("Field '{}' not found in secret", self.config.field))?;
266
267 let secret_bytes = match field_value {
268 Value::String(s) => s.as_bytes().to_vec(),
269 _ => serde_json::to_vec(field_value)
270 .context("Failed to serialize field value to bytes")?,
271 };
272
273 Ok(secret_bytes)
274 }
275
276 async fn get_token_internal(&self, force_refresh: bool) -> Result<String> {
277 match &self.config.auth {
278 HashiCorpVaultAuth::Token { auth_token } => Ok(auth_token.clone()),
279 HashiCorpVaultAuth::AppRole {
280 auth_role_id,
281 auth_secret_id,
282 } => {
283 let cache_key = TokenCacheKey {
285 vault_base_url: self.config.addr.trim_end_matches('/').to_owned(),
286 role_id: auth_role_id.clone(),
287 };
288
289 if !force_refresh
291 && let Some(cached_token) = GLOBAL_VAULT_TOKEN_CACHE.get(&cache_key).await
292 {
293 if cached_token.expires_at > Instant::now() {
294 return Ok(cached_token.token);
295 } else {
296 GLOBAL_VAULT_TOKEN_CACHE.invalidate(&cache_key).await;
298 }
299 }
300
301 let mut login_url =
303 Url::parse(&self.config.addr).context("Failed to parse Vault address")?;
304 login_url.set_path("v1/auth/approle/login");
305 let login_request = VaultAppRoleLoginRequest {
306 role_id: auth_role_id.clone(),
307 secret_id: auth_secret_id.clone(),
308 };
309
310 let response = self
311 .client
312 .post(login_url.as_str())
313 .json(&login_request)
314 .send()
315 .await
316 .context("Failed to send app role login request")?;
317
318 if !response.status().is_success() {
319 if !force_refresh {
321 GLOBAL_VAULT_TOKEN_CACHE.invalidate(&cache_key).await;
322 }
323 return Err(anyhow::anyhow!(
324 "Vault app role login failed: {} - {}",
325 response.status(),
326 response.text().await.unwrap_or_default()
327 ));
328 }
329
330 let auth_response: VaultAuthResponse = response
331 .json()
332 .await
333 .context("Failed to parse Vault auth response")?;
334
335 let token = auth_response.auth.client_token;
336 let lease_duration = auth_response.auth.lease_duration;
337
338 let expires_at = Instant::now() + Duration::from_secs((lease_duration * 9) / 10);
340 let cached_token = CachedToken {
341 token: token.clone(),
342 expires_at,
343 };
344 GLOBAL_VAULT_TOKEN_CACHE
345 .insert(cache_key, cached_token)
346 .await;
347
348 Ok(token)
349 }
350 }
351 }
352}
353
354#[cfg(test)]
355mod tests {
356 use serde_json::json;
357
358 use super::*;
359
360 #[test]
361 fn test_hashicorp_vault_config_token_auth_full() {
362 let json_config = json!({
363 "addr": "https://vault.example.com",
364 "path": "secret/data/myapp",
365 "field": "api_key",
366 "auth_method": "token",
367 "auth_token": "hvs.123abc",
368 "tls_skip_verify": "true"
369 });
370
371 let config: HashiCorpVaultConfig = serde_json::from_value(json_config).unwrap();
372
373 assert_eq!(config.addr, "https://vault.example.com");
374 assert_eq!(config.path, "secret/data/myapp");
375 assert_eq!(config.field, "api_key");
376 assert!(config.tls_skip_verify);
377
378 match config.auth {
379 HashiCorpVaultAuth::Token { auth_token } => {
380 assert_eq!(auth_token, "hvs.123abc");
381 }
382 _ => panic!("Expected Token auth method"),
383 }
384 }
385
386 #[test]
387 fn test_hashicorp_vault_config_approle_auth_full() {
388 let json_config = json!({
389 "addr": "https://vault.example.com",
390 "path": "secret/data/myapp",
391 "field": "password",
392 "auth_method": "approle",
393 "auth_role_id": "role123",
394 "auth_secret_id": "secret456",
395 "tls_skip_verify": "false"
396 });
397
398 let config: HashiCorpVaultConfig = serde_json::from_value(json_config).unwrap();
399
400 assert_eq!(config.addr, "https://vault.example.com");
401 assert_eq!(config.path, "secret/data/myapp");
402 assert_eq!(config.field, "password");
403 assert!(!config.tls_skip_verify);
404
405 match config.auth {
406 HashiCorpVaultAuth::AppRole {
407 auth_role_id,
408 auth_secret_id,
409 } => {
410 assert_eq!(auth_role_id, "role123");
411 assert_eq!(auth_secret_id, "secret456");
412 }
413 _ => panic!("Expected AppRole auth method"),
414 }
415 }
416}