pgwire/
net.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::io;
16use std::net::SocketAddr as IpSocketAddr;
17#[cfg(madsim)]
18use std::os::unix::net::SocketAddr as UnixSocketAddr;
19use std::sync::Arc;
20
21#[cfg(not(madsim))]
22use tokio::net::unix::SocketAddr as UnixSocketAddr;
23use tokio::net::{TcpListener, TcpStream, UnixListener, UnixStream};
24
25/// A wrapper of either [`TcpListener`] or [`UnixListener`].
26pub(crate) enum Listener {
27    Tcp(TcpListener),
28    Unix(UnixListener),
29}
30
31/// A wrapper of either [`TcpStream`] or [`UnixStream`].
32#[auto_enums::enum_derive(tokio1::AsyncRead, tokio1::AsyncWrite)]
33pub(crate) enum Stream {
34    Tcp(TcpStream),
35    Unix(UnixStream),
36}
37
38/// A wrapper of either [`std::net::SocketAddr`] or [`tokio::net::unix::SocketAddr`].
39pub enum Address {
40    Tcp(IpSocketAddr),
41    Unix(UnixSocketAddr),
42}
43
44pub type AddressRef = Arc<Address>;
45
46impl std::fmt::Display for Address {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        match self {
49            Address::Tcp(addr) => addr.fmt(f),
50            Address::Unix(addr) => {
51                if let Some(path) = addr.as_pathname() {
52                    path.display().fmt(f)
53                } else {
54                    std::fmt::Debug::fmt(addr, f)
55                }
56            }
57        }
58    }
59}
60
61impl Listener {
62    /// Creates a new [`Listener`] bound to the specified address.
63    ///
64    /// If the address starts with `unix:`, it will create a [`UnixListener`].
65    /// Otherwise, it will create a [`TcpListener`].
66    pub async fn bind(addr: &str) -> io::Result<Self> {
67        if let Some(path) = addr.strip_prefix("unix:") {
68            UnixListener::bind(path).map(Self::Unix)
69        } else {
70            TcpListener::bind(addr).await.map(Self::Tcp)
71        }
72    }
73
74    /// Accepts a new incoming connection from this listener.
75    ///
76    /// Returns a tuple of the stream and the string representation of the peer address.
77    pub async fn accept(&self, tcp_keepalive: &TcpKeepalive) -> io::Result<(Stream, Address)> {
78        match self {
79            Self::Tcp(listener) => {
80                let (stream, addr) = listener.accept().await?;
81                stream.set_nodelay(true)?;
82                // Set TCP keepalive to 5 minutes, which is less than the connection idle timeout of 350 seconds in AWS ELB.
83                // https://docs.aws.amazon.com/elasticloadbalancing/latest/network/network-load-balancers.html#connection-idle-timeout
84                #[cfg(not(madsim))]
85                {
86                    let r = socket2::SockRef::from(&stream);
87                    r.set_tcp_keepalive(tcp_keepalive)?;
88                }
89                Ok((Stream::Tcp(stream), Address::Tcp(addr)))
90            }
91            Self::Unix(listener) => {
92                let (stream, addr) = listener.accept().await?;
93                Ok((Stream::Unix(stream), Address::Unix(addr)))
94            }
95        }
96    }
97}
98
99pub use socket2::TcpKeepalive;