pgwire/
ldap_auth.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::fs;
17use std::sync::Arc;
18
19use anyhow::anyhow;
20use ldap3::{LdapConnAsync, Scope, SearchEntry, dn_escape, ldap_escape};
21use risingwave_common::config::{AuthMethod, HbaEntry};
22use thiserror_ext::AsReport;
23use tracing::warn;
24
25use crate::error::{PsqlError, PsqlResult};
26
27const LDAP_SERVER_KEY: &str = "ldapserver";
28const LDAP_PORT_KEY: &str = "ldapport";
29const LDAP_SCHEME_KEY: &str = "ldapscheme";
30const LDAP_BASE_DN_KEY: &str = "ldapbasedn";
31const LDAP_SEARCH_FILTER_KEY: &str = "ldapsearchfilter";
32const LDAP_SEARCH_ATTRIBUTE_KEY: &str = "ldapsearchattribute";
33const LDAP_BIND_DN_KEY: &str = "ldapbinddn";
34const LDAP_BIND_PASSWD_KEY: &str = "ldapbindpasswd";
35const LDAP_PREFIX_KEY: &str = "ldapprefix";
36const LDAP_SUFFIX_KEY: &str = "ldapsuffix";
37const LDAP_URL_KEY: &str = "ldapurl";
38
39const LDAP_TLS: &str = "ldaptls";
40
41/// LDAP TLS environment configuration
42const RW_LDAPTLS_CACERT: &str = "LDAPTLS_CACERT";
43const RW_LDAPTLS_CERT: &str = "LDAPTLS_CERT";
44const RW_LDAPTLS_KEY: &str = "LDAPTLS_KEY";
45const RW_LDAPTLS_REQCERT: &str = "LDAPTLS_REQCERT";
46
47#[derive(Debug, Clone, Copy)]
48enum ReqCertPolicy {
49    Never,
50    Allow,
51    Try,
52    Demand,
53}
54
55#[derive(Debug)]
56pub struct LdapTlsConfig {
57    /// `LDAPTLS_CACERT` environment variable
58    ca_cert: Option<String>,
59    /// `LDAPTLS_CERT` environment variable
60    cert: Option<String>,
61    /// `LDAPTLS_KEY` environment variable
62    key: Option<String>,
63    /// `LDAPTLS_REQCERT` environment variable
64    req_cert: ReqCertPolicy,
65}
66
67impl LdapTlsConfig {
68    /// Create LDAP TLS configuration from environment variables
69    fn from_env() -> Self {
70        let ca_cert = std::env::var(RW_LDAPTLS_CACERT).ok();
71        let cert = std::env::var(RW_LDAPTLS_CERT).ok();
72        let key = std::env::var(RW_LDAPTLS_KEY).ok();
73        let req_cert = match std::env::var(RW_LDAPTLS_REQCERT).as_deref() {
74            Ok("never") => ReqCertPolicy::Never,
75            Ok("allow") => ReqCertPolicy::Allow,
76            Ok("try") => ReqCertPolicy::Try,
77            Ok("demand") => ReqCertPolicy::Demand,
78            _ => ReqCertPolicy::Demand, // Default to demand
79        };
80
81        Self {
82            ca_cert,
83            cert,
84            key,
85            req_cert,
86        }
87    }
88
89    /// Initialize rustls ClientConfig based on TLS configuration
90    fn init_client_config(&self) -> PsqlResult<rustls::ClientConfig> {
91        let tls_client_config = rustls::ClientConfig::builder();
92
93        let mut root_cert_store = rustls::RootCertStore::empty();
94        if let Some(tls_config) = &self.ca_cert {
95            let ca_cert_bytes = fs::read(tls_config).map_err(|e| {
96                PsqlError::StartupError(anyhow!(e).context("Failed to read CA certificate").into())
97            })?;
98            for cert in rustls_pemfile::certs(&mut ca_cert_bytes.as_slice()) {
99                let cert = cert.map_err(|e| {
100                    PsqlError::StartupError(
101                        anyhow!(e).context("Failed to parse CA certificate").into(),
102                    )
103                })?;
104                root_cert_store.add(cert).map_err(|err| {
105                    PsqlError::StartupError(
106                        anyhow!(err).context("Failed to add CA certificate").into(),
107                    )
108                })?;
109            }
110        } else {
111            // If ca certs is not present, load system native certs.
112            for cert in
113                rustls_native_certs::load_native_certs().expect("could not load platform certs")
114            {
115                root_cert_store.add(cert).map_err(|err| {
116                    PsqlError::StartupError(
117                        anyhow!(err)
118                            .context("Failed to add native CA certificate")
119                            .into(),
120                    )
121                })?;
122            }
123        }
124        let tls_client_config = tls_client_config.with_root_certificates(root_cert_store);
125
126        if let Some(cert) = &self.cert {
127            let Some(key) = &self.key else {
128                return Err(PsqlError::StartupError(
129                    "Client certificate provided without private key".into(),
130                ));
131            };
132            let client_cert_bytes = fs::read(cert).map_err(|e| {
133                PsqlError::StartupError(
134                    anyhow!(e)
135                        .context("Failed to read client certificate")
136                        .into(),
137                )
138            })?;
139            let client_key_bytes = fs::read(key).map_err(|e| {
140                PsqlError::StartupError(anyhow!(e).context("Failed to read client key").into())
141            })?;
142            let client_certs = rustls_pemfile::certs(&mut client_cert_bytes.as_slice())
143                .try_collect()
144                .map_err(|e| {
145                    PsqlError::StartupError(
146                        anyhow!(e)
147                            .context("Failed to parse client certificate")
148                            .into(),
149                    )
150                })?;
151
152            let private_key = rustls_pemfile::private_key(&mut client_key_bytes.as_slice())
153                .map_err(|e| {
154                    PsqlError::StartupError(anyhow!(e).context("Failed to parse client key").into())
155                })?;
156            let Some(client_private_key) = private_key else {
157                return Err(PsqlError::StartupError(
158                    "No private key found in client key file".into(),
159                ));
160            };
161
162            tls_client_config
163                .with_client_auth_cert(client_certs, client_private_key)
164                .map_err(|err| {
165                    PsqlError::StartupError(
166                        anyhow!(err)
167                            .context("Failed to set client certificate")
168                            .into(),
169                    )
170                })
171        } else {
172            Ok(tls_client_config.with_no_client_auth())
173        }
174    }
175}
176
177/// LDAP configuration extracted from HBA entry
178#[derive(Debug, Clone)]
179pub struct LdapConfig {
180    /// LDAP server address
181    pub server: String,
182    /// LDAP bind DN template or search base
183    pub base_dn: Option<String>,
184    /// LDAP search filter template
185    pub search_filter: Option<String>,
186    /// LDAP search attribute (used in search+bind mode, defaults to "uid" if not specified)
187    pub search_attribute: Option<String>,
188    /// DN to bind as when performing searches
189    pub bind_dn: Option<String>,
190    /// Password for bind DN
191    pub bind_passwd: Option<String>,
192    /// Prefix to prepend to username in simple bind
193    pub prefix: Option<String>,
194    /// Suffix to append to username in simple bind
195    pub suffix: Option<String>,
196    /// Whether to use STARTTLS
197    pub start_tls: bool,
198}
199
200impl LdapConfig {
201    /// Create LDAP configuration from HBA entry options
202    pub fn from_hba_options(options: &HashMap<String, String>) -> PsqlResult<Self> {
203        if let Some(ldap_url) = options.get(LDAP_URL_KEY) {
204            return Self::from_ldap_url(ldap_url, options);
205        }
206
207        let server = options
208            .get(LDAP_SERVER_KEY)
209            .ok_or_else(|| PsqlError::StartupError("LDAP server (ldapserver) is required".into()))?
210            .clone();
211
212        let scheme = options
213            .get(LDAP_SCHEME_KEY)
214            .map(|s| s.as_str())
215            .unwrap_or("ldap");
216        if scheme != "ldap" && scheme != "ldaps" {
217            return Err(PsqlError::StartupError(
218                "LDAP scheme (ldapscheme) must be either 'ldap' or 'ldaps'".into(),
219            ));
220        }
221
222        let port = options
223            .get(LDAP_PORT_KEY)
224            .and_then(|p| p.parse::<u16>().ok())
225            .unwrap_or_else(|| if scheme == "ldaps" { 636 } else { 389 });
226
227        let start_tls = options
228            .get(LDAP_TLS)
229            .and_then(|p| p.parse::<bool>().ok())
230            .unwrap_or(false);
231
232        // Validate that StartTLS and ldaps are not used together
233        if start_tls && scheme == "ldaps" {
234            return Err(PsqlError::StartupError(
235                "Cannot use STARTTLS (ldaptls) with ldaps scheme".into(),
236            ));
237        }
238
239        let server = format!("{}://{}:{}", scheme, server, port);
240        let base_dn = options.get(LDAP_BASE_DN_KEY).cloned();
241        let search_filter = options.get(LDAP_SEARCH_FILTER_KEY).cloned();
242        let search_attribute = options.get(LDAP_SEARCH_ATTRIBUTE_KEY).cloned();
243        let bind_dn = options.get(LDAP_BIND_DN_KEY).cloned();
244        let bind_passwd = options.get(LDAP_BIND_PASSWD_KEY).cloned();
245        let prefix = options.get(LDAP_PREFIX_KEY).cloned();
246        let suffix = options.get(LDAP_SUFFIX_KEY).cloned();
247
248        Ok(Self {
249            server,
250            base_dn,
251            search_filter,
252            search_attribute,
253            bind_dn,
254            bind_passwd,
255            prefix,
256            suffix,
257            start_tls,
258        })
259    }
260
261    /// Parse LDAP URL (RFC 4516 format)
262    /// Format: ldap\[s\]://host:port/basedn?attributes?scope?filter
263    fn from_ldap_url(ldap_url: &str, options: &HashMap<String, String>) -> PsqlResult<Self> {
264        // Validate that conflicting parameters are not present
265        // According to PostgreSQL docs, ldapurl cannot be mixed with parameters that would conflict
266        let conflicting_params = [
267            LDAP_SERVER_KEY,
268            LDAP_PORT_KEY,
269            LDAP_SCHEME_KEY,
270            LDAP_BASE_DN_KEY,
271            LDAP_SEARCH_ATTRIBUTE_KEY,
272            LDAP_SEARCH_FILTER_KEY,
273        ];
274
275        for param in &conflicting_params {
276            if options.contains_key(*param) {
277                return Err(PsqlError::StartupError(
278                    format!("Cannot specify both ldapurl and {} parameter", param).into(),
279                ));
280            }
281        }
282
283        // Parse the URL using standard URL parsing
284        let url = url::Url::parse(ldap_url).map_err(|e| {
285            PsqlError::StartupError(anyhow!(e).context("Failed to parse ldap url").into())
286        })?;
287
288        // Validate scheme
289        let scheme = url.scheme();
290        if scheme != "ldap" && scheme != "ldaps" {
291            return Err(PsqlError::StartupError(
292                "LDAP URL scheme must be either 'ldap' or 'ldaps'".into(),
293            ));
294        }
295
296        // Extract host and port
297        let host = url
298            .host_str()
299            .ok_or_else(|| PsqlError::StartupError("LDAP URL must contain a host".into()))?;
300        let port = url
301            .port()
302            .unwrap_or_else(|| if scheme == "ldaps" { 636 } else { 389 });
303
304        let server = format!("{}://{}:{}", scheme, host, port);
305
306        // Extract basedn from path (remove leading /)
307        let base_dn = if url.path().len() > 1 {
308            Some(url.path()[1..].to_string())
309        } else {
310            None
311        };
312
313        // Parse query parameters for attributes, scope, filter
314        // Format: ?attributes?scope?filter
315        let mut search_attribute = None;
316        let mut search_filter = None;
317
318        if let Some(query) = url.query() {
319            let parts: Vec<&str> = query.split('?').collect();
320
321            // First part is attributes (comma-separated, we only care about the first one for search)
322            if !parts.is_empty() && !parts[0].is_empty() {
323                search_attribute = Some(parts[0].split(',').next().unwrap().to_owned());
324            }
325
326            // Third part is filter (index 2)
327            if parts.len() > 2 && !parts[2].is_empty() {
328                search_filter = Some(parts[2].to_owned());
329            }
330        }
331
332        // Only allow supplementary parameters with ldapurl:
333        // - ldaptls: for StartTLS
334        // - ldapbinddn/ldapbindpasswd: for authenticated searches
335        let start_tls = options
336            .get(LDAP_TLS)
337            .and_then(|p| p.parse::<bool>().ok())
338            .unwrap_or(false);
339
340        // Validate that StartTLS and ldaps are not used together
341        if start_tls && scheme == "ldaps" {
342            return Err(PsqlError::StartupError(
343                "Cannot use STARTTLS (ldaptls) with ldaps scheme".into(),
344            ));
345        }
346
347        let bind_dn = options.get(LDAP_BIND_DN_KEY).cloned();
348        let bind_passwd = options.get(LDAP_BIND_PASSWD_KEY).cloned();
349        let prefix = options.get(LDAP_PREFIX_KEY).cloned();
350        let suffix = options.get(LDAP_SUFFIX_KEY).cloned();
351
352        Ok(Self {
353            server,
354            base_dn,
355            search_filter,
356            search_attribute,
357            bind_dn,
358            bind_passwd,
359            prefix,
360            suffix,
361            start_tls,
362        })
363    }
364
365    fn certs_required(&self) -> bool {
366        self.server.starts_with("ldaps://") || self.start_tls
367    }
368}
369
370/// LDAP authenticator that validates user credentials against an LDAP server
371#[derive(Debug, Clone)]
372pub struct LdapAuthenticator {
373    /// LDAP server configuration from HBA entry
374    config: LdapConfig,
375}
376
377impl LdapAuthenticator {
378    /// Create a new LDAP authenticator from HBA entry options
379    pub fn new(hba_entry: &HbaEntry) -> PsqlResult<Self> {
380        if hba_entry.auth_method != AuthMethod::Ldap {
381            return Err(PsqlError::StartupError(
382                "HBA entry is not configured for LDAP authentication".into(),
383            ));
384        }
385
386        let config = LdapConfig::from_hba_options(&hba_entry.auth_options)?;
387        Ok(Self { config })
388    }
389
390    /// Authenticate a user
391    pub async fn authenticate(&self, username: &str, password: &str) -> PsqlResult<bool> {
392        // Skip authentication if password is empty
393        if password.is_empty() {
394            return Ok(false);
395        }
396
397        // Determine the authentication strategy based on configured parameters
398        // According to PostgreSQL documentation:
399        // - Simple bind mode: Triggered by ldapprefix and/or ldapsuffix
400        // - Search+bind mode: Triggered by ldapbasedn (when no prefix/suffix present)
401        // - It's an error to mix simple bind params with search+bind-only params
402
403        let has_simple_bind_params = self.config.prefix.is_some() || self.config.suffix.is_some();
404
405        // Search+bind-only parameters that shouldn't be mixed with simple bind
406        let has_search_only_params = self.config.search_filter.is_some()
407            || self.config.bind_dn.is_some()
408            || self.config.search_attribute.is_some();
409
410        // Validate that we don't mix simple bind params with search+bind-only params
411        if has_simple_bind_params && has_search_only_params {
412            return Err(PsqlError::StartupError(
413                "Cannot mix simple bind parameters (ldapprefix/ldapsuffix) with search+bind parameters (ldapsearchfilter/ldapbinddn/ldapsearchattribute)".into()
414            ));
415        }
416
417        // Decision logic based on PostgreSQL behavior:
418        // The mode is determined by which parameters are present:
419        if has_simple_bind_params {
420            // Simple bind mode: prefix/suffix present
421            self.simple_bind(username, password).await
422        } else if self.config.base_dn.is_some() {
423            // Search+bind mode: basedn present without prefix/suffix
424            self.search_and_bind(username, password).await
425        } else {
426            // Fallback: no basedn, no prefix/suffix - use username directly as DN
427            self.simple_bind(username, password).await
428        }
429    }
430
431    /// Establish an LDAP connection with configurable options
432    #[allow(rw::format_error)]
433    async fn establish_connection(&self) -> PsqlResult<ldap3::Ldap> {
434        let config = &self.config;
435        let mut settings = ldap3::LdapConnSettings::new();
436
437        // Configure STARTTLS if specified
438        settings = settings.set_starttls(config.start_tls);
439
440        if config.certs_required() {
441            let tls_config = LdapTlsConfig::from_env();
442            tracing::debug!("fetched tls config from env: {:?}", tls_config);
443
444            let client_config = tls_config.init_client_config()?;
445            settings = settings.set_config(Arc::new(client_config));
446
447            if matches!(tls_config.req_cert, ReqCertPolicy::Demand) {
448                settings = settings.set_no_tls_verify(false);
449            } else {
450                warn!(
451                    "LDAP client certificate verification is disabled due to LDAPTLS_REQCERT policy"
452                );
453                settings = settings.set_no_tls_verify(true);
454            }
455        }
456
457        let (conn, ldap) = LdapConnAsync::with_settings(settings, &config.server)
458            .await
459            .map_err(|err| {
460                PsqlError::StartupError(
461                    anyhow!(err)
462                        .context("Failed to connect to LDAP server")
463                        .into(),
464                )
465            })?;
466        ldap3::drive!(conn);
467
468        Ok(ldap)
469    }
470
471    /// Search for user in LDAP directory and then bind
472    async fn search_and_bind(&self, username: &str, password: &str) -> PsqlResult<bool> {
473        // Establish connection to LDAP server
474        let mut ldap = self.establish_connection().await?;
475
476        // Validate base_dn configuration
477        let base_dn = self
478            .config
479            .base_dn
480            .as_ref()
481            .ok_or_else(|| PsqlError::StartupError("LDAP base_dn not configured".into()))?;
482
483        // If bind_dn and bind_passwd are provided, bind as that user first
484        if let (Some(bind_dn), Some(bind_passwd)) = (&self.config.bind_dn, &self.config.bind_passwd)
485        {
486            ldap.simple_bind(bind_dn, bind_passwd)
487                .await
488                .map_err(|e| {
489                    PsqlError::StartupError(
490                        anyhow!(e).context("LDAP bind as search user failed").into(),
491                    )
492                })?
493                .success()
494                .map_err(|e| {
495                    PsqlError::StartupError(
496                        anyhow!(e).context("LDAP bind as search user failed").into(),
497                    )
498                })?;
499        }
500
501        // Build search filter
502        let search_filter = if let Some(filter_template) = &self.config.search_filter {
503            // Use custom filter template with $username placeholder
504            // SECURITY: Escape username to prevent LDAP filter injection
505            let escaped_username = ldap_escape(username);
506            filter_template.replace("$username", &escaped_username)
507        } else {
508            // Default filter using search_attribute (defaults to "uid" if not configured)
509            // SECURITY: Escape username to prevent LDAP filter injection
510            let escaped_username = ldap_escape(username);
511            let attr = self.config.search_attribute.as_deref().unwrap_or("uid");
512            format!("({}={})", attr, escaped_username)
513        };
514
515        let rs = ldap
516            .search(base_dn, Scope::Subtree, &search_filter, vec!["dn"])
517            .await
518            .map_err(|e| {
519                PsqlError::StartupError(anyhow!(e).context("LDAP search failed").into())
520            })?;
521
522        // If no user found, authentication fails
523        let search_entries: Vec<SearchEntry> =
524            rs.0.into_iter().map(SearchEntry::construct).collect();
525        if search_entries.is_empty() {
526            return Ok(false);
527        }
528
529        // Attempt to bind with the user's DN and password
530        let user_dn = &search_entries[0].dn;
531
532        let bind_result = ldap
533            .simple_bind(user_dn, password)
534            .await
535            .map_err(|e| PsqlError::StartupError(anyhow!(e).context("LDAP bind failed").into()));
536
537        // Explicitly unbind the connection
538        let _ = ldap.unbind().await;
539
540        let bind_result = bind_result?;
541        match bind_result.success() {
542            Ok(_) => Ok(true),
543            Err(e) => {
544                tracing::error!(error = %e.as_report(), "LDAP bind unsuccessful");
545                Err(PsqlError::StartupError(
546                    anyhow!(e).context("LDAP bind failed").into(),
547                ))
548            }
549        }
550    }
551
552    /// Simple bind authentication
553    async fn simple_bind(&self, username: &str, password: &str) -> PsqlResult<bool> {
554        // Construct DN from username according to PostgreSQL simple bind rules:
555        // 1. If prefix/suffix are configured, use them: prefix + username + suffix
556        // 2. If only basedn is configured (legacy/fallback), use: uid=username,basedn
557        // 3. Otherwise, use username directly as DN
558        //
559        // SECURITY: Escape username to prevent LDAP DN injection
560        let escaped_username = dn_escape(username);
561
562        let dn = if self.config.prefix.is_some() || self.config.suffix.is_some() {
563            // Use prefix/suffix to construct DN
564            let prefix = self.config.prefix.as_deref().unwrap_or("");
565            let suffix = self.config.suffix.as_deref().unwrap_or("");
566            format!("{}{}{}", prefix, escaped_username, suffix)
567        } else if let Some(base_dn) = &self.config.base_dn {
568            // Fallback: construct DN as uid=username,basedn
569            // Note: If basedn is present without prefix/suffix, this should normally
570            // trigger search+bind mode, but we support this for backwards compatibility
571            format!("uid={},{}", escaped_username, base_dn)
572        } else {
573            // Use username as-is as the DN (still escaped)
574            escaped_username.to_string()
575        };
576
577        // Attempt to bind
578        let mut ldap = self.establish_connection().await?;
579
580        tracing::info!(%self.config.server, %dn, "simple bind authentication with LDAP server");
581
582        let bind_result = ldap
583            .simple_bind(&dn, password)
584            .await
585            .map_err(|e| PsqlError::StartupError(anyhow!(e).context("LDAP bind failed").into()));
586
587        // Explicitly unbind the connection
588        let _ = ldap.unbind().await;
589
590        let bind_result = bind_result?;
591        match bind_result.success() {
592            Ok(_) => Ok(true),
593            Err(e) => {
594                tracing::error!(error = %e.as_report(), "LDAP bind unsuccessful");
595                Err(PsqlError::StartupError(
596                    format!("LDAP bind failed: {}", e.as_report()).into(),
597                ))
598            }
599        }
600    }
601}