pgwire/
pg_protocol.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::collections::HashMap;
17use std::io::ErrorKind;
18use std::panic::AssertUnwindSafe;
19use std::pin::Pin;
20use std::str::Utf8Error;
21use std::sync::{Arc, LazyLock, Weak};
22use std::time::{Duration, Instant};
23use std::{io, str};
24
25use bytes::{Bytes, BytesMut};
26use futures::FutureExt;
27use futures::stream::StreamExt;
28use itertools::Itertools;
29use openssl::ssl::{SslAcceptor, SslContext, SslContextRef, SslMethod};
30use risingwave_common::types::DataType;
31use risingwave_common::util::deployment::Deployment;
32use risingwave_common::util::env_var::env_var_is_true;
33use risingwave_common::util::panic::FutureCatchUnwindExt;
34use risingwave_common::util::query_log::*;
35use risingwave_common::{PG_VERSION, SERVER_ENCODING, STANDARD_CONFORMING_STRINGS};
36use risingwave_sqlparser::ast::{RedactSqlOptionKeywordsRef, Statement};
37use risingwave_sqlparser::parser::Parser;
38use thiserror_ext::AsReport;
39use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
40use tokio::sync::Mutex;
41use tokio_openssl::SslStream;
42use tracing::Instrument;
43
44use crate::error::{PsqlError, PsqlResult};
45use crate::memory_manager::{MessageMemoryGuard, MessageMemoryManagerRef};
46use crate::net::AddressRef;
47use crate::pg_extended::ResultCache;
48use crate::pg_message::{
49    BeCommandCompleteMessage, BeMessage, BeParameterStatusMessage, FeBindMessage, FeCancelMessage,
50    FeCloseMessage, FeDescribeMessage, FeExecuteMessage, FeMessage, FeMessageHeader,
51    FeParseMessage, FePasswordMessage, FeStartupMessage, ServerThrottleReason, TransactionStatus,
52};
53use crate::pg_server::{Session, SessionManager, UserAuthenticator};
54use crate::types::Format;
55
56/// Truncates query log if it's longer than `RW_QUERY_LOG_TRUNCATE_LEN`, to avoid log file being too
57/// large.
58static RW_QUERY_LOG_TRUNCATE_LEN: LazyLock<usize> =
59    LazyLock::new(|| match std::env::var("RW_QUERY_LOG_TRUNCATE_LEN") {
60        Ok(len) if len.parse::<usize>().is_ok() => len.parse::<usize>().unwrap(),
61        _ => {
62            if cfg!(debug_assertions) {
63                65536
64            } else {
65                1024
66            }
67        }
68    });
69
70tokio::task_local! {
71    /// The current session. Concrete type is erased for different session implementations.
72    pub static CURRENT_SESSION: Weak<dyn Any + Send + Sync>
73}
74
75/// The state machine for each psql connection.
76/// Read pg messages from tcp stream and write results back.
77pub struct PgProtocol<S, SM>
78where
79    SM: SessionManager,
80{
81    /// Used for write/read pg messages.
82    stream: PgStream<S>,
83    /// Current states of pg connection.
84    state: PgProtocolState,
85    /// Whether the connection is terminated.
86    is_terminate: bool,
87
88    session_mgr: Arc<SM>,
89    session: Option<Arc<SM::Session>>,
90
91    result_cache: HashMap<String, ResultCache<<SM::Session as Session>::ValuesStream>>,
92    unnamed_prepare_statement: Option<<SM::Session as Session>::PreparedStatement>,
93    prepare_statement_store: HashMap<String, <SM::Session as Session>::PreparedStatement>,
94    unnamed_portal: Option<<SM::Session as Session>::Portal>,
95    portal_store: HashMap<String, <SM::Session as Session>::Portal>,
96    // Used to store the dependency of portal and prepare statement.
97    // When we close a prepare statement, we need to close all the portals that depend on it.
98    statement_portal_dependency: HashMap<String, Vec<String>>,
99
100    // Used for ssl connection.
101    // If None, not expected to build ssl connection (panic).
102    tls_context: Option<SslContext>,
103
104    // TLS configuration including SSL enforcement setting
105    tls_config: Option<TlsConfig>,
106
107    // Used in extended query protocol. When encounter error in extended query, we need to ignore
108    // the following message util sync message.
109    ignore_util_sync: bool,
110
111    // Client Address
112    peer_addr: AddressRef,
113
114    redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
115    message_memory_manager: MessageMemoryManagerRef,
116}
117
118/// Configures TLS encryption for connections.
119#[derive(Debug, Clone)]
120pub struct TlsConfig {
121    /// The path to the TLS certificate.
122    pub cert: String,
123    /// The path to the TLS key.
124    pub key: String,
125    /// Whether to enforce SSL connections (reject non-SSL clients).
126    pub enforce_ssl: bool,
127}
128
129impl TlsConfig {
130    pub fn new_default() -> Option<Self> {
131        let cert = std::env::var("RW_SSL_CERT").ok()?;
132        let key = std::env::var("RW_SSL_KEY").ok()?;
133        let enforce_ssl = env_var_is_true("RW_SSL_ENFORCE");
134        tracing::info!(
135            "RW_SSL_CERT={}, RW_SSL_KEY={}, RW_SSL_ENFORCE={}",
136            cert,
137            key,
138            enforce_ssl
139        );
140        Some(Self {
141            cert,
142            key,
143            enforce_ssl,
144        })
145    }
146}
147
148impl<S, SM> Drop for PgProtocol<S, SM>
149where
150    SM: SessionManager,
151{
152    fn drop(&mut self) {
153        if let Some(session) = &self.session {
154            // Clear the session in session manager.
155            self.session_mgr.end_session(session);
156        }
157    }
158}
159
160/// States flow happened from top to down.
161enum PgProtocolState {
162    Startup,
163    Regular,
164}
165
166/// Truncate 0 from C string in Bytes and stringify it (returns slice, no allocations).
167///
168/// PG protocol strings are always C strings.
169pub fn cstr_to_str(b: &Bytes) -> Result<&str, Utf8Error> {
170    let without_null = if b.last() == Some(&0) {
171        &b[..b.len() - 1]
172    } else {
173        &b[..]
174    };
175    std::str::from_utf8(without_null)
176}
177
178fn record_sql_in_current_span(
179    sql: &str,
180    redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
181) -> String {
182    let mut span = tracing::Span::current();
183    record_sql_in_span(sql, redact_sql_option_keywords, &mut span)
184}
185
186/// Record `sql` in the current tracing span.
187fn record_sql_in_span(
188    sql: &str,
189    redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
190    span: &mut tracing::Span,
191) -> String {
192    let redacted_sql = if let Some(keywords) = redact_sql_option_keywords
193        && !keywords.is_empty()
194    {
195        redact_sql(sql, keywords)
196    } else {
197        sql.to_owned()
198    };
199    let truncated = truncated_fmt::TruncatedFmt(&redacted_sql, *RW_QUERY_LOG_TRUNCATE_LEN);
200    span.record("sql", tracing::field::display(&truncated));
201    truncated.to_string()
202}
203
204/// Redacts SQL options. Data in DML is not redacted.
205fn redact_sql(sql: &str, keywords: RedactSqlOptionKeywordsRef) -> String {
206    match Parser::parse_sql(sql) {
207        Ok(sqls) => sqls
208            .into_iter()
209            .map(|sql| sql.to_redacted_string(keywords.clone()))
210            .join(";"),
211        Err(_) => sql.to_owned(),
212    }
213}
214
215#[derive(Clone)]
216pub struct ConnectionContext {
217    pub tls_config: Option<TlsConfig>,
218    pub redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
219    pub message_memory_manager: MessageMemoryManagerRef,
220}
221
222impl<S, SM> PgProtocol<S, SM>
223where
224    S: PgByteStream,
225    SM: SessionManager,
226{
227    pub fn new(
228        stream: S,
229        session_mgr: Arc<SM>,
230        peer_addr: AddressRef,
231        context: ConnectionContext,
232    ) -> Self {
233        let ConnectionContext {
234            tls_config,
235            redact_sql_option_keywords,
236            message_memory_manager,
237        } = context;
238        Self {
239            stream: PgStream::new(stream),
240            is_terminate: false,
241            state: PgProtocolState::Startup,
242            session_mgr,
243            session: None,
244            tls_context: tls_config
245                .as_ref()
246                .and_then(|e| build_ssl_ctx_from_config(e).ok()),
247            tls_config,
248            result_cache: Default::default(),
249            unnamed_prepare_statement: Default::default(),
250            prepare_statement_store: Default::default(),
251            unnamed_portal: Default::default(),
252            portal_store: Default::default(),
253            statement_portal_dependency: Default::default(),
254            ignore_util_sync: false,
255            peer_addr,
256            redact_sql_option_keywords,
257            message_memory_manager,
258        }
259    }
260
261    /// Run the protocol to serve the connection.
262    pub async fn run(&mut self) {
263        let mut notice_fut = None;
264
265        loop {
266            // Once a session is present, create a future to subscribe and send notices asynchronously.
267            if notice_fut.is_none()
268                && let Some(session) = self.session.clone()
269            {
270                let mut stream = self.stream.clone();
271                notice_fut = Some(Box::pin(async move {
272                    loop {
273                        let notice = session.next_notice().await;
274                        if let Err(e) = stream.write(BeMessage::NoticeResponse(&notice)).await {
275                            tracing::error!(error = %e.as_report(), notice, "failed to send notice");
276                        }
277                    }
278                }));
279            }
280
281            // Read and process messages.
282            let process = std::pin::pin!(async {
283                let (msg, _memory_guard) = match self.read_message().await {
284                    Ok(msg) => msg,
285                    Err(e) => {
286                        tracing::error!(error = %e.as_report(), "error when reading message");
287                        return true; // terminate the connection
288                    }
289                };
290                tracing::trace!(?msg, "received message");
291                self.process(msg).await
292            });
293
294            let terminated = if let Some(notice_fut) = notice_fut.as_mut() {
295                tokio::select! {
296                    _ = notice_fut => unreachable!(),
297                    terminated = process => terminated,
298                }
299            } else {
300                process.await
301            };
302
303            if terminated {
304                break;
305            }
306        }
307    }
308
309    /// Processes one message. Returns true if the connection is terminated.
310    pub async fn process(&mut self, msg: FeMessage) -> bool {
311        self.do_process(msg).await.is_none() || self.is_terminate
312    }
313
314    /// The root tracing span for processing a message. The target of the span is
315    /// [`PGWIRE_ROOT_SPAN_TARGET`].
316    ///
317    /// This is used to provide context for the (slow) query logs and traces.
318    ///
319    /// The span is only effective if there's a current session and the message is
320    /// query-related. Otherwise, `Span::none()` is returned.
321    fn root_span_for_msg(&self, msg: &FeMessage) -> tracing::Span {
322        let Some(session_id) = self.session.as_ref().map(|s| s.id().0) else {
323            return tracing::Span::none();
324        };
325
326        let mode = match msg {
327            FeMessage::Query(_) => "simple query",
328            FeMessage::Parse(_) => "extended query parse",
329            FeMessage::Execute(_) => "extended query execute",
330            _ => return tracing::Span::none(),
331        };
332
333        let mut span = tracing::info_span!(
334            target: PGWIRE_ROOT_SPAN_TARGET,
335            "handle_query",
336            mode,
337            session_id,
338            sql = tracing::field::Empty,
339        );
340        if let Ok(sql) = msg.get_sql()
341            && let Some(sql) = sql
342        {
343            record_sql_in_span(sql, self.redact_sql_option_keywords.clone(), &mut span);
344        }
345        span
346    }
347
348    /// Return type `Option<()>` is essentially a bool, but allows `?` for early return.
349    /// - `None` means to terminate the current connection
350    /// - `Some(())` means to continue processing the next message
351    async fn do_process(&mut self, msg: FeMessage) -> Option<()> {
352        let span = self.root_span_for_msg(&msg);
353        let weak_session = self
354            .session
355            .as_ref()
356            .map(|s| Arc::downgrade(s) as Weak<dyn Any + Send + Sync>);
357
358        // Processing the message itself.
359        //
360        // Note: pin the future to avoid stack overflow as we'll wrap it multiple times
361        // in the following code.
362        let fut = Box::pin(self.do_process_inner(msg));
363
364        // Set the current session as the context when processing the message, if exists.
365        let fut = async move {
366            if let Some(session) = weak_session {
367                CURRENT_SESSION.scope(session, fut).await
368            } else {
369                fut.await
370            }
371        };
372
373        // Catch unwind.
374        let fut = async move {
375            AssertUnwindSafe(fut)
376                .rw_catch_unwind()
377                .await
378                .unwrap_or_else(|payload| {
379                    Err(PsqlError::Panic(
380                        panic_message::panic_message(&payload).to_owned(),
381                    ))
382                })
383        };
384
385        // Slow query log.
386        let fut = async move {
387            let period = *SLOW_QUERY_LOG_PERIOD;
388            let mut fut = std::pin::pin!(fut);
389            let mut elapsed = Duration::ZERO;
390
391            // Report the SQL in the log periodically if the query is slow.
392            loop {
393                match tokio::time::timeout(period, &mut fut).await {
394                    Ok(result) => break result,
395                    Err(_) => {
396                        elapsed += period;
397                        tracing::info!(
398                            target: PGWIRE_SLOW_QUERY_LOG,
399                            elapsed = %format_args!("{}ms", elapsed.as_millis()),
400                            "slow query"
401                        );
402                    }
403                }
404            }
405        };
406
407        // Query log.
408        let fut = async move {
409            if !tracing::Span::current().is_none() {
410                tracing::info!(
411                    target: PGWIRE_QUERY_LOG,
412                    status = "started",
413                );
414            }
415
416            let start = Instant::now();
417            let result = fut.await;
418            let elapsed = start.elapsed();
419
420            // Always log if an error occurs.
421            // Note: all messages will be processed through this code path, making it the
422            //       only necessary place to log errors.
423            if let Err(error) = &result {
424                if cfg!(debug_assertions) && !Deployment::current().is_ci() {
425                    // For local debugging, we print the error with backtrace.
426                    // It's useful only when:
427                    // - no additional context is added to the error
428                    // - backtrace is captured in the error
429                    // - backtrace is not printed in the middle
430                    tracing::error!(error = ?error.as_report(), "error when process message");
431                } else {
432                    tracing::error!(error = %error.as_report(), "error when process message");
433                }
434            }
435
436            // Log to optionally-enabled target `PGWIRE_QUERY_LOG`.
437            // Only log if we're currently in a tracing span set in `span_for_msg`.
438            if !tracing::Span::current().is_none() {
439                tracing::info!(
440                    target: PGWIRE_QUERY_LOG,
441                    status = if result.is_ok() { "ok" } else { "err" },
442                    time = %format_args!("{}ms", elapsed.as_millis()),
443                );
444            }
445
446            result
447        };
448
449        // Tracing span.
450        let fut = fut.instrument(span);
451
452        // Execute the future and handle the error.
453        match fut.await {
454            Ok(()) => Some(()),
455            Err(e) => {
456                match e {
457                    PsqlError::IoError(io_err) => {
458                        if io_err.kind() == std::io::ErrorKind::UnexpectedEof {
459                            return None;
460                        }
461                    }
462
463                    PsqlError::SslError(_) => {
464                        // For ssl error, because the stream has already been consumed, so there is
465                        // no way to write more message.
466                        return None;
467                    }
468
469                    PsqlError::StartupError(_) | PsqlError::PasswordError => {
470                        self.stream
471                            .write_no_flush(BeMessage::ErrorResponse {
472                                error: &e,
473                                // At this time we're not in a session, use compact error message for
474                                // better alignment with Postgres' UI.
475                                pretty: false,
476                            })
477                            .ok()?;
478                        let _ = self.stream.flush().await;
479                        return None;
480                    }
481
482                    PsqlError::SimpleQueryError(_) | PsqlError::ServerThrottle(_) => {
483                        self.stream
484                            .write_no_flush(BeMessage::ErrorResponse {
485                                error: &e,
486                                pretty: true,
487                            })
488                            .ok()?;
489                        self.ready_for_query().ok()?;
490                    }
491
492                    PsqlError::IdleInTxnTimeout | PsqlError::Panic(_) => {
493                        self.stream
494                            .write_no_flush(BeMessage::ErrorResponse {
495                                error: &e,
496                                pretty: true,
497                            })
498                            .ok()?;
499                        let _ = self.stream.flush().await;
500
501                        // 1. Catching the panic during message processing may leave the session in an
502                        // inconsistent state. We forcefully close the connection (then end the
503                        // session) here for safety.
504                        // 2. Idle in transaction timeout should also close the connection.
505                        return None;
506                    }
507
508                    PsqlError::Uncategorized(_)
509                    | PsqlError::ExtendedPrepareError(_)
510                    | PsqlError::ExtendedExecuteError(_) => {
511                        self.stream
512                            .write_no_flush(BeMessage::ErrorResponse {
513                                error: &e,
514                                pretty: true,
515                            })
516                            .ok()?;
517                    }
518                }
519                let _ = self.stream.flush().await;
520                Some(())
521            }
522        }
523    }
524
525    async fn do_process_inner(&mut self, msg: FeMessage) -> PsqlResult<()> {
526        // Ignore util sync message.
527        if self.ignore_util_sync {
528            if let FeMessage::Sync = msg {
529            } else {
530                tracing::trace!("ignore message {:?} until sync.", msg);
531                return Ok(());
532            }
533        }
534
535        match msg {
536            FeMessage::Gss => self.process_gss_msg().await?,
537            FeMessage::Ssl => self.process_ssl_msg().await?,
538            FeMessage::Startup(msg) => self.process_startup_msg(msg).await?,
539            FeMessage::Password(msg) => self.process_password_msg(msg).await?,
540            FeMessage::Query(query_msg) => {
541                let sql = Arc::from(query_msg.get_sql()?);
542                // The process_query_msg can be slow. Release potential large FeQueryMessage early.
543                drop(query_msg);
544                self.process_query_msg(sql).await?
545            }
546            FeMessage::CancelQuery(m) => self.process_cancel_msg(m)?,
547            FeMessage::Terminate => self.process_terminate(),
548            FeMessage::Parse(m) => {
549                if let Err(err) = self.process_parse_msg(m).await {
550                    self.ignore_util_sync = true;
551                    return Err(err);
552                }
553            }
554            FeMessage::Bind(m) => {
555                if let Err(err) = self.process_bind_msg(m) {
556                    self.ignore_util_sync = true;
557                    return Err(err);
558                }
559            }
560            FeMessage::Execute(m) => {
561                if let Err(err) = self.process_execute_msg(m).await {
562                    self.ignore_util_sync = true;
563                    return Err(err);
564                }
565            }
566            FeMessage::Describe(m) => {
567                if let Err(err) = self.process_describe_msg(m) {
568                    self.ignore_util_sync = true;
569                    return Err(err);
570                }
571            }
572            FeMessage::Sync => {
573                self.ignore_util_sync = false;
574                self.ready_for_query()?
575            }
576            FeMessage::Close(m) => {
577                if let Err(err) = self.process_close_msg(m) {
578                    self.ignore_util_sync = true;
579                    return Err(err);
580                }
581            }
582            FeMessage::Flush => {
583                if let Err(err) = self.stream.flush().await {
584                    self.ignore_util_sync = true;
585                    return Err(err.into());
586                }
587            }
588            FeMessage::HealthCheck => self.process_health_check(),
589            FeMessage::ServerThrottle(reason) => match reason {
590                ServerThrottleReason::TooLargeMessage => {
591                    return Err(PsqlError::ServerThrottle(format!(
592                        "max_single_query_size_bytes {} has been exceeded, please either reduce the query size or increase the limit",
593                        self.message_memory_manager.max_filter_bytes
594                    )));
595                }
596                ServerThrottleReason::TooManyMemoryUsage => {
597                    return Err(PsqlError::ServerThrottle(format!(
598                        "max_total_query_size_bytes {} has been exceeded, please either retry or increase the limit",
599                        self.message_memory_manager.max_running_bytes
600                    )));
601                }
602            },
603        }
604        self.stream.flush().await?;
605        Ok(())
606    }
607
608    pub async fn read_message(&mut self) -> io::Result<(FeMessage, Option<MessageMemoryGuard>)> {
609        match self.state {
610            PgProtocolState::Startup => self
611                .stream
612                .read_startup()
613                .await
614                .map(|message: FeMessage| (message, None)),
615            PgProtocolState::Regular => {
616                self.stream.read_header().await?;
617                let guard = if let Some(ref header) = self.stream.read_header {
618                    let payload_len = std::cmp::max(header.payload_len, 0) as u64;
619                    let (reason, guard) = self.message_memory_manager.add(payload_len);
620                    if let Some(reason) = reason {
621                        // Release the memory ASAP.
622                        drop(guard);
623                        self.stream.skip_body().await?;
624                        return Ok((FeMessage::ServerThrottle(reason), None));
625                    }
626                    guard
627                } else {
628                    None
629                };
630                let message = self.stream.read_body().await?;
631                Ok((message, guard))
632            }
633        }
634    }
635
636    /// Writes a `ReadyForQuery` message to the client without flushing.
637    fn ready_for_query(&mut self) -> io::Result<()> {
638        self.stream.write_no_flush(BeMessage::ReadyForQuery(
639            self.session
640                .as_ref()
641                .map(|s| s.transaction_status())
642                .unwrap_or(TransactionStatus::Idle),
643        ))
644    }
645
646    async fn process_gss_msg(&mut self) -> PsqlResult<()> {
647        // We don't support GSSAPI, so we just say no gracefully.
648        self.stream.write(BeMessage::EncryptionResponseNo).await?;
649        Ok(())
650    }
651
652    async fn process_ssl_msg(&mut self) -> PsqlResult<()> {
653        if let Some(context) = self.tls_context.as_ref() {
654            // If got and ssl context, say yes for ssl connection.
655            // Construct ssl stream and replace with current one.
656            self.stream.write(BeMessage::EncryptionResponseSsl).await?;
657            self.stream.upgrade_to_ssl(context).await?;
658        } else {
659            // If no, say no for encryption.
660            self.stream.write(BeMessage::EncryptionResponseNo).await?;
661        }
662
663        Ok(())
664    }
665
666    async fn process_startup_msg(&mut self, msg: FeStartupMessage) -> PsqlResult<()> {
667        // Check SSL enforcement: if SSL is enforced but connection is not using SSL, reject
668        if let Some(ref tls_config) = self.tls_config
669            && tls_config.enforce_ssl
670            && !self.stream.is_ssl_connection().await
671        {
672            return Err(PsqlError::StartupError(
673                "SSL connection is required but not established".into(),
674            ));
675        }
676
677        let db_name = msg
678            .config
679            .get("database")
680            .cloned()
681            .unwrap_or_else(|| "dev".to_owned());
682        let user_name = msg
683            .config
684            .get("user")
685            .cloned()
686            .unwrap_or_else(|| "root".to_owned());
687
688        let session = self
689            .session_mgr
690            .connect(&db_name, &user_name, self.peer_addr.clone())
691            .map_err(|e| PsqlError::StartupError(e.into()))?;
692
693        let application_name = msg.config.get("application_name");
694        if let Some(application_name) = application_name {
695            session
696                .set_config("application_name", application_name.clone())
697                .map_err(|e| PsqlError::StartupError(e.into()))?;
698        }
699
700        match session.user_authenticator() {
701            UserAuthenticator::None => {
702                self.stream.write_no_flush(BeMessage::AuthenticationOk)?;
703
704                // Cancel request need this for identify and verification. According to postgres
705                // doc, it should be written to buffer after receive AuthenticationOk.
706                self.stream
707                    .write_no_flush(BeMessage::BackendKeyData(session.id()))?;
708
709                self.stream.write_no_flush(BeMessage::ParameterStatus(
710                    BeParameterStatusMessage::TimeZone(
711                        &session
712                            .get_config("timezone")
713                            .map_err(|e| PsqlError::StartupError(e.into()))?,
714                    ),
715                ))?;
716                self.stream
717                    .write_parameter_status_msg_no_flush(&ParameterStatus {
718                        application_name: application_name.cloned(),
719                    })?;
720                self.ready_for_query()?;
721            }
722            UserAuthenticator::ClearText(_)
723            | UserAuthenticator::OAuth { .. }
724            | UserAuthenticator::Ldap(..) => {
725                self.stream
726                    .write_no_flush(BeMessage::AuthenticationCleartextPassword)?;
727            }
728            UserAuthenticator::Md5WithSalt { salt, .. } => {
729                self.stream
730                    .write_no_flush(BeMessage::AuthenticationMd5Password(salt))?;
731            }
732        }
733
734        self.session = Some(session);
735        self.state = PgProtocolState::Regular;
736        Ok(())
737    }
738
739    async fn process_password_msg(&mut self, msg: FePasswordMessage) -> PsqlResult<()> {
740        let session = self.session.as_ref().unwrap();
741        let authenticator = session.user_authenticator();
742        authenticator.authenticate(&msg.password).await?;
743        self.stream.write_no_flush(BeMessage::AuthenticationOk)?;
744        let timezone = session
745            .get_config("timezone")
746            .map_err(|e| PsqlError::StartupError(e.into()))?;
747        self.stream.write_no_flush(BeMessage::ParameterStatus(
748            BeParameterStatusMessage::TimeZone(&timezone),
749        ))?;
750        self.stream
751            .write_parameter_status_msg_no_flush(&ParameterStatus::default())?;
752        self.ready_for_query()?;
753        self.state = PgProtocolState::Regular;
754        Ok(())
755    }
756
757    fn process_cancel_msg(&mut self, m: FeCancelMessage) -> PsqlResult<()> {
758        let session_id = (m.target_process_id, m.target_secret_key);
759        tracing::trace!("cancel query in session: {:?}", session_id);
760        self.session_mgr.cancel_queries_in_session(session_id);
761        self.session_mgr.cancel_creating_jobs_in_session(session_id);
762        self.is_terminate = true;
763        Ok(())
764    }
765
766    async fn process_query_msg(&mut self, sql: Arc<str>) -> PsqlResult<()> {
767        let truncated_sql =
768            record_sql_in_current_span(&sql, self.redact_sql_option_keywords.clone());
769        let session = self.session.clone().unwrap();
770
771        session.check_idle_in_transaction_timeout()?;
772        // Store only truncated SQL in context to prevent excessive memory usage from large SQL.
773        let _exec_context_guard = session.init_exec_context(truncated_sql.into());
774        self.inner_process_query_msg(sql, session.clone()).await
775    }
776
777    async fn inner_process_query_msg(
778        &mut self,
779        sql: Arc<str>,
780        session: Arc<SM::Session>,
781    ) -> PsqlResult<()> {
782        // Parse sql.
783        let stmts =
784            Parser::parse_sql(&sql).map_err(|err| PsqlError::SimpleQueryError(err.into()))?;
785        // The following inner_process_query_msg_one_stmt can be slow. Release potential large String early.
786        drop(sql);
787        if stmts.is_empty() {
788            self.stream.write_no_flush(BeMessage::EmptyQueryResponse)?;
789        }
790
791        // Execute multiple statements in simple query. KISS later.
792        for stmt in stmts {
793            self.inner_process_query_msg_one_stmt(stmt, session.clone())
794                .await?;
795        }
796        // Put this line inside the for loop above will lead to unfinished/stuck regress test...Not
797        // sure the reason.
798        self.ready_for_query()?;
799        Ok(())
800    }
801
802    async fn inner_process_query_msg_one_stmt(
803        &mut self,
804        stmt: Statement,
805        session: Arc<SM::Session>,
806    ) -> PsqlResult<()> {
807        let session = session.clone();
808
809        // execute query
810        let res = session.clone().run_one_query(stmt, Format::Text).await;
811
812        // Take all remaining notices (if any) and send them before `CommandComplete`.
813        while let Some(notice) = session.next_notice().now_or_never() {
814            self.stream
815                .write_no_flush(BeMessage::NoticeResponse(&notice))?;
816        }
817
818        let mut res = res.map_err(|e| PsqlError::SimpleQueryError(e.into()))?;
819
820        for notice in res.notices() {
821            self.stream
822                .write_no_flush(BeMessage::NoticeResponse(notice))?;
823        }
824
825        let status = res.status();
826        if let Some(ref application_name) = status.application_name {
827            self.stream.write_no_flush(BeMessage::ParameterStatus(
828                BeParameterStatusMessage::ApplicationName(application_name),
829            ))?;
830        }
831
832        if res.is_copy_query_to_stdout() {
833            self.stream
834                .write_no_flush(BeMessage::CopyOutResponse(res.row_desc().len()))?;
835            let mut count = 0;
836            while let Some(row_set) = res.values_stream().next().await {
837                let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
838                for row in row_set {
839                    self.stream.write_no_flush(BeMessage::CopyData(&row))?;
840                    count += 1;
841                }
842            }
843
844            self.stream.write_no_flush(BeMessage::CopyDone)?;
845
846            // Run the callback before sending the `CommandComplete` message.
847            res.run_callback().await?;
848
849            self.stream
850                .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
851                    stmt_type: res.stmt_type(),
852                    rows_cnt: count,
853                }))?;
854        } else if res.is_query() {
855            self.stream
856                .write_no_flush(BeMessage::RowDescription(res.row_desc()))?;
857
858            let mut rows_cnt = 0;
859
860            while let Some(row_set) = res.values_stream().next().await {
861                let row_set = row_set.map_err(PsqlError::SimpleQueryError)?;
862                for row in row_set {
863                    self.stream.write_no_flush(BeMessage::DataRow(&row))?;
864                    rows_cnt += 1;
865                }
866            }
867
868            // Run the callback before sending the `CommandComplete` message.
869            res.run_callback().await?;
870
871            self.stream
872                .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
873                    stmt_type: res.stmt_type(),
874                    rows_cnt,
875                }))?;
876        } else if res.stmt_type().is_dml() && !res.stmt_type().is_returning() {
877            let first_row_set = res.values_stream().next().await;
878            let first_row_set = match first_row_set {
879                None => {
880                    return Err(PsqlError::Uncategorized(
881                        anyhow::anyhow!("no affected rows in output").into(),
882                    ));
883                }
884                Some(row) => row.map_err(PsqlError::SimpleQueryError)?,
885            };
886            let affected_rows_str = first_row_set[0].values()[0]
887                .as_ref()
888                .expect("compute node should return affected rows in output");
889
890            assert!(matches!(res.row_cnt_format(), Some(Format::Text)));
891            let affected_rows_cnt = String::from_utf8(affected_rows_str.to_vec())
892                .unwrap()
893                .parse()
894                .unwrap_or_default();
895
896            // Run the callback before sending the `CommandComplete` message.
897            res.run_callback().await?;
898
899            self.stream
900                .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
901                    stmt_type: res.stmt_type(),
902                    rows_cnt: affected_rows_cnt,
903                }))?;
904        } else {
905            // Run the callback before sending the `CommandComplete` message.
906            res.run_callback().await?;
907
908            self.stream
909                .write_no_flush(BeMessage::CommandComplete(BeCommandCompleteMessage {
910                    stmt_type: res.stmt_type(),
911                    rows_cnt: 0,
912                }))?;
913        }
914
915        Ok(())
916    }
917
918    fn process_terminate(&mut self) {
919        self.is_terminate = true;
920    }
921
922    fn process_health_check(&mut self) {
923        tracing::debug!("health check");
924        self.is_terminate = true;
925    }
926
927    async fn process_parse_msg(&mut self, mut msg: FeParseMessage) -> PsqlResult<()> {
928        let sql = Arc::from(cstr_to_str(&msg.sql_bytes).unwrap());
929        record_sql_in_current_span(&sql, self.redact_sql_option_keywords.clone());
930        let session = self.session.clone().unwrap();
931        let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
932        let type_ids = std::mem::take(&mut msg.type_ids);
933        // The inner_process_parse_msg can be slow. Release potential large FeParseMessage early.
934        drop(msg);
935        self.inner_process_parse_msg(session, sql, statement_name, type_ids)
936            .await?;
937        Ok(())
938    }
939
940    async fn inner_process_parse_msg(
941        &mut self,
942        session: Arc<SM::Session>,
943        sql: Arc<str>,
944        statement_name: String,
945        type_ids: Vec<i32>,
946    ) -> PsqlResult<()> {
947        if statement_name.is_empty() {
948            // Remove the unnamed prepare statement first, in case the unsupported sql binds a wrong
949            // prepare statement.
950            self.unnamed_prepare_statement.take();
951        } else if self.prepare_statement_store.contains_key(&statement_name) {
952            return Err(PsqlError::ExtendedPrepareError(
953                "Duplicated statement name".into(),
954            ));
955        }
956
957        let stmt = {
958            let stmts = Parser::parse_sql(&sql)
959                .map_err(|err| PsqlError::ExtendedPrepareError(err.into()))?;
960            drop(sql);
961            if stmts.len() > 1 {
962                return Err(PsqlError::ExtendedPrepareError(
963                    "Only one statement is allowed in extended query mode".into(),
964                ));
965            }
966
967            stmts.into_iter().next()
968        };
969
970        let param_types: Vec<Option<DataType>> = type_ids
971            .iter()
972            .map(|&id| {
973                // 0 means unspecified type
974                // ref: https://www.postgresql.org/docs/15/protocol-message-formats.html#:~:text=Placing%20a%20zero%20here%20is%20equivalent%20to%20leaving%20the%20type%20unspecified.
975                if id == 0 {
976                    Ok(None)
977                } else {
978                    DataType::from_oid(id)
979                        .map(Some)
980                        .map_err(|e| PsqlError::ExtendedPrepareError(e.into()))
981                }
982            })
983            .try_collect()?;
984
985        let prepare_statement = session
986            .parse(stmt, param_types)
987            .await
988            .map_err(|e| PsqlError::ExtendedPrepareError(e.into()))?;
989
990        if statement_name.is_empty() {
991            self.unnamed_prepare_statement.replace(prepare_statement);
992        } else {
993            self.prepare_statement_store
994                .insert(statement_name.clone(), prepare_statement);
995        }
996
997        self.statement_portal_dependency
998            .entry(statement_name)
999            .or_default()
1000            .clear();
1001
1002        self.stream.write_no_flush(BeMessage::ParseComplete)?;
1003        Ok(())
1004    }
1005
1006    fn process_bind_msg(&mut self, msg: FeBindMessage) -> PsqlResult<()> {
1007        let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_owned();
1008        let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
1009        let session = self.session.clone().unwrap();
1010
1011        if self.portal_store.contains_key(&portal_name) {
1012            return Err(PsqlError::Uncategorized("Duplicated portal name".into()));
1013        }
1014
1015        let prepare_statement = self.get_statement(&statement_name)?;
1016
1017        let result_formats = msg
1018            .result_format_codes
1019            .iter()
1020            .map(|&format_code| Format::from_i16(format_code))
1021            .try_collect()?;
1022        let param_formats = msg
1023            .param_format_codes
1024            .iter()
1025            .map(|&format_code| Format::from_i16(format_code))
1026            .try_collect()?;
1027
1028        let portal = session
1029            .bind(prepare_statement, msg.params, param_formats, result_formats)
1030            .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1031
1032        if portal_name.is_empty() {
1033            self.result_cache.remove(&portal_name);
1034            self.unnamed_portal.replace(portal);
1035        } else {
1036            assert!(
1037                !self.result_cache.contains_key(&portal_name),
1038                "Named portal never can be overridden."
1039            );
1040            self.portal_store.insert(portal_name.clone(), portal);
1041        }
1042
1043        self.statement_portal_dependency
1044            .get_mut(&statement_name)
1045            .unwrap()
1046            .push(portal_name);
1047
1048        self.stream.write_no_flush(BeMessage::BindComplete)?;
1049        Ok(())
1050    }
1051
1052    async fn process_execute_msg(&mut self, msg: FeExecuteMessage) -> PsqlResult<()> {
1053        let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_owned();
1054        let row_max = msg.max_rows as usize;
1055        drop(msg);
1056        let session = self.session.clone().unwrap();
1057
1058        match self.result_cache.remove(&portal_name) {
1059            Some(mut result_cache) => {
1060                assert!(self.portal_store.contains_key(&portal_name));
1061
1062                let is_consume_completed =
1063                    result_cache.consume::<S>(row_max, &mut self.stream).await?;
1064
1065                if !is_consume_completed {
1066                    self.result_cache.insert(portal_name, result_cache);
1067                }
1068            }
1069            _ => {
1070                let portal = self.get_portal(&portal_name)?;
1071                let sql = format!("{}", portal);
1072                let truncated_sql =
1073                    record_sql_in_current_span(&sql, self.redact_sql_option_keywords.clone());
1074                drop(sql);
1075
1076                session.check_idle_in_transaction_timeout()?;
1077                // Store only truncated SQL in context to prevent excessive memory usage from large SQL.
1078                let _exec_context_guard = session.init_exec_context(truncated_sql.into());
1079                let result = session.clone().execute(portal).await;
1080
1081                let pg_response = result.map_err(|e| PsqlError::ExtendedExecuteError(e.into()))?;
1082                let mut result_cache = ResultCache::new(pg_response);
1083                let is_consume_completed =
1084                    result_cache.consume::<S>(row_max, &mut self.stream).await?;
1085                if !is_consume_completed {
1086                    self.result_cache.insert(portal_name, result_cache);
1087                }
1088            }
1089        }
1090
1091        Ok(())
1092    }
1093
1094    fn process_describe_msg(&mut self, msg: FeDescribeMessage) -> PsqlResult<()> {
1095        let name = cstr_to_str(&msg.name).unwrap().to_owned();
1096        let session = self.session.clone().unwrap();
1097        //  b'S' => Statement
1098        //  b'P' => Portal
1099
1100        assert!(msg.kind == b'S' || msg.kind == b'P');
1101        if msg.kind == b'S' {
1102            let prepare_statement = self.get_statement(&name)?;
1103
1104            let (param_types, row_descriptions) = self
1105                .session
1106                .clone()
1107                .unwrap()
1108                .describe_statement(prepare_statement)
1109                .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1110            self.stream.write_no_flush(BeMessage::ParameterDescription(
1111                &param_types.iter().map(|t| t.to_oid()).collect_vec(),
1112            ))?;
1113
1114            if row_descriptions.is_empty() {
1115                // According https://www.postgresql.org/docs/current/protocol-flow.html#:~:text=The%20response%20is%20a%20RowDescri[…]0a%20query%20that%20will%20return%20rows%3B,
1116                // return NoData message if the statement is not a query.
1117                self.stream.write_no_flush(BeMessage::NoData)?;
1118            } else {
1119                self.stream
1120                    .write_no_flush(BeMessage::RowDescription(&row_descriptions))?;
1121            }
1122        } else if msg.kind == b'P' {
1123            let portal = self.get_portal(&name)?;
1124
1125            let row_descriptions = session
1126                .describe_portal(portal)
1127                .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1128
1129            if row_descriptions.is_empty() {
1130                // According https://www.postgresql.org/docs/current/protocol-flow.html#:~:text=The%20response%20is%20a%20RowDescri[…]0a%20query%20that%20will%20return%20rows%3B,
1131                // return NoData message if the statement is not a query.
1132                self.stream.write_no_flush(BeMessage::NoData)?;
1133            } else {
1134                self.stream
1135                    .write_no_flush(BeMessage::RowDescription(&row_descriptions))?;
1136            }
1137        }
1138        Ok(())
1139    }
1140
1141    fn process_close_msg(&mut self, msg: FeCloseMessage) -> PsqlResult<()> {
1142        let name = cstr_to_str(&msg.name).unwrap().to_owned();
1143        assert!(msg.kind == b'S' || msg.kind == b'P');
1144        if msg.kind == b'S' {
1145            if name.is_empty() {
1146                self.unnamed_prepare_statement = None;
1147            } else {
1148                self.prepare_statement_store.remove(&name);
1149            }
1150            for portal_name in self
1151                .statement_portal_dependency
1152                .remove(&name)
1153                .unwrap_or_default()
1154            {
1155                self.remove_portal(&portal_name);
1156            }
1157        } else if msg.kind == b'P' {
1158            self.remove_portal(&name);
1159        }
1160        self.stream.write_no_flush(BeMessage::CloseComplete)?;
1161        Ok(())
1162    }
1163
1164    fn remove_portal(&mut self, portal_name: &str) {
1165        if portal_name.is_empty() {
1166            self.unnamed_portal = None;
1167        } else {
1168            self.portal_store.remove(portal_name);
1169        }
1170        self.result_cache.remove(portal_name);
1171    }
1172
1173    fn get_portal(&self, portal_name: &str) -> PsqlResult<<SM::Session as Session>::Portal> {
1174        if portal_name.is_empty() {
1175            Ok(self
1176                .unnamed_portal
1177                .as_ref()
1178                .ok_or_else(|| PsqlError::Uncategorized("unnamed portal not found".into()))?
1179                .clone())
1180        } else {
1181            Ok(self
1182                .portal_store
1183                .get(portal_name)
1184                .ok_or_else(|| {
1185                    PsqlError::Uncategorized(format!("Portal {} not found", portal_name).into())
1186                })?
1187                .clone())
1188        }
1189    }
1190
1191    fn get_statement(
1192        &self,
1193        statement_name: &str,
1194    ) -> PsqlResult<<SM::Session as Session>::PreparedStatement> {
1195        if statement_name.is_empty() {
1196            Ok(self
1197                .unnamed_prepare_statement
1198                .as_ref()
1199                .ok_or_else(|| {
1200                    PsqlError::Uncategorized("unnamed prepare statement not found".into())
1201                })?
1202                .clone())
1203        } else {
1204            Ok(self
1205                .prepare_statement_store
1206                .get(statement_name)
1207                .ok_or_else(|| {
1208                    PsqlError::Uncategorized(
1209                        format!("Prepare statement {} not found", statement_name).into(),
1210                    )
1211                })?
1212                .clone())
1213        }
1214    }
1215}
1216
1217enum PgStreamInner<S> {
1218    /// Used for the intermediate state when converting from unencrypted to ssl stream.
1219    Placeholder,
1220    /// An unencrypted stream.
1221    Unencrypted(S),
1222    /// An ssl stream.
1223    Ssl(SslStream<S>),
1224}
1225
1226/// Trait for a byte stream that can be used for pg protocol.
1227pub trait PgByteStream: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1228impl<S> PgByteStream for S where S: AsyncWrite + AsyncRead + Unpin + Send + 'static {}
1229
1230/// Wraps a byte stream and read/write pg messages.
1231///
1232/// Cloning a `PgStream` will share the same stream but a fresh & independent write buffer,
1233/// so that it can be used to write messages concurrently without interference.
1234pub struct PgStream<S> {
1235    /// The underlying stream.
1236    stream: Arc<Mutex<PgStreamInner<S>>>,
1237    /// Write into buffer before flush to stream.
1238    write_buf: BytesMut,
1239    read_header: Option<FeMessageHeader>,
1240}
1241
1242impl<S> PgStream<S> {
1243    /// Create a new `PgStream` with the given stream and default write buffer capacity.
1244    pub fn new(stream: S) -> Self {
1245        const DEFAULT_WRITE_BUF_CAPACITY: usize = 10 * 1024;
1246
1247        Self {
1248            stream: Arc::new(Mutex::new(PgStreamInner::Unencrypted(stream))),
1249            write_buf: BytesMut::with_capacity(DEFAULT_WRITE_BUF_CAPACITY),
1250            read_header: None,
1251        }
1252    }
1253
1254    /// Check if the current connection is using SSL
1255    async fn is_ssl_connection(&self) -> bool {
1256        let stream = self.stream.lock().await;
1257        matches!(*stream, PgStreamInner::Ssl(_))
1258    }
1259}
1260
1261impl<S> Clone for PgStream<S> {
1262    fn clone(&self) -> Self {
1263        Self {
1264            stream: Arc::clone(&self.stream),
1265            write_buf: BytesMut::with_capacity(self.write_buf.capacity()),
1266            read_header: self.read_header.clone(),
1267        }
1268    }
1269}
1270
1271/// At present there is a hard-wired set of parameters for which
1272/// ParameterStatus will be generated: they are:
1273///
1274///  * `server_version`
1275///  * `server_encoding`
1276///  * `client_encoding`
1277///  * `application_name`
1278///  * `is_superuser`
1279///  * `session_authorization`
1280///  * `DateStyle`
1281///  * `IntervalStyle`
1282///  * `TimeZone`
1283///  * `integer_datetimes`
1284///  * `standard_conforming_string`
1285///
1286/// See: <https://www.postgresql.org/docs/9.2/static/protocol-flow.html#PROTOCOL-ASYNC>.
1287#[derive(Debug, Default, Clone)]
1288pub struct ParameterStatus {
1289    pub application_name: Option<String>,
1290}
1291
1292impl<S> PgStream<S>
1293where
1294    S: PgByteStream,
1295{
1296    async fn read_startup(&mut self) -> io::Result<FeMessage> {
1297        let mut stream = self.stream.lock().await;
1298        match &mut *stream {
1299            PgStreamInner::Placeholder => unreachable!(),
1300            PgStreamInner::Unencrypted(stream) => FeStartupMessage::read(stream).await,
1301            PgStreamInner::Ssl(ssl_stream) => FeStartupMessage::read(ssl_stream).await,
1302        }
1303    }
1304
1305    async fn read_header(&mut self) -> io::Result<()> {
1306        let mut stream = self.stream.lock().await;
1307        match &mut *stream {
1308            PgStreamInner::Placeholder => unreachable!(),
1309            PgStreamInner::Unencrypted(stream) => {
1310                self.read_header = Some(FeMessage::read_header(stream).await?);
1311                Ok(())
1312            }
1313            PgStreamInner::Ssl(ssl_stream) => {
1314                self.read_header = Some(FeMessage::read_header(ssl_stream).await?);
1315                Ok(())
1316            }
1317        }
1318    }
1319
1320    async fn read_body(&mut self) -> io::Result<FeMessage> {
1321        let mut stream = self.stream.lock().await;
1322        let header = self
1323            .read_header
1324            .take()
1325            .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1326        match &mut *stream {
1327            PgStreamInner::Placeholder => unreachable!(),
1328            PgStreamInner::Unencrypted(stream) => FeMessage::read_body(stream, header).await,
1329            PgStreamInner::Ssl(ssl_stream) => FeMessage::read_body(ssl_stream, header).await,
1330        }
1331    }
1332
1333    async fn skip_body(&mut self) -> io::Result<()> {
1334        let mut stream = self.stream.lock().await;
1335        let header = self
1336            .read_header
1337            .take()
1338            .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidInput, "header not found"))?;
1339        match &mut *stream {
1340            PgStreamInner::Placeholder => unreachable!(),
1341            PgStreamInner::Unencrypted(stream) => FeMessage::skip_body(stream, header).await,
1342            PgStreamInner::Ssl(ssl_stream) => FeMessage::skip_body(ssl_stream, header).await,
1343        }
1344    }
1345
1346    fn write_parameter_status_msg_no_flush(&mut self, status: &ParameterStatus) -> io::Result<()> {
1347        self.write_no_flush(BeMessage::ParameterStatus(
1348            BeParameterStatusMessage::ClientEncoding(SERVER_ENCODING),
1349        ))?;
1350        self.write_no_flush(BeMessage::ParameterStatus(
1351            BeParameterStatusMessage::StandardConformingString(STANDARD_CONFORMING_STRINGS),
1352        ))?;
1353        self.write_no_flush(BeMessage::ParameterStatus(
1354            BeParameterStatusMessage::ServerVersion(PG_VERSION),
1355        ))?;
1356        if let Some(application_name) = &status.application_name {
1357            self.write_no_flush(BeMessage::ParameterStatus(
1358                BeParameterStatusMessage::ApplicationName(application_name),
1359            ))?;
1360        }
1361        Ok(())
1362    }
1363
1364    pub fn write_no_flush(&mut self, message: BeMessage<'_>) -> io::Result<()> {
1365        BeMessage::write(&mut self.write_buf, message)
1366    }
1367
1368    async fn write(&mut self, message: BeMessage<'_>) -> io::Result<()> {
1369        self.write_no_flush(message)?;
1370        self.flush().await?;
1371        Ok(())
1372    }
1373
1374    async fn flush(&mut self) -> io::Result<()> {
1375        let mut stream = self.stream.lock().await;
1376        match &mut *stream {
1377            PgStreamInner::Placeholder => unreachable!(),
1378            PgStreamInner::Unencrypted(stream) => {
1379                stream.write_all(&self.write_buf).await?;
1380                stream.flush().await?;
1381            }
1382            PgStreamInner::Ssl(ssl_stream) => {
1383                ssl_stream.write_all(&self.write_buf).await?;
1384                ssl_stream.flush().await?;
1385            }
1386        }
1387        self.write_buf.clear();
1388        Ok(())
1389    }
1390}
1391
1392impl<S> PgStream<S>
1393where
1394    S: PgByteStream,
1395{
1396    /// Convert the underlying stream to ssl stream based on the given context.
1397    async fn upgrade_to_ssl(&mut self, ssl_ctx: &SslContextRef) -> PsqlResult<()> {
1398        let mut stream = self.stream.lock().await;
1399
1400        match std::mem::replace(&mut *stream, PgStreamInner::Placeholder) {
1401            PgStreamInner::Unencrypted(unencrypted_stream) => {
1402                let ssl = openssl::ssl::Ssl::new(ssl_ctx).unwrap();
1403                let mut ssl_stream =
1404                    tokio_openssl::SslStream::new(ssl, unencrypted_stream).unwrap();
1405
1406                if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
1407                    tracing::warn!(error = %e.as_report(), "Unable to set up an ssl connection");
1408                    let _ = ssl_stream.shutdown().await;
1409                    return Err(e.into());
1410                }
1411
1412                *stream = PgStreamInner::Ssl(ssl_stream);
1413            }
1414            PgStreamInner::Ssl(_) => panic!("the stream is already ssl"),
1415            PgStreamInner::Placeholder => unreachable!(),
1416        }
1417
1418        Ok(())
1419    }
1420}
1421
1422fn build_ssl_ctx_from_config(tls_config: &TlsConfig) -> PsqlResult<SslContext> {
1423    let mut acceptor = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls()).unwrap();
1424
1425    let key_path = &tls_config.key;
1426    let cert_path = &tls_config.cert;
1427
1428    // Build ssl acceptor according to the config.
1429    // Now we set every verify to true.
1430    acceptor
1431        .set_private_key_file(key_path, openssl::ssl::SslFiletype::PEM)
1432        .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1433    acceptor
1434        .set_ca_file(cert_path)
1435        .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1436    acceptor
1437        .set_certificate_chain_file(cert_path)
1438        .map_err(|e| PsqlError::Uncategorized(e.into()))?;
1439    let acceptor = acceptor.build();
1440
1441    Ok(acceptor.into_context())
1442}
1443
1444pub mod truncated_fmt {
1445    use std::fmt::*;
1446
1447    struct TruncatedFormatter<'a, 'b> {
1448        remaining: usize,
1449        finished: bool,
1450        f: &'a mut Formatter<'b>,
1451    }
1452    impl Write for TruncatedFormatter<'_, '_> {
1453        fn write_str(&mut self, s: &str) -> Result {
1454            if self.finished {
1455                return Ok(());
1456            }
1457
1458            if self.remaining < s.len() {
1459                let actual = s.floor_char_boundary(self.remaining);
1460                self.f.write_str(&s[0..actual])?;
1461                self.remaining -= actual;
1462                self.f.write_str(&format!("...(truncated,{})", s.len()))?;
1463                self.finished = true; // so that ...(truncated) is printed exactly once
1464            } else {
1465                self.f.write_str(s)?;
1466                self.remaining -= s.len();
1467            }
1468            Ok(())
1469        }
1470    }
1471
1472    pub struct TruncatedFmt<'a, T>(pub &'a T, pub usize);
1473
1474    impl<T> Debug for TruncatedFmt<'_, T>
1475    where
1476        T: Debug,
1477    {
1478        fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1479            TruncatedFormatter {
1480                remaining: self.1,
1481                finished: false,
1482                f,
1483            }
1484            .write_fmt(format_args!("{:?}", self.0))
1485        }
1486    }
1487
1488    impl<T> Display for TruncatedFmt<'_, T>
1489    where
1490        T: Display,
1491    {
1492        fn fmt(&self, f: &mut Formatter<'_>) -> Result {
1493            TruncatedFormatter {
1494                remaining: self.1,
1495                finished: false,
1496                f,
1497            }
1498            .write_fmt(format_args!("{}", self.0))
1499        }
1500    }
1501
1502    #[cfg(test)]
1503    mod tests {
1504        use super::*;
1505
1506        #[test]
1507        fn test_trunc_utf8() {
1508            assert_eq!(
1509                format!("{}", TruncatedFmt(&"select '🌊';", 10)),
1510                "select '...(truncated,14)",
1511            );
1512        }
1513    }
1514}
1515
1516#[cfg(test)]
1517mod tests {
1518    use std::collections::HashSet;
1519
1520    use super::*;
1521
1522    #[test]
1523    fn test_redact_parsable_sql() {
1524        let keywords = Arc::new(HashSet::from(["v2".into(), "v4".into(), "b".into()]));
1525        let sql = r"
1526        create source temp (k bigint, v varchar) with (
1527            connector = 'datagen',
1528            v1 = 123,
1529            v2 = 'with',
1530            v3 = false,
1531            v4 = '',
1532        ) FORMAT plain ENCODE json (a='1',b='2')
1533        ";
1534        assert_eq!(
1535            redact_sql(sql, keywords),
1536            "CREATE SOURCE temp (k BIGINT, v CHARACTER VARYING) WITH (connector = 'datagen', v1 = 123, v2 = [REDACTED], v3 = false, v4 = [REDACTED]) FORMAT PLAIN ENCODE JSON (a = '1', b = [REDACTED])"
1537        );
1538    }
1539}