pgwire/
pg_protocol.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::collections::HashMap;
17use std::io::ErrorKind;
18use std::panic::AssertUnwindSafe;
19use std::pin::Pin;
20use std::str::Utf8Error;
21use std::sync::{Arc, LazyLock, Weak};
22use std::time::{Duration, Instant};
23use std::{io, str};
24
25use bytes::{Bytes, BytesMut};
26use futures::FutureExt;
27use futures::stream::StreamExt;
28use itertools::Itertools;
29use openssl::ssl::{SslAcceptor, SslContext, SslContextRef, SslMethod};
30use risingwave_common::types::DataType;
31use risingwave_common::util::deployment::Deployment;
32use risingwave_common::util::env_var::env_var_is_true;
33use risingwave_common::util::panic::FutureCatchUnwindExt;
34use risingwave_common::util::query_log::*;
35use risingwave_common::{PG_VERSION, SERVER_ENCODING, STANDARD_CONFORMING_STRINGS};
36use risingwave_sqlparser::ast::{RedactSqlOptionKeywordsRef, Statement};
37use risingwave_sqlparser::parser::Parser;
38use thiserror_ext::AsReport;
39use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
40use tokio::sync::Mutex;
41use tokio_openssl::SslStream;
42use tracing::Instrument;
43
44use crate::error::{PsqlError, PsqlResult};
45use crate::memory_manager::{MessageMemoryGuard, MessageMemoryManagerRef};
46use crate::net::AddressRef;
47use crate::pg_extended::ResultCache;
48use crate::pg_message::{
49    BeCommandCompleteMessage, BeMessage, BeParameterStatusMessage, FeBindMessage, FeCancelMessage,
50    FeCloseMessage, FeDescribeMessage, FeExecuteMessage, FeMessage, FeMessageHeader,
51    FeParseMessage, FePasswordMessage, FeStartupMessage, ServerThrottleReason, TransactionStatus,
52};
53use crate::pg_server::{Session, SessionManager, UserAuthenticator};
54use crate::types::Format;
55
56/// Truncates query log if it's longer than `RW_QUERY_LOG_TRUNCATE_LEN`, to avoid log file being too
57/// large.
58static RW_QUERY_LOG_TRUNCATE_LEN: LazyLock<usize> =
59    LazyLock::new(|| match std::env::var("RW_QUERY_LOG_TRUNCATE_LEN") {
60        Ok(len) if len.parse::<usize>().is_ok() => len.parse::<usize>().unwrap(),
61        _ => {
62            if cfg!(debug_assertions) {
63                65536
64            } else {
65                1024
66            }
67        }
68    });
69
70tokio::task_local! {
71    /// The current session. Concrete type is erased for different session implementations.
72    pub static CURRENT_SESSION: Weak<dyn Any + Send + Sync>
73}
74
75/// The state machine for each psql connection.
76/// Read pg messages from tcp stream and write results back.
77pub struct PgProtocol<S, SM>
78where
79    SM: SessionManager,
80{
81    /// Used for write/read pg messages.
82    stream: PgStream<S>,
83    /// Current states of pg connection.
84    state: PgProtocolState,
85    /// Whether the connection is terminated.
86    is_terminate: bool,
87
88    session_mgr: Arc<SM>,
89    session: Option<Arc<SM::Session>>,
90
91    result_cache: HashMap<String, ResultCache<<SM::Session as Session>::ValuesStream>>,
92    unnamed_prepare_statement: Option<<SM::Session as Session>::PreparedStatement>,
93    prepare_statement_store: HashMap<String, <SM::Session as Session>::PreparedStatement>,
94    unnamed_portal: Option<<SM::Session as Session>::Portal>,
95    portal_store: HashMap<String, <SM::Session as Session>::Portal>,
96    // Used to store the dependency of portal and prepare statement.
97    // When we close a prepare statement, we need to close all the portals that depend on it.
98    statement_portal_dependency: HashMap<String, Vec<String>>,
99
100    // Used for ssl connection.
101    // If None, not expected to build ssl connection (panic).
102    tls_context: Option<SslContext>,
103
104    // TLS configuration including SSL enforcement setting
105    tls_config: Option<TlsConfig>,
106
107    // Used in extended query protocol. When encounter error in extended query, we need to ignore
108    // the following message util sync message.
109    ignore_util_sync: bool,
110
111    // Client Address
112    peer_addr: AddressRef,
113
114    redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
115    message_memory_manager: MessageMemoryManagerRef,
116}
117
118/// Configures TLS encryption for connections.
119#[derive(Debug, Clone)]
120pub struct TlsConfig {
121    /// The path to the TLS certificate.
122    pub cert: String,
123    /// The path to the TLS key.
124    pub key: String,
125    /// Whether to enforce SSL connections (reject non-SSL clients).
126    pub enforce_ssl: bool,
127}
128
129impl TlsConfig {
130    pub fn new_default() -> Option<Self> {
131        let cert = std::env::var("RW_SSL_CERT").ok()?;
132        let key = std::env::var("RW_SSL_KEY").ok()?;
133        let enforce_ssl = env_var_is_true("RW_SSL_ENFORCE");
134        tracing::info!(
135            "RW_SSL_CERT={}, RW_SSL_KEY={}, RW_SSL_ENFORCE={}",
136            cert,
137            key,
138            enforce_ssl
139        );
140        Some(Self {
141            cert,
142            key,
143            enforce_ssl,
144        })
145    }
146}
147
148impl<S, SM> Drop for PgProtocol<S, SM>
149where
150    SM: SessionManager,
151{
152    fn drop(&mut self) {
153        if let Some(session) = &self.session {
154            // Clear the session in session manager.
155            self.session_mgr.end_session(session);
156        }
157    }
158}
159
160/// States flow happened from top to down.
161enum PgProtocolState {
162    Startup,
163    Regular,
164}
165
166/// Truncate 0 from C string in Bytes and stringify it (returns slice, no allocations).
167///
168/// PG protocol strings are always C strings.
169pub fn cstr_to_str(b: &Bytes) -> Result<&str, Utf8Error> {
170    let without_null = if b.last() == Some(&0) {
171        &b[..b.len() - 1]
172    } else {
173        &b[..]
174    };
175    std::str::from_utf8(without_null)
176}
177
178/// Record `sql` in the current tracing span.
179fn record_sql_in_span(
180    sql: &str,
181    redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
182) -> String {
183    let redacted_sql = if let Some(keywords) = redact_sql_option_keywords
184        && !keywords.is_empty()
185    {
186        redact_sql(sql, keywords)
187    } else {
188        sql.to_owned()
189    };
190    let truncated = truncated_fmt::TruncatedFmt(&redacted_sql, *RW_QUERY_LOG_TRUNCATE_LEN);
191    tracing::Span::current().record("sql", tracing::field::display(&truncated));
192    truncated.to_string()
193}
194
195/// Redacts SQL options. Data in DML is not redacted.
196fn redact_sql(sql: &str, keywords: RedactSqlOptionKeywordsRef) -> String {
197    match Parser::parse_sql(sql) {
198        Ok(sqls) => sqls
199            .into_iter()
200            .map(|sql| sql.to_redacted_string(keywords.clone()))
201            .join(";"),
202        Err(_) => sql.to_owned(),
203    }
204}
205
206#[derive(Clone)]
207pub struct ConnectionContext {
208    pub tls_config: Option<TlsConfig>,
209    pub redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
210    pub message_memory_manager: MessageMemoryManagerRef,
211}
212
213impl<S, SM> PgProtocol<S, SM>
214where
215    S: PgByteStream,
216    SM: SessionManager,
217{
218    pub fn new(
219        stream: S,
220        session_mgr: Arc<SM>,
221        peer_addr: AddressRef,
222        context: ConnectionContext,
223    ) -> Self {
224        let ConnectionContext {
225            tls_config,
226            redact_sql_option_keywords,
227            message_memory_manager,
228        } = context;
229        Self {
230            stream: PgStream::new(stream),
231            is_terminate: false,
232            state: PgProtocolState::Startup,
233            session_mgr,
234            session: None,
235            tls_context: tls_config
236                .as_ref()
237                .and_then(|e| build_ssl_ctx_from_config(e).ok()),
238            tls_config: tls_config.clone(),
239            result_cache: Default::default(),
240            unnamed_prepare_statement: Default::default(),
241            prepare_statement_store: Default::default(),
242            unnamed_portal: Default::default(),
243            portal_store: Default::default(),
244            statement_portal_dependency: Default::default(),
245            ignore_util_sync: false,
246            peer_addr,
247            redact_sql_option_keywords,
248            message_memory_manager,
249        }
250    }
251
252    /// Run the protocol to serve the connection.
253    pub async fn run(&mut self) {
254        let mut notice_fut = None;
255
256        loop {
257            // Once a session is present, create a future to subscribe and send notices asynchronously.
258            if notice_fut.is_none()
259                && let Some(session) = self.session.clone()
260            {
261                let mut stream = self.stream.clone();
262                notice_fut = Some(Box::pin(async move {
263                    loop {
264                        let notice = session.next_notice().await;
265                        if let Err(e) = stream.write(BeMessage::NoticeResponse(&notice)).await {
266                            tracing::error!(error = %e.as_report(), notice, "failed to send notice");
267                        }
268                    }
269                }));
270            }
271
272            // Read and process messages.
273            let process = std::pin::pin!(async {
274                let (msg, _memory_guard) = match self.read_message().await {
275                    Ok(msg) => msg,
276                    Err(e) => {
277                        tracing::error!(error = %e.as_report(), "error when reading message");
278                        return true; // terminate the connection
279                    }
280                };
281                tracing::trace!(?msg, "received message");
282                self.process(msg).await
283            });
284
285            let terminated = if let Some(notice_fut) = notice_fut.as_mut() {
286                tokio::select! {
287                    _ = notice_fut => unreachable!(),
288                    terminated = process => terminated,
289                }
290            } else {
291                process.await
292            };
293
294            if terminated {
295                break;
296            }
297        }
298    }
299
300    /// Processes one message. Returns true if the connection is terminated.
301    pub async fn process(&mut self, msg: FeMessage) -> bool {
302        self.do_process(msg).await.is_none() || self.is_terminate
303    }
304
305    /// The root tracing span for processing a message. The target of the span is
306    /// [`PGWIRE_ROOT_SPAN_TARGET`].
307    ///
308    /// This is used to provide context for the (slow) query logs and traces.
309    ///
310    /// The span is only effective if there's a current session and the message is
311    /// query-related. Otherwise, `Span::none()` is returned.
312    fn root_span_for_msg(&self, msg: &FeMessage) -> tracing::Span {
313        let Some(session_id) = self.session.as_ref().map(|s| s.id().0) else {
314            return tracing::Span::none();
315        };
316
317        let mode = match msg {
318            FeMessage::Query(_) => "simple query",
319            FeMessage::Parse(_) => "extended query parse",
320            FeMessage::Execute(_) => "extended query execute",
321            _ => return tracing::Span::none(),
322        };
323
324        tracing::info_span!(
325            target: PGWIRE_ROOT_SPAN_TARGET,
326            "handle_query",
327            mode,
328            session_id,
329            sql = tracing::field::Empty, // record SQL later in each `process` call
330        )
331    }
332
333    /// Return type `Option<()>` is essentially a bool, but allows `?` for early return.
334    /// - `None` means to terminate the current connection
335    /// - `Some(())` means to continue processing the next message
336    async fn do_process(&mut self, msg: FeMessage) -> Option<()> {
337        let span = self.root_span_for_msg(&msg);
338        let weak_session = self
339            .session
340            .as_ref()
341            .map(|s| Arc::downgrade(s) as Weak<dyn Any + Send + Sync>);
342
343        // Processing the message itself.
344        //
345        // Note: pin the future to avoid stack overflow as we'll wrap it multiple times
346        // in the following code.
347        let fut = Box::pin(self.do_process_inner(msg));
348
349        // Set the current session as the context when processing the message, if exists.
350        let fut = async move {
351            if let Some(session) = weak_session {
352                CURRENT_SESSION.scope(session, fut).await
353            } else {
354                fut.await
355            }
356        };
357
358        // Catch unwind.
359        let fut = async move {
360            AssertUnwindSafe(fut)
361                .rw_catch_unwind()
362                .await
363                .unwrap_or_else(|payload| {
364                    Err(PsqlError::Panic(
365                        panic_message::panic_message(&payload).to_owned(),
366                    ))
367                })
368        };
369
370        // Slow query log.
371        let fut = async move {
372            let period = *SLOW_QUERY_LOG_PERIOD;
373            let mut fut = std::pin::pin!(fut);
374            let mut elapsed = Duration::ZERO;
375
376            // Report the SQL in the log periodically if the query is slow.
377            loop {
378                match tokio::time::timeout(period, &mut fut).await {
379                    Ok(result) => break result,
380                    Err(_) => {
381                        elapsed += period;
382                        tracing::info!(
383                            target: PGWIRE_SLOW_QUERY_LOG,
384                            elapsed = %format_args!("{}ms", elapsed.as_millis()),
385                            "slow query"
386                        );
387                    }
388                }
389            }
390        };
391
392        // Query log.
393        let fut = async move {
394            let start = Instant::now();
395            let result = fut.await;
396            let elapsed = start.elapsed();
397
398            // Always log if an error occurs.
399            // Note: all messages will be processed through this code path, making it the
400            //       only necessary place to log errors.
401            if let Err(error) = &result {
402                if cfg!(debug_assertions) && !Deployment::current().is_ci() {
403                    // For local debugging, we print the error with backtrace.
404                    // It's useful only when:
405                    // - no additional context is added to the error
406                    // - backtrace is captured in the error
407                    // - backtrace is not printed in the middle
408                    tracing::error!(error = ?error.as_report(), "error when process message");
409                } else {
410                    tracing::error!(error = %error.as_report(), "error when process message");
411                }
412            }
413
414            // Log to optionally-enabled target `PGWIRE_QUERY_LOG`.
415            // Only log if we're currently in a tracing span set in `span_for_msg`.
416            if !tracing::Span::current().is_none() {
417                tracing::info!(
418                    target: PGWIRE_QUERY_LOG,
419                    status = if result.is_ok() { "ok" } else { "err" },
420                    time = %format_args!("{}ms", elapsed.as_millis()),
421                );
422            }
423
424            result
425        };
426
427        // Tracing span.
428        let fut = fut.instrument(span);
429
430        // Execute the future and handle the error.
431        match fut.await {
432            Ok(()) => Some(()),
433            Err(e) => {
434                match e {
435                    PsqlError::IoError(io_err) => {
436                        if io_err.kind() == std::io::ErrorKind::UnexpectedEof {
437                            return None;
438                        }
439                    }
440
441                    PsqlError::SslError(_) => {
442                        // For ssl error, because the stream has already been consumed, so there is
443                        // no way to write more message.
444                        return None;
445                    }
446
447                    PsqlError::StartupError(_) | PsqlError::PasswordError => {
448                        self.stream
449                            .write_no_flush(BeMessage::ErrorResponse(&e))
450                            .ok()?;
451                        let _ = self.stream.flush().await;
452                        return None;
453                    }
454
455                    PsqlError::SimpleQueryError(_) | PsqlError::ServerThrottle(_) => {
456                        self.stream
457                            .write_no_flush(BeMessage::ErrorResponse(&e))
458                            .ok()?;
459                        self.ready_for_query().ok()?;
460                    }
461
462                    PsqlError::IdleInTxnTimeout | PsqlError::Panic(_) => {
463                        self.stream
464                            .write_no_flush(BeMessage::ErrorResponse(&e))
465                            .ok()?;
466                        let _ = self.stream.flush().await;
467
468                        // 1. Catching the panic during message processing may leave the session in an
469                        // inconsistent state. We forcefully close the connection (then end the
470                        // session) here for safety.
471                        // 2. Idle in transaction timeout should also close the connection.
472                        return None;
473                    }
474
475                    PsqlError::Uncategorized(_)
476                    | PsqlError::ExtendedPrepareError(_)
477                    | PsqlError::ExtendedExecuteError(_) => {
478                        self.stream
479                            .write_no_flush(BeMessage::ErrorResponse(&e))
480                            .ok()?;
481                    }
482                }
483                let _ = self.stream.flush().await;
484                Some(())
485            }
486        }
487    }
488
489    async fn do_process_inner(&mut self, msg: FeMessage) -> PsqlResult<()> {
490        // Ignore util sync message.
491        if self.ignore_util_sync {
492            if let FeMessage::Sync = msg {
493            } else {
494                tracing::trace!("ignore message {:?} until sync.", msg);
495                return Ok(());
496            }
497        }
498
499        match msg {
500            FeMessage::Gss => self.process_gss_msg().await?,
501            FeMessage::Ssl => self.process_ssl_msg().await?,
502            FeMessage::Startup(msg) => self.process_startup_msg(msg).await?,
503            FeMessage::Password(msg) => self.process_password_msg(msg).await?,
504            FeMessage::Query(query_msg) => {
505                let sql = Arc::from(query_msg.get_sql()?);
506                // The process_query_msg can be slow. Release potential large FeQueryMessage early.
507                drop(query_msg);
508                self.process_query_msg(sql).await?
509            }
510            FeMessage::CancelQuery(m) => self.process_cancel_msg(m)?,
511            FeMessage::Terminate => self.process_terminate(),
512            FeMessage::Parse(m) => {
513                if let Err(err) = self.process_parse_msg(m).await {
514                    self.ignore_util_sync = true;
515                    return Err(err);
516                }
517            }
518            FeMessage::Bind(m) => {
519                if let Err(err) = self.process_bind_msg(m) {
520                    self.ignore_util_sync = true;
521                    return Err(err);
522                }
523            }
524            FeMessage::Execute(m) => {
525                if let Err(err) = self.process_execute_msg(m).await {
526                    self.ignore_util_sync = true;
527                    return Err(err);
528                }
529            }
530            FeMessage::Describe(m) => {
531                if let Err(err) = self.process_describe_msg(m) {
532                    self.ignore_util_sync = true;
533                    return Err(err);
534                }
535            }
536            FeMessage::Sync => {
537                self.ignore_util_sync = false;
538                self.ready_for_query()?
539            }
540            FeMessage::Close(m) => {
541                if let Err(err) = self.process_close_msg(m) {
542                    self.ignore_util_sync = true;
543                    return Err(err);
544                }
545            }
546            FeMessage::Flush => {
547                if let Err(err) = self.stream.flush().await {
548                    self.ignore_util_sync = true;
549                    return Err(err.into());
550                }
551            }
552            FeMessage::HealthCheck => self.process_health_check(),
553            FeMessage::ServerThrottle(reason) => match reason {
554                ServerThrottleReason::TooLargeMessage => {
555                    return Err(PsqlError::ServerThrottle(format!(
556                        "max_single_query_size_bytes {} has been exceeded, please either reduce the query size or increase the limit",
557                        self.message_memory_manager.max_filter_bytes
558                    )));
559                }
560                ServerThrottleReason::TooManyMemoryUsage => {
561                    return Err(PsqlError::ServerThrottle(format!(
562                        "max_total_query_size_bytes {} has been exceeded, please either retry or increase the limit",
563                        self.message_memory_manager.max_running_bytes
564                    )));
565                }
566            },
567        }
568        self.stream.flush().await?;
569        Ok(())
570    }
571
572    pub async fn read_message(&mut self) -> io::Result<(FeMessage, Option<MessageMemoryGuard>)> {
573        match self.state {
574            PgProtocolState::Startup => self
575                .stream
576                .read_startup()
577                .await
578                .map(|message: FeMessage| (message, None)),
579            PgProtocolState::Regular => {
580                self.stream.read_header().await?;
581                let guard = if let Some(ref header) = self.stream.read_header {
582                    let payload_len = std::cmp::max(header.payload_len, 0) as u64;
583                    let (reason, guard) = self.message_memory_manager.add(payload_len);
584                    if let Some(reason) = reason {
585                        // Release the memory ASAP.
586                        drop(guard);
587                        self.stream.skip_body().await?;
588                        return Ok((FeMessage::ServerThrottle(reason), None));
589                    }
590                    guard
591                } else {
592                    None
593                };
594                let message = self.stream.read_body().await?;
595                Ok((message, guard))
596            }
597        }
598    }
599
600    /// Writes a `ReadyForQuery` message to the client without flushing.
601    fn ready_for_query(&mut self) -> io::Result<()> {
602        self.stream.write_no_flush(BeMessage::ReadyForQuery(
603            self.session
604                .as_ref()
605                .map(|s| s.transaction_status())
606                .unwrap_or(TransactionStatus::Idle),
607        ))
608    }
609
610    async fn process_gss_msg(&mut self) -> PsqlResult<()> {
611        // We don't support GSSAPI, so we just say no gracefully.
612        self.stream.write(BeMessage::EncryptionResponseNo).await?;
613        Ok(())
614    }
615
616    async fn process_ssl_msg(&mut self) -> PsqlResult<()> {
617        if let Some(context) = self.tls_context.as_ref() {
618            // If got and ssl context, say yes for ssl connection.
619            // Construct ssl stream and replace with current one.
620            self.stream.write(BeMessage::EncryptionResponseSsl).await?;
621            self.stream.upgrade_to_ssl(context).await?;
622        } else {
623            // If no, say no for encryption.
624            self.stream.write(BeMessage::EncryptionResponseNo).await?;
625        }
626
627        Ok(())
628    }
629
630    async fn process_startup_msg(&mut self, msg: FeStartupMessage) -> PsqlResult<()> {
631        // Check SSL enforcement: if SSL is enforced but connection is not using SSL, reject
632        if let Some(ref tls_config) = self.tls_config
633            && tls_config.enforce_ssl
634            && !self.stream.is_ssl_connection().await
635        {
636            return Err(PsqlError::StartupError(
637                "SSL connection is required but not established".into(),
638            ));
639        }
640
641        let db_name = msg
642            .config
643            .get("database")
644            .cloned()
645            .unwrap_or_else(|| "dev".to_owned());
646        let user_name = msg
647            .config
648            .get("user")
649            .cloned()
650            .unwrap_or_else(|| "root".to_owned());
651
652        let session = self
653            .session_mgr
654            .connect(&db_name, &user_name, self.peer_addr.clone())
655            .map_err(PsqlError::StartupError)?;
656
657        let application_name = msg.config.get("application_name");
658        if let Some(application_name) = application_name {
659            session
660                .set_config("application_name", application_name.clone())
661                .map_err(PsqlError::StartupError)?;
662        }
663
664        match session.user_authenticator() {
665            UserAuthenticator::None => {
666                self.stream.write_no_flush(BeMessage::AuthenticationOk)?;
667
668                // Cancel request need this for identify and verification. According to postgres
669                // doc, it should be written to buffer after receive AuthenticationOk.
670                self.stream
671                    .write_no_flush(BeMessage::BackendKeyData(session.id()))?;
672
673                self.stream.write_no_flush(BeMessage::ParameterStatus(
674                    BeParameterStatusMessage::TimeZone(&session.get_config("timezone")?),
675                ))?;
676                self.stream
677                    .write_parameter_status_msg_no_flush(&ParameterStatus {
678                        application_name: application_name.cloned(),
679                    })?;
680                self.ready_for_query()?;
681            }
682            UserAuthenticator::ClearText(_) | UserAuthenticator::OAuth(_) => {
683                self.stream
684                    .write_no_flush(BeMessage::AuthenticationCleartextPassword)?;
685            }
686            UserAuthenticator::Md5WithSalt { salt, .. } => {
687                self.stream
688                    .write_no_flush(BeMessage::AuthenticationMd5Password(salt))?;
689            }
690        }
691
692        self.session = Some(session);
693        self.state = PgProtocolState::Regular;
694        Ok(())
695    }
696
697    async fn process_password_msg(&mut self, msg: FePasswordMessage) -> PsqlResult<()> {
698        let session = self.session.as_ref().unwrap();
699        let authenticator = session.user_authenticator();
700        authenticator.authenticate(&msg.password).await?;
701        self.stream.write_no_flush(BeMessage::AuthenticationOk)?;
702        self.stream.write_no_flush(BeMessage::ParameterStatus(
703            BeParameterStatusMessage::TimeZone(&session.get_config("timezone")?),
704        ))?;
705        self.stream
706            .write_parameter_status_msg_no_flush(&ParameterStatus::default())?;
707        self.ready_for_query()?;
708        self.state = PgProtocolState::Regular;
709        Ok(())
710    }
711
712    fn process_cancel_msg(&mut self, m: FeCancelMessage) -> PsqlResult<()> {
713        let session_id = (m.target_process_id, m.target_secret_key);
714        tracing::trace!("cancel query in session: {:?}", session_id);
715        self.session_mgr.cancel_queries_in_session(session_id);
716        self.session_mgr.cancel_creating_jobs_in_session(session_id);
717        self.is_terminate = true;
718        Ok(())
719    }
720
721    async fn process_query_msg(&mut self, sql: Arc<str>) -> PsqlResult<()> {
722        let truncated_sql = record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
723        let session = self.session.clone().unwrap();
724
725        session.check_idle_in_transaction_timeout()?;
726        // Store only truncated SQL in context to prevent excessive memory usage from large SQL.
727        let _exec_context_guard = session.init_exec_context(truncated_sql.into());
728        self.inner_process_query_msg(sql, session.clone()).await
729    }
730
731    async fn inner_process_query_msg(
732        &mut self,
733        sql: Arc<str>,
734        session: Arc<SM::Session>,
735    ) -> PsqlResult<()> {
736        // Parse sql.
737        let stmts =
738            Parser::parse_sql(&sql).map_err(|err| PsqlError::SimpleQueryError(err.into()))?;
739        // The following inner_process_query_msg_one_stmt can be slow. Release potential large String early.
740        drop(sql);
741        if stmts.is_empty() {
742            self.stream.write_no_flush(BeMessage::EmptyQueryResponse)?;
743        }
744
745        // Execute multiple statements in simple query. KISS later.
746        for stmt in stmts {
747            self.inner_process_query_msg_one_stmt(stmt, session.clone())
748                .await?;
749        }
750        // Put this line inside the for loop above will lead to unfinished/stuck regress test...Not
751        // sure the reason.
752        self.ready_for_query()?;
753        Ok(())
754    }
755
756    async fn inner_process_query_msg_one_stmt(
757        &mut self,
758        stmt: Statement,
759        session: Arc<SM::Session>,
760    ) -> PsqlResult<()> {
761        let session = session.clone();
762
763        // execute query
764        let res = session.clone().run_one_query(stmt, Format::Text).await;
765
766        // Take all remaining notices (if any) and send them before `CommandComplete`.
767        while let Some(notice) = session.next_notice().now_or_never() {
768            self.stream
769                .write_no_flush(BeMessage::NoticeResponse(&notice))?;
770        }
771
772        let mut res = res.map_err(PsqlError::SimpleQueryError)?;
773
774        for notice in res.notices() {
775            self.stream
776                .write_no_flush(BeMessage::NoticeResponse(notice))?;
777        }
778
779        let status = res.status();
780        if let Some(ref application_name) = status.application_name {
781            self.stream.write_no_flush(BeMessage::ParameterStatus(
782                BeParameterStatusMessage::ApplicationName(application_name),
783            ))?;
784        }
785
786        if res.is_copy_query_to_stdout() {
787            self.stream
788                .write_no_flush(BeMessage::CopyOutResponse(res.row_desc().len()))?;
789            let mut count = 0;
790            while let Some(row_set) = res.values_stream().next().await {
791                let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
792                for row in row_set {
793                    self.stream.write_no_flush(BeMessage::CopyData(&row))?;
794                    count += 1;
795                }
796            }
797
798            self.stream.write_no_flush(BeMessage::CopyDone)?;
799
800            // Run the callback before sending the `CommandComplete` message.
801            res.run_callback().await?;
802
803            self.stream
804                .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
805                    stmt_type: res.stmt_type(),
806                    rows_cnt: count,
807                }))?;
808        } else if res.is_query() {
809            self.stream
810                .write_no_flush(BeMessage::RowDescription(res.row_desc()))?;
811
812            let mut rows_cnt = 0;
813
814            while let Some(row_set) = res.values_stream().next().await {
815                let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
816                for row in row_set {
817                    self.stream.write_no_flush(BeMessage::DataRow(&row))?;
818                    rows_cnt += 1;
819                }
820            }
821
822            // Run the callback before sending the `CommandComplete` message.
823            res.run_callback().await?;
824
825            self.stream
826                .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
827                    stmt_type: res.stmt_type(),
828                    rows_cnt,
829                }))?;
830        } else if res.stmt_type().is_dml() && !res.stmt_type().is_returning() {
831            let first_row_set = res.values_stream().next().await;
832            let first_row_set = match first_row_set {
833                None => {
834                    return Err(PsqlError::Uncategorized(
835                        anyhow::anyhow!("no affected rows in output").into(),
836                    ));
837                }
838                Some(row) => row.map_err(PsqlError::SimpleQueryError)?,
839            };
840            let affected_rows_str = first_row_set[0].values()[0]
841                .as_ref()
842                .expect("compute node should return affected rows in output");
843
844            assert!(matches!(res.row_cnt_format(), Some(Format::Text)));
845            let affected_rows_cnt = String::from_utf8(affected_rows_str.to_vec())
846                .unwrap()
847                .parse()
848                .unwrap_or_default();
849
850            // Run the callback before sending the `CommandComplete` message.
851            res.run_callback().await?;
852
853            self.stream
854                .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
855                    stmt_type: res.stmt_type(),
856                    rows_cnt: affected_rows_cnt,
857                }))?;
858        } else {
859            // Run the callback before sending the `CommandComplete` message.
860            res.run_callback().await?;
861
862            self.stream
863                .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
864                    stmt_type: res.stmt_type(),
865                    rows_cnt: 0,
866                }))?;
867        }
868
869        Ok(())
870    }
871
872    fn process_terminate(&mut self) {
873        self.is_terminate = true;
874    }
875
876    fn process_health_check(&mut self) {
877        tracing::debug!("health check");
878        self.is_terminate = true;
879    }
880
881    async fn process_parse_msg(&mut self, mut msg: FeParseMessage) -> PsqlResult<()> {
882        let sql = Arc::from(cstr_to_str(&msg.sql_bytes).unwrap());
883        record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
884        let session = self.session.clone().unwrap();
885        let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
886        let type_ids = std::mem::take(&mut msg.type_ids);
887        // The inner_process_parse_msg can be slow. Release potential large FeParseMessage early.
888        drop(msg);
889        self.inner_process_parse_msg(session, sql, statement_name, type_ids)
890            .await?;
891        Ok(())
892    }
893
894    async fn inner_process_parse_msg(
895        &mut self,
896        session: Arc<SM::Session>,
897        sql: Arc<str>,
898        statement_name: String,
899        type_ids: Vec<i32>,
900    ) -> PsqlResult<()> {
901        if statement_name.is_empty() {
902            // Remove the unnamed prepare statement first, in case the unsupported sql binds a wrong
903            // prepare statement.
904            self.unnamed_prepare_statement.take();
905        } else if self.prepare_statement_store.contains_key(&statement_name) {
906            return Err(PsqlError::ExtendedPrepareError(
907                "Duplicated statement name".into(),
908            ));
909        }
910
911        let stmt = {
912            let stmts = Parser::parse_sql(&sql)
913                .map_err(|err| PsqlError::ExtendedPrepareError(err.into()))?;
914            drop(sql);
915            if stmts.len() > 1 {
916                return Err(PsqlError::ExtendedPrepareError(
917                    "Only one statement is allowed in extended query mode".into(),
918                ));
919            }
920
921            stmts.into_iter().next()
922        };
923
924        let param_types: Vec<Option<DataType>> = type_ids
925            .iter()
926            .map(|&id| {
927                // 0 means unspecified type
928                // ref: https://www.postgresql.org/docs/15/protocol-message-formats.html#:~:text=Placing%20a%20zero%20here%20is%20equivalent%20to%20leaving%20the%20type%20unspecified.
929                if id == 0 {
930                    Ok(None)
931                } else {
932                    DataType::from_oid(id)
933                        .map(Some)
934                        .map_err(|e| PsqlError::ExtendedPrepareError(e.into()))
935                }
936            })
937            .try_collect()?;
938
939        let prepare_statement = session
940            .parse(stmt, param_types)
941            .await
942            .map_err(PsqlError::ExtendedPrepareError)?;
943
944        if statement_name.is_empty() {
945            self.unnamed_prepare_statement.replace(prepare_statement);
946        } else {
947            self.prepare_statement_store
948                .insert(statement_name.clone(), prepare_statement);
949        }
950
951        self.statement_portal_dependency
952            .entry(statement_name)
953            .or_default()
954            .clear();
955
956        self.stream.write_no_flush(BeMessage::ParseComplete)?;
957        Ok(())
958    }
959
960    fn process_bind_msg(&mut self, msg: FeBindMessage) -> PsqlResult<()> {
961        let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
962        let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
963        let session = self.session.clone().unwrap();
964
965        if self.portal_store.contains_key(&portal_name) {
966            return Err(PsqlError::Uncategorized("Duplicated portal name".into()));
967        }
968
969        let prepare_statement = self.get_statement(&statement_name)?;
970
971        let result_formats = msg
972            .result_format_codes
973            .iter()
974            .map(|&format_code| Format::from_i16(format_code))
975            .try_collect()?;
976        let param_formats = msg
977            .param_format_codes
978            .iter()
979            .map(|&format_code| Format::from_i16(format_code))
980            .try_collect()?;
981
982        let portal = session
983            .bind(prepare_statement, msg.params, param_formats, result_formats)
984            .map_err(PsqlError::Uncategorized)?;
985
986        if portal_name.is_empty() {
987            self.result_cache.remove(&portal_name);
988            self.unnamed_portal.replace(portal);
989        } else {
990            assert!(
991                !self.result_cache.contains_key(&portal_name),
992                "Named portal never can be overridden."
993            );
994            self.portal_store.insert(portal_name.clone(), portal);
995        }
996
997        self.statement_portal_dependency
998            .get_mut(&statement_name)
999            .unwrap()
1000            .push(portal_name);
1001
1002        self.stream.write_no_flush(BeMessage::BindComplete)?;
1003        Ok(())
1004    }
1005
1006    async fn process_execute_msg(&mut self, msg: FeExecuteMessage) -> PsqlResult<()> {
1007        let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
1008        let row_max = msg.max_rows as usize;
1009        drop(msg);
1010        let session = self.session.clone().unwrap();
1011
1012        match self.result_cache.remove(&portal_name) {
1013            Some(mut result_cache) => {
1014                assert!(self.portal_store.contains_key(&portal_name));
1015
1016                let is_cosume_completed =
1017                    result_cache.consume::<S>(row_max, &mut self.stream).await?;
1018
1019                if !is_cosume_completed {
1020                    self.result_cache.insert(portal_name, result_cache);
1021                }
1022            }
1023            _ => {
1024                let portal = self.get_portal(&portal_name)?;
1025                let sql = format!("{}", portal);
1026                let truncated_sql =
1027                    record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
1028                drop(sql);
1029
1030                session.check_idle_in_transaction_timeout()?;
1031                // Store only truncated SQL in context to prevent excessive memory usage from large SQL.
1032                let _exec_context_guard = session.init_exec_context(truncated_sql.into());
1033                let result = session.clone().execute(portal).await;
1034
1035                let pg_response = result.map_err(PsqlError::ExtendedExecuteError)?;
1036                let mut result_cache = ResultCache::new(pg_response);
1037                let is_consume_completed =
1038                    result_cache.consume::<S>(row_max, &mut self.stream).await?;
1039                if !is_consume_completed {
1040                    self.result_cache.insert(portal_name, result_cache);
1041                }
1042            }
1043        }
1044
1045        Ok(())
1046    }
1047
1048    fn process_describe_msg(&mut self, msg: FeDescribeMessage) -> PsqlResult<()> {
1049        let name = cstr_to_str(&msg.name).unwrap().to_owned();
1050        let session = self.session.clone().unwrap();
1051        //  b'S' => Statement
1052        //  b'P' => Portal
1053
1054        assert!(msg.kind == b'S' || msg.kind == b'P');
1055        if msg.kind == b'S' {
1056            let prepare_statement = self.get_statement(&name)?;
1057
1058            let (param_types, row_descriptions) = self
1059                .session
1060                .clone()
1061                .unwrap()
1062                .describe_statement(prepare_statement)
1063                .map_err(PsqlError::Uncategorized)?;
1064            self.stream.write_no_flush(BeMessage::ParameterDescription(
1065                &param_types.iter().map(|t| t.to_oid()).collect_vec(),
1066            ))?;
1067
1068            if row_descriptions.is_empty() {
1069                // According https://www.postgresql.org/docs/current/protocol-flow.html#:~:text=The%20response%20is%20a%20RowDescri[…]0a%20query%20that%20will%20return%20rows%3B,
1070                // return NoData message if the statement is not a query.
1071                self.stream.write_no_flush(BeMessage::NoData)?;
1072            } else {
1073                self.stream
1074                    .write_no_flush(BeMessage::RowDescription(&row_descriptions))?;
1075            }
1076        } else if msg.kind == b'P' {
1077            let portal = self.get_portal(&name)?;
1078
1079            let row_descriptions = session
1080                .describe_portal(portal)
1081                .map_err(PsqlError::Uncategorized)?;
1082
1083            if row_descriptions.is_empty() {
1084                // According https://www.postgresql.org/docs/current/protocol-flow.html#:~:text=The%20response%20is%20a%20RowDescri[…]0a%20query%20that%20will%20return%20rows%3B,
1085                // return NoData message if the statement is not a query.
1086                self.stream.write_no_flush(BeMessage::NoData)?;
1087            } else {
1088                self.stream
1089                    .write_no_flush(BeMessage::RowDescription(&row_descriptions))?;
1090            }
1091        }
1092        Ok(())
1093    }
1094
1095    fn process_close_msg(&mut self, msg: FeCloseMessage) -> PsqlResult<()> {
1096        let name = cstr_to_str(&msg.name).unwrap().to_owned();
1097        assert!(msg.kind == b'S' || msg.kind == b'P');
1098        if msg.kind == b'S' {
1099            if name.is_empty() {
1100                self.unnamed_prepare_statement = None;
1101            } else {
1102                self.prepare_statement_store.remove(&name);
1103            }
1104            for portal_name in self
1105                .statement_portal_dependency
1106                .remove(&name)
1107                .unwrap_or_default()
1108            {
1109                self.remove_portal(&portal_name);
1110            }
1111        } else if msg.kind == b'P' {
1112            self.remove_portal(&name);
1113        }
1114        self.stream.write_no_flush(BeMessage::CloseComplete)?;
1115        Ok(())
1116    }
1117
1118    fn remove_portal(&mut self, portal_name: &str) {
1119        if portal_name.is_empty() {
1120            self.unnamed_portal = None;
1121        } else {
1122            self.portal_store.remove(portal_name);
1123        }
1124        self.result_cache.remove(portal_name);
1125    }
1126
1127    fn get_portal(&self, portal_name: &str) -> PsqlResult<<SM::Session as Session>::Portal> {
1128        if portal_name.is_empty() {
1129            Ok(self
1130                .unnamed_portal
1131                .as_ref()
1132                .ok_or_else(|| PsqlError::Uncategorized("unnamed portal not found".into()))?
1133                .clone())
1134        } else {
1135            Ok(self
1136                .portal_store
1137                .get(portal_name)
1138                .ok_or_else(|| {
1139                    PsqlError::Uncategorized(format!("Portal {} not found", portal_name).into())
1140                })?
1141                .clone())
1142        }
1143    }
1144
1145    fn get_statement(
1146        &self,
1147        statement_name: &str,
1148    ) -> PsqlResult<<SM::Session as Session>::PreparedStatement> {
1149        if statement_name.is_empty() {
1150            Ok(self
1151                .unnamed_prepare_statement
1152                .as_ref()
1153                .ok_or_else(|| {
1154                    PsqlError::Uncategorized("unnamed prepare statement not found".into())
1155                })?
1156                .clone())
1157        } else {
1158            Ok(self
1159                .prepare_statement_store
1160                .get(statement_name)
1161                .ok_or_else(|| {
1162                    PsqlError::Uncategorized(
1163                        format!("Prepare statement {} not found", statement_name).into(),
1164                    )
1165                })?
1166                .clone())
1167        }
1168    }
1169}
1170
1171enum PgStreamInner<S> {
1172    /// Used for the intermediate state when converting from unencrypted to ssl stream.
1173    Placeholder,
1174    /// An unencrypted stream.
1175    Unencrypted(S),
1176    /// An ssl stream.
1177    Ssl(SslStream<S>),
1178}
1179
1180/// Trait for a byte stream that can be used for pg protocol.
1181pub trait PgByteStream: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1182impl<S> PgByteStream for S where S: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1183
1184/// Wraps a byte stream and read/write pg messages.
1185///
1186/// Cloning a `PgStream` will share the same stream but a fresh & independent write buffer,
1187/// so that it can be used to write messages concurrently without interference.
1188pub struct PgStream<S> {
1189    /// The underlying stream.
1190    stream: Arc<Mutex<PgStreamInner<S>>>,
1191    /// Write into buffer before flush to stream.
1192    write_buf: BytesMut,
1193    read_header: Option<FeMessageHeader>,
1194}
1195
1196impl<S> PgStream<S> {
1197    /// Create a new `PgStream` with the given stream and default write buffer capacity.
1198    pub fn new(stream: S) -> Self {
1199        const DEFAULT_WRITE_BUF_CAPACITY: usize = 10 * 1024;
1200
1201        Self {
1202            stream: Arc::new(Mutex::new(PgStreamInner::Unencrypted(stream))),
1203            write_buf: BytesMut::with_capacity(DEFAULT_WRITE_BUF_CAPACITY),
1204            read_header: None,
1205        }
1206    }
1207
1208    /// Check if the current connection is using SSL
1209    async fn is_ssl_connection(&self) -> bool {
1210        let stream = self.stream.lock().await;
1211        matches!(*stream, PgStreamInner::Ssl(_))
1212    }
1213}
1214
1215impl<S> Clone for PgStream<S> {
1216    fn clone(&self) -> Self {
1217        Self {
1218            stream: Arc::clone(&self.stream),
1219            write_buf: BytesMut::with_capacity(self.write_buf.capacity()),
1220            read_header: self.read_header.clone(),
1221        }
1222    }
1223}
1224
1225/// At present there is a hard-wired set of parameters for which
1226/// ParameterStatus will be generated: they are:
1227///
1228///  * `server_version`
1229///  * `server_encoding`
1230///  * `client_encoding`
1231///  * `application_name`
1232///  * `is_superuser`
1233///  * `session_authorization`
1234///  * `DateStyle`
1235///  * `IntervalStyle`
1236///  * `TimeZone`
1237///  * `integer_datetimes`
1238///  * `standard_conforming_string`
1239///
1240/// See: <https://www.postgresql.org/docs/9.2/static/protocol-flow.html#PROTOCOL-ASYNC>.
1241#[derive(Debug, Default, Clone)]
1242pub struct ParameterStatus {
1243    pub application_name: Option<String>,
1244}
1245
1246impl<S> PgStream<S>
1247where
1248    S: PgByteStream,
1249{
1250    async fn read_startup(&mut self) -> io::Result<FeMessage> {
1251        let mut stream = self.stream.lock().await;
1252        match &mut *stream {
1253            PgStreamInner::Placeholder => unreachable!(),
1254            PgStreamInner::Unencrypted(stream) => FeStartupMessage::read(stream).await,
1255            PgStreamInner::Ssl(ssl_stream) => FeStartupMessage::read(ssl_stream).await,
1256        }
1257    }
1258
1259    async fn read_header(&mut self) -> io::Result<()> {
1260        let mut stream = self.stream.lock().await;
1261        match &mut *stream {
1262            PgStreamInner::Placeholder => unreachable!(),
1263            PgStreamInner::Unencrypted(stream) => {
1264                self.read_header = Some(FeMessage::read_header(stream).await?);
1265                Ok(())
1266            }
1267            PgStreamInner::Ssl(ssl_stream) => {
1268                self.read_header = Some(FeMessage::read_header(ssl_stream).await?);
1269                Ok(())
1270            }
1271        }
1272    }
1273
1274    async fn read_body(&mut self) -> io::Result<FeMessage> {
1275        let mut stream = self.stream.lock().await;
1276        let header = self
1277            .read_header
1278            .take()
1279            .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1280        match &mut *stream {
1281            PgStreamInner::Placeholder => unreachable!(),
1282            PgStreamInner::Unencrypted(stream) => FeMessage::read_body(stream, header).await,
1283            PgStreamInner::Ssl(ssl_stream) => FeMessage::read_body(ssl_stream, header).await,
1284        }
1285    }
1286
1287    async fn skip_body(&mut self) -> io::Result<()> {
1288        let mut stream = self.stream.lock().await;
1289        let header = self
1290            .read_header
1291            .take()
1292            .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1293        match &mut *stream {
1294            PgStreamInner::Placeholder => unreachable!(),
1295            PgStreamInner::Unencrypted(stream) => FeMessage::skip_body(stream, header).await,
1296            PgStreamInner::Ssl(ssl_stream) => FeMessage::skip_body(ssl_stream, header).await,
1297        }
1298    }
1299
1300    fn write_parameter_status_msg_no_flush(&mut self, status: &ParameterStatus) -> io::Result<()> {
1301        self.write_no_flush(BeMessage::ParameterStatus(
1302            BeParameterStatusMessage::ClientEncoding(SERVER_ENCODING),
1303        ))?;
1304        self.write_no_flush(BeMessage::ParameterStatus(
1305            BeParameterStatusMessage::StandardConformingString(STANDARD_CONFORMING_STRINGS),
1306        ))?;
1307        self.write_no_flush(BeMessage::ParameterStatus(
1308            BeParameterStatusMessage::ServerVersion(PG_VERSION),
1309        ))?;
1310        if let Some(application_name) = &status.application_name {
1311            self.write_no_flush(BeMessage::ParameterStatus(
1312                BeParameterStatusMessage::ApplicationName(application_name),
1313            ))?;
1314        }
1315        Ok(())
1316    }
1317
1318    pub fn write_no_flush(&mut self, message: BeMessage<'_>) -> io::Result<()> {
1319        BeMessage::write(&mut self.write_buf, message)
1320    }
1321
1322    async fn write(&mut self, message: BeMessage<'_>) -> io::Result<()> {
1323        self.write_no_flush(message)?;
1324        self.flush().await?;
1325        Ok(())
1326    }
1327
1328    async fn flush(&mut self) -> io::Result<()> {
1329        let mut stream = self.stream.lock().await;
1330        match &mut *stream {
1331            PgStreamInner::Placeholder => unreachable!(),
1332            PgStreamInner::Unencrypted(stream) => {
1333                stream.write_all(&self.write_buf).await?;
1334                stream.flush().await?;
1335            }
1336            PgStreamInner::Ssl(ssl_stream) => {
1337                ssl_stream.write_all(&self.write_buf).await?;
1338                ssl_stream.flush().await?;
1339            }
1340        }
1341        self.write_buf.clear();
1342        Ok(())
1343    }
1344}
1345
1346impl<S> PgStream<S>
1347where
1348    S: PgByteStream,
1349{
1350    /// Convert the underlying stream to ssl stream based on the given context.
1351    async fn upgrade_to_ssl(&mut self, ssl_ctx: &SslContextRef) -> PsqlResult<()> {
1352        let mut stream = self.stream.lock().await;
1353
1354        match std::mem::replace(&mut *stream, PgStreamInner::Placeholder) {
1355            PgStreamInner::Unencrypted(unencrypted_stream) => {
1356                let ssl = openssl::ssl::Ssl::new(ssl_ctx).unwrap();
1357                let mut ssl_stream =
1358                    tokio_openssl::SslStream::new(ssl, unencrypted_stream).unwrap();
1359
1360                if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
1361                    tracing::warn!(error = %e.as_report(), "Unable to set up an ssl connection");
1362                    let _ = ssl_stream.shutdown().await;
1363                    return Err(e.into());
1364                }
1365
1366                *stream = PgStreamInner::Ssl(ssl_stream);
1367            }
1368            PgStreamInner::Ssl(_) => panic!("the stream is already ssl"),
1369            PgStreamInner::Placeholder => unreachable!(),
1370        }
1371
1372        Ok(())
1373    }
1374}
1375
1376fn build_ssl_ctx_from_config(tls_config: &TlsConfig) -> PsqlResult<SslContext> {
1377    let mut acceptor = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls()).unwrap();
1378
1379    let key_path = &tls_config.key;
1380    let cert_path = &tls_config.cert;
1381
1382    // Build ssl acceptor according to the config.
1383    // Now we set every verify to true.
1384    acceptor
1385        .set_private_key_file(key_path, openssl::ssl::SslFiletype::PEM)
1386        .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1387    acceptor
1388        .set_ca_file(cert_path)
1389        .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1390    acceptor
1391        .set_certificate_chain_file(cert_path)
1392        .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1393    let acceptor = acceptor.build();
1394
1395    Ok(acceptor.into_context())
1396}
1397
1398pub mod truncated_fmt {
1399    use std::fmt::*;
1400
1401    struct TruncatedFormatter<'a, 'b> {
1402        remaining: usize,
1403        finished: bool,
1404        f: &'a mut Formatter<'b>,
1405    }
1406    impl Write for TruncatedFormatter<'_, '_> {
1407        fn write_str(&mut self, s: &str) -> Result {
1408            if self.finished {
1409                return Ok(());
1410            }
1411
1412            if self.remaining < s.len() {
1413                let actual = s.floor_char_boundary(self.remaining);
1414                self.f.write_str(&s[0..actual])?;
1415                self.remaining -= actual;
1416                self.f.write_str(&format!("...(truncated,{})", s.len()))?;
1417                self.finished = true; // so that ...(truncated) is printed exactly once
1418            } else {
1419                self.f.write_str(s)?;
1420                self.remaining -= s.len();
1421            }
1422            Ok(())
1423        }
1424    }
1425
1426    pub struct TruncatedFmt<'a, T>(pub &'a T, pub usize);
1427
1428    impl<T> Debug for TruncatedFmt<'_, T>
1429    where
1430        T: Debug,
1431    {
1432        fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1433            TruncatedFormatter {
1434                remaining: self.1,
1435                finished: false,
1436                f,
1437            }
1438            .write_fmt(format_args!("{:?}", self.0))
1439        }
1440    }
1441
1442    impl<T> Display for TruncatedFmt<'_, T>
1443    where
1444        T: Display,
1445    {
1446        fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1447            TruncatedFormatter {
1448                remaining: self.1,
1449                finished: false,
1450                f,
1451            }
1452            .write_fmt(format_args!("{}", self.0))
1453        }
1454    }
1455
1456    #[cfg(test)]
1457    mod tests {
1458        use super::*;
1459
1460        #[test]
1461        fn test_trunc_utf8() {
1462            assert_eq!(
1463                format!("{}", TruncatedFmt(&"select '🌊';", 10)),
1464                "select '...(truncated,14)",
1465            );
1466        }
1467    }
1468}
1469
1470#[cfg(test)]
1471mod tests {
1472    use std::collections::HashSet;
1473
1474    use super::*;
1475
1476    #[test]
1477    fn test_redact_parsable_sql() {
1478        let keywords = Arc::new(HashSet::from(["v2".into(), "v4".into(), "b".into()]));
1479        let sql = r"
1480        create source temp (k bigint, v varchar) with (
1481            connector = 'datagen',
1482            v1 = 123,
1483            v2 = 'with',
1484            v3 = false,
1485            v4 = '',
1486        ) FORMAT plain ENCODE json (a='1',b='2')
1487        ";
1488        assert_eq!(
1489            redact_sql(sql, keywords),
1490            "CREATE SOURCE temp (k BIGINT, v CHARACTER VARYING) WITH (connector = 'datagen', v1 = 123, v2 = [REDACTED], v3 = false, v4 = [REDACTED]) FORMAT PLAIN ENCODE JSON (a = '1', b = [REDACTED])"
1491        );
1492    }
1493}