pgwire/
pg_protocol.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::collections::HashMap;
17use std::io::ErrorKind;
18use std::panic::AssertUnwindSafe;
19use std::pin::Pin;
20use std::str::Utf8Error;
21use std::sync::{Arc, LazyLock, Weak};
22use std::time::{Duration, Instant};
23use std::{io, str};
24
25use bytes::{Bytes, BytesMut};
26use futures::FutureExt;
27use futures::stream::StreamExt;
28use itertools::Itertools;
29use openssl::ssl::{SslAcceptor, SslContext, SslContextRef, SslMethod};
30use risingwave_common::types::DataType;
31use risingwave_common::util::deployment::Deployment;
32use risingwave_common::util::env_var::env_var_is_true;
33use risingwave_common::util::panic::FutureCatchUnwindExt;
34use risingwave_common::util::query_log::*;
35use risingwave_common::{PG_VERSION, SERVER_ENCODING, STANDARD_CONFORMING_STRINGS};
36use risingwave_sqlparser::ast::{RedactSqlOptionKeywordsRef, Statement};
37use risingwave_sqlparser::parser::Parser;
38use thiserror_ext::AsReport;
39use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
40use tokio::sync::Mutex;
41use tokio_openssl::SslStream;
42use tracing::Instrument;
43
44use crate::error::{PsqlError, PsqlResult};
45use crate::memory_manager::{MessageMemoryGuard, MessageMemoryManagerRef};
46use crate::net::AddressRef;
47use crate::pg_extended::ResultCache;
48use crate::pg_message::{
49    BeCommandCompleteMessage, BeMessage, BeParameterStatusMessage, FeBindMessage, FeCancelMessage,
50    FeCloseMessage, FeDescribeMessage, FeExecuteMessage, FeMessage, FeMessageHeader,
51    FeParseMessage, FePasswordMessage, FeStartupMessage, ServerThrottleReason, TransactionStatus,
52};
53use crate::pg_server::{Session, SessionManager, UserAuthenticator};
54use crate::types::Format;
55
56/// Truncates query log if it's longer than `RW_QUERY_LOG_TRUNCATE_LEN`, to avoid log file being too
57/// large.
58static RW_QUERY_LOG_TRUNCATE_LEN: LazyLock<usize> =
59    LazyLock::new(|| match std::env::var("RW_QUERY_LOG_TRUNCATE_LEN") {
60        Ok(len) if len.parse::<usize>().is_ok() => len.parse::<usize>().unwrap(),
61        _ => {
62            if cfg!(debug_assertions) {
63                65536
64            } else {
65                1024
66            }
67        }
68    });
69
70tokio::task_local! {
71    /// The current session. Concrete type is erased for different session implementations.
72    pub static CURRENT_SESSION: Weak<dyn Any + Send + Sync>
73}
74
75/// The state machine for each psql connection.
76/// Read pg messages from tcp stream and write results back.
77pub struct PgProtocol<S, SM>
78where
79    SM: SessionManager,
80{
81    /// Used for write/read pg messages.
82    stream: PgStream<S>,
83    /// Current states of pg connection.
84    state: PgProtocolState,
85    /// Whether the connection is terminated.
86    is_terminate: bool,
87
88    session_mgr: Arc<SM>,
89    session: Option<Arc<SM::Session>>,
90
91    result_cache: HashMap<String, ResultCache<<SM::Session as Session>::ValuesStream>>,
92    unnamed_prepare_statement: Option<<SM::Session as Session>::PreparedStatement>,
93    prepare_statement_store: HashMap<String, <SM::Session as Session>::PreparedStatement>,
94    unnamed_portal: Option<<SM::Session as Session>::Portal>,
95    portal_store: HashMap<String, <SM::Session as Session>::Portal>,
96    // Used to store the dependency of portal and prepare statement.
97    // When we close a prepare statement, we need to close all the portals that depend on it.
98    statement_portal_dependency: HashMap<String, Vec<String>>,
99
100    // Used for ssl connection.
101    // If None, not expected to build ssl connection (panic).
102    tls_context: Option<SslContext>,
103
104    // TLS configuration including SSL enforcement setting
105    tls_config: Option<TlsConfig>,
106
107    // Used in extended query protocol. When encounter error in extended query, we need to ignore
108    // the following message util sync message.
109    ignore_util_sync: bool,
110
111    // Client Address
112    peer_addr: AddressRef,
113
114    redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
115    message_memory_manager: MessageMemoryManagerRef,
116}
117
118/// Configures TLS encryption for connections.
119#[derive(Debug, Clone)]
120pub struct TlsConfig {
121    /// The path to the TLS certificate.
122    pub cert: String,
123    /// The path to the TLS key.
124    pub key: String,
125    /// Whether to enforce SSL connections (reject non-SSL clients).
126    pub enforce_ssl: bool,
127}
128
129impl TlsConfig {
130    pub fn new_default() -> Option<Self> {
131        let cert = std::env::var("RW_SSL_CERT").ok()?;
132        let key = std::env::var("RW_SSL_KEY").ok()?;
133        let enforce_ssl = env_var_is_true("RW_SSL_ENFORCE");
134        tracing::info!(
135            "RW_SSL_CERT={}, RW_SSL_KEY={}, RW_SSL_ENFORCE={}",
136            cert,
137            key,
138            enforce_ssl
139        );
140        Some(Self {
141            cert,
142            key,
143            enforce_ssl,
144        })
145    }
146}
147
148impl<S, SM> Drop for PgProtocol<S, SM>
149where
150    SM: SessionManager,
151{
152    fn drop(&mut self) {
153        if let Some(session) = &self.session {
154            // Clear the session in session manager.
155            self.session_mgr.end_session(session);
156        }
157    }
158}
159
160/// States flow happened from top to down.
161enum PgProtocolState {
162    Startup,
163    Regular,
164}
165
166/// Truncate 0 from C string in Bytes and stringify it (returns slice, no allocations).
167///
168/// PG protocol strings are always C strings.
169pub fn cstr_to_str(b: &Bytes) -> Result<&str, Utf8Error> {
170    let without_null = if b.last() == Some(&0) {
171        &b[..b.len() - 1]
172    } else {
173        &b[..]
174    };
175    std::str::from_utf8(without_null)
176}
177
178/// Record `sql` in the current tracing span.
179fn record_sql_in_span(
180    sql: &str,
181    redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
182) -> String {
183    let redacted_sql = if let Some(keywords) = redact_sql_option_keywords
184        && !keywords.is_empty()
185    {
186        redact_sql(sql, keywords)
187    } else {
188        sql.to_owned()
189    };
190    let truncated = truncated_fmt::TruncatedFmt(&redacted_sql, *RW_QUERY_LOG_TRUNCATE_LEN);
191    tracing::Span::current().record("sql", tracing::field::display(&truncated));
192    truncated.to_string()
193}
194
195/// Redacts SQL options. Data in DML is not redacted.
196fn redact_sql(sql: &str, keywords: RedactSqlOptionKeywordsRef) -> String {
197    match Parser::parse_sql(sql) {
198        Ok(sqls) => sqls
199            .into_iter()
200            .map(|sql| sql.to_redacted_string(keywords.clone()))
201            .join(";"),
202        Err(_) => sql.to_owned(),
203    }
204}
205
206#[derive(Clone)]
207pub struct ConnectionContext {
208    pub tls_config: Option<TlsConfig>,
209    pub redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
210    pub message_memory_manager: MessageMemoryManagerRef,
211}
212
213impl<S, SM> PgProtocol<S, SM>
214where
215    S: PgByteStream,
216    SM: SessionManager,
217{
218    pub fn new(
219        stream: S,
220        session_mgr: Arc<SM>,
221        peer_addr: AddressRef,
222        context: ConnectionContext,
223    ) -> Self {
224        let ConnectionContext {
225            tls_config,
226            redact_sql_option_keywords,
227            message_memory_manager,
228        } = context;
229        Self {
230            stream: PgStream::new(stream),
231            is_terminate: false,
232            state: PgProtocolState::Startup,
233            session_mgr,
234            session: None,
235            tls_context: tls_config
236                .as_ref()
237                .and_then(|e| build_ssl_ctx_from_config(e).ok()),
238            tls_config,
239            result_cache: Default::default(),
240            unnamed_prepare_statement: Default::default(),
241            prepare_statement_store: Default::default(),
242            unnamed_portal: Default::default(),
243            portal_store: Default::default(),
244            statement_portal_dependency: Default::default(),
245            ignore_util_sync: false,
246            peer_addr,
247            redact_sql_option_keywords,
248            message_memory_manager,
249        }
250    }
251
252    /// Run the protocol to serve the connection.
253    pub async fn run(&mut self) {
254        let mut notice_fut = None;
255
256        loop {
257            // Once a session is present, create a future to subscribe and send notices asynchronously.
258            if notice_fut.is_none()
259                && let Some(session) = self.session.clone()
260            {
261                let mut stream = self.stream.clone();
262                notice_fut = Some(Box::pin(async move {
263                    loop {
264                        let notice = session.next_notice().await;
265                        if let Err(e) = stream.write(BeMessage::NoticeResponse(&notice)).await {
266                            tracing::error!(error = %e.as_report(), notice, "failed to send notice");
267                        }
268                    }
269                }));
270            }
271
272            // Read and process messages.
273            let process = std::pin::pin!(async {
274                let (msg, _memory_guard) = match self.read_message().await {
275                    Ok(msg) => msg,
276                    Err(e) => {
277                        tracing::error!(error = %e.as_report(), "error when reading message");
278                        return true; // terminate the connection
279                    }
280                };
281                tracing::trace!(?msg, "received message");
282                self.process(msg).await
283            });
284
285            let terminated = if let Some(notice_fut) = notice_fut.as_mut() {
286                tokio::select! {
287                    _ = notice_fut => unreachable!(),
288                    terminated = process => terminated,
289                }
290            } else {
291                process.await
292            };
293
294            if terminated {
295                break;
296            }
297        }
298    }
299
300    /// Processes one message. Returns true if the connection is terminated.
301    pub async fn process(&mut self, msg: FeMessage) -> bool {
302        self.do_process(msg).await.is_none() || self.is_terminate
303    }
304
305    /// The root tracing span for processing a message. The target of the span is
306    /// [`PGWIRE_ROOT_SPAN_TARGET`].
307    ///
308    /// This is used to provide context for the (slow) query logs and traces.
309    ///
310    /// The span is only effective if there's a current session and the message is
311    /// query-related. Otherwise, `Span::none()` is returned.
312    fn root_span_for_msg(&self, msg: &FeMessage) -> tracing::Span {
313        let Some(session_id) = self.session.as_ref().map(|s| s.id().0) else {
314            return tracing::Span::none();
315        };
316
317        let mode = match msg {
318            FeMessage::Query(_) => "simple query",
319            FeMessage::Parse(_) => "extended query parse",
320            FeMessage::Execute(_) => "extended query execute",
321            _ => return tracing::Span::none(),
322        };
323
324        tracing::info_span!(
325            target: PGWIRE_ROOT_SPAN_TARGET,
326            "handle_query",
327            mode,
328            session_id,
329            sql = tracing::field::Empty, // record SQL later in each `process` call
330        )
331    }
332
333    /// Return type `Option<()>` is essentially a bool, but allows `?` for early return.
334    /// - `None` means to terminate the current connection
335    /// - `Some(())` means to continue processing the next message
336    async fn do_process(&mut self, msg: FeMessage) -> Option<()> {
337        let span = self.root_span_for_msg(&msg);
338        let weak_session = self
339            .session
340            .as_ref()
341            .map(|s| Arc::downgrade(s) as Weak<dyn Any + Send + Sync>);
342
343        // Processing the message itself.
344        //
345        // Note: pin the future to avoid stack overflow as we'll wrap it multiple times
346        // in the following code.
347        let fut = Box::pin(self.do_process_inner(msg));
348
349        // Set the current session as the context when processing the message, if exists.
350        let fut = async move {
351            if let Some(session) = weak_session {
352                CURRENT_SESSION.scope(session, fut).await
353            } else {
354                fut.await
355            }
356        };
357
358        // Catch unwind.
359        let fut = async move {
360            AssertUnwindSafe(fut)
361                .rw_catch_unwind()
362                .await
363                .unwrap_or_else(|payload| {
364                    Err(PsqlError::Panic(
365                        panic_message::panic_message(&payload).to_owned(),
366                    ))
367                })
368        };
369
370        // Slow query log.
371        let fut = async move {
372            let period = *SLOW_QUERY_LOG_PERIOD;
373            let mut fut = std::pin::pin!(fut);
374            let mut elapsed = Duration::ZERO;
375
376            // Report the SQL in the log periodically if the query is slow.
377            loop {
378                match tokio::time::timeout(period, &mut fut).await {
379                    Ok(result) => break result,
380                    Err(_) => {
381                        elapsed += period;
382                        tracing::info!(
383                            target: PGWIRE_SLOW_QUERY_LOG,
384                            elapsed = %format_args!("{}ms", elapsed.as_millis()),
385                            "slow query"
386                        );
387                    }
388                }
389            }
390        };
391
392        // Query log.
393        let fut = async move {
394            if !tracing::Span::current().is_none() {
395                tracing::info!(
396                    target: PGWIRE_QUERY_LOG,
397                    status = "started",
398                );
399            }
400
401            let start = Instant::now();
402            let result = fut.await;
403            let elapsed = start.elapsed();
404
405            // Always log if an error occurs.
406            // Note: all messages will be processed through this code path, making it the
407            //       only necessary place to log errors.
408            if let Err(error) = &result {
409                if cfg!(debug_assertions) && !Deployment::current().is_ci() {
410                    // For local debugging, we print the error with backtrace.
411                    // It's useful only when:
412                    // - no additional context is added to the error
413                    // - backtrace is captured in the error
414                    // - backtrace is not printed in the middle
415                    tracing::error!(error = ?error.as_report(), "error when process message");
416                } else {
417                    tracing::error!(error = %error.as_report(), "error when process message");
418                }
419            }
420
421            // Log to optionally-enabled target `PGWIRE_QUERY_LOG`.
422            // Only log if we're currently in a tracing span set in `span_for_msg`.
423            if !tracing::Span::current().is_none() {
424                tracing::info!(
425                    target: PGWIRE_QUERY_LOG,
426                    status = if result.is_ok() { "ok" } else { "err" },
427                    time = %format_args!("{}ms", elapsed.as_millis()),
428                );
429            }
430
431            result
432        };
433
434        // Tracing span.
435        let fut = fut.instrument(span);
436
437        // Execute the future and handle the error.
438        match fut.await {
439            Ok(()) => Some(()),
440            Err(e) => {
441                match e {
442                    PsqlError::IoError(io_err) => {
443                        if io_err.kind() == std::io::ErrorKind::UnexpectedEof {
444                            return None;
445                        }
446                    }
447
448                    PsqlError::SslError(_) => {
449                        // For ssl error, because the stream has already been consumed, so there is
450                        // no way to write more message.
451                        return None;
452                    }
453
454                    PsqlError::StartupError(_) | PsqlError::PasswordError => {
455                        self.stream
456                            .write_no_flush(BeMessage::ErrorResponse(&e))
457                            .ok()?;
458                        let _ = self.stream.flush().await;
459                        return None;
460                    }
461
462                    PsqlError::SimpleQueryError(_) | PsqlError::ServerThrottle(_) => {
463                        self.stream
464                            .write_no_flush(BeMessage::ErrorResponse(&e))
465                            .ok()?;
466                        self.ready_for_query().ok()?;
467                    }
468
469                    PsqlError::IdleInTxnTimeout | PsqlError::Panic(_) => {
470                        self.stream
471                            .write_no_flush(BeMessage::ErrorResponse(&e))
472                            .ok()?;
473                        let _ = self.stream.flush().await;
474
475                        // 1. Catching the panic during message processing may leave the session in an
476                        // inconsistent state. We forcefully close the connection (then end the
477                        // session) here for safety.
478                        // 2. Idle in transaction timeout should also close the connection.
479                        return None;
480                    }
481
482                    PsqlError::Uncategorized(_)
483                    | PsqlError::ExtendedPrepareError(_)
484                    | PsqlError::ExtendedExecuteError(_) => {
485                        self.stream
486                            .write_no_flush(BeMessage::ErrorResponse(&e))
487                            .ok()?;
488                    }
489                }
490                let _ = self.stream.flush().await;
491                Some(())
492            }
493        }
494    }
495
496    async fn do_process_inner(&mut self, msg: FeMessage) -> PsqlResult<()> {
497        // Ignore util sync message.
498        if self.ignore_util_sync {
499            if let FeMessage::Sync = msg {
500            } else {
501                tracing::trace!("ignore message {:?} until sync.", msg);
502                return Ok(());
503            }
504        }
505
506        match msg {
507            FeMessage::Gss => self.process_gss_msg().await?,
508            FeMessage::Ssl => self.process_ssl_msg().await?,
509            FeMessage::Startup(msg) => self.process_startup_msg(msg).await?,
510            FeMessage::Password(msg) => self.process_password_msg(msg).await?,
511            FeMessage::Query(query_msg) => {
512                let sql = Arc::from(query_msg.get_sql()?);
513                // The process_query_msg can be slow. Release potential large FeQueryMessage early.
514                drop(query_msg);
515                self.process_query_msg(sql).await?
516            }
517            FeMessage::CancelQuery(m) => self.process_cancel_msg(m)?,
518            FeMessage::Terminate => self.process_terminate(),
519            FeMessage::Parse(m) => {
520                if let Err(err) = self.process_parse_msg(m).await {
521                    self.ignore_util_sync = true;
522                    return Err(err);
523                }
524            }
525            FeMessage::Bind(m) => {
526                if let Err(err) = self.process_bind_msg(m) {
527                    self.ignore_util_sync = true;
528                    return Err(err);
529                }
530            }
531            FeMessage::Execute(m) => {
532                if let Err(err) = self.process_execute_msg(m).await {
533                    self.ignore_util_sync = true;
534                    return Err(err);
535                }
536            }
537            FeMessage::Describe(m) => {
538                if let Err(err) = self.process_describe_msg(m) {
539                    self.ignore_util_sync = true;
540                    return Err(err);
541                }
542            }
543            FeMessage::Sync => {
544                self.ignore_util_sync = false;
545                self.ready_for_query()?
546            }
547            FeMessage::Close(m) => {
548                if let Err(err) = self.process_close_msg(m) {
549                    self.ignore_util_sync = true;
550                    return Err(err);
551                }
552            }
553            FeMessage::Flush => {
554                if let Err(err) = self.stream.flush().await {
555                    self.ignore_util_sync = true;
556                    return Err(err.into());
557                }
558            }
559            FeMessage::HealthCheck => self.process_health_check(),
560            FeMessage::ServerThrottle(reason) => match reason {
561                ServerThrottleReason::TooLargeMessage => {
562                    return Err(PsqlError::ServerThrottle(format!(
563                        "max_single_query_size_bytes {} has been exceeded, please either reduce the query size or increase the limit",
564                        self.message_memory_manager.max_filter_bytes
565                    )));
566                }
567                ServerThrottleReason::TooManyMemoryUsage => {
568                    return Err(PsqlError::ServerThrottle(format!(
569                        "max_total_query_size_bytes {} has been exceeded, please either retry or increase the limit",
570                        self.message_memory_manager.max_running_bytes
571                    )));
572                }
573            },
574        }
575        self.stream.flush().await?;
576        Ok(())
577    }
578
579    pub async fn read_message(&mut self) -> io::Result<(FeMessage, Option<MessageMemoryGuard>)> {
580        match self.state {
581            PgProtocolState::Startup => self
582                .stream
583                .read_startup()
584                .await
585                .map(|message: FeMessage| (message, None)),
586            PgProtocolState::Regular => {
587                self.stream.read_header().await?;
588                let guard = if let Some(ref header) = self.stream.read_header {
589                    let payload_len = std::cmp::max(header.payload_len, 0) as u64;
590                    let (reason, guard) = self.message_memory_manager.add(payload_len);
591                    if let Some(reason) = reason {
592                        // Release the memory ASAP.
593                        drop(guard);
594                        self.stream.skip_body().await?;
595                        return Ok((FeMessage::ServerThrottle(reason), None));
596                    }
597                    guard
598                } else {
599                    None
600                };
601                let message = self.stream.read_body().await?;
602                Ok((message, guard))
603            }
604        }
605    }
606
607    /// Writes a `ReadyForQuery` message to the client without flushing.
608    fn ready_for_query(&mut self) -> io::Result<()> {
609        self.stream.write_no_flush(BeMessage::ReadyForQuery(
610            self.session
611                .as_ref()
612                .map(|s| s.transaction_status())
613                .unwrap_or(TransactionStatus::Idle),
614        ))
615    }
616
617    async fn process_gss_msg(&mut self) -> PsqlResult<()> {
618        // We don't support GSSAPI, so we just say no gracefully.
619        self.stream.write(BeMessage::EncryptionResponseNo).await?;
620        Ok(())
621    }
622
623    async fn process_ssl_msg(&mut self) -> PsqlResult<()> {
624        if let Some(context) = self.tls_context.as_ref() {
625            // If got and ssl context, say yes for ssl connection.
626            // Construct ssl stream and replace with current one.
627            self.stream.write(BeMessage::EncryptionResponseSsl).await?;
628            self.stream.upgrade_to_ssl(context).await?;
629        } else {
630            // If no, say no for encryption.
631            self.stream.write(BeMessage::EncryptionResponseNo).await?;
632        }
633
634        Ok(())
635    }
636
637    async fn process_startup_msg(&mut self, msg: FeStartupMessage) -> PsqlResult<()> {
638        // Check SSL enforcement: if SSL is enforced but connection is not using SSL, reject
639        if let Some(ref tls_config) = self.tls_config
640            && tls_config.enforce_ssl
641            && !self.stream.is_ssl_connection().await
642        {
643            return Err(PsqlError::StartupError(
644                "SSL connection is required but not established".into(),
645            ));
646        }
647
648        let db_name = msg
649            .config
650            .get("database")
651            .cloned()
652            .unwrap_or_else(|| "dev".to_owned());
653        let user_name = msg
654            .config
655            .get("user")
656            .cloned()
657            .unwrap_or_else(|| "root".to_owned());
658
659        let session = self
660            .session_mgr
661            .connect(&db_name, &user_name, self.peer_addr.clone())
662            .map_err(PsqlError::StartupError)?;
663
664        let application_name = msg.config.get("application_name");
665        if let Some(application_name) = application_name {
666            session
667                .set_config("application_name", application_name.clone())
668                .map_err(PsqlError::StartupError)?;
669        }
670
671        match session.user_authenticator() {
672            UserAuthenticator::None => {
673                self.stream.write_no_flush(BeMessage::AuthenticationOk)?;
674
675                // Cancel request need this for identify and verification. According to postgres
676                // doc, it should be written to buffer after receive AuthenticationOk.
677                self.stream
678                    .write_no_flush(BeMessage::BackendKeyData(session.id()))?;
679
680                self.stream.write_no_flush(BeMessage::ParameterStatus(
681                    BeParameterStatusMessage::TimeZone(&session.get_config("timezone")?),
682                ))?;
683                self.stream
684                    .write_parameter_status_msg_no_flush(&ParameterStatus {
685                        application_name: application_name.cloned(),
686                    })?;
687                self.ready_for_query()?;
688            }
689            UserAuthenticator::ClearText(_)
690            | UserAuthenticator::OAuth { .. }
691            | UserAuthenticator::Ldap(..) => {
692                self.stream
693                    .write_no_flush(BeMessage::AuthenticationCleartextPassword)?;
694            }
695            UserAuthenticator::Md5WithSalt { salt, .. } => {
696                self.stream
697                    .write_no_flush(BeMessage::AuthenticationMd5Password(salt))?;
698            }
699        }
700
701        self.session = Some(session);
702        self.state = PgProtocolState::Regular;
703        Ok(())
704    }
705
706    async fn process_password_msg(&mut self, msg: FePasswordMessage) -> PsqlResult<()> {
707        let session = self.session.as_ref().unwrap();
708        let authenticator = session.user_authenticator();
709        authenticator.authenticate(&msg.password).await?;
710        self.stream.write_no_flush(BeMessage::AuthenticationOk)?;
711        self.stream.write_no_flush(BeMessage::ParameterStatus(
712            BeParameterStatusMessage::TimeZone(&session.get_config("timezone")?),
713        ))?;
714        self.stream
715            .write_parameter_status_msg_no_flush(&ParameterStatus::default())?;
716        self.ready_for_query()?;
717        self.state = PgProtocolState::Regular;
718        Ok(())
719    }
720
721    fn process_cancel_msg(&mut self, m: FeCancelMessage) -> PsqlResult<()> {
722        let session_id = (m.target_process_id, m.target_secret_key);
723        tracing::trace!("cancel query in session: {:?}", session_id);
724        self.session_mgr.cancel_queries_in_session(session_id);
725        self.session_mgr.cancel_creating_jobs_in_session(session_id);
726        self.is_terminate = true;
727        Ok(())
728    }
729
730    async fn process_query_msg(&mut self, sql: Arc<str>) -> PsqlResult<()> {
731        let truncated_sql = record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
732        let session = self.session.clone().unwrap();
733
734        session.check_idle_in_transaction_timeout()?;
735        // Store only truncated SQL in context to prevent excessive memory usage from large SQL.
736        let _exec_context_guard = session.init_exec_context(truncated_sql.into());
737        self.inner_process_query_msg(sql, session.clone()).await
738    }
739
740    async fn inner_process_query_msg(
741        &mut self,
742        sql: Arc<str>,
743        session: Arc<SM::Session>,
744    ) -> PsqlResult<()> {
745        // Parse sql.
746        let stmts =
747            Parser::parse_sql(&sql).map_err(|err| PsqlError::SimpleQueryError(err.into()))?;
748        // The following inner_process_query_msg_one_stmt can be slow. Release potential large String early.
749        drop(sql);
750        if stmts.is_empty() {
751            self.stream.write_no_flush(BeMessage::EmptyQueryResponse)?;
752        }
753
754        // Execute multiple statements in simple query. KISS later.
755        for stmt in stmts {
756            self.inner_process_query_msg_one_stmt(stmt, session.clone())
757                .await?;
758        }
759        // Put this line inside the for loop above will lead to unfinished/stuck regress test...Not
760        // sure the reason.
761        self.ready_for_query()?;
762        Ok(())
763    }
764
765    async fn inner_process_query_msg_one_stmt(
766        &mut self,
767        stmt: Statement,
768        session: Arc<SM::Session>,
769    ) -> PsqlResult<()> {
770        let session = session.clone();
771
772        // execute query
773        let res = session.clone().run_one_query(stmt, Format::Text).await;
774
775        // Take all remaining notices (if any) and send them before `CommandComplete`.
776        while let Some(notice) = session.next_notice().now_or_never() {
777            self.stream
778                .write_no_flush(BeMessage::NoticeResponse(&notice))?;
779        }
780
781        let mut res = res.map_err(PsqlError::SimpleQueryError)?;
782
783        for notice in res.notices() {
784            self.stream
785                .write_no_flush(BeMessage::NoticeResponse(notice))?;
786        }
787
788        let status = res.status();
789        if let Some(ref application_name) = status.application_name {
790            self.stream.write_no_flush(BeMessage::ParameterStatus(
791                BeParameterStatusMessage::ApplicationName(application_name),
792            ))?;
793        }
794
795        if res.is_copy_query_to_stdout() {
796            self.stream
797                .write_no_flush(BeMessage::CopyOutResponse(res.row_desc().len()))?;
798            let mut count = 0;
799            while let Some(row_set) = res.values_stream().next().await {
800                let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
801                for row in row_set {
802                    self.stream.write_no_flush(BeMessage::CopyData(&row))?;
803                    count += 1;
804                }
805            }
806
807            self.stream.write_no_flush(BeMessage::CopyDone)?;
808
809            // Run the callback before sending the `CommandComplete` message.
810            res.run_callback().await?;
811
812            self.stream
813                .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
814                    stmt_type: res.stmt_type(),
815                    rows_cnt: count,
816                }))?;
817        } else if res.is_query() {
818            self.stream
819                .write_no_flush(BeMessage::RowDescription(res.row_desc()))?;
820
821            let mut rows_cnt = 0;
822
823            while let Some(row_set) = res.values_stream().next().await {
824                let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
825                for row in row_set {
826                    self.stream.write_no_flush(BeMessage::DataRow(&row))?;
827                    rows_cnt += 1;
828                }
829            }
830
831            // Run the callback before sending the `CommandComplete` message.
832            res.run_callback().await?;
833
834            self.stream
835                .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
836                    stmt_type: res.stmt_type(),
837                    rows_cnt,
838                }))?;
839        } else if res.stmt_type().is_dml() && !res.stmt_type().is_returning() {
840            let first_row_set = res.values_stream().next().await;
841            let first_row_set = match first_row_set {
842                None => {
843                    return Err(PsqlError::Uncategorized(
844                        anyhow::anyhow!("no affected rows in output").into(),
845                    ));
846                }
847                Some(row) => row.map_err(PsqlError::SimpleQueryError)?,
848            };
849            let affected_rows_str = first_row_set[0].values()[0]
850                .as_ref()
851                .expect("compute node should return affected rows in output");
852
853            assert!(matches!(res.row_cnt_format(), Some(Format::Text)));
854            let affected_rows_cnt = String::from_utf8(affected_rows_str.to_vec())
855                .unwrap()
856                .parse()
857                .unwrap_or_default();
858
859            // Run the callback before sending the `CommandComplete` message.
860            res.run_callback().await?;
861
862            self.stream
863                .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
864                    stmt_type: res.stmt_type(),
865                    rows_cnt: affected_rows_cnt,
866                }))?;
867        } else {
868            // Run the callback before sending the `CommandComplete` message.
869            res.run_callback().await?;
870
871            self.stream
872                .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
873                    stmt_type: res.stmt_type(),
874                    rows_cnt: 0,
875                }))?;
876        }
877
878        Ok(())
879    }
880
881    fn process_terminate(&mut self) {
882        self.is_terminate = true;
883    }
884
885    fn process_health_check(&mut self) {
886        tracing::debug!("health check");
887        self.is_terminate = true;
888    }
889
890    async fn process_parse_msg(&mut self, mut msg: FeParseMessage) -> PsqlResult<()> {
891        let sql = Arc::from(cstr_to_str(&msg.sql_bytes).unwrap());
892        record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
893        let session = self.session.clone().unwrap();
894        let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
895        let type_ids = std::mem::take(&mut msg.type_ids);
896        // The inner_process_parse_msg can be slow. Release potential large FeParseMessage early.
897        drop(msg);
898        self.inner_process_parse_msg(session, sql, statement_name, type_ids)
899            .await?;
900        Ok(())
901    }
902
903    async fn inner_process_parse_msg(
904        &mut self,
905        session: Arc<SM::Session>,
906        sql: Arc<str>,
907        statement_name: String,
908        type_ids: Vec<i32>,
909    ) -> PsqlResult<()> {
910        if statement_name.is_empty() {
911            // Remove the unnamed prepare statement first, in case the unsupported sql binds a wrong
912            // prepare statement.
913            self.unnamed_prepare_statement.take();
914        } else if self.prepare_statement_store.contains_key(&statement_name) {
915            return Err(PsqlError::ExtendedPrepareError(
916                "Duplicated statement name".into(),
917            ));
918        }
919
920        let stmt = {
921            let stmts = Parser::parse_sql(&sql)
922                .map_err(|err| PsqlError::ExtendedPrepareError(err.into()))?;
923            drop(sql);
924            if stmts.len() > 1 {
925                return Err(PsqlError::ExtendedPrepareError(
926                    "Only one statement is allowed in extended query mode".into(),
927                ));
928            }
929
930            stmts.into_iter().next()
931        };
932
933        let param_types: Vec<Option<DataType>> = type_ids
934            .iter()
935            .map(|&id| {
936                // 0 means unspecified type
937                // ref: https://www.postgresql.org/docs/15/protocol-message-formats.html#:~:text=Placing%20a%20zero%20here%20is%20equivalent%20to%20leaving%20the%20type%20unspecified.
938                if id == 0 {
939                    Ok(None)
940                } else {
941                    DataType::from_oid(id)
942                        .map(Some)
943                        .map_err(|e| PsqlError::ExtendedPrepareError(e.into()))
944                }
945            })
946            .try_collect()?;
947
948        let prepare_statement = session
949            .parse(stmt, param_types)
950            .await
951            .map_err(PsqlError::ExtendedPrepareError)?;
952
953        if statement_name.is_empty() {
954            self.unnamed_prepare_statement.replace(prepare_statement);
955        } else {
956            self.prepare_statement_store
957                .insert(statement_name.clone(), prepare_statement);
958        }
959
960        self.statement_portal_dependency
961            .entry(statement_name)
962            .or_default()
963            .clear();
964
965        self.stream.write_no_flush(BeMessage::ParseComplete)?;
966        Ok(())
967    }
968
969    fn process_bind_msg(&mut self, msg: FeBindMessage) -> PsqlResult<()> {
970        let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
971        let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
972        let session = self.session.clone().unwrap();
973
974        if self.portal_store.contains_key(&portal_name) {
975            return Err(PsqlError::Uncategorized("Duplicated portal name".into()));
976        }
977
978        let prepare_statement = self.get_statement(&statement_name)?;
979
980        let result_formats = msg
981            .result_format_codes
982            .iter()
983            .map(|&format_code| Format::from_i16(format_code))
984            .try_collect()?;
985        let param_formats = msg
986            .param_format_codes
987            .iter()
988            .map(|&format_code| Format::from_i16(format_code))
989            .try_collect()?;
990
991        let portal = session
992            .bind(prepare_statement, msg.params, param_formats, result_formats)
993            .map_err(PsqlError::Uncategorized)?;
994
995        if portal_name.is_empty() {
996            self.result_cache.remove(&portal_name);
997            self.unnamed_portal.replace(portal);
998        } else {
999            assert!(
1000                !self.result_cache.contains_key(&portal_name),
1001                "Named portal never can be overridden."
1002            );
1003            self.portal_store.insert(portal_name.clone(), portal);
1004        }
1005
1006        self.statement_portal_dependency
1007            .get_mut(&statement_name)
1008            .unwrap()
1009            .push(portal_name);
1010
1011        self.stream.write_no_flush(BeMessage::BindComplete)?;
1012        Ok(())
1013    }
1014
1015    async fn process_execute_msg(&mut self, msg: FeExecuteMessage) -> PsqlResult<()> {
1016        let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
1017        let row_max = msg.max_rows as usize;
1018        drop(msg);
1019        let session = self.session.clone().unwrap();
1020
1021        match self.result_cache.remove(&portal_name) {
1022            Some(mut result_cache) => {
1023                assert!(self.portal_store.contains_key(&portal_name));
1024
1025                let is_consume_completed =
1026                    result_cache.consume::<S>(row_max, &mut self.stream).await?;
1027
1028                if !is_consume_completed {
1029                    self.result_cache.insert(portal_name, result_cache);
1030                }
1031            }
1032            _ => {
1033                let portal = self.get_portal(&portal_name)?;
1034                let sql = format!("{}", portal);
1035                let truncated_sql =
1036                    record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
1037                drop(sql);
1038
1039                session.check_idle_in_transaction_timeout()?;
1040                // Store only truncated SQL in context to prevent excessive memory usage from large SQL.
1041                let _exec_context_guard = session.init_exec_context(truncated_sql.into());
1042                let result = session.clone().execute(portal).await;
1043
1044                let pg_response = result.map_err(PsqlError::ExtendedExecuteError)?;
1045                let mut result_cache = ResultCache::new(pg_response);
1046                let is_consume_completed =
1047                    result_cache.consume::<S>(row_max, &mut self.stream).await?;
1048                if !is_consume_completed {
1049                    self.result_cache.insert(portal_name, result_cache);
1050                }
1051            }
1052        }
1053
1054        Ok(())
1055    }
1056
1057    fn process_describe_msg(&mut self, msg: FeDescribeMessage) -> PsqlResult<()> {
1058        let name = cstr_to_str(&msg.name).unwrap().to_owned();
1059        let session = self.session.clone().unwrap();
1060        //  b'S' => Statement
1061        //  b'P' => Portal
1062
1063        assert!(msg.kind == b'S' || msg.kind == b'P');
1064        if msg.kind == b'S' {
1065            let prepare_statement = self.get_statement(&name)?;
1066
1067            let (param_types, row_descriptions) = self
1068                .session
1069                .clone()
1070                .unwrap()
1071                .describe_statement(prepare_statement)
1072                .map_err(PsqlError::Uncategorized)?;
1073            self.stream.write_no_flush(BeMessage::ParameterDescription(
1074                &param_types.iter().map(|t| t.to_oid()).collect_vec(),
1075            ))?;
1076
1077            if row_descriptions.is_empty() {
1078                // According https://www.postgresql.org/docs/current/protocol-flow.html#:~:text=The%20response%20is%20a%20RowDescri[…]0a%20query%20that%20will%20return%20rows%3B,
1079                // return NoData message if the statement is not a query.
1080                self.stream.write_no_flush(BeMessage::NoData)?;
1081            } else {
1082                self.stream
1083                    .write_no_flush(BeMessage::RowDescription(&row_descriptions))?;
1084            }
1085        } else if msg.kind == b'P' {
1086            let portal = self.get_portal(&name)?;
1087
1088            let row_descriptions = session
1089                .describe_portal(portal)
1090                .map_err(PsqlError::Uncategorized)?;
1091
1092            if row_descriptions.is_empty() {
1093                // According https://www.postgresql.org/docs/current/protocol-flow.html#:~:text=The%20response%20is%20a%20RowDescri[…]0a%20query%20that%20will%20return%20rows%3B,
1094                // return NoData message if the statement is not a query.
1095                self.stream.write_no_flush(BeMessage::NoData)?;
1096            } else {
1097                self.stream
1098                    .write_no_flush(BeMessage::RowDescription(&row_descriptions))?;
1099            }
1100        }
1101        Ok(())
1102    }
1103
1104    fn process_close_msg(&mut self, msg: FeCloseMessage) -> PsqlResult<()> {
1105        let name = cstr_to_str(&msg.name).unwrap().to_owned();
1106        assert!(msg.kind == b'S' || msg.kind == b'P');
1107        if msg.kind == b'S' {
1108            if name.is_empty() {
1109                self.unnamed_prepare_statement = None;
1110            } else {
1111                self.prepare_statement_store.remove(&name);
1112            }
1113            for portal_name in self
1114                .statement_portal_dependency
1115                .remove(&name)
1116                .unwrap_or_default()
1117            {
1118                self.remove_portal(&portal_name);
1119            }
1120        } else if msg.kind == b'P' {
1121            self.remove_portal(&name);
1122        }
1123        self.stream.write_no_flush(BeMessage::CloseComplete)?;
1124        Ok(())
1125    }
1126
1127    fn remove_portal(&mut self, portal_name: &str) {
1128        if portal_name.is_empty() {
1129            self.unnamed_portal = None;
1130        } else {
1131            self.portal_store.remove(portal_name);
1132        }
1133        self.result_cache.remove(portal_name);
1134    }
1135
1136    fn get_portal(&self, portal_name: &str) -> PsqlResult<<SM::Session as Session>::Portal> {
1137        if portal_name.is_empty() {
1138            Ok(self
1139                .unnamed_portal
1140                .as_ref()
1141                .ok_or_else(|| PsqlError::Uncategorized("unnamed portal not found".into()))?
1142                .clone())
1143        } else {
1144            Ok(self
1145                .portal_store
1146                .get(portal_name)
1147                .ok_or_else(|| {
1148                    PsqlError::Uncategorized(format!("Portal {} not found", portal_name).into())
1149                })?
1150                .clone())
1151        }
1152    }
1153
1154    fn get_statement(
1155        &self,
1156        statement_name: &str,
1157    ) -> PsqlResult<<SM::Session as Session>::PreparedStatement> {
1158        if statement_name.is_empty() {
1159            Ok(self
1160                .unnamed_prepare_statement
1161                .as_ref()
1162                .ok_or_else(|| {
1163                    PsqlError::Uncategorized("unnamed prepare statement not found".into())
1164                })?
1165                .clone())
1166        } else {
1167            Ok(self
1168                .prepare_statement_store
1169                .get(statement_name)
1170                .ok_or_else(|| {
1171                    PsqlError::Uncategorized(
1172                        format!("Prepare statement {} not found", statement_name).into(),
1173                    )
1174                })?
1175                .clone())
1176        }
1177    }
1178}
1179
1180enum PgStreamInner<S> {
1181    /// Used for the intermediate state when converting from unencrypted to ssl stream.
1182    Placeholder,
1183    /// An unencrypted stream.
1184    Unencrypted(S),
1185    /// An ssl stream.
1186    Ssl(SslStream<S>),
1187}
1188
1189/// Trait for a byte stream that can be used for pg protocol.
1190pub trait PgByteStream: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1191impl<S> PgByteStream for S where S: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1192
1193/// Wraps a byte stream and read/write pg messages.
1194///
1195/// Cloning a `PgStream` will share the same stream but a fresh & independent write buffer,
1196/// so that it can be used to write messages concurrently without interference.
1197pub struct PgStream<S> {
1198    /// The underlying stream.
1199    stream: Arc<Mutex<PgStreamInner<S>>>,
1200    /// Write into buffer before flush to stream.
1201    write_buf: BytesMut,
1202    read_header: Option<FeMessageHeader>,
1203}
1204
1205impl<S> PgStream<S> {
1206    /// Create a new `PgStream` with the given stream and default write buffer capacity.
1207    pub fn new(stream: S) -> Self {
1208        const DEFAULT_WRITE_BUF_CAPACITY: usize = 10 * 1024;
1209
1210        Self {
1211            stream: Arc::new(Mutex::new(PgStreamInner::Unencrypted(stream))),
1212            write_buf: BytesMut::with_capacity(DEFAULT_WRITE_BUF_CAPACITY),
1213            read_header: None,
1214        }
1215    }
1216
1217    /// Check if the current connection is using SSL
1218    async fn is_ssl_connection(&self) -> bool {
1219        let stream = self.stream.lock().await;
1220        matches!(*stream, PgStreamInner::Ssl(_))
1221    }
1222}
1223
1224impl<S> Clone for PgStream<S> {
1225    fn clone(&self) -> Self {
1226        Self {
1227            stream: Arc::clone(&self.stream),
1228            write_buf: BytesMut::with_capacity(self.write_buf.capacity()),
1229            read_header: self.read_header.clone(),
1230        }
1231    }
1232}
1233
1234/// At present there is a hard-wired set of parameters for which
1235/// ParameterStatus will be generated: they are:
1236///
1237///  * `server_version`
1238///  * `server_encoding`
1239///  * `client_encoding`
1240///  * `application_name`
1241///  * `is_superuser`
1242///  * `session_authorization`
1243///  * `DateStyle`
1244///  * `IntervalStyle`
1245///  * `TimeZone`
1246///  * `integer_datetimes`
1247///  * `standard_conforming_string`
1248///
1249/// See: <https://www.postgresql.org/docs/9.2/static/protocol-flow.html#PROTOCOL-ASYNC>.
1250#[derive(Debug, Default, Clone)]
1251pub struct ParameterStatus {
1252    pub application_name: Option<String>,
1253}
1254
1255impl<S> PgStream<S>
1256where
1257    S: PgByteStream,
1258{
1259    async fn read_startup(&mut self) -> io::Result<FeMessage> {
1260        let mut stream = self.stream.lock().await;
1261        match &mut *stream {
1262            PgStreamInner::Placeholder => unreachable!(),
1263            PgStreamInner::Unencrypted(stream) => FeStartupMessage::read(stream).await,
1264            PgStreamInner::Ssl(ssl_stream) => FeStartupMessage::read(ssl_stream).await,
1265        }
1266    }
1267
1268    async fn read_header(&mut self) -> io::Result<()> {
1269        let mut stream = self.stream.lock().await;
1270        match &mut *stream {
1271            PgStreamInner::Placeholder => unreachable!(),
1272            PgStreamInner::Unencrypted(stream) => {
1273                self.read_header = Some(FeMessage::read_header(stream).await?);
1274                Ok(())
1275            }
1276            PgStreamInner::Ssl(ssl_stream) => {
1277                self.read_header = Some(FeMessage::read_header(ssl_stream).await?);
1278                Ok(())
1279            }
1280        }
1281    }
1282
1283    async fn read_body(&mut self) -> io::Result<FeMessage> {
1284        let mut stream = self.stream.lock().await;
1285        let header = self
1286            .read_header
1287            .take()
1288            .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1289        match &mut *stream {
1290            PgStreamInner::Placeholder => unreachable!(),
1291            PgStreamInner::Unencrypted(stream) => FeMessage::read_body(stream, header).await,
1292            PgStreamInner::Ssl(ssl_stream) => FeMessage::read_body(ssl_stream, header).await,
1293        }
1294    }
1295
1296    async fn skip_body(&mut self) -> io::Result<()> {
1297        let mut stream = self.stream.lock().await;
1298        let header = self
1299            .read_header
1300            .take()
1301            .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1302        match &mut *stream {
1303            PgStreamInner::Placeholder => unreachable!(),
1304            PgStreamInner::Unencrypted(stream) => FeMessage::skip_body(stream, header).await,
1305            PgStreamInner::Ssl(ssl_stream) => FeMessage::skip_body(ssl_stream, header).await,
1306        }
1307    }
1308
1309    fn write_parameter_status_msg_no_flush(&mut self, status: &ParameterStatus) -> io::Result<()> {
1310        self.write_no_flush(BeMessage::ParameterStatus(
1311            BeParameterStatusMessage::ClientEncoding(SERVER_ENCODING),
1312        ))?;
1313        self.write_no_flush(BeMessage::ParameterStatus(
1314            BeParameterStatusMessage::StandardConformingString(STANDARD_CONFORMING_STRINGS),
1315        ))?;
1316        self.write_no_flush(BeMessage::ParameterStatus(
1317            BeParameterStatusMessage::ServerVersion(PG_VERSION),
1318        ))?;
1319        if let Some(application_name) = &status.application_name {
1320            self.write_no_flush(BeMessage::ParameterStatus(
1321                BeParameterStatusMessage::ApplicationName(application_name),
1322            ))?;
1323        }
1324        Ok(())
1325    }
1326
1327    pub fn write_no_flush(&mut self, message: BeMessage<'_>) -> io::Result<()> {
1328        BeMessage::write(&mut self.write_buf, message)
1329    }
1330
1331    async fn write(&mut self, message: BeMessage<'_>) -> io::Result<()> {
1332        self.write_no_flush(message)?;
1333        self.flush().await?;
1334        Ok(())
1335    }
1336
1337    async fn flush(&mut self) -> io::Result<()> {
1338        let mut stream = self.stream.lock().await;
1339        match &mut *stream {
1340            PgStreamInner::Placeholder => unreachable!(),
1341            PgStreamInner::Unencrypted(stream) => {
1342                stream.write_all(&self.write_buf).await?;
1343                stream.flush().await?;
1344            }
1345            PgStreamInner::Ssl(ssl_stream) => {
1346                ssl_stream.write_all(&self.write_buf).await?;
1347                ssl_stream.flush().await?;
1348            }
1349        }
1350        self.write_buf.clear();
1351        Ok(())
1352    }
1353}
1354
1355impl<S> PgStream<S>
1356where
1357    S: PgByteStream,
1358{
1359    /// Convert the underlying stream to ssl stream based on the given context.
1360    async fn upgrade_to_ssl(&mut self, ssl_ctx: &SslContextRef) -> PsqlResult<()> {
1361        let mut stream = self.stream.lock().await;
1362
1363        match std::mem::replace(&mut *stream, PgStreamInner::Placeholder) {
1364            PgStreamInner::Unencrypted(unencrypted_stream) => {
1365                let ssl = openssl::ssl::Ssl::new(ssl_ctx).unwrap();
1366                let mut ssl_stream =
1367                    tokio_openssl::SslStream::new(ssl, unencrypted_stream).unwrap();
1368
1369                if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
1370                    tracing::warn!(error = %e.as_report(), "Unable to set up an ssl connection");
1371                    let _ = ssl_stream.shutdown().await;
1372                    return Err(e.into());
1373                }
1374
1375                *stream = PgStreamInner::Ssl(ssl_stream);
1376            }
1377            PgStreamInner::Ssl(_) => panic!("the stream is already ssl"),
1378            PgStreamInner::Placeholder => unreachable!(),
1379        }
1380
1381        Ok(())
1382    }
1383}
1384
1385fn build_ssl_ctx_from_config(tls_config: &TlsConfig) -> PsqlResult<SslContext> {
1386    let mut acceptor = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls()).unwrap();
1387
1388    let key_path = &tls_config.key;
1389    let cert_path = &tls_config.cert;
1390
1391    // Build ssl acceptor according to the config.
1392    // Now we set every verify to true.
1393    acceptor
1394        .set_private_key_file(key_path, openssl::ssl::SslFiletype::PEM)
1395        .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1396    acceptor
1397        .set_ca_file(cert_path)
1398        .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1399    acceptor
1400        .set_certificate_chain_file(cert_path)
1401        .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1402    let acceptor = acceptor.build();
1403
1404    Ok(acceptor.into_context())
1405}
1406
1407pub mod truncated_fmt {
1408    use std::fmt::*;
1409
1410    struct TruncatedFormatter<'a, 'b> {
1411        remaining: usize,
1412        finished: bool,
1413        f: &'a mut Formatter<'b>,
1414    }
1415    impl Write for TruncatedFormatter<'_, '_> {
1416        fn write_str(&mut self, s: &str) -> Result {
1417            if self.finished {
1418                return Ok(());
1419            }
1420
1421            if self.remaining < s.len() {
1422                let actual = s.floor_char_boundary(self.remaining);
1423                self.f.write_str(&s[0..actual])?;
1424                self.remaining -= actual;
1425                self.f.write_str(&format!("...(truncated,{})", s.len()))?;
1426                self.finished = true; // so that ...(truncated) is printed exactly once
1427            } else {
1428                self.f.write_str(s)?;
1429                self.remaining -= s.len();
1430            }
1431            Ok(())
1432        }
1433    }
1434
1435    pub struct TruncatedFmt<'a, T>(pub &'a T, pub usize);
1436
1437    impl<T> Debug for TruncatedFmt<'_, T>
1438    where
1439        T: Debug,
1440    {
1441        fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1442            TruncatedFormatter {
1443                remaining: self.1,
1444                finished: false,
1445                f,
1446            }
1447            .write_fmt(format_args!("{:?}", self.0))
1448        }
1449    }
1450
1451    impl<T> Display for TruncatedFmt<'_, T>
1452    where
1453        T: Display,
1454    {
1455        fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1456            TruncatedFormatter {
1457                remaining: self.1,
1458                finished: false,
1459                f,
1460            }
1461            .write_fmt(format_args!("{}", self.0))
1462        }
1463    }
1464
1465    #[cfg(test)]
1466    mod tests {
1467        use super::*;
1468
1469        #[test]
1470        fn test_trunc_utf8() {
1471            assert_eq!(
1472                format!("{}", TruncatedFmt(&"select '🌊';", 10)),
1473                "select '...(truncated,14)",
1474            );
1475        }
1476    }
1477}
1478
1479#[cfg(test)]
1480mod tests {
1481    use std::collections::HashSet;
1482
1483    use super::*;
1484
1485    #[test]
1486    fn test_redact_parsable_sql() {
1487        let keywords = Arc::new(HashSet::from(["v2".into(), "v4".into(), "b".into()]));
1488        let sql = r"
1489        create source temp (k bigint, v varchar) with (
1490            connector = 'datagen',
1491            v1 = 123,
1492            v2 = 'with',
1493            v3 = false,
1494            v4 = '',
1495        ) FORMAT plain ENCODE json (a='1',b='2')
1496        ";
1497        assert_eq!(
1498            redact_sql(sql, keywords),
1499            "CREATE SOURCE temp (k BIGINT, v CHARACTER VARYING) WITH (connector = 'datagen', v1 = 123, v2 = [REDACTED], v3 = false, v4 = [REDACTED]) FORMAT PLAIN ENCODE JSON (a = '1', b = [REDACTED])"
1500        );
1501    }
1502}