risingwave_common/util/
addr.rs1use std::net::{SocketAddr, ToSocketAddrs};
16use std::str::FromStr;
17use std::time::Duration;
18
19use anyhow::Context;
20use risingwave_pb::common::PbHostAddress;
21use thiserror_ext::AsReport;
22use tokio::time::sleep;
23use tokio_retry::strategy::ExponentialBackoff;
24use tracing::error;
25
26#[derive(Debug, Clone, PartialEq, Eq, Hash)]
28pub struct HostAddr {
29 pub host: String,
30 pub port: u16,
31}
32
33impl std::fmt::Display for HostAddr {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 write!(f, "{}:{}", self.host, self.port)
36 }
37}
38impl From<SocketAddr> for HostAddr {
39 fn from(addr: SocketAddr) -> Self {
40 HostAddr {
41 host: addr.ip().to_string(),
42 port: addr.port(),
43 }
44 }
45}
46
47impl TryFrom<&str> for HostAddr {
48 type Error = anyhow::Error;
49
50 fn try_from(s: &str) -> Result<Self, Self::Error> {
51 let s = format!("http://{s}");
52 let addr = url::Url::parse(&s).with_context(|| format!("failed to parse address: {s}"))?;
53 Ok(HostAddr {
54 host: addr.host().context("invalid host")?.to_string(),
55 port: addr.port().context("invalid port")?,
56 })
57 }
58}
59
60impl TryFrom<&String> for HostAddr {
61 type Error = anyhow::Error;
62
63 fn try_from(s: &String) -> Result<Self, Self::Error> {
64 Self::try_from(s.as_str())
65 }
66}
67
68impl FromStr for HostAddr {
69 type Err = anyhow::Error;
70
71 fn from_str(s: &str) -> Result<Self, Self::Err> {
72 Self::try_from(s)
73 }
74}
75
76impl From<&PbHostAddress> for HostAddr {
77 fn from(addr: &PbHostAddress) -> Self {
78 HostAddr {
79 host: addr.get_host().to_string(),
80 port: addr.get_port() as u16,
81 }
82 }
83}
84
85impl HostAddr {
86 pub fn to_protobuf(&self) -> PbHostAddress {
87 PbHostAddress {
88 host: self.host.clone(),
89 port: self.port as i32,
90 }
91 }
92}
93
94pub fn is_local_address(server_addr: &HostAddr, peer_addr: &HostAddr) -> bool {
95 server_addr == peer_addr
96}
97
98pub async fn try_resolve_dns(host: &str, port: i32) -> Result<SocketAddr, String> {
99 let addr = format!("{}:{}", host, port);
100 let mut backoff = ExponentialBackoff::from_millis(100)
101 .max_delay(Duration::from_secs(3))
102 .factor(5);
103 const MAX_RETRY: usize = 20;
104 for i in 1..=MAX_RETRY {
105 let err = match addr.to_socket_addrs() {
106 Ok(mut addr_iter) => {
107 if let Some(addr) = addr_iter.next() {
108 return Ok(addr);
109 } else {
110 format!("{} resolved to no addr", addr)
111 }
112 }
113 Err(e) => e.to_report_string(),
114 };
115 let delay = backoff.next().unwrap();
118 error!(attempt = i, backoff_delay = ?delay, err, addr, "fail to resolve worker node address");
119 sleep(delay).await;
120 }
121 Err(format!("failed to resolve dns: {}", addr))
122}
123
124#[cfg(test)]
125mod tests {
126 use crate::util::addr::{HostAddr, is_local_address};
127
128 #[test]
129 fn test_is_local_address() {
130 let check_local = |a: &str, b: &str, result: bool| {
131 assert_eq!(
132 is_local_address(&a.parse().unwrap(), &b.parse().unwrap()),
133 result
134 );
135 };
136 check_local("localhost:3456", "localhost:3456", true);
137 check_local("10.11.12.13:3456", "10.11.12.13:3456", true);
138 check_local("some.host.in.k8s:3456", "some.host.in.k8s:3456", true);
139 check_local("some.host.in.k8s:3456", "other.host.in.k8s:3456", false);
140 check_local("some.host.in.k8s:3456", "some.host.in.k8s:4567", false);
141 }
142
143 #[test]
144 fn test_host_addr_convert() {
145 let addr = "1.2.3.4:567";
146 assert_eq!(
147 addr.parse::<HostAddr>().unwrap(),
148 HostAddr {
149 host: String::from("1.2.3.4"),
150 port: 567
151 }
152 );
153 let addr = "test.test:12345";
154 assert_eq!(
155 addr.parse::<HostAddr>().unwrap(),
156 HostAddr {
157 host: String::from("test.test"),
158 port: 12345
159 }
160 );
161 let addr = "test.test";
162 assert!(addr.parse::<HostAddr>().is_err());
163 let addr = "test.test:65537";
164 assert!(addr.parse::<HostAddr>().is_err());
165 let addr = "test.test:";
166 assert!(addr.parse::<HostAddr>().is_err());
167 let addr = "test.test:12345:12345";
168 assert!(addr.parse::<HostAddr>().is_err());
169 }
170}