1use std::collections::HashMap;
16use std::fs;
17use std::sync::Arc;
18
19use anyhow::anyhow;
20use ldap3::{LdapConnAsync, Scope, SearchEntry, dn_escape, ldap_escape};
21use risingwave_common::config::{AuthMethod, HbaEntry};
22use rustls_pki_types::pem::PemObject;
23use rustls_pki_types::{CertificateDer, PrivateKeyDer};
24use thiserror_ext::AsReport;
25use tracing::warn;
26
27use crate::error::{PsqlError, PsqlResult};
28
29const LDAP_SERVER_KEY: &str = "ldapserver";
30const LDAP_PORT_KEY: &str = "ldapport";
31const LDAP_SCHEME_KEY: &str = "ldapscheme";
32const LDAP_BASE_DN_KEY: &str = "ldapbasedn";
33const LDAP_SEARCH_FILTER_KEY: &str = "ldapsearchfilter";
34const LDAP_SEARCH_ATTRIBUTE_KEY: &str = "ldapsearchattribute";
35const LDAP_BIND_DN_KEY: &str = "ldapbinddn";
36const LDAP_BIND_PASSWD_KEY: &str = "ldapbindpasswd";
37const LDAP_PREFIX_KEY: &str = "ldapprefix";
38const LDAP_SUFFIX_KEY: &str = "ldapsuffix";
39const LDAP_URL_KEY: &str = "ldapurl";
40
41const LDAP_TLS: &str = "ldaptls";
42
43const RW_LDAPTLS_CACERT: &str = "LDAPTLS_CACERT";
45const RW_LDAPTLS_CERT: &str = "LDAPTLS_CERT";
46const RW_LDAPTLS_KEY: &str = "LDAPTLS_KEY";
47const RW_LDAPTLS_REQCERT: &str = "LDAPTLS_REQCERT";
48
49#[derive(Debug, Clone, Copy)]
50enum ReqCertPolicy {
51 Never,
52 Allow,
53 Try,
54 Demand,
55}
56
57#[derive(Debug)]
58pub struct LdapTlsConfig {
59 ca_cert: Option<String>,
61 cert: Option<String>,
63 key: Option<String>,
65 req_cert: ReqCertPolicy,
67}
68
69impl LdapTlsConfig {
70 fn from_env() -> Self {
72 let ca_cert = std::env::var(RW_LDAPTLS_CACERT).ok();
73 let cert = std::env::var(RW_LDAPTLS_CERT).ok();
74 let key = std::env::var(RW_LDAPTLS_KEY).ok();
75 let req_cert = match std::env::var(RW_LDAPTLS_REQCERT).as_deref() {
76 Ok("never") => ReqCertPolicy::Never,
77 Ok("allow") => ReqCertPolicy::Allow,
78 Ok("try") => ReqCertPolicy::Try,
79 Ok("demand") => ReqCertPolicy::Demand,
80 _ => ReqCertPolicy::Demand, };
82
83 Self {
84 ca_cert,
85 cert,
86 key,
87 req_cert,
88 }
89 }
90
91 fn init_client_config(&self) -> PsqlResult<rustls::ClientConfig> {
93 let tls_client_config = rustls::ClientConfig::builder();
94
95 let mut root_cert_store = rustls::RootCertStore::empty();
96 if let Some(tls_config) = &self.ca_cert {
97 let ca_cert_bytes = fs::read(tls_config).map_err(|e| {
98 PsqlError::StartupError(anyhow!(e).context("Failed to read CA certificate").into())
99 })?;
100 for cert in CertificateDer::pem_slice_iter(&ca_cert_bytes) {
101 let cert = cert.map_err(|e| {
102 PsqlError::StartupError(
103 anyhow!(e).context("Failed to parse CA certificate").into(),
104 )
105 })?;
106 root_cert_store.add(cert).map_err(|err| {
107 PsqlError::StartupError(
108 anyhow!(err).context("Failed to add CA certificate").into(),
109 )
110 })?;
111 }
112 } else {
113 for cert in
115 rustls_native_certs::load_native_certs().expect("could not load platform certs")
116 {
117 root_cert_store.add(cert).map_err(|err| {
118 PsqlError::StartupError(
119 anyhow!(err)
120 .context("Failed to add native CA certificate")
121 .into(),
122 )
123 })?;
124 }
125 }
126 let tls_client_config = tls_client_config.with_root_certificates(root_cert_store);
127
128 if let Some(cert) = &self.cert {
129 let Some(key) = &self.key else {
130 return Err(PsqlError::StartupError(
131 "Client certificate provided without private key".into(),
132 ));
133 };
134 let client_cert_bytes = fs::read(cert).map_err(|e| {
135 PsqlError::StartupError(
136 anyhow!(e)
137 .context("Failed to read client certificate")
138 .into(),
139 )
140 })?;
141 let client_key_bytes = fs::read(key).map_err(|e| {
142 PsqlError::StartupError(anyhow!(e).context("Failed to read client key").into())
143 })?;
144 let client_certs = CertificateDer::pem_slice_iter(&client_cert_bytes)
145 .collect::<Result<Vec<_>, _>>()
146 .map_err(|e| {
147 PsqlError::StartupError(
148 anyhow!(e)
149 .context("Failed to parse client certificate")
150 .into(),
151 )
152 })?;
153
154 let client_private_key =
155 PrivateKeyDer::from_pem_slice(&client_key_bytes).map_err(|e| {
156 PsqlError::StartupError(anyhow!(e).context("Failed to parse client key").into())
157 })?;
158
159 tls_client_config
160 .with_client_auth_cert(client_certs, client_private_key)
161 .map_err(|err| {
162 PsqlError::StartupError(
163 anyhow!(err)
164 .context("Failed to set client certificate")
165 .into(),
166 )
167 })
168 } else {
169 Ok(tls_client_config.with_no_client_auth())
170 }
171 }
172}
173
174#[derive(Debug, Clone)]
176pub struct LdapConfig {
177 pub server: String,
179 pub base_dn: Option<String>,
181 pub search_filter: Option<String>,
183 pub search_attribute: Option<String>,
185 pub bind_dn: Option<String>,
187 pub bind_passwd: Option<String>,
189 pub prefix: Option<String>,
191 pub suffix: Option<String>,
193 pub start_tls: bool,
195}
196
197impl LdapConfig {
198 pub fn from_hba_options(options: &HashMap<String, String>) -> PsqlResult<Self> {
200 if let Some(ldap_url) = options.get(LDAP_URL_KEY) {
201 return Self::from_ldap_url(ldap_url, options);
202 }
203
204 let server = options
205 .get(LDAP_SERVER_KEY)
206 .ok_or_else(|| PsqlError::StartupError("LDAP server (ldapserver) is required".into()))?
207 .clone();
208
209 let scheme = options
210 .get(LDAP_SCHEME_KEY)
211 .map(|s| s.as_str())
212 .unwrap_or("ldap");
213 if scheme != "ldap" && scheme != "ldaps" {
214 return Err(PsqlError::StartupError(
215 "LDAP scheme (ldapscheme) must be either 'ldap' or 'ldaps'".into(),
216 ));
217 }
218
219 let port = options
220 .get(LDAP_PORT_KEY)
221 .and_then(|p| p.parse::<u16>().ok())
222 .unwrap_or_else(|| if scheme == "ldaps" { 636 } else { 389 });
223
224 let start_tls = options
225 .get(LDAP_TLS)
226 .and_then(|p| p.parse::<bool>().ok())
227 .unwrap_or(false);
228
229 if start_tls && scheme == "ldaps" {
231 return Err(PsqlError::StartupError(
232 "Cannot use STARTTLS (ldaptls) with ldaps scheme".into(),
233 ));
234 }
235
236 let server = format!("{}://{}:{}", scheme, server, port);
237 let base_dn = options.get(LDAP_BASE_DN_KEY).cloned();
238 let search_filter = options.get(LDAP_SEARCH_FILTER_KEY).cloned();
239 let search_attribute = options.get(LDAP_SEARCH_ATTRIBUTE_KEY).cloned();
240 let bind_dn = options.get(LDAP_BIND_DN_KEY).cloned();
241 let bind_passwd = options.get(LDAP_BIND_PASSWD_KEY).cloned();
242 let prefix = options.get(LDAP_PREFIX_KEY).cloned();
243 let suffix = options.get(LDAP_SUFFIX_KEY).cloned();
244
245 Ok(Self {
246 server,
247 base_dn,
248 search_filter,
249 search_attribute,
250 bind_dn,
251 bind_passwd,
252 prefix,
253 suffix,
254 start_tls,
255 })
256 }
257
258 fn from_ldap_url(ldap_url: &str, options: &HashMap<String, String>) -> PsqlResult<Self> {
261 let conflicting_params = [
264 LDAP_SERVER_KEY,
265 LDAP_PORT_KEY,
266 LDAP_SCHEME_KEY,
267 LDAP_BASE_DN_KEY,
268 LDAP_SEARCH_ATTRIBUTE_KEY,
269 LDAP_SEARCH_FILTER_KEY,
270 ];
271
272 for param in &conflicting_params {
273 if options.contains_key(*param) {
274 return Err(PsqlError::StartupError(
275 format!("Cannot specify both ldapurl and {} parameter", param).into(),
276 ));
277 }
278 }
279
280 let url = url::Url::parse(ldap_url).map_err(|e| {
282 PsqlError::StartupError(anyhow!(e).context("Failed to parse ldap url").into())
283 })?;
284
285 let scheme = url.scheme();
287 if scheme != "ldap" && scheme != "ldaps" {
288 return Err(PsqlError::StartupError(
289 "LDAP URL scheme must be either 'ldap' or 'ldaps'".into(),
290 ));
291 }
292
293 let host = url
295 .host_str()
296 .ok_or_else(|| PsqlError::StartupError("LDAP URL must contain a host".into()))?;
297 let port = url
298 .port()
299 .unwrap_or_else(|| if scheme == "ldaps" { 636 } else { 389 });
300
301 let server = format!("{}://{}:{}", scheme, host, port);
302
303 let base_dn = if url.path().len() > 1 {
305 Some(url.path()[1..].to_string())
306 } else {
307 None
308 };
309
310 let mut search_attribute = None;
313 let mut search_filter = None;
314
315 if let Some(query) = url.query() {
316 let parts: Vec<&str> = query.split('?').collect();
317
318 if !parts.is_empty() && !parts[0].is_empty() {
320 search_attribute = Some(parts[0].split(',').next().unwrap().to_owned());
321 }
322
323 if parts.len() > 2 && !parts[2].is_empty() {
325 search_filter = Some(parts[2].to_owned());
326 }
327 }
328
329 let start_tls = options
333 .get(LDAP_TLS)
334 .and_then(|p| p.parse::<bool>().ok())
335 .unwrap_or(false);
336
337 if start_tls && scheme == "ldaps" {
339 return Err(PsqlError::StartupError(
340 "Cannot use STARTTLS (ldaptls) with ldaps scheme".into(),
341 ));
342 }
343
344 let bind_dn = options.get(LDAP_BIND_DN_KEY).cloned();
345 let bind_passwd = options.get(LDAP_BIND_PASSWD_KEY).cloned();
346 let prefix = options.get(LDAP_PREFIX_KEY).cloned();
347 let suffix = options.get(LDAP_SUFFIX_KEY).cloned();
348
349 Ok(Self {
350 server,
351 base_dn,
352 search_filter,
353 search_attribute,
354 bind_dn,
355 bind_passwd,
356 prefix,
357 suffix,
358 start_tls,
359 })
360 }
361
362 fn certs_required(&self) -> bool {
363 self.server.starts_with("ldaps://") || self.start_tls
364 }
365}
366
367#[derive(Debug, Clone)]
369pub struct LdapAuthenticator {
370 config: LdapConfig,
372}
373
374impl LdapAuthenticator {
375 pub fn new(hba_entry: &HbaEntry) -> PsqlResult<Self> {
377 if hba_entry.auth_method != AuthMethod::Ldap {
378 return Err(PsqlError::StartupError(
379 "HBA entry is not configured for LDAP authentication".into(),
380 ));
381 }
382
383 let config = LdapConfig::from_hba_options(&hba_entry.auth_options)?;
384 Ok(Self { config })
385 }
386
387 pub async fn authenticate(&self, username: &str, password: &str) -> PsqlResult<bool> {
389 if password.is_empty() {
391 return Ok(false);
392 }
393
394 let has_simple_bind_params = self.config.prefix.is_some() || self.config.suffix.is_some();
401
402 let has_search_only_params = self.config.search_filter.is_some()
404 || self.config.bind_dn.is_some()
405 || self.config.search_attribute.is_some();
406
407 if has_simple_bind_params && has_search_only_params {
409 return Err(PsqlError::StartupError(
410 "Cannot mix simple bind parameters (ldapprefix/ldapsuffix) with search+bind parameters (ldapsearchfilter/ldapbinddn/ldapsearchattribute)".into()
411 ));
412 }
413
414 if has_simple_bind_params {
417 self.simple_bind(username, password).await
419 } else if self.config.base_dn.is_some() {
420 self.search_and_bind(username, password).await
422 } else {
423 self.simple_bind(username, password).await
425 }
426 }
427
428 #[allow(rw::format_error)]
430 async fn establish_connection(&self) -> PsqlResult<ldap3::Ldap> {
431 let config = &self.config;
432 let mut settings = ldap3::LdapConnSettings::new();
433
434 settings = settings.set_starttls(config.start_tls);
436
437 if config.certs_required() {
438 let tls_config = LdapTlsConfig::from_env();
439 tracing::debug!("fetched tls config from env: {:?}", tls_config);
440
441 let client_config = tls_config.init_client_config()?;
442 settings = settings.set_config(Arc::new(client_config));
443
444 if matches!(tls_config.req_cert, ReqCertPolicy::Demand) {
445 settings = settings.set_no_tls_verify(false);
446 } else {
447 warn!(
448 "LDAP client certificate verification is disabled due to LDAPTLS_REQCERT policy"
449 );
450 settings = settings.set_no_tls_verify(true);
451 }
452 }
453
454 let (conn, ldap) = LdapConnAsync::with_settings(settings, &config.server)
455 .await
456 .map_err(|err| {
457 PsqlError::StartupError(
458 anyhow!(err)
459 .context("Failed to connect to LDAP server")
460 .into(),
461 )
462 })?;
463 ldap3::drive!(conn);
464
465 Ok(ldap)
466 }
467
468 async fn search_and_bind(&self, username: &str, password: &str) -> PsqlResult<bool> {
470 let mut ldap = self.establish_connection().await?;
472
473 let base_dn = self
475 .config
476 .base_dn
477 .as_ref()
478 .ok_or_else(|| PsqlError::StartupError("LDAP base_dn not configured".into()))?;
479
480 if let (Some(bind_dn), Some(bind_passwd)) = (&self.config.bind_dn, &self.config.bind_passwd)
482 {
483 ldap.simple_bind(bind_dn, bind_passwd)
484 .await
485 .map_err(|e| {
486 PsqlError::StartupError(
487 anyhow!(e).context("LDAP bind as search user failed").into(),
488 )
489 })?
490 .success()
491 .map_err(|e| {
492 PsqlError::StartupError(
493 anyhow!(e).context("LDAP bind as search user failed").into(),
494 )
495 })?;
496 }
497
498 let search_filter = if let Some(filter_template) = &self.config.search_filter {
500 let escaped_username = ldap_escape(username);
503 filter_template.replace("$username", &escaped_username)
504 } else {
505 let escaped_username = ldap_escape(username);
508 let attr = self.config.search_attribute.as_deref().unwrap_or("uid");
509 format!("({}={})", attr, escaped_username)
510 };
511
512 let rs = ldap
513 .search(base_dn, Scope::Subtree, &search_filter, vec!["dn"])
514 .await
515 .map_err(|e| {
516 PsqlError::StartupError(anyhow!(e).context("LDAP search failed").into())
517 })?;
518
519 let search_entries: Vec<SearchEntry> =
521 rs.0.into_iter().map(SearchEntry::construct).collect();
522 if search_entries.is_empty() {
523 return Ok(false);
524 }
525
526 let user_dn = &search_entries[0].dn;
528
529 let bind_result = ldap
530 .simple_bind(user_dn, password)
531 .await
532 .map_err(|e| PsqlError::StartupError(anyhow!(e).context("LDAP bind failed").into()));
533
534 let _ = ldap.unbind().await;
536
537 let bind_result = bind_result?;
538 match bind_result.success() {
539 Ok(_) => Ok(true),
540 Err(e) => {
541 tracing::error!(error = %e.as_report(), "LDAP bind unsuccessful");
542 Err(PsqlError::StartupError(
543 anyhow!(e).context("LDAP bind failed").into(),
544 ))
545 }
546 }
547 }
548
549 async fn simple_bind(&self, username: &str, password: &str) -> PsqlResult<bool> {
551 let escaped_username = dn_escape(username);
558
559 let dn = if self.config.prefix.is_some() || self.config.suffix.is_some() {
560 let prefix = self.config.prefix.as_deref().unwrap_or("");
562 let suffix = self.config.suffix.as_deref().unwrap_or("");
563 format!("{}{}{}", prefix, escaped_username, suffix)
564 } else if let Some(base_dn) = &self.config.base_dn {
565 format!("uid={},{}", escaped_username, base_dn)
569 } else {
570 escaped_username.to_string()
572 };
573
574 let mut ldap = self.establish_connection().await?;
576
577 tracing::info!(%self.config.server, %dn, "simple bind authentication with LDAP server");
578
579 let bind_result = ldap
580 .simple_bind(&dn, password)
581 .await
582 .map_err(|e| PsqlError::StartupError(anyhow!(e).context("LDAP bind failed").into()));
583
584 let _ = ldap.unbind().await;
586
587 let bind_result = bind_result?;
588 match bind_result.success() {
589 Ok(_) => Ok(true),
590 Err(e) => {
591 tracing::error!(error = %e.as_report(), "LDAP bind unsuccessful");
592 Err(PsqlError::StartupError(
593 format!("LDAP bind failed: {}", e.as_report()).into(),
594 ))
595 }
596 }
597 }
598}