pgwire/
pg_protocol.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::collections::HashMap;
17use std::io::ErrorKind;
18use std::panic::AssertUnwindSafe;
19use std::pin::Pin;
20use std::str::Utf8Error;
21use std::sync::{Arc, LazyLock, Weak};
22use std::time::{Duration, Instant};
23use std::{io, str};
24
25use bytes::{Bytes, BytesMut};
26use futures::FutureExt;
27use futures::stream::StreamExt;
28use itertools::Itertools;
29use openssl::ssl::{SslAcceptor, SslContext, SslContextRef, SslMethod};
30use risingwave_common::types::DataType;
31use risingwave_common::util::deployment::Deployment;
32use risingwave_common::util::env_var::env_var_is_true;
33use risingwave_common::util::panic::FutureCatchUnwindExt;
34use risingwave_common::util::query_log::*;
35use risingwave_common::{PG_VERSION, SERVER_ENCODING, STANDARD_CONFORMING_STRINGS};
36use risingwave_sqlparser::ast::{RedactSqlOptionKeywordsRef, Statement};
37use risingwave_sqlparser::parser::Parser;
38use thiserror_ext::AsReport;
39use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
40use tokio::sync::Mutex;
41use tokio_openssl::SslStream;
42use tracing::Instrument;
43
44use crate::error::{PsqlError, PsqlResult};
45use crate::error_or_notice::Severity;
46use crate::memory_manager::{MessageMemoryGuard, MessageMemoryManagerRef};
47use crate::net::AddressRef;
48use crate::pg_extended::ResultCache;
49use crate::pg_message::{
50    BeCommandCompleteMessage, BeMessage, BeParameterStatusMessage, FeBindMessage, FeCancelMessage,
51    FeCloseMessage, FeDescribeMessage, FeExecuteMessage, FeMessage, FeMessageHeader,
52    FeParseMessage, FePasswordMessage, FeStartupMessage, ServerThrottleReason, TransactionStatus,
53};
54use crate::pg_server::{Session, SessionManager, UserAuthenticator};
55use crate::types::Format;
56
57/// Truncates query log if it's longer than `RW_QUERY_LOG_TRUNCATE_LEN`, to avoid log file being too
58/// large.
59static RW_QUERY_LOG_TRUNCATE_LEN: LazyLock<usize> =
60    LazyLock::new(|| match std::env::var("RW_QUERY_LOG_TRUNCATE_LEN") {
61        Ok(len) if len.parse::<usize>().is_ok() => len.parse::<usize>().unwrap(),
62        _ => 65536,
63    });
64
65tokio::task_local! {
66    /// The current session. Concrete type is erased for different session implementations.
67    pub static CURRENT_SESSION: Weak<dyn Any + Send + Sync>
68}
69
70/// The state machine for each psql connection.
71/// Read pg messages from tcp stream and write results back.
72pub struct PgProtocol<S, SM>
73where
74    SM: SessionManager,
75{
76    /// Used for write/read pg messages.
77    stream: PgStream<S>,
78    /// Current states of pg connection.
79    state: PgProtocolState,
80    /// Whether the connection is terminated.
81    is_terminate: bool,
82
83    session_mgr: Arc<SM>,
84    session: Option<Arc<SM::Session>>,
85
86    result_cache: HashMap<String, ResultCache<<SM::Session as Session>::ValuesStream>>,
87    unnamed_prepare_statement: Option<<SM::Session as Session>::PreparedStatement>,
88    prepare_statement_store: HashMap<String, <SM::Session as Session>::PreparedStatement>,
89    unnamed_portal: Option<<SM::Session as Session>::Portal>,
90    portal_store: HashMap<String, <SM::Session as Session>::Portal>,
91    // Used to store the dependency of portal and prepare statement.
92    // When we close a prepare statement, we need to close all the portals that depend on it.
93    statement_portal_dependency: HashMap<String, Vec<String>>,
94
95    // Used for ssl connection.
96    // If None, not expected to build ssl connection (panic).
97    tls_context: Option<SslContext>,
98
99    // TLS configuration including SSL enforcement setting
100    tls_config: Option<TlsConfig>,
101
102    // Used in extended query protocol. When encounter error in extended query, we need to ignore
103    // the following message util sync message.
104    ignore_util_sync: bool,
105
106    // Client Address
107    peer_addr: AddressRef,
108
109    redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
110    message_memory_manager: MessageMemoryManagerRef,
111}
112
113/// Configures TLS encryption for connections.
114#[derive(Debug, Clone)]
115pub struct TlsConfig {
116    /// The path to the TLS certificate.
117    pub cert: String,
118    /// The path to the TLS key.
119    pub key: String,
120    /// Whether to enforce SSL connections (reject non-SSL clients).
121    pub enforce_ssl: bool,
122}
123
124impl TlsConfig {
125    pub fn new_default() -> Option<Self> {
126        let cert = std::env::var("RW_SSL_CERT").ok()?;
127        let key = std::env::var("RW_SSL_KEY").ok()?;
128        let enforce_ssl = env_var_is_true("RW_SSL_ENFORCE");
129        tracing::info!(
130            "RW_SSL_CERT={}, RW_SSL_KEY={}, RW_SSL_ENFORCE={}",
131            cert,
132            key,
133            enforce_ssl
134        );
135        Some(Self {
136            cert,
137            key,
138            enforce_ssl,
139        })
140    }
141}
142
143impl<S, SM> Drop for PgProtocol<S, SM>
144where
145    SM: SessionManager,
146{
147    fn drop(&mut self) {
148        if let Some(session) = &self.session {
149            // Clear the session in session manager.
150            self.session_mgr.end_session(session);
151        }
152    }
153}
154
155/// States flow happened from top to down.
156enum PgProtocolState {
157    Startup,
158    Regular,
159}
160
161/// Truncate 0 from C string in Bytes and stringify it (returns slice, no allocations).
162///
163/// PG protocol strings are always C strings.
164pub fn cstr_to_str(b: &Bytes) -> Result<&str, Utf8Error> {
165    let without_null = if b.last() == Some(&0) {
166        &b[..b.len() - 1]
167    } else {
168        &b[..]
169    };
170    std::str::from_utf8(without_null)
171}
172
173fn get_redacted_and_truncated_sql(
174    sql: &str,
175    redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
176) -> String {
177    let redacted_sql = if let Some(keywords) = redact_sql_option_keywords
178        && !keywords.is_empty()
179    {
180        redact_sql(sql, keywords)
181    } else {
182        sql.to_owned()
183    };
184    let truncated = truncated_fmt::TruncatedFmt(&redacted_sql, *RW_QUERY_LOG_TRUNCATE_LEN);
185    truncated.to_string()
186}
187
188/// Record `sql` in the current tracing span.
189fn record_sql_in_span(
190    sql: &str,
191    redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
192    span: &mut tracing::Span,
193) {
194    let redacted_and_truncated_sql =
195        get_redacted_and_truncated_sql(sql, redact_sql_option_keywords);
196    span.record("sql", tracing::field::display(&redacted_and_truncated_sql));
197}
198
199fn record_user_in_span(user: &str, span: &mut tracing::Span) {
200    span.record("user", tracing::field::display(user));
201}
202
203/// Redacts SQL options. Data in DML is not redacted.
204fn redact_sql(sql: &str, keywords: RedactSqlOptionKeywordsRef) -> String {
205    match Parser::parse_sql(sql) {
206        Ok(sqls) => sqls
207            .into_iter()
208            .map(|sql| sql.to_redacted_string(keywords.clone()))
209            .join(";"),
210        Err(_) => sql.to_owned(),
211    }
212}
213
214#[derive(Clone)]
215pub struct ConnectionContext {
216    pub tls_config: Option<TlsConfig>,
217    pub redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
218    pub message_memory_manager: MessageMemoryManagerRef,
219}
220
221impl<S, SM> PgProtocol<S, SM>
222where
223    S: PgByteStream,
224    SM: SessionManager,
225{
226    pub fn new(
227        stream: S,
228        session_mgr: Arc<SM>,
229        peer_addr: AddressRef,
230        context: ConnectionContext,
231    ) -> Self {
232        let ConnectionContext {
233            tls_config,
234            redact_sql_option_keywords,
235            message_memory_manager,
236        } = context;
237        Self {
238            stream: PgStream::new(stream),
239            is_terminate: false,
240            state: PgProtocolState::Startup,
241            session_mgr,
242            session: None,
243            tls_context: tls_config
244                .as_ref()
245                .and_then(|e| build_ssl_ctx_from_config(e).ok()),
246            tls_config,
247            result_cache: Default::default(),
248            unnamed_prepare_statement: Default::default(),
249            prepare_statement_store: Default::default(),
250            unnamed_portal: Default::default(),
251            portal_store: Default::default(),
252            statement_portal_dependency: Default::default(),
253            ignore_util_sync: false,
254            peer_addr,
255            redact_sql_option_keywords,
256            message_memory_manager,
257        }
258    }
259
260    /// Run the protocol to serve the connection.
261    pub async fn run(&mut self) {
262        let mut notice_fut = None;
263
264        loop {
265            // Once a session is present, create a future to subscribe and send notices asynchronously.
266            if notice_fut.is_none()
267                && let Some(session) = self.session.clone()
268            {
269                let mut stream = self.stream.clone();
270                notice_fut = Some(Box::pin(async move {
271                    loop {
272                        let notice = session.next_notice().await;
273                        if let Err(e) = stream.write(BeMessage::NoticeResponse(&notice)).await {
274                            tracing::error!(error = %e.as_report(), notice, "failed to send notice");
275                        }
276                    }
277                }));
278            }
279
280            // Read and process messages.
281            let process = std::pin::pin!(async {
282                let (msg, _memory_guard) = match self.read_message().await {
283                    Ok(msg) => msg,
284                    Err(e) => {
285                        tracing::error!(error = %e.as_report(), "error when reading message");
286                        return true; // terminate the connection
287                    }
288                };
289                tracing::trace!(?msg, "received message");
290                self.process(msg).await
291            });
292
293            let terminated = if let Some(notice_fut) = notice_fut.as_mut() {
294                tokio::select! {
295                    _ = notice_fut => unreachable!(),
296                    terminated = process => terminated,
297                }
298            } else {
299                process.await
300            };
301
302            if terminated {
303                break;
304            }
305        }
306    }
307
308    /// Processes one message. Returns true if the connection is terminated.
309    pub async fn process(&mut self, msg: FeMessage) -> bool {
310        self.do_process(msg).await.is_none() || self.is_terminate
311    }
312
313    /// The root tracing span for processing a message. The target of the span is
314    /// [`PGWIRE_ROOT_SPAN_TARGET`].
315    ///
316    /// This is used to provide context for the (slow) query logs and traces.
317    ///
318    /// The span is only effective if there's a current session and the message is
319    /// query-related. Otherwise, `Span::none()` is returned.
320    fn root_span_for_msg(&self, msg: &FeMessage) -> tracing::Span {
321        let Some(session_id) = self.session.as_ref().map(|s| s.id().0) else {
322            return tracing::Span::none();
323        };
324
325        let mode = match msg {
326            FeMessage::Query(_) => "simple query",
327            FeMessage::Parse(_) => "extended query parse",
328            FeMessage::Execute(_) => "extended query execute",
329            _ => return tracing::Span::none(),
330        };
331
332        let mut span = tracing::info_span!(
333            target: PGWIRE_ROOT_SPAN_TARGET,
334            "handle_query",
335            mode,
336            session_id,
337            sql = tracing::field::Empty,
338            user = tracing::field::Empty,
339        );
340        if let Ok(sql) = msg.get_sql()
341            && let Some(sql) = sql
342        {
343            record_sql_in_span(sql, self.redact_sql_option_keywords.clone(), &mut span);
344        }
345        if let Some(current_session) = self.session.as_ref() {
346            record_user_in_span(&current_session.user(), &mut span);
347        }
348        span
349    }
350
351    /// Return type `Option<()>` is essentially a bool, but allows `?` for early return.
352    /// - `None` means to terminate the current connection
353    /// - `Some(())` means to continue processing the next message
354    async fn do_process(&mut self, msg: FeMessage) -> Option<()> {
355        let span = self.root_span_for_msg(&msg);
356        let weak_session = self
357            .session
358            .as_ref()
359            .map(|s| Arc::downgrade(s) as Weak<dyn Any + Send + Sync>);
360
361        // Processing the message itself.
362        //
363        // Note: pin the future to avoid stack overflow as we'll wrap it multiple times
364        // in the following code.
365        let fut = Box::pin(self.do_process_inner(msg));
366
367        // Set the current session as the context when processing the message, if exists.
368        let fut = async move {
369            if let Some(session) = weak_session {
370                CURRENT_SESSION.scope(session, fut).await
371            } else {
372                fut.await
373            }
374        };
375
376        // Catch unwind.
377        let fut = async move {
378            AssertUnwindSafe(fut)
379                .rw_catch_unwind()
380                .await
381                .unwrap_or_else(|payload| {
382                    Err(PsqlError::Panic(
383                        panic_message::panic_message(&payload).to_owned(),
384                    ))
385                })
386        };
387
388        // Slow query log.
389        let fut = async move {
390            let period = *SLOW_QUERY_LOG_PERIOD;
391            let mut fut = std::pin::pin!(fut);
392            let mut elapsed = Duration::ZERO;
393
394            // Report the SQL in the log periodically if the query is slow.
395            loop {
396                match tokio::time::timeout(period, &mut fut).await {
397                    Ok(result) => break result,
398                    Err(_) => {
399                        elapsed += period;
400                        tracing::info!(
401                            target: PGWIRE_SLOW_QUERY_LOG,
402                            elapsed = %format_args!("{}ms", elapsed.as_millis()),
403                            "slow query"
404                        );
405                    }
406                }
407            }
408        };
409
410        // Query log.
411        let fut = async move {
412            if !tracing::Span::current().is_none() {
413                tracing::info!(
414                    target: PGWIRE_QUERY_LOG,
415                    status = "started",
416                );
417            }
418
419            let start = Instant::now();
420            let result = fut.await;
421            let elapsed = start.elapsed();
422
423            // Always log if an error occurs.
424            // Note: all messages will be processed through this code path, making it the
425            //       only necessary place to log errors.
426            if let Err(error) = &result {
427                if cfg!(debug_assertions) && !Deployment::current().is_ci() {
428                    // For local debugging, we print the error with backtrace.
429                    // It's useful only when:
430                    // - no additional context is added to the error
431                    // - backtrace is captured in the error
432                    // - backtrace is not printed in the middle
433                    tracing::error!(error = ?error.as_report(), "error when process message");
434                } else {
435                    tracing::error!(error = %error.as_report(), "error when process message");
436                }
437            }
438
439            // Log to optionally-enabled target `PGWIRE_QUERY_LOG`.
440            // Only log if we're currently in a tracing span set in `span_for_msg`.
441            if !tracing::Span::current().is_none() {
442                tracing::info!(
443                    target: PGWIRE_QUERY_LOG,
444                    status = if result.is_ok() { "ok" } else { "err" },
445                    time = %format_args!("{}ms", elapsed.as_millis()),
446                );
447            }
448
449            result
450        };
451
452        // Tracing span.
453        let fut = fut.instrument(span);
454
455        // Execute the future and handle the error.
456        match fut.await {
457            Ok(()) => Some(()),
458            Err(e) => {
459                match e {
460                    PsqlError::IoError(io_err) => {
461                        if io_err.kind() == std::io::ErrorKind::UnexpectedEof {
462                            return None;
463                        }
464                    }
465
466                    PsqlError::SslError(_) => {
467                        // For ssl error, because the stream has already been consumed, so there is
468                        // no way to write more message.
469                        return None;
470                    }
471
472                    PsqlError::StartupError(_) | PsqlError::PasswordError => {
473                        self.stream
474                            .write_no_flush(BeMessage::ErrorResponse {
475                                error: &e,
476                                // At this time we're not in a session, use compact error message for
477                                // better alignment with Postgres' UI.
478                                pretty: false,
479                                severity: Some(Severity::Fatal),
480                            })
481                            .ok()?;
482                        let _ = self.stream.flush().await;
483                        return None;
484                    }
485
486                    PsqlError::SimpleQueryError(_) | PsqlError::ServerThrottle(_) => {
487                        self.stream
488                            .write_no_flush(BeMessage::ErrorResponse {
489                                error: &e,
490                                pretty: true,
491                                severity: None,
492                            })
493                            .ok()?;
494                        self.ready_for_query().ok()?;
495                    }
496
497                    PsqlError::IdleInTxnTimeout | PsqlError::Panic(_) => {
498                        self.stream
499                            .write_no_flush(BeMessage::ErrorResponse {
500                                error: &e,
501                                pretty: true,
502                                severity: None,
503                            })
504                            .ok()?;
505                        let _ = self.stream.flush().await;
506
507                        // 1. Catching the panic during message processing may leave the session in an
508                        // inconsistent state. We forcefully close the connection (then end the
509                        // session) here for safety.
510                        // 2. Idle in transaction timeout should also close the connection.
511                        return None;
512                    }
513
514                    PsqlError::Uncategorized(_)
515                    | PsqlError::ExtendedPrepareError(_)
516                    | PsqlError::ExtendedExecuteError(_) => {
517                        self.stream
518                            .write_no_flush(BeMessage::ErrorResponse {
519                                error: &e,
520                                pretty: true,
521                                severity: None,
522                            })
523                            .ok()?;
524                    }
525                }
526                let _ = self.stream.flush().await;
527                Some(())
528            }
529        }
530    }
531
532    async fn do_process_inner(&mut self, msg: FeMessage) -> PsqlResult<()> {
533        // Ignore util sync message.
534        if self.ignore_util_sync {
535            if let FeMessage::Sync = msg {
536            } else {
537                tracing::trace!("ignore message {:?} until sync.", msg);
538                return Ok(());
539            }
540        }
541
542        match msg {
543            FeMessage::Gss => self.process_gss_msg().await?,
544            FeMessage::Ssl => self.process_ssl_msg().await?,
545            FeMessage::Startup(msg) => self.process_startup_msg(msg).await?,
546            FeMessage::Password(msg) => self.process_password_msg(msg).await?,
547            FeMessage::Query(query_msg) => {
548                let sql = Arc::from(query_msg.get_sql()?);
549                // The process_query_msg can be slow. Release potential large FeQueryMessage early.
550                drop(query_msg);
551                self.process_query_msg(sql).await?
552            }
553            FeMessage::CancelQuery(m) => self.process_cancel_msg(m)?,
554            FeMessage::Terminate => self.process_terminate(),
555            FeMessage::Parse(m) => {
556                if let Err(err) = self.process_parse_msg(m).await {
557                    self.ignore_util_sync = true;
558                    return Err(err);
559                }
560            }
561            FeMessage::Bind(m) => {
562                if let Err(err) = self.process_bind_msg(m) {
563                    self.ignore_util_sync = true;
564                    return Err(err);
565                }
566            }
567            FeMessage::Execute(m) => {
568                if let Err(err) = self.process_execute_msg(m).await {
569                    self.ignore_util_sync = true;
570                    return Err(err);
571                }
572            }
573            FeMessage::Describe(m) => {
574                if let Err(err) = self.process_describe_msg(m) {
575                    self.ignore_util_sync = true;
576                    return Err(err);
577                }
578            }
579            FeMessage::Sync => {
580                self.ignore_util_sync = false;
581                self.ready_for_query()?
582            }
583            FeMessage::Close(m) => {
584                if let Err(err) = self.process_close_msg(m) {
585                    self.ignore_util_sync = true;
586                    return Err(err);
587                }
588            }
589            FeMessage::Flush => {
590                if let Err(err) = self.stream.flush().await {
591                    self.ignore_util_sync = true;
592                    return Err(err.into());
593                }
594            }
595            FeMessage::HealthCheck => self.process_health_check(),
596            FeMessage::ServerThrottle(reason) => match reason {
597                ServerThrottleReason::TooLargeMessage => {
598                    return Err(PsqlError::ServerThrottle(format!(
599                        "max_single_query_size_bytes {} has been exceeded, please either reduce the query size or increase the limit",
600                        self.message_memory_manager.max_filter_bytes
601                    )));
602                }
603                ServerThrottleReason::TooManyMemoryUsage => {
604                    return Err(PsqlError::ServerThrottle(format!(
605                        "max_total_query_size_bytes {} has been exceeded, please either retry or increase the limit",
606                        self.message_memory_manager.max_running_bytes
607                    )));
608                }
609            },
610        }
611        self.stream.flush().await?;
612        Ok(())
613    }
614
615    pub async fn read_message(&mut self) -> io::Result<(FeMessage, Option<MessageMemoryGuard>)> {
616        match self.state {
617            PgProtocolState::Startup => self
618                .stream
619                .read_startup()
620                .await
621                .map(|message: FeMessage| (message, None)),
622            PgProtocolState::Regular => {
623                self.stream.read_header().await?;
624                let guard = if let Some(ref header) = self.stream.read_header {
625                    let payload_len = std::cmp::max(header.payload_len, 0) as u64;
626                    let (reason, guard) = self.message_memory_manager.add(payload_len);
627                    if let Some(reason) = reason {
628                        // Release the memory ASAP.
629                        drop(guard);
630                        self.stream.skip_body().await?;
631                        return Ok((FeMessage::ServerThrottle(reason), None));
632                    }
633                    guard
634                } else {
635                    None
636                };
637                let message = self.stream.read_body().await?;
638                Ok((message, guard))
639            }
640        }
641    }
642
643    /// Writes a `ReadyForQuery` message to the client without flushing.
644    fn ready_for_query(&mut self) -> io::Result<()> {
645        self.stream.write_no_flush(BeMessage::ReadyForQuery(
646            self.session
647                .as_ref()
648                .map(|s| s.transaction_status())
649                .unwrap_or(TransactionStatus::Idle),
650        ))
651    }
652
653    async fn process_gss_msg(&mut self) -> PsqlResult<()> {
654        // We don't support GSSAPI, so we just say no gracefully.
655        self.stream.write(BeMessage::EncryptionResponseNo).await?;
656        Ok(())
657    }
658
659    async fn process_ssl_msg(&mut self) -> PsqlResult<()> {
660        if let Some(context) = self.tls_context.as_ref() {
661            // If got and ssl context, say yes for ssl connection.
662            // Construct ssl stream and replace with current one.
663            self.stream.write(BeMessage::EncryptionResponseSsl).await?;
664            self.stream.upgrade_to_ssl(context).await?;
665        } else {
666            // If no, say no for encryption.
667            self.stream.write(BeMessage::EncryptionResponseNo).await?;
668        }
669
670        Ok(())
671    }
672
673    async fn process_startup_msg(&mut self, msg: FeStartupMessage) -> PsqlResult<()> {
674        // Check SSL enforcement: if SSL is enforced but connection is not using SSL, reject
675        if let Some(ref tls_config) = self.tls_config
676            && tls_config.enforce_ssl
677            && !self.stream.is_ssl_connection().await
678        {
679            return Err(PsqlError::StartupError(
680                "SSL connection is required but not established".into(),
681            ));
682        }
683
684        let db_name = msg
685            .config
686            .get("database")
687            .cloned()
688            .unwrap_or_else(|| "dev".to_owned());
689        let user_name = msg
690            .config
691            .get("user")
692            .cloned()
693            .unwrap_or_else(|| "root".to_owned());
694
695        let session = self
696            .session_mgr
697            .connect(&db_name, &user_name, self.peer_addr.clone())
698            .map_err(|e| PsqlError::StartupError(e.into()))?;
699
700        if let Some(options) = msg.config.get("options") {
701            for (key, value) in parse_options(options)? {
702                session
703                    .set_config(&key, value)
704                    .map_err(|e| PsqlError::StartupError(e.into()))?;
705            }
706        }
707        // dedicated `application_name` has higher priority than `options`
708        let application_name = msg.config.get("application_name");
709        if let Some(application_name) = application_name {
710            session
711                .set_config("application_name", application_name.clone())
712                .map_err(|e| PsqlError::StartupError(e.into()))?;
713        }
714
715        match session.user_authenticator() {
716            UserAuthenticator::None => {
717                self.stream.write_no_flush(BeMessage::AuthenticationOk)?;
718
719                // Cancel request need this for identify and verification. According to postgres
720                // doc, it should be written to buffer after receive AuthenticationOk.
721                self.stream
722                    .write_no_flush(BeMessage::BackendKeyData(session.id()))?;
723
724                self.stream.write_no_flush(BeMessage::ParameterStatus(
725                    BeParameterStatusMessage::TimeZone(
726                        &session
727                            .get_config("timezone")
728                            .map_err(|e| PsqlError::StartupError(e.into()))?,
729                    ),
730                ))?;
731                self.stream
732                    .write_parameter_status_msg_no_flush(&ParameterStatus {
733                        application_name: application_name.cloned(),
734                    })?;
735                self.ready_for_query()?;
736            }
737            UserAuthenticator::ClearText(_)
738            | UserAuthenticator::OAuth { .. }
739            | UserAuthenticator::Ldap(..) => {
740                self.stream
741                    .write_no_flush(BeMessage::AuthenticationCleartextPassword)?;
742            }
743            UserAuthenticator::Md5WithSalt { salt, .. } => {
744                self.stream
745                    .write_no_flush(BeMessage::AuthenticationMd5Password(salt))?;
746            }
747        }
748
749        self.session = Some(session);
750        self.state = PgProtocolState::Regular;
751        Ok(())
752    }
753
754    async fn process_password_msg(&mut self, msg: FePasswordMessage) -> PsqlResult<()> {
755        let session = self.session.as_ref().unwrap();
756        let authenticator = session.user_authenticator();
757        authenticator.authenticate(&msg.password).await?;
758        self.stream.write_no_flush(BeMessage::AuthenticationOk)?;
759        let timezone = session
760            .get_config("timezone")
761            .map_err(|e| PsqlError::StartupError(e.into()))?;
762        self.stream.write_no_flush(BeMessage::ParameterStatus(
763            BeParameterStatusMessage::TimeZone(&timezone),
764        ))?;
765        self.stream
766            .write_parameter_status_msg_no_flush(&ParameterStatus::default())?;
767        self.ready_for_query()?;
768        self.state = PgProtocolState::Regular;
769        Ok(())
770    }
771
772    fn process_cancel_msg(&mut self, m: FeCancelMessage) -> PsqlResult<()> {
773        let session_id = (m.target_process_id, m.target_secret_key);
774        tracing::trace!("cancel query in session: {:?}", session_id);
775        self.session_mgr.cancel_queries_in_session(session_id);
776        self.session_mgr.cancel_creating_jobs_in_session(session_id);
777        self.is_terminate = true;
778        Ok(())
779    }
780
781    async fn process_query_msg(&mut self, sql: Arc<str>) -> PsqlResult<()> {
782        let truncated_sql =
783            get_redacted_and_truncated_sql(&sql, self.redact_sql_option_keywords.clone());
784        let session = self.session.clone().unwrap();
785
786        session.check_idle_in_transaction_timeout()?;
787        // Store only truncated SQL in context to prevent excessive memory usage from large SQL.
788        let _exec_context_guard = session.init_exec_context(truncated_sql.into());
789        self.inner_process_query_msg(sql, session.clone()).await
790    }
791
792    async fn inner_process_query_msg(
793        &mut self,
794        sql: Arc<str>,
795        session: Arc<SM::Session>,
796    ) -> PsqlResult<()> {
797        // Parse sql.
798        let stmts =
799            Parser::parse_sql(&sql).map_err(|err| PsqlError::SimpleQueryError(err.into()))?;
800        // The following inner_process_query_msg_one_stmt can be slow. Release potential large String early.
801        drop(sql);
802        if stmts.is_empty() {
803            self.stream.write_no_flush(BeMessage::EmptyQueryResponse)?;
804        }
805
806        // Execute multiple statements in simple query. KISS later.
807        for stmt in stmts {
808            self.inner_process_query_msg_one_stmt(stmt, session.clone())
809                .await?;
810        }
811        // Put this line inside the for loop above will lead to unfinished/stuck regress test...Not
812        // sure the reason.
813        self.ready_for_query()?;
814        Ok(())
815    }
816
817    async fn inner_process_query_msg_one_stmt(
818        &mut self,
819        stmt: Statement,
820        session: Arc<SM::Session>,
821    ) -> PsqlResult<()> {
822        let session = session.clone();
823
824        // execute query
825        let res = session.clone().run_one_query(stmt, Format::Text).await;
826
827        // Take all remaining notices (if any) and send them before `CommandComplete`.
828        while let Some(notice) = session.next_notice().now_or_never() {
829            self.stream
830                .write_no_flush(BeMessage::NoticeResponse(&notice))?;
831        }
832
833        let mut res = res.map_err(|e| PsqlError::SimpleQueryError(e.into()))?;
834
835        for notice in res.notices() {
836            self.stream
837                .write_no_flush(BeMessage::NoticeResponse(notice))?;
838        }
839
840        let status = res.status();
841        if let Some(ref application_name) = status.application_name {
842            self.stream.write_no_flush(BeMessage::ParameterStatus(
843                BeParameterStatusMessage::ApplicationName(application_name),
844            ))?;
845        }
846
847        if res.is_copy_query_to_stdout() {
848            self.stream
849                .write_no_flush(BeMessage::CopyOutResponse(res.row_desc().len()))?;
850            let mut count = 0;
851            while let Some(row_set) = res.values_stream().next().await {
852                let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
853                for row in row_set {
854                    self.stream.write_no_flush(BeMessage::CopyData(&row))?;
855                    count += 1;
856                }
857            }
858
859            self.stream.write_no_flush(BeMessage::CopyDone)?;
860
861            // Run the callback before sending the `CommandComplete` message.
862            res.run_callback().await?;
863
864            self.stream
865                .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
866                    stmt_type: res.stmt_type(),
867                    rows_cnt: count,
868                }))?;
869        } else if res.is_query() {
870            self.stream
871                .write_no_flush(BeMessage::RowDescription(res.row_desc()))?;
872
873            let mut rows_cnt = 0;
874
875            while let Some(row_set) = res.values_stream().next().await {
876                let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
877                for row in row_set {
878                    self.stream.write_no_flush(BeMessage::DataRow(&row))?;
879                    rows_cnt += 1;
880                }
881            }
882
883            // Run the callback before sending the `CommandComplete` message.
884            res.run_callback().await?;
885
886            self.stream
887                .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
888                    stmt_type: res.stmt_type(),
889                    rows_cnt,
890                }))?;
891        } else if res.stmt_type().is_dml() && !res.stmt_type().is_returning() {
892            let first_row_set = res.values_stream().next().await;
893            let first_row_set = match first_row_set {
894                None => {
895                    return Err(PsqlError::Uncategorized(
896                        anyhow::anyhow!("no affected rows in output").into(),
897                    ));
898                }
899                Some(row) => row.map_err(PsqlError::SimpleQueryError)?,
900            };
901            let affected_rows_str = first_row_set[0].values()[0]
902                .as_ref()
903                .expect("compute node should return affected rows in output");
904
905            assert!(matches!(res.row_cnt_format(), Some(Format::Text)));
906            let affected_rows_cnt = String::from_utf8(affected_rows_str.to_vec())
907                .unwrap()
908                .parse()
909                .unwrap_or_default();
910
911            // Run the callback before sending the `CommandComplete` message.
912            res.run_callback().await?;
913
914            self.stream
915                .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
916                    stmt_type: res.stmt_type(),
917                    rows_cnt: affected_rows_cnt,
918                }))?;
919        } else {
920            // Run the callback before sending the `CommandComplete` message.
921            res.run_callback().await?;
922
923            self.stream
924                .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
925                    stmt_type: res.stmt_type(),
926                    rows_cnt: 0,
927                }))?;
928        }
929
930        Ok(())
931    }
932
933    fn process_terminate(&mut self) {
934        self.is_terminate = true;
935    }
936
937    fn process_health_check(&mut self) {
938        tracing::debug!("health check");
939        self.is_terminate = true;
940    }
941
942    async fn process_parse_msg(&mut self, mut msg: FeParseMessage) -> PsqlResult<()> {
943        let sql = Arc::from(cstr_to_str(&msg.sql_bytes).unwrap());
944        let session = self.session.clone().unwrap();
945        let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
946        let type_ids = std::mem::take(&mut msg.type_ids);
947        // The inner_process_parse_msg can be slow. Release potential large FeParseMessage early.
948        drop(msg);
949        self.inner_process_parse_msg(session, sql, statement_name, type_ids)
950            .await?;
951        Ok(())
952    }
953
954    async fn inner_process_parse_msg(
955        &mut self,
956        session: Arc<SM::Session>,
957        sql: Arc<str>,
958        statement_name: String,
959        type_ids: Vec<i32>,
960    ) -> PsqlResult<()> {
961        if statement_name.is_empty() {
962            // Remove the unnamed prepare statement first, in case the unsupported sql binds a wrong
963            // prepare statement.
964            self.unnamed_prepare_statement.take();
965        } else if self.prepare_statement_store.contains_key(&statement_name) {
966            return Err(PsqlError::ExtendedPrepareError(
967                "Duplicated statement name".into(),
968            ));
969        }
970
971        let stmt = {
972            let stmts = Parser::parse_sql(&sql)
973                .map_err(|err| PsqlError::ExtendedPrepareError(err.into()))?;
974            drop(sql);
975            if stmts.len() > 1 {
976                return Err(PsqlError::ExtendedPrepareError(
977                    "Only one statement is allowed in extended query mode".into(),
978                ));
979            }
980
981            stmts.into_iter().next()
982        };
983
984        let param_types: Vec<Option<DataType>> = type_ids
985            .iter()
986            .map(|&id| {
987                // 0 means unspecified type
988                // ref: https://www.postgresql.org/docs/15/protocol-message-formats.html#:~:text=Placing%20a%20zero%20here%20is%20equivalent%20to%20leaving%20the%20type%20unspecified.
989                if id == 0 {
990                    Ok(None)
991                } else {
992                    DataType::from_oid(id)
993                        .map(Some)
994                        .map_err(|e| PsqlError::ExtendedPrepareError(e.into()))
995                }
996            })
997            .try_collect()?;
998
999        let prepare_statement = session
1000            .parse(stmt, param_types)
1001            .await
1002            .map_err(|e| PsqlError::ExtendedPrepareError(e.into()))?;
1003
1004        if statement_name.is_empty() {
1005            self.unnamed_prepare_statement.replace(prepare_statement);
1006        } else {
1007            self.prepare_statement_store
1008                .insert(statement_name.clone(), prepare_statement);
1009        }
1010
1011        self.statement_portal_dependency
1012            .entry(statement_name)
1013            .or_default()
1014            .clear();
1015
1016        self.stream.write_no_flush(BeMessage::ParseComplete)?;
1017        Ok(())
1018    }
1019
1020    fn process_bind_msg(&mut self, msg: FeBindMessage) -> PsqlResult<()> {
1021        let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
1022        let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
1023        let session = self.session.clone().unwrap();
1024
1025        if self.portal_store.contains_key(&portal_name) {
1026            return Err(PsqlError::Uncategorized("Duplicated portal name".into()));
1027        }
1028
1029        let prepare_statement = self.get_statement(&statement_name)?;
1030
1031        let result_formats = msg
1032            .result_format_codes
1033            .iter()
1034            .map(|&format_code| Format::from_i16(format_code))
1035            .try_collect()?;
1036        let param_formats = msg
1037            .param_format_codes
1038            .iter()
1039            .map(|&format_code| Format::from_i16(format_code))
1040            .try_collect()?;
1041
1042        let portal = session
1043            .bind(prepare_statement, msg.params, param_formats, result_formats)
1044            .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1045
1046        if portal_name.is_empty() {
1047            self.result_cache.remove(&portal_name);
1048            self.unnamed_portal.replace(portal);
1049        } else {
1050            assert!(
1051                !self.result_cache.contains_key(&portal_name),
1052                "Named portal never can be overridden."
1053            );
1054            self.portal_store.insert(portal_name.clone(), portal);
1055        }
1056
1057        self.statement_portal_dependency
1058            .get_mut(&statement_name)
1059            .unwrap()
1060            .push(portal_name);
1061
1062        self.stream.write_no_flush(BeMessage::BindComplete)?;
1063        Ok(())
1064    }
1065
1066    async fn process_execute_msg(&mut self, msg: FeExecuteMessage) -> PsqlResult<()> {
1067        let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
1068        let row_max = msg.max_rows as usize;
1069        drop(msg);
1070        let session = self.session.clone().unwrap();
1071
1072        match self.result_cache.remove(&portal_name) {
1073            Some(mut result_cache) => {
1074                assert!(self.portal_store.contains_key(&portal_name));
1075
1076                let is_consume_completed =
1077                    result_cache.consume::<S>(row_max, &mut self.stream).await?;
1078
1079                if !is_consume_completed {
1080                    self.result_cache.insert(portal_name, result_cache);
1081                }
1082            }
1083            _ => {
1084                let portal = self.get_portal(&portal_name)?;
1085                let sql = format!("{}", portal);
1086                let truncated_sql =
1087                    get_redacted_and_truncated_sql(&sql, self.redact_sql_option_keywords.clone());
1088                drop(sql);
1089
1090                session.check_idle_in_transaction_timeout()?;
1091                // Store only truncated SQL in context to prevent excessive memory usage from large SQL.
1092                let _exec_context_guard = session.init_exec_context(truncated_sql.into());
1093                let result = session.clone().execute(portal).await;
1094
1095                let pg_response = result.map_err(|e| PsqlError::ExtendedExecuteError(e.into()))?;
1096                let mut result_cache = ResultCache::new(pg_response);
1097                let is_consume_completed =
1098                    result_cache.consume::<S>(row_max, &mut self.stream).await?;
1099                if !is_consume_completed {
1100                    self.result_cache.insert(portal_name, result_cache);
1101                }
1102            }
1103        }
1104
1105        Ok(())
1106    }
1107
1108    fn process_describe_msg(&mut self, msg: FeDescribeMessage) -> PsqlResult<()> {
1109        let name = cstr_to_str(&msg.name).unwrap().to_owned();
1110        let session = self.session.clone().unwrap();
1111        //  b'S' => Statement
1112        //  b'P' => Portal
1113
1114        assert!(msg.kind == b'S' || msg.kind == b'P');
1115        if msg.kind == b'S' {
1116            let prepare_statement = self.get_statement(&name)?;
1117
1118            let (param_types, row_descriptions) = self
1119                .session
1120                .clone()
1121                .unwrap()
1122                .describe_statement(prepare_statement)
1123                .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1124            self.stream.write_no_flush(BeMessage::ParameterDescription(
1125                &param_types.iter().map(|t| t.to_oid()).collect_vec(),
1126            ))?;
1127
1128            if row_descriptions.is_empty() {
1129                // According https://www.postgresql.org/docs/current/protocol-flow.html#:~:text=The%20response%20is%20a%20RowDescri[…]0a%20query%20that%20will%20return%20rows%3B,
1130                // return NoData message if the statement is not a query.
1131                self.stream.write_no_flush(BeMessage::NoData)?;
1132            } else {
1133                self.stream
1134                    .write_no_flush(BeMessage::RowDescription(&row_descriptions))?;
1135            }
1136        } else if msg.kind == b'P' {
1137            let portal = self.get_portal(&name)?;
1138
1139            let row_descriptions = session
1140                .describe_portal(portal)
1141                .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1142
1143            if row_descriptions.is_empty() {
1144                // According https://www.postgresql.org/docs/current/protocol-flow.html#:~:text=The%20response%20is%20a%20RowDescri[…]0a%20query%20that%20will%20return%20rows%3B,
1145                // return NoData message if the statement is not a query.
1146                self.stream.write_no_flush(BeMessage::NoData)?;
1147            } else {
1148                self.stream
1149                    .write_no_flush(BeMessage::RowDescription(&row_descriptions))?;
1150            }
1151        }
1152        Ok(())
1153    }
1154
1155    fn process_close_msg(&mut self, msg: FeCloseMessage) -> PsqlResult<()> {
1156        let name = cstr_to_str(&msg.name).unwrap().to_owned();
1157        assert!(msg.kind == b'S' || msg.kind == b'P');
1158        if msg.kind == b'S' {
1159            if name.is_empty() {
1160                self.unnamed_prepare_statement = None;
1161            } else {
1162                self.prepare_statement_store.remove(&name);
1163            }
1164            for portal_name in self
1165                .statement_portal_dependency
1166                .remove(&name)
1167                .unwrap_or_default()
1168            {
1169                self.remove_portal(&portal_name);
1170            }
1171        } else if msg.kind == b'P' {
1172            self.remove_portal(&name);
1173        }
1174        self.stream.write_no_flush(BeMessage::CloseComplete)?;
1175        Ok(())
1176    }
1177
1178    fn remove_portal(&mut self, portal_name: &str) {
1179        if portal_name.is_empty() {
1180            self.unnamed_portal = None;
1181        } else {
1182            self.portal_store.remove(portal_name);
1183        }
1184        self.result_cache.remove(portal_name);
1185    }
1186
1187    fn get_portal(&self, portal_name: &str) -> PsqlResult<<SM::Session as Session>::Portal> {
1188        if portal_name.is_empty() {
1189            Ok(self
1190                .unnamed_portal
1191                .as_ref()
1192                .ok_or_else(|| PsqlError::Uncategorized("unnamed portal not found".into()))?
1193                .clone())
1194        } else {
1195            Ok(self
1196                .portal_store
1197                .get(portal_name)
1198                .ok_or_else(|| {
1199                    PsqlError::Uncategorized(format!("Portal {} not found", portal_name).into())
1200                })?
1201                .clone())
1202        }
1203    }
1204
1205    fn get_statement(
1206        &self,
1207        statement_name: &str,
1208    ) -> PsqlResult<<SM::Session as Session>::PreparedStatement> {
1209        if statement_name.is_empty() {
1210            Ok(self
1211                .unnamed_prepare_statement
1212                .as_ref()
1213                .ok_or_else(|| {
1214                    PsqlError::Uncategorized("unnamed prepare statement not found".into())
1215                })?
1216                .clone())
1217        } else {
1218            Ok(self
1219                .prepare_statement_store
1220                .get(statement_name)
1221                .ok_or_else(|| {
1222                    PsqlError::Uncategorized(
1223                        format!("Prepare statement {} not found", statement_name).into(),
1224                    )
1225                })?
1226                .clone())
1227        }
1228    }
1229}
1230
1231enum PgStreamInner<S> {
1232    /// Used for the intermediate state when converting from unencrypted to ssl stream.
1233    Placeholder,
1234    /// An unencrypted stream.
1235    Unencrypted(S),
1236    /// An ssl stream.
1237    Ssl(SslStream<S>),
1238}
1239
1240/// Trait for a byte stream that can be used for pg protocol.
1241pub trait PgByteStream: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1242impl<S> PgByteStream for S where S: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1243
1244/// Wraps a byte stream and read/write pg messages.
1245///
1246/// Cloning a `PgStream` will share the same stream but a fresh & independent write buffer,
1247/// so that it can be used to write messages concurrently without interference.
1248pub struct PgStream<S> {
1249    /// The underlying stream.
1250    stream: Arc<Mutex<PgStreamInner<S>>>,
1251    /// Write into buffer before flush to stream.
1252    write_buf: BytesMut,
1253    read_header: Option<FeMessageHeader>,
1254}
1255
1256impl<S> PgStream<S> {
1257    /// Create a new `PgStream` with the given stream and default write buffer capacity.
1258    pub fn new(stream: S) -> Self {
1259        const DEFAULT_WRITE_BUF_CAPACITY: usize = 10 * 1024;
1260
1261        Self {
1262            stream: Arc::new(Mutex::new(PgStreamInner::Unencrypted(stream))),
1263            write_buf: BytesMut::with_capacity(DEFAULT_WRITE_BUF_CAPACITY),
1264            read_header: None,
1265        }
1266    }
1267
1268    /// Check if the current connection is using SSL
1269    async fn is_ssl_connection(&self) -> bool {
1270        let stream = self.stream.lock().await;
1271        matches!(*stream, PgStreamInner::Ssl(_))
1272    }
1273}
1274
1275impl<S> Clone for PgStream<S> {
1276    fn clone(&self) -> Self {
1277        Self {
1278            stream: Arc::clone(&self.stream),
1279            write_buf: BytesMut::with_capacity(self.write_buf.capacity()),
1280            read_header: self.read_header.clone(),
1281        }
1282    }
1283}
1284
1285/// At present there is a hard-wired set of parameters for which
1286/// ParameterStatus will be generated: they are:
1287///
1288///  * `server_version`
1289///  * `server_encoding`
1290///  * `client_encoding`
1291///  * `application_name`
1292///  * `is_superuser`
1293///  * `session_authorization`
1294///  * `DateStyle`
1295///  * `IntervalStyle`
1296///  * `TimeZone`
1297///  * `integer_datetimes`
1298///  * `standard_conforming_string`
1299///
1300/// See: <https://www.postgresql.org/docs/9.2/static/protocol-flow.html#PROTOCOL-ASYNC>.
1301#[derive(Debug, Default, Clone)]
1302pub struct ParameterStatus {
1303    pub application_name: Option<String>,
1304}
1305
1306impl<S> PgStream<S>
1307where
1308    S: PgByteStream,
1309{
1310    async fn read_startup(&mut self) -> io::Result<FeMessage> {
1311        let mut stream = self.stream.lock().await;
1312        match &mut *stream {
1313            PgStreamInner::Placeholder => unreachable!(),
1314            PgStreamInner::Unencrypted(stream) => FeStartupMessage::read(stream).await,
1315            PgStreamInner::Ssl(ssl_stream) => FeStartupMessage::read(ssl_stream).await,
1316        }
1317    }
1318
1319    async fn read_header(&mut self) -> io::Result<()> {
1320        let mut stream = self.stream.lock().await;
1321        match &mut *stream {
1322            PgStreamInner::Placeholder => unreachable!(),
1323            PgStreamInner::Unencrypted(stream) => {
1324                self.read_header = Some(FeMessage::read_header(stream).await?);
1325                Ok(())
1326            }
1327            PgStreamInner::Ssl(ssl_stream) => {
1328                self.read_header = Some(FeMessage::read_header(ssl_stream).await?);
1329                Ok(())
1330            }
1331        }
1332    }
1333
1334    async fn read_body(&mut self) -> io::Result<FeMessage> {
1335        let mut stream = self.stream.lock().await;
1336        let header = self
1337            .read_header
1338            .take()
1339            .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1340        match &mut *stream {
1341            PgStreamInner::Placeholder => unreachable!(),
1342            PgStreamInner::Unencrypted(stream) => FeMessage::read_body(stream, header).await,
1343            PgStreamInner::Ssl(ssl_stream) => FeMessage::read_body(ssl_stream, header).await,
1344        }
1345    }
1346
1347    async fn skip_body(&mut self) -> io::Result<()> {
1348        let mut stream = self.stream.lock().await;
1349        let header = self
1350            .read_header
1351            .take()
1352            .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1353        match &mut *stream {
1354            PgStreamInner::Placeholder => unreachable!(),
1355            PgStreamInner::Unencrypted(stream) => FeMessage::skip_body(stream, header).await,
1356            PgStreamInner::Ssl(ssl_stream) => FeMessage::skip_body(ssl_stream, header).await,
1357        }
1358    }
1359
1360    fn write_parameter_status_msg_no_flush(&mut self, status: &ParameterStatus) -> io::Result<()> {
1361        self.write_no_flush(BeMessage::ParameterStatus(
1362            BeParameterStatusMessage::ClientEncoding(SERVER_ENCODING),
1363        ))?;
1364        self.write_no_flush(BeMessage::ParameterStatus(
1365            BeParameterStatusMessage::StandardConformingString(STANDARD_CONFORMING_STRINGS),
1366        ))?;
1367        self.write_no_flush(BeMessage::ParameterStatus(
1368            BeParameterStatusMessage::ServerVersion(PG_VERSION),
1369        ))?;
1370        if let Some(application_name) = &status.application_name {
1371            self.write_no_flush(BeMessage::ParameterStatus(
1372                BeParameterStatusMessage::ApplicationName(application_name),
1373            ))?;
1374        }
1375        Ok(())
1376    }
1377
1378    pub fn write_no_flush(&mut self, message: BeMessage<'_>) -> io::Result<()> {
1379        BeMessage::write(&mut self.write_buf, message)
1380    }
1381
1382    async fn write(&mut self, message: BeMessage<'_>) -> io::Result<()> {
1383        self.write_no_flush(message)?;
1384        self.flush().await?;
1385        Ok(())
1386    }
1387
1388    async fn flush(&mut self) -> io::Result<()> {
1389        let mut stream = self.stream.lock().await;
1390        match &mut *stream {
1391            PgStreamInner::Placeholder => unreachable!(),
1392            PgStreamInner::Unencrypted(stream) => {
1393                stream.write_all(&self.write_buf).await?;
1394                stream.flush().await?;
1395            }
1396            PgStreamInner::Ssl(ssl_stream) => {
1397                ssl_stream.write_all(&self.write_buf).await?;
1398                ssl_stream.flush().await?;
1399            }
1400        }
1401        self.write_buf.clear();
1402        Ok(())
1403    }
1404}
1405
1406impl<S> PgStream<S>
1407where
1408    S: PgByteStream,
1409{
1410    /// Convert the underlying stream to ssl stream based on the given context.
1411    async fn upgrade_to_ssl(&mut self, ssl_ctx: &SslContextRef) -> PsqlResult<()> {
1412        let mut stream = self.stream.lock().await;
1413
1414        match std::mem::replace(&mut *stream, PgStreamInner::Placeholder) {
1415            PgStreamInner::Unencrypted(unencrypted_stream) => {
1416                let ssl = openssl::ssl::Ssl::new(ssl_ctx).unwrap();
1417                let mut ssl_stream =
1418                    tokio_openssl::SslStream::new(ssl, unencrypted_stream).unwrap();
1419
1420                if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
1421                    tracing::warn!(error = %e.as_report(), "Unable to set up an ssl connection");
1422                    let _ = ssl_stream.shutdown().await;
1423                    return Err(e.into());
1424                }
1425
1426                *stream = PgStreamInner::Ssl(ssl_stream);
1427            }
1428            PgStreamInner::Ssl(_) => panic!("the stream is already ssl"),
1429            PgStreamInner::Placeholder => unreachable!(),
1430        }
1431
1432        Ok(())
1433    }
1434}
1435
1436fn build_ssl_ctx_from_config(tls_config: &TlsConfig) -> PsqlResult<SslContext> {
1437    let mut acceptor = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls()).unwrap();
1438
1439    let key_path = &tls_config.key;
1440    let cert_path = &tls_config.cert;
1441
1442    // Build ssl acceptor according to the config.
1443    // Now we set every verify to true.
1444    acceptor
1445        .set_private_key_file(key_path, openssl::ssl::SslFiletype::PEM)
1446        .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1447    acceptor
1448        .set_ca_file(cert_path)
1449        .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1450    acceptor
1451        .set_certificate_chain_file(cert_path)
1452        .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1453    let acceptor = acceptor.build();
1454
1455    Ok(acceptor.into_context())
1456}
1457
1458pub mod truncated_fmt {
1459    use std::fmt::*;
1460
1461    struct TruncatedFormatter<'a, 'b> {
1462        remaining: usize,
1463        finished: bool,
1464        f: &'a mut Formatter<'b>,
1465    }
1466    impl Write for TruncatedFormatter<'_, '_> {
1467        fn write_str(&mut self, s: &str) -> Result {
1468            if self.finished {
1469                return Ok(());
1470            }
1471
1472            if self.remaining < s.len() {
1473                let actual = s.floor_char_boundary(self.remaining);
1474                self.f.write_str(&s[0..actual])?;
1475                self.remaining -= actual;
1476                self.f.write_str(&format!("...(truncated,{})", s.len()))?;
1477                self.finished = true; // so that ...(truncated) is printed exactly once
1478            } else {
1479                self.f.write_str(s)?;
1480                self.remaining -= s.len();
1481            }
1482            Ok(())
1483        }
1484    }
1485
1486    pub struct TruncatedFmt<'a, T>(pub &'a T, pub usize);
1487
1488    impl<T> Debug for TruncatedFmt<'_, T>
1489    where
1490        T: Debug,
1491    {
1492        fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1493            TruncatedFormatter {
1494                remaining: self.1,
1495                finished: false,
1496                f,
1497            }
1498            .write_fmt(format_args!("{:?}", self.0))
1499        }
1500    }
1501
1502    impl<T> Display for TruncatedFmt<'_, T>
1503    where
1504        T: Display,
1505    {
1506        fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1507            TruncatedFormatter {
1508                remaining: self.1,
1509                finished: false,
1510                f,
1511            }
1512            .write_fmt(format_args!("{}", self.0))
1513        }
1514    }
1515
1516    #[cfg(test)]
1517    mod tests {
1518        use super::*;
1519
1520        #[test]
1521        fn test_trunc_utf8() {
1522            assert_eq!(
1523                format!("{}", TruncatedFmt(&"select '🌊';", 10)),
1524                "select '...(truncated,14)",
1525            );
1526        }
1527    }
1528}
1529
1530/// Handle `options` in `StartupMessage` from client
1531///
1532/// It is like shell arguments but only respects backslash-escape and space;
1533/// quotes have no special meaning and are handled literally.
1534///
1535/// PostgreSQL allows both `-c key=value` and `--key=value`.
1536///
1537/// `key-name` is normalized as `key_name`.
1538///
1539/// * <https://github.com/postgres/postgres/blob/REL_18_1/src/backend/utils/init/postinit.c#L487>
1540/// * <https://github.com/postgres/postgres/blob/REL_18_1/src/backend/tcop/postgres.c#L3866>
1541/// * <https://github.com/postgres/postgres/blob/REL_18_1/src/backend/utils/misc/guc.c#L6361>
1542fn parse_options(options: &str) -> PsqlResult<Vec<(String, String)>> {
1543    let mut args = Vec::new();
1544    let mut current_arg = String::new();
1545    let mut chars = options.chars().peekable();
1546
1547    while let Some(c) = chars.next() {
1548        if c == '\\' {
1549            if let Some(next_c) = chars.next() {
1550                current_arg.push(next_c);
1551            }
1552        } else if c.is_ascii_whitespace() {
1553            if !current_arg.is_empty() {
1554                args.push(std::mem::take(&mut current_arg));
1555            }
1556        } else {
1557            current_arg.push(c);
1558        }
1559    }
1560    if !current_arg.is_empty() {
1561        args.push(current_arg);
1562    }
1563
1564    let mut args_iter = args.into_iter();
1565    let mut config = Vec::new();
1566
1567    while let Some(arg) = args_iter.next() {
1568        if arg == "-c" {
1569            if let Some(config_str) = args_iter.next() {
1570                if let Some((key, value)) = config_str.split_once('=') {
1571                    let key = key.replace("-", "_");
1572                    config.push((key, value.to_owned()));
1573                } else {
1574                    return Err(PsqlError::StartupError(
1575                        format!("invalid config format: {}", config_str).into(),
1576                    ));
1577                }
1578            } else {
1579                return Err(PsqlError::StartupError("missing argument for -c".into()));
1580            }
1581        } else if let Some(config_str) = arg.strip_prefix("--") {
1582            if let Some((key, value)) = config_str.split_once('=') {
1583                let key = key.replace("-", "_");
1584                config.push((key, value.to_owned()));
1585            } else {
1586                return Err(PsqlError::StartupError(
1587                    format!("invalid config format: {}", config_str).into(),
1588                ));
1589            }
1590        } else {
1591            tracing::warn!(
1592                arg,
1593                "ignoring unrecognized option for backward compatibility"
1594            );
1595        }
1596    }
1597    Ok(config)
1598}
1599
1600#[cfg(test)]
1601mod tests {
1602    use std::collections::HashSet;
1603
1604    use super::*;
1605
1606    #[test]
1607    fn test_redact_parsable_sql() {
1608        let keywords = Arc::new(HashSet::from(["v2".into(), "v4".into(), "b".into()]));
1609        let sql = r"
1610        create source temp (k bigint, v varchar) with (
1611            connector = 'datagen',
1612            v1 = 123,
1613            v2 = 'with',
1614            v3 = false,
1615            v4 = '',
1616        ) FORMAT plain ENCODE json (a='1',b='2')
1617        ";
1618        assert_eq!(
1619            redact_sql(sql, keywords),
1620            "CREATE SOURCE temp (k BIGINT, v CHARACTER VARYING) WITH (connector = 'datagen', v1 = 123, v2 = [REDACTED], v3 = false, v4 = [REDACTED]) FORMAT PLAIN ENCODE JSON (a = '1', b = [REDACTED])"
1621        );
1622    }
1623
1624    #[test]
1625    fn test_parse_options() {
1626        assert_eq!(parse_options("").unwrap(), vec![]);
1627        assert_eq!(
1628            parse_options("-c a=1 -c b=2").unwrap(),
1629            vec![("a".into(), "1".into()), ("b".into(), "2".into())]
1630        );
1631        assert_eq!(
1632            parse_options("-c   key=value").unwrap(),
1633            vec![("key".into(), "value".into())]
1634        );
1635        // Custom parser treats quotes as normal characters, so they are included in value
1636        assert_eq!(
1637            parse_options("-c key='value'").unwrap(),
1638            vec![("key".into(), "'value'".into())]
1639        );
1640
1641        // Test backslash escaping for spaces (standard Postgres way)
1642        assert_eq!(
1643            parse_options(r#"-c key=value\ with\ spaces"#).unwrap(),
1644            vec![("key".into(), "value with spaces".into())]
1645        );
1646        assert_eq!(
1647            parse_options(r#"-c search_path=my\ schema"#).unwrap(),
1648            vec![("search_path".into(), "my schema".into())]
1649        );
1650
1651        assert!(parse_options("-c").is_err());
1652        assert!(parse_options("-c foo").is_err()); // missing =
1653        assert!(parse_options("--foo").is_err()); // missing = in -- option
1654
1655        assert_eq!(
1656            parse_options("--foo=bar").unwrap(),
1657            vec![("foo".into(), "bar".into())]
1658        );
1659        assert_eq!(
1660            parse_options(r#"--foo=bar\ baz"#).unwrap(),
1661            vec![("foo".into(), "bar baz".into())]
1662        );
1663        assert_eq!(
1664            parse_options("-c a=1 --b=2").unwrap(),
1665            vec![("a".into(), "1".into()), ("b".into(), "2".into())]
1666        );
1667        // Unpaired trailing backslash is silently dropped, same as PostgreSQL
1668        assert_eq!(
1669            parse_options(r#"-c a=b\"#).unwrap(),
1670            vec![("a".into(), "b".into())]
1671        );
1672    }
1673}