pgwire/
pg_protocol.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::collections::HashMap;
17use std::io::ErrorKind;
18use std::panic::AssertUnwindSafe;
19use std::pin::Pin;
20use std::str::Utf8Error;
21use std::sync::{Arc, LazyLock, Weak};
22use std::time::{Duration, Instant};
23use std::{io, str};
24
25use bytes::{Bytes, BytesMut};
26use futures::FutureExt;
27use futures::stream::StreamExt;
28use itertools::Itertools;
29use openssl::ssl::{SslAcceptor, SslContext, SslContextRef, SslMethod};
30use risingwave_common::types::DataType;
31use risingwave_common::util::deployment::Deployment;
32use risingwave_common::util::env_var::env_var_is_true;
33use risingwave_common::util::panic::FutureCatchUnwindExt;
34use risingwave_common::util::query_log::*;
35use risingwave_common::{PG_VERSION, SERVER_ENCODING, STANDARD_CONFORMING_STRINGS};
36use risingwave_sqlparser::ast::{RedactSqlOptionKeywordsRef, Statement};
37use risingwave_sqlparser::parser::Parser;
38use thiserror_ext::AsReport;
39use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
40use tokio::sync::Mutex;
41use tokio_openssl::SslStream;
42use tracing::Instrument;
43
44use crate::error::{PsqlError, PsqlResult};
45use crate::memory_manager::{MessageMemoryGuard, MessageMemoryManagerRef};
46use crate::net::AddressRef;
47use crate::pg_extended::ResultCache;
48use crate::pg_message::{
49    BeCommandCompleteMessage, BeMessage, BeParameterStatusMessage, FeBindMessage, FeCancelMessage,
50    FeCloseMessage, FeDescribeMessage, FeExecuteMessage, FeMessage, FeMessageHeader,
51    FeParseMessage, FePasswordMessage, FeStartupMessage, ServerThrottleReason, TransactionStatus,
52};
53use crate::pg_server::{Session, SessionManager, UserAuthenticator};
54use crate::types::Format;
55
56/// Truncates query log if it's longer than `RW_QUERY_LOG_TRUNCATE_LEN`, to avoid log file being too
57/// large.
58static RW_QUERY_LOG_TRUNCATE_LEN: LazyLock<usize> =
59    LazyLock::new(|| match std::env::var("RW_QUERY_LOG_TRUNCATE_LEN") {
60        Ok(len) if len.parse::<usize>().is_ok() => len.parse::<usize>().unwrap(),
61        _ => {
62            if cfg!(debug_assertions) {
63                65536
64            } else {
65                1024
66            }
67        }
68    });
69
70tokio::task_local! {
71    /// The current session. Concrete type is erased for different session implementations.
72    pub static CURRENT_SESSION: Weak<dyn Any + Send + Sync>
73}
74
75/// The state machine for each psql connection.
76/// Read pg messages from tcp stream and write results back.
77pub struct PgProtocol<S, SM>
78where
79    SM: SessionManager,
80{
81    /// Used for write/read pg messages.
82    stream: PgStream<S>,
83    /// Current states of pg connection.
84    state: PgProtocolState,
85    /// Whether the connection is terminated.
86    is_terminate: bool,
87
88    session_mgr: Arc<SM>,
89    session: Option<Arc<SM::Session>>,
90
91    result_cache: HashMap<String, ResultCache<<SM::Session as Session>::ValuesStream>>,
92    unnamed_prepare_statement: Option<<SM::Session as Session>::PreparedStatement>,
93    prepare_statement_store: HashMap<String, <SM::Session as Session>::PreparedStatement>,
94    unnamed_portal: Option<<SM::Session as Session>::Portal>,
95    portal_store: HashMap<String, <SM::Session as Session>::Portal>,
96    // Used to store the dependency of portal and prepare statement.
97    // When we close a prepare statement, we need to close all the portals that depend on it.
98    statement_portal_dependency: HashMap<String, Vec<String>>,
99
100    // Used for ssl connection.
101    // If None, not expected to build ssl connection (panic).
102    tls_context: Option<SslContext>,
103
104    // TLS configuration including SSL enforcement setting
105    tls_config: Option<TlsConfig>,
106
107    // Used in extended query protocol. When encounter error in extended query, we need to ignore
108    // the following message util sync message.
109    ignore_util_sync: bool,
110
111    // Client Address
112    peer_addr: AddressRef,
113
114    redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
115    message_memory_manager: MessageMemoryManagerRef,
116}
117
118/// Configures TLS encryption for connections.
119#[derive(Debug, Clone)]
120pub struct TlsConfig {
121    /// The path to the TLS certificate.
122    pub cert: String,
123    /// The path to the TLS key.
124    pub key: String,
125    /// Whether to enforce SSL connections (reject non-SSL clients).
126    pub enforce_ssl: bool,
127}
128
129impl TlsConfig {
130    pub fn new_default() -> Option<Self> {
131        let cert = std::env::var("RW_SSL_CERT").ok()?;
132        let key = std::env::var("RW_SSL_KEY").ok()?;
133        let enforce_ssl = env_var_is_true("RW_SSL_ENFORCE");
134        tracing::info!(
135            "RW_SSL_CERT={}, RW_SSL_KEY={}, RW_SSL_ENFORCE={}",
136            cert,
137            key,
138            enforce_ssl
139        );
140        Some(Self {
141            cert,
142            key,
143            enforce_ssl,
144        })
145    }
146}
147
148impl<S, SM> Drop for PgProtocol<S, SM>
149where
150    SM: SessionManager,
151{
152    fn drop(&mut self) {
153        if let Some(session) = &self.session {
154            // Clear the session in session manager.
155            self.session_mgr.end_session(session);
156        }
157    }
158}
159
160/// States flow happened from top to down.
161enum PgProtocolState {
162    Startup,
163    Regular,
164}
165
166/// Truncate 0 from C string in Bytes and stringify it (returns slice, no allocations).
167///
168/// PG protocol strings are always C strings.
169pub fn cstr_to_str(b: &Bytes) -> Result<&str, Utf8Error> {
170    let without_null = if b.last() == Some(&0) {
171        &b[..b.len() - 1]
172    } else {
173        &b[..]
174    };
175    std::str::from_utf8(without_null)
176}
177
178/// Record `sql` in the current tracing span.
179fn record_sql_in_span(
180    sql: &str,
181    redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
182) -> String {
183    let redacted_sql = if let Some(keywords) = redact_sql_option_keywords
184        && !keywords.is_empty()
185    {
186        redact_sql(sql, keywords)
187    } else {
188        sql.to_owned()
189    };
190    let truncated = truncated_fmt::TruncatedFmt(&redacted_sql, *RW_QUERY_LOG_TRUNCATE_LEN);
191    tracing::Span::current().record("sql", tracing::field::display(&truncated));
192    truncated.to_string()
193}
194
195/// Redacts SQL options. Data in DML is not redacted.
196fn redact_sql(sql: &str, keywords: RedactSqlOptionKeywordsRef) -> String {
197    match Parser::parse_sql(sql) {
198        Ok(sqls) => sqls
199            .into_iter()
200            .map(|sql| sql.to_redacted_string(keywords.clone()))
201            .join(";"),
202        Err(_) => sql.to_owned(),
203    }
204}
205
206#[derive(Clone)]
207pub struct ConnectionContext {
208    pub tls_config: Option<TlsConfig>,
209    pub redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
210    pub message_memory_manager: MessageMemoryManagerRef,
211}
212
213impl<S, SM> PgProtocol<S, SM>
214where
215    S: PgByteStream,
216    SM: SessionManager,
217{
218    pub fn new(
219        stream: S,
220        session_mgr: Arc<SM>,
221        peer_addr: AddressRef,
222        context: ConnectionContext,
223    ) -> Self {
224        let ConnectionContext {
225            tls_config,
226            redact_sql_option_keywords,
227            message_memory_manager,
228        } = context;
229        Self {
230            stream: PgStream::new(stream),
231            is_terminate: false,
232            state: PgProtocolState::Startup,
233            session_mgr,
234            session: None,
235            tls_context: tls_config
236                .as_ref()
237                .and_then(|e| build_ssl_ctx_from_config(e).ok()),
238            tls_config,
239            result_cache: Default::default(),
240            unnamed_prepare_statement: Default::default(),
241            prepare_statement_store: Default::default(),
242            unnamed_portal: Default::default(),
243            portal_store: Default::default(),
244            statement_portal_dependency: Default::default(),
245            ignore_util_sync: false,
246            peer_addr,
247            redact_sql_option_keywords,
248            message_memory_manager,
249        }
250    }
251
252    /// Run the protocol to serve the connection.
253    pub async fn run(&mut self) {
254        let mut notice_fut = None;
255
256        loop {
257            // Once a session is present, create a future to subscribe and send notices asynchronously.
258            if notice_fut.is_none()
259                && let Some(session) = self.session.clone()
260            {
261                let mut stream = self.stream.clone();
262                notice_fut = Some(Box::pin(async move {
263                    loop {
264                        let notice = session.next_notice().await;
265                        if let Err(e) = stream.write(BeMessage::NoticeResponse(&notice)).await {
266                            tracing::error!(error = %e.as_report(), notice, "failed to send notice");
267                        }
268                    }
269                }));
270            }
271
272            // Read and process messages.
273            let process = std::pin::pin!(async {
274                let (msg, _memory_guard) = match self.read_message().await {
275                    Ok(msg) => msg,
276                    Err(e) => {
277                        tracing::error!(error = %e.as_report(), "error when reading message");
278                        return true; // terminate the connection
279                    }
280                };
281                tracing::trace!(?msg, "received message");
282                self.process(msg).await
283            });
284
285            let terminated = if let Some(notice_fut) = notice_fut.as_mut() {
286                tokio::select! {
287                    _ = notice_fut => unreachable!(),
288                    terminated = process => terminated,
289                }
290            } else {
291                process.await
292            };
293
294            if terminated {
295                break;
296            }
297        }
298    }
299
300    /// Processes one message. Returns true if the connection is terminated.
301    pub async fn process(&mut self, msg: FeMessage) -> bool {
302        self.do_process(msg).await.is_none() || self.is_terminate
303    }
304
305    /// The root tracing span for processing a message. The target of the span is
306    /// [`PGWIRE_ROOT_SPAN_TARGET`].
307    ///
308    /// This is used to provide context for the (slow) query logs and traces.
309    ///
310    /// The span is only effective if there's a current session and the message is
311    /// query-related. Otherwise, `Span::none()` is returned.
312    fn root_span_for_msg(&self, msg: &FeMessage) -> tracing::Span {
313        let Some(session_id) = self.session.as_ref().map(|s| s.id().0) else {
314            return tracing::Span::none();
315        };
316
317        let mode = match msg {
318            FeMessage::Query(_) => "simple query",
319            FeMessage::Parse(_) => "extended query parse",
320            FeMessage::Execute(_) => "extended query execute",
321            _ => return tracing::Span::none(),
322        };
323
324        tracing::info_span!(
325            target: PGWIRE_ROOT_SPAN_TARGET,
326            "handle_query",
327            mode,
328            session_id,
329            sql = tracing::field::Empty, // record SQL later in each `process` call
330        )
331    }
332
333    /// Return type `Option<()>` is essentially a bool, but allows `?` for early return.
334    /// - `None` means to terminate the current connection
335    /// - `Some(())` means to continue processing the next message
336    async fn do_process(&mut self, msg: FeMessage) -> Option<()> {
337        let span = self.root_span_for_msg(&msg);
338        let weak_session = self
339            .session
340            .as_ref()
341            .map(|s| Arc::downgrade(s) as Weak<dyn Any + Send + Sync>);
342
343        // Processing the message itself.
344        //
345        // Note: pin the future to avoid stack overflow as we'll wrap it multiple times
346        // in the following code.
347        let fut = Box::pin(self.do_process_inner(msg));
348
349        // Set the current session as the context when processing the message, if exists.
350        let fut = async move {
351            if let Some(session) = weak_session {
352                CURRENT_SESSION.scope(session, fut).await
353            } else {
354                fut.await
355            }
356        };
357
358        // Catch unwind.
359        let fut = async move {
360            AssertUnwindSafe(fut)
361                .rw_catch_unwind()
362                .await
363                .unwrap_or_else(|payload| {
364                    Err(PsqlError::Panic(
365                        panic_message::panic_message(&payload).to_owned(),
366                    ))
367                })
368        };
369
370        // Slow query log.
371        let fut = async move {
372            let period = *SLOW_QUERY_LOG_PERIOD;
373            let mut fut = std::pin::pin!(fut);
374            let mut elapsed = Duration::ZERO;
375
376            // Report the SQL in the log periodically if the query is slow.
377            loop {
378                match tokio::time::timeout(period, &mut fut).await {
379                    Ok(result) => break result,
380                    Err(_) => {
381                        elapsed += period;
382                        tracing::info!(
383                            target: PGWIRE_SLOW_QUERY_LOG,
384                            elapsed = %format_args!("{}ms", elapsed.as_millis()),
385                            "slow query"
386                        );
387                    }
388                }
389            }
390        };
391
392        // Query log.
393        let fut = async move {
394            let start = Instant::now();
395            let result = fut.await;
396            let elapsed = start.elapsed();
397
398            // Always log if an error occurs.
399            // Note: all messages will be processed through this code path, making it the
400            //       only necessary place to log errors.
401            if let Err(error) = &result {
402                if cfg!(debug_assertions) && !Deployment::current().is_ci() {
403                    // For local debugging, we print the error with backtrace.
404                    // It's useful only when:
405                    // - no additional context is added to the error
406                    // - backtrace is captured in the error
407                    // - backtrace is not printed in the middle
408                    tracing::error!(error = ?error.as_report(), "error when process message");
409                } else {
410                    tracing::error!(error = %error.as_report(), "error when process message");
411                }
412            }
413
414            // Log to optionally-enabled target `PGWIRE_QUERY_LOG`.
415            // Only log if we're currently in a tracing span set in `span_for_msg`.
416            if !tracing::Span::current().is_none() {
417                tracing::info!(
418                    target: PGWIRE_QUERY_LOG,
419                    status = if result.is_ok() { "ok" } else { "err" },
420                    time = %format_args!("{}ms", elapsed.as_millis()),
421                );
422            }
423
424            result
425        };
426
427        // Tracing span.
428        let fut = fut.instrument(span);
429
430        // Execute the future and handle the error.
431        match fut.await {
432            Ok(()) => Some(()),
433            Err(e) => {
434                match e {
435                    PsqlError::IoError(io_err) => {
436                        if io_err.kind() == std::io::ErrorKind::UnexpectedEof {
437                            return None;
438                        }
439                    }
440
441                    PsqlError::SslError(_) => {
442                        // For ssl error, because the stream has already been consumed, so there is
443                        // no way to write more message.
444                        return None;
445                    }
446
447                    PsqlError::StartupError(_) | PsqlError::PasswordError => {
448                        self.stream
449                            .write_no_flush(BeMessage::ErrorResponse(&e))
450                            .ok()?;
451                        let _ = self.stream.flush().await;
452                        return None;
453                    }
454
455                    PsqlError::SimpleQueryError(_) | PsqlError::ServerThrottle(_) => {
456                        self.stream
457                            .write_no_flush(BeMessage::ErrorResponse(&e))
458                            .ok()?;
459                        self.ready_for_query().ok()?;
460                    }
461
462                    PsqlError::IdleInTxnTimeout | PsqlError::Panic(_) => {
463                        self.stream
464                            .write_no_flush(BeMessage::ErrorResponse(&e))
465                            .ok()?;
466                        let _ = self.stream.flush().await;
467
468                        // 1. Catching the panic during message processing may leave the session in an
469                        // inconsistent state. We forcefully close the connection (then end the
470                        // session) here for safety.
471                        // 2. Idle in transaction timeout should also close the connection.
472                        return None;
473                    }
474
475                    PsqlError::Uncategorized(_)
476                    | PsqlError::ExtendedPrepareError(_)
477                    | PsqlError::ExtendedExecuteError(_) => {
478                        self.stream
479                            .write_no_flush(BeMessage::ErrorResponse(&e))
480                            .ok()?;
481                    }
482                }
483                let _ = self.stream.flush().await;
484                Some(())
485            }
486        }
487    }
488
489    async fn do_process_inner(&mut self, msg: FeMessage) -> PsqlResult<()> {
490        // Ignore util sync message.
491        if self.ignore_util_sync {
492            if let FeMessage::Sync = msg {
493            } else {
494                tracing::trace!("ignore message {:?} until sync.", msg);
495                return Ok(());
496            }
497        }
498
499        match msg {
500            FeMessage::Gss => self.process_gss_msg().await?,
501            FeMessage::Ssl => self.process_ssl_msg().await?,
502            FeMessage::Startup(msg) => self.process_startup_msg(msg).await?,
503            FeMessage::Password(msg) => self.process_password_msg(msg).await?,
504            FeMessage::Query(query_msg) => {
505                let sql = Arc::from(query_msg.get_sql()?);
506                // The process_query_msg can be slow. Release potential large FeQueryMessage early.
507                drop(query_msg);
508                self.process_query_msg(sql).await?
509            }
510            FeMessage::CancelQuery(m) => self.process_cancel_msg(m)?,
511            FeMessage::Terminate => self.process_terminate(),
512            FeMessage::Parse(m) => {
513                if let Err(err) = self.process_parse_msg(m).await {
514                    self.ignore_util_sync = true;
515                    return Err(err);
516                }
517            }
518            FeMessage::Bind(m) => {
519                if let Err(err) = self.process_bind_msg(m) {
520                    self.ignore_util_sync = true;
521                    return Err(err);
522                }
523            }
524            FeMessage::Execute(m) => {
525                if let Err(err) = self.process_execute_msg(m).await {
526                    self.ignore_util_sync = true;
527                    return Err(err);
528                }
529            }
530            FeMessage::Describe(m) => {
531                if let Err(err) = self.process_describe_msg(m) {
532                    self.ignore_util_sync = true;
533                    return Err(err);
534                }
535            }
536            FeMessage::Sync => {
537                self.ignore_util_sync = false;
538                self.ready_for_query()?
539            }
540            FeMessage::Close(m) => {
541                if let Err(err) = self.process_close_msg(m) {
542                    self.ignore_util_sync = true;
543                    return Err(err);
544                }
545            }
546            FeMessage::Flush => {
547                if let Err(err) = self.stream.flush().await {
548                    self.ignore_util_sync = true;
549                    return Err(err.into());
550                }
551            }
552            FeMessage::HealthCheck => self.process_health_check(),
553            FeMessage::ServerThrottle(reason) => match reason {
554                ServerThrottleReason::TooLargeMessage => {
555                    return Err(PsqlError::ServerThrottle(format!(
556                        "max_single_query_size_bytes {} has been exceeded, please either reduce the query size or increase the limit",
557                        self.message_memory_manager.max_filter_bytes
558                    )));
559                }
560                ServerThrottleReason::TooManyMemoryUsage => {
561                    return Err(PsqlError::ServerThrottle(format!(
562                        "max_total_query_size_bytes {} has been exceeded, please either retry or increase the limit",
563                        self.message_memory_manager.max_running_bytes
564                    )));
565                }
566            },
567        }
568        self.stream.flush().await?;
569        Ok(())
570    }
571
572    pub async fn read_message(&mut self) -> io::Result<(FeMessage, Option<MessageMemoryGuard>)> {
573        match self.state {
574            PgProtocolState::Startup => self
575                .stream
576                .read_startup()
577                .await
578                .map(|message: FeMessage| (message, None)),
579            PgProtocolState::Regular => {
580                self.stream.read_header().await?;
581                let guard = if let Some(ref header) = self.stream.read_header {
582                    let payload_len = std::cmp::max(header.payload_len, 0) as u64;
583                    let (reason, guard) = self.message_memory_manager.add(payload_len);
584                    if let Some(reason) = reason {
585                        // Release the memory ASAP.
586                        drop(guard);
587                        self.stream.skip_body().await?;
588                        return Ok((FeMessage::ServerThrottle(reason), None));
589                    }
590                    guard
591                } else {
592                    None
593                };
594                let message = self.stream.read_body().await?;
595                Ok((message, guard))
596            }
597        }
598    }
599
600    /// Writes a `ReadyForQuery` message to the client without flushing.
601    fn ready_for_query(&mut self) -> io::Result<()> {
602        self.stream.write_no_flush(BeMessage::ReadyForQuery(
603            self.session
604                .as_ref()
605                .map(|s| s.transaction_status())
606                .unwrap_or(TransactionStatus::Idle),
607        ))
608    }
609
610    async fn process_gss_msg(&mut self) -> PsqlResult<()> {
611        // We don't support GSSAPI, so we just say no gracefully.
612        self.stream.write(BeMessage::EncryptionResponseNo).await?;
613        Ok(())
614    }
615
616    async fn process_ssl_msg(&mut self) -> PsqlResult<()> {
617        if let Some(context) = self.tls_context.as_ref() {
618            // If got and ssl context, say yes for ssl connection.
619            // Construct ssl stream and replace with current one.
620            self.stream.write(BeMessage::EncryptionResponseSsl).await?;
621            self.stream.upgrade_to_ssl(context).await?;
622        } else {
623            // If no, say no for encryption.
624            self.stream.write(BeMessage::EncryptionResponseNo).await?;
625        }
626
627        Ok(())
628    }
629
630    async fn process_startup_msg(&mut self, msg: FeStartupMessage) -> PsqlResult<()> {
631        // Check SSL enforcement: if SSL is enforced but connection is not using SSL, reject
632        if let Some(ref tls_config) = self.tls_config
633            && tls_config.enforce_ssl
634            && !self.stream.is_ssl_connection().await
635        {
636            return Err(PsqlError::StartupError(
637                "SSL connection is required but not established".into(),
638            ));
639        }
640
641        let db_name = msg
642            .config
643            .get("database")
644            .cloned()
645            .unwrap_or_else(|| "dev".to_owned());
646        let user_name = msg
647            .config
648            .get("user")
649            .cloned()
650            .unwrap_or_else(|| "root".to_owned());
651
652        let session = self
653            .session_mgr
654            .connect(&db_name, &user_name, self.peer_addr.clone())
655            .map_err(PsqlError::StartupError)?;
656
657        let application_name = msg.config.get("application_name");
658        if let Some(application_name) = application_name {
659            session
660                .set_config("application_name", application_name.clone())
661                .map_err(PsqlError::StartupError)?;
662        }
663
664        match session.user_authenticator() {
665            UserAuthenticator::None => {
666                self.stream.write_no_flush(BeMessage::AuthenticationOk)?;
667
668                // Cancel request need this for identify and verification. According to postgres
669                // doc, it should be written to buffer after receive AuthenticationOk.
670                self.stream
671                    .write_no_flush(BeMessage::BackendKeyData(session.id()))?;
672
673                self.stream.write_no_flush(BeMessage::ParameterStatus(
674                    BeParameterStatusMessage::TimeZone(&session.get_config("timezone")?),
675                ))?;
676                self.stream
677                    .write_parameter_status_msg_no_flush(&ParameterStatus {
678                        application_name: application_name.cloned(),
679                    })?;
680                self.ready_for_query()?;
681            }
682            UserAuthenticator::ClearText(_)
683            | UserAuthenticator::OAuth { .. }
684            | UserAuthenticator::Ldap(..) => {
685                self.stream
686                    .write_no_flush(BeMessage::AuthenticationCleartextPassword)?;
687            }
688            UserAuthenticator::Md5WithSalt { salt, .. } => {
689                self.stream
690                    .write_no_flush(BeMessage::AuthenticationMd5Password(salt))?;
691            }
692        }
693
694        self.session = Some(session);
695        self.state = PgProtocolState::Regular;
696        Ok(())
697    }
698
699    async fn process_password_msg(&mut self, msg: FePasswordMessage) -> PsqlResult<()> {
700        let session = self.session.as_ref().unwrap();
701        let authenticator = session.user_authenticator();
702        authenticator.authenticate(&msg.password).await?;
703        self.stream.write_no_flush(BeMessage::AuthenticationOk)?;
704        self.stream.write_no_flush(BeMessage::ParameterStatus(
705            BeParameterStatusMessage::TimeZone(&session.get_config("timezone")?),
706        ))?;
707        self.stream
708            .write_parameter_status_msg_no_flush(&ParameterStatus::default())?;
709        self.ready_for_query()?;
710        self.state = PgProtocolState::Regular;
711        Ok(())
712    }
713
714    fn process_cancel_msg(&mut self, m: FeCancelMessage) -> PsqlResult<()> {
715        let session_id = (m.target_process_id, m.target_secret_key);
716        tracing::trace!("cancel query in session: {:?}", session_id);
717        self.session_mgr.cancel_queries_in_session(session_id);
718        self.session_mgr.cancel_creating_jobs_in_session(session_id);
719        self.is_terminate = true;
720        Ok(())
721    }
722
723    async fn process_query_msg(&mut self, sql: Arc<str>) -> PsqlResult<()> {
724        let truncated_sql = record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
725        let session = self.session.clone().unwrap();
726
727        session.check_idle_in_transaction_timeout()?;
728        // Store only truncated SQL in context to prevent excessive memory usage from large SQL.
729        let _exec_context_guard = session.init_exec_context(truncated_sql.into());
730        self.inner_process_query_msg(sql, session.clone()).await
731    }
732
733    async fn inner_process_query_msg(
734        &mut self,
735        sql: Arc<str>,
736        session: Arc<SM::Session>,
737    ) -> PsqlResult<()> {
738        // Parse sql.
739        let stmts =
740            Parser::parse_sql(&sql).map_err(|err| PsqlError::SimpleQueryError(err.into()))?;
741        // The following inner_process_query_msg_one_stmt can be slow. Release potential large String early.
742        drop(sql);
743        if stmts.is_empty() {
744            self.stream.write_no_flush(BeMessage::EmptyQueryResponse)?;
745        }
746
747        // Execute multiple statements in simple query. KISS later.
748        for stmt in stmts {
749            self.inner_process_query_msg_one_stmt(stmt, session.clone())
750                .await?;
751        }
752        // Put this line inside the for loop above will lead to unfinished/stuck regress test...Not
753        // sure the reason.
754        self.ready_for_query()?;
755        Ok(())
756    }
757
758    async fn inner_process_query_msg_one_stmt(
759        &mut self,
760        stmt: Statement,
761        session: Arc<SM::Session>,
762    ) -> PsqlResult<()> {
763        let session = session.clone();
764
765        // execute query
766        let res = session.clone().run_one_query(stmt, Format::Text).await;
767
768        // Take all remaining notices (if any) and send them before `CommandComplete`.
769        while let Some(notice) = session.next_notice().now_or_never() {
770            self.stream
771                .write_no_flush(BeMessage::NoticeResponse(&notice))?;
772        }
773
774        let mut res = res.map_err(PsqlError::SimpleQueryError)?;
775
776        for notice in res.notices() {
777            self.stream
778                .write_no_flush(BeMessage::NoticeResponse(notice))?;
779        }
780
781        let status = res.status();
782        if let Some(ref application_name) = status.application_name {
783            self.stream.write_no_flush(BeMessage::ParameterStatus(
784                BeParameterStatusMessage::ApplicationName(application_name),
785            ))?;
786        }
787
788        if res.is_copy_query_to_stdout() {
789            self.stream
790                .write_no_flush(BeMessage::CopyOutResponse(res.row_desc().len()))?;
791            let mut count = 0;
792            while let Some(row_set) = res.values_stream().next().await {
793                let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
794                for row in row_set {
795                    self.stream.write_no_flush(BeMessage::CopyData(&row))?;
796                    count += 1;
797                }
798            }
799
800            self.stream.write_no_flush(BeMessage::CopyDone)?;
801
802            // Run the callback before sending the `CommandComplete` message.
803            res.run_callback().await?;
804
805            self.stream
806                .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
807                    stmt_type: res.stmt_type(),
808                    rows_cnt: count,
809                }))?;
810        } else if res.is_query() {
811            self.stream
812                .write_no_flush(BeMessage::RowDescription(res.row_desc()))?;
813
814            let mut rows_cnt = 0;
815
816            while let Some(row_set) = res.values_stream().next().await {
817                let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
818                for row in row_set {
819                    self.stream.write_no_flush(BeMessage::DataRow(&row))?;
820                    rows_cnt += 1;
821                }
822            }
823
824            // Run the callback before sending the `CommandComplete` message.
825            res.run_callback().await?;
826
827            self.stream
828                .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
829                    stmt_type: res.stmt_type(),
830                    rows_cnt,
831                }))?;
832        } else if res.stmt_type().is_dml() && !res.stmt_type().is_returning() {
833            let first_row_set = res.values_stream().next().await;
834            let first_row_set = match first_row_set {
835                None => {
836                    return Err(PsqlError::Uncategorized(
837                        anyhow::anyhow!("no affected rows in output").into(),
838                    ));
839                }
840                Some(row) => row.map_err(PsqlError::SimpleQueryError)?,
841            };
842            let affected_rows_str = first_row_set[0].values()[0]
843                .as_ref()
844                .expect("compute node should return affected rows in output");
845
846            assert!(matches!(res.row_cnt_format(), Some(Format::Text)));
847            let affected_rows_cnt = String::from_utf8(affected_rows_str.to_vec())
848                .unwrap()
849                .parse()
850                .unwrap_or_default();
851
852            // Run the callback before sending the `CommandComplete` message.
853            res.run_callback().await?;
854
855            self.stream
856                .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
857                    stmt_type: res.stmt_type(),
858                    rows_cnt: affected_rows_cnt,
859                }))?;
860        } else {
861            // Run the callback before sending the `CommandComplete` message.
862            res.run_callback().await?;
863
864            self.stream
865                .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
866                    stmt_type: res.stmt_type(),
867                    rows_cnt: 0,
868                }))?;
869        }
870
871        Ok(())
872    }
873
874    fn process_terminate(&mut self) {
875        self.is_terminate = true;
876    }
877
878    fn process_health_check(&mut self) {
879        tracing::debug!("health check");
880        self.is_terminate = true;
881    }
882
883    async fn process_parse_msg(&mut self, mut msg: FeParseMessage) -> PsqlResult<()> {
884        let sql = Arc::from(cstr_to_str(&msg.sql_bytes).unwrap());
885        record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
886        let session = self.session.clone().unwrap();
887        let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
888        let type_ids = std::mem::take(&mut msg.type_ids);
889        // The inner_process_parse_msg can be slow. Release potential large FeParseMessage early.
890        drop(msg);
891        self.inner_process_parse_msg(session, sql, statement_name, type_ids)
892            .await?;
893        Ok(())
894    }
895
896    async fn inner_process_parse_msg(
897        &mut self,
898        session: Arc<SM::Session>,
899        sql: Arc<str>,
900        statement_name: String,
901        type_ids: Vec<i32>,
902    ) -> PsqlResult<()> {
903        if statement_name.is_empty() {
904            // Remove the unnamed prepare statement first, in case the unsupported sql binds a wrong
905            // prepare statement.
906            self.unnamed_prepare_statement.take();
907        } else if self.prepare_statement_store.contains_key(&statement_name) {
908            return Err(PsqlError::ExtendedPrepareError(
909                "Duplicated statement name".into(),
910            ));
911        }
912
913        let stmt = {
914            let stmts = Parser::parse_sql(&sql)
915                .map_err(|err| PsqlError::ExtendedPrepareError(err.into()))?;
916            drop(sql);
917            if stmts.len() > 1 {
918                return Err(PsqlError::ExtendedPrepareError(
919                    "Only one statement is allowed in extended query mode".into(),
920                ));
921            }
922
923            stmts.into_iter().next()
924        };
925
926        let param_types: Vec<Option<DataType>> = type_ids
927            .iter()
928            .map(|&id| {
929                // 0 means unspecified type
930                // ref: https://www.postgresql.org/docs/15/protocol-message-formats.html#:~:text=Placing%20a%20zero%20here%20is%20equivalent%20to%20leaving%20the%20type%20unspecified.
931                if id == 0 {
932                    Ok(None)
933                } else {
934                    DataType::from_oid(id)
935                        .map(Some)
936                        .map_err(|e| PsqlError::ExtendedPrepareError(e.into()))
937                }
938            })
939            .try_collect()?;
940
941        let prepare_statement = session
942            .parse(stmt, param_types)
943            .await
944            .map_err(PsqlError::ExtendedPrepareError)?;
945
946        if statement_name.is_empty() {
947            self.unnamed_prepare_statement.replace(prepare_statement);
948        } else {
949            self.prepare_statement_store
950                .insert(statement_name.clone(), prepare_statement);
951        }
952
953        self.statement_portal_dependency
954            .entry(statement_name)
955            .or_default()
956            .clear();
957
958        self.stream.write_no_flush(BeMessage::ParseComplete)?;
959        Ok(())
960    }
961
962    fn process_bind_msg(&mut self, msg: FeBindMessage) -> PsqlResult<()> {
963        let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
964        let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
965        let session = self.session.clone().unwrap();
966
967        if self.portal_store.contains_key(&portal_name) {
968            return Err(PsqlError::Uncategorized("Duplicated portal name".into()));
969        }
970
971        let prepare_statement = self.get_statement(&statement_name)?;
972
973        let result_formats = msg
974            .result_format_codes
975            .iter()
976            .map(|&format_code| Format::from_i16(format_code))
977            .try_collect()?;
978        let param_formats = msg
979            .param_format_codes
980            .iter()
981            .map(|&format_code| Format::from_i16(format_code))
982            .try_collect()?;
983
984        let portal = session
985            .bind(prepare_statement, msg.params, param_formats, result_formats)
986            .map_err(PsqlError::Uncategorized)?;
987
988        if portal_name.is_empty() {
989            self.result_cache.remove(&portal_name);
990            self.unnamed_portal.replace(portal);
991        } else {
992            assert!(
993                !self.result_cache.contains_key(&portal_name),
994                "Named portal never can be overridden."
995            );
996            self.portal_store.insert(portal_name.clone(), portal);
997        }
998
999        self.statement_portal_dependency
1000            .get_mut(&statement_name)
1001            .unwrap()
1002            .push(portal_name);
1003
1004        self.stream.write_no_flush(BeMessage::BindComplete)?;
1005        Ok(())
1006    }
1007
1008    async fn process_execute_msg(&mut self, msg: FeExecuteMessage) -> PsqlResult<()> {
1009        let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
1010        let row_max = msg.max_rows as usize;
1011        drop(msg);
1012        let session = self.session.clone().unwrap();
1013
1014        match self.result_cache.remove(&portal_name) {
1015            Some(mut result_cache) => {
1016                assert!(self.portal_store.contains_key(&portal_name));
1017
1018                let is_consume_completed =
1019                    result_cache.consume::<S>(row_max, &mut self.stream).await?;
1020
1021                if !is_consume_completed {
1022                    self.result_cache.insert(portal_name, result_cache);
1023                }
1024            }
1025            _ => {
1026                let portal = self.get_portal(&portal_name)?;
1027                let sql = format!("{}", portal);
1028                let truncated_sql =
1029                    record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
1030                drop(sql);
1031
1032                session.check_idle_in_transaction_timeout()?;
1033                // Store only truncated SQL in context to prevent excessive memory usage from large SQL.
1034                let _exec_context_guard = session.init_exec_context(truncated_sql.into());
1035                let result = session.clone().execute(portal).await;
1036
1037                let pg_response = result.map_err(PsqlError::ExtendedExecuteError)?;
1038                let mut result_cache = ResultCache::new(pg_response);
1039                let is_consume_completed =
1040                    result_cache.consume::<S>(row_max, &mut self.stream).await?;
1041                if !is_consume_completed {
1042                    self.result_cache.insert(portal_name, result_cache);
1043                }
1044            }
1045        }
1046
1047        Ok(())
1048    }
1049
1050    fn process_describe_msg(&mut self, msg: FeDescribeMessage) -> PsqlResult<()> {
1051        let name = cstr_to_str(&msg.name).unwrap().to_owned();
1052        let session = self.session.clone().unwrap();
1053        //  b'S' => Statement
1054        //  b'P' => Portal
1055
1056        assert!(msg.kind == b'S' || msg.kind == b'P');
1057        if msg.kind == b'S' {
1058            let prepare_statement = self.get_statement(&name)?;
1059
1060            let (param_types, row_descriptions) = self
1061                .session
1062                .clone()
1063                .unwrap()
1064                .describe_statement(prepare_statement)
1065                .map_err(PsqlError::Uncategorized)?;
1066            self.stream.write_no_flush(BeMessage::ParameterDescription(
1067                &param_types.iter().map(|t| t.to_oid()).collect_vec(),
1068            ))?;
1069
1070            if row_descriptions.is_empty() {
1071                // According https://www.postgresql.org/docs/current/protocol-flow.html#:~:text=The%20response%20is%20a%20RowDescri[…]0a%20query%20that%20will%20return%20rows%3B,
1072                // return NoData message if the statement is not a query.
1073                self.stream.write_no_flush(BeMessage::NoData)?;
1074            } else {
1075                self.stream
1076                    .write_no_flush(BeMessage::RowDescription(&row_descriptions))?;
1077            }
1078        } else if msg.kind == b'P' {
1079            let portal = self.get_portal(&name)?;
1080
1081            let row_descriptions = session
1082                .describe_portal(portal)
1083                .map_err(PsqlError::Uncategorized)?;
1084
1085            if row_descriptions.is_empty() {
1086                // According https://www.postgresql.org/docs/current/protocol-flow.html#:~:text=The%20response%20is%20a%20RowDescri[…]0a%20query%20that%20will%20return%20rows%3B,
1087                // return NoData message if the statement is not a query.
1088                self.stream.write_no_flush(BeMessage::NoData)?;
1089            } else {
1090                self.stream
1091                    .write_no_flush(BeMessage::RowDescription(&row_descriptions))?;
1092            }
1093        }
1094        Ok(())
1095    }
1096
1097    fn process_close_msg(&mut self, msg: FeCloseMessage) -> PsqlResult<()> {
1098        let name = cstr_to_str(&msg.name).unwrap().to_owned();
1099        assert!(msg.kind == b'S' || msg.kind == b'P');
1100        if msg.kind == b'S' {
1101            if name.is_empty() {
1102                self.unnamed_prepare_statement = None;
1103            } else {
1104                self.prepare_statement_store.remove(&name);
1105            }
1106            for portal_name in self
1107                .statement_portal_dependency
1108                .remove(&name)
1109                .unwrap_or_default()
1110            {
1111                self.remove_portal(&portal_name);
1112            }
1113        } else if msg.kind == b'P' {
1114            self.remove_portal(&name);
1115        }
1116        self.stream.write_no_flush(BeMessage::CloseComplete)?;
1117        Ok(())
1118    }
1119
1120    fn remove_portal(&mut self, portal_name: &str) {
1121        if portal_name.is_empty() {
1122            self.unnamed_portal = None;
1123        } else {
1124            self.portal_store.remove(portal_name);
1125        }
1126        self.result_cache.remove(portal_name);
1127    }
1128
1129    fn get_portal(&self, portal_name: &str) -> PsqlResult<<SM::Session as Session>::Portal> {
1130        if portal_name.is_empty() {
1131            Ok(self
1132                .unnamed_portal
1133                .as_ref()
1134                .ok_or_else(|| PsqlError::Uncategorized("unnamed portal not found".into()))?
1135                .clone())
1136        } else {
1137            Ok(self
1138                .portal_store
1139                .get(portal_name)
1140                .ok_or_else(|| {
1141                    PsqlError::Uncategorized(format!("Portal {} not found", portal_name).into())
1142                })?
1143                .clone())
1144        }
1145    }
1146
1147    fn get_statement(
1148        &self,
1149        statement_name: &str,
1150    ) -> PsqlResult<<SM::Session as Session>::PreparedStatement> {
1151        if statement_name.is_empty() {
1152            Ok(self
1153                .unnamed_prepare_statement
1154                .as_ref()
1155                .ok_or_else(|| {
1156                    PsqlError::Uncategorized("unnamed prepare statement not found".into())
1157                })?
1158                .clone())
1159        } else {
1160            Ok(self
1161                .prepare_statement_store
1162                .get(statement_name)
1163                .ok_or_else(|| {
1164                    PsqlError::Uncategorized(
1165                        format!("Prepare statement {} not found", statement_name).into(),
1166                    )
1167                })?
1168                .clone())
1169        }
1170    }
1171}
1172
1173enum PgStreamInner<S> {
1174    /// Used for the intermediate state when converting from unencrypted to ssl stream.
1175    Placeholder,
1176    /// An unencrypted stream.
1177    Unencrypted(S),
1178    /// An ssl stream.
1179    Ssl(SslStream<S>),
1180}
1181
1182/// Trait for a byte stream that can be used for pg protocol.
1183pub trait PgByteStream: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1184impl<S> PgByteStream for S where S: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1185
1186/// Wraps a byte stream and read/write pg messages.
1187///
1188/// Cloning a `PgStream` will share the same stream but a fresh & independent write buffer,
1189/// so that it can be used to write messages concurrently without interference.
1190pub struct PgStream<S> {
1191    /// The underlying stream.
1192    stream: Arc<Mutex<PgStreamInner<S>>>,
1193    /// Write into buffer before flush to stream.
1194    write_buf: BytesMut,
1195    read_header: Option<FeMessageHeader>,
1196}
1197
1198impl<S> PgStream<S> {
1199    /// Create a new `PgStream` with the given stream and default write buffer capacity.
1200    pub fn new(stream: S) -> Self {
1201        const DEFAULT_WRITE_BUF_CAPACITY: usize = 10 * 1024;
1202
1203        Self {
1204            stream: Arc::new(Mutex::new(PgStreamInner::Unencrypted(stream))),
1205            write_buf: BytesMut::with_capacity(DEFAULT_WRITE_BUF_CAPACITY),
1206            read_header: None,
1207        }
1208    }
1209
1210    /// Check if the current connection is using SSL
1211    async fn is_ssl_connection(&self) -> bool {
1212        let stream = self.stream.lock().await;
1213        matches!(*stream, PgStreamInner::Ssl(_))
1214    }
1215}
1216
1217impl<S> Clone for PgStream<S> {
1218    fn clone(&self) -> Self {
1219        Self {
1220            stream: Arc::clone(&self.stream),
1221            write_buf: BytesMut::with_capacity(self.write_buf.capacity()),
1222            read_header: self.read_header.clone(),
1223        }
1224    }
1225}
1226
1227/// At present there is a hard-wired set of parameters for which
1228/// ParameterStatus will be generated: they are:
1229///
1230///  * `server_version`
1231///  * `server_encoding`
1232///  * `client_encoding`
1233///  * `application_name`
1234///  * `is_superuser`
1235///  * `session_authorization`
1236///  * `DateStyle`
1237///  * `IntervalStyle`
1238///  * `TimeZone`
1239///  * `integer_datetimes`
1240///  * `standard_conforming_string`
1241///
1242/// See: <https://www.postgresql.org/docs/9.2/static/protocol-flow.html#PROTOCOL-ASYNC>.
1243#[derive(Debug, Default, Clone)]
1244pub struct ParameterStatus {
1245    pub application_name: Option<String>,
1246}
1247
1248impl<S> PgStream<S>
1249where
1250    S: PgByteStream,
1251{
1252    async fn read_startup(&mut self) -> io::Result<FeMessage> {
1253        let mut stream = self.stream.lock().await;
1254        match &mut *stream {
1255            PgStreamInner::Placeholder => unreachable!(),
1256            PgStreamInner::Unencrypted(stream) => FeStartupMessage::read(stream).await,
1257            PgStreamInner::Ssl(ssl_stream) => FeStartupMessage::read(ssl_stream).await,
1258        }
1259    }
1260
1261    async fn read_header(&mut self) -> io::Result<()> {
1262        let mut stream = self.stream.lock().await;
1263        match &mut *stream {
1264            PgStreamInner::Placeholder => unreachable!(),
1265            PgStreamInner::Unencrypted(stream) => {
1266                self.read_header = Some(FeMessage::read_header(stream).await?);
1267                Ok(())
1268            }
1269            PgStreamInner::Ssl(ssl_stream) => {
1270                self.read_header = Some(FeMessage::read_header(ssl_stream).await?);
1271                Ok(())
1272            }
1273        }
1274    }
1275
1276    async fn read_body(&mut self) -> io::Result<FeMessage> {
1277        let mut stream = self.stream.lock().await;
1278        let header = self
1279            .read_header
1280            .take()
1281            .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1282        match &mut *stream {
1283            PgStreamInner::Placeholder => unreachable!(),
1284            PgStreamInner::Unencrypted(stream) => FeMessage::read_body(stream, header).await,
1285            PgStreamInner::Ssl(ssl_stream) => FeMessage::read_body(ssl_stream, header).await,
1286        }
1287    }
1288
1289    async fn skip_body(&mut self) -> io::Result<()> {
1290        let mut stream = self.stream.lock().await;
1291        let header = self
1292            .read_header
1293            .take()
1294            .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1295        match &mut *stream {
1296            PgStreamInner::Placeholder => unreachable!(),
1297            PgStreamInner::Unencrypted(stream) => FeMessage::skip_body(stream, header).await,
1298            PgStreamInner::Ssl(ssl_stream) => FeMessage::skip_body(ssl_stream, header).await,
1299        }
1300    }
1301
1302    fn write_parameter_status_msg_no_flush(&mut self, status: &ParameterStatus) -> io::Result<()> {
1303        self.write_no_flush(BeMessage::ParameterStatus(
1304            BeParameterStatusMessage::ClientEncoding(SERVER_ENCODING),
1305        ))?;
1306        self.write_no_flush(BeMessage::ParameterStatus(
1307            BeParameterStatusMessage::StandardConformingString(STANDARD_CONFORMING_STRINGS),
1308        ))?;
1309        self.write_no_flush(BeMessage::ParameterStatus(
1310            BeParameterStatusMessage::ServerVersion(PG_VERSION),
1311        ))?;
1312        if let Some(application_name) = &status.application_name {
1313            self.write_no_flush(BeMessage::ParameterStatus(
1314                BeParameterStatusMessage::ApplicationName(application_name),
1315            ))?;
1316        }
1317        Ok(())
1318    }
1319
1320    pub fn write_no_flush(&mut self, message: BeMessage<'_>) -> io::Result<()> {
1321        BeMessage::write(&mut self.write_buf, message)
1322    }
1323
1324    async fn write(&mut self, message: BeMessage<'_>) -> io::Result<()> {
1325        self.write_no_flush(message)?;
1326        self.flush().await?;
1327        Ok(())
1328    }
1329
1330    async fn flush(&mut self) -> io::Result<()> {
1331        let mut stream = self.stream.lock().await;
1332        match &mut *stream {
1333            PgStreamInner::Placeholder => unreachable!(),
1334            PgStreamInner::Unencrypted(stream) => {
1335                stream.write_all(&self.write_buf).await?;
1336                stream.flush().await?;
1337            }
1338            PgStreamInner::Ssl(ssl_stream) => {
1339                ssl_stream.write_all(&self.write_buf).await?;
1340                ssl_stream.flush().await?;
1341            }
1342        }
1343        self.write_buf.clear();
1344        Ok(())
1345    }
1346}
1347
1348impl<S> PgStream<S>
1349where
1350    S: PgByteStream,
1351{
1352    /// Convert the underlying stream to ssl stream based on the given context.
1353    async fn upgrade_to_ssl(&mut self, ssl_ctx: &SslContextRef) -> PsqlResult<()> {
1354        let mut stream = self.stream.lock().await;
1355
1356        match std::mem::replace(&mut *stream, PgStreamInner::Placeholder) {
1357            PgStreamInner::Unencrypted(unencrypted_stream) => {
1358                let ssl = openssl::ssl::Ssl::new(ssl_ctx).unwrap();
1359                let mut ssl_stream =
1360                    tokio_openssl::SslStream::new(ssl, unencrypted_stream).unwrap();
1361
1362                if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
1363                    tracing::warn!(error = %e.as_report(), "Unable to set up an ssl connection");
1364                    let _ = ssl_stream.shutdown().await;
1365                    return Err(e.into());
1366                }
1367
1368                *stream = PgStreamInner::Ssl(ssl_stream);
1369            }
1370            PgStreamInner::Ssl(_) => panic!("the stream is already ssl"),
1371            PgStreamInner::Placeholder => unreachable!(),
1372        }
1373
1374        Ok(())
1375    }
1376}
1377
1378fn build_ssl_ctx_from_config(tls_config: &TlsConfig) -> PsqlResult<SslContext> {
1379    let mut acceptor = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls()).unwrap();
1380
1381    let key_path = &tls_config.key;
1382    let cert_path = &tls_config.cert;
1383
1384    // Build ssl acceptor according to the config.
1385    // Now we set every verify to true.
1386    acceptor
1387        .set_private_key_file(key_path, openssl::ssl::SslFiletype::PEM)
1388        .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1389    acceptor
1390        .set_ca_file(cert_path)
1391        .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1392    acceptor
1393        .set_certificate_chain_file(cert_path)
1394        .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1395    let acceptor = acceptor.build();
1396
1397    Ok(acceptor.into_context())
1398}
1399
1400pub mod truncated_fmt {
1401    use std::fmt::*;
1402
1403    struct TruncatedFormatter<'a, 'b> {
1404        remaining: usize,
1405        finished: bool,
1406        f: &'a mut Formatter<'b>,
1407    }
1408    impl Write for TruncatedFormatter<'_, '_> {
1409        fn write_str(&mut self, s: &str) -> Result {
1410            if self.finished {
1411                return Ok(());
1412            }
1413
1414            if self.remaining < s.len() {
1415                let actual = s.floor_char_boundary(self.remaining);
1416                self.f.write_str(&s[0..actual])?;
1417                self.remaining -= actual;
1418                self.f.write_str(&format!("...(truncated,{})", s.len()))?;
1419                self.finished = true; // so that ...(truncated) is printed exactly once
1420            } else {
1421                self.f.write_str(s)?;
1422                self.remaining -= s.len();
1423            }
1424            Ok(())
1425        }
1426    }
1427
1428    pub struct TruncatedFmt<'a, T>(pub &'a T, pub usize);
1429
1430    impl<T> Debug for TruncatedFmt<'_, T>
1431    where
1432        T: Debug,
1433    {
1434        fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1435            TruncatedFormatter {
1436                remaining: self.1,
1437                finished: false,
1438                f,
1439            }
1440            .write_fmt(format_args!("{:?}", self.0))
1441        }
1442    }
1443
1444    impl<T> Display for TruncatedFmt<'_, T>
1445    where
1446        T: Display,
1447    {
1448        fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1449            TruncatedFormatter {
1450                remaining: self.1,
1451                finished: false,
1452                f,
1453            }
1454            .write_fmt(format_args!("{}", self.0))
1455        }
1456    }
1457
1458    #[cfg(test)]
1459    mod tests {
1460        use super::*;
1461
1462        #[test]
1463        fn test_trunc_utf8() {
1464            assert_eq!(
1465                format!("{}", TruncatedFmt(&"select '🌊';", 10)),
1466                "select '...(truncated,14)",
1467            );
1468        }
1469    }
1470}
1471
1472#[cfg(test)]
1473mod tests {
1474    use std::collections::HashSet;
1475
1476    use super::*;
1477
1478    #[test]
1479    fn test_redact_parsable_sql() {
1480        let keywords = Arc::new(HashSet::from(["v2".into(), "v4".into(), "b".into()]));
1481        let sql = r"
1482        create source temp (k bigint, v varchar) with (
1483            connector = 'datagen',
1484            v1 = 123,
1485            v2 = 'with',
1486            v3 = false,
1487            v4 = '',
1488        ) FORMAT plain ENCODE json (a='1',b='2')
1489        ";
1490        assert_eq!(
1491            redact_sql(sql, keywords),
1492            "CREATE SOURCE temp (k BIGINT, v CHARACTER VARYING) WITH (connector = 'datagen', v1 = 123, v2 = [REDACTED], v3 = false, v4 = [REDACTED]) FORMAT PLAIN ENCODE JSON (a = '1', b = [REDACTED])"
1493        );
1494    }
1495}