pgwire/
pg_protocol.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::collections::HashMap;
17use std::io::ErrorKind;
18use std::panic::AssertUnwindSafe;
19use std::pin::Pin;
20use std::str::Utf8Error;
21use std::sync::{Arc, LazyLock, Weak};
22use std::time::{Duration, Instant};
23use std::{io, str};
24
25use bytes::{Bytes, BytesMut};
26use futures::FutureExt;
27use futures::stream::StreamExt;
28use itertools::Itertools;
29use openssl::ssl::{SslAcceptor, SslContext, SslContextRef, SslMethod};
30use risingwave_common::types::DataType;
31use risingwave_common::util::deployment::Deployment;
32use risingwave_common::util::env_var::env_var_is_true;
33use risingwave_common::util::panic::FutureCatchUnwindExt;
34use risingwave_common::util::query_log::*;
35use risingwave_common::{PG_VERSION, SERVER_ENCODING, STANDARD_CONFORMING_STRINGS};
36use risingwave_sqlparser::ast::{RedactSqlOptionKeywordsRef, Statement};
37use risingwave_sqlparser::parser::Parser;
38use thiserror_ext::AsReport;
39use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
40use tokio::sync::Mutex;
41use tokio_openssl::SslStream;
42use tracing::Instrument;
43
44use crate::error::{PsqlError, PsqlResult};
45use crate::error_or_notice::Severity;
46use crate::memory_manager::{MessageMemoryGuard, MessageMemoryManagerRef};
47use crate::net::AddressRef;
48use crate::pg_extended::ResultCache;
49use crate::pg_message::{
50    BeCommandCompleteMessage, BeMessage, BeParameterStatusMessage, FeBindMessage, FeCancelMessage,
51    FeCloseMessage, FeDescribeMessage, FeExecuteMessage, FeMessage, FeMessageHeader,
52    FeParseMessage, FePasswordMessage, FeStartupMessage, ServerThrottleReason, TransactionStatus,
53};
54use crate::pg_server::{Session, SessionManager, UserAuthenticator};
55use crate::types::Format;
56
57/// Truncates query log if it's longer than `RW_QUERY_LOG_TRUNCATE_LEN`, to avoid log file being too
58/// large.
59static RW_QUERY_LOG_TRUNCATE_LEN: LazyLock<usize> =
60    LazyLock::new(|| match std::env::var("RW_QUERY_LOG_TRUNCATE_LEN") {
61        Ok(len) if len.parse::<usize>().is_ok() => len.parse::<usize>().unwrap(),
62        _ => 65536,
63    });
64
65tokio::task_local! {
66    /// The current session. Concrete type is erased for different session implementations.
67    pub static CURRENT_SESSION: Weak<dyn Any + Send + Sync>
68}
69
70/// The state machine for each psql connection.
71/// Read pg messages from tcp stream and write results back.
72pub struct PgProtocol<S, SM>
73where
74    SM: SessionManager,
75{
76    /// Used for write/read pg messages.
77    stream: PgStream<S>,
78    /// Current states of pg connection.
79    state: PgProtocolState,
80    /// Whether the connection is terminated.
81    is_terminate: bool,
82
83    session_mgr: Arc<SM>,
84    session: Option<Arc<SM::Session>>,
85
86    result_cache: HashMap<String, ResultCache<<SM::Session as Session>::ValuesStream>>,
87    unnamed_prepare_statement: Option<<SM::Session as Session>::PreparedStatement>,
88    prepare_statement_store: HashMap<String, <SM::Session as Session>::PreparedStatement>,
89    unnamed_portal: Option<<SM::Session as Session>::Portal>,
90    portal_store: HashMap<String, <SM::Session as Session>::Portal>,
91    // Used to store the dependency of portal and prepare statement.
92    // When we close a prepare statement, we need to close all the portals that depend on it.
93    statement_portal_dependency: HashMap<String, Vec<String>>,
94
95    // Used for ssl connection.
96    // If None, not expected to build ssl connection (panic).
97    tls_context: Option<SslContext>,
98
99    // TLS configuration including SSL enforcement setting
100    tls_config: Option<TlsConfig>,
101
102    // Used in extended query protocol. When encounter error in extended query, we need to ignore
103    // the following message util sync message.
104    ignore_util_sync: bool,
105
106    // Client Address
107    peer_addr: AddressRef,
108
109    redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
110    message_memory_manager: MessageMemoryManagerRef,
111}
112
113/// Configures TLS encryption for connections.
114#[derive(Debug, Clone)]
115pub struct TlsConfig {
116    /// The path to the TLS certificate.
117    pub cert: String,
118    /// The path to the TLS key.
119    pub key: String,
120    /// Whether to enforce SSL connections (reject non-SSL clients).
121    pub enforce_ssl: bool,
122}
123
124impl TlsConfig {
125    pub fn new_default() -> anyhow::Result<Option<Self>> {
126        let cert = std::env::var("RW_SSL_CERT").ok();
127        let key = std::env::var("RW_SSL_KEY").ok();
128        let enforce_ssl = env_var_is_true("RW_SSL_ENFORCE");
129
130        if cert.is_some() ^ key.is_some() {
131            return Err(anyhow::anyhow!(
132                "RW_SSL_CERT and RW_SSL_KEY must be set together"
133            ));
134        }
135
136        if enforce_ssl && cert.is_none() {
137            return Err(anyhow::anyhow!(
138                "RW_SSL_ENFORCE requires RW_SSL_CERT and RW_SSL_KEY to be set"
139            ));
140        }
141
142        let (Some(cert), Some(key)) = (cert, key) else {
143            return Ok(None);
144        };
145
146        tracing::info!(
147            "RW_SSL_CERT={}, RW_SSL_KEY={}, RW_SSL_ENFORCE={}",
148            cert,
149            key,
150            enforce_ssl
151        );
152        Ok(Some(Self {
153            cert,
154            key,
155            enforce_ssl,
156        }))
157    }
158}
159
160impl<S, SM> Drop for PgProtocol<S, SM>
161where
162    SM: SessionManager,
163{
164    fn drop(&mut self) {
165        if let Some(session) = &self.session {
166            // Clear the session in session manager.
167            self.session_mgr.end_session(session);
168        }
169    }
170}
171
172/// States flow happened from top to down.
173enum PgProtocolState {
174    Startup,
175    Regular,
176}
177
178/// Truncate 0 from C string in Bytes and stringify it (returns slice, no allocations).
179///
180/// PG protocol strings are always C strings.
181pub fn cstr_to_str(b: &Bytes) -> Result<&str, Utf8Error> {
182    let without_null = if b.last() == Some(&0) {
183        &b[..b.len() - 1]
184    } else {
185        &b[..]
186    };
187    std::str::from_utf8(without_null)
188}
189
190fn get_redacted_and_truncated_sql(
191    sql: &str,
192    redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
193) -> String {
194    let redacted_sql = if let Some(keywords) = redact_sql_option_keywords
195        && !keywords.is_empty()
196    {
197        redact_sql(sql, keywords)
198    } else {
199        sql.to_owned()
200    };
201    let truncated = truncated_fmt::TruncatedFmt(&redacted_sql, *RW_QUERY_LOG_TRUNCATE_LEN);
202    truncated.to_string()
203}
204
205/// Record `sql` in the current tracing span.
206fn record_sql_in_span(
207    sql: &str,
208    redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
209    span: &mut tracing::Span,
210) {
211    let redacted_and_truncated_sql =
212        get_redacted_and_truncated_sql(sql, redact_sql_option_keywords);
213    span.record("sql", tracing::field::display(&redacted_and_truncated_sql));
214}
215
216fn record_user_in_span(user: &str, span: &mut tracing::Span) {
217    span.record("user", tracing::field::display(user));
218}
219
220/// Redacts SQL options. Data in DML is not redacted.
221fn redact_sql(sql: &str, keywords: RedactSqlOptionKeywordsRef) -> String {
222    match Parser::parse_sql(sql) {
223        Ok(sqls) => sqls
224            .into_iter()
225            .map(|sql| sql.to_redacted_string(keywords.clone()))
226            .join(";"),
227        Err(_) => sql.to_owned(),
228    }
229}
230
231#[derive(Clone)]
232pub struct ConnectionContext {
233    pub tls_config: Option<TlsConfig>,
234    pub redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
235    pub message_memory_manager: MessageMemoryManagerRef,
236}
237
238impl<S, SM> PgProtocol<S, SM>
239where
240    S: PgByteStream,
241    SM: SessionManager,
242{
243    pub fn new(
244        stream: S,
245        session_mgr: Arc<SM>,
246        peer_addr: AddressRef,
247        context: ConnectionContext,
248    ) -> Self {
249        let ConnectionContext {
250            tls_config,
251            redact_sql_option_keywords,
252            message_memory_manager,
253        } = context;
254        Self {
255            stream: PgStream::new(stream),
256            is_terminate: false,
257            state: PgProtocolState::Startup,
258            session_mgr,
259            session: None,
260            tls_context: tls_config
261                .as_ref()
262                .and_then(|e| build_ssl_ctx_from_config(e).ok()),
263            tls_config,
264            result_cache: Default::default(),
265            unnamed_prepare_statement: Default::default(),
266            prepare_statement_store: Default::default(),
267            unnamed_portal: Default::default(),
268            portal_store: Default::default(),
269            statement_portal_dependency: Default::default(),
270            ignore_util_sync: false,
271            peer_addr,
272            redact_sql_option_keywords,
273            message_memory_manager,
274        }
275    }
276
277    /// Run the protocol to serve the connection.
278    pub async fn run(&mut self) {
279        let mut notice_fut = None;
280
281        loop {
282            // Once a session is present, create a future to subscribe and send notices asynchronously.
283            if notice_fut.is_none()
284                && let Some(session) = self.session.clone()
285            {
286                let mut stream = self.stream.clone();
287                notice_fut = Some(Box::pin(async move {
288                    loop {
289                        let notice = session.next_notice().await;
290                        if let Err(e) = stream.write(BeMessage::NoticeResponse(&notice)).await {
291                            tracing::error!(error = %e.as_report(), notice, "failed to send notice");
292                        }
293                    }
294                }));
295            }
296
297            // Read and process messages.
298            let process = std::pin::pin!(async {
299                let (msg, _memory_guard) = match self.read_message().await {
300                    Ok(msg) => msg,
301                    Err(e) => {
302                        tracing::error!(error = %e.as_report(), "error when reading message");
303                        return true; // terminate the connection
304                    }
305                };
306                tracing::trace!(?msg, "received message");
307                self.process(msg).await
308            });
309
310            let terminated = if let Some(notice_fut) = notice_fut.as_mut() {
311                tokio::select! {
312                    _ = notice_fut => unreachable!(),
313                    terminated = process => terminated,
314                }
315            } else {
316                process.await
317            };
318
319            if terminated {
320                break;
321            }
322        }
323    }
324
325    /// Processes one message. Returns true if the connection is terminated.
326    pub async fn process(&mut self, msg: FeMessage) -> bool {
327        self.do_process(msg).await.is_none() || self.is_terminate
328    }
329
330    /// The root tracing span for processing a message. The target of the span is
331    /// [`PGWIRE_ROOT_SPAN_TARGET`].
332    ///
333    /// This is used to provide context for the (slow) query logs and traces.
334    ///
335    /// The span is only effective if there's a current session and the message is
336    /// query-related. Otherwise, `Span::none()` is returned.
337    fn root_span_for_msg(&self, msg: &FeMessage) -> tracing::Span {
338        let Some(session_id) = self.session.as_ref().map(|s| s.id().0) else {
339            return tracing::Span::none();
340        };
341
342        let mode = match msg {
343            FeMessage::Query(_) => "simple query",
344            FeMessage::Parse(_) => "extended query parse",
345            FeMessage::Execute(_) => "extended query execute",
346            _ => return tracing::Span::none(),
347        };
348
349        let mut span = tracing::info_span!(
350            target: PGWIRE_ROOT_SPAN_TARGET,
351            "handle_query",
352            mode,
353            session_id,
354            sql = tracing::field::Empty,
355            user = tracing::field::Empty,
356        );
357        if let Ok(sql) = msg.get_sql()
358            && let Some(sql) = sql
359        {
360            record_sql_in_span(sql, self.redact_sql_option_keywords.clone(), &mut span);
361        }
362        if let Some(current_session) = self.session.as_ref() {
363            record_user_in_span(&current_session.user(), &mut span);
364        }
365        span
366    }
367
368    /// Return type `Option<()>` is essentially a bool, but allows `?` for early return.
369    /// - `None` means to terminate the current connection
370    /// - `Some(())` means to continue processing the next message
371    async fn do_process(&mut self, msg: FeMessage) -> Option<()> {
372        let span = self.root_span_for_msg(&msg);
373        let weak_session = self
374            .session
375            .as_ref()
376            .map(|s| Arc::downgrade(s) as Weak<dyn Any + Send + Sync>);
377
378        // Processing the message itself.
379        //
380        // Note: pin the future to avoid stack overflow as we'll wrap it multiple times
381        // in the following code.
382        let fut = Box::pin(self.do_process_inner(msg));
383
384        // Set the current session as the context when processing the message, if exists.
385        let fut = async move {
386            if let Some(session) = weak_session {
387                CURRENT_SESSION.scope(session, fut).await
388            } else {
389                fut.await
390            }
391        };
392
393        // Catch unwind.
394        let fut = async move {
395            AssertUnwindSafe(fut)
396                .rw_catch_unwind()
397                .await
398                .unwrap_or_else(|payload| {
399                    Err(PsqlError::Panic(
400                        panic_message::panic_message(&payload).to_owned(),
401                    ))
402                })
403        };
404
405        // Slow query log.
406        let fut = async move {
407            let period = *SLOW_QUERY_LOG_PERIOD;
408            let mut fut = std::pin::pin!(fut);
409            let mut elapsed = Duration::ZERO;
410
411            // Report the SQL in the log periodically if the query is slow.
412            loop {
413                match tokio::time::timeout(period, &mut fut).await {
414                    Ok(result) => break result,
415                    Err(_) => {
416                        elapsed += period;
417                        tracing::info!(
418                            target: PGWIRE_SLOW_QUERY_LOG,
419                            elapsed = %format_args!("{}ms", elapsed.as_millis()),
420                            "slow query"
421                        );
422                    }
423                }
424            }
425        };
426
427        // Query log.
428        let fut = async move {
429            if !tracing::Span::current().is_none() {
430                tracing::info!(
431                    target: PGWIRE_QUERY_LOG,
432                    status = "started",
433                );
434            }
435
436            let start = Instant::now();
437            let result = fut.await;
438            let elapsed = start.elapsed();
439
440            // Always log if an error occurs.
441            // Note: all messages will be processed through this code path, making it the
442            //       only necessary place to log errors.
443            if let Err(error) = &result {
444                if cfg!(debug_assertions) && !Deployment::current().is_ci() {
445                    // For local debugging, we print the error with backtrace.
446                    // It's useful only when:
447                    // - no additional context is added to the error
448                    // - backtrace is captured in the error
449                    // - backtrace is not printed in the middle
450                    tracing::error!(error = ?error.as_report(), "error when process message");
451                } else {
452                    tracing::error!(error = %error.as_report(), "error when process message");
453                }
454            }
455
456            // Log to optionally-enabled target `PGWIRE_QUERY_LOG`.
457            // Only log if we're currently in a tracing span set in `span_for_msg`.
458            if !tracing::Span::current().is_none() {
459                tracing::info!(
460                    target: PGWIRE_QUERY_LOG,
461                    status = if result.is_ok() { "ok" } else { "err" },
462                    time = %format_args!("{}ms", elapsed.as_millis()),
463                );
464            }
465
466            result
467        };
468
469        // Tracing span.
470        let fut = fut.instrument(span);
471
472        // Execute the future and handle the error.
473        match fut.await {
474            Ok(()) => Some(()),
475            Err(e) => {
476                match e {
477                    PsqlError::IoError(io_err) => {
478                        if io_err.kind() == std::io::ErrorKind::UnexpectedEof {
479                            return None;
480                        }
481                    }
482
483                    PsqlError::SslError(_) => {
484                        // For ssl error, because the stream has already been consumed, so there is
485                        // no way to write more message.
486                        return None;
487                    }
488
489                    PsqlError::StartupError(_) | PsqlError::PasswordError => {
490                        self.stream
491                            .write_no_flush(BeMessage::ErrorResponse {
492                                error: &e,
493                                // At this time we're not in a session, use compact error message for
494                                // better alignment with Postgres' UI.
495                                pretty: false,
496                                severity: Some(Severity::Fatal),
497                            })
498                            .ok()?;
499                        let _ = self.stream.flush().await;
500                        return None;
501                    }
502
503                    PsqlError::SimpleQueryError(_) | PsqlError::ServerThrottle(_) => {
504                        self.stream
505                            .write_no_flush(BeMessage::ErrorResponse {
506                                error: &e,
507                                pretty: true,
508                                severity: None,
509                            })
510                            .ok()?;
511                        self.ready_for_query().ok()?;
512                    }
513
514                    PsqlError::IdleInTxnTimeout | PsqlError::Panic(_) => {
515                        self.stream
516                            .write_no_flush(BeMessage::ErrorResponse {
517                                error: &e,
518                                pretty: true,
519                                severity: None,
520                            })
521                            .ok()?;
522                        let _ = self.stream.flush().await;
523
524                        // 1. Catching the panic during message processing may leave the session in an
525                        // inconsistent state. We forcefully close the connection (then end the
526                        // session) here for safety.
527                        // 2. Idle in transaction timeout should also close the connection.
528                        return None;
529                    }
530
531                    PsqlError::Uncategorized(_)
532                    | PsqlError::ExtendedPrepareError(_)
533                    | PsqlError::ExtendedExecuteError(_) => {
534                        self.stream
535                            .write_no_flush(BeMessage::ErrorResponse {
536                                error: &e,
537                                pretty: true,
538                                severity: None,
539                            })
540                            .ok()?;
541                    }
542                }
543                let _ = self.stream.flush().await;
544                Some(())
545            }
546        }
547    }
548
549    async fn do_process_inner(&mut self, msg: FeMessage) -> PsqlResult<()> {
550        // Ignore util sync message.
551        if self.ignore_util_sync {
552            if let FeMessage::Sync = msg {
553            } else {
554                tracing::trace!("ignore message {:?} until sync.", msg);
555                return Ok(());
556            }
557        }
558
559        match msg {
560            FeMessage::Gss => self.process_gss_msg().await?,
561            FeMessage::Ssl => self.process_ssl_msg().await?,
562            FeMessage::Startup(msg) => self.process_startup_msg(msg).await?,
563            FeMessage::Password(msg) => self.process_password_msg(msg).await?,
564            FeMessage::Query(query_msg) => {
565                let sql = Arc::from(query_msg.get_sql()?);
566                // The process_query_msg can be slow. Release potential large FeQueryMessage early.
567                drop(query_msg);
568                self.process_query_msg(sql).await?
569            }
570            FeMessage::CancelQuery(m) => self.process_cancel_msg(m)?,
571            FeMessage::Terminate => self.process_terminate(),
572            FeMessage::Parse(m) => {
573                if let Err(err) = self.process_parse_msg(m).await {
574                    self.ignore_util_sync = true;
575                    return Err(err);
576                }
577            }
578            FeMessage::Bind(m) => {
579                if let Err(err) = self.process_bind_msg(m) {
580                    self.ignore_util_sync = true;
581                    return Err(err);
582                }
583            }
584            FeMessage::Execute(m) => {
585                if let Err(err) = self.process_execute_msg(m).await {
586                    self.ignore_util_sync = true;
587                    return Err(err);
588                }
589            }
590            FeMessage::Describe(m) => {
591                if let Err(err) = self.process_describe_msg(m) {
592                    self.ignore_util_sync = true;
593                    return Err(err);
594                }
595            }
596            FeMessage::Sync => {
597                self.ignore_util_sync = false;
598                self.ready_for_query()?
599            }
600            FeMessage::Close(m) => {
601                if let Err(err) = self.process_close_msg(m) {
602                    self.ignore_util_sync = true;
603                    return Err(err);
604                }
605            }
606            FeMessage::Flush => {
607                if let Err(err) = self.stream.flush().await {
608                    self.ignore_util_sync = true;
609                    return Err(err.into());
610                }
611            }
612            FeMessage::HealthCheck => self.process_health_check(),
613            FeMessage::ServerThrottle(reason) => match reason {
614                ServerThrottleReason::TooLargeMessage => {
615                    return Err(PsqlError::ServerThrottle(format!(
616                        "max_single_query_size_bytes {} has been exceeded, please either reduce the query size or increase the limit",
617                        self.message_memory_manager.max_filter_bytes
618                    )));
619                }
620                ServerThrottleReason::TooManyMemoryUsage => {
621                    return Err(PsqlError::ServerThrottle(format!(
622                        "max_total_query_size_bytes {} has been exceeded, please either retry or increase the limit",
623                        self.message_memory_manager.max_running_bytes
624                    )));
625                }
626            },
627        }
628        self.stream.flush().await?;
629        Ok(())
630    }
631
632    pub async fn read_message(&mut self) -> io::Result<(FeMessage, Option<MessageMemoryGuard>)> {
633        match self.state {
634            PgProtocolState::Startup => self
635                .stream
636                .read_startup()
637                .await
638                .map(|message: FeMessage| (message, None)),
639            PgProtocolState::Regular => {
640                self.stream.read_header().await?;
641                let guard = if let Some(ref header) = self.stream.read_header {
642                    let payload_len = std::cmp::max(header.payload_len, 0) as u64;
643                    let (reason, guard) = self.message_memory_manager.add(payload_len);
644                    if let Some(reason) = reason {
645                        // Release the memory ASAP.
646                        drop(guard);
647                        self.stream.skip_body().await?;
648                        return Ok((FeMessage::ServerThrottle(reason), None));
649                    }
650                    guard
651                } else {
652                    None
653                };
654                let message = self.stream.read_body().await?;
655                Ok((message, guard))
656            }
657        }
658    }
659
660    /// Writes a `ReadyForQuery` message to the client without flushing.
661    fn ready_for_query(&mut self) -> io::Result<()> {
662        self.stream.write_no_flush(BeMessage::ReadyForQuery(
663            self.session
664                .as_ref()
665                .map(|s| s.transaction_status())
666                .unwrap_or(TransactionStatus::Idle),
667        ))
668    }
669
670    async fn process_gss_msg(&mut self) -> PsqlResult<()> {
671        // We don't support GSSAPI, so we just say no gracefully.
672        self.stream.write(BeMessage::EncryptionResponseNo).await?;
673        Ok(())
674    }
675
676    async fn process_ssl_msg(&mut self) -> PsqlResult<()> {
677        if let Some(context) = self.tls_context.as_ref() {
678            // If got and ssl context, say yes for ssl connection.
679            // Construct ssl stream and replace with current one.
680            self.stream.write(BeMessage::EncryptionResponseSsl).await?;
681            self.stream.upgrade_to_ssl(context).await?;
682        } else {
683            // If no, say no for encryption.
684            self.stream.write(BeMessage::EncryptionResponseNo).await?;
685        }
686
687        Ok(())
688    }
689
690    async fn process_startup_msg(&mut self, msg: FeStartupMessage) -> PsqlResult<()> {
691        // Check SSL enforcement: if SSL is enforced but connection is not using SSL, reject
692        if let Some(ref tls_config) = self.tls_config
693            && tls_config.enforce_ssl
694            && !self.stream.is_ssl_connection().await
695        {
696            return Err(PsqlError::StartupError(
697                "SSL connection is required but not established".into(),
698            ));
699        }
700
701        let db_name = msg
702            .config
703            .get("database")
704            .cloned()
705            .unwrap_or_else(|| "dev".to_owned());
706        let user_name = msg
707            .config
708            .get("user")
709            .cloned()
710            .unwrap_or_else(|| "root".to_owned());
711
712        let session = self
713            .session_mgr
714            .connect(&db_name, &user_name, self.peer_addr.clone())
715            .map_err(|e| PsqlError::StartupError(e.into()))?;
716
717        if let Some(options) = msg.config.get("options") {
718            for (key, value) in parse_options(options)? {
719                session
720                    .set_config(&key, value)
721                    .map_err(|e| PsqlError::StartupError(e.into()))?;
722            }
723        }
724        // dedicated `application_name` has higher priority than `options`
725        let application_name = msg.config.get("application_name");
726        if let Some(application_name) = application_name {
727            session
728                .set_config("application_name", application_name.clone())
729                .map_err(|e| PsqlError::StartupError(e.into()))?;
730        }
731
732        match session.user_authenticator() {
733            UserAuthenticator::None => {
734                self.stream.write_no_flush(BeMessage::AuthenticationOk)?;
735
736                // Cancel request need this for identify and verification. According to postgres
737                // doc, it should be written to buffer after receive AuthenticationOk.
738                self.stream
739                    .write_no_flush(BeMessage::BackendKeyData(session.id()))?;
740
741                self.stream.write_no_flush(BeMessage::ParameterStatus(
742                    BeParameterStatusMessage::TimeZone(
743                        &session
744                            .get_config("timezone")
745                            .map_err(|e| PsqlError::StartupError(e.into()))?,
746                    ),
747                ))?;
748                self.stream
749                    .write_parameter_status_msg_no_flush(&ParameterStatus {
750                        application_name: application_name.cloned(),
751                    })?;
752                self.ready_for_query()?;
753            }
754            UserAuthenticator::ClearText(_)
755            | UserAuthenticator::OAuth { .. }
756            | UserAuthenticator::Ldap(..) => {
757                self.stream
758                    .write_no_flush(BeMessage::AuthenticationCleartextPassword)?;
759            }
760            UserAuthenticator::Md5WithSalt { salt, .. } => {
761                self.stream
762                    .write_no_flush(BeMessage::AuthenticationMd5Password(salt))?;
763            }
764        }
765
766        self.session = Some(session);
767        self.state = PgProtocolState::Regular;
768        Ok(())
769    }
770
771    async fn process_password_msg(&mut self, msg: FePasswordMessage) -> PsqlResult<()> {
772        let session = self.session.as_ref().unwrap();
773        let authenticator = session.user_authenticator();
774        authenticator.authenticate(&msg.password).await?;
775        self.stream.write_no_flush(BeMessage::AuthenticationOk)?;
776        let timezone = session
777            .get_config("timezone")
778            .map_err(|e| PsqlError::StartupError(e.into()))?;
779        self.stream.write_no_flush(BeMessage::ParameterStatus(
780            BeParameterStatusMessage::TimeZone(&timezone),
781        ))?;
782        self.stream
783            .write_parameter_status_msg_no_flush(&ParameterStatus::default())?;
784        self.ready_for_query()?;
785        self.state = PgProtocolState::Regular;
786        Ok(())
787    }
788
789    fn process_cancel_msg(&mut self, m: FeCancelMessage) -> PsqlResult<()> {
790        let session_id = (m.target_process_id, m.target_secret_key);
791        tracing::trace!("cancel query in session: {:?}", session_id);
792        self.session_mgr.cancel_queries_in_session(session_id);
793        self.session_mgr.cancel_creating_jobs_in_session(session_id);
794        self.is_terminate = true;
795        Ok(())
796    }
797
798    async fn process_query_msg(&mut self, sql: Arc<str>) -> PsqlResult<()> {
799        let truncated_sql =
800            get_redacted_and_truncated_sql(&sql, self.redact_sql_option_keywords.clone());
801        let session = self.session.clone().unwrap();
802
803        session.check_idle_in_transaction_timeout()?;
804        // Store only truncated SQL in context to prevent excessive memory usage from large SQL.
805        let _exec_context_guard = session.init_exec_context(truncated_sql.into());
806        self.inner_process_query_msg(sql, session.clone()).await
807    }
808
809    async fn inner_process_query_msg(
810        &mut self,
811        sql: Arc<str>,
812        session: Arc<SM::Session>,
813    ) -> PsqlResult<()> {
814        // Parse sql.
815        let stmts =
816            Parser::parse_sql(&sql).map_err(|err| PsqlError::SimpleQueryError(err.into()))?;
817        // The following inner_process_query_msg_one_stmt can be slow. Release potential large String early.
818        drop(sql);
819        if stmts.is_empty() {
820            self.stream.write_no_flush(BeMessage::EmptyQueryResponse)?;
821        }
822
823        // Execute multiple statements in simple query. KISS later.
824        for stmt in stmts {
825            self.inner_process_query_msg_one_stmt(stmt, session.clone())
826                .await?;
827        }
828        // Put this line inside the for loop above will lead to unfinished/stuck regress test...Not
829        // sure the reason.
830        self.ready_for_query()?;
831        Ok(())
832    }
833
834    async fn inner_process_query_msg_one_stmt(
835        &mut self,
836        stmt: Statement,
837        session: Arc<SM::Session>,
838    ) -> PsqlResult<()> {
839        let session = session.clone();
840
841        // execute query
842        let res = session.clone().run_one_query(stmt, Format::Text).await;
843
844        // Take all remaining notices (if any) and send them before `CommandComplete`.
845        while let Some(notice) = session.next_notice().now_or_never() {
846            self.stream
847                .write_no_flush(BeMessage::NoticeResponse(&notice))?;
848        }
849
850        let mut res = res.map_err(|e| PsqlError::SimpleQueryError(e.into()))?;
851
852        for notice in res.notices() {
853            self.stream
854                .write_no_flush(BeMessage::NoticeResponse(notice))?;
855        }
856
857        let status = res.status();
858        if let Some(ref application_name) = status.application_name {
859            self.stream.write_no_flush(BeMessage::ParameterStatus(
860                BeParameterStatusMessage::ApplicationName(application_name),
861            ))?;
862        }
863
864        if res.is_copy_query_to_stdout() {
865            self.stream
866                .write_no_flush(BeMessage::CopyOutResponse(res.row_desc().len()))?;
867            let mut count = 0;
868            while let Some(row_set) = res.values_stream().next().await {
869                let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
870                for row in row_set {
871                    self.stream.write_no_flush(BeMessage::CopyData(&row))?;
872                    count += 1;
873                }
874            }
875
876            self.stream.write_no_flush(BeMessage::CopyDone)?;
877
878            // Run the callback before sending the `CommandComplete` message.
879            res.run_callback().await?;
880
881            self.stream
882                .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
883                    stmt_type: res.stmt_type(),
884                    rows_cnt: count,
885                }))?;
886        } else if res.is_query() {
887            self.stream
888                .write_no_flush(BeMessage::RowDescription(res.row_desc()))?;
889
890            let mut rows_cnt = 0;
891
892            while let Some(row_set) = res.values_stream().next().await {
893                let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
894                for row in row_set {
895                    self.stream.write_no_flush(BeMessage::DataRow(&row))?;
896                    rows_cnt += 1;
897                }
898            }
899
900            // Run the callback before sending the `CommandComplete` message.
901            res.run_callback().await?;
902
903            self.stream
904                .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
905                    stmt_type: res.stmt_type(),
906                    rows_cnt,
907                }))?;
908        } else if res.stmt_type().is_dml() && !res.stmt_type().is_returning() {
909            let first_row_set = res.values_stream().next().await;
910            let first_row_set = match first_row_set {
911                None => {
912                    return Err(PsqlError::Uncategorized(
913                        anyhow::anyhow!("no affected rows in output").into(),
914                    ));
915                }
916                Some(row) => row.map_err(PsqlError::SimpleQueryError)?,
917            };
918            let affected_rows_str = first_row_set[0].values()[0]
919                .as_ref()
920                .expect("compute node should return affected rows in output");
921
922            assert!(matches!(res.row_cnt_format(), Some(Format::Text)));
923            let affected_rows_cnt = String::from_utf8(affected_rows_str.to_vec())
924                .unwrap()
925                .parse()
926                .unwrap_or_default();
927
928            // Run the callback before sending the `CommandComplete` message.
929            res.run_callback().await?;
930
931            self.stream
932                .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
933                    stmt_type: res.stmt_type(),
934                    rows_cnt: affected_rows_cnt,
935                }))?;
936        } else {
937            // Run the callback before sending the `CommandComplete` message.
938            res.run_callback().await?;
939
940            self.stream
941                .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
942                    stmt_type: res.stmt_type(),
943                    rows_cnt: 0,
944                }))?;
945        }
946
947        Ok(())
948    }
949
950    fn process_terminate(&mut self) {
951        self.is_terminate = true;
952    }
953
954    fn process_health_check(&mut self) {
955        tracing::debug!("health check");
956        self.is_terminate = true;
957    }
958
959    async fn process_parse_msg(&mut self, mut msg: FeParseMessage) -> PsqlResult<()> {
960        let sql = Arc::from(cstr_to_str(&msg.sql_bytes).unwrap());
961        let session = self.session.clone().unwrap();
962        let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
963        let type_ids = std::mem::take(&mut msg.type_ids);
964        // The inner_process_parse_msg can be slow. Release potential large FeParseMessage early.
965        drop(msg);
966        self.inner_process_parse_msg(session, sql, statement_name, type_ids)
967            .await?;
968        Ok(())
969    }
970
971    async fn inner_process_parse_msg(
972        &mut self,
973        session: Arc<SM::Session>,
974        sql: Arc<str>,
975        statement_name: String,
976        type_ids: Vec<i32>,
977    ) -> PsqlResult<()> {
978        if statement_name.is_empty() {
979            // Remove the unnamed prepare statement first, in case the unsupported sql binds a wrong
980            // prepare statement.
981            self.unnamed_prepare_statement.take();
982        } else if self.prepare_statement_store.contains_key(&statement_name) {
983            return Err(PsqlError::ExtendedPrepareError(
984                "Duplicated statement name".into(),
985            ));
986        }
987
988        let stmt = {
989            let stmts = Parser::parse_sql(&sql)
990                .map_err(|err| PsqlError::ExtendedPrepareError(err.into()))?;
991            drop(sql);
992            if stmts.len() > 1 {
993                return Err(PsqlError::ExtendedPrepareError(
994                    "Only one statement is allowed in extended query mode".into(),
995                ));
996            }
997
998            stmts.into_iter().next()
999        };
1000
1001        let param_types: Vec<Option<DataType>> = type_ids
1002            .iter()
1003            .map(|&id| {
1004                // 0 means unspecified type
1005                // ref: https://www.postgresql.org/docs/15/protocol-message-formats.html#:~:text=Placing%20a%20zero%20here%20is%20equivalent%20to%20leaving%20the%20type%20unspecified.
1006                if id == 0 {
1007                    Ok(None)
1008                } else {
1009                    DataType::from_oid(id)
1010                        .map(Some)
1011                        .map_err(|e| PsqlError::ExtendedPrepareError(e.into()))
1012                }
1013            })
1014            .try_collect()?;
1015
1016        let prepare_statement = session
1017            .parse(stmt, param_types)
1018            .await
1019            .map_err(|e| PsqlError::ExtendedPrepareError(e.into()))?;
1020
1021        if statement_name.is_empty() {
1022            self.unnamed_prepare_statement.replace(prepare_statement);
1023        } else {
1024            self.prepare_statement_store
1025                .insert(statement_name.clone(), prepare_statement);
1026        }
1027
1028        self.statement_portal_dependency
1029            .entry(statement_name)
1030            .or_default()
1031            .clear();
1032
1033        self.stream.write_no_flush(BeMessage::ParseComplete)?;
1034        Ok(())
1035    }
1036
1037    fn process_bind_msg(&mut self, msg: FeBindMessage) -> PsqlResult<()> {
1038        let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
1039        let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
1040        let session = self.session.clone().unwrap();
1041
1042        if self.portal_store.contains_key(&portal_name) {
1043            return Err(PsqlError::Uncategorized("Duplicated portal name".into()));
1044        }
1045
1046        let prepare_statement = self.get_statement(&statement_name)?;
1047
1048        let result_formats = msg
1049            .result_format_codes
1050            .iter()
1051            .map(|&format_code| Format::from_i16(format_code))
1052            .try_collect()?;
1053        let param_formats = msg
1054            .param_format_codes
1055            .iter()
1056            .map(|&format_code| Format::from_i16(format_code))
1057            .try_collect()?;
1058
1059        let portal = session
1060            .bind(prepare_statement, msg.params, param_formats, result_formats)
1061            .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1062
1063        if portal_name.is_empty() {
1064            self.result_cache.remove(&portal_name);
1065            self.unnamed_portal.replace(portal);
1066        } else {
1067            assert!(
1068                !self.result_cache.contains_key(&portal_name),
1069                "Named portal never can be overridden."
1070            );
1071            self.portal_store.insert(portal_name.clone(), portal);
1072        }
1073
1074        self.statement_portal_dependency
1075            .get_mut(&statement_name)
1076            .unwrap()
1077            .push(portal_name);
1078
1079        self.stream.write_no_flush(BeMessage::BindComplete)?;
1080        Ok(())
1081    }
1082
1083    async fn process_execute_msg(&mut self, msg: FeExecuteMessage) -> PsqlResult<()> {
1084        let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
1085        let row_max = msg.max_rows as usize;
1086        drop(msg);
1087        let session = self.session.clone().unwrap();
1088
1089        match self.result_cache.remove(&portal_name) {
1090            Some(mut result_cache) => {
1091                assert!(self.portal_store.contains_key(&portal_name));
1092
1093                let is_consume_completed =
1094                    result_cache.consume::<S>(row_max, &mut self.stream).await?;
1095
1096                if !is_consume_completed {
1097                    self.result_cache.insert(portal_name, result_cache);
1098                }
1099            }
1100            _ => {
1101                let portal = self.get_portal(&portal_name)?;
1102                let sql = format!("{}", portal);
1103                let truncated_sql =
1104                    get_redacted_and_truncated_sql(&sql, self.redact_sql_option_keywords.clone());
1105                drop(sql);
1106
1107                session.check_idle_in_transaction_timeout()?;
1108                // Store only truncated SQL in context to prevent excessive memory usage from large SQL.
1109                let _exec_context_guard = session.init_exec_context(truncated_sql.into());
1110                let result = session.clone().execute(portal).await;
1111
1112                let pg_response = result.map_err(|e| PsqlError::ExtendedExecuteError(e.into()))?;
1113                let mut result_cache = ResultCache::new(pg_response);
1114                let is_consume_completed =
1115                    result_cache.consume::<S>(row_max, &mut self.stream).await?;
1116                if !is_consume_completed {
1117                    self.result_cache.insert(portal_name, result_cache);
1118                }
1119            }
1120        }
1121
1122        Ok(())
1123    }
1124
1125    fn process_describe_msg(&mut self, msg: FeDescribeMessage) -> PsqlResult<()> {
1126        let name = cstr_to_str(&msg.name).unwrap().to_owned();
1127        let session = self.session.clone().unwrap();
1128        //  b'S' => Statement
1129        //  b'P' => Portal
1130
1131        assert!(msg.kind == b'S' || msg.kind == b'P');
1132        if msg.kind == b'S' {
1133            let prepare_statement = self.get_statement(&name)?;
1134
1135            let (param_types, row_descriptions) = self
1136                .session
1137                .clone()
1138                .unwrap()
1139                .describe_statement(prepare_statement)
1140                .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1141            self.stream.write_no_flush(BeMessage::ParameterDescription(
1142                &param_types.iter().map(|t| t.to_oid()).collect_vec(),
1143            ))?;
1144
1145            if row_descriptions.is_empty() {
1146                // According https://www.postgresql.org/docs/current/protocol-flow.html#:~:text=The%20response%20is%20a%20RowDescri[…]0a%20query%20that%20will%20return%20rows%3B,
1147                // return NoData message if the statement is not a query.
1148                self.stream.write_no_flush(BeMessage::NoData)?;
1149            } else {
1150                self.stream
1151                    .write_no_flush(BeMessage::RowDescription(&row_descriptions))?;
1152            }
1153        } else if msg.kind == b'P' {
1154            let portal = self.get_portal(&name)?;
1155
1156            let row_descriptions = session
1157                .describe_portal(portal)
1158                .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1159
1160            if row_descriptions.is_empty() {
1161                // According https://www.postgresql.org/docs/current/protocol-flow.html#:~:text=The%20response%20is%20a%20RowDescri[…]0a%20query%20that%20will%20return%20rows%3B,
1162                // return NoData message if the statement is not a query.
1163                self.stream.write_no_flush(BeMessage::NoData)?;
1164            } else {
1165                self.stream
1166                    .write_no_flush(BeMessage::RowDescription(&row_descriptions))?;
1167            }
1168        }
1169        Ok(())
1170    }
1171
1172    fn process_close_msg(&mut self, msg: FeCloseMessage) -> PsqlResult<()> {
1173        let name = cstr_to_str(&msg.name).unwrap().to_owned();
1174        assert!(msg.kind == b'S' || msg.kind == b'P');
1175        if msg.kind == b'S' {
1176            if name.is_empty() {
1177                self.unnamed_prepare_statement = None;
1178            } else {
1179                self.prepare_statement_store.remove(&name);
1180            }
1181            for portal_name in self
1182                .statement_portal_dependency
1183                .remove(&name)
1184                .unwrap_or_default()
1185            {
1186                self.remove_portal(&portal_name);
1187            }
1188        } else if msg.kind == b'P' {
1189            self.remove_portal(&name);
1190        }
1191        self.stream.write_no_flush(BeMessage::CloseComplete)?;
1192        Ok(())
1193    }
1194
1195    fn remove_portal(&mut self, portal_name: &str) {
1196        if portal_name.is_empty() {
1197            self.unnamed_portal = None;
1198        } else {
1199            self.portal_store.remove(portal_name);
1200        }
1201        self.result_cache.remove(portal_name);
1202    }
1203
1204    fn get_portal(&self, portal_name: &str) -> PsqlResult<<SM::Session as Session>::Portal> {
1205        if portal_name.is_empty() {
1206            Ok(self
1207                .unnamed_portal
1208                .as_ref()
1209                .ok_or_else(|| PsqlError::Uncategorized("unnamed portal not found".into()))?
1210                .clone())
1211        } else {
1212            Ok(self
1213                .portal_store
1214                .get(portal_name)
1215                .ok_or_else(|| {
1216                    PsqlError::Uncategorized(format!("Portal {} not found", portal_name).into())
1217                })?
1218                .clone())
1219        }
1220    }
1221
1222    fn get_statement(
1223        &self,
1224        statement_name: &str,
1225    ) -> PsqlResult<<SM::Session as Session>::PreparedStatement> {
1226        if statement_name.is_empty() {
1227            Ok(self
1228                .unnamed_prepare_statement
1229                .as_ref()
1230                .ok_or_else(|| {
1231                    PsqlError::Uncategorized("unnamed prepare statement not found".into())
1232                })?
1233                .clone())
1234        } else {
1235            Ok(self
1236                .prepare_statement_store
1237                .get(statement_name)
1238                .ok_or_else(|| {
1239                    PsqlError::Uncategorized(
1240                        format!("Prepare statement {} not found", statement_name).into(),
1241                    )
1242                })?
1243                .clone())
1244        }
1245    }
1246}
1247
1248enum PgStreamInner<S> {
1249    /// Used for the intermediate state when converting from unencrypted to ssl stream.
1250    Placeholder,
1251    /// An unencrypted stream.
1252    Unencrypted(S),
1253    /// An ssl stream.
1254    Ssl(SslStream<S>),
1255}
1256
1257/// Trait for a byte stream that can be used for pg protocol.
1258pub trait PgByteStream: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1259impl<S> PgByteStream for S where S: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1260
1261/// Wraps a byte stream and read/write pg messages.
1262///
1263/// Cloning a `PgStream` will share the same stream but a fresh & independent write buffer,
1264/// so that it can be used to write messages concurrently without interference.
1265pub struct PgStream<S> {
1266    /// The underlying stream.
1267    stream: Arc<Mutex<PgStreamInner<S>>>,
1268    /// Write into buffer before flush to stream.
1269    write_buf: BytesMut,
1270    read_header: Option<FeMessageHeader>,
1271}
1272
1273impl<S> PgStream<S> {
1274    /// Create a new `PgStream` with the given stream and default write buffer capacity.
1275    pub fn new(stream: S) -> Self {
1276        const DEFAULT_WRITE_BUF_CAPACITY: usize = 10 * 1024;
1277
1278        Self {
1279            stream: Arc::new(Mutex::new(PgStreamInner::Unencrypted(stream))),
1280            write_buf: BytesMut::with_capacity(DEFAULT_WRITE_BUF_CAPACITY),
1281            read_header: None,
1282        }
1283    }
1284
1285    /// Check if the current connection is using SSL
1286    async fn is_ssl_connection(&self) -> bool {
1287        let stream = self.stream.lock().await;
1288        matches!(*stream, PgStreamInner::Ssl(_))
1289    }
1290}
1291
1292impl<S> Clone for PgStream<S> {
1293    fn clone(&self) -> Self {
1294        Self {
1295            stream: Arc::clone(&self.stream),
1296            write_buf: BytesMut::with_capacity(self.write_buf.capacity()),
1297            read_header: self.read_header.clone(),
1298        }
1299    }
1300}
1301
1302/// At present there is a hard-wired set of parameters for which
1303/// ParameterStatus will be generated: they are:
1304///
1305///  * `server_version`
1306///  * `server_encoding`
1307///  * `client_encoding`
1308///  * `application_name`
1309///  * `is_superuser`
1310///  * `session_authorization`
1311///  * `DateStyle`
1312///  * `IntervalStyle`
1313///  * `TimeZone`
1314///  * `integer_datetimes`
1315///  * `standard_conforming_string`
1316///
1317/// See: <https://www.postgresql.org/docs/9.2/static/protocol-flow.html#PROTOCOL-ASYNC>.
1318#[derive(Debug, Default, Clone)]
1319pub struct ParameterStatus {
1320    pub application_name: Option<String>,
1321}
1322
1323impl<S> PgStream<S>
1324where
1325    S: PgByteStream,
1326{
1327    async fn read_startup(&mut self) -> io::Result<FeMessage> {
1328        let mut stream = self.stream.lock().await;
1329        match &mut *stream {
1330            PgStreamInner::Placeholder => unreachable!(),
1331            PgStreamInner::Unencrypted(stream) => FeStartupMessage::read(stream).await,
1332            PgStreamInner::Ssl(ssl_stream) => FeStartupMessage::read(ssl_stream).await,
1333        }
1334    }
1335
1336    async fn read_header(&mut self) -> io::Result<()> {
1337        let mut stream = self.stream.lock().await;
1338        match &mut *stream {
1339            PgStreamInner::Placeholder => unreachable!(),
1340            PgStreamInner::Unencrypted(stream) => {
1341                self.read_header = Some(FeMessage::read_header(stream).await?);
1342                Ok(())
1343            }
1344            PgStreamInner::Ssl(ssl_stream) => {
1345                self.read_header = Some(FeMessage::read_header(ssl_stream).await?);
1346                Ok(())
1347            }
1348        }
1349    }
1350
1351    async fn read_body(&mut self) -> io::Result<FeMessage> {
1352        let mut stream = self.stream.lock().await;
1353        let header = self
1354            .read_header
1355            .take()
1356            .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1357        match &mut *stream {
1358            PgStreamInner::Placeholder => unreachable!(),
1359            PgStreamInner::Unencrypted(stream) => FeMessage::read_body(stream, header).await,
1360            PgStreamInner::Ssl(ssl_stream) => FeMessage::read_body(ssl_stream, header).await,
1361        }
1362    }
1363
1364    async fn skip_body(&mut self) -> io::Result<()> {
1365        let mut stream = self.stream.lock().await;
1366        let header = self
1367            .read_header
1368            .take()
1369            .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1370        match &mut *stream {
1371            PgStreamInner::Placeholder => unreachable!(),
1372            PgStreamInner::Unencrypted(stream) => FeMessage::skip_body(stream, header).await,
1373            PgStreamInner::Ssl(ssl_stream) => FeMessage::skip_body(ssl_stream, header).await,
1374        }
1375    }
1376
1377    fn write_parameter_status_msg_no_flush(&mut self, status: &ParameterStatus) -> io::Result<()> {
1378        self.write_no_flush(BeMessage::ParameterStatus(
1379            BeParameterStatusMessage::ClientEncoding(SERVER_ENCODING),
1380        ))?;
1381        self.write_no_flush(BeMessage::ParameterStatus(
1382            BeParameterStatusMessage::StandardConformingString(STANDARD_CONFORMING_STRINGS),
1383        ))?;
1384        self.write_no_flush(BeMessage::ParameterStatus(
1385            BeParameterStatusMessage::ServerVersion(PG_VERSION),
1386        ))?;
1387        if let Some(application_name) = &status.application_name {
1388            self.write_no_flush(BeMessage::ParameterStatus(
1389                BeParameterStatusMessage::ApplicationName(application_name),
1390            ))?;
1391        }
1392        Ok(())
1393    }
1394
1395    pub fn write_no_flush(&mut self, message: BeMessage<'_>) -> io::Result<()> {
1396        BeMessage::write(&mut self.write_buf, message)
1397    }
1398
1399    async fn write(&mut self, message: BeMessage<'_>) -> io::Result<()> {
1400        self.write_no_flush(message)?;
1401        self.flush().await?;
1402        Ok(())
1403    }
1404
1405    async fn flush(&mut self) -> io::Result<()> {
1406        let mut stream = self.stream.lock().await;
1407        match &mut *stream {
1408            PgStreamInner::Placeholder => unreachable!(),
1409            PgStreamInner::Unencrypted(stream) => {
1410                stream.write_all(&self.write_buf).await?;
1411                stream.flush().await?;
1412            }
1413            PgStreamInner::Ssl(ssl_stream) => {
1414                ssl_stream.write_all(&self.write_buf).await?;
1415                ssl_stream.flush().await?;
1416            }
1417        }
1418        self.write_buf.clear();
1419        Ok(())
1420    }
1421}
1422
1423impl<S> PgStream<S>
1424where
1425    S: PgByteStream,
1426{
1427    /// Convert the underlying stream to ssl stream based on the given context.
1428    async fn upgrade_to_ssl(&mut self, ssl_ctx: &SslContextRef) -> PsqlResult<()> {
1429        let mut stream = self.stream.lock().await;
1430
1431        match std::mem::replace(&mut *stream, PgStreamInner::Placeholder) {
1432            PgStreamInner::Unencrypted(unencrypted_stream) => {
1433                let ssl = openssl::ssl::Ssl::new(ssl_ctx).unwrap();
1434                let mut ssl_stream =
1435                    tokio_openssl::SslStream::new(ssl, unencrypted_stream).unwrap();
1436
1437                if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
1438                    tracing::warn!(error = %e.as_report(), "Unable to set up an ssl connection");
1439                    let _ = ssl_stream.shutdown().await;
1440                    return Err(e.into());
1441                }
1442
1443                *stream = PgStreamInner::Ssl(ssl_stream);
1444            }
1445            PgStreamInner::Ssl(_) => panic!("the stream is already ssl"),
1446            PgStreamInner::Placeholder => unreachable!(),
1447        }
1448
1449        Ok(())
1450    }
1451}
1452
1453fn build_ssl_ctx_from_config(tls_config: &TlsConfig) -> PsqlResult<SslContext> {
1454    let mut acceptor = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls()).unwrap();
1455
1456    let key_path = &tls_config.key;
1457    let cert_path = &tls_config.cert;
1458
1459    // Build ssl acceptor according to the config.
1460    // Now we set every verify to true.
1461    acceptor
1462        .set_private_key_file(key_path, openssl::ssl::SslFiletype::PEM)
1463        .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1464    acceptor
1465        .set_ca_file(cert_path)
1466        .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1467    acceptor
1468        .set_certificate_chain_file(cert_path)
1469        .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1470    let acceptor = acceptor.build();
1471
1472    Ok(acceptor.into_context())
1473}
1474
1475pub mod truncated_fmt {
1476    use std::fmt::*;
1477
1478    struct TruncatedFormatter<'a, 'b> {
1479        remaining: usize,
1480        finished: bool,
1481        f: &'a mut Formatter<'b>,
1482    }
1483    impl Write for TruncatedFormatter<'_, '_> {
1484        fn write_str(&mut self, s: &str) -> Result {
1485            if self.finished {
1486                return Ok(());
1487            }
1488
1489            if self.remaining < s.len() {
1490                let actual = s.floor_char_boundary(self.remaining);
1491                self.f.write_str(&s[0..actual])?;
1492                self.remaining -= actual;
1493                self.f.write_str(&format!("...(truncated,{})", s.len()))?;
1494                self.finished = true; // so that ...(truncated) is printed exactly once
1495            } else {
1496                self.f.write_str(s)?;
1497                self.remaining -= s.len();
1498            }
1499            Ok(())
1500        }
1501    }
1502
1503    pub struct TruncatedFmt<'a, T>(pub &'a T, pub usize);
1504
1505    impl<T> Debug for TruncatedFmt<'_, T>
1506    where
1507        T: Debug,
1508    {
1509        fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1510            TruncatedFormatter {
1511                remaining: self.1,
1512                finished: false,
1513                f,
1514            }
1515            .write_fmt(format_args!("{:?}", self.0))
1516        }
1517    }
1518
1519    impl<T> Display for TruncatedFmt<'_, T>
1520    where
1521        T: Display,
1522    {
1523        fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1524            TruncatedFormatter {
1525                remaining: self.1,
1526                finished: false,
1527                f,
1528            }
1529            .write_fmt(format_args!("{}", self.0))
1530        }
1531    }
1532
1533    #[cfg(test)]
1534    mod tests {
1535        use super::*;
1536
1537        #[test]
1538        fn test_trunc_utf8() {
1539            assert_eq!(
1540                format!("{}", TruncatedFmt(&"select '🌊';", 10)),
1541                "select '...(truncated,14)",
1542            );
1543        }
1544    }
1545}
1546
1547/// Handle `options` in `StartupMessage` from client
1548///
1549/// It is like shell arguments but only respects backslash-escape and space;
1550/// quotes have no special meaning and are handled literally.
1551///
1552/// PostgreSQL allows both `-c key=value` and `--key=value`.
1553///
1554/// `key-name` is normalized as `key_name`.
1555///
1556/// * <https://github.com/postgres/postgres/blob/REL_18_1/src/backend/utils/init/postinit.c#L487>
1557/// * <https://github.com/postgres/postgres/blob/REL_18_1/src/backend/tcop/postgres.c#L3866>
1558/// * <https://github.com/postgres/postgres/blob/REL_18_1/src/backend/utils/misc/guc.c#L6361>
1559fn parse_options(options: &str) -> PsqlResult<Vec<(String, String)>> {
1560    let mut args = Vec::new();
1561    let mut current_arg = String::new();
1562    let mut chars = options.chars().peekable();
1563
1564    while let Some(c) = chars.next() {
1565        if c == '\\' {
1566            if let Some(next_c) = chars.next() {
1567                current_arg.push(next_c);
1568            }
1569        } else if c.is_ascii_whitespace() {
1570            if !current_arg.is_empty() {
1571                args.push(std::mem::take(&mut current_arg));
1572            }
1573        } else {
1574            current_arg.push(c);
1575        }
1576    }
1577    if !current_arg.is_empty() {
1578        args.push(current_arg);
1579    }
1580
1581    let mut args_iter = args.into_iter();
1582    let mut config = Vec::new();
1583
1584    while let Some(arg) = args_iter.next() {
1585        if arg == "-c" {
1586            if let Some(config_str) = args_iter.next() {
1587                if let Some((key, value)) = config_str.split_once('=') {
1588                    let key = key.replace("-", "_");
1589                    config.push((key, value.to_owned()));
1590                } else {
1591                    return Err(PsqlError::StartupError(
1592                        format!("invalid config format: {}", config_str).into(),
1593                    ));
1594                }
1595            } else {
1596                return Err(PsqlError::StartupError("missing argument for -c".into()));
1597            }
1598        } else if let Some(config_str) = arg.strip_prefix("--") {
1599            if let Some((key, value)) = config_str.split_once('=') {
1600                let key = key.replace("-", "_");
1601                config.push((key, value.to_owned()));
1602            } else {
1603                return Err(PsqlError::StartupError(
1604                    format!("invalid config format: {}", config_str).into(),
1605                ));
1606            }
1607        } else {
1608            tracing::warn!(
1609                arg,
1610                "ignoring unrecognized option for backward compatibility"
1611            );
1612        }
1613    }
1614    Ok(config)
1615}
1616
1617#[cfg(test)]
1618mod tests {
1619    use std::collections::HashSet;
1620
1621    use super::*;
1622
1623    #[test]
1624    fn test_redact_parsable_sql() {
1625        let keywords = Arc::new(HashSet::from(["v2".into(), "v4".into(), "b".into()]));
1626        let sql = r"
1627        create source temp (k bigint, v varchar) with (
1628            connector = 'datagen',
1629            v1 = 123,
1630            v2 = 'with',
1631            v3 = false,
1632            v4 = '',
1633        ) FORMAT plain ENCODE json (a='1',b='2')
1634        ";
1635        assert_eq!(
1636            redact_sql(sql, keywords),
1637            "CREATE SOURCE temp (k BIGINT, v CHARACTER VARYING) WITH (connector = 'datagen', v1 = 123, v2 = [REDACTED], v3 = false, v4 = [REDACTED]) FORMAT PLAIN ENCODE JSON (a = '1', b = [REDACTED])"
1638        );
1639    }
1640
1641    #[test]
1642    fn test_parse_options() {
1643        assert_eq!(parse_options("").unwrap(), vec![]);
1644        assert_eq!(
1645            parse_options("-c a=1 -c b=2").unwrap(),
1646            vec![("a".into(), "1".into()), ("b".into(), "2".into())]
1647        );
1648        assert_eq!(
1649            parse_options("-c   key=value").unwrap(),
1650            vec![("key".into(), "value".into())]
1651        );
1652        // Custom parser treats quotes as normal characters, so they are included in value
1653        assert_eq!(
1654            parse_options("-c key='value'").unwrap(),
1655            vec![("key".into(), "'value'".into())]
1656        );
1657
1658        // Test backslash escaping for spaces (standard Postgres way)
1659        assert_eq!(
1660            parse_options(r#"-c key=value\ with\ spaces"#).unwrap(),
1661            vec![("key".into(), "value with spaces".into())]
1662        );
1663        assert_eq!(
1664            parse_options(r#"-c search_path=my\ schema"#).unwrap(),
1665            vec![("search_path".into(), "my schema".into())]
1666        );
1667
1668        assert!(parse_options("-c").is_err());
1669        assert!(parse_options("-c foo").is_err()); // missing =
1670        assert!(parse_options("--foo").is_err()); // missing = in -- option
1671
1672        assert_eq!(
1673            parse_options("--foo=bar").unwrap(),
1674            vec![("foo".into(), "bar".into())]
1675        );
1676        assert_eq!(
1677            parse_options(r#"--foo=bar\ baz"#).unwrap(),
1678            vec![("foo".into(), "bar baz".into())]
1679        );
1680        assert_eq!(
1681            parse_options("-c a=1 --b=2").unwrap(),
1682            vec![("a".into(), "1".into()), ("b".into(), "2".into())]
1683        );
1684        // Unpaired trailing backslash is silently dropped, same as PostgreSQL
1685        assert_eq!(
1686            parse_options(r#"-c a=b\"#).unwrap(),
1687            vec![("a".into(), "b".into())]
1688        );
1689    }
1690}