pgwire/
ldap_auth.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::fs;
17use std::sync::Arc;
18
19use anyhow::anyhow;
20use ldap3::{LdapConnAsync, Scope, SearchEntry, dn_escape, ldap_escape};
21use risingwave_common::config::{AuthMethod, HbaEntry};
22use thiserror_ext::AsReport;
23use tracing::warn;
24
25use crate::error::{PsqlError, PsqlResult};
26
27const LDAP_SERVER_KEY: &str = "ldapserver";
28const LDAP_PORT_KEY: &str = "ldapport";
29const LDAP_SCHEME_KEY: &str = "ldapscheme";
30const LDAP_BASE_DN_KEY: &str = "ldapbasedn";
31const LDAP_SEARCH_FILTER_KEY: &str = "ldapsearchfilter";
32const LDAP_SEARCH_ATTRIBUTE_KEY: &str = "ldapsearchattribute";
33const LDAP_BIND_DN_KEY: &str = "ldapbinddn";
34const LDAP_BIND_PASSWD_KEY: &str = "ldapbindpasswd";
35const LDAP_PREFIX_KEY: &str = "ldapprefix";
36const LDAP_SUFFIX_KEY: &str = "ldapsuffix";
37const LDAP_URL_KEY: &str = "ldapurl";
38
39const LDAP_TLS: &str = "ldaptls";
40
41/// LDAP TLS environment configuration
42const RW_LDAPTLS_CACERT: &str = "LDAPTLS_CACERT";
43const RW_LDAPTLS_CERT: &str = "LDAPTLS_CERT";
44const RW_LDAPTLS_KEY: &str = "LDAPTLS_KEY";
45const RW_LDAPTLS_REQCERT: &str = "LDAPTLS_REQCERT";
46
47#[derive(Debug, Clone, Copy)]
48enum ReqCertPolicy {
49    Never,
50    Allow,
51    Try,
52    Demand,
53}
54
55#[derive(Debug)]
56pub struct LdapTlsConfig {
57    /// `LDAPTLS_CACERT` environment variable
58    ca_cert: Option<String>,
59    /// `LDAPTLS_CERT` environment variable
60    cert: Option<String>,
61    /// `LDAPTLS_KEY` environment variable
62    key: Option<String>,
63    /// `LDAPTLS_REQCERT` environment variable
64    req_cert: ReqCertPolicy,
65}
66
67impl LdapTlsConfig {
68    /// Create LDAP TLS configuration from environment variables
69    fn from_env() -> Self {
70        let ca_cert = std::env::var(RW_LDAPTLS_CACERT).ok();
71        let cert = std::env::var(RW_LDAPTLS_CERT).ok();
72        let key = std::env::var(RW_LDAPTLS_KEY).ok();
73        let req_cert = match std::env::var(RW_LDAPTLS_REQCERT).as_deref() {
74            Ok("never") => ReqCertPolicy::Never,
75            Ok("allow") => ReqCertPolicy::Allow,
76            Ok("try") => ReqCertPolicy::Try,
77            Ok("demand") => ReqCertPolicy::Demand,
78            _ => ReqCertPolicy::Demand, // Default to demand
79        };
80
81        Self {
82            ca_cert,
83            cert,
84            key,
85            req_cert,
86        }
87    }
88
89    /// Initialize rustls ClientConfig based on TLS configuration
90    fn init_client_config(&self) -> PsqlResult<rustls::ClientConfig> {
91        let tls_client_config = rustls::ClientConfig::builder().with_safe_defaults();
92
93        let mut root_cert_store = rustls::RootCertStore::empty();
94        if let Some(tls_config) = &self.ca_cert {
95            let ca_cert_bytes = fs::read(tls_config).map_err(|e| {
96                PsqlError::StartupError(anyhow!(e).context("Failed to read CA certificate").into())
97            })?;
98            let ca_certs = rustls_pemfile::certs(&mut ca_cert_bytes.as_slice()).map_err(|e| {
99                PsqlError::StartupError(anyhow!(e).context("Failed to parse CA certificate").into())
100            })?;
101            for cert in ca_certs {
102                root_cert_store
103                    .add(&rustls::Certificate(cert))
104                    .map_err(|err| {
105                        PsqlError::StartupError(
106                            anyhow!(err).context("Failed to add CA certificate").into(),
107                        )
108                    })?;
109            }
110        } else {
111            // If ca certs is not present, load system native certs.
112            for cert in
113                rustls_native_certs::load_native_certs().expect("could not load platform certs")
114            {
115                root_cert_store
116                    .add(&rustls::Certificate(cert.0))
117                    .map_err(|err| {
118                        PsqlError::StartupError(
119                            anyhow!(err)
120                                .context("Failed to add native CA certificate")
121                                .into(),
122                        )
123                    })?;
124            }
125        }
126        let tls_client_config = tls_client_config.with_root_certificates(root_cert_store);
127
128        if let Some(cert) = &self.cert {
129            let Some(key) = &self.key else {
130                return Err(PsqlError::StartupError(
131                    "Client certificate provided without private key".into(),
132                ));
133            };
134            let client_cert_bytes = fs::read(cert).map_err(|e| {
135                PsqlError::StartupError(
136                    anyhow!(e)
137                        .context("Failed to read client certificate")
138                        .into(),
139                )
140            })?;
141            let client_key_bytes = fs::read(key).map_err(|e| {
142                PsqlError::StartupError(anyhow!(e).context("Failed to read client key").into())
143            })?;
144            let client_certs =
145                rustls_pemfile::certs(&mut client_cert_bytes.as_slice()).map_err(|e| {
146                    PsqlError::StartupError(
147                        anyhow!(e)
148                            .context("Failed to parse client certificate")
149                            .into(),
150                    )
151                })?;
152
153            let mut reader = std::io::Cursor::new(&client_key_bytes);
154            let mut private_keys =
155                rustls_pemfile::pkcs8_private_keys(&mut reader).map_err(|e| {
156                    PsqlError::StartupError(anyhow!(e).context("Failed to parse client key").into())
157                })?;
158            if private_keys.is_empty() {
159                // Try RSA private keys as a fallback
160                reader.set_position(0);
161                private_keys = rustls_pemfile::rsa_private_keys(&mut reader).map_err(|e| {
162                    PsqlError::StartupError(anyhow!(e).context("Failed to parse client key").into())
163                })?;
164            }
165            let client_private_key = private_keys.pop().ok_or_else(|| {
166                PsqlError::StartupError("No private key found in client key file".into())
167            })?;
168            let client_certs_rustls: Vec<rustls::Certificate> =
169                client_certs.into_iter().map(rustls::Certificate).collect();
170
171            tls_client_config
172                .with_client_auth_cert(client_certs_rustls, rustls::PrivateKey(client_private_key))
173                .map_err(|err| {
174                    PsqlError::StartupError(
175                        anyhow!(err)
176                            .context("Failed to set client certificate")
177                            .into(),
178                    )
179                })
180        } else {
181            Ok(tls_client_config.with_no_client_auth())
182        }
183    }
184}
185
186/// LDAP configuration extracted from HBA entry
187#[derive(Debug, Clone)]
188pub struct LdapConfig {
189    /// LDAP server address
190    pub server: String,
191    /// LDAP bind DN template or search base
192    pub base_dn: Option<String>,
193    /// LDAP search filter template
194    pub search_filter: Option<String>,
195    /// LDAP search attribute (used in search+bind mode, defaults to "uid" if not specified)
196    pub search_attribute: Option<String>,
197    /// DN to bind as when performing searches
198    pub bind_dn: Option<String>,
199    /// Password for bind DN
200    pub bind_passwd: Option<String>,
201    /// Prefix to prepend to username in simple bind
202    pub prefix: Option<String>,
203    /// Suffix to append to username in simple bind
204    pub suffix: Option<String>,
205    /// Whether to use STARTTLS
206    pub start_tls: bool,
207}
208
209impl LdapConfig {
210    /// Create LDAP configuration from HBA entry options
211    pub fn from_hba_options(options: &HashMap<String, String>) -> PsqlResult<Self> {
212        if let Some(ldap_url) = options.get(LDAP_URL_KEY) {
213            return Self::from_ldap_url(ldap_url, options);
214        }
215
216        let server = options
217            .get(LDAP_SERVER_KEY)
218            .ok_or_else(|| PsqlError::StartupError("LDAP server (ldapserver) is required".into()))?
219            .clone();
220
221        let scheme = options
222            .get(LDAP_SCHEME_KEY)
223            .map(|s| s.as_str())
224            .unwrap_or("ldap");
225        if scheme != "ldap" && scheme != "ldaps" {
226            return Err(PsqlError::StartupError(
227                "LDAP scheme (ldapscheme) must be either 'ldap' or 'ldaps'".into(),
228            ));
229        }
230
231        let port = options
232            .get(LDAP_PORT_KEY)
233            .and_then(|p| p.parse::<u16>().ok())
234            .unwrap_or_else(|| if scheme == "ldaps" { 636 } else { 389 });
235
236        let start_tls = options
237            .get(LDAP_TLS)
238            .and_then(|p| p.parse::<bool>().ok())
239            .unwrap_or(false);
240
241        // Validate that StartTLS and ldaps are not used together
242        if start_tls && scheme == "ldaps" {
243            return Err(PsqlError::StartupError(
244                "Cannot use STARTTLS (ldaptls) with ldaps scheme".into(),
245            ));
246        }
247
248        let server = format!("{}://{}:{}", scheme, server, port);
249        let base_dn = options.get(LDAP_BASE_DN_KEY).cloned();
250        let search_filter = options.get(LDAP_SEARCH_FILTER_KEY).cloned();
251        let search_attribute = options.get(LDAP_SEARCH_ATTRIBUTE_KEY).cloned();
252        let bind_dn = options.get(LDAP_BIND_DN_KEY).cloned();
253        let bind_passwd = options.get(LDAP_BIND_PASSWD_KEY).cloned();
254        let prefix = options.get(LDAP_PREFIX_KEY).cloned();
255        let suffix = options.get(LDAP_SUFFIX_KEY).cloned();
256
257        Ok(Self {
258            server,
259            base_dn,
260            search_filter,
261            search_attribute,
262            bind_dn,
263            bind_passwd,
264            prefix,
265            suffix,
266            start_tls,
267        })
268    }
269
270    /// Parse LDAP URL (RFC 4516 format)
271    /// Format: ldap\[s\]://host:port/basedn?attributes?scope?filter
272    fn from_ldap_url(ldap_url: &str, options: &HashMap<String, String>) -> PsqlResult<Self> {
273        // Validate that conflicting parameters are not present
274        // According to PostgreSQL docs, ldapurl cannot be mixed with parameters that would conflict
275        let conflicting_params = [
276            LDAP_SERVER_KEY,
277            LDAP_PORT_KEY,
278            LDAP_SCHEME_KEY,
279            LDAP_BASE_DN_KEY,
280            LDAP_SEARCH_ATTRIBUTE_KEY,
281            LDAP_SEARCH_FILTER_KEY,
282        ];
283
284        for param in &conflicting_params {
285            if options.contains_key(*param) {
286                return Err(PsqlError::StartupError(
287                    format!("Cannot specify both ldapurl and {} parameter", param).into(),
288                ));
289            }
290        }
291
292        // Parse the URL using standard URL parsing
293        let url = url::Url::parse(ldap_url).map_err(|e| {
294            PsqlError::StartupError(anyhow!(e).context("Failed to parse ldap url").into())
295        })?;
296
297        // Validate scheme
298        let scheme = url.scheme();
299        if scheme != "ldap" && scheme != "ldaps" {
300            return Err(PsqlError::StartupError(
301                "LDAP URL scheme must be either 'ldap' or 'ldaps'".into(),
302            ));
303        }
304
305        // Extract host and port
306        let host = url
307            .host_str()
308            .ok_or_else(|| PsqlError::StartupError("LDAP URL must contain a host".into()))?;
309        let port = url
310            .port()
311            .unwrap_or_else(|| if scheme == "ldaps" { 636 } else { 389 });
312
313        let server = format!("{}://{}:{}", scheme, host, port);
314
315        // Extract basedn from path (remove leading /)
316        let base_dn = if url.path().len() > 1 {
317            Some(url.path()[1..].to_string())
318        } else {
319            None
320        };
321
322        // Parse query parameters for attributes, scope, filter
323        // Format: ?attributes?scope?filter
324        let mut search_attribute = None;
325        let mut search_filter = None;
326
327        if let Some(query) = url.query() {
328            let parts: Vec<&str> = query.split('?').collect();
329
330            // First part is attributes (comma-separated, we only care about the first one for search)
331            if !parts.is_empty() && !parts[0].is_empty() {
332                search_attribute = Some(parts[0].split(',').next().unwrap().to_owned());
333            }
334
335            // Third part is filter (index 2)
336            if parts.len() > 2 && !parts[2].is_empty() {
337                search_filter = Some(parts[2].to_owned());
338            }
339        }
340
341        // Only allow supplementary parameters with ldapurl:
342        // - ldaptls: for StartTLS
343        // - ldapbinddn/ldapbindpasswd: for authenticated searches
344        let start_tls = options
345            .get(LDAP_TLS)
346            .and_then(|p| p.parse::<bool>().ok())
347            .unwrap_or(false);
348
349        // Validate that StartTLS and ldaps are not used together
350        if start_tls && scheme == "ldaps" {
351            return Err(PsqlError::StartupError(
352                "Cannot use STARTTLS (ldaptls) with ldaps scheme".into(),
353            ));
354        }
355
356        let bind_dn = options.get(LDAP_BIND_DN_KEY).cloned();
357        let bind_passwd = options.get(LDAP_BIND_PASSWD_KEY).cloned();
358        let prefix = options.get(LDAP_PREFIX_KEY).cloned();
359        let suffix = options.get(LDAP_SUFFIX_KEY).cloned();
360
361        Ok(Self {
362            server,
363            base_dn,
364            search_filter,
365            search_attribute,
366            bind_dn,
367            bind_passwd,
368            prefix,
369            suffix,
370            start_tls,
371        })
372    }
373
374    fn certs_required(&self) -> bool {
375        self.server.starts_with("ldaps://") || self.start_tls
376    }
377}
378
379/// LDAP authenticator that validates user credentials against an LDAP server
380#[derive(Debug, Clone)]
381pub struct LdapAuthenticator {
382    /// LDAP server configuration from HBA entry
383    config: LdapConfig,
384}
385
386impl LdapAuthenticator {
387    /// Create a new LDAP authenticator from HBA entry options
388    pub fn new(hba_entry: &HbaEntry) -> PsqlResult<Self> {
389        if hba_entry.auth_method != AuthMethod::Ldap {
390            return Err(PsqlError::StartupError(
391                "HBA entry is not configured for LDAP authentication".into(),
392            ));
393        }
394
395        let config = LdapConfig::from_hba_options(&hba_entry.auth_options)?;
396        Ok(Self { config })
397    }
398
399    /// Authenticate a user
400    pub async fn authenticate(&self, username: &str, password: &str) -> PsqlResult<bool> {
401        // Skip authentication if password is empty
402        if password.is_empty() {
403            return Ok(false);
404        }
405
406        // Determine the authentication strategy based on configured parameters
407        // According to PostgreSQL documentation:
408        // - Simple bind mode: Triggered by ldapprefix and/or ldapsuffix
409        // - Search+bind mode: Triggered by ldapbasedn (when no prefix/suffix present)
410        // - It's an error to mix simple bind params with search+bind-only params
411
412        let has_simple_bind_params = self.config.prefix.is_some() || self.config.suffix.is_some();
413
414        // Search+bind-only parameters that shouldn't be mixed with simple bind
415        let has_search_only_params = self.config.search_filter.is_some()
416            || self.config.bind_dn.is_some()
417            || self.config.search_attribute.is_some();
418
419        // Validate that we don't mix simple bind params with search+bind-only params
420        if has_simple_bind_params && has_search_only_params {
421            return Err(PsqlError::StartupError(
422                "Cannot mix simple bind parameters (ldapprefix/ldapsuffix) with search+bind parameters (ldapsearchfilter/ldapbinddn/ldapsearchattribute)".into()
423            ));
424        }
425
426        // Decision logic based on PostgreSQL behavior:
427        // The mode is determined by which parameters are present:
428        if has_simple_bind_params {
429            // Simple bind mode: prefix/suffix present
430            self.simple_bind(username, password).await
431        } else if self.config.base_dn.is_some() {
432            // Search+bind mode: basedn present without prefix/suffix
433            self.search_and_bind(username, password).await
434        } else {
435            // Fallback: no basedn, no prefix/suffix - use username directly as DN
436            self.simple_bind(username, password).await
437        }
438    }
439
440    /// Establish an LDAP connection with configurable options
441    #[allow(rw::format_error)]
442    async fn establish_connection(&self) -> PsqlResult<ldap3::Ldap> {
443        let config = &self.config;
444        let mut settings = ldap3::LdapConnSettings::new();
445
446        // Configure STARTTLS if specified
447        settings = settings.set_starttls(config.start_tls);
448
449        if config.certs_required() {
450            let tls_config = LdapTlsConfig::from_env();
451            tracing::debug!("fetched tls config from env: {:?}", tls_config);
452
453            let client_config = tls_config.init_client_config()?;
454            settings = settings.set_config(Arc::new(client_config));
455
456            if matches!(tls_config.req_cert, ReqCertPolicy::Demand) {
457                settings = settings.set_no_tls_verify(false);
458            } else {
459                warn!(
460                    "LDAP client certificate verification is disabled due to LDAPTLS_REQCERT policy"
461                );
462                settings = settings.set_no_tls_verify(true);
463            }
464        }
465
466        let (conn, ldap) = LdapConnAsync::with_settings(settings, &config.server)
467            .await
468            .map_err(|err| {
469                PsqlError::StartupError(
470                    anyhow!(err)
471                        .context("Failed to connect to LDAP server")
472                        .into(),
473                )
474            })?;
475        ldap3::drive!(conn);
476
477        Ok(ldap)
478    }
479
480    /// Search for user in LDAP directory and then bind
481    async fn search_and_bind(&self, username: &str, password: &str) -> PsqlResult<bool> {
482        // Establish connection to LDAP server
483        let mut ldap = self.establish_connection().await?;
484
485        // Validate base_dn configuration
486        let base_dn = self
487            .config
488            .base_dn
489            .as_ref()
490            .ok_or_else(|| PsqlError::StartupError("LDAP base_dn not configured".into()))?;
491
492        // If bind_dn and bind_passwd are provided, bind as that user first
493        if let (Some(bind_dn), Some(bind_passwd)) = (&self.config.bind_dn, &self.config.bind_passwd)
494        {
495            ldap.simple_bind(bind_dn, bind_passwd)
496                .await
497                .map_err(|e| {
498                    PsqlError::StartupError(
499                        anyhow!(e).context("LDAP bind as search user failed").into(),
500                    )
501                })?
502                .success()
503                .map_err(|e| {
504                    PsqlError::StartupError(
505                        anyhow!(e).context("LDAP bind as search user failed").into(),
506                    )
507                })?;
508        }
509
510        // Build search filter
511        let search_filter = if let Some(filter_template) = &self.config.search_filter {
512            // Use custom filter template with $username placeholder
513            // SECURITY: Escape username to prevent LDAP filter injection
514            let escaped_username = ldap_escape(username);
515            filter_template.replace("$username", &escaped_username)
516        } else {
517            // Default filter using search_attribute (defaults to "uid" if not configured)
518            // SECURITY: Escape username to prevent LDAP filter injection
519            let escaped_username = ldap_escape(username);
520            let attr = self.config.search_attribute.as_deref().unwrap_or("uid");
521            format!("({}={})", attr, escaped_username)
522        };
523
524        let rs = ldap
525            .search(base_dn, Scope::Subtree, &search_filter, vec!["dn"])
526            .await
527            .map_err(|e| {
528                PsqlError::StartupError(anyhow!(e).context("LDAP search failed").into())
529            })?;
530
531        // If no user found, authentication fails
532        let search_entries: Vec<SearchEntry> =
533            rs.0.into_iter().map(SearchEntry::construct).collect();
534        if search_entries.is_empty() {
535            return Ok(false);
536        }
537
538        // Attempt to bind with the user's DN and password
539        let user_dn = &search_entries[0].dn;
540
541        let bind_result = ldap
542            .simple_bind(user_dn, password)
543            .await
544            .map_err(|e| PsqlError::StartupError(anyhow!(e).context("LDAP bind failed").into()));
545
546        // Explicitly unbind the connection
547        let _ = ldap.unbind().await;
548
549        let bind_result = bind_result?;
550        match bind_result.success() {
551            Ok(_) => Ok(true),
552            Err(e) => {
553                tracing::error!(error = %e.as_report(), "LDAP bind unsuccessful");
554                Err(PsqlError::StartupError(
555                    anyhow!(e).context("LDAP bind failed").into(),
556                ))
557            }
558        }
559    }
560
561    /// Simple bind authentication
562    async fn simple_bind(&self, username: &str, password: &str) -> PsqlResult<bool> {
563        // Construct DN from username according to PostgreSQL simple bind rules:
564        // 1. If prefix/suffix are configured, use them: prefix + username + suffix
565        // 2. If only basedn is configured (legacy/fallback), use: uid=username,basedn
566        // 3. Otherwise, use username directly as DN
567        //
568        // SECURITY: Escape username to prevent LDAP DN injection
569        let escaped_username = dn_escape(username);
570
571        let dn = if self.config.prefix.is_some() || self.config.suffix.is_some() {
572            // Use prefix/suffix to construct DN
573            let prefix = self.config.prefix.as_deref().unwrap_or("");
574            let suffix = self.config.suffix.as_deref().unwrap_or("");
575            format!("{}{}{}", prefix, escaped_username, suffix)
576        } else if let Some(base_dn) = &self.config.base_dn {
577            // Fallback: construct DN as uid=username,basedn
578            // Note: If basedn is present without prefix/suffix, this should normally
579            // trigger search+bind mode, but we support this for backwards compatibility
580            format!("uid={},{}", escaped_username, base_dn)
581        } else {
582            // Use username as-is as the DN (still escaped)
583            escaped_username.to_string()
584        };
585
586        // Attempt to bind
587        let mut ldap = self.establish_connection().await?;
588
589        tracing::info!(%self.config.server, %dn, "simple bind authentication with LDAP server");
590
591        let bind_result = ldap
592            .simple_bind(&dn, password)
593            .await
594            .map_err(|e| PsqlError::StartupError(anyhow!(e).context("LDAP bind failed").into()));
595
596        // Explicitly unbind the connection
597        let _ = ldap.unbind().await;
598
599        let bind_result = bind_result?;
600        match bind_result.success() {
601            Ok(_) => Ok(true),
602            Err(e) => {
603                tracing::error!(error = %e.as_report(), "LDAP bind unsuccessful");
604                Err(PsqlError::StartupError(
605                    format!("LDAP bind failed: {}", e.as_report()).into(),
606                ))
607            }
608        }
609    }
610}