1use std::collections::HashMap;
16use std::fs;
17use std::sync::Arc;
18
19use anyhow::anyhow;
20use ldap3::{LdapConnAsync, Scope, SearchEntry, dn_escape, ldap_escape};
21use risingwave_common::config::{AuthMethod, HbaEntry};
22use thiserror_ext::AsReport;
23use tracing::warn;
24
25use crate::error::{PsqlError, PsqlResult};
26
27const LDAP_SERVER_KEY: &str = "ldapserver";
28const LDAP_PORT_KEY: &str = "ldapport";
29const LDAP_SCHEME_KEY: &str = "ldapscheme";
30const LDAP_BASE_DN_KEY: &str = "ldapbasedn";
31const LDAP_SEARCH_FILTER_KEY: &str = "ldapsearchfilter";
32const LDAP_SEARCH_ATTRIBUTE_KEY: &str = "ldapsearchattribute";
33const LDAP_BIND_DN_KEY: &str = "ldapbinddn";
34const LDAP_BIND_PASSWD_KEY: &str = "ldapbindpasswd";
35const LDAP_PREFIX_KEY: &str = "ldapprefix";
36const LDAP_SUFFIX_KEY: &str = "ldapsuffix";
37const LDAP_URL_KEY: &str = "ldapurl";
38
39const LDAP_TLS: &str = "ldaptls";
40
41const RW_LDAPTLS_CACERT: &str = "LDAPTLS_CACERT";
43const RW_LDAPTLS_CERT: &str = "LDAPTLS_CERT";
44const RW_LDAPTLS_KEY: &str = "LDAPTLS_KEY";
45const RW_LDAPTLS_REQCERT: &str = "LDAPTLS_REQCERT";
46
47#[derive(Debug, Clone, Copy)]
48enum ReqCertPolicy {
49 Never,
50 Allow,
51 Try,
52 Demand,
53}
54
55#[derive(Debug)]
56pub struct LdapTlsConfig {
57 ca_cert: Option<String>,
59 cert: Option<String>,
61 key: Option<String>,
63 req_cert: ReqCertPolicy,
65}
66
67impl LdapTlsConfig {
68 fn from_env() -> Self {
70 let ca_cert = std::env::var(RW_LDAPTLS_CACERT).ok();
71 let cert = std::env::var(RW_LDAPTLS_CERT).ok();
72 let key = std::env::var(RW_LDAPTLS_KEY).ok();
73 let req_cert = match std::env::var(RW_LDAPTLS_REQCERT).as_deref() {
74 Ok("never") => ReqCertPolicy::Never,
75 Ok("allow") => ReqCertPolicy::Allow,
76 Ok("try") => ReqCertPolicy::Try,
77 Ok("demand") => ReqCertPolicy::Demand,
78 _ => ReqCertPolicy::Demand, };
80
81 Self {
82 ca_cert,
83 cert,
84 key,
85 req_cert,
86 }
87 }
88
89 fn init_client_config(&self) -> PsqlResult<rustls::ClientConfig> {
91 let tls_client_config = rustls::ClientConfig::builder();
92
93 let mut root_cert_store = rustls::RootCertStore::empty();
94 if let Some(tls_config) = &self.ca_cert {
95 let ca_cert_bytes = fs::read(tls_config).map_err(|e| {
96 PsqlError::StartupError(anyhow!(e).context("Failed to read CA certificate").into())
97 })?;
98 for cert in rustls_pemfile::certs(&mut ca_cert_bytes.as_slice()) {
99 let cert = cert.map_err(|e| {
100 PsqlError::StartupError(
101 anyhow!(e).context("Failed to parse CA certificate").into(),
102 )
103 })?;
104 root_cert_store.add(cert).map_err(|err| {
105 PsqlError::StartupError(
106 anyhow!(err).context("Failed to add CA certificate").into(),
107 )
108 })?;
109 }
110 } else {
111 for cert in
113 rustls_native_certs::load_native_certs().expect("could not load platform certs")
114 {
115 root_cert_store.add(cert).map_err(|err| {
116 PsqlError::StartupError(
117 anyhow!(err)
118 .context("Failed to add native CA certificate")
119 .into(),
120 )
121 })?;
122 }
123 }
124 let tls_client_config = tls_client_config.with_root_certificates(root_cert_store);
125
126 if let Some(cert) = &self.cert {
127 let Some(key) = &self.key else {
128 return Err(PsqlError::StartupError(
129 "Client certificate provided without private key".into(),
130 ));
131 };
132 let client_cert_bytes = fs::read(cert).map_err(|e| {
133 PsqlError::StartupError(
134 anyhow!(e)
135 .context("Failed to read client certificate")
136 .into(),
137 )
138 })?;
139 let client_key_bytes = fs::read(key).map_err(|e| {
140 PsqlError::StartupError(anyhow!(e).context("Failed to read client key").into())
141 })?;
142 let client_certs = rustls_pemfile::certs(&mut client_cert_bytes.as_slice())
143 .try_collect()
144 .map_err(|e| {
145 PsqlError::StartupError(
146 anyhow!(e)
147 .context("Failed to parse client certificate")
148 .into(),
149 )
150 })?;
151
152 let private_key = rustls_pemfile::private_key(&mut client_key_bytes.as_slice())
153 .map_err(|e| {
154 PsqlError::StartupError(anyhow!(e).context("Failed to parse client key").into())
155 })?;
156 let Some(client_private_key) = private_key else {
157 return Err(PsqlError::StartupError(
158 "No private key found in client key file".into(),
159 ));
160 };
161
162 tls_client_config
163 .with_client_auth_cert(client_certs, client_private_key)
164 .map_err(|err| {
165 PsqlError::StartupError(
166 anyhow!(err)
167 .context("Failed to set client certificate")
168 .into(),
169 )
170 })
171 } else {
172 Ok(tls_client_config.with_no_client_auth())
173 }
174 }
175}
176
177#[derive(Debug, Clone)]
179pub struct LdapConfig {
180 pub server: String,
182 pub base_dn: Option<String>,
184 pub search_filter: Option<String>,
186 pub search_attribute: Option<String>,
188 pub bind_dn: Option<String>,
190 pub bind_passwd: Option<String>,
192 pub prefix: Option<String>,
194 pub suffix: Option<String>,
196 pub start_tls: bool,
198}
199
200impl LdapConfig {
201 pub fn from_hba_options(options: &HashMap<String, String>) -> PsqlResult<Self> {
203 if let Some(ldap_url) = options.get(LDAP_URL_KEY) {
204 return Self::from_ldap_url(ldap_url, options);
205 }
206
207 let server = options
208 .get(LDAP_SERVER_KEY)
209 .ok_or_else(|| PsqlError::StartupError("LDAP server (ldapserver) is required".into()))?
210 .clone();
211
212 let scheme = options
213 .get(LDAP_SCHEME_KEY)
214 .map(|s| s.as_str())
215 .unwrap_or("ldap");
216 if scheme != "ldap" && scheme != "ldaps" {
217 return Err(PsqlError::StartupError(
218 "LDAP scheme (ldapscheme) must be either 'ldap' or 'ldaps'".into(),
219 ));
220 }
221
222 let port = options
223 .get(LDAP_PORT_KEY)
224 .and_then(|p| p.parse::<u16>().ok())
225 .unwrap_or_else(|| if scheme == "ldaps" { 636 } else { 389 });
226
227 let start_tls = options
228 .get(LDAP_TLS)
229 .and_then(|p| p.parse::<bool>().ok())
230 .unwrap_or(false);
231
232 if start_tls && scheme == "ldaps" {
234 return Err(PsqlError::StartupError(
235 "Cannot use STARTTLS (ldaptls) with ldaps scheme".into(),
236 ));
237 }
238
239 let server = format!("{}://{}:{}", scheme, server, port);
240 let base_dn = options.get(LDAP_BASE_DN_KEY).cloned();
241 let search_filter = options.get(LDAP_SEARCH_FILTER_KEY).cloned();
242 let search_attribute = options.get(LDAP_SEARCH_ATTRIBUTE_KEY).cloned();
243 let bind_dn = options.get(LDAP_BIND_DN_KEY).cloned();
244 let bind_passwd = options.get(LDAP_BIND_PASSWD_KEY).cloned();
245 let prefix = options.get(LDAP_PREFIX_KEY).cloned();
246 let suffix = options.get(LDAP_SUFFIX_KEY).cloned();
247
248 Ok(Self {
249 server,
250 base_dn,
251 search_filter,
252 search_attribute,
253 bind_dn,
254 bind_passwd,
255 prefix,
256 suffix,
257 start_tls,
258 })
259 }
260
261 fn from_ldap_url(ldap_url: &str, options: &HashMap<String, String>) -> PsqlResult<Self> {
264 let conflicting_params = [
267 LDAP_SERVER_KEY,
268 LDAP_PORT_KEY,
269 LDAP_SCHEME_KEY,
270 LDAP_BASE_DN_KEY,
271 LDAP_SEARCH_ATTRIBUTE_KEY,
272 LDAP_SEARCH_FILTER_KEY,
273 ];
274
275 for param in &conflicting_params {
276 if options.contains_key(*param) {
277 return Err(PsqlError::StartupError(
278 format!("Cannot specify both ldapurl and {} parameter", param).into(),
279 ));
280 }
281 }
282
283 let url = url::Url::parse(ldap_url).map_err(|e| {
285 PsqlError::StartupError(anyhow!(e).context("Failed to parse ldap url").into())
286 })?;
287
288 let scheme = url.scheme();
290 if scheme != "ldap" && scheme != "ldaps" {
291 return Err(PsqlError::StartupError(
292 "LDAP URL scheme must be either 'ldap' or 'ldaps'".into(),
293 ));
294 }
295
296 let host = url
298 .host_str()
299 .ok_or_else(|| PsqlError::StartupError("LDAP URL must contain a host".into()))?;
300 let port = url
301 .port()
302 .unwrap_or_else(|| if scheme == "ldaps" { 636 } else { 389 });
303
304 let server = format!("{}://{}:{}", scheme, host, port);
305
306 let base_dn = if url.path().len() > 1 {
308 Some(url.path()[1..].to_string())
309 } else {
310 None
311 };
312
313 let mut search_attribute = None;
316 let mut search_filter = None;
317
318 if let Some(query) = url.query() {
319 let parts: Vec<&str> = query.split('?').collect();
320
321 if !parts.is_empty() && !parts[0].is_empty() {
323 search_attribute = Some(parts[0].split(',').next().unwrap().to_owned());
324 }
325
326 if parts.len() > 2 && !parts[2].is_empty() {
328 search_filter = Some(parts[2].to_owned());
329 }
330 }
331
332 let start_tls = options
336 .get(LDAP_TLS)
337 .and_then(|p| p.parse::<bool>().ok())
338 .unwrap_or(false);
339
340 if start_tls && scheme == "ldaps" {
342 return Err(PsqlError::StartupError(
343 "Cannot use STARTTLS (ldaptls) with ldaps scheme".into(),
344 ));
345 }
346
347 let bind_dn = options.get(LDAP_BIND_DN_KEY).cloned();
348 let bind_passwd = options.get(LDAP_BIND_PASSWD_KEY).cloned();
349 let prefix = options.get(LDAP_PREFIX_KEY).cloned();
350 let suffix = options.get(LDAP_SUFFIX_KEY).cloned();
351
352 Ok(Self {
353 server,
354 base_dn,
355 search_filter,
356 search_attribute,
357 bind_dn,
358 bind_passwd,
359 prefix,
360 suffix,
361 start_tls,
362 })
363 }
364
365 fn certs_required(&self) -> bool {
366 self.server.starts_with("ldaps://") || self.start_tls
367 }
368}
369
370#[derive(Debug, Clone)]
372pub struct LdapAuthenticator {
373 config: LdapConfig,
375}
376
377impl LdapAuthenticator {
378 pub fn new(hba_entry: &HbaEntry) -> PsqlResult<Self> {
380 if hba_entry.auth_method != AuthMethod::Ldap {
381 return Err(PsqlError::StartupError(
382 "HBA entry is not configured for LDAP authentication".into(),
383 ));
384 }
385
386 let config = LdapConfig::from_hba_options(&hba_entry.auth_options)?;
387 Ok(Self { config })
388 }
389
390 pub async fn authenticate(&self, username: &str, password: &str) -> PsqlResult<bool> {
392 if password.is_empty() {
394 return Ok(false);
395 }
396
397 let has_simple_bind_params = self.config.prefix.is_some() || self.config.suffix.is_some();
404
405 let has_search_only_params = self.config.search_filter.is_some()
407 || self.config.bind_dn.is_some()
408 || self.config.search_attribute.is_some();
409
410 if has_simple_bind_params && has_search_only_params {
412 return Err(PsqlError::StartupError(
413 "Cannot mix simple bind parameters (ldapprefix/ldapsuffix) with search+bind parameters (ldapsearchfilter/ldapbinddn/ldapsearchattribute)".into()
414 ));
415 }
416
417 if has_simple_bind_params {
420 self.simple_bind(username, password).await
422 } else if self.config.base_dn.is_some() {
423 self.search_and_bind(username, password).await
425 } else {
426 self.simple_bind(username, password).await
428 }
429 }
430
431 #[allow(rw::format_error)]
433 async fn establish_connection(&self) -> PsqlResult<ldap3::Ldap> {
434 let config = &self.config;
435 let mut settings = ldap3::LdapConnSettings::new();
436
437 settings = settings.set_starttls(config.start_tls);
439
440 if config.certs_required() {
441 let tls_config = LdapTlsConfig::from_env();
442 tracing::debug!("fetched tls config from env: {:?}", tls_config);
443
444 let client_config = tls_config.init_client_config()?;
445 settings = settings.set_config(Arc::new(client_config));
446
447 if matches!(tls_config.req_cert, ReqCertPolicy::Demand) {
448 settings = settings.set_no_tls_verify(false);
449 } else {
450 warn!(
451 "LDAP client certificate verification is disabled due to LDAPTLS_REQCERT policy"
452 );
453 settings = settings.set_no_tls_verify(true);
454 }
455 }
456
457 let (conn, ldap) = LdapConnAsync::with_settings(settings, &config.server)
458 .await
459 .map_err(|err| {
460 PsqlError::StartupError(
461 anyhow!(err)
462 .context("Failed to connect to LDAP server")
463 .into(),
464 )
465 })?;
466 ldap3::drive!(conn);
467
468 Ok(ldap)
469 }
470
471 async fn search_and_bind(&self, username: &str, password: &str) -> PsqlResult<bool> {
473 let mut ldap = self.establish_connection().await?;
475
476 let base_dn = self
478 .config
479 .base_dn
480 .as_ref()
481 .ok_or_else(|| PsqlError::StartupError("LDAP base_dn not configured".into()))?;
482
483 if let (Some(bind_dn), Some(bind_passwd)) = (&self.config.bind_dn, &self.config.bind_passwd)
485 {
486 ldap.simple_bind(bind_dn, bind_passwd)
487 .await
488 .map_err(|e| {
489 PsqlError::StartupError(
490 anyhow!(e).context("LDAP bind as search user failed").into(),
491 )
492 })?
493 .success()
494 .map_err(|e| {
495 PsqlError::StartupError(
496 anyhow!(e).context("LDAP bind as search user failed").into(),
497 )
498 })?;
499 }
500
501 let search_filter = if let Some(filter_template) = &self.config.search_filter {
503 let escaped_username = ldap_escape(username);
506 filter_template.replace("$username", &escaped_username)
507 } else {
508 let escaped_username = ldap_escape(username);
511 let attr = self.config.search_attribute.as_deref().unwrap_or("uid");
512 format!("({}={})", attr, escaped_username)
513 };
514
515 let rs = ldap
516 .search(base_dn, Scope::Subtree, &search_filter, vec!["dn"])
517 .await
518 .map_err(|e| {
519 PsqlError::StartupError(anyhow!(e).context("LDAP search failed").into())
520 })?;
521
522 let search_entries: Vec<SearchEntry> =
524 rs.0.into_iter().map(SearchEntry::construct).collect();
525 if search_entries.is_empty() {
526 return Ok(false);
527 }
528
529 let user_dn = &search_entries[0].dn;
531
532 let bind_result = ldap
533 .simple_bind(user_dn, password)
534 .await
535 .map_err(|e| PsqlError::StartupError(anyhow!(e).context("LDAP bind failed").into()));
536
537 let _ = ldap.unbind().await;
539
540 let bind_result = bind_result?;
541 match bind_result.success() {
542 Ok(_) => Ok(true),
543 Err(e) => {
544 tracing::error!(error = %e.as_report(), "LDAP bind unsuccessful");
545 Err(PsqlError::StartupError(
546 anyhow!(e).context("LDAP bind failed").into(),
547 ))
548 }
549 }
550 }
551
552 async fn simple_bind(&self, username: &str, password: &str) -> PsqlResult<bool> {
554 let escaped_username = dn_escape(username);
561
562 let dn = if self.config.prefix.is_some() || self.config.suffix.is_some() {
563 let prefix = self.config.prefix.as_deref().unwrap_or("");
565 let suffix = self.config.suffix.as_deref().unwrap_or("");
566 format!("{}{}{}", prefix, escaped_username, suffix)
567 } else if let Some(base_dn) = &self.config.base_dn {
568 format!("uid={},{}", escaped_username, base_dn)
572 } else {
573 escaped_username.to_string()
575 };
576
577 let mut ldap = self.establish_connection().await?;
579
580 tracing::info!(%self.config.server, %dn, "simple bind authentication with LDAP server");
581
582 let bind_result = ldap
583 .simple_bind(&dn, password)
584 .await
585 .map_err(|e| PsqlError::StartupError(anyhow!(e).context("LDAP bind failed").into()));
586
587 let _ = ldap.unbind().await;
589
590 let bind_result = bind_result?;
591 match bind_result.success() {
592 Ok(_) => Ok(true),
593 Err(e) => {
594 tracing::error!(error = %e.as_report(), "LDAP bind unsuccessful");
595 Err(PsqlError::StartupError(
596 format!("LDAP bind failed: {}", e.as_report()).into(),
597 ))
598 }
599 }
600 }
601}