risingwave_common/config/
hba.rs1use std::cmp::PartialEq;
16use std::collections::HashMap;
17use std::net::IpAddr;
18use std::str::FromStr;
19
20use serde::{Deserialize, Deserializer, Serialize, Serializer};
21
22const ALL_KEYWORD: &str = "all";
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct HbaConfig {
29 pub entries: Vec<HbaEntry>,
30}
31
32impl Default for HbaConfig {
33 fn default() -> Self {
34 Self {
35 entries: vec![
36 HbaEntry {
38 connection_type: ConnectionType::Local,
39 databases: vec![ALL_KEYWORD.to_owned()],
40 users: vec![ALL_KEYWORD.to_owned()],
41 addresses: None,
42 auth_method: AuthMethod::Trust,
43 auth_options: HashMap::new(),
44 },
45 HbaEntry {
47 connection_type: ConnectionType::Host,
48 databases: vec![ALL_KEYWORD.to_owned()],
49 users: vec![ALL_KEYWORD.to_owned()],
50 addresses: Some(vec![AddressPattern::Cidr("0.0.0.0/0".to_owned())]),
51 auth_method: AuthMethod::Password,
52 auth_options: HashMap::new(),
53 },
54 ],
55 }
56 }
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct HbaEntry {
61 pub connection_type: ConnectionType,
63 pub databases: Vec<String>,
65 pub users: Vec<String>,
67 pub addresses: Option<Vec<AddressPattern>>,
69 pub auth_method: AuthMethod,
71 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
73 pub auth_options: HashMap<String, String>,
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
77pub enum ConnectionType {
78 Local,
80 Host,
82 #[serde(rename = "hostssl")]
84 HostSsl,
85 #[serde(rename = "hostnossl")]
87 HostNoSsl,
88}
89
90#[derive(Debug, Clone, PartialEq)]
91pub enum AddressPattern {
92 Cidr(String),
94 Hostname(String),
96 All,
98}
99
100impl Serialize for AddressPattern {
101 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
102 where
103 S: Serializer,
104 {
105 match self {
106 AddressPattern::All => serializer.serialize_str(ALL_KEYWORD),
107 AddressPattern::Cidr(s) => serializer.serialize_str(s),
108 AddressPattern::Hostname(s) => serializer.serialize_str(s),
109 }
110 }
111}
112
113impl<'de> Deserialize<'de> for AddressPattern {
114 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
115 where
116 D: Deserializer<'de>,
117 {
118 let s = String::deserialize(deserializer)?;
119 if s == ALL_KEYWORD {
120 Ok(AddressPattern::All)
121 } else if s.contains('/') {
122 Ok(AddressPattern::Cidr(s))
123 } else {
124 Ok(AddressPattern::Hostname(s))
125 }
126 }
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
130pub enum AuthMethod {
131 #[serde(rename = "trust")]
133 Trust,
134 #[serde(rename = "password")]
136 Password,
137 #[serde(rename = "md5")]
139 Md5,
140 #[serde(rename = "ldap")]
142 Ldap,
143 #[serde(rename = "oauth")]
145 OAuth,
146}
147
148impl HbaConfig {
149 pub fn find_matching_entry(
151 &self,
152 connection_type: &ConnectionType,
153 database: &str,
154 user: &str,
155 client_addr: Option<&IpAddr>,
156 ) -> Option<&HbaEntry> {
157 self.entries
158 .iter()
159 .find(|entry| self.matches_entry(entry, connection_type, database, user, client_addr))
160 }
161
162 fn matches_entry(
163 &self,
164 entry: &HbaEntry,
165 connection_type: &ConnectionType,
166 database: &str,
167 user: &str,
168 client_addr: Option<&IpAddr>,
169 ) -> bool {
170 if !self.matches_connection_type(&entry.connection_type, connection_type) {
172 return false;
173 }
174
175 if !self.matches_list(&entry.databases, database) {
177 return false;
178 }
179
180 if !self.matches_list(&entry.users, user) {
182 return false;
183 }
184
185 if *connection_type != ConnectionType::Local
187 && let Some(addresses) = &entry.addresses
188 {
189 if let Some(addr) = client_addr {
190 if !self.matches_address(addresses, addr) {
191 return false;
192 }
193 } else {
194 return false;
195 }
196 }
197
198 true
199 }
200
201 fn matches_connection_type(
202 &self,
203 entry_type: &ConnectionType,
204 actual_type: &ConnectionType,
205 ) -> bool {
206 matches!(
207 (entry_type, actual_type),
208 (ConnectionType::Local, ConnectionType::Local)
209 | (ConnectionType::Host, ConnectionType::Host)
210 | (ConnectionType::Host, ConnectionType::HostSsl)
211 | (ConnectionType::Host, ConnectionType::HostNoSsl)
212 | (ConnectionType::HostSsl, ConnectionType::HostSsl)
213 | (ConnectionType::HostNoSsl, ConnectionType::HostNoSsl)
214 )
215 }
216
217 fn matches_list(&self, list: &[String], value: &str) -> bool {
218 list.iter().any(|item| item == ALL_KEYWORD || item == value)
219 }
220
221 fn matches_address(&self, patterns: &[AddressPattern], addr: &IpAddr) -> bool {
222 patterns.iter().any(|pattern| match pattern {
223 AddressPattern::All => true,
224 AddressPattern::Cidr(cidr) => self.matches_cidr(cidr, addr),
225 AddressPattern::Hostname(_hostname) => {
226 false
228 }
229 })
230 }
231
232 fn matches_cidr(&self, cidr: &str, addr: &IpAddr) -> bool {
233 if let Ok(network) = ipnet::IpNet::from_str(cidr) {
234 network.contains(addr)
235 } else {
236 false
237 }
238 }
239
240 pub fn from_toml(content: &str) -> Result<Self, toml::de::Error> {
242 toml::from_str(content)
243 }
244
245 pub fn to_toml(&self) -> Result<String, toml::ser::Error> {
247 toml::to_string_pretty(self)
248 }
249}
250
251impl FromStr for HbaConfig {
252 type Err = toml::de::Error;
253
254 fn from_str(s: &str) -> Result<Self, Self::Err> {
255 Self::from_toml(s)
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use std::net::{IpAddr, Ipv4Addr};
262
263 use super::*;
264
265 #[test]
266 fn test_default_hba_config() {
267 let config = HbaConfig::default();
268 assert_eq!(config.entries.len(), 2);
269
270 let entry = config.find_matching_entry(&ConnectionType::Local, "testdb", "testuser", None);
272 assert!(entry.is_some());
273 assert_eq!(entry.unwrap().auth_method, AuthMethod::Trust);
274 }
275
276 #[test]
277 fn test_cidr_matching() {
278 let config = HbaConfig::default();
279 let addr = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
280
281 assert!(config.matches_cidr("192.168.1.0/24", &addr));
282 assert!(!config.matches_cidr("10.0.0.0/8", &addr));
283 }
284
285 #[test]
286 fn test_ldap_config_serialization() {
287 let mut auth_options = HashMap::new();
288 auth_options.insert("ldapserver".to_owned(), "ldap.example.com".to_owned());
289 auth_options.insert("ldapport".to_owned(), "389".to_owned());
290 auth_options.insert("ldapprefix".to_owned(), "cn=".to_owned());
291 auth_options.insert("ldapsuffix".to_owned(), ",dc=example,dc=com".to_owned());
292
293 let entry = HbaEntry {
294 connection_type: ConnectionType::Host,
295 databases: vec![ALL_KEYWORD.to_owned()],
296 users: vec![ALL_KEYWORD.to_owned()],
297 addresses: Some(vec![AddressPattern::Cidr("10.0.0.0/8".to_owned())]),
298 auth_method: AuthMethod::Ldap,
299 auth_options,
300 };
301
302 let config = HbaConfig {
303 entries: vec![entry],
304 };
305
306 let toml_str = config.to_toml().unwrap();
307 let parsed_config = HbaConfig::from_toml(&toml_str).unwrap();
308
309 assert_eq!(parsed_config.entries.len(), 1);
310 assert_eq!(parsed_config.entries[0].auth_method, AuthMethod::Ldap);
311 assert_eq!(
312 parsed_config.entries[0].auth_options.get("ldapserver"),
313 Some(&"ldap.example.com".to_owned())
314 );
315 }
316}