pgwire/
pg_protocol.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::collections::HashMap;
17use std::io::ErrorKind;
18use std::panic::AssertUnwindSafe;
19use std::pin::Pin;
20use std::str::Utf8Error;
21use std::sync::{Arc, LazyLock, Weak};
22use std::time::{Duration, Instant};
23use std::{io, str};
24
25use bytes::{Bytes, BytesMut};
26use futures::FutureExt;
27use futures::stream::StreamExt;
28use itertools::Itertools;
29use openssl::ssl::{SslAcceptor, SslContext, SslContextRef, SslMethod};
30use risingwave_common::types::DataType;
31use risingwave_common::util::deployment::Deployment;
32use risingwave_common::util::env_var::env_var_is_true;
33use risingwave_common::util::panic::FutureCatchUnwindExt;
34use risingwave_common::util::query_log::*;
35use risingwave_common::{PG_VERSION, SERVER_ENCODING, STANDARD_CONFORMING_STRINGS};
36use risingwave_sqlparser::ast::{RedactSqlOptionKeywordsRef, Statement};
37use risingwave_sqlparser::parser::Parser;
38use thiserror_ext::AsReport;
39use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
40use tokio::sync::Mutex;
41use tokio_openssl::SslStream;
42use tracing::Instrument;
43
44use crate::error::{PsqlError, PsqlResult};
45use crate::memory_manager::{MessageMemoryGuard, MessageMemoryManagerRef};
46use crate::net::AddressRef;
47use crate::pg_extended::ResultCache;
48use crate::pg_message::{
49    BeCommandCompleteMessage, BeMessage, BeParameterStatusMessage, FeBindMessage, FeCancelMessage,
50    FeCloseMessage, FeDescribeMessage, FeExecuteMessage, FeMessage, FeMessageHeader,
51    FeParseMessage, FePasswordMessage, FeStartupMessage, ServerThrottleReason, TransactionStatus,
52};
53use crate::pg_server::{Session, SessionManager, UserAuthenticator};
54use crate::types::Format;
55
56/// Truncates query log if it's longer than `RW_QUERY_LOG_TRUNCATE_LEN`, to avoid log file being too
57/// large.
58static RW_QUERY_LOG_TRUNCATE_LEN: LazyLock<usize> =
59    LazyLock::new(|| match std::env::var("RW_QUERY_LOG_TRUNCATE_LEN") {
60        Ok(len) if len.parse::<usize>().is_ok() => len.parse::<usize>().unwrap(),
61        _ => 65536,
62    });
63
64tokio::task_local! {
65    /// The current session. Concrete type is erased for different session implementations.
66    pub static CURRENT_SESSION: Weak<dyn Any + Send + Sync>
67}
68
69/// The state machine for each psql connection.
70/// Read pg messages from tcp stream and write results back.
71pub struct PgProtocol<S, SM>
72where
73    SM: SessionManager,
74{
75    /// Used for write/read pg messages.
76    stream: PgStream<S>,
77    /// Current states of pg connection.
78    state: PgProtocolState,
79    /// Whether the connection is terminated.
80    is_terminate: bool,
81
82    session_mgr: Arc<SM>,
83    session: Option<Arc<SM::Session>>,
84
85    result_cache: HashMap<String, ResultCache<<SM::Session as Session>::ValuesStream>>,
86    unnamed_prepare_statement: Option<<SM::Session as Session>::PreparedStatement>,
87    prepare_statement_store: HashMap<String, <SM::Session as Session>::PreparedStatement>,
88    unnamed_portal: Option<<SM::Session as Session>::Portal>,
89    portal_store: HashMap<String, <SM::Session as Session>::Portal>,
90    // Used to store the dependency of portal and prepare statement.
91    // When we close a prepare statement, we need to close all the portals that depend on it.
92    statement_portal_dependency: HashMap<String, Vec<String>>,
93
94    // Used for ssl connection.
95    // If None, not expected to build ssl connection (panic).
96    tls_context: Option<SslContext>,
97
98    // TLS configuration including SSL enforcement setting
99    tls_config: Option<TlsConfig>,
100
101    // Used in extended query protocol. When encounter error in extended query, we need to ignore
102    // the following message util sync message.
103    ignore_util_sync: bool,
104
105    // Client Address
106    peer_addr: AddressRef,
107
108    redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
109    message_memory_manager: MessageMemoryManagerRef,
110}
111
112/// Configures TLS encryption for connections.
113#[derive(Debug, Clone)]
114pub struct TlsConfig {
115    /// The path to the TLS certificate.
116    pub cert: String,
117    /// The path to the TLS key.
118    pub key: String,
119    /// Whether to enforce SSL connections (reject non-SSL clients).
120    pub enforce_ssl: bool,
121}
122
123impl TlsConfig {
124    pub fn new_default() -> Option<Self> {
125        let cert = std::env::var("RW_SSL_CERT").ok()?;
126        let key = std::env::var("RW_SSL_KEY").ok()?;
127        let enforce_ssl = env_var_is_true("RW_SSL_ENFORCE");
128        tracing::info!(
129            "RW_SSL_CERT={}, RW_SSL_KEY={}, RW_SSL_ENFORCE={}",
130            cert,
131            key,
132            enforce_ssl
133        );
134        Some(Self {
135            cert,
136            key,
137            enforce_ssl,
138        })
139    }
140}
141
142impl<S, SM> Drop for PgProtocol<S, SM>
143where
144    SM: SessionManager,
145{
146    fn drop(&mut self) {
147        if let Some(session) = &self.session {
148            // Clear the session in session manager.
149            self.session_mgr.end_session(session);
150        }
151    }
152}
153
154/// States flow happened from top to down.
155enum PgProtocolState {
156    Startup,
157    Regular,
158}
159
160/// Truncate 0 from C string in Bytes and stringify it (returns slice, no allocations).
161///
162/// PG protocol strings are always C strings.
163pub fn cstr_to_str(b: &Bytes) -> Result<&str, Utf8Error> {
164    let without_null = if b.last() == Some(&0) {
165        &b[..b.len() - 1]
166    } else {
167        &b[..]
168    };
169    std::str::from_utf8(without_null)
170}
171
172fn record_sql_in_current_span(
173    sql: &str,
174    redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
175) -> String {
176    let mut span = tracing::Span::current();
177    record_sql_in_span(sql, redact_sql_option_keywords, &mut span)
178}
179
180/// Record `sql` in the current tracing span.
181fn record_sql_in_span(
182    sql: &str,
183    redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
184    span: &mut tracing::Span,
185) -> String {
186    let redacted_sql = if let Some(keywords) = redact_sql_option_keywords
187        && !keywords.is_empty()
188    {
189        redact_sql(sql, keywords)
190    } else {
191        sql.to_owned()
192    };
193    let truncated = truncated_fmt::TruncatedFmt(&redacted_sql, *RW_QUERY_LOG_TRUNCATE_LEN);
194    span.record("sql", tracing::field::display(&truncated));
195    truncated.to_string()
196}
197
198/// Redacts SQL options. Data in DML is not redacted.
199fn redact_sql(sql: &str, keywords: RedactSqlOptionKeywordsRef) -> String {
200    match Parser::parse_sql(sql) {
201        Ok(sqls) => sqls
202            .into_iter()
203            .map(|sql| sql.to_redacted_string(keywords.clone()))
204            .join(";"),
205        Err(_) => sql.to_owned(),
206    }
207}
208
209#[derive(Clone)]
210pub struct ConnectionContext {
211    pub tls_config: Option<TlsConfig>,
212    pub redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
213    pub message_memory_manager: MessageMemoryManagerRef,
214}
215
216impl<S, SM> PgProtocol<S, SM>
217where
218    S: PgByteStream,
219    SM: SessionManager,
220{
221    pub fn new(
222        stream: S,
223        session_mgr: Arc<SM>,
224        peer_addr: AddressRef,
225        context: ConnectionContext,
226    ) -> Self {
227        let ConnectionContext {
228            tls_config,
229            redact_sql_option_keywords,
230            message_memory_manager,
231        } = context;
232        Self {
233            stream: PgStream::new(stream),
234            is_terminate: false,
235            state: PgProtocolState::Startup,
236            session_mgr,
237            session: None,
238            tls_context: tls_config
239                .as_ref()
240                .and_then(|e| build_ssl_ctx_from_config(e).ok()),
241            tls_config,
242            result_cache: Default::default(),
243            unnamed_prepare_statement: Default::default(),
244            prepare_statement_store: Default::default(),
245            unnamed_portal: Default::default(),
246            portal_store: Default::default(),
247            statement_portal_dependency: Default::default(),
248            ignore_util_sync: false,
249            peer_addr,
250            redact_sql_option_keywords,
251            message_memory_manager,
252        }
253    }
254
255    /// Run the protocol to serve the connection.
256    pub async fn run(&mut self) {
257        let mut notice_fut = None;
258
259        loop {
260            // Once a session is present, create a future to subscribe and send notices asynchronously.
261            if notice_fut.is_none()
262                && let Some(session) = self.session.clone()
263            {
264                let mut stream = self.stream.clone();
265                notice_fut = Some(Box::pin(async move {
266                    loop {
267                        let notice = session.next_notice().await;
268                        if let Err(e) = stream.write(BeMessage::NoticeResponse(&notice)).await {
269                            tracing::error!(error = %e.as_report(), notice, "failed to send notice");
270                        }
271                    }
272                }));
273            }
274
275            // Read and process messages.
276            let process = std::pin::pin!(async {
277                let (msg, _memory_guard) = match self.read_message().await {
278                    Ok(msg) => msg,
279                    Err(e) => {
280                        tracing::error!(error = %e.as_report(), "error when reading message");
281                        return true; // terminate the connection
282                    }
283                };
284                tracing::trace!(?msg, "received message");
285                self.process(msg).await
286            });
287
288            let terminated = if let Some(notice_fut) = notice_fut.as_mut() {
289                tokio::select! {
290                    _ = notice_fut => unreachable!(),
291                    terminated = process => terminated,
292                }
293            } else {
294                process.await
295            };
296
297            if terminated {
298                break;
299            }
300        }
301    }
302
303    /// Processes one message. Returns true if the connection is terminated.
304    pub async fn process(&mut self, msg: FeMessage) -> bool {
305        self.do_process(msg).await.is_none() || self.is_terminate
306    }
307
308    /// The root tracing span for processing a message. The target of the span is
309    /// [`PGWIRE_ROOT_SPAN_TARGET`].
310    ///
311    /// This is used to provide context for the (slow) query logs and traces.
312    ///
313    /// The span is only effective if there's a current session and the message is
314    /// query-related. Otherwise, `Span::none()` is returned.
315    fn root_span_for_msg(&self, msg: &FeMessage) -> tracing::Span {
316        let Some(session_id) = self.session.as_ref().map(|s| s.id().0) else {
317            return tracing::Span::none();
318        };
319
320        let mode = match msg {
321            FeMessage::Query(_) => "simple query",
322            FeMessage::Parse(_) => "extended query parse",
323            FeMessage::Execute(_) => "extended query execute",
324            _ => return tracing::Span::none(),
325        };
326
327        let mut span = tracing::info_span!(
328            target: PGWIRE_ROOT_SPAN_TARGET,
329            "handle_query",
330            mode,
331            session_id,
332            sql = tracing::field::Empty,
333        );
334        if let Ok(sql) = msg.get_sql()
335            && let Some(sql) = sql
336        {
337            record_sql_in_span(sql, self.redact_sql_option_keywords.clone(), &mut span);
338        }
339        span
340    }
341
342    /// Return type `Option<()>` is essentially a bool, but allows `?` for early return.
343    /// - `None` means to terminate the current connection
344    /// - `Some(())` means to continue processing the next message
345    async fn do_process(&mut self, msg: FeMessage) -> Option<()> {
346        let span = self.root_span_for_msg(&msg);
347        let weak_session = self
348            .session
349            .as_ref()
350            .map(|s| Arc::downgrade(s) as Weak<dyn Any + Send + Sync>);
351
352        // Processing the message itself.
353        //
354        // Note: pin the future to avoid stack overflow as we'll wrap it multiple times
355        // in the following code.
356        let fut = Box::pin(self.do_process_inner(msg));
357
358        // Set the current session as the context when processing the message, if exists.
359        let fut = async move {
360            if let Some(session) = weak_session {
361                CURRENT_SESSION.scope(session, fut).await
362            } else {
363                fut.await
364            }
365        };
366
367        // Catch unwind.
368        let fut = async move {
369            AssertUnwindSafe(fut)
370                .rw_catch_unwind()
371                .await
372                .unwrap_or_else(|payload| {
373                    Err(PsqlError::Panic(
374                        panic_message::panic_message(&payload).to_owned(),
375                    ))
376                })
377        };
378
379        // Slow query log.
380        let fut = async move {
381            let period = *SLOW_QUERY_LOG_PERIOD;
382            let mut fut = std::pin::pin!(fut);
383            let mut elapsed = Duration::ZERO;
384
385            // Report the SQL in the log periodically if the query is slow.
386            loop {
387                match tokio::time::timeout(period, &mut fut).await {
388                    Ok(result) => break result,
389                    Err(_) => {
390                        elapsed += period;
391                        tracing::info!(
392                            target: PGWIRE_SLOW_QUERY_LOG,
393                            elapsed = %format_args!("{}ms", elapsed.as_millis()),
394                            "slow query"
395                        );
396                    }
397                }
398            }
399        };
400
401        // Query log.
402        let fut = async move {
403            if !tracing::Span::current().is_none() {
404                tracing::info!(
405                    target: PGWIRE_QUERY_LOG,
406                    status = "started",
407                );
408            }
409
410            let start = Instant::now();
411            let result = fut.await;
412            let elapsed = start.elapsed();
413
414            // Always log if an error occurs.
415            // Note: all messages will be processed through this code path, making it the
416            //       only necessary place to log errors.
417            if let Err(error) = &result {
418                if cfg!(debug_assertions) && !Deployment::current().is_ci() {
419                    // For local debugging, we print the error with backtrace.
420                    // It's useful only when:
421                    // - no additional context is added to the error
422                    // - backtrace is captured in the error
423                    // - backtrace is not printed in the middle
424                    tracing::error!(error = ?error.as_report(), "error when process message");
425                } else {
426                    tracing::error!(error = %error.as_report(), "error when process message");
427                }
428            }
429
430            // Log to optionally-enabled target `PGWIRE_QUERY_LOG`.
431            // Only log if we're currently in a tracing span set in `span_for_msg`.
432            if !tracing::Span::current().is_none() {
433                tracing::info!(
434                    target: PGWIRE_QUERY_LOG,
435                    status = if result.is_ok() { "ok" } else { "err" },
436                    time = %format_args!("{}ms", elapsed.as_millis()),
437                );
438            }
439
440            result
441        };
442
443        // Tracing span.
444        let fut = fut.instrument(span);
445
446        // Execute the future and handle the error.
447        match fut.await {
448            Ok(()) => Some(()),
449            Err(e) => {
450                match e {
451                    PsqlError::IoError(io_err) => {
452                        if io_err.kind() == std::io::ErrorKind::UnexpectedEof {
453                            return None;
454                        }
455                    }
456
457                    PsqlError::SslError(_) => {
458                        // For ssl error, because the stream has already been consumed, so there is
459                        // no way to write more message.
460                        return None;
461                    }
462
463                    PsqlError::StartupError(_) | PsqlError::PasswordError => {
464                        self.stream
465                            .write_no_flush(BeMessage::ErrorResponse {
466                                error: &e,
467                                // At this time we're not in a session, use compact error message for
468                                // better alignment with Postgres' UI.
469                                pretty: false,
470                            })
471                            .ok()?;
472                        let _ = self.stream.flush().await;
473                        return None;
474                    }
475
476                    PsqlError::SimpleQueryError(_) | PsqlError::ServerThrottle(_) => {
477                        self.stream
478                            .write_no_flush(BeMessage::ErrorResponse {
479                                error: &e,
480                                pretty: true,
481                            })
482                            .ok()?;
483                        self.ready_for_query().ok()?;
484                    }
485
486                    PsqlError::IdleInTxnTimeout | PsqlError::Panic(_) => {
487                        self.stream
488                            .write_no_flush(BeMessage::ErrorResponse {
489                                error: &e,
490                                pretty: true,
491                            })
492                            .ok()?;
493                        let _ = self.stream.flush().await;
494
495                        // 1. Catching the panic during message processing may leave the session in an
496                        // inconsistent state. We forcefully close the connection (then end the
497                        // session) here for safety.
498                        // 2. Idle in transaction timeout should also close the connection.
499                        return None;
500                    }
501
502                    PsqlError::Uncategorized(_)
503                    | PsqlError::ExtendedPrepareError(_)
504                    | PsqlError::ExtendedExecuteError(_) => {
505                        self.stream
506                            .write_no_flush(BeMessage::ErrorResponse {
507                                error: &e,
508                                pretty: true,
509                            })
510                            .ok()?;
511                    }
512                }
513                let _ = self.stream.flush().await;
514                Some(())
515            }
516        }
517    }
518
519    async fn do_process_inner(&mut self, msg: FeMessage) -> PsqlResult<()> {
520        // Ignore util sync message.
521        if self.ignore_util_sync {
522            if let FeMessage::Sync = msg {
523            } else {
524                tracing::trace!("ignore message {:?} until sync.", msg);
525                return Ok(());
526            }
527        }
528
529        match msg {
530            FeMessage::Gss => self.process_gss_msg().await?,
531            FeMessage::Ssl => self.process_ssl_msg().await?,
532            FeMessage::Startup(msg) => self.process_startup_msg(msg).await?,
533            FeMessage::Password(msg) => self.process_password_msg(msg).await?,
534            FeMessage::Query(query_msg) => {
535                let sql = Arc::from(query_msg.get_sql()?);
536                // The process_query_msg can be slow. Release potential large FeQueryMessage early.
537                drop(query_msg);
538                self.process_query_msg(sql).await?
539            }
540            FeMessage::CancelQuery(m) => self.process_cancel_msg(m)?,
541            FeMessage::Terminate => self.process_terminate(),
542            FeMessage::Parse(m) => {
543                if let Err(err) = self.process_parse_msg(m).await {
544                    self.ignore_util_sync = true;
545                    return Err(err);
546                }
547            }
548            FeMessage::Bind(m) => {
549                if let Err(err) = self.process_bind_msg(m) {
550                    self.ignore_util_sync = true;
551                    return Err(err);
552                }
553            }
554            FeMessage::Execute(m) => {
555                if let Err(err) = self.process_execute_msg(m).await {
556                    self.ignore_util_sync = true;
557                    return Err(err);
558                }
559            }
560            FeMessage::Describe(m) => {
561                if let Err(err) = self.process_describe_msg(m) {
562                    self.ignore_util_sync = true;
563                    return Err(err);
564                }
565            }
566            FeMessage::Sync => {
567                self.ignore_util_sync = false;
568                self.ready_for_query()?
569            }
570            FeMessage::Close(m) => {
571                if let Err(err) = self.process_close_msg(m) {
572                    self.ignore_util_sync = true;
573                    return Err(err);
574                }
575            }
576            FeMessage::Flush => {
577                if let Err(err) = self.stream.flush().await {
578                    self.ignore_util_sync = true;
579                    return Err(err.into());
580                }
581            }
582            FeMessage::HealthCheck => self.process_health_check(),
583            FeMessage::ServerThrottle(reason) => match reason {
584                ServerThrottleReason::TooLargeMessage => {
585                    return Err(PsqlError::ServerThrottle(format!(
586                        "max_single_query_size_bytes {} has been exceeded, please either reduce the query size or increase the limit",
587                        self.message_memory_manager.max_filter_bytes
588                    )));
589                }
590                ServerThrottleReason::TooManyMemoryUsage => {
591                    return Err(PsqlError::ServerThrottle(format!(
592                        "max_total_query_size_bytes {} has been exceeded, please either retry or increase the limit",
593                        self.message_memory_manager.max_running_bytes
594                    )));
595                }
596            },
597        }
598        self.stream.flush().await?;
599        Ok(())
600    }
601
602    pub async fn read_message(&mut self) -> io::Result<(FeMessage, Option<MessageMemoryGuard>)> {
603        match self.state {
604            PgProtocolState::Startup => self
605                .stream
606                .read_startup()
607                .await
608                .map(|message: FeMessage| (message, None)),
609            PgProtocolState::Regular => {
610                self.stream.read_header().await?;
611                let guard = if let Some(ref header) = self.stream.read_header {
612                    let payload_len = std::cmp::max(header.payload_len, 0) as u64;
613                    let (reason, guard) = self.message_memory_manager.add(payload_len);
614                    if let Some(reason) = reason {
615                        // Release the memory ASAP.
616                        drop(guard);
617                        self.stream.skip_body().await?;
618                        return Ok((FeMessage::ServerThrottle(reason), None));
619                    }
620                    guard
621                } else {
622                    None
623                };
624                let message = self.stream.read_body().await?;
625                Ok((message, guard))
626            }
627        }
628    }
629
630    /// Writes a `ReadyForQuery` message to the client without flushing.
631    fn ready_for_query(&mut self) -> io::Result<()> {
632        self.stream.write_no_flush(BeMessage::ReadyForQuery(
633            self.session
634                .as_ref()
635                .map(|s| s.transaction_status())
636                .unwrap_or(TransactionStatus::Idle),
637        ))
638    }
639
640    async fn process_gss_msg(&mut self) -> PsqlResult<()> {
641        // We don't support GSSAPI, so we just say no gracefully.
642        self.stream.write(BeMessage::EncryptionResponseNo).await?;
643        Ok(())
644    }
645
646    async fn process_ssl_msg(&mut self) -> PsqlResult<()> {
647        if let Some(context) = self.tls_context.as_ref() {
648            // If got and ssl context, say yes for ssl connection.
649            // Construct ssl stream and replace with current one.
650            self.stream.write(BeMessage::EncryptionResponseSsl).await?;
651            self.stream.upgrade_to_ssl(context).await?;
652        } else {
653            // If no, say no for encryption.
654            self.stream.write(BeMessage::EncryptionResponseNo).await?;
655        }
656
657        Ok(())
658    }
659
660    async fn process_startup_msg(&mut self, msg: FeStartupMessage) -> PsqlResult<()> {
661        // Check SSL enforcement: if SSL is enforced but connection is not using SSL, reject
662        if let Some(ref tls_config) = self.tls_config
663            && tls_config.enforce_ssl
664            && !self.stream.is_ssl_connection().await
665        {
666            return Err(PsqlError::StartupError(
667                "SSL connection is required but not established".into(),
668            ));
669        }
670
671        let db_name = msg
672            .config
673            .get("database")
674            .cloned()
675            .unwrap_or_else(|| "dev".to_owned());
676        let user_name = msg
677            .config
678            .get("user")
679            .cloned()
680            .unwrap_or_else(|| "root".to_owned());
681
682        let session = self
683            .session_mgr
684            .connect(&db_name, &user_name, self.peer_addr.clone())
685            .map_err(|e| PsqlError::StartupError(e.into()))?;
686
687        let application_name = msg.config.get("application_name");
688        if let Some(application_name) = application_name {
689            session
690                .set_config("application_name", application_name.clone())
691                .map_err(|e| PsqlError::StartupError(e.into()))?;
692        }
693
694        match session.user_authenticator() {
695            UserAuthenticator::None => {
696                self.stream.write_no_flush(BeMessage::AuthenticationOk)?;
697
698                // Cancel request need this for identify and verification. According to postgres
699                // doc, it should be written to buffer after receive AuthenticationOk.
700                self.stream
701                    .write_no_flush(BeMessage::BackendKeyData(session.id()))?;
702
703                self.stream.write_no_flush(BeMessage::ParameterStatus(
704                    BeParameterStatusMessage::TimeZone(
705                        &session
706                            .get_config("timezone")
707                            .map_err(|e| PsqlError::StartupError(e.into()))?,
708                    ),
709                ))?;
710                self.stream
711                    .write_parameter_status_msg_no_flush(&ParameterStatus {
712                        application_name: application_name.cloned(),
713                    })?;
714                self.ready_for_query()?;
715            }
716            UserAuthenticator::ClearText(_)
717            | UserAuthenticator::OAuth { .. }
718            | UserAuthenticator::Ldap(..) => {
719                self.stream
720                    .write_no_flush(BeMessage::AuthenticationCleartextPassword)?;
721            }
722            UserAuthenticator::Md5WithSalt { salt, .. } => {
723                self.stream
724                    .write_no_flush(BeMessage::AuthenticationMd5Password(salt))?;
725            }
726        }
727
728        self.session = Some(session);
729        self.state = PgProtocolState::Regular;
730        Ok(())
731    }
732
733    async fn process_password_msg(&mut self, msg: FePasswordMessage) -> PsqlResult<()> {
734        let session = self.session.as_ref().unwrap();
735        let authenticator = session.user_authenticator();
736        authenticator.authenticate(&msg.password).await?;
737        self.stream.write_no_flush(BeMessage::AuthenticationOk)?;
738        let timezone = session
739            .get_config("timezone")
740            .map_err(|e| PsqlError::StartupError(e.into()))?;
741        self.stream.write_no_flush(BeMessage::ParameterStatus(
742            BeParameterStatusMessage::TimeZone(&timezone),
743        ))?;
744        self.stream
745            .write_parameter_status_msg_no_flush(&ParameterStatus::default())?;
746        self.ready_for_query()?;
747        self.state = PgProtocolState::Regular;
748        Ok(())
749    }
750
751    fn process_cancel_msg(&mut self, m: FeCancelMessage) -> PsqlResult<()> {
752        let session_id = (m.target_process_id, m.target_secret_key);
753        tracing::trace!("cancel query in session: {:?}", session_id);
754        self.session_mgr.cancel_queries_in_session(session_id);
755        self.session_mgr.cancel_creating_jobs_in_session(session_id);
756        self.is_terminate = true;
757        Ok(())
758    }
759
760    async fn process_query_msg(&mut self, sql: Arc<str>) -> PsqlResult<()> {
761        let truncated_sql =
762            record_sql_in_current_span(&sql, self.redact_sql_option_keywords.clone());
763        let session = self.session.clone().unwrap();
764
765        session.check_idle_in_transaction_timeout()?;
766        // Store only truncated SQL in context to prevent excessive memory usage from large SQL.
767        let _exec_context_guard = session.init_exec_context(truncated_sql.into());
768        self.inner_process_query_msg(sql, session.clone()).await
769    }
770
771    async fn inner_process_query_msg(
772        &mut self,
773        sql: Arc<str>,
774        session: Arc<SM::Session>,
775    ) -> PsqlResult<()> {
776        // Parse sql.
777        let stmts =
778            Parser::parse_sql(&sql).map_err(|err| PsqlError::SimpleQueryError(err.into()))?;
779        // The following inner_process_query_msg_one_stmt can be slow. Release potential large String early.
780        drop(sql);
781        if stmts.is_empty() {
782            self.stream.write_no_flush(BeMessage::EmptyQueryResponse)?;
783        }
784
785        // Execute multiple statements in simple query. KISS later.
786        for stmt in stmts {
787            self.inner_process_query_msg_one_stmt(stmt, session.clone())
788                .await?;
789        }
790        // Put this line inside the for loop above will lead to unfinished/stuck regress test...Not
791        // sure the reason.
792        self.ready_for_query()?;
793        Ok(())
794    }
795
796    async fn inner_process_query_msg_one_stmt(
797        &mut self,
798        stmt: Statement,
799        session: Arc<SM::Session>,
800    ) -> PsqlResult<()> {
801        let session = session.clone();
802
803        // execute query
804        let res = session.clone().run_one_query(stmt, Format::Text).await;
805
806        // Take all remaining notices (if any) and send them before `CommandComplete`.
807        while let Some(notice) = session.next_notice().now_or_never() {
808            self.stream
809                .write_no_flush(BeMessage::NoticeResponse(&notice))?;
810        }
811
812        let mut res = res.map_err(|e| PsqlError::SimpleQueryError(e.into()))?;
813
814        for notice in res.notices() {
815            self.stream
816                .write_no_flush(BeMessage::NoticeResponse(notice))?;
817        }
818
819        let status = res.status();
820        if let Some(ref application_name) = status.application_name {
821            self.stream.write_no_flush(BeMessage::ParameterStatus(
822                BeParameterStatusMessage::ApplicationName(application_name),
823            ))?;
824        }
825
826        if res.is_copy_query_to_stdout() {
827            self.stream
828                .write_no_flush(BeMessage::CopyOutResponse(res.row_desc().len()))?;
829            let mut count = 0;
830            while let Some(row_set) = res.values_stream().next().await {
831                let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
832                for row in row_set {
833                    self.stream.write_no_flush(BeMessage::CopyData(&row))?;
834                    count += 1;
835                }
836            }
837
838            self.stream.write_no_flush(BeMessage::CopyDone)?;
839
840            // Run the callback before sending the `CommandComplete` message.
841            res.run_callback().await?;
842
843            self.stream
844                .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
845                    stmt_type: res.stmt_type(),
846                    rows_cnt: count,
847                }))?;
848        } else if res.is_query() {
849            self.stream
850                .write_no_flush(BeMessage::RowDescription(res.row_desc()))?;
851
852            let mut rows_cnt = 0;
853
854            while let Some(row_set) = res.values_stream().next().await {
855                let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
856                for row in row_set {
857                    self.stream.write_no_flush(BeMessage::DataRow(&row))?;
858                    rows_cnt += 1;
859                }
860            }
861
862            // Run the callback before sending the `CommandComplete` message.
863            res.run_callback().await?;
864
865            self.stream
866                .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
867                    stmt_type: res.stmt_type(),
868                    rows_cnt,
869                }))?;
870        } else if res.stmt_type().is_dml() && !res.stmt_type().is_returning() {
871            let first_row_set = res.values_stream().next().await;
872            let first_row_set = match first_row_set {
873                None => {
874                    return Err(PsqlError::Uncategorized(
875                        anyhow::anyhow!("no affected rows in output").into(),
876                    ));
877                }
878                Some(row) => row.map_err(PsqlError::SimpleQueryError)?,
879            };
880            let affected_rows_str = first_row_set[0].values()[0]
881                .as_ref()
882                .expect("compute node should return affected rows in output");
883
884            assert!(matches!(res.row_cnt_format(), Some(Format::Text)));
885            let affected_rows_cnt = String::from_utf8(affected_rows_str.to_vec())
886                .unwrap()
887                .parse()
888                .unwrap_or_default();
889
890            // Run the callback before sending the `CommandComplete` message.
891            res.run_callback().await?;
892
893            self.stream
894                .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
895                    stmt_type: res.stmt_type(),
896                    rows_cnt: affected_rows_cnt,
897                }))?;
898        } else {
899            // Run the callback before sending the `CommandComplete` message.
900            res.run_callback().await?;
901
902            self.stream
903                .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
904                    stmt_type: res.stmt_type(),
905                    rows_cnt: 0,
906                }))?;
907        }
908
909        Ok(())
910    }
911
912    fn process_terminate(&mut self) {
913        self.is_terminate = true;
914    }
915
916    fn process_health_check(&mut self) {
917        tracing::debug!("health check");
918        self.is_terminate = true;
919    }
920
921    async fn process_parse_msg(&mut self, mut msg: FeParseMessage) -> PsqlResult<()> {
922        let sql = Arc::from(cstr_to_str(&msg.sql_bytes).unwrap());
923        record_sql_in_current_span(&sql, self.redact_sql_option_keywords.clone());
924        let session = self.session.clone().unwrap();
925        let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
926        let type_ids = std::mem::take(&mut msg.type_ids);
927        // The inner_process_parse_msg can be slow. Release potential large FeParseMessage early.
928        drop(msg);
929        self.inner_process_parse_msg(session, sql, statement_name, type_ids)
930            .await?;
931        Ok(())
932    }
933
934    async fn inner_process_parse_msg(
935        &mut self,
936        session: Arc<SM::Session>,
937        sql: Arc<str>,
938        statement_name: String,
939        type_ids: Vec<i32>,
940    ) -> PsqlResult<()> {
941        if statement_name.is_empty() {
942            // Remove the unnamed prepare statement first, in case the unsupported sql binds a wrong
943            // prepare statement.
944            self.unnamed_prepare_statement.take();
945        } else if self.prepare_statement_store.contains_key(&statement_name) {
946            return Err(PsqlError::ExtendedPrepareError(
947                "Duplicated statement name".into(),
948            ));
949        }
950
951        let stmt = {
952            let stmts = Parser::parse_sql(&sql)
953                .map_err(|err| PsqlError::ExtendedPrepareError(err.into()))?;
954            drop(sql);
955            if stmts.len() > 1 {
956                return Err(PsqlError::ExtendedPrepareError(
957                    "Only one statement is allowed in extended query mode".into(),
958                ));
959            }
960
961            stmts.into_iter().next()
962        };
963
964        let param_types: Vec<Option<DataType>> = type_ids
965            .iter()
966            .map(|&id| {
967                // 0 means unspecified type
968                // ref: https://www.postgresql.org/docs/15/protocol-message-formats.html#:~:text=Placing%20a%20zero%20here%20is%20equivalent%20to%20leaving%20the%20type%20unspecified.
969                if id == 0 {
970                    Ok(None)
971                } else {
972                    DataType::from_oid(id)
973                        .map(Some)
974                        .map_err(|e| PsqlError::ExtendedPrepareError(e.into()))
975                }
976            })
977            .try_collect()?;
978
979        let prepare_statement = session
980            .parse(stmt, param_types)
981            .await
982            .map_err(|e| PsqlError::ExtendedPrepareError(e.into()))?;
983
984        if statement_name.is_empty() {
985            self.unnamed_prepare_statement.replace(prepare_statement);
986        } else {
987            self.prepare_statement_store
988                .insert(statement_name.clone(), prepare_statement);
989        }
990
991        self.statement_portal_dependency
992            .entry(statement_name)
993            .or_default()
994            .clear();
995
996        self.stream.write_no_flush(BeMessage::ParseComplete)?;
997        Ok(())
998    }
999
1000    fn process_bind_msg(&mut self, msg: FeBindMessage) -> PsqlResult<()> {
1001        let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
1002        let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
1003        let session = self.session.clone().unwrap();
1004
1005        if self.portal_store.contains_key(&portal_name) {
1006            return Err(PsqlError::Uncategorized("Duplicated portal name".into()));
1007        }
1008
1009        let prepare_statement = self.get_statement(&statement_name)?;
1010
1011        let result_formats = msg
1012            .result_format_codes
1013            .iter()
1014            .map(|&format_code| Format::from_i16(format_code))
1015            .try_collect()?;
1016        let param_formats = msg
1017            .param_format_codes
1018            .iter()
1019            .map(|&format_code| Format::from_i16(format_code))
1020            .try_collect()?;
1021
1022        let portal = session
1023            .bind(prepare_statement, msg.params, param_formats, result_formats)
1024            .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1025
1026        if portal_name.is_empty() {
1027            self.result_cache.remove(&portal_name);
1028            self.unnamed_portal.replace(portal);
1029        } else {
1030            assert!(
1031                !self.result_cache.contains_key(&portal_name),
1032                "Named portal never can be overridden."
1033            );
1034            self.portal_store.insert(portal_name.clone(), portal);
1035        }
1036
1037        self.statement_portal_dependency
1038            .get_mut(&statement_name)
1039            .unwrap()
1040            .push(portal_name);
1041
1042        self.stream.write_no_flush(BeMessage::BindComplete)?;
1043        Ok(())
1044    }
1045
1046    async fn process_execute_msg(&mut self, msg: FeExecuteMessage) -> PsqlResult<()> {
1047        let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
1048        let row_max = msg.max_rows as usize;
1049        drop(msg);
1050        let session = self.session.clone().unwrap();
1051
1052        match self.result_cache.remove(&portal_name) {
1053            Some(mut result_cache) => {
1054                assert!(self.portal_store.contains_key(&portal_name));
1055
1056                let is_consume_completed =
1057                    result_cache.consume::<S>(row_max, &mut self.stream).await?;
1058
1059                if !is_consume_completed {
1060                    self.result_cache.insert(portal_name, result_cache);
1061                }
1062            }
1063            _ => {
1064                let portal = self.get_portal(&portal_name)?;
1065                let sql = format!("{}", portal);
1066                let truncated_sql =
1067                    record_sql_in_current_span(&sql, self.redact_sql_option_keywords.clone());
1068                drop(sql);
1069
1070                session.check_idle_in_transaction_timeout()?;
1071                // Store only truncated SQL in context to prevent excessive memory usage from large SQL.
1072                let _exec_context_guard = session.init_exec_context(truncated_sql.into());
1073                let result = session.clone().execute(portal).await;
1074
1075                let pg_response = result.map_err(|e| PsqlError::ExtendedExecuteError(e.into()))?;
1076                let mut result_cache = ResultCache::new(pg_response);
1077                let is_consume_completed =
1078                    result_cache.consume::<S>(row_max, &mut self.stream).await?;
1079                if !is_consume_completed {
1080                    self.result_cache.insert(portal_name, result_cache);
1081                }
1082            }
1083        }
1084
1085        Ok(())
1086    }
1087
1088    fn process_describe_msg(&mut self, msg: FeDescribeMessage) -> PsqlResult<()> {
1089        let name = cstr_to_str(&msg.name).unwrap().to_owned();
1090        let session = self.session.clone().unwrap();
1091        //  b'S' => Statement
1092        //  b'P' => Portal
1093
1094        assert!(msg.kind == b'S' || msg.kind == b'P');
1095        if msg.kind == b'S' {
1096            let prepare_statement = self.get_statement(&name)?;
1097
1098            let (param_types, row_descriptions) = self
1099                .session
1100                .clone()
1101                .unwrap()
1102                .describe_statement(prepare_statement)
1103                .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1104            self.stream.write_no_flush(BeMessage::ParameterDescription(
1105                &param_types.iter().map(|t| t.to_oid()).collect_vec(),
1106            ))?;
1107
1108            if row_descriptions.is_empty() {
1109                // According https://www.postgresql.org/docs/current/protocol-flow.html#:~:text=The%20response%20is%20a%20RowDescri[…]0a%20query%20that%20will%20return%20rows%3B,
1110                // return NoData message if the statement is not a query.
1111                self.stream.write_no_flush(BeMessage::NoData)?;
1112            } else {
1113                self.stream
1114                    .write_no_flush(BeMessage::RowDescription(&row_descriptions))?;
1115            }
1116        } else if msg.kind == b'P' {
1117            let portal = self.get_portal(&name)?;
1118
1119            let row_descriptions = session
1120                .describe_portal(portal)
1121                .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1122
1123            if row_descriptions.is_empty() {
1124                // According https://www.postgresql.org/docs/current/protocol-flow.html#:~:text=The%20response%20is%20a%20RowDescri[…]0a%20query%20that%20will%20return%20rows%3B,
1125                // return NoData message if the statement is not a query.
1126                self.stream.write_no_flush(BeMessage::NoData)?;
1127            } else {
1128                self.stream
1129                    .write_no_flush(BeMessage::RowDescription(&row_descriptions))?;
1130            }
1131        }
1132        Ok(())
1133    }
1134
1135    fn process_close_msg(&mut self, msg: FeCloseMessage) -> PsqlResult<()> {
1136        let name = cstr_to_str(&msg.name).unwrap().to_owned();
1137        assert!(msg.kind == b'S' || msg.kind == b'P');
1138        if msg.kind == b'S' {
1139            if name.is_empty() {
1140                self.unnamed_prepare_statement = None;
1141            } else {
1142                self.prepare_statement_store.remove(&name);
1143            }
1144            for portal_name in self
1145                .statement_portal_dependency
1146                .remove(&name)
1147                .unwrap_or_default()
1148            {
1149                self.remove_portal(&portal_name);
1150            }
1151        } else if msg.kind == b'P' {
1152            self.remove_portal(&name);
1153        }
1154        self.stream.write_no_flush(BeMessage::CloseComplete)?;
1155        Ok(())
1156    }
1157
1158    fn remove_portal(&mut self, portal_name: &str) {
1159        if portal_name.is_empty() {
1160            self.unnamed_portal = None;
1161        } else {
1162            self.portal_store.remove(portal_name);
1163        }
1164        self.result_cache.remove(portal_name);
1165    }
1166
1167    fn get_portal(&self, portal_name: &str) -> PsqlResult<<SM::Session as Session>::Portal> {
1168        if portal_name.is_empty() {
1169            Ok(self
1170                .unnamed_portal
1171                .as_ref()
1172                .ok_or_else(|| PsqlError::Uncategorized("unnamed portal not found".into()))?
1173                .clone())
1174        } else {
1175            Ok(self
1176                .portal_store
1177                .get(portal_name)
1178                .ok_or_else(|| {
1179                    PsqlError::Uncategorized(format!("Portal {} not found", portal_name).into())
1180                })?
1181                .clone())
1182        }
1183    }
1184
1185    fn get_statement(
1186        &self,
1187        statement_name: &str,
1188    ) -> PsqlResult<<SM::Session as Session>::PreparedStatement> {
1189        if statement_name.is_empty() {
1190            Ok(self
1191                .unnamed_prepare_statement
1192                .as_ref()
1193                .ok_or_else(|| {
1194                    PsqlError::Uncategorized("unnamed prepare statement not found".into())
1195                })?
1196                .clone())
1197        } else {
1198            Ok(self
1199                .prepare_statement_store
1200                .get(statement_name)
1201                .ok_or_else(|| {
1202                    PsqlError::Uncategorized(
1203                        format!("Prepare statement {} not found", statement_name).into(),
1204                    )
1205                })?
1206                .clone())
1207        }
1208    }
1209}
1210
1211enum PgStreamInner<S> {
1212    /// Used for the intermediate state when converting from unencrypted to ssl stream.
1213    Placeholder,
1214    /// An unencrypted stream.
1215    Unencrypted(S),
1216    /// An ssl stream.
1217    Ssl(SslStream<S>),
1218}
1219
1220/// Trait for a byte stream that can be used for pg protocol.
1221pub trait PgByteStream: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1222impl<S> PgByteStream for S where S: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1223
1224/// Wraps a byte stream and read/write pg messages.
1225///
1226/// Cloning a `PgStream` will share the same stream but a fresh & independent write buffer,
1227/// so that it can be used to write messages concurrently without interference.
1228pub struct PgStream<S> {
1229    /// The underlying stream.
1230    stream: Arc<Mutex<PgStreamInner<S>>>,
1231    /// Write into buffer before flush to stream.
1232    write_buf: BytesMut,
1233    read_header: Option<FeMessageHeader>,
1234}
1235
1236impl<S> PgStream<S> {
1237    /// Create a new `PgStream` with the given stream and default write buffer capacity.
1238    pub fn new(stream: S) -> Self {
1239        const DEFAULT_WRITE_BUF_CAPACITY: usize = 10 * 1024;
1240
1241        Self {
1242            stream: Arc::new(Mutex::new(PgStreamInner::Unencrypted(stream))),
1243            write_buf: BytesMut::with_capacity(DEFAULT_WRITE_BUF_CAPACITY),
1244            read_header: None,
1245        }
1246    }
1247
1248    /// Check if the current connection is using SSL
1249    async fn is_ssl_connection(&self) -> bool {
1250        let stream = self.stream.lock().await;
1251        matches!(*stream, PgStreamInner::Ssl(_))
1252    }
1253}
1254
1255impl<S> Clone for PgStream<S> {
1256    fn clone(&self) -> Self {
1257        Self {
1258            stream: Arc::clone(&self.stream),
1259            write_buf: BytesMut::with_capacity(self.write_buf.capacity()),
1260            read_header: self.read_header.clone(),
1261        }
1262    }
1263}
1264
1265/// At present there is a hard-wired set of parameters for which
1266/// ParameterStatus will be generated: they are:
1267///
1268///  * `server_version`
1269///  * `server_encoding`
1270///  * `client_encoding`
1271///  * `application_name`
1272///  * `is_superuser`
1273///  * `session_authorization`
1274///  * `DateStyle`
1275///  * `IntervalStyle`
1276///  * `TimeZone`
1277///  * `integer_datetimes`
1278///  * `standard_conforming_string`
1279///
1280/// See: <https://www.postgresql.org/docs/9.2/static/protocol-flow.html#PROTOCOL-ASYNC>.
1281#[derive(Debug, Default, Clone)]
1282pub struct ParameterStatus {
1283    pub application_name: Option<String>,
1284}
1285
1286impl<S> PgStream<S>
1287where
1288    S: PgByteStream,
1289{
1290    async fn read_startup(&mut self) -> io::Result<FeMessage> {
1291        let mut stream = self.stream.lock().await;
1292        match &mut *stream {
1293            PgStreamInner::Placeholder => unreachable!(),
1294            PgStreamInner::Unencrypted(stream) => FeStartupMessage::read(stream).await,
1295            PgStreamInner::Ssl(ssl_stream) => FeStartupMessage::read(ssl_stream).await,
1296        }
1297    }
1298
1299    async fn read_header(&mut self) -> io::Result<()> {
1300        let mut stream = self.stream.lock().await;
1301        match &mut *stream {
1302            PgStreamInner::Placeholder => unreachable!(),
1303            PgStreamInner::Unencrypted(stream) => {
1304                self.read_header = Some(FeMessage::read_header(stream).await?);
1305                Ok(())
1306            }
1307            PgStreamInner::Ssl(ssl_stream) => {
1308                self.read_header = Some(FeMessage::read_header(ssl_stream).await?);
1309                Ok(())
1310            }
1311        }
1312    }
1313
1314    async fn read_body(&mut self) -> io::Result<FeMessage> {
1315        let mut stream = self.stream.lock().await;
1316        let header = self
1317            .read_header
1318            .take()
1319            .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1320        match &mut *stream {
1321            PgStreamInner::Placeholder => unreachable!(),
1322            PgStreamInner::Unencrypted(stream) => FeMessage::read_body(stream, header).await,
1323            PgStreamInner::Ssl(ssl_stream) => FeMessage::read_body(ssl_stream, header).await,
1324        }
1325    }
1326
1327    async fn skip_body(&mut self) -> io::Result<()> {
1328        let mut stream = self.stream.lock().await;
1329        let header = self
1330            .read_header
1331            .take()
1332            .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1333        match &mut *stream {
1334            PgStreamInner::Placeholder => unreachable!(),
1335            PgStreamInner::Unencrypted(stream) => FeMessage::skip_body(stream, header).await,
1336            PgStreamInner::Ssl(ssl_stream) => FeMessage::skip_body(ssl_stream, header).await,
1337        }
1338    }
1339
1340    fn write_parameter_status_msg_no_flush(&mut self, status: &ParameterStatus) -> io::Result<()> {
1341        self.write_no_flush(BeMessage::ParameterStatus(
1342            BeParameterStatusMessage::ClientEncoding(SERVER_ENCODING),
1343        ))?;
1344        self.write_no_flush(BeMessage::ParameterStatus(
1345            BeParameterStatusMessage::StandardConformingString(STANDARD_CONFORMING_STRINGS),
1346        ))?;
1347        self.write_no_flush(BeMessage::ParameterStatus(
1348            BeParameterStatusMessage::ServerVersion(PG_VERSION),
1349        ))?;
1350        if let Some(application_name) = &status.application_name {
1351            self.write_no_flush(BeMessage::ParameterStatus(
1352                BeParameterStatusMessage::ApplicationName(application_name),
1353            ))?;
1354        }
1355        Ok(())
1356    }
1357
1358    pub fn write_no_flush(&mut self, message: BeMessage<'_>) -> io::Result<()> {
1359        BeMessage::write(&mut self.write_buf, message)
1360    }
1361
1362    async fn write(&mut self, message: BeMessage<'_>) -> io::Result<()> {
1363        self.write_no_flush(message)?;
1364        self.flush().await?;
1365        Ok(())
1366    }
1367
1368    async fn flush(&mut self) -> io::Result<()> {
1369        let mut stream = self.stream.lock().await;
1370        match &mut *stream {
1371            PgStreamInner::Placeholder => unreachable!(),
1372            PgStreamInner::Unencrypted(stream) => {
1373                stream.write_all(&self.write_buf).await?;
1374                stream.flush().await?;
1375            }
1376            PgStreamInner::Ssl(ssl_stream) => {
1377                ssl_stream.write_all(&self.write_buf).await?;
1378                ssl_stream.flush().await?;
1379            }
1380        }
1381        self.write_buf.clear();
1382        Ok(())
1383    }
1384}
1385
1386impl<S> PgStream<S>
1387where
1388    S: PgByteStream,
1389{
1390    /// Convert the underlying stream to ssl stream based on the given context.
1391    async fn upgrade_to_ssl(&mut self, ssl_ctx: &SslContextRef) -> PsqlResult<()> {
1392        let mut stream = self.stream.lock().await;
1393
1394        match std::mem::replace(&mut *stream, PgStreamInner::Placeholder) {
1395            PgStreamInner::Unencrypted(unencrypted_stream) => {
1396                let ssl = openssl::ssl::Ssl::new(ssl_ctx).unwrap();
1397                let mut ssl_stream =
1398                    tokio_openssl::SslStream::new(ssl, unencrypted_stream).unwrap();
1399
1400                if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
1401                    tracing::warn!(error = %e.as_report(), "Unable to set up an ssl connection");
1402                    let _ = ssl_stream.shutdown().await;
1403                    return Err(e.into());
1404                }
1405
1406                *stream = PgStreamInner::Ssl(ssl_stream);
1407            }
1408            PgStreamInner::Ssl(_) => panic!("the stream is already ssl"),
1409            PgStreamInner::Placeholder => unreachable!(),
1410        }
1411
1412        Ok(())
1413    }
1414}
1415
1416fn build_ssl_ctx_from_config(tls_config: &TlsConfig) -> PsqlResult<SslContext> {
1417    let mut acceptor = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls()).unwrap();
1418
1419    let key_path = &tls_config.key;
1420    let cert_path = &tls_config.cert;
1421
1422    // Build ssl acceptor according to the config.
1423    // Now we set every verify to true.
1424    acceptor
1425        .set_private_key_file(key_path, openssl::ssl::SslFiletype::PEM)
1426        .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1427    acceptor
1428        .set_ca_file(cert_path)
1429        .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1430    acceptor
1431        .set_certificate_chain_file(cert_path)
1432        .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1433    let acceptor = acceptor.build();
1434
1435    Ok(acceptor.into_context())
1436}
1437
1438pub mod truncated_fmt {
1439    use std::fmt::*;
1440
1441    struct TruncatedFormatter<'a, 'b> {
1442        remaining: usize,
1443        finished: bool,
1444        f: &'a mut Formatter<'b>,
1445    }
1446    impl Write for TruncatedFormatter<'_, '_> {
1447        fn write_str(&mut self, s: &str) -> Result {
1448            if self.finished {
1449                return Ok(());
1450            }
1451
1452            if self.remaining < s.len() {
1453                let actual = s.floor_char_boundary(self.remaining);
1454                self.f.write_str(&s[0..actual])?;
1455                self.remaining -= actual;
1456                self.f.write_str(&format!("...(truncated,{})", s.len()))?;
1457                self.finished = true; // so that ...(truncated) is printed exactly once
1458            } else {
1459                self.f.write_str(s)?;
1460                self.remaining -= s.len();
1461            }
1462            Ok(())
1463        }
1464    }
1465
1466    pub struct TruncatedFmt<'a, T>(pub &'a T, pub usize);
1467
1468    impl<T> Debug for TruncatedFmt<'_, T>
1469    where
1470        T: Debug,
1471    {
1472        fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1473            TruncatedFormatter {
1474                remaining: self.1,
1475                finished: false,
1476                f,
1477            }
1478            .write_fmt(format_args!("{:?}", self.0))
1479        }
1480    }
1481
1482    impl<T> Display for TruncatedFmt<'_, T>
1483    where
1484        T: Display,
1485    {
1486        fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1487            TruncatedFormatter {
1488                remaining: self.1,
1489                finished: false,
1490                f,
1491            }
1492            .write_fmt(format_args!("{}", self.0))
1493        }
1494    }
1495
1496    #[cfg(test)]
1497    mod tests {
1498        use super::*;
1499
1500        #[test]
1501        fn test_trunc_utf8() {
1502            assert_eq!(
1503                format!("{}", TruncatedFmt(&"select '🌊';", 10)),
1504                "select '...(truncated,14)",
1505            );
1506        }
1507    }
1508}
1509
1510#[cfg(test)]
1511mod tests {
1512    use std::collections::HashSet;
1513
1514    use super::*;
1515
1516    #[test]
1517    fn test_redact_parsable_sql() {
1518        let keywords = Arc::new(HashSet::from(["v2".into(), "v4".into(), "b".into()]));
1519        let sql = r"
1520        create source temp (k bigint, v varchar) with (
1521            connector = 'datagen',
1522            v1 = 123,
1523            v2 = 'with',
1524            v3 = false,
1525            v4 = '',
1526        ) FORMAT plain ENCODE json (a='1',b='2')
1527        ";
1528        assert_eq!(
1529            redact_sql(sql, keywords),
1530            "CREATE SOURCE temp (k BIGINT, v CHARACTER VARYING) WITH (connector = 'datagen', v1 = 123, v2 = [REDACTED], v3 = false, v4 = [REDACTED]) FORMAT PLAIN ENCODE JSON (a = '1', b = [REDACTED])"
1531        );
1532    }
1533}