risingwave_connector/connector_common/
maybe_tls_connector.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use core::task;
16use std::error::Error;
17use std::io;
18use std::pin::Pin;
19use std::task::Poll;
20
21use futures::{Future, FutureExt};
22use openssl::error::ErrorStack;
23use postgres_openssl::{MakeTlsConnector, TlsConnector, TlsStream};
24use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
25use tokio_postgres::NoTls;
26use tokio_postgres::tls::{self, MakeTlsConnect, NoTlsFuture, NoTlsStream, TlsConnect};
27
28pub enum MaybeMakeTlsConnector {
29    NoTls(NoTls),
30    Tls(MakeTlsConnector),
31}
32
33impl<S> MakeTlsConnect<S> for MaybeMakeTlsConnector
34where
35    S: AsyncRead + AsyncWrite + Unpin + core::fmt::Debug + 'static + Sync + Send,
36{
37    type Error = ErrorStack;
38    type Stream = MaybeTlsStream<S>;
39    type TlsConnect = MaybeTlsConnector;
40
41    fn make_tls_connect(&mut self, domain: &str) -> Result<Self::TlsConnect, Self::Error> {
42        match self {
43            MaybeMakeTlsConnector::NoTls(make_connector) => {
44                let connector =
45                    <NoTls as MakeTlsConnect<S>>::make_tls_connect(make_connector, domain)
46                        .expect("make NoTls connector always success");
47                Ok(MaybeTlsConnector::NoTls(connector))
48            }
49            MaybeMakeTlsConnector::Tls(make_connector) => {
50                <MakeTlsConnector as MakeTlsConnect<S>>::make_tls_connect(make_connector, domain)
51                    .map(MaybeTlsConnector::Tls)
52            }
53        }
54    }
55}
56
57pub enum MaybeTlsConnector {
58    NoTls(NoTls),
59    Tls(TlsConnector),
60}
61
62impl<S> TlsConnect<S> for MaybeTlsConnector
63where
64    S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
65{
66    type Error = Box<dyn Error + Send + Sync>;
67    type Future = MaybeTlsFuture<Self::Stream, Self::Error>;
68    type Stream = MaybeTlsStream<S>;
69
70    fn connect(self, stream: S) -> Self::Future {
71        match self {
72            MaybeTlsConnector::NoTls(connector) => MaybeTlsFuture::NoTls(connector.connect(stream)),
73            MaybeTlsConnector::Tls(connector) => MaybeTlsFuture::Tls(Box::pin(
74                connector
75                    .connect(stream)
76                    .map(|x| x.map(|x| MaybeTlsStream::Tls(x))),
77            )),
78        }
79    }
80}
81
82pub enum MaybeTlsStream<S> {
83    NoTls(NoTlsStream),
84    Tls(TlsStream<S>),
85}
86
87impl<S> AsyncRead for MaybeTlsStream<S>
88where
89    S: AsyncRead + AsyncWrite + Unpin,
90{
91    fn poll_read(
92        mut self: Pin<&mut Self>,
93        cx: &mut task::Context<'_>,
94        buf: &mut ReadBuf<'_>,
95    ) -> Poll<io::Result<()>> {
96        match &mut *self {
97            MaybeTlsStream::NoTls(stream) => {
98                <NoTlsStream as AsyncRead>::poll_read(Pin::new(stream), cx, buf)
99            }
100            MaybeTlsStream::Tls(stream) => {
101                <TlsStream<S> as AsyncRead>::poll_read(Pin::new(stream), cx, buf)
102            }
103        }
104    }
105}
106
107impl<S> AsyncWrite for MaybeTlsStream<S>
108where
109    S: AsyncRead + AsyncWrite + Unpin,
110{
111    fn poll_write(
112        mut self: Pin<&mut Self>,
113        cx: &mut task::Context<'_>,
114        buf: &[u8],
115    ) -> Poll<io::Result<usize>> {
116        match &mut *self {
117            MaybeTlsStream::NoTls(stream) => {
118                <NoTlsStream as AsyncWrite>::poll_write(Pin::new(stream), cx, buf)
119            }
120            MaybeTlsStream::Tls(stream) => {
121                <TlsStream<S> as AsyncWrite>::poll_write(Pin::new(stream), cx, buf)
122            }
123        }
124    }
125
126    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
127        match &mut *self {
128            MaybeTlsStream::NoTls(stream) => {
129                <NoTlsStream as AsyncWrite>::poll_flush(Pin::new(stream), cx)
130            }
131            MaybeTlsStream::Tls(stream) => {
132                <TlsStream<S> as AsyncWrite>::poll_flush(Pin::new(stream), cx)
133            }
134        }
135    }
136
137    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
138        match &mut *self {
139            MaybeTlsStream::NoTls(stream) => {
140                <NoTlsStream as AsyncWrite>::poll_shutdown(Pin::new(stream), cx)
141            }
142            MaybeTlsStream::Tls(stream) => {
143                <TlsStream<S> as AsyncWrite>::poll_shutdown(Pin::new(stream), cx)
144            }
145        }
146    }
147}
148
149impl<S> tls::TlsStream for MaybeTlsStream<S>
150where
151    S: AsyncRead + AsyncWrite + Unpin,
152{
153    fn channel_binding(&self) -> tls::ChannelBinding {
154        match self {
155            MaybeTlsStream::NoTls(stream) => stream.channel_binding(),
156            MaybeTlsStream::Tls(stream) => stream.channel_binding(),
157        }
158    }
159}
160
161pub enum MaybeTlsFuture<S, E> {
162    NoTls(NoTlsFuture),
163    Tls(Pin<Box<dyn Future<Output = Result<S, E>> + Send>>),
164}
165
166impl<S, E> Future for MaybeTlsFuture<MaybeTlsStream<S>, E>
167where
168    MaybeTlsStream<S>: Sync + Send,
169    E: std::convert::From<tokio_postgres::tls::NoTlsError>,
170{
171    type Output = Result<MaybeTlsStream<S>, E>;
172
173    fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
174        match &mut *self {
175            MaybeTlsFuture::NoTls(fut) => fut
176                .poll_unpin(cx)
177                .map(|x| x.map(|x| MaybeTlsStream::NoTls(x)))
178                .map_err(|x| x.into()),
179            MaybeTlsFuture::Tls(fut) => fut.poll_unpin(cx),
180        }
181    }
182}