risingwave_connector/connector_common/
maybe_tls_connector.rs1use core::task;
16use std::error::Error;
17use std::io;
18use std::pin::Pin;
19use std::task::Poll;
20
21use futures::{Future, FutureExt};
22use openssl::error::ErrorStack;
23use postgres_openssl::{MakeTlsConnector, TlsConnector, TlsStream};
24use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
25use tokio_postgres::NoTls;
26use tokio_postgres::tls::{self, MakeTlsConnect, NoTlsFuture, NoTlsStream, TlsConnect};
27
28pub enum MaybeMakeTlsConnector {
29 NoTls(NoTls),
30 Tls(MakeTlsConnector),
31}
32
33impl<S> MakeTlsConnect<S> for MaybeMakeTlsConnector
34where
35 S: AsyncRead + AsyncWrite + Unpin + core::fmt::Debug + 'static + Sync + Send,
36{
37 type Error = ErrorStack;
38 type Stream = MaybeTlsStream<S>;
39 type TlsConnect = MaybeTlsConnector;
40
41 fn make_tls_connect(&mut self, domain: &str) -> Result<Self::TlsConnect, Self::Error> {
42 match self {
43 MaybeMakeTlsConnector::NoTls(make_connector) => {
44 let connector =
45 <NoTls as MakeTlsConnect<S>>::make_tls_connect(make_connector, domain)
46 .expect("make NoTls connector always success");
47 Ok(MaybeTlsConnector::NoTls(connector))
48 }
49 MaybeMakeTlsConnector::Tls(make_connector) => {
50 <MakeTlsConnector as MakeTlsConnect<S>>::make_tls_connect(make_connector, domain)
51 .map(MaybeTlsConnector::Tls)
52 }
53 }
54 }
55}
56
57pub enum MaybeTlsConnector {
58 NoTls(NoTls),
59 Tls(TlsConnector),
60}
61
62impl<S> TlsConnect<S> for MaybeTlsConnector
63where
64 S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
65{
66 type Error = Box<dyn Error + Send + Sync>;
67 type Future = MaybeTlsFuture<Self::Stream, Self::Error>;
68 type Stream = MaybeTlsStream<S>;
69
70 fn connect(self, stream: S) -> Self::Future {
71 match self {
72 MaybeTlsConnector::NoTls(connector) => MaybeTlsFuture::NoTls(connector.connect(stream)),
73 MaybeTlsConnector::Tls(connector) => MaybeTlsFuture::Tls(Box::pin(
74 connector
75 .connect(stream)
76 .map(|x| x.map(|x| MaybeTlsStream::Tls(x))),
77 )),
78 }
79 }
80}
81
82pub enum MaybeTlsStream<S> {
83 NoTls(NoTlsStream),
84 Tls(TlsStream<S>),
85}
86
87impl<S> AsyncRead for MaybeTlsStream<S>
88where
89 S: AsyncRead + AsyncWrite + Unpin,
90{
91 fn poll_read(
92 mut self: Pin<&mut Self>,
93 cx: &mut task::Context<'_>,
94 buf: &mut ReadBuf<'_>,
95 ) -> Poll<io::Result<()>> {
96 match &mut *self {
97 MaybeTlsStream::NoTls(stream) => {
98 <NoTlsStream as AsyncRead>::poll_read(Pin::new(stream), cx, buf)
99 }
100 MaybeTlsStream::Tls(stream) => {
101 <TlsStream<S> as AsyncRead>::poll_read(Pin::new(stream), cx, buf)
102 }
103 }
104 }
105}
106
107impl<S> AsyncWrite for MaybeTlsStream<S>
108where
109 S: AsyncRead + AsyncWrite + Unpin,
110{
111 fn poll_write(
112 mut self: Pin<&mut Self>,
113 cx: &mut task::Context<'_>,
114 buf: &[u8],
115 ) -> Poll<io::Result<usize>> {
116 match &mut *self {
117 MaybeTlsStream::NoTls(stream) => {
118 <NoTlsStream as AsyncWrite>::poll_write(Pin::new(stream), cx, buf)
119 }
120 MaybeTlsStream::Tls(stream) => {
121 <TlsStream<S> as AsyncWrite>::poll_write(Pin::new(stream), cx, buf)
122 }
123 }
124 }
125
126 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
127 match &mut *self {
128 MaybeTlsStream::NoTls(stream) => {
129 <NoTlsStream as AsyncWrite>::poll_flush(Pin::new(stream), cx)
130 }
131 MaybeTlsStream::Tls(stream) => {
132 <TlsStream<S> as AsyncWrite>::poll_flush(Pin::new(stream), cx)
133 }
134 }
135 }
136
137 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
138 match &mut *self {
139 MaybeTlsStream::NoTls(stream) => {
140 <NoTlsStream as AsyncWrite>::poll_shutdown(Pin::new(stream), cx)
141 }
142 MaybeTlsStream::Tls(stream) => {
143 <TlsStream<S> as AsyncWrite>::poll_shutdown(Pin::new(stream), cx)
144 }
145 }
146 }
147}
148
149impl<S> tls::TlsStream for MaybeTlsStream<S>
150where
151 S: AsyncRead + AsyncWrite + Unpin,
152{
153 fn channel_binding(&self) -> tls::ChannelBinding {
154 match self {
155 MaybeTlsStream::NoTls(stream) => stream.channel_binding(),
156 MaybeTlsStream::Tls(stream) => stream.channel_binding(),
157 }
158 }
159}
160
161pub enum MaybeTlsFuture<S, E> {
162 NoTls(NoTlsFuture),
163 Tls(Pin<Box<dyn Future<Output = Result<S, E>> + Send>>),
164}
165
166impl<S, E> Future for MaybeTlsFuture<MaybeTlsStream<S>, E>
167where
168 MaybeTlsStream<S>: Sync + Send,
169 E: std::convert::From<tokio_postgres::tls::NoTlsError>,
170{
171 type Output = Result<MaybeTlsStream<S>, E>;
172
173 fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
174 match &mut *self {
175 MaybeTlsFuture::NoTls(fut) => fut
176 .poll_unpin(cx)
177 .map(|x| x.map(|x| MaybeTlsStream::NoTls(x)))
178 .map_err(|x| x.into()),
179 MaybeTlsFuture::Tls(fut) => fut.poll_unpin(cx),
180 }
181 }
182}