risingwave_common_secret/
vault_client.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::sync::LazyLock;
17use std::time::{Duration, Instant};
18
19use anyhow::{Context, Result};
20use moka::future::Cache as MokaCache;
21use reqwest::Client;
22use risingwave_pb::secret;
23use serde::{Deserialize, Serialize};
24use serde_json::Value;
25use serde_with::{DisplayFromStr, serde_as};
26use url::Url;
27
28#[derive(Debug, Clone, PartialEq, Eq, Hash)]
29struct TokenCacheKey {
30    vault_base_url: String,
31    role_id: String,
32}
33
34#[derive(Debug, Clone)]
35struct CachedToken {
36    token: String,
37    expires_at: Instant,
38}
39
40/// Global cache for Vault tokens to reduce authentication requests
41/// Cache key contains (vault service base url, `role_id`) as requested
42static GLOBAL_VAULT_TOKEN_CACHE: LazyLock<MokaCache<TokenCacheKey, CachedToken>> =
43    LazyLock::new(|| {
44        MokaCache::builder()
45            .max_capacity(1000) // Limit cache size
46            .build()
47    });
48
49#[serde_as]
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct HashiCorpVaultConfig {
52    pub addr: String,
53    pub path: String,
54    pub field: String,
55    #[serde(flatten)]
56    pub auth: HashiCorpVaultAuth,
57    #[serde(default)]
58    #[serde_as(as = "DisplayFromStr")]
59    pub tls_skip_verify: bool,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
63#[serde(tag = "auth_method", rename_all = "lowercase")]
64pub enum HashiCorpVaultAuth {
65    Token {
66        auth_token: String,
67    },
68    #[serde(rename = "approle")]
69    AppRole {
70        auth_role_id: String,
71        auth_secret_id: String,
72    },
73}
74
75#[derive(Debug, Serialize, Deserialize)]
76struct VaultAppRoleLoginRequest {
77    role_id: String,
78    secret_id: String,
79}
80
81#[derive(Debug, Serialize, Deserialize)]
82struct VaultAuthResponse {
83    auth: VaultAuthData,
84}
85
86#[derive(Debug, Serialize, Deserialize)]
87struct VaultAuthData {
88    client_token: String,
89    lease_duration: u64,
90}
91
92#[derive(Debug, Serialize, Deserialize)]
93struct VaultSecretResponse {
94    data: VaultSecretData,
95}
96
97#[derive(Debug, Serialize, Deserialize)]
98struct VaultSecretData {
99    data: HashMap<String, Value>,
100}
101
102#[derive(Debug)]
103pub struct HashiCorpVaultClient {
104    client: Client,
105    config: HashiCorpVaultConfig,
106}
107
108impl HashiCorpVaultConfig {
109    /// Convert from protobuf `SecretHashicorpVaultBackend` to `HashiCorpVaultConfig`
110    pub fn from_protobuf(vault_backend: &secret::SecretHashicorpVaultBackend) -> Result<Self> {
111        let auth = match vault_backend.auth.as_ref() {
112            Some(secret::secret_hashicorp_vault_backend::Auth::TokenAuth(token_auth)) => {
113                HashiCorpVaultAuth::Token {
114                    auth_token: token_auth.token.clone(),
115                }
116            }
117            Some(secret::secret_hashicorp_vault_backend::Auth::ApproleAuth(approle_auth)) => {
118                HashiCorpVaultAuth::AppRole {
119                    auth_role_id: approle_auth.role_id.clone(),
120                    auth_secret_id: approle_auth.secret_id.clone(),
121                }
122            }
123            None => {
124                return Err(anyhow::anyhow!(
125                    "No auth method specified for Vault backend"
126                ));
127            }
128        };
129
130        Ok(HashiCorpVaultConfig {
131            addr: vault_backend.addr.clone(),
132            path: vault_backend.path.clone(),
133            field: vault_backend.field.clone(),
134            auth,
135            tls_skip_verify: vault_backend.tls_skip_verify,
136        })
137    }
138
139    /// Convert `HashiCorpVaultConfig` to protobuf `SecretHashicorpVaultBackend`
140    pub fn to_protobuf(&self) -> secret::SecretHashicorpVaultBackend {
141        let auth = match &self.auth {
142            HashiCorpVaultAuth::Token { auth_token } => Some(
143                secret::secret_hashicorp_vault_backend::Auth::TokenAuth(secret::VaultTokenAuth {
144                    token: auth_token.clone(),
145                }),
146            ),
147            HashiCorpVaultAuth::AppRole {
148                auth_role_id,
149                auth_secret_id,
150            } => Some(secret::secret_hashicorp_vault_backend::Auth::ApproleAuth(
151                secret::VaultAppRoleAuth {
152                    role_id: auth_role_id.clone(),
153                    secret_id: auth_secret_id.clone(),
154                },
155            )),
156        };
157
158        secret::SecretHashicorpVaultBackend {
159            addr: self.addr.clone(),
160            path: self.path.clone(),
161            field: self.field.clone(),
162            auth,
163            tls_skip_verify: self.tls_skip_verify,
164        }
165    }
166}
167
168impl HashiCorpVaultClient {
169    pub fn new(config: HashiCorpVaultConfig) -> Result<Self> {
170        let mut client_builder = Client::builder();
171
172        if config.tls_skip_verify {
173            client_builder = client_builder.danger_accept_invalid_certs(true);
174        }
175
176        let client = client_builder
177            .build()
178            .context("Failed to create HTTP client")?;
179
180        Ok(Self { client, config })
181    }
182
183    pub async fn get_secret(&self) -> Result<Vec<u8>> {
184        // Try to get secret, with retry logic for token invalidation
185        let mut force_refresh_token = false;
186
187        // Retry loop for handling token invalidation
188        for retry_count in 0..2 {
189            // Get token (either directly or via app role)
190            let token = self.get_token_internal(force_refresh_token).await?;
191
192            // Fetch secret from Vault
193            let mut url = Url::parse(&self.config.addr).context("Failed to parse Vault address")?;
194            url.set_path(&format!("v1/{}", self.config.path));
195            let response = self
196                .client
197                .get(url.as_str())
198                .header("X-Vault-Token", &token)
199                .send()
200                .await
201                .context("Failed to send request to Vault")?;
202
203            // Handle authentication failures - token may have been rotated/revoked
204            if (response.status() == 401 || response.status() == 403)
205                && retry_count == 0
206                && matches!(self.config.auth, HashiCorpVaultAuth::AppRole { .. })
207            {
208                // this case means the token changed during cache, need to trigger a refresh
209                force_refresh_token = true;
210                continue;
211            }
212
213            if !response.status().is_success() {
214                return Err(anyhow::anyhow!(
215                    "Vault API returned error status: {} - {}",
216                    response.status(),
217                    response.text().await.unwrap_or_default()
218                ));
219            }
220
221            // Success case - process the response and break out of retry loop
222            return self.process_secret_response(response).await;
223        }
224
225        // todo: refine error message
226        Err(anyhow::anyhow!("Failed to get secret from Vault"))
227    }
228
229    async fn process_secret_response(&self, response: reqwest::Response) -> Result<Vec<u8>> {
230        // https://developer.hashicorp.com/vault/docs/secrets/kv/kv-v2/cookbook/read-data
231        // a demo response:
232        //   {
233        //     "request_id": "e345b77b-8b5a-552b-eb2c-7d80a627c9ad",
234        //     "lease_id": "",
235        //     "renewable": false,
236        //     "lease_duration": 0,
237        //     "data": {
238        //       "data": {
239        //         "key": "test-api-key-12345",
240        //         "secret": "test-api-secret-67890"
241        //       },
242        //       "metadata": {
243        //         "created_time": "2025-07-17T08:07:24.177261949Z",
244        //         "custom_metadata": null,
245        //         "deletion_time": "",
246        //         "destroyed": false,
247        //         "version": 1
248        //       }
249        //     },
250        //     "wrap_info": null,
251        //     "warnings": null,
252        //     "auth": null,
253        //     "mount_type": "kv"
254        //   }
255
256        let secret_response: VaultSecretResponse = response
257            .json()
258            .await
259            .context("Failed to parse Vault secret response")?;
260
261        let field_value = secret_response
262            .data
263            .data
264            .get(&self.config.field)
265            .ok_or_else(|| anyhow::anyhow!("Field '{}' not found in secret", self.config.field))?;
266
267        let secret_bytes = match field_value {
268            Value::String(s) => s.as_bytes().to_vec(),
269            _ => serde_json::to_vec(field_value)
270                .context("Failed to serialize field value to bytes")?,
271        };
272
273        Ok(secret_bytes)
274    }
275
276    async fn get_token_internal(&self, force_refresh: bool) -> Result<String> {
277        match &self.config.auth {
278            HashiCorpVaultAuth::Token { auth_token } => Ok(auth_token.clone()),
279            HashiCorpVaultAuth::AppRole {
280                auth_role_id,
281                auth_secret_id,
282            } => {
283                // Create cache key with vault base URL and role_id
284                let cache_key = TokenCacheKey {
285                    vault_base_url: self.config.addr.trim_end_matches('/').to_owned(),
286                    role_id: auth_role_id.clone(),
287                };
288
289                // Check global token cache first (unless forced refresh)
290                if !force_refresh
291                    && let Some(cached_token) = GLOBAL_VAULT_TOKEN_CACHE.get(&cache_key).await
292                {
293                    if cached_token.expires_at > Instant::now() {
294                        return Ok(cached_token.token);
295                    } else {
296                        // Token expired, remove it from cache
297                        GLOBAL_VAULT_TOKEN_CACHE.invalidate(&cache_key).await;
298                    }
299                }
300
301                // Login with app role
302                let mut login_url =
303                    Url::parse(&self.config.addr).context("Failed to parse Vault address")?;
304                login_url.set_path("v1/auth/approle/login");
305                let login_request = VaultAppRoleLoginRequest {
306                    role_id: auth_role_id.clone(),
307                    secret_id: auth_secret_id.clone(),
308                };
309
310                let response = self
311                    .client
312                    .post(login_url.as_str())
313                    .json(&login_request)
314                    .send()
315                    .await
316                    .context("Failed to send app role login request")?;
317
318                if !response.status().is_success() {
319                    // If authentication fails and we have a cached token, invalidate it
320                    if !force_refresh {
321                        GLOBAL_VAULT_TOKEN_CACHE.invalidate(&cache_key).await;
322                    }
323                    return Err(anyhow::anyhow!(
324                        "Vault app role login failed: {} - {}",
325                        response.status(),
326                        response.text().await.unwrap_or_default()
327                    ));
328                }
329
330                let auth_response: VaultAuthResponse = response
331                    .json()
332                    .await
333                    .context("Failed to parse Vault auth response")?;
334
335                let token = auth_response.auth.client_token;
336                let lease_duration = auth_response.auth.lease_duration;
337
338                // Cache the token with per-entry expiration based on lease duration (90% of lease duration)
339                let expires_at = Instant::now() + Duration::from_secs((lease_duration * 9) / 10);
340                let cached_token = CachedToken {
341                    token: token.clone(),
342                    expires_at,
343                };
344                GLOBAL_VAULT_TOKEN_CACHE
345                    .insert(cache_key, cached_token)
346                    .await;
347
348                Ok(token)
349            }
350        }
351    }
352}
353
354#[cfg(test)]
355mod tests {
356    use serde_json::json;
357
358    use super::*;
359
360    #[test]
361    fn test_hashicorp_vault_config_token_auth_full() {
362        let json_config = json!({
363            "addr": "https://vault.example.com",
364            "path": "secret/data/myapp",
365            "field": "api_key",
366            "auth_method": "token",
367            "auth_token": "hvs.123abc",
368            "tls_skip_verify": "true"
369        });
370
371        let config: HashiCorpVaultConfig = serde_json::from_value(json_config).unwrap();
372
373        assert_eq!(config.addr, "https://vault.example.com");
374        assert_eq!(config.path, "secret/data/myapp");
375        assert_eq!(config.field, "api_key");
376        assert!(config.tls_skip_verify);
377
378        match config.auth {
379            HashiCorpVaultAuth::Token { auth_token } => {
380                assert_eq!(auth_token, "hvs.123abc");
381            }
382            _ => panic!("Expected Token auth method"),
383        }
384    }
385
386    #[test]
387    fn test_hashicorp_vault_config_approle_auth_full() {
388        let json_config = json!({
389            "addr": "https://vault.example.com",
390            "path": "secret/data/myapp",
391            "field": "password",
392            "auth_method": "approle",
393            "auth_role_id": "role123",
394            "auth_secret_id": "secret456",
395            "tls_skip_verify": "false"
396        });
397
398        let config: HashiCorpVaultConfig = serde_json::from_value(json_config).unwrap();
399
400        assert_eq!(config.addr, "https://vault.example.com");
401        assert_eq!(config.path, "secret/data/myapp");
402        assert_eq!(config.field, "password");
403        assert!(!config.tls_skip_verify);
404
405        match config.auth {
406            HashiCorpVaultAuth::AppRole {
407                auth_role_id,
408                auth_secret_id,
409            } => {
410                assert_eq!(auth_role_id, "role123");
411                assert_eq!(auth_secret_id, "secret456");
412            }
413            _ => panic!("Expected AppRole auth method"),
414        }
415    }
416}