pgwire/
pg_protocol.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::collections::HashMap;
17use std::io::ErrorKind;
18use std::panic::AssertUnwindSafe;
19use std::pin::Pin;
20use std::str::Utf8Error;
21use std::sync::{Arc, LazyLock, Weak};
22use std::time::{Duration, Instant};
23use std::{io, str};
24
25use bytes::{Bytes, BytesMut};
26use futures::FutureExt;
27use futures::stream::StreamExt;
28use itertools::Itertools;
29use openssl::ssl::{SslAcceptor, SslContext, SslContextRef, SslMethod};
30use risingwave_common::types::DataType;
31use risingwave_common::util::deployment::Deployment;
32use risingwave_common::util::env_var::env_var_is_true;
33use risingwave_common::util::panic::FutureCatchUnwindExt;
34use risingwave_common::util::query_log::*;
35use risingwave_common::{PG_VERSION, SERVER_ENCODING, STANDARD_CONFORMING_STRINGS};
36use risingwave_sqlparser::ast::{RedactSqlOptionKeywordsRef, Statement};
37use risingwave_sqlparser::parser::Parser;
38use thiserror_ext::AsReport;
39use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
40use tokio::sync::Mutex;
41use tokio_openssl::SslStream;
42use tracing::Instrument;
43
44use crate::error::{PsqlError, PsqlResult};
45use crate::error_or_notice::Severity;
46use crate::memory_manager::{MessageMemoryGuard, MessageMemoryManagerRef};
47use crate::net::AddressRef;
48use crate::pg_extended::ResultCache;
49use crate::pg_message::{
50    BeCommandCompleteMessage, BeMessage, BeParameterStatusMessage, FeBindMessage, FeCancelMessage,
51    FeCloseMessage, FeDescribeMessage, FeExecuteMessage, FeMessage, FeMessageHeader,
52    FeParseMessage, FePasswordMessage, FeStartupMessage, ServerThrottleReason, TransactionStatus,
53};
54use crate::pg_server::{Session, SessionManager, UserAuthenticator};
55use crate::types::Format;
56
57/// Truncates query log if it's longer than `RW_QUERY_LOG_TRUNCATE_LEN`, to avoid log file being too
58/// large.
59static RW_QUERY_LOG_TRUNCATE_LEN: LazyLock<usize> =
60    LazyLock::new(|| match std::env::var("RW_QUERY_LOG_TRUNCATE_LEN") {
61        Ok(len) if len.parse::<usize>().is_ok() => len.parse::<usize>().unwrap(),
62        _ => 65536,
63    });
64
65tokio::task_local! {
66    /// The current session. Concrete type is erased for different session implementations.
67    pub static CURRENT_SESSION: Weak<dyn Any + Send + Sync>
68}
69
70/// The state machine for each psql connection.
71/// Read pg messages from tcp stream and write results back.
72pub struct PgProtocol<S, SM>
73where
74    SM: SessionManager,
75{
76    /// Used for write/read pg messages.
77    stream: PgStream<S>,
78    /// Current states of pg connection.
79    state: PgProtocolState,
80    /// Whether the connection is terminated.
81    is_terminate: bool,
82
83    session_mgr: Arc<SM>,
84    session: Option<Arc<SM::Session>>,
85
86    result_cache: HashMap<String, ResultCache<<SM::Session as Session>::ValuesStream>>,
87    unnamed_prepare_statement:
88        Option<PreparedStatementData<<SM::Session as Session>::PreparedStatement>>,
89    prepare_statement_store:
90        HashMap<String, PreparedStatementData<<SM::Session as Session>::PreparedStatement>>,
91    unnamed_portal: Option<PortalData<<SM::Session as Session>::Portal>>,
92    portal_store: HashMap<String, PortalData<<SM::Session as Session>::Portal>>,
93    // Used to store the dependency of portal and prepare statement.
94    // When we close a prepare statement, we need to close all the portals that depend on it.
95    statement_portal_dependency: HashMap<String, Vec<String>>,
96
97    // Used for ssl connection.
98    // If None, not expected to build ssl connection (panic).
99    tls_context: Option<SslContext>,
100
101    // TLS configuration including SSL enforcement setting
102    tls_config: Option<TlsConfig>,
103
104    // Used in extended query protocol. When encounter error in extended query, we need to ignore
105    // the following message util sync message.
106    ignore_util_sync: bool,
107
108    // Client Address
109    peer_addr: AddressRef,
110
111    redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
112    message_memory_manager: MessageMemoryManagerRef,
113}
114
115/// Configures TLS encryption for connections.
116#[derive(Debug, Clone)]
117pub struct TlsConfig {
118    /// The path to the TLS certificate.
119    pub cert: String,
120    /// The path to the TLS key.
121    pub key: String,
122    /// Whether to enforce SSL connections (reject non-SSL clients).
123    pub enforce_ssl: bool,
124}
125
126impl TlsConfig {
127    pub fn new_default() -> anyhow::Result<Option<Self>> {
128        let cert = std::env::var("RW_SSL_CERT").ok();
129        let key = std::env::var("RW_SSL_KEY").ok();
130        let enforce_ssl = env_var_is_true("RW_SSL_ENFORCE");
131
132        if cert.is_some() ^ key.is_some() {
133            return Err(anyhow::anyhow!(
134                "RW_SSL_CERT and RW_SSL_KEY must be set together"
135            ));
136        }
137
138        if enforce_ssl && cert.is_none() {
139            return Err(anyhow::anyhow!(
140                "RW_SSL_ENFORCE requires RW_SSL_CERT and RW_SSL_KEY to be set"
141            ));
142        }
143
144        let (Some(cert), Some(key)) = (cert, key) else {
145            return Ok(None);
146        };
147
148        tracing::info!(
149            "RW_SSL_CERT={}, RW_SSL_KEY={}, RW_SSL_ENFORCE={}",
150            cert,
151            key,
152            enforce_ssl
153        );
154        Ok(Some(Self {
155            cert,
156            key,
157            enforce_ssl,
158        }))
159    }
160}
161
162impl<S, SM> Drop for PgProtocol<S, SM>
163where
164    SM: SessionManager,
165{
166    fn drop(&mut self) {
167        if let Some(session) = &self.session {
168            // Clear the session in session manager.
169            self.session_mgr.end_session(session);
170        }
171    }
172}
173
174/// States flow happened from top to down.
175enum PgProtocolState {
176    Startup,
177    Regular,
178}
179
180#[derive(Clone)]
181struct PreparedStatementData<S> {
182    statement: S,
183    sql: Arc<str>,
184}
185
186#[derive(Clone)]
187struct PortalData<P> {
188    portal: P,
189    sql: Arc<str>,
190}
191
192/// Truncate 0 from C string in Bytes and stringify it (returns slice, no allocations).
193///
194/// PG protocol strings are always C strings.
195pub fn cstr_to_str(b: &Bytes) -> Result<&str, Utf8Error> {
196    let without_null = if b.last() == Some(&0) {
197        &b[..b.len() - 1]
198    } else {
199        &b[..]
200    };
201    std::str::from_utf8(without_null)
202}
203
204fn get_redacted_and_truncated_sql(
205    sql: &str,
206    redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
207) -> String {
208    let redacted_sql = if let Some(keywords) = redact_sql_option_keywords
209        && !keywords.is_empty()
210    {
211        redact_sql(sql, keywords)
212    } else {
213        sql.to_owned()
214    };
215    let truncated = truncated_fmt::TruncatedFmt(&redacted_sql, *RW_QUERY_LOG_TRUNCATE_LEN);
216    truncated.to_string()
217}
218
219/// Record `sql` in the current tracing span.
220fn record_sql_in_span(
221    sql: &str,
222    redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
223    span: &mut tracing::Span,
224) {
225    let redacted_and_truncated_sql =
226        get_redacted_and_truncated_sql(sql, redact_sql_option_keywords);
227    span.record("sql", tracing::field::display(&redacted_and_truncated_sql));
228}
229
230fn record_user_in_span(user: &str, span: &mut tracing::Span) {
231    span.record("user", tracing::field::display(user));
232}
233
234/// Redacts sensitive SQL fields. Data in DML is not redacted.
235fn redact_sql(sql: &str, keywords: RedactSqlOptionKeywordsRef) -> String {
236    match Parser::parse_sql(sql) {
237        Ok(sqls) => sqls
238            .into_iter()
239            .map(|sql| sql.to_redacted_string(keywords.clone()))
240            .join(";"),
241        Err(_) => sql.to_owned(),
242    }
243}
244
245#[derive(Clone)]
246pub struct ConnectionContext {
247    pub tls_config: Option<TlsConfig>,
248    pub redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
249    pub message_memory_manager: MessageMemoryManagerRef,
250}
251
252impl<S, SM> PgProtocol<S, SM>
253where
254    S: PgByteStream,
255    SM: SessionManager,
256{
257    pub fn new(
258        stream: S,
259        session_mgr: Arc<SM>,
260        peer_addr: AddressRef,
261        context: ConnectionContext,
262    ) -> Self {
263        let ConnectionContext {
264            tls_config,
265            redact_sql_option_keywords,
266            message_memory_manager,
267        } = context;
268        Self {
269            stream: PgStream::new(stream),
270            is_terminate: false,
271            state: PgProtocolState::Startup,
272            session_mgr,
273            session: None,
274            tls_context: tls_config
275                .as_ref()
276                .and_then(|e| build_ssl_ctx_from_config(e).ok()),
277            tls_config,
278            result_cache: Default::default(),
279            unnamed_prepare_statement: Default::default(),
280            prepare_statement_store: Default::default(),
281            unnamed_portal: Default::default(),
282            portal_store: Default::default(),
283            statement_portal_dependency: Default::default(),
284            ignore_util_sync: false,
285            peer_addr,
286            redact_sql_option_keywords,
287            message_memory_manager,
288        }
289    }
290
291    /// Run the protocol to serve the connection.
292    pub async fn run(&mut self) {
293        let mut notice_fut = None;
294
295        loop {
296            // Once a session is present, create a future to subscribe and send notices asynchronously.
297            if notice_fut.is_none()
298                && let Some(session) = self.session.clone()
299            {
300                let mut stream = self.stream.clone();
301                notice_fut = Some(Box::pin(async move {
302                    loop {
303                        let notice = session.next_notice().await;
304                        if let Err(e) = stream.write(BeMessage::NoticeResponse(&notice)).await {
305                            tracing::error!(error = %e.as_report(), notice, "failed to send notice");
306                        }
307                    }
308                }));
309            }
310
311            // Read and process messages.
312            let process = std::pin::pin!(async {
313                let (msg, _memory_guard) = match self.read_message().await {
314                    Ok(msg) => msg,
315                    Err(e) => {
316                        tracing::error!(error = %e.as_report(), "error when reading message");
317                        return true; // terminate the connection
318                    }
319                };
320                tracing::trace!(?msg, "received message");
321                self.process(msg).await
322            });
323
324            let terminated = if let Some(notice_fut) = notice_fut.as_mut() {
325                tokio::select! {
326                    _ = notice_fut => unreachable!(),
327                    terminated = process => terminated,
328                }
329            } else {
330                process.await
331            };
332
333            if terminated {
334                break;
335            }
336        }
337    }
338
339    /// Processes one message. Returns true if the connection is terminated.
340    pub async fn process(&mut self, msg: FeMessage) -> bool {
341        self.do_process(msg).await.is_none() || self.is_terminate
342    }
343
344    /// The root tracing span for processing a message. The target of the span is
345    /// [`PGWIRE_ROOT_SPAN_TARGET`].
346    ///
347    /// This is used to provide context for the (slow) query logs and traces.
348    ///
349    /// The span is only effective if there's a current session and the message is
350    /// query-related. Otherwise, `Span::none()` is returned.
351    fn root_span_for_msg(&self, msg: &FeMessage) -> tracing::Span {
352        let Some(session_id) = self.session.as_ref().map(|s| s.id().0) else {
353            return tracing::Span::none();
354        };
355
356        let mode = match msg {
357            FeMessage::Query(_) => "simple query",
358            FeMessage::Parse(_) => "extended query parse",
359            FeMessage::Execute(_) => "extended query execute",
360            _ => return tracing::Span::none(),
361        };
362
363        let mut span = tracing::info_span!(
364            target: PGWIRE_ROOT_SPAN_TARGET,
365            "handle_query",
366            mode,
367            session_id,
368            sql = tracing::field::Empty,
369            user = tracing::field::Empty,
370        );
371        match msg {
372            FeMessage::Execute(execute_msg) => {
373                if let Ok(portal_name) = cstr_to_str(&execute_msg.portal_name)
374                    && let Ok(sql) = self.get_portal_sql(portal_name)
375                {
376                    record_sql_in_span(&sql, self.redact_sql_option_keywords.clone(), &mut span);
377                }
378            }
379            _ => {
380                if let Ok(sql) = msg.get_sql()
381                    && let Some(sql) = sql
382                {
383                    record_sql_in_span(sql, self.redact_sql_option_keywords.clone(), &mut span);
384                }
385            }
386        }
387        if let Some(current_session) = self.session.as_ref() {
388            record_user_in_span(&current_session.user(), &mut span);
389        }
390        span
391    }
392
393    /// Return type `Option<()>` is essentially a bool, but allows `?` for early return.
394    /// - `None` means to terminate the current connection
395    /// - `Some(())` means to continue processing the next message
396    async fn do_process(&mut self, msg: FeMessage) -> Option<()> {
397        let span = self.root_span_for_msg(&msg);
398        let weak_session = self
399            .session
400            .as_ref()
401            .map(|s| Arc::downgrade(s) as Weak<dyn Any + Send + Sync>);
402
403        // Processing the message itself.
404        //
405        // Note: pin the future to avoid stack overflow as we'll wrap it multiple times
406        // in the following code.
407        let fut = Box::pin(self.do_process_inner(msg));
408
409        // Set the current session as the context when processing the message, if exists.
410        let fut = async move {
411            if let Some(session) = weak_session {
412                CURRENT_SESSION.scope(session, fut).await
413            } else {
414                fut.await
415            }
416        };
417
418        // Catch unwind.
419        let fut = async move {
420            AssertUnwindSafe(fut)
421                .rw_catch_unwind()
422                .await
423                .unwrap_or_else(|payload| {
424                    Err(PsqlError::Panic(
425                        panic_message::panic_message(&payload).to_owned(),
426                    ))
427                })
428        };
429
430        // Slow query log.
431        let fut = async move {
432            let period = *SLOW_QUERY_LOG_PERIOD;
433            let mut fut = std::pin::pin!(fut);
434            let mut elapsed = Duration::ZERO;
435
436            // Report the SQL in the log periodically if the query is slow.
437            loop {
438                match tokio::time::timeout(period, &mut fut).await {
439                    Ok(result) => break result,
440                    Err(_) => {
441                        elapsed += period;
442                        tracing::info!(
443                            target: PGWIRE_SLOW_QUERY_LOG,
444                            elapsed = %format_args!("{}ms", elapsed.as_millis()),
445                            "slow query"
446                        );
447                    }
448                }
449            }
450        };
451
452        // Query log.
453        let fut = async move {
454            if !tracing::Span::current().is_none() {
455                tracing::info!(
456                    target: PGWIRE_QUERY_LOG,
457                    status = "started",
458                );
459            }
460
461            let start = Instant::now();
462            let result = fut.await;
463            let elapsed = start.elapsed();
464
465            // Always log if an error occurs.
466            // Note: all messages will be processed through this code path, making it the
467            //       only necessary place to log errors.
468            if let Err(error) = &result {
469                if cfg!(debug_assertions) && !Deployment::current().is_ci() {
470                    // For local debugging, we print the error with backtrace.
471                    // It's useful only when:
472                    // - no additional context is added to the error
473                    // - backtrace is captured in the error
474                    // - backtrace is not printed in the middle
475                    tracing::error!(error = ?error.as_report(), "error when process message");
476                } else {
477                    tracing::error!(error = %error.as_report(), "error when process message");
478                }
479            }
480
481            // Log to optionally-enabled target `PGWIRE_QUERY_LOG`.
482            // Only log if we're currently in a tracing span set in `span_for_msg`.
483            if !tracing::Span::current().is_none() {
484                tracing::info!(
485                    target: PGWIRE_QUERY_LOG,
486                    status = if result.is_ok() { "ok" } else { "err" },
487                    time = %format_args!("{}ms", elapsed.as_millis()),
488                );
489            }
490
491            result
492        };
493
494        // Tracing span.
495        let fut = fut.instrument(span);
496
497        // Execute the future and handle the error.
498        match fut.await {
499            Ok(()) => Some(()),
500            Err(e) => {
501                match e {
502                    PsqlError::IoError(io_err) => {
503                        if io_err.kind() == std::io::ErrorKind::UnexpectedEof {
504                            return None;
505                        }
506                    }
507
508                    PsqlError::SslError(_) => {
509                        // For ssl error, because the stream has already been consumed, so there is
510                        // no way to write more message.
511                        return None;
512                    }
513
514                    PsqlError::StartupError(_) | PsqlError::PasswordError => {
515                        self.stream
516                            .write_no_flush(BeMessage::ErrorResponse {
517                                error: &e,
518                                // At this time we're not in a session, use compact error message for
519                                // better alignment with Postgres' UI.
520                                pretty: false,
521                                severity: Some(Severity::Fatal),
522                            })
523                            .ok()?;
524                        let _ = self.stream.flush().await;
525                        return None;
526                    }
527
528                    PsqlError::SimpleQueryError(_) | PsqlError::ServerThrottle(_) => {
529                        self.stream
530                            .write_no_flush(BeMessage::ErrorResponse {
531                                error: &e,
532                                pretty: true,
533                                severity: None,
534                            })
535                            .ok()?;
536                        self.ready_for_query().ok()?;
537                    }
538
539                    PsqlError::IdleInTxnTimeout | PsqlError::Panic(_) => {
540                        self.stream
541                            .write_no_flush(BeMessage::ErrorResponse {
542                                error: &e,
543                                pretty: true,
544                                severity: None,
545                            })
546                            .ok()?;
547                        let _ = self.stream.flush().await;
548
549                        // 1. Catching the panic during message processing may leave the session in an
550                        // inconsistent state. We forcefully close the connection (then end the
551                        // session) here for safety.
552                        // 2. Idle in transaction timeout should also close the connection.
553                        return None;
554                    }
555
556                    PsqlError::Uncategorized(_)
557                    | PsqlError::ExtendedPrepareError(_)
558                    | PsqlError::ExtendedExecuteError(_) => {
559                        self.stream
560                            .write_no_flush(BeMessage::ErrorResponse {
561                                error: &e,
562                                pretty: true,
563                                severity: None,
564                            })
565                            .ok()?;
566                    }
567                }
568                let _ = self.stream.flush().await;
569                Some(())
570            }
571        }
572    }
573
574    async fn do_process_inner(&mut self, msg: FeMessage) -> PsqlResult<()> {
575        // Ignore util sync message.
576        if self.ignore_util_sync {
577            if let FeMessage::Sync = msg {
578            } else {
579                tracing::trace!("ignore message {:?} until sync.", msg);
580                return Ok(());
581            }
582        }
583
584        match msg {
585            FeMessage::Gss => self.process_gss_msg().await?,
586            FeMessage::Ssl => self.process_ssl_msg().await?,
587            FeMessage::Startup(msg) => self.process_startup_msg(msg).await?,
588            FeMessage::Password(msg) => self.process_password_msg(msg).await?,
589            FeMessage::Query(query_msg) => {
590                let sql = Arc::from(query_msg.get_sql()?);
591                // The process_query_msg can be slow. Release potential large FeQueryMessage early.
592                drop(query_msg);
593                self.process_query_msg(sql).await?
594            }
595            FeMessage::CancelQuery(m) => self.process_cancel_msg(m)?,
596            FeMessage::Terminate => self.process_terminate(),
597            FeMessage::Parse(m) => {
598                if let Err(err) = self.process_parse_msg(m).await {
599                    self.ignore_util_sync = true;
600                    return Err(err);
601                }
602            }
603            FeMessage::Bind(m) => {
604                if let Err(err) = self.process_bind_msg(m) {
605                    self.ignore_util_sync = true;
606                    return Err(err);
607                }
608            }
609            FeMessage::Execute(m) => {
610                if let Err(err) = self.process_execute_msg(m).await {
611                    self.ignore_util_sync = true;
612                    return Err(err);
613                }
614            }
615            FeMessage::Describe(m) => {
616                if let Err(err) = self.process_describe_msg(m) {
617                    self.ignore_util_sync = true;
618                    return Err(err);
619                }
620            }
621            FeMessage::Sync => {
622                self.ignore_util_sync = false;
623                self.ready_for_query()?
624            }
625            FeMessage::Close(m) => {
626                if let Err(err) = self.process_close_msg(m) {
627                    self.ignore_util_sync = true;
628                    return Err(err);
629                }
630            }
631            FeMessage::Flush => {
632                if let Err(err) = self.stream.flush().await {
633                    self.ignore_util_sync = true;
634                    return Err(err.into());
635                }
636            }
637            FeMessage::HealthCheck => self.process_health_check(),
638            FeMessage::ServerThrottle(reason) => match reason {
639                ServerThrottleReason::TooLargeMessage => {
640                    return Err(PsqlError::ServerThrottle(format!(
641                        "max_single_query_size_bytes {} has been exceeded, please either reduce the query size or increase the limit",
642                        self.message_memory_manager.max_filter_bytes
643                    )));
644                }
645                ServerThrottleReason::TooManyMemoryUsage => {
646                    return Err(PsqlError::ServerThrottle(format!(
647                        "max_total_query_size_bytes {} has been exceeded, please either retry or increase the limit",
648                        self.message_memory_manager.max_running_bytes
649                    )));
650                }
651            },
652        }
653        self.stream.flush().await?;
654        Ok(())
655    }
656
657    pub async fn read_message(&mut self) -> io::Result<(FeMessage, Option<MessageMemoryGuard>)> {
658        match self.state {
659            PgProtocolState::Startup => self
660                .stream
661                .read_startup()
662                .await
663                .map(|message: FeMessage| (message, None)),
664            PgProtocolState::Regular => {
665                self.stream.read_header().await?;
666                let guard = if let Some(ref header) = self.stream.read_header {
667                    let payload_len = std::cmp::max(header.payload_len, 0) as u64;
668                    let (reason, guard) = self.message_memory_manager.add(payload_len);
669                    if let Some(reason) = reason {
670                        // Release the memory ASAP.
671                        drop(guard);
672                        self.stream.skip_body().await?;
673                        return Ok((FeMessage::ServerThrottle(reason), None));
674                    }
675                    guard
676                } else {
677                    None
678                };
679                let message = self.stream.read_body().await?;
680                Ok((message, guard))
681            }
682        }
683    }
684
685    /// Writes a `ReadyForQuery` message to the client without flushing.
686    fn ready_for_query(&mut self) -> io::Result<()> {
687        self.stream.write_no_flush(BeMessage::ReadyForQuery(
688            self.session
689                .as_ref()
690                .map(|s| s.transaction_status())
691                .unwrap_or(TransactionStatus::Idle),
692        ))
693    }
694
695    async fn process_gss_msg(&mut self) -> PsqlResult<()> {
696        // We don't support GSSAPI, so we just say no gracefully.
697        self.stream.write(BeMessage::EncryptionResponseNo).await?;
698        Ok(())
699    }
700
701    async fn process_ssl_msg(&mut self) -> PsqlResult<()> {
702        if let Some(context) = self.tls_context.as_ref() {
703            // If got and ssl context, say yes for ssl connection.
704            // Construct ssl stream and replace with current one.
705            self.stream.write(BeMessage::EncryptionResponseSsl).await?;
706            self.stream.upgrade_to_ssl(context).await?;
707        } else {
708            // If no, say no for encryption.
709            self.stream.write(BeMessage::EncryptionResponseNo).await?;
710        }
711
712        Ok(())
713    }
714
715    async fn process_startup_msg(&mut self, msg: FeStartupMessage) -> PsqlResult<()> {
716        // Check SSL enforcement: if SSL is enforced but connection is not using SSL, reject
717        if let Some(ref tls_config) = self.tls_config
718            && tls_config.enforce_ssl
719            && !self.stream.is_ssl_connection().await
720        {
721            return Err(PsqlError::StartupError(
722                "SSL connection is required but not established".into(),
723            ));
724        }
725
726        let db_name = msg
727            .config
728            .get("database")
729            .cloned()
730            .unwrap_or_else(|| "dev".to_owned());
731        let user_name = msg
732            .config
733            .get("user")
734            .cloned()
735            .unwrap_or_else(|| "root".to_owned());
736
737        let session = self
738            .session_mgr
739            .connect(&db_name, &user_name, self.peer_addr.clone())
740            .map_err(|e| PsqlError::StartupError(e.into()))?;
741
742        if let Some(options) = msg.config.get("options") {
743            for (key, value) in parse_options(options)? {
744                session
745                    .set_config(&key, value)
746                    .map_err(|e| PsqlError::StartupError(e.into()))?;
747            }
748        }
749        // dedicated `application_name` has higher priority than `options`
750        let application_name = msg.config.get("application_name");
751        if let Some(application_name) = application_name {
752            session
753                .set_config("application_name", application_name.clone())
754                .map_err(|e| PsqlError::StartupError(e.into()))?;
755        }
756
757        match session.user_authenticator() {
758            UserAuthenticator::None => {
759                self.stream.write_no_flush(BeMessage::AuthenticationOk)?;
760
761                // Cancel request need this for identify and verification. According to postgres
762                // doc, it should be written to buffer after receive AuthenticationOk.
763                self.stream
764                    .write_no_flush(BeMessage::BackendKeyData(session.id()))?;
765
766                self.stream.write_no_flush(BeMessage::ParameterStatus(
767                    BeParameterStatusMessage::TimeZone(
768                        &session
769                            .get_config("timezone")
770                            .map_err(|e| PsqlError::StartupError(e.into()))?,
771                    ),
772                ))?;
773                self.stream
774                    .write_parameter_status_msg_no_flush(&ParameterStatus {
775                        application_name: application_name.cloned(),
776                    })?;
777                self.ready_for_query()?;
778            }
779            UserAuthenticator::ClearText(_)
780            | UserAuthenticator::OAuth { .. }
781            | UserAuthenticator::Ldap(..) => {
782                self.stream
783                    .write_no_flush(BeMessage::AuthenticationCleartextPassword)?;
784            }
785            UserAuthenticator::Md5WithSalt { salt, .. } => {
786                self.stream
787                    .write_no_flush(BeMessage::AuthenticationMd5Password(salt))?;
788            }
789        }
790
791        self.session = Some(session);
792        self.state = PgProtocolState::Regular;
793        Ok(())
794    }
795
796    async fn process_password_msg(&mut self, msg: FePasswordMessage) -> PsqlResult<()> {
797        let session = self.session.as_ref().unwrap();
798        let authenticator = session.user_authenticator();
799        authenticator.authenticate(&msg.password).await?;
800        self.stream.write_no_flush(BeMessage::AuthenticationOk)?;
801        let timezone = session
802            .get_config("timezone")
803            .map_err(|e| PsqlError::StartupError(e.into()))?;
804        self.stream.write_no_flush(BeMessage::ParameterStatus(
805            BeParameterStatusMessage::TimeZone(&timezone),
806        ))?;
807        self.stream
808            .write_parameter_status_msg_no_flush(&ParameterStatus::default())?;
809        self.ready_for_query()?;
810        self.state = PgProtocolState::Regular;
811        Ok(())
812    }
813
814    fn process_cancel_msg(&mut self, m: FeCancelMessage) -> PsqlResult<()> {
815        let session_id = (m.target_process_id, m.target_secret_key);
816        tracing::trace!("cancel query in session: {:?}", session_id);
817        self.session_mgr.cancel_queries_in_session(session_id);
818        self.session_mgr.cancel_creating_jobs_in_session(session_id);
819        self.is_terminate = true;
820        Ok(())
821    }
822
823    async fn process_query_msg(&mut self, sql: Arc<str>) -> PsqlResult<()> {
824        let truncated_sql =
825            get_redacted_and_truncated_sql(&sql, self.redact_sql_option_keywords.clone());
826        let session = self.session.clone().unwrap();
827
828        session.check_idle_in_transaction_timeout()?;
829        // Store only truncated SQL in context to prevent excessive memory usage from large SQL.
830        let _exec_context_guard = session.init_exec_context(truncated_sql.into());
831        self.inner_process_query_msg(sql, session.clone()).await
832    }
833
834    async fn inner_process_query_msg(
835        &mut self,
836        sql: Arc<str>,
837        session: Arc<SM::Session>,
838    ) -> PsqlResult<()> {
839        // Parse sql.
840        let stmts =
841            Parser::parse_sql(&sql).map_err(|err| PsqlError::SimpleQueryError(err.into()))?;
842        // The following inner_process_query_msg_one_stmt can be slow. Release potential large String early.
843        drop(sql);
844        if stmts.is_empty() {
845            self.stream.write_no_flush(BeMessage::EmptyQueryResponse)?;
846        }
847
848        // Execute multiple statements in simple query. KISS later.
849        for stmt in stmts {
850            self.inner_process_query_msg_one_stmt(stmt, session.clone())
851                .await?;
852        }
853        // Put this line inside the for loop above will lead to unfinished/stuck regress test...Not
854        // sure the reason.
855        self.ready_for_query()?;
856        Ok(())
857    }
858
859    async fn inner_process_query_msg_one_stmt(
860        &mut self,
861        stmt: Statement,
862        session: Arc<SM::Session>,
863    ) -> PsqlResult<()> {
864        let session = session.clone();
865
866        // execute query
867        let res = session.clone().run_one_query(stmt, Format::Text).await;
868
869        // Take all remaining notices (if any) and send them before `CommandComplete`.
870        while let Some(notice) = session.next_notice().now_or_never() {
871            self.stream
872                .write_no_flush(BeMessage::NoticeResponse(&notice))?;
873        }
874
875        let mut res = res.map_err(|e| PsqlError::SimpleQueryError(e.into()))?;
876
877        for notice in res.notices() {
878            self.stream
879                .write_no_flush(BeMessage::NoticeResponse(notice))?;
880        }
881
882        let status = res.status();
883        if let Some(ref application_name) = status.application_name {
884            self.stream.write_no_flush(BeMessage::ParameterStatus(
885                BeParameterStatusMessage::ApplicationName(application_name),
886            ))?;
887        }
888
889        if res.is_copy_query_to_stdout() {
890            self.stream
891                .write_no_flush(BeMessage::CopyOutResponse(res.row_desc().len()))?;
892            let mut count = 0;
893            while let Some(row_set) = res.values_stream().next().await {
894                let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
895                for row in row_set {
896                    self.stream.write_no_flush(BeMessage::CopyData(&row))?;
897                    count += 1;
898                }
899            }
900
901            self.stream.write_no_flush(BeMessage::CopyDone)?;
902
903            // Run the callback before sending the `CommandComplete` message.
904            res.run_callback().await?;
905
906            self.stream
907                .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
908                    stmt_type: res.stmt_type(),
909                    rows_cnt: count,
910                }))?;
911        } else if res.is_query() {
912            self.stream
913                .write_no_flush(BeMessage::RowDescription(res.row_desc()))?;
914
915            let mut rows_cnt = 0;
916
917            while let Some(row_set) = res.values_stream().next().await {
918                let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
919                for row in row_set {
920                    self.stream.write_no_flush(BeMessage::DataRow(&row))?;
921                    rows_cnt += 1;
922                }
923            }
924
925            // Run the callback before sending the `CommandComplete` message.
926            res.run_callback().await?;
927
928            self.stream
929                .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
930                    stmt_type: res.stmt_type(),
931                    rows_cnt,
932                }))?;
933        } else if res.stmt_type().is_dml() && !res.stmt_type().is_returning() {
934            let first_row_set = res.values_stream().next().await;
935            let first_row_set = match first_row_set {
936                None => {
937                    return Err(PsqlError::Uncategorized(
938                        anyhow::anyhow!("no affected rows in output").into(),
939                    ));
940                }
941                Some(row) => row.map_err(PsqlError::SimpleQueryError)?,
942            };
943            let affected_rows_str = first_row_set[0].values()[0]
944                .as_ref()
945                .expect("compute node should return affected rows in output");
946
947            assert!(matches!(res.row_cnt_format(), Some(Format::Text)));
948            let affected_rows_cnt = String::from_utf8(affected_rows_str.to_vec())
949                .unwrap()
950                .parse()
951                .unwrap_or_default();
952
953            // Run the callback before sending the `CommandComplete` message.
954            res.run_callback().await?;
955
956            self.stream
957                .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
958                    stmt_type: res.stmt_type(),
959                    rows_cnt: affected_rows_cnt,
960                }))?;
961        } else {
962            // Run the callback before sending the `CommandComplete` message.
963            res.run_callback().await?;
964
965            self.stream
966                .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
967                    stmt_type: res.stmt_type(),
968                    rows_cnt: 0,
969                }))?;
970        }
971
972        Ok(())
973    }
974
975    fn process_terminate(&mut self) {
976        self.is_terminate = true;
977    }
978
979    fn process_health_check(&mut self) {
980        tracing::debug!("health check");
981        self.is_terminate = true;
982    }
983
984    async fn process_parse_msg(&mut self, mut msg: FeParseMessage) -> PsqlResult<()> {
985        let sql = Arc::from(cstr_to_str(&msg.sql_bytes).unwrap());
986        let session = self.session.clone().unwrap();
987        let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
988        let type_ids = std::mem::take(&mut msg.type_ids);
989        // The inner_process_parse_msg can be slow. Release potential large FeParseMessage early.
990        drop(msg);
991        self.inner_process_parse_msg(session, sql, statement_name, type_ids)
992            .await?;
993        Ok(())
994    }
995
996    async fn inner_process_parse_msg(
997        &mut self,
998        session: Arc<SM::Session>,
999        sql: Arc<str>,
1000        statement_name: String,
1001        type_ids: Vec<i32>,
1002    ) -> PsqlResult<()> {
1003        if statement_name.is_empty() {
1004            // Remove the unnamed prepare statement first, in case the unsupported sql binds a wrong
1005            // prepare statement.
1006            self.unnamed_prepare_statement.take();
1007        } else if self.prepare_statement_store.contains_key(&statement_name) {
1008            return Err(PsqlError::ExtendedPrepareError(
1009                "Duplicated statement name".into(),
1010            ));
1011        }
1012
1013        let stmt = {
1014            let stmts = Parser::parse_sql(&sql)
1015                .map_err(|err| PsqlError::ExtendedPrepareError(err.into()))?;
1016            if stmts.len() > 1 {
1017                return Err(PsqlError::ExtendedPrepareError(
1018                    "Only one statement is allowed in extended query mode".into(),
1019                ));
1020            }
1021
1022            stmts.into_iter().next()
1023        };
1024
1025        let param_types: Vec<Option<DataType>> = type_ids
1026            .iter()
1027            .map(|&id| {
1028                // 0 means unspecified type
1029                // ref: https://www.postgresql.org/docs/15/protocol-message-formats.html#:~:text=Placing%20a%20zero%20here%20is%20equivalent%20to%20leaving%20the%20type%20unspecified.
1030                if id == 0 {
1031                    Ok(None)
1032                } else {
1033                    DataType::from_oid(id)
1034                        .map(Some)
1035                        .map_err(|e| PsqlError::ExtendedPrepareError(e.into()))
1036                }
1037            })
1038            .try_collect()?;
1039
1040        let prepare_statement = session
1041            .parse(stmt, param_types)
1042            .await
1043            .map_err(|e| PsqlError::ExtendedPrepareError(e.into()))?;
1044        let prepare_statement = PreparedStatementData {
1045            statement: prepare_statement,
1046            sql,
1047        };
1048
1049        if statement_name.is_empty() {
1050            self.unnamed_prepare_statement.replace(prepare_statement);
1051        } else {
1052            self.prepare_statement_store
1053                .insert(statement_name.clone(), prepare_statement);
1054        }
1055
1056        self.statement_portal_dependency
1057            .entry(statement_name)
1058            .or_default()
1059            .clear();
1060
1061        self.stream.write_no_flush(BeMessage::ParseComplete)?;
1062        Ok(())
1063    }
1064
1065    fn process_bind_msg(&mut self, msg: FeBindMessage) -> PsqlResult<()> {
1066        let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
1067        let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
1068        let session = self.session.clone().unwrap();
1069
1070        if self.portal_store.contains_key(&portal_name) {
1071            return Err(PsqlError::Uncategorized("Duplicated portal name".into()));
1072        }
1073
1074        let prepare_statement = self.get_statement_data(&statement_name)?.clone();
1075
1076        let result_formats = msg
1077            .result_format_codes
1078            .iter()
1079            .map(|&format_code| Format::from_i16(format_code))
1080            .try_collect()?;
1081        let param_formats = msg
1082            .param_format_codes
1083            .iter()
1084            .map(|&format_code| Format::from_i16(format_code))
1085            .try_collect()?;
1086
1087        let portal = session
1088            .bind(
1089                prepare_statement.statement,
1090                msg.params,
1091                param_formats,
1092                result_formats,
1093            )
1094            .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1095        let portal = PortalData {
1096            portal,
1097            sql: prepare_statement.sql,
1098        };
1099
1100        if portal_name.is_empty() {
1101            self.result_cache.remove(&portal_name);
1102            self.unnamed_portal.replace(portal);
1103        } else {
1104            assert!(
1105                !self.result_cache.contains_key(&portal_name),
1106                "Named portal never can be overridden."
1107            );
1108            self.portal_store.insert(portal_name.clone(), portal);
1109        }
1110
1111        self.statement_portal_dependency
1112            .get_mut(&statement_name)
1113            .unwrap()
1114            .push(portal_name);
1115
1116        self.stream.write_no_flush(BeMessage::BindComplete)?;
1117        Ok(())
1118    }
1119
1120    async fn process_execute_msg(&mut self, msg: FeExecuteMessage) -> PsqlResult<()> {
1121        let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
1122        let row_max = msg.max_rows as usize;
1123        drop(msg);
1124        let session = self.session.clone().unwrap();
1125
1126        match self.result_cache.remove(&portal_name) {
1127            Some(mut result_cache) => {
1128                assert!(self.portal_store.contains_key(&portal_name));
1129
1130                let is_consume_completed =
1131                    result_cache.consume::<S>(row_max, &mut self.stream).await?;
1132
1133                if !is_consume_completed {
1134                    self.result_cache.insert(portal_name, result_cache);
1135                }
1136            }
1137            _ => {
1138                let portal = self.get_portal_data(&portal_name)?.clone();
1139                let sql = format!("{}", portal.portal);
1140                let truncated_sql =
1141                    get_redacted_and_truncated_sql(&sql, self.redact_sql_option_keywords.clone());
1142                drop(sql);
1143
1144                session.check_idle_in_transaction_timeout()?;
1145                // Store only truncated SQL in context to prevent excessive memory usage from large SQL.
1146                let _exec_context_guard = session.init_exec_context(truncated_sql.into());
1147                let result = session.clone().execute(portal.portal).await;
1148
1149                let pg_response = result.map_err(|e| PsqlError::ExtendedExecuteError(e.into()))?;
1150                let mut result_cache = ResultCache::new(pg_response);
1151                let is_consume_completed =
1152                    result_cache.consume::<S>(row_max, &mut self.stream).await?;
1153                if !is_consume_completed {
1154                    self.result_cache.insert(portal_name, result_cache);
1155                }
1156            }
1157        }
1158
1159        Ok(())
1160    }
1161
1162    fn process_describe_msg(&mut self, msg: FeDescribeMessage) -> PsqlResult<()> {
1163        let name = cstr_to_str(&msg.name).unwrap().to_owned();
1164        let session = self.session.clone().unwrap();
1165        //  b'S' => Statement
1166        //  b'P' => Portal
1167
1168        assert!(msg.kind == b'S' || msg.kind == b'P');
1169        if msg.kind == b'S' {
1170            let prepare_statement = self.get_statement(&name)?;
1171
1172            let (param_types, row_descriptions) = self
1173                .session
1174                .clone()
1175                .unwrap()
1176                .describe_statement(prepare_statement)
1177                .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1178            self.stream.write_no_flush(BeMessage::ParameterDescription(
1179                &param_types.iter().map(|t| t.to_oid()).collect_vec(),
1180            ))?;
1181
1182            if row_descriptions.is_empty() {
1183                // According https://www.postgresql.org/docs/current/protocol-flow.html#:~:text=The%20response%20is%20a%20RowDescri[…]0a%20query%20that%20will%20return%20rows%3B,
1184                // return NoData message if the statement is not a query.
1185                self.stream.write_no_flush(BeMessage::NoData)?;
1186            } else {
1187                self.stream
1188                    .write_no_flush(BeMessage::RowDescription(&row_descriptions))?;
1189            }
1190        } else if msg.kind == b'P' {
1191            let portal = self.get_portal(&name)?;
1192
1193            let row_descriptions = session
1194                .describe_portal(portal)
1195                .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1196
1197            if row_descriptions.is_empty() {
1198                // According https://www.postgresql.org/docs/current/protocol-flow.html#:~:text=The%20response%20is%20a%20RowDescri[…]0a%20query%20that%20will%20return%20rows%3B,
1199                // return NoData message if the statement is not a query.
1200                self.stream.write_no_flush(BeMessage::NoData)?;
1201            } else {
1202                self.stream
1203                    .write_no_flush(BeMessage::RowDescription(&row_descriptions))?;
1204            }
1205        }
1206        Ok(())
1207    }
1208
1209    fn process_close_msg(&mut self, msg: FeCloseMessage) -> PsqlResult<()> {
1210        let name = cstr_to_str(&msg.name).unwrap().to_owned();
1211        assert!(msg.kind == b'S' || msg.kind == b'P');
1212        if msg.kind == b'S' {
1213            if name.is_empty() {
1214                self.unnamed_prepare_statement = None;
1215            } else {
1216                self.prepare_statement_store.remove(&name);
1217            }
1218            for portal_name in self
1219                .statement_portal_dependency
1220                .remove(&name)
1221                .unwrap_or_default()
1222            {
1223                self.remove_portal(&portal_name);
1224            }
1225        } else if msg.kind == b'P' {
1226            self.remove_portal(&name);
1227        }
1228        self.stream.write_no_flush(BeMessage::CloseComplete)?;
1229        Ok(())
1230    }
1231
1232    fn remove_portal(&mut self, portal_name: &str) {
1233        if portal_name.is_empty() {
1234            self.unnamed_portal = None;
1235        } else {
1236            self.portal_store.remove(portal_name);
1237        }
1238        self.result_cache.remove(portal_name);
1239    }
1240
1241    fn get_portal(&self, portal_name: &str) -> PsqlResult<<SM::Session as Session>::Portal> {
1242        Ok(self.get_portal_data(portal_name)?.portal.clone())
1243    }
1244
1245    fn get_portal_data(
1246        &self,
1247        portal_name: &str,
1248    ) -> PsqlResult<&PortalData<<SM::Session as Session>::Portal>> {
1249        if portal_name.is_empty() {
1250            self.unnamed_portal
1251                .as_ref()
1252                .ok_or_else(|| PsqlError::Uncategorized("unnamed portal not found".into()))
1253        } else {
1254            self.portal_store.get(portal_name).ok_or_else(|| {
1255                PsqlError::Uncategorized(format!("Portal {} not found", portal_name).into())
1256            })
1257        }
1258    }
1259
1260    fn get_statement(
1261        &self,
1262        statement_name: &str,
1263    ) -> PsqlResult<<SM::Session as Session>::PreparedStatement> {
1264        Ok(self.get_statement_data(statement_name)?.statement.clone())
1265    }
1266
1267    fn get_statement_data(
1268        &self,
1269        statement_name: &str,
1270    ) -> PsqlResult<&PreparedStatementData<<SM::Session as Session>::PreparedStatement>> {
1271        if statement_name.is_empty() {
1272            self.unnamed_prepare_statement.as_ref().ok_or_else(|| {
1273                PsqlError::Uncategorized("unnamed prepare statement not found".into())
1274            })
1275        } else {
1276            self.prepare_statement_store
1277                .get(statement_name)
1278                .ok_or_else(|| {
1279                    PsqlError::Uncategorized(
1280                        format!("Prepare statement {} not found", statement_name).into(),
1281                    )
1282                })
1283        }
1284    }
1285
1286    fn get_portal_sql(&self, portal_name: &str) -> PsqlResult<Arc<str>> {
1287        Ok(self.get_portal_data(portal_name)?.sql.clone())
1288    }
1289}
1290
1291enum PgStreamInner<S> {
1292    /// Used for the intermediate state when converting from unencrypted to ssl stream.
1293    Placeholder,
1294    /// An unencrypted stream.
1295    Unencrypted(S),
1296    /// An ssl stream.
1297    Ssl(SslStream<S>),
1298}
1299
1300/// Trait for a byte stream that can be used for pg protocol.
1301pub trait PgByteStream: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1302impl<S> PgByteStream for S where S: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1303
1304/// Wraps a byte stream and read/write pg messages.
1305///
1306/// Cloning a `PgStream` will share the same stream but a fresh & independent write buffer,
1307/// so that it can be used to write messages concurrently without interference.
1308pub struct PgStream<S> {
1309    /// The underlying stream.
1310    stream: Arc<Mutex<PgStreamInner<S>>>,
1311    /// Write into buffer before flush to stream.
1312    write_buf: BytesMut,
1313    read_header: Option<FeMessageHeader>,
1314}
1315
1316impl<S> PgStream<S> {
1317    /// Create a new `PgStream` with the given stream and default write buffer capacity.
1318    pub fn new(stream: S) -> Self {
1319        const DEFAULT_WRITE_BUF_CAPACITY: usize = 10 * 1024;
1320
1321        Self {
1322            stream: Arc::new(Mutex::new(PgStreamInner::Unencrypted(stream))),
1323            write_buf: BytesMut::with_capacity(DEFAULT_WRITE_BUF_CAPACITY),
1324            read_header: None,
1325        }
1326    }
1327
1328    /// Check if the current connection is using SSL
1329    async fn is_ssl_connection(&self) -> bool {
1330        let stream = self.stream.lock().await;
1331        matches!(*stream, PgStreamInner::Ssl(_))
1332    }
1333}
1334
1335impl<S> Clone for PgStream<S> {
1336    fn clone(&self) -> Self {
1337        Self {
1338            stream: Arc::clone(&self.stream),
1339            write_buf: BytesMut::with_capacity(self.write_buf.capacity()),
1340            read_header: self.read_header.clone(),
1341        }
1342    }
1343}
1344
1345/// At present there is a hard-wired set of parameters for which
1346/// ParameterStatus will be generated: they are:
1347///
1348///  * `server_version`
1349///  * `server_encoding`
1350///  * `client_encoding`
1351///  * `application_name`
1352///  * `is_superuser`
1353///  * `session_authorization`
1354///  * `DateStyle`
1355///  * `IntervalStyle`
1356///  * `TimeZone`
1357///  * `integer_datetimes`
1358///  * `standard_conforming_string`
1359///
1360/// See: <https://www.postgresql.org/docs/9.2/static/protocol-flow.html#PROTOCOL-ASYNC>.
1361#[derive(Debug, Default, Clone)]
1362pub struct ParameterStatus {
1363    pub application_name: Option<String>,
1364}
1365
1366impl<S> PgStream<S>
1367where
1368    S: PgByteStream,
1369{
1370    async fn read_startup(&mut self) -> io::Result<FeMessage> {
1371        let mut stream = self.stream.lock().await;
1372        match &mut *stream {
1373            PgStreamInner::Placeholder => unreachable!(),
1374            PgStreamInner::Unencrypted(stream) => FeStartupMessage::read(stream).await,
1375            PgStreamInner::Ssl(ssl_stream) => FeStartupMessage::read(ssl_stream).await,
1376        }
1377    }
1378
1379    async fn read_header(&mut self) -> io::Result<()> {
1380        let mut stream = self.stream.lock().await;
1381        match &mut *stream {
1382            PgStreamInner::Placeholder => unreachable!(),
1383            PgStreamInner::Unencrypted(stream) => {
1384                self.read_header = Some(FeMessage::read_header(stream).await?);
1385                Ok(())
1386            }
1387            PgStreamInner::Ssl(ssl_stream) => {
1388                self.read_header = Some(FeMessage::read_header(ssl_stream).await?);
1389                Ok(())
1390            }
1391        }
1392    }
1393
1394    async fn read_body(&mut self) -> io::Result<FeMessage> {
1395        let mut stream = self.stream.lock().await;
1396        let header = self
1397            .read_header
1398            .take()
1399            .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1400        match &mut *stream {
1401            PgStreamInner::Placeholder => unreachable!(),
1402            PgStreamInner::Unencrypted(stream) => FeMessage::read_body(stream, header).await,
1403            PgStreamInner::Ssl(ssl_stream) => FeMessage::read_body(ssl_stream, header).await,
1404        }
1405    }
1406
1407    async fn skip_body(&mut self) -> io::Result<()> {
1408        let mut stream = self.stream.lock().await;
1409        let header = self
1410            .read_header
1411            .take()
1412            .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1413        match &mut *stream {
1414            PgStreamInner::Placeholder => unreachable!(),
1415            PgStreamInner::Unencrypted(stream) => FeMessage::skip_body(stream, header).await,
1416            PgStreamInner::Ssl(ssl_stream) => FeMessage::skip_body(ssl_stream, header).await,
1417        }
1418    }
1419
1420    fn write_parameter_status_msg_no_flush(&mut self, status: &ParameterStatus) -> io::Result<()> {
1421        self.write_no_flush(BeMessage::ParameterStatus(
1422            BeParameterStatusMessage::ClientEncoding(SERVER_ENCODING),
1423        ))?;
1424        self.write_no_flush(BeMessage::ParameterStatus(
1425            BeParameterStatusMessage::StandardConformingString(STANDARD_CONFORMING_STRINGS),
1426        ))?;
1427        self.write_no_flush(BeMessage::ParameterStatus(
1428            BeParameterStatusMessage::ServerVersion(PG_VERSION),
1429        ))?;
1430        if let Some(application_name) = &status.application_name {
1431            self.write_no_flush(BeMessage::ParameterStatus(
1432                BeParameterStatusMessage::ApplicationName(application_name),
1433            ))?;
1434        }
1435        Ok(())
1436    }
1437
1438    pub fn write_no_flush(&mut self, message: BeMessage<'_>) -> io::Result<()> {
1439        BeMessage::write(&mut self.write_buf, message)
1440    }
1441
1442    async fn write(&mut self, message: BeMessage<'_>) -> io::Result<()> {
1443        self.write_no_flush(message)?;
1444        self.flush().await?;
1445        Ok(())
1446    }
1447
1448    async fn flush(&mut self) -> io::Result<()> {
1449        let mut stream = self.stream.lock().await;
1450        match &mut *stream {
1451            PgStreamInner::Placeholder => unreachable!(),
1452            PgStreamInner::Unencrypted(stream) => {
1453                stream.write_all(&self.write_buf).await?;
1454                stream.flush().await?;
1455            }
1456            PgStreamInner::Ssl(ssl_stream) => {
1457                ssl_stream.write_all(&self.write_buf).await?;
1458                ssl_stream.flush().await?;
1459            }
1460        }
1461        self.write_buf.clear();
1462        Ok(())
1463    }
1464}
1465
1466impl<S> PgStream<S>
1467where
1468    S: PgByteStream,
1469{
1470    /// Convert the underlying stream to ssl stream based on the given context.
1471    async fn upgrade_to_ssl(&mut self, ssl_ctx: &SslContextRef) -> PsqlResult<()> {
1472        let mut stream = self.stream.lock().await;
1473
1474        match std::mem::replace(&mut *stream, PgStreamInner::Placeholder) {
1475            PgStreamInner::Unencrypted(unencrypted_stream) => {
1476                let ssl = openssl::ssl::Ssl::new(ssl_ctx).unwrap();
1477                let mut ssl_stream =
1478                    tokio_openssl::SslStream::new(ssl, unencrypted_stream).unwrap();
1479
1480                if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
1481                    tracing::warn!(error = %e.as_report(), "Unable to set up an ssl connection");
1482                    let _ = ssl_stream.shutdown().await;
1483                    return Err(e.into());
1484                }
1485
1486                *stream = PgStreamInner::Ssl(ssl_stream);
1487            }
1488            PgStreamInner::Ssl(_) => panic!("the stream is already ssl"),
1489            PgStreamInner::Placeholder => unreachable!(),
1490        }
1491
1492        Ok(())
1493    }
1494}
1495
1496fn build_ssl_ctx_from_config(tls_config: &TlsConfig) -> PsqlResult<SslContext> {
1497    let mut acceptor = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls()).unwrap();
1498
1499    let key_path = &tls_config.key;
1500    let cert_path = &tls_config.cert;
1501
1502    // Build ssl acceptor according to the config.
1503    // Now we set every verify to true.
1504    acceptor
1505        .set_private_key_file(key_path, openssl::ssl::SslFiletype::PEM)
1506        .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1507    acceptor
1508        .set_ca_file(cert_path)
1509        .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1510    acceptor
1511        .set_certificate_chain_file(cert_path)
1512        .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1513    let acceptor = acceptor.build();
1514
1515    Ok(acceptor.into_context())
1516}
1517
1518pub mod truncated_fmt {
1519    use std::fmt::*;
1520
1521    struct TruncatedFormatter<'a, 'b> {
1522        remaining: usize,
1523        finished: bool,
1524        f: &'a mut Formatter<'b>,
1525    }
1526    impl Write for TruncatedFormatter<'_, '_> {
1527        fn write_str(&mut self, s: &str) -> Result {
1528            if self.finished {
1529                return Ok(());
1530            }
1531
1532            if self.remaining < s.len() {
1533                let actual = s.floor_char_boundary(self.remaining);
1534                self.f.write_str(&s[0..actual])?;
1535                self.remaining -= actual;
1536                self.f.write_str(&format!("...(truncated,{})", s.len()))?;
1537                self.finished = true; // so that ...(truncated) is printed exactly once
1538            } else {
1539                self.f.write_str(s)?;
1540                self.remaining -= s.len();
1541            }
1542            Ok(())
1543        }
1544    }
1545
1546    pub struct TruncatedFmt<'a, T>(pub &'a T, pub usize);
1547
1548    impl<T> Debug for TruncatedFmt<'_, T>
1549    where
1550        T: Debug,
1551    {
1552        fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1553            TruncatedFormatter {
1554                remaining: self.1,
1555                finished: false,
1556                f,
1557            }
1558            .write_fmt(format_args!("{:?}", self.0))
1559        }
1560    }
1561
1562    impl<T> Display for TruncatedFmt<'_, T>
1563    where
1564        T: Display,
1565    {
1566        fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1567            TruncatedFormatter {
1568                remaining: self.1,
1569                finished: false,
1570                f,
1571            }
1572            .write_fmt(format_args!("{}", self.0))
1573        }
1574    }
1575
1576    #[cfg(test)]
1577    mod tests {
1578        use super::*;
1579
1580        #[test]
1581        fn test_trunc_utf8() {
1582            assert_eq!(
1583                format!("{}", TruncatedFmt(&"select '🌊';", 10)),
1584                "select '...(truncated,14)",
1585            );
1586        }
1587    }
1588}
1589
1590/// Handle `options` in `StartupMessage` from client
1591///
1592/// It is like shell arguments but only respects backslash-escape and space;
1593/// quotes have no special meaning and are handled literally.
1594///
1595/// PostgreSQL allows both `-c key=value` and `--key=value`.
1596///
1597/// `key-name` is normalized as `key_name`.
1598///
1599/// * <https://github.com/postgres/postgres/blob/REL_18_1/src/backend/utils/init/postinit.c#L487>
1600/// * <https://github.com/postgres/postgres/blob/REL_18_1/src/backend/tcop/postgres.c#L3866>
1601/// * <https://github.com/postgres/postgres/blob/REL_18_1/src/backend/utils/misc/guc.c#L6361>
1602fn parse_options(options: &str) -> PsqlResult<Vec<(String, String)>> {
1603    let mut args = Vec::new();
1604    let mut current_arg = String::new();
1605    let mut chars = options.chars().peekable();
1606
1607    while let Some(c) = chars.next() {
1608        if c == '\\' {
1609            if let Some(next_c) = chars.next() {
1610                current_arg.push(next_c);
1611            }
1612        } else if c.is_ascii_whitespace() {
1613            if !current_arg.is_empty() {
1614                args.push(std::mem::take(&mut current_arg));
1615            }
1616        } else {
1617            current_arg.push(c);
1618        }
1619    }
1620    if !current_arg.is_empty() {
1621        args.push(current_arg);
1622    }
1623
1624    let mut args_iter = args.into_iter();
1625    let mut config = Vec::new();
1626
1627    while let Some(arg) = args_iter.next() {
1628        if arg == "-c" {
1629            if let Some(config_str) = args_iter.next() {
1630                if let Some((key, value)) = config_str.split_once('=') {
1631                    let key = key.replace("-", "_");
1632                    config.push((key, value.to_owned()));
1633                } else {
1634                    return Err(PsqlError::StartupError(
1635                        format!("invalid config format: {}", config_str).into(),
1636                    ));
1637                }
1638            } else {
1639                return Err(PsqlError::StartupError("missing argument for -c".into()));
1640            }
1641        } else if let Some(config_str) = arg.strip_prefix("--") {
1642            if let Some((key, value)) = config_str.split_once('=') {
1643                let key = key.replace("-", "_");
1644                config.push((key, value.to_owned()));
1645            } else {
1646                return Err(PsqlError::StartupError(
1647                    format!("invalid config format: {}", config_str).into(),
1648                ));
1649            }
1650        } else {
1651            tracing::warn!(
1652                arg,
1653                "ignoring unrecognized option for backward compatibility"
1654            );
1655        }
1656    }
1657    Ok(config)
1658}
1659
1660#[cfg(test)]
1661mod tests {
1662    use std::collections::HashSet;
1663
1664    use super::*;
1665
1666    #[test]
1667    fn test_redact_parsable_sql() {
1668        let keywords = Arc::new(HashSet::from(["v2".into(), "v4".into(), "b".into()]));
1669        let sql = r"
1670        create source temp (k bigint, v varchar) with (
1671            connector = 'datagen',
1672            v1 = 123,
1673            v2 = 'with',
1674            v3 = false,
1675            v4 = '',
1676        ) FORMAT plain ENCODE json (a='1',b='2')
1677        ";
1678        assert_eq!(
1679            redact_sql(sql, keywords),
1680            "CREATE SOURCE temp (k BIGINT, v CHARACTER VARYING) WITH (connector = 'datagen', v1 = 123, v2 = [REDACTED], v3 = false, v4 = [REDACTED]) FORMAT PLAIN ENCODE JSON (a = '1', b = [REDACTED])"
1681        );
1682    }
1683
1684    #[test]
1685    fn test_redact_user_password_sql() {
1686        let keywords = Arc::new(HashSet::from(["password".into()]));
1687
1688        assert_eq!(
1689            redact_sql("ALTER USER WITH PASSWORD 'rw_password_2'", keywords.clone()),
1690            "ALTER USER WITH PASSWORD [REDACTED]"
1691        );
1692        assert_eq!(
1693            redact_sql(
1694                "ALTER USER foo WITH ENCRYPTED PASSWORD 'md5827ccb0eea8a706c4c34a16891f84e7b'",
1695                keywords.clone(),
1696            ),
1697            "ALTER USER foo WITH ENCRYPTED PASSWORD [REDACTED]"
1698        );
1699        assert_eq!(
1700            redact_sql("CREATE USER foo WITH PASSWORD 'rw_password_2'", keywords),
1701            "CREATE USER foo WITH PASSWORD [REDACTED]"
1702        );
1703    }
1704
1705    #[test]
1706    fn test_parse_options() {
1707        assert_eq!(parse_options("").unwrap(), vec![]);
1708        assert_eq!(
1709            parse_options("-c a=1 -c b=2").unwrap(),
1710            vec![("a".into(), "1".into()), ("b".into(), "2".into())]
1711        );
1712        assert_eq!(
1713            parse_options("-c   key=value").unwrap(),
1714            vec![("key".into(), "value".into())]
1715        );
1716        // Custom parser treats quotes as normal characters, so they are included in value
1717        assert_eq!(
1718            parse_options("-c key='value'").unwrap(),
1719            vec![("key".into(), "'value'".into())]
1720        );
1721
1722        // Test backslash escaping for spaces (standard Postgres way)
1723        assert_eq!(
1724            parse_options(r#"-c key=value\ with\ spaces"#).unwrap(),
1725            vec![("key".into(), "value with spaces".into())]
1726        );
1727        assert_eq!(
1728            parse_options(r#"-c search_path=my\ schema"#).unwrap(),
1729            vec![("search_path".into(), "my schema".into())]
1730        );
1731
1732        assert!(parse_options("-c").is_err());
1733        assert!(parse_options("-c foo").is_err()); // missing =
1734        assert!(parse_options("--foo").is_err()); // missing = in -- option
1735
1736        assert_eq!(
1737            parse_options("--foo=bar").unwrap(),
1738            vec![("foo".into(), "bar".into())]
1739        );
1740        assert_eq!(
1741            parse_options(r#"--foo=bar\ baz"#).unwrap(),
1742            vec![("foo".into(), "bar baz".into())]
1743        );
1744        assert_eq!(
1745            parse_options("-c a=1 --b=2").unwrap(),
1746            vec![("a".into(), "1".into()), ("b".into(), "2".into())]
1747        );
1748        // Unpaired trailing backslash is silently dropped, same as PostgreSQL
1749        assert_eq!(
1750            parse_options(r#"-c a=b\"#).unwrap(),
1751            vec![("a".into(), "b".into())]
1752        );
1753    }
1754}