pgwire/
ldap_auth.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::fs;
17use std::sync::Arc;
18
19use anyhow::anyhow;
20use ldap3::{LdapConnAsync, Scope, SearchEntry, dn_escape, ldap_escape};
21use risingwave_common::config::{AuthMethod, HbaEntry};
22use rustls_pki_types::pem::PemObject;
23use rustls_pki_types::{CertificateDer, PrivateKeyDer};
24use thiserror_ext::AsReport;
25use tracing::warn;
26
27use crate::error::{PsqlError, PsqlResult};
28
29const LDAP_SERVER_KEY: &str = "ldapserver";
30const LDAP_PORT_KEY: &str = "ldapport";
31const LDAP_SCHEME_KEY: &str = "ldapscheme";
32const LDAP_BASE_DN_KEY: &str = "ldapbasedn";
33const LDAP_SEARCH_FILTER_KEY: &str = "ldapsearchfilter";
34const LDAP_SEARCH_ATTRIBUTE_KEY: &str = "ldapsearchattribute";
35const LDAP_BIND_DN_KEY: &str = "ldapbinddn";
36const LDAP_BIND_PASSWD_KEY: &str = "ldapbindpasswd";
37const LDAP_PREFIX_KEY: &str = "ldapprefix";
38const LDAP_SUFFIX_KEY: &str = "ldapsuffix";
39const LDAP_URL_KEY: &str = "ldapurl";
40
41const LDAP_TLS: &str = "ldaptls";
42
43/// LDAP TLS environment configuration
44const RW_LDAPTLS_CACERT: &str = "LDAPTLS_CACERT";
45const RW_LDAPTLS_CERT: &str = "LDAPTLS_CERT";
46const RW_LDAPTLS_KEY: &str = "LDAPTLS_KEY";
47const RW_LDAPTLS_REQCERT: &str = "LDAPTLS_REQCERT";
48
49#[derive(Debug, Clone, Copy)]
50enum ReqCertPolicy {
51    Never,
52    Allow,
53    Try,
54    Demand,
55}
56
57#[derive(Debug)]
58pub struct LdapTlsConfig {
59    /// `LDAPTLS_CACERT` environment variable
60    ca_cert: Option<String>,
61    /// `LDAPTLS_CERT` environment variable
62    cert: Option<String>,
63    /// `LDAPTLS_KEY` environment variable
64    key: Option<String>,
65    /// `LDAPTLS_REQCERT` environment variable
66    req_cert: ReqCertPolicy,
67}
68
69impl LdapTlsConfig {
70    /// Create LDAP TLS configuration from environment variables
71    fn from_env() -> Self {
72        let ca_cert = std::env::var(RW_LDAPTLS_CACERT).ok();
73        let cert = std::env::var(RW_LDAPTLS_CERT).ok();
74        let key = std::env::var(RW_LDAPTLS_KEY).ok();
75        let req_cert = match std::env::var(RW_LDAPTLS_REQCERT).as_deref() {
76            Ok("never") => ReqCertPolicy::Never,
77            Ok("allow") => ReqCertPolicy::Allow,
78            Ok("try") => ReqCertPolicy::Try,
79            Ok("demand") => ReqCertPolicy::Demand,
80            _ => ReqCertPolicy::Demand, // Default to demand
81        };
82
83        Self {
84            ca_cert,
85            cert,
86            key,
87            req_cert,
88        }
89    }
90
91    /// Initialize rustls ClientConfig based on TLS configuration
92    fn init_client_config(&self) -> PsqlResult<rustls::ClientConfig> {
93        let tls_client_config = rustls::ClientConfig::builder();
94
95        let mut root_cert_store = rustls::RootCertStore::empty();
96        if let Some(tls_config) = &self.ca_cert {
97            let ca_cert_bytes = fs::read(tls_config).map_err(|e| {
98                PsqlError::StartupError(anyhow!(e).context("Failed to read CA certificate").into())
99            })?;
100            for cert in CertificateDer::pem_slice_iter(&ca_cert_bytes) {
101                let cert = cert.map_err(|e| {
102                    PsqlError::StartupError(
103                        anyhow!(e).context("Failed to parse CA certificate").into(),
104                    )
105                })?;
106                root_cert_store.add(cert).map_err(|err| {
107                    PsqlError::StartupError(
108                        anyhow!(err).context("Failed to add CA certificate").into(),
109                    )
110                })?;
111            }
112        } else {
113            // If ca certs is not present, load system native certs.
114            for cert in
115                rustls_native_certs::load_native_certs().expect("could not load platform certs")
116            {
117                root_cert_store.add(cert).map_err(|err| {
118                    PsqlError::StartupError(
119                        anyhow!(err)
120                            .context("Failed to add native CA certificate")
121                            .into(),
122                    )
123                })?;
124            }
125        }
126        let tls_client_config = tls_client_config.with_root_certificates(root_cert_store);
127
128        if let Some(cert) = &self.cert {
129            let Some(key) = &self.key else {
130                return Err(PsqlError::StartupError(
131                    "Client certificate provided without private key".into(),
132                ));
133            };
134            let client_cert_bytes = fs::read(cert).map_err(|e| {
135                PsqlError::StartupError(
136                    anyhow!(e)
137                        .context("Failed to read client certificate")
138                        .into(),
139                )
140            })?;
141            let client_key_bytes = fs::read(key).map_err(|e| {
142                PsqlError::StartupError(anyhow!(e).context("Failed to read client key").into())
143            })?;
144            let client_certs = CertificateDer::pem_slice_iter(&client_cert_bytes)
145                .collect::<Result<Vec<_>, _>>()
146                .map_err(|e| {
147                    PsqlError::StartupError(
148                        anyhow!(e)
149                            .context("Failed to parse client certificate")
150                            .into(),
151                    )
152                })?;
153
154            let client_private_key =
155                PrivateKeyDer::from_pem_slice(&client_key_bytes).map_err(|e| {
156                    PsqlError::StartupError(anyhow!(e).context("Failed to parse client key").into())
157                })?;
158
159            tls_client_config
160                .with_client_auth_cert(client_certs, client_private_key)
161                .map_err(|err| {
162                    PsqlError::StartupError(
163                        anyhow!(err)
164                            .context("Failed to set client certificate")
165                            .into(),
166                    )
167                })
168        } else {
169            Ok(tls_client_config.with_no_client_auth())
170        }
171    }
172}
173
174/// LDAP configuration extracted from HBA entry
175#[derive(Debug, Clone)]
176pub struct LdapConfig {
177    /// LDAP server address
178    pub server: String,
179    /// LDAP bind DN template or search base
180    pub base_dn: Option<String>,
181    /// LDAP search filter template
182    pub search_filter: Option<String>,
183    /// LDAP search attribute (used in search+bind mode, defaults to "uid" if not specified)
184    pub search_attribute: Option<String>,
185    /// DN to bind as when performing searches
186    pub bind_dn: Option<String>,
187    /// Password for bind DN
188    pub bind_passwd: Option<String>,
189    /// Prefix to prepend to username in simple bind
190    pub prefix: Option<String>,
191    /// Suffix to append to username in simple bind
192    pub suffix: Option<String>,
193    /// Whether to use STARTTLS
194    pub start_tls: bool,
195}
196
197impl LdapConfig {
198    /// Create LDAP configuration from HBA entry options
199    pub fn from_hba_options(options: &HashMap<String, String>) -> PsqlResult<Self> {
200        if let Some(ldap_url) = options.get(LDAP_URL_KEY) {
201            return Self::from_ldap_url(ldap_url, options);
202        }
203
204        let server = options
205            .get(LDAP_SERVER_KEY)
206            .ok_or_else(|| PsqlError::StartupError("LDAP server (ldapserver) is required".into()))?
207            .clone();
208
209        let scheme = options
210            .get(LDAP_SCHEME_KEY)
211            .map(|s| s.as_str())
212            .unwrap_or("ldap");
213        if scheme != "ldap" && scheme != "ldaps" {
214            return Err(PsqlError::StartupError(
215                "LDAP scheme (ldapscheme) must be either 'ldap' or 'ldaps'".into(),
216            ));
217        }
218
219        let port = options
220            .get(LDAP_PORT_KEY)
221            .and_then(|p| p.parse::<u16>().ok())
222            .unwrap_or_else(|| if scheme == "ldaps" { 636 } else { 389 });
223
224        let start_tls = options
225            .get(LDAP_TLS)
226            .and_then(|p| p.parse::<bool>().ok())
227            .unwrap_or(false);
228
229        // Validate that StartTLS and ldaps are not used together
230        if start_tls && scheme == "ldaps" {
231            return Err(PsqlError::StartupError(
232                "Cannot use STARTTLS (ldaptls) with ldaps scheme".into(),
233            ));
234        }
235
236        let server = format!("{}://{}:{}", scheme, server, port);
237        let base_dn = options.get(LDAP_BASE_DN_KEY).cloned();
238        let search_filter = options.get(LDAP_SEARCH_FILTER_KEY).cloned();
239        let search_attribute = options.get(LDAP_SEARCH_ATTRIBUTE_KEY).cloned();
240        let bind_dn = options.get(LDAP_BIND_DN_KEY).cloned();
241        let bind_passwd = options.get(LDAP_BIND_PASSWD_KEY).cloned();
242        let prefix = options.get(LDAP_PREFIX_KEY).cloned();
243        let suffix = options.get(LDAP_SUFFIX_KEY).cloned();
244
245        Ok(Self {
246            server,
247            base_dn,
248            search_filter,
249            search_attribute,
250            bind_dn,
251            bind_passwd,
252            prefix,
253            suffix,
254            start_tls,
255        })
256    }
257
258    /// Parse LDAP URL (RFC 4516 format)
259    /// Format: ldap\[s\]://host:port/basedn?attributes?scope?filter
260    fn from_ldap_url(ldap_url: &str, options: &HashMap<String, String>) -> PsqlResult<Self> {
261        // Validate that conflicting parameters are not present
262        // According to PostgreSQL docs, ldapurl cannot be mixed with parameters that would conflict
263        let conflicting_params = [
264            LDAP_SERVER_KEY,
265            LDAP_PORT_KEY,
266            LDAP_SCHEME_KEY,
267            LDAP_BASE_DN_KEY,
268            LDAP_SEARCH_ATTRIBUTE_KEY,
269            LDAP_SEARCH_FILTER_KEY,
270        ];
271
272        for param in &conflicting_params {
273            if options.contains_key(*param) {
274                return Err(PsqlError::StartupError(
275                    format!("Cannot specify both ldapurl and {} parameter", param).into(),
276                ));
277            }
278        }
279
280        // Parse the URL using standard URL parsing
281        let url = url::Url::parse(ldap_url).map_err(|e| {
282            PsqlError::StartupError(anyhow!(e).context("Failed to parse ldap url").into())
283        })?;
284
285        // Validate scheme
286        let scheme = url.scheme();
287        if scheme != "ldap" && scheme != "ldaps" {
288            return Err(PsqlError::StartupError(
289                "LDAP URL scheme must be either 'ldap' or 'ldaps'".into(),
290            ));
291        }
292
293        // Extract host and port
294        let host = url
295            .host_str()
296            .ok_or_else(|| PsqlError::StartupError("LDAP URL must contain a host".into()))?;
297        let port = url
298            .port()
299            .unwrap_or_else(|| if scheme == "ldaps" { 636 } else { 389 });
300
301        let server = format!("{}://{}:{}", scheme, host, port);
302
303        // Extract basedn from path (remove leading /)
304        let base_dn = if url.path().len() > 1 {
305            Some(url.path()[1..].to_string())
306        } else {
307            None
308        };
309
310        // Parse query parameters for attributes, scope, filter
311        // Format: ?attributes?scope?filter
312        let mut search_attribute = None;
313        let mut search_filter = None;
314
315        if let Some(query) = url.query() {
316            let parts: Vec<&str> = query.split('?').collect();
317
318            // First part is attributes (comma-separated, we only care about the first one for search)
319            if !parts.is_empty() && !parts[0].is_empty() {
320                search_attribute = Some(parts[0].split(',').next().unwrap().to_owned());
321            }
322
323            // Third part is filter (index 2)
324            if parts.len() > 2 && !parts[2].is_empty() {
325                search_filter = Some(parts[2].to_owned());
326            }
327        }
328
329        // Only allow supplementary parameters with ldapurl:
330        // - ldaptls: for StartTLS
331        // - ldapbinddn/ldapbindpasswd: for authenticated searches
332        let start_tls = options
333            .get(LDAP_TLS)
334            .and_then(|p| p.parse::<bool>().ok())
335            .unwrap_or(false);
336
337        // Validate that StartTLS and ldaps are not used together
338        if start_tls && scheme == "ldaps" {
339            return Err(PsqlError::StartupError(
340                "Cannot use STARTTLS (ldaptls) with ldaps scheme".into(),
341            ));
342        }
343
344        let bind_dn = options.get(LDAP_BIND_DN_KEY).cloned();
345        let bind_passwd = options.get(LDAP_BIND_PASSWD_KEY).cloned();
346        let prefix = options.get(LDAP_PREFIX_KEY).cloned();
347        let suffix = options.get(LDAP_SUFFIX_KEY).cloned();
348
349        Ok(Self {
350            server,
351            base_dn,
352            search_filter,
353            search_attribute,
354            bind_dn,
355            bind_passwd,
356            prefix,
357            suffix,
358            start_tls,
359        })
360    }
361
362    fn certs_required(&self) -> bool {
363        self.server.starts_with("ldaps://") || self.start_tls
364    }
365}
366
367/// LDAP authenticator that validates user credentials against an LDAP server
368#[derive(Debug, Clone)]
369pub struct LdapAuthenticator {
370    /// LDAP server configuration from HBA entry
371    config: LdapConfig,
372}
373
374impl LdapAuthenticator {
375    /// Create a new LDAP authenticator from HBA entry options
376    pub fn new(hba_entry: &HbaEntry) -> PsqlResult<Self> {
377        if hba_entry.auth_method != AuthMethod::Ldap {
378            return Err(PsqlError::StartupError(
379                "HBA entry is not configured for LDAP authentication".into(),
380            ));
381        }
382
383        let config = LdapConfig::from_hba_options(&hba_entry.auth_options)?;
384        Ok(Self { config })
385    }
386
387    /// Authenticate a user
388    pub async fn authenticate(&self, username: &str, password: &str) -> PsqlResult<bool> {
389        // Skip authentication if password is empty
390        if password.is_empty() {
391            return Ok(false);
392        }
393
394        // Determine the authentication strategy based on configured parameters
395        // According to PostgreSQL documentation:
396        // - Simple bind mode: Triggered by ldapprefix and/or ldapsuffix
397        // - Search+bind mode: Triggered by ldapbasedn (when no prefix/suffix present)
398        // - It's an error to mix simple bind params with search+bind-only params
399
400        let has_simple_bind_params = self.config.prefix.is_some() || self.config.suffix.is_some();
401
402        // Search+bind-only parameters that shouldn't be mixed with simple bind
403        let has_search_only_params = self.config.search_filter.is_some()
404            || self.config.bind_dn.is_some()
405            || self.config.search_attribute.is_some();
406
407        // Validate that we don't mix simple bind params with search+bind-only params
408        if has_simple_bind_params && has_search_only_params {
409            return Err(PsqlError::StartupError(
410                "Cannot mix simple bind parameters (ldapprefix/ldapsuffix) with search+bind parameters (ldapsearchfilter/ldapbinddn/ldapsearchattribute)".into()
411            ));
412        }
413
414        // Decision logic based on PostgreSQL behavior:
415        // The mode is determined by which parameters are present:
416        if has_simple_bind_params {
417            // Simple bind mode: prefix/suffix present
418            self.simple_bind(username, password).await
419        } else if self.config.base_dn.is_some() {
420            // Search+bind mode: basedn present without prefix/suffix
421            self.search_and_bind(username, password).await
422        } else {
423            // Fallback: no basedn, no prefix/suffix - use username directly as DN
424            self.simple_bind(username, password).await
425        }
426    }
427
428    /// Establish an LDAP connection with configurable options
429    #[allow(rw::format_error)]
430    async fn establish_connection(&self) -> PsqlResult<ldap3::Ldap> {
431        let config = &self.config;
432        let mut settings = ldap3::LdapConnSettings::new();
433
434        // Configure STARTTLS if specified
435        settings = settings.set_starttls(config.start_tls);
436
437        if config.certs_required() {
438            let tls_config = LdapTlsConfig::from_env();
439            tracing::debug!("fetched tls config from env: {:?}", tls_config);
440
441            let client_config = tls_config.init_client_config()?;
442            settings = settings.set_config(Arc::new(client_config));
443
444            if matches!(tls_config.req_cert, ReqCertPolicy::Demand) {
445                settings = settings.set_no_tls_verify(false);
446            } else {
447                warn!(
448                    "LDAP client certificate verification is disabled due to LDAPTLS_REQCERT policy"
449                );
450                settings = settings.set_no_tls_verify(true);
451            }
452        }
453
454        let (conn, ldap) = LdapConnAsync::with_settings(settings, &config.server)
455            .await
456            .map_err(|err| {
457                PsqlError::StartupError(
458                    anyhow!(err)
459                        .context("Failed to connect to LDAP server")
460                        .into(),
461                )
462            })?;
463        ldap3::drive!(conn);
464
465        Ok(ldap)
466    }
467
468    /// Search for user in LDAP directory and then bind
469    async fn search_and_bind(&self, username: &str, password: &str) -> PsqlResult<bool> {
470        // Establish connection to LDAP server
471        let mut ldap = self.establish_connection().await?;
472
473        // Validate base_dn configuration
474        let base_dn = self
475            .config
476            .base_dn
477            .as_ref()
478            .ok_or_else(|| PsqlError::StartupError("LDAP base_dn not configured".into()))?;
479
480        // If bind_dn and bind_passwd are provided, bind as that user first
481        if let (Some(bind_dn), Some(bind_passwd)) = (&self.config.bind_dn, &self.config.bind_passwd)
482        {
483            ldap.simple_bind(bind_dn, bind_passwd)
484                .await
485                .map_err(|e| {
486                    PsqlError::StartupError(
487                        anyhow!(e).context("LDAP bind as search user failed").into(),
488                    )
489                })?
490                .success()
491                .map_err(|e| {
492                    PsqlError::StartupError(
493                        anyhow!(e).context("LDAP bind as search user failed").into(),
494                    )
495                })?;
496        }
497
498        // Build search filter
499        let search_filter = if let Some(filter_template) = &self.config.search_filter {
500            // Use custom filter template with $username placeholder
501            // SECURITY: Escape username to prevent LDAP filter injection
502            let escaped_username = ldap_escape(username);
503            filter_template.replace("$username", &escaped_username)
504        } else {
505            // Default filter using search_attribute (defaults to "uid" if not configured)
506            // SECURITY: Escape username to prevent LDAP filter injection
507            let escaped_username = ldap_escape(username);
508            let attr = self.config.search_attribute.as_deref().unwrap_or("uid");
509            format!("({}={})", attr, escaped_username)
510        };
511
512        let rs = ldap
513            .search(base_dn, Scope::Subtree, &search_filter, vec!["dn"])
514            .await
515            .map_err(|e| {
516                PsqlError::StartupError(anyhow!(e).context("LDAP search failed").into())
517            })?;
518
519        // If no user found, authentication fails
520        let search_entries: Vec<SearchEntry> =
521            rs.0.into_iter().map(SearchEntry::construct).collect();
522        if search_entries.is_empty() {
523            return Ok(false);
524        }
525
526        // Attempt to bind with the user's DN and password
527        let user_dn = &search_entries[0].dn;
528
529        let bind_result = ldap
530            .simple_bind(user_dn, password)
531            .await
532            .map_err(|e| PsqlError::StartupError(anyhow!(e).context("LDAP bind failed").into()));
533
534        // Explicitly unbind the connection
535        let _ = ldap.unbind().await;
536
537        let bind_result = bind_result?;
538        match bind_result.success() {
539            Ok(_) => Ok(true),
540            Err(e) => {
541                tracing::error!(error = %e.as_report(), "LDAP bind unsuccessful");
542                Err(PsqlError::StartupError(
543                    anyhow!(e).context("LDAP bind failed").into(),
544                ))
545            }
546        }
547    }
548
549    /// Simple bind authentication
550    async fn simple_bind(&self, username: &str, password: &str) -> PsqlResult<bool> {
551        // Construct DN from username according to PostgreSQL simple bind rules:
552        // 1. If prefix/suffix are configured, use them: prefix + username + suffix
553        // 2. If only basedn is configured (legacy/fallback), use: uid=username,basedn
554        // 3. Otherwise, use username directly as DN
555        //
556        // SECURITY: Escape username to prevent LDAP DN injection
557        let escaped_username = dn_escape(username);
558
559        let dn = if self.config.prefix.is_some() || self.config.suffix.is_some() {
560            // Use prefix/suffix to construct DN
561            let prefix = self.config.prefix.as_deref().unwrap_or("");
562            let suffix = self.config.suffix.as_deref().unwrap_or("");
563            format!("{}{}{}", prefix, escaped_username, suffix)
564        } else if let Some(base_dn) = &self.config.base_dn {
565            // Fallback: construct DN as uid=username,basedn
566            // Note: If basedn is present without prefix/suffix, this should normally
567            // trigger search+bind mode, but we support this for backwards compatibility
568            format!("uid={},{}", escaped_username, base_dn)
569        } else {
570            // Use username as-is as the DN (still escaped)
571            escaped_username.to_string()
572        };
573
574        // Attempt to bind
575        let mut ldap = self.establish_connection().await?;
576
577        tracing::info!(%self.config.server, %dn, "simple bind authentication with LDAP server");
578
579        let bind_result = ldap
580            .simple_bind(&dn, password)
581            .await
582            .map_err(|e| PsqlError::StartupError(anyhow!(e).context("LDAP bind failed").into()));
583
584        // Explicitly unbind the connection
585        let _ = ldap.unbind().await;
586
587        let bind_result = bind_result?;
588        match bind_result.success() {
589            Ok(_) => Ok(true),
590            Err(e) => {
591                tracing::error!(error = %e.as_report(), "LDAP bind unsuccessful");
592                Err(PsqlError::StartupError(
593                    format!("LDAP bind failed: {}", e.as_report()).into(),
594                ))
595            }
596        }
597    }
598}