1use std::collections::HashMap;
16use std::fs;
17use std::sync::Arc;
18
19use anyhow::anyhow;
20use ldap3::{LdapConnAsync, Scope, SearchEntry, dn_escape, ldap_escape};
21use risingwave_common::config::{AuthMethod, HbaEntry};
22use thiserror_ext::AsReport;
23use tracing::warn;
24
25use crate::error::{PsqlError, PsqlResult};
26
27const LDAP_SERVER_KEY: &str = "ldapserver";
28const LDAP_PORT_KEY: &str = "ldapport";
29const LDAP_SCHEME_KEY: &str = "ldapscheme";
30const LDAP_BASE_DN_KEY: &str = "ldapbasedn";
31const LDAP_SEARCH_FILTER_KEY: &str = "ldapsearchfilter";
32const LDAP_SEARCH_ATTRIBUTE_KEY: &str = "ldapsearchattribute";
33const LDAP_BIND_DN_KEY: &str = "ldapbinddn";
34const LDAP_BIND_PASSWD_KEY: &str = "ldapbindpasswd";
35const LDAP_PREFIX_KEY: &str = "ldapprefix";
36const LDAP_SUFFIX_KEY: &str = "ldapsuffix";
37const LDAP_URL_KEY: &str = "ldapurl";
38
39const LDAP_TLS: &str = "ldaptls";
40
41const RW_LDAPTLS_CACERT: &str = "LDAPTLS_CACERT";
43const RW_LDAPTLS_CERT: &str = "LDAPTLS_CERT";
44const RW_LDAPTLS_KEY: &str = "LDAPTLS_KEY";
45const RW_LDAPTLS_REQCERT: &str = "LDAPTLS_REQCERT";
46
47#[derive(Debug, Clone, Copy)]
48enum ReqCertPolicy {
49 Never,
50 Allow,
51 Try,
52 Demand,
53}
54
55#[derive(Debug)]
56pub struct LdapTlsConfig {
57 ca_cert: Option<String>,
59 cert: Option<String>,
61 key: Option<String>,
63 req_cert: ReqCertPolicy,
65}
66
67impl LdapTlsConfig {
68 fn from_env() -> Self {
70 let ca_cert = std::env::var(RW_LDAPTLS_CACERT).ok();
71 let cert = std::env::var(RW_LDAPTLS_CERT).ok();
72 let key = std::env::var(RW_LDAPTLS_KEY).ok();
73 let req_cert = match std::env::var(RW_LDAPTLS_REQCERT).as_deref() {
74 Ok("never") => ReqCertPolicy::Never,
75 Ok("allow") => ReqCertPolicy::Allow,
76 Ok("try") => ReqCertPolicy::Try,
77 Ok("demand") => ReqCertPolicy::Demand,
78 _ => ReqCertPolicy::Demand, };
80
81 Self {
82 ca_cert,
83 cert,
84 key,
85 req_cert,
86 }
87 }
88
89 fn init_client_config(&self) -> PsqlResult<rustls::ClientConfig> {
91 let tls_client_config = rustls::ClientConfig::builder().with_safe_defaults();
92
93 let mut root_cert_store = rustls::RootCertStore::empty();
94 if let Some(tls_config) = &self.ca_cert {
95 let ca_cert_bytes = fs::read(tls_config).map_err(|e| {
96 PsqlError::StartupError(anyhow!(e).context("Failed to read CA certificate").into())
97 })?;
98 let ca_certs = rustls_pemfile::certs(&mut ca_cert_bytes.as_slice()).map_err(|e| {
99 PsqlError::StartupError(anyhow!(e).context("Failed to parse CA certificate").into())
100 })?;
101 for cert in ca_certs {
102 root_cert_store
103 .add(&rustls::Certificate(cert))
104 .map_err(|err| {
105 PsqlError::StartupError(
106 anyhow!(err).context("Failed to add CA certificate").into(),
107 )
108 })?;
109 }
110 } else {
111 for cert in
113 rustls_native_certs::load_native_certs().expect("could not load platform certs")
114 {
115 root_cert_store
116 .add(&rustls::Certificate(cert.0))
117 .map_err(|err| {
118 PsqlError::StartupError(
119 anyhow!(err)
120 .context("Failed to add native CA certificate")
121 .into(),
122 )
123 })?;
124 }
125 }
126 let tls_client_config = tls_client_config.with_root_certificates(root_cert_store);
127
128 if let Some(cert) = &self.cert {
129 let Some(key) = &self.key else {
130 return Err(PsqlError::StartupError(
131 "Client certificate provided without private key".into(),
132 ));
133 };
134 let client_cert_bytes = fs::read(cert).map_err(|e| {
135 PsqlError::StartupError(
136 anyhow!(e)
137 .context("Failed to read client certificate")
138 .into(),
139 )
140 })?;
141 let client_key_bytes = fs::read(key).map_err(|e| {
142 PsqlError::StartupError(anyhow!(e).context("Failed to read client key").into())
143 })?;
144 let client_certs =
145 rustls_pemfile::certs(&mut client_cert_bytes.as_slice()).map_err(|e| {
146 PsqlError::StartupError(
147 anyhow!(e)
148 .context("Failed to parse client certificate")
149 .into(),
150 )
151 })?;
152
153 let mut reader = std::io::Cursor::new(&client_key_bytes);
154 let mut private_keys =
155 rustls_pemfile::pkcs8_private_keys(&mut reader).map_err(|e| {
156 PsqlError::StartupError(anyhow!(e).context("Failed to parse client key").into())
157 })?;
158 if private_keys.is_empty() {
159 reader.set_position(0);
161 private_keys = rustls_pemfile::rsa_private_keys(&mut reader).map_err(|e| {
162 PsqlError::StartupError(anyhow!(e).context("Failed to parse client key").into())
163 })?;
164 }
165 let client_private_key = private_keys.pop().ok_or_else(|| {
166 PsqlError::StartupError("No private key found in client key file".into())
167 })?;
168 let client_certs_rustls: Vec<rustls::Certificate> =
169 client_certs.into_iter().map(rustls::Certificate).collect();
170
171 tls_client_config
172 .with_client_auth_cert(client_certs_rustls, rustls::PrivateKey(client_private_key))
173 .map_err(|err| {
174 PsqlError::StartupError(
175 anyhow!(err)
176 .context("Failed to set client certificate")
177 .into(),
178 )
179 })
180 } else {
181 Ok(tls_client_config.with_no_client_auth())
182 }
183 }
184}
185
186#[derive(Debug, Clone)]
188pub struct LdapConfig {
189 pub server: String,
191 pub base_dn: Option<String>,
193 pub search_filter: Option<String>,
195 pub search_attribute: Option<String>,
197 pub bind_dn: Option<String>,
199 pub bind_passwd: Option<String>,
201 pub prefix: Option<String>,
203 pub suffix: Option<String>,
205 pub start_tls: bool,
207}
208
209impl LdapConfig {
210 pub fn from_hba_options(options: &HashMap<String, String>) -> PsqlResult<Self> {
212 if let Some(ldap_url) = options.get(LDAP_URL_KEY) {
213 return Self::from_ldap_url(ldap_url, options);
214 }
215
216 let server = options
217 .get(LDAP_SERVER_KEY)
218 .ok_or_else(|| PsqlError::StartupError("LDAP server (ldapserver) is required".into()))?
219 .clone();
220
221 let scheme = options
222 .get(LDAP_SCHEME_KEY)
223 .map(|s| s.as_str())
224 .unwrap_or("ldap");
225 if scheme != "ldap" && scheme != "ldaps" {
226 return Err(PsqlError::StartupError(
227 "LDAP scheme (ldapscheme) must be either 'ldap' or 'ldaps'".into(),
228 ));
229 }
230
231 let port = options
232 .get(LDAP_PORT_KEY)
233 .and_then(|p| p.parse::<u16>().ok())
234 .unwrap_or_else(|| if scheme == "ldaps" { 636 } else { 389 });
235
236 let start_tls = options
237 .get(LDAP_TLS)
238 .and_then(|p| p.parse::<bool>().ok())
239 .unwrap_or(false);
240
241 if start_tls && scheme == "ldaps" {
243 return Err(PsqlError::StartupError(
244 "Cannot use STARTTLS (ldaptls) with ldaps scheme".into(),
245 ));
246 }
247
248 let server = format!("{}://{}:{}", scheme, server, port);
249 let base_dn = options.get(LDAP_BASE_DN_KEY).cloned();
250 let search_filter = options.get(LDAP_SEARCH_FILTER_KEY).cloned();
251 let search_attribute = options.get(LDAP_SEARCH_ATTRIBUTE_KEY).cloned();
252 let bind_dn = options.get(LDAP_BIND_DN_KEY).cloned();
253 let bind_passwd = options.get(LDAP_BIND_PASSWD_KEY).cloned();
254 let prefix = options.get(LDAP_PREFIX_KEY).cloned();
255 let suffix = options.get(LDAP_SUFFIX_KEY).cloned();
256
257 Ok(Self {
258 server,
259 base_dn,
260 search_filter,
261 search_attribute,
262 bind_dn,
263 bind_passwd,
264 prefix,
265 suffix,
266 start_tls,
267 })
268 }
269
270 fn from_ldap_url(ldap_url: &str, options: &HashMap<String, String>) -> PsqlResult<Self> {
273 let conflicting_params = [
276 LDAP_SERVER_KEY,
277 LDAP_PORT_KEY,
278 LDAP_SCHEME_KEY,
279 LDAP_BASE_DN_KEY,
280 LDAP_SEARCH_ATTRIBUTE_KEY,
281 LDAP_SEARCH_FILTER_KEY,
282 ];
283
284 for param in &conflicting_params {
285 if options.contains_key(*param) {
286 return Err(PsqlError::StartupError(
287 format!("Cannot specify both ldapurl and {} parameter", param).into(),
288 ));
289 }
290 }
291
292 let url = url::Url::parse(ldap_url).map_err(|e| {
294 PsqlError::StartupError(anyhow!(e).context("Failed to parse ldap url").into())
295 })?;
296
297 let scheme = url.scheme();
299 if scheme != "ldap" && scheme != "ldaps" {
300 return Err(PsqlError::StartupError(
301 "LDAP URL scheme must be either 'ldap' or 'ldaps'".into(),
302 ));
303 }
304
305 let host = url
307 .host_str()
308 .ok_or_else(|| PsqlError::StartupError("LDAP URL must contain a host".into()))?;
309 let port = url
310 .port()
311 .unwrap_or_else(|| if scheme == "ldaps" { 636 } else { 389 });
312
313 let server = format!("{}://{}:{}", scheme, host, port);
314
315 let base_dn = if url.path().len() > 1 {
317 Some(url.path()[1..].to_string())
318 } else {
319 None
320 };
321
322 let mut search_attribute = None;
325 let mut search_filter = None;
326
327 if let Some(query) = url.query() {
328 let parts: Vec<&str> = query.split('?').collect();
329
330 if !parts.is_empty() && !parts[0].is_empty() {
332 search_attribute = Some(parts[0].split(',').next().unwrap().to_owned());
333 }
334
335 if parts.len() > 2 && !parts[2].is_empty() {
337 search_filter = Some(parts[2].to_owned());
338 }
339 }
340
341 let start_tls = options
345 .get(LDAP_TLS)
346 .and_then(|p| p.parse::<bool>().ok())
347 .unwrap_or(false);
348
349 if start_tls && scheme == "ldaps" {
351 return Err(PsqlError::StartupError(
352 "Cannot use STARTTLS (ldaptls) with ldaps scheme".into(),
353 ));
354 }
355
356 let bind_dn = options.get(LDAP_BIND_DN_KEY).cloned();
357 let bind_passwd = options.get(LDAP_BIND_PASSWD_KEY).cloned();
358 let prefix = options.get(LDAP_PREFIX_KEY).cloned();
359 let suffix = options.get(LDAP_SUFFIX_KEY).cloned();
360
361 Ok(Self {
362 server,
363 base_dn,
364 search_filter,
365 search_attribute,
366 bind_dn,
367 bind_passwd,
368 prefix,
369 suffix,
370 start_tls,
371 })
372 }
373
374 fn certs_required(&self) -> bool {
375 self.server.starts_with("ldaps://") || self.start_tls
376 }
377}
378
379#[derive(Debug, Clone)]
381pub struct LdapAuthenticator {
382 config: LdapConfig,
384}
385
386impl LdapAuthenticator {
387 pub fn new(hba_entry: &HbaEntry) -> PsqlResult<Self> {
389 if hba_entry.auth_method != AuthMethod::Ldap {
390 return Err(PsqlError::StartupError(
391 "HBA entry is not configured for LDAP authentication".into(),
392 ));
393 }
394
395 let config = LdapConfig::from_hba_options(&hba_entry.auth_options)?;
396 Ok(Self { config })
397 }
398
399 pub async fn authenticate(&self, username: &str, password: &str) -> PsqlResult<bool> {
401 if password.is_empty() {
403 return Ok(false);
404 }
405
406 let has_simple_bind_params = self.config.prefix.is_some() || self.config.suffix.is_some();
413
414 let has_search_only_params = self.config.search_filter.is_some()
416 || self.config.bind_dn.is_some()
417 || self.config.search_attribute.is_some();
418
419 if has_simple_bind_params && has_search_only_params {
421 return Err(PsqlError::StartupError(
422 "Cannot mix simple bind parameters (ldapprefix/ldapsuffix) with search+bind parameters (ldapsearchfilter/ldapbinddn/ldapsearchattribute)".into()
423 ));
424 }
425
426 if has_simple_bind_params {
429 self.simple_bind(username, password).await
431 } else if self.config.base_dn.is_some() {
432 self.search_and_bind(username, password).await
434 } else {
435 self.simple_bind(username, password).await
437 }
438 }
439
440 #[allow(rw::format_error)]
442 async fn establish_connection(&self) -> PsqlResult<ldap3::Ldap> {
443 let config = &self.config;
444 let mut settings = ldap3::LdapConnSettings::new();
445
446 settings = settings.set_starttls(config.start_tls);
448
449 if config.certs_required() {
450 let tls_config = LdapTlsConfig::from_env();
451 tracing::debug!("fetched tls config from env: {:?}", tls_config);
452
453 let client_config = tls_config.init_client_config()?;
454 settings = settings.set_config(Arc::new(client_config));
455
456 if matches!(tls_config.req_cert, ReqCertPolicy::Demand) {
457 settings = settings.set_no_tls_verify(false);
458 } else {
459 warn!(
460 "LDAP client certificate verification is disabled due to LDAPTLS_REQCERT policy"
461 );
462 settings = settings.set_no_tls_verify(true);
463 }
464 }
465
466 let (conn, ldap) = LdapConnAsync::with_settings(settings, &config.server)
467 .await
468 .map_err(|err| {
469 PsqlError::StartupError(
470 anyhow!(err)
471 .context("Failed to connect to LDAP server")
472 .into(),
473 )
474 })?;
475 ldap3::drive!(conn);
476
477 Ok(ldap)
478 }
479
480 async fn search_and_bind(&self, username: &str, password: &str) -> PsqlResult<bool> {
482 let mut ldap = self.establish_connection().await?;
484
485 let base_dn = self
487 .config
488 .base_dn
489 .as_ref()
490 .ok_or_else(|| PsqlError::StartupError("LDAP base_dn not configured".into()))?;
491
492 if let (Some(bind_dn), Some(bind_passwd)) = (&self.config.bind_dn, &self.config.bind_passwd)
494 {
495 ldap.simple_bind(bind_dn, bind_passwd)
496 .await
497 .map_err(|e| {
498 PsqlError::StartupError(
499 anyhow!(e).context("LDAP bind as search user failed").into(),
500 )
501 })?
502 .success()
503 .map_err(|e| {
504 PsqlError::StartupError(
505 anyhow!(e).context("LDAP bind as search user failed").into(),
506 )
507 })?;
508 }
509
510 let search_filter = if let Some(filter_template) = &self.config.search_filter {
512 let escaped_username = ldap_escape(username);
515 filter_template.replace("$username", &escaped_username)
516 } else {
517 let escaped_username = ldap_escape(username);
520 let attr = self.config.search_attribute.as_deref().unwrap_or("uid");
521 format!("({}={})", attr, escaped_username)
522 };
523
524 let rs = ldap
525 .search(base_dn, Scope::Subtree, &search_filter, vec!["dn"])
526 .await
527 .map_err(|e| {
528 PsqlError::StartupError(anyhow!(e).context("LDAP search failed").into())
529 })?;
530
531 let search_entries: Vec<SearchEntry> =
533 rs.0.into_iter().map(SearchEntry::construct).collect();
534 if search_entries.is_empty() {
535 return Ok(false);
536 }
537
538 let user_dn = &search_entries[0].dn;
540
541 let bind_result = ldap
542 .simple_bind(user_dn, password)
543 .await
544 .map_err(|e| PsqlError::StartupError(anyhow!(e).context("LDAP bind failed").into()));
545
546 let _ = ldap.unbind().await;
548
549 let bind_result = bind_result?;
550 match bind_result.success() {
551 Ok(_) => Ok(true),
552 Err(e) => {
553 tracing::error!(error = %e.as_report(), "LDAP bind unsuccessful");
554 Err(PsqlError::StartupError(
555 anyhow!(e).context("LDAP bind failed").into(),
556 ))
557 }
558 }
559 }
560
561 async fn simple_bind(&self, username: &str, password: &str) -> PsqlResult<bool> {
563 let escaped_username = dn_escape(username);
570
571 let dn = if self.config.prefix.is_some() || self.config.suffix.is_some() {
572 let prefix = self.config.prefix.as_deref().unwrap_or("");
574 let suffix = self.config.suffix.as_deref().unwrap_or("");
575 format!("{}{}{}", prefix, escaped_username, suffix)
576 } else if let Some(base_dn) = &self.config.base_dn {
577 format!("uid={},{}", escaped_username, base_dn)
581 } else {
582 escaped_username.to_string()
584 };
585
586 let mut ldap = self.establish_connection().await?;
588
589 tracing::info!(%self.config.server, %dn, "simple bind authentication with LDAP server");
590
591 let bind_result = ldap
592 .simple_bind(&dn, password)
593 .await
594 .map_err(|e| PsqlError::StartupError(anyhow!(e).context("LDAP bind failed").into()));
595
596 let _ = ldap.unbind().await;
598
599 let bind_result = bind_result?;
600 match bind_result.success() {
601 Ok(_) => Ok(true),
602 Err(e) => {
603 tracing::error!(error = %e.as_report(), "LDAP bind unsuccessful");
604 Err(PsqlError::StartupError(
605 format!("LDAP bind failed: {}", e.as_report()).into(),
606 ))
607 }
608 }
609 }
610}