pgwire/net.rs
1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::io;
16use std::net::SocketAddr as IpSocketAddr;
17#[cfg(madsim)]
18use std::os::unix::net::SocketAddr as UnixSocketAddr;
19use std::sync::Arc;
20
21#[cfg(not(madsim))]
22use tokio::net::unix::SocketAddr as UnixSocketAddr;
23use tokio::net::{TcpListener, TcpStream, UnixListener, UnixStream};
24
25/// A wrapper of either [`TcpListener`] or [`UnixListener`].
26pub(crate) enum Listener {
27 Tcp(TcpListener),
28 Unix(UnixListener),
29}
30
31/// A wrapper of either [`TcpStream`] or [`UnixStream`].
32#[auto_enums::enum_derive(tokio1::AsyncRead, tokio1::AsyncWrite)]
33pub(crate) enum Stream {
34 Tcp(TcpStream),
35 Unix(UnixStream),
36}
37
38/// A wrapper of either [`std::net::SocketAddr`] or [`tokio::net::unix::SocketAddr`].
39pub enum Address {
40 Tcp(IpSocketAddr),
41 Unix(UnixSocketAddr),
42}
43
44pub type AddressRef = Arc<Address>;
45
46impl std::fmt::Display for Address {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 match self {
49 Address::Tcp(addr) => addr.fmt(f),
50 Address::Unix(addr) => {
51 if let Some(path) = addr.as_pathname() {
52 path.display().fmt(f)
53 } else {
54 std::fmt::Debug::fmt(addr, f)
55 }
56 }
57 }
58 }
59}
60
61impl Listener {
62 /// Creates a new [`Listener`] bound to the specified address.
63 ///
64 /// If the address starts with `unix:`, it will create a [`UnixListener`].
65 /// Otherwise, it will create a [`TcpListener`].
66 pub async fn bind(addr: &str) -> io::Result<Self> {
67 if let Some(path) = addr.strip_prefix("unix:") {
68 UnixListener::bind(path).map(Self::Unix)
69 } else {
70 TcpListener::bind(addr).await.map(Self::Tcp)
71 }
72 }
73
74 /// Accepts a new incoming connection from this listener.
75 ///
76 /// Returns a tuple of the stream and the string representation of the peer address.
77 pub async fn accept(&self, tcp_keepalive: &TcpKeepalive) -> io::Result<(Stream, Address)> {
78 match self {
79 Self::Tcp(listener) => {
80 let (stream, addr) = listener.accept().await?;
81 stream.set_nodelay(true)?;
82 // Set TCP keepalive to 5 minutes, which is less than the connection idle timeout of 350 seconds in AWS ELB.
83 // https://docs.aws.amazon.com/elasticloadbalancing/latest/network/network-load-balancers.html#connection-idle-timeout
84 #[cfg(not(madsim))]
85 {
86 let r = socket2::SockRef::from(&stream);
87 r.set_tcp_keepalive(tcp_keepalive)?;
88 }
89 Ok((Stream::Tcp(stream), Address::Tcp(addr)))
90 }
91 Self::Unix(listener) => {
92 let (stream, addr) = listener.accept().await?;
93 Ok((Stream::Unix(stream), Address::Unix(addr)))
94 }
95 }
96 }
97}
98
99pub use socket2::TcpKeepalive;