risingwave_common/config/
hba.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::PartialEq;
16use std::collections::HashMap;
17use std::net::IpAddr;
18use std::str::FromStr;
19
20use serde::{Deserialize, Deserializer, Serialize, Serializer};
21
22// Keyword representing "all" databases, users, or addresses
23const ALL_KEYWORD: &str = "all";
24
25/// RisingWave HBA (Host-Based Authentication) configuration, similar to PostgreSQL's `pg_hba.conf`
26/// This determines which authentication method to use for each connection.
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct HbaConfig {
29    pub entries: Vec<HbaEntry>,
30}
31
32impl Default for HbaConfig {
33    fn default() -> Self {
34        Self {
35            entries: vec![
36                // Default rule: allow all local connections without authentication
37                HbaEntry {
38                    connection_type: ConnectionType::Local,
39                    databases: vec![ALL_KEYWORD.to_owned()],
40                    users: vec![ALL_KEYWORD.to_owned()],
41                    addresses: None,
42                    auth_method: AuthMethod::Trust,
43                    auth_options: HashMap::new(),
44                },
45                // Default rule: require password for all remote connections
46                HbaEntry {
47                    connection_type: ConnectionType::Host,
48                    databases: vec![ALL_KEYWORD.to_owned()],
49                    users: vec![ALL_KEYWORD.to_owned()],
50                    addresses: Some(vec![AddressPattern::Cidr("0.0.0.0/0".to_owned())]),
51                    auth_method: AuthMethod::Password,
52                    auth_options: HashMap::new(),
53                },
54            ],
55        }
56    }
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct HbaEntry {
61    /// Connection type (local, host, hostssl, hostnossl)
62    pub connection_type: ConnectionType,
63    /// Database names or "all"
64    pub databases: Vec<String>,
65    /// Usernames or "all"
66    pub users: Vec<String>,
67    /// Client addresses (only for non-local connections)
68    pub addresses: Option<Vec<AddressPattern>>,
69    /// Authentication method
70    pub auth_method: AuthMethod,
71    /// Authentication method options
72    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
73    pub auth_options: HashMap<String, String>,
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
77pub enum ConnectionType {
78    /// Unix socket connections
79    Local,
80    /// TCP/IP connections (both SSL and non-SSL)
81    Host,
82    /// TCP/IP connections that use SSL
83    #[serde(rename = "hostssl")]
84    HostSsl,
85    /// TCP/IP connections that do not use SSL
86    #[serde(rename = "hostnossl")]
87    HostNoSsl,
88}
89
90#[derive(Debug, Clone, PartialEq)]
91pub enum AddressPattern {
92    /// IP address with CIDR notation (e.g., "192.168.1.0/24")
93    Cidr(String),
94    /// Hostname
95    Hostname(String),
96    /// Keyword "all"
97    All,
98}
99
100impl Serialize for AddressPattern {
101    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
102    where
103        S: Serializer,
104    {
105        match self {
106            AddressPattern::All => serializer.serialize_str(ALL_KEYWORD),
107            AddressPattern::Cidr(s) => serializer.serialize_str(s),
108            AddressPattern::Hostname(s) => serializer.serialize_str(s),
109        }
110    }
111}
112
113impl<'de> Deserialize<'de> for AddressPattern {
114    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
115    where
116        D: Deserializer<'de>,
117    {
118        let s = String::deserialize(deserializer)?;
119        if s == ALL_KEYWORD {
120            Ok(AddressPattern::All)
121        } else if s.contains('/') {
122            Ok(AddressPattern::Cidr(s))
123        } else {
124            Ok(AddressPattern::Hostname(s))
125        }
126    }
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
130pub enum AuthMethod {
131    /// No authentication
132    #[serde(rename = "trust")]
133    Trust,
134    /// Password with any authentication
135    #[serde(rename = "password")]
136    Password,
137    /// MD5-hashed password
138    #[serde(rename = "md5")]
139    Md5,
140    /// LDAP authentication
141    #[serde(rename = "ldap")]
142    Ldap,
143    /// OAuth/JWT authentication
144    #[serde(rename = "oauth")]
145    OAuth,
146}
147
148impl HbaConfig {
149    /// Find the first matching HBA entry for the given connection parameters
150    pub fn find_matching_entry(
151        &self,
152        connection_type: &ConnectionType,
153        database: &str,
154        user: &str,
155        client_addr: Option<&IpAddr>,
156    ) -> Option<&HbaEntry> {
157        self.entries
158            .iter()
159            .find(|entry| self.matches_entry(entry, connection_type, database, user, client_addr))
160    }
161
162    fn matches_entry(
163        &self,
164        entry: &HbaEntry,
165        connection_type: &ConnectionType,
166        database: &str,
167        user: &str,
168        client_addr: Option<&IpAddr>,
169    ) -> bool {
170        // Check connection type
171        if !self.matches_connection_type(&entry.connection_type, connection_type) {
172            return false;
173        }
174
175        // Check database
176        if !self.matches_list(&entry.databases, database) {
177            return false;
178        }
179
180        // Check user
181        if !self.matches_list(&entry.users, user) {
182            return false;
183        }
184
185        // Check address (only for non-local connections)
186        if *connection_type != ConnectionType::Local
187            && let Some(addresses) = &entry.addresses
188        {
189            if let Some(addr) = client_addr {
190                if !self.matches_address(addresses, addr) {
191                    return false;
192                }
193            } else {
194                return false;
195            }
196        }
197
198        true
199    }
200
201    fn matches_connection_type(
202        &self,
203        entry_type: &ConnectionType,
204        actual_type: &ConnectionType,
205    ) -> bool {
206        matches!(
207            (entry_type, actual_type),
208            (ConnectionType::Local, ConnectionType::Local)
209                | (ConnectionType::Host, ConnectionType::Host)
210                | (ConnectionType::Host, ConnectionType::HostSsl)
211                | (ConnectionType::Host, ConnectionType::HostNoSsl)
212                | (ConnectionType::HostSsl, ConnectionType::HostSsl)
213                | (ConnectionType::HostNoSsl, ConnectionType::HostNoSsl)
214        )
215    }
216
217    fn matches_list(&self, list: &[String], value: &str) -> bool {
218        list.iter().any(|item| item == ALL_KEYWORD || item == value)
219    }
220
221    fn matches_address(&self, patterns: &[AddressPattern], addr: &IpAddr) -> bool {
222        patterns.iter().any(|pattern| match pattern {
223            AddressPattern::All => true,
224            AddressPattern::Cidr(cidr) => self.matches_cidr(cidr, addr),
225            AddressPattern::Hostname(_hostname) => {
226                // TODO: implement hostname resolution
227                false
228            }
229        })
230    }
231
232    fn matches_cidr(&self, cidr: &str, addr: &IpAddr) -> bool {
233        if let Ok(network) = ipnet::IpNet::from_str(cidr) {
234            network.contains(addr)
235        } else {
236            false
237        }
238    }
239
240    /// Load HBA configuration from a TOML string
241    pub fn from_toml(content: &str) -> Result<Self, toml::de::Error> {
242        toml::from_str(content)
243    }
244
245    /// Save HBA configuration to a TOML string
246    pub fn to_toml(&self) -> Result<String, toml::ser::Error> {
247        toml::to_string_pretty(self)
248    }
249}
250
251impl FromStr for HbaConfig {
252    type Err = toml::de::Error;
253
254    fn from_str(s: &str) -> Result<Self, Self::Err> {
255        Self::from_toml(s)
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use std::net::{IpAddr, Ipv4Addr};
262
263    use super::*;
264
265    #[test]
266    fn test_default_hba_config() {
267        let config = HbaConfig::default();
268        assert_eq!(config.entries.len(), 2);
269
270        // Test local connection
271        let entry = config.find_matching_entry(&ConnectionType::Local, "testdb", "testuser", None);
272        assert!(entry.is_some());
273        assert_eq!(entry.unwrap().auth_method, AuthMethod::Trust);
274    }
275
276    #[test]
277    fn test_cidr_matching() {
278        let config = HbaConfig::default();
279        let addr = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
280
281        assert!(config.matches_cidr("192.168.1.0/24", &addr));
282        assert!(!config.matches_cidr("10.0.0.0/8", &addr));
283    }
284
285    #[test]
286    fn test_ldap_config_serialization() {
287        let mut auth_options = HashMap::new();
288        auth_options.insert("ldapserver".to_owned(), "ldap.example.com".to_owned());
289        auth_options.insert("ldapport".to_owned(), "389".to_owned());
290        auth_options.insert("ldapprefix".to_owned(), "cn=".to_owned());
291        auth_options.insert("ldapsuffix".to_owned(), ",dc=example,dc=com".to_owned());
292
293        let entry = HbaEntry {
294            connection_type: ConnectionType::Host,
295            databases: vec![ALL_KEYWORD.to_owned()],
296            users: vec![ALL_KEYWORD.to_owned()],
297            addresses: Some(vec![AddressPattern::Cidr("10.0.0.0/8".to_owned())]),
298            auth_method: AuthMethod::Ldap,
299            auth_options,
300        };
301
302        let config = HbaConfig {
303            entries: vec![entry],
304        };
305
306        let toml_str = config.to_toml().unwrap();
307        let parsed_config = HbaConfig::from_toml(&toml_str).unwrap();
308
309        assert_eq!(parsed_config.entries.len(), 1);
310        assert_eq!(parsed_config.entries[0].auth_method, AuthMethod::Ldap);
311        assert_eq!(
312            parsed_config.entries[0].auth_options.get("ldapserver"),
313            Some(&"ldap.example.com".to_owned())
314        );
315    }
316}